cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
  5. cehrgpt/data/sample_packing_sampler.py +36 -6
  6. cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  7. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  8. cehrgpt/generation/omop_converter_batch.py +32 -2
  9. cehrgpt/gpt_utils.py +20 -2
  10. cehrgpt/models/config.py +25 -0
  11. cehrgpt/models/hf_cehrgpt.py +244 -39
  12. cehrgpt/models/hf_modeling_outputs.py +1 -0
  13. cehrgpt/models/special_tokens.py +1 -0
  14. cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  15. cehrgpt/runners/data_utils.py +131 -5
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +6 -7
  20. cehrgpt/runners/sample_packing_trainer.py +17 -0
  21. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  22. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  23. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
  25. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
  26. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
  27. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
  28. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
  29. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,13 @@ class SamplePackingTrainer(Trainer):
35
35
  self.max_tokens_per_batch,
36
36
  )
37
37
 
38
+ self.negative_sampling_probability = kwargs.pop(
39
+ "negative_sampling_probability", None
40
+ )
41
+ if self.negative_sampling_probability:
42
+ LOG.info(
43
+ "negative_sampling_probability: %s", self.negative_sampling_probability
44
+ )
38
45
  self.train_lengths = kwargs.pop("train_lengths", None)
39
46
  self.validation_lengths = kwargs.pop("validation_lengths", None)
40
47
  super().__init__(*args, **kwargs)
@@ -70,6 +77,14 @@ class SamplePackingTrainer(Trainer):
70
77
  data_collator = self._get_collator_with_removed_columns(
71
78
  data_collator, description="training"
72
79
  )
80
+
81
+ labels = None
82
+ if (
83
+ self.negative_sampling_probability is not None
84
+ and "classifier_label" in train_dataset.column_names
85
+ ):
86
+ labels = train_dataset["classifier_label"]
87
+
73
88
  # Create our custom batch sampler
74
89
  batch_sampler = SamplePackingBatchSampler(
75
90
  lengths=lengths,
@@ -77,6 +92,8 @@ class SamplePackingTrainer(Trainer):
77
92
  max_position_embeddings=self.max_position_embeddings,
78
93
  drop_last=self.args.dataloader_drop_last,
79
94
  seed=self.args.seed,
95
+ negative_sampling_probability=self.negative_sampling_probability,
96
+ labels=labels,
80
97
  )
81
98
  dataloader_params = {
82
99
  "collate_fn": data_collator,
@@ -0,0 +1,23 @@
1
+ task_name: "cabg_prediction"
2
+ outcome_events: [
3
+ "43528001",
4
+ "43528003",
5
+ "43528004",
6
+ "43528002",
7
+ "4305852",
8
+ "4168831",
9
+ "2107250",
10
+ "2107216",
11
+ "2107222",
12
+ "2107231",
13
+ "4336464",
14
+ "4231998",
15
+ "4284104",
16
+ "2100873",
17
+ ]
18
+ future_visit_start: 0
19
+ future_visit_end: -1
20
+ prediction_window_start: 0
21
+ prediction_window_end: 365
22
+ max_new_tokens: 1024
23
+ include_descendants: true
@@ -80,20 +80,9 @@ class TimeToEventModel:
80
80
  return token in self.outcome_events
81
81
 
82
82
  def simulate(
83
- self, partial_history: Union[np.ndarray, List[str]]
83
+ self,
84
+ partial_history: Union[np.ndarray, List[str]],
84
85
  ) -> List[List[str]]:
85
-
86
- sequence_is_demographics = len(partial_history) == 4 and partial_history[
87
- 0
88
- ].startswith("year")
89
- sequence_ends_ve = is_visit_end(partial_history[-1])
90
-
91
- if not (sequence_is_demographics | sequence_ends_ve):
92
- raise ValueError(
93
- "There are only two types of sequences allowed. 1) the sequence only contains "
94
- "demographics; 2) the sequence ends on VE;"
95
- )
96
-
97
86
  token_ids = self.tokenizer.encode(partial_history)
98
87
  prompt = torch.tensor(token_ids).unsqueeze(0).to(self.device)
99
88
 
@@ -118,9 +118,9 @@ def main(args):
118
118
  LOG.info(f"Top P {args.top_p}")
119
119
  LOG.info(f"Top K {args.top_k}")
120
120
 
121
- cehrgpt_model.resize_position_embeddings(
122
- cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
123
- )
121
+ # cehrgpt_model.resize_position_embeddings(
122
+ # cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
123
+ # )
124
124
 
125
125
  generation_config = TimeToEventModel.get_generation_config(
126
126
  tokenizer=cehrgpt_tokenizer,
@@ -190,14 +190,22 @@ def main(args):
190
190
  args.max_n_trial,
191
191
  )
192
192
  visit_counter = sum([int(is_visit_end(_)) for _ in partial_history])
193
+ predicted_boolean_probability = (
194
+ sum([event != "0" for event in concept_time_to_event.outcome_events])
195
+ / len(concept_time_to_event.outcome_events)
196
+ if concept_time_to_event
197
+ else 0.0
198
+ )
193
199
  tte_outputs.append(
194
200
  {
195
- "person_id": record["person_id"],
196
- "index_date": record["index_date"],
201
+ "subject_id": record["person_id"],
202
+ "prediction_time": record["index_date"],
197
203
  "visit_counter": visit_counter,
198
- "label": label,
204
+ "boolean_value": label,
205
+ "predicted_boolean_probability": predicted_boolean_probability,
206
+ "predicted_boolean_value": None,
199
207
  "time_to_event": time_to_event,
200
- "prediction": (
208
+ "trials": (
201
209
  asdict(concept_time_to_event) if concept_time_to_event else None
202
210
  ),
203
211
  }
@@ -263,9 +271,13 @@ def filter_out_existing_results(
263
271
  parquet_files = glob.glob(os.path.join(prediction_output_folder_name, "*parquet"))
264
272
  if parquet_files:
265
273
  cohort_members = set()
266
- results_dataframe = pd.read_parquet(parquet_files)[["person_id", "index_date"]]
274
+ results_dataframe = pd.read_parquet(parquet_files)[
275
+ ["subject_id", "prediction_time"]
276
+ ]
267
277
  for row in results_dataframe.itertuples():
268
- cohort_members.add((row.person_id, row.index_date.strftime("%Y-%m-%d")))
278
+ cohort_members.add(
279
+ (row.subject_id, row.prediction_time.strftime("%Y-%m-%d"))
280
+ )
269
281
 
270
282
  def filter_func(batched):
271
283
  return [
@@ -292,12 +304,14 @@ def flush_to_disk_if_full(
292
304
  pd.DataFrame(
293
305
  tte_outputs,
294
306
  columns=[
295
- "person_id",
296
- "index_date",
307
+ "subject_id",
308
+ "prediction_time",
297
309
  "visit_counter",
298
- "label",
310
+ "boolean_value",
311
+ "predicted_boolean_probability",
312
+ "predicted_boolean_value",
299
313
  "time_to_event",
300
- "prediction",
314
+ "trials",
301
315
  ],
302
316
  ).to_parquet(output_parquet_file)
303
317
  tte_outputs.clear()
@@ -1,8 +1,8 @@
1
+ import datetime
1
2
  import glob
2
3
  import os
3
4
  import shutil
4
5
  import uuid
5
- from datetime import datetime
6
6
  from functools import partial
7
7
  from pathlib import Path
8
8
  from typing import Optional, Union
@@ -29,8 +29,12 @@ from cehrgpt.models.hf_cehrgpt import (
29
29
  CEHRGPT2Model,
30
30
  extract_features_from_packed_sequence,
31
31
  )
32
+ from cehrgpt.models.special_tokens import LINEAR_PROB_TOKEN
32
33
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
33
- from cehrgpt.runners.data_utils import prepare_finetune_dataset
34
+ from cehrgpt.runners.data_utils import (
35
+ extract_cohort_sequences,
36
+ prepare_finetune_dataset,
37
+ )
34
38
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
35
39
  from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
36
40
 
@@ -112,6 +116,11 @@ def main():
112
116
  .eval()
113
117
  .to(device)
114
118
  )
119
+
120
+ if LINEAR_PROB_TOKEN not in cehrgpt_tokenizer.get_vocab():
121
+ cehrgpt_tokenizer.add_tokens(LINEAR_PROB_TOKEN)
122
+ cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
123
+
115
124
  prepared_ds_path = generate_prepared_ds_path(
116
125
  data_args, model_args, data_folder=data_args.cohort_folder
117
126
  )
@@ -137,39 +146,31 @@ def main():
137
146
 
138
147
  if processed_dataset is None:
139
148
  if is_main_process(training_args.local_rank):
140
- # Organize them into a single DatasetDict
141
- final_splits = prepare_finetune_dataset(
142
- data_args, training_args, cehrgpt_args, cache_file_collector
143
- )
144
- if cehrgpt_args.expand_tokenizer:
145
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
146
- if tokenizer_exists(new_tokenizer_path):
147
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
148
- new_tokenizer_path
149
- )
150
- else:
151
- cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
152
- cehrgpt_tokenizer=cehrgpt_tokenizer,
153
- dataset=final_splits["train"],
154
- data_args=data_args,
155
- concept_name_mapping={},
156
- )
157
- cehrgpt_tokenizer.save_pretrained(
158
- os.path.expanduser(training_args.output_dir)
159
- )
160
-
149
+ # If the full dataset has been tokenized, we don't want to tokenize the cohort containing
150
+ # the subset of the data. We should slice out the portion of the tokenized sequences for each sample
151
+ if cehrgpt_args.tokenized_full_dataset_path is not None:
152
+ processed_dataset = extract_cohort_sequences(
153
+ data_args, cehrgpt_args, cache_file_collector
154
+ )
155
+ else:
156
+ # Organize them into a single DatasetDict
157
+ final_splits = prepare_finetune_dataset(
158
+ data_args, training_args, cehrgpt_args, cache_file_collector
159
+ )
161
160
  # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
162
- if not data_args.streaming:
163
- all_columns = final_splits["train"].column_names
164
- if "visit_concept_ids" in all_columns:
165
- final_splits = final_splits.remove_columns(["visit_concept_ids"])
166
-
167
- processed_dataset = create_cehrgpt_finetuning_dataset(
168
- dataset=final_splits,
169
- cehrgpt_tokenizer=cehrgpt_tokenizer,
170
- data_args=data_args,
171
- cache_file_collector=cache_file_collector,
172
- )
161
+ if not data_args.streaming:
162
+ all_columns = final_splits["train"].column_names
163
+ if "visit_concept_ids" in all_columns:
164
+ final_splits = final_splits.remove_columns(
165
+ ["visit_concept_ids"]
166
+ )
167
+
168
+ processed_dataset = create_cehrgpt_finetuning_dataset(
169
+ dataset=final_splits,
170
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
171
+ data_args=data_args,
172
+ cache_file_collector=cache_file_collector,
173
+ )
173
174
  if not data_args.streaming:
174
175
  processed_dataset.save_to_disk(prepared_ds_path)
175
176
  processed_dataset.cleanup_cache_files()
@@ -218,10 +219,6 @@ def main():
218
219
  len(processed_dataset["test"]),
219
220
  )
220
221
 
221
- LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
222
- LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
223
- if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
224
- cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
225
222
  if (
226
223
  cehrgpt_model.config.max_position_embeddings
227
224
  < model_args.max_position_embeddings
@@ -244,6 +241,7 @@ def main():
244
241
  SamplePackingCehrGptDataCollator,
245
242
  cehrgpt_args.max_tokens_per_batch,
246
243
  cehrgpt_model.config.max_position_embeddings,
244
+ add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
247
245
  )
248
246
  train_batch_sampler = SamplePackingBatchSampler(
249
247
  lengths=train_set["num_of_concepts"],
@@ -278,6 +276,7 @@ def main():
278
276
  include_ttv_prediction=False,
279
277
  use_sub_time_tokenization=False,
280
278
  include_demographics=cehrgpt_args.include_demographics,
279
+ add_linear_prob_token=True,
281
280
  )
282
281
 
283
282
  train_loader = DataLoader(
@@ -298,30 +297,38 @@ def main():
298
297
  batch_sampler=test_batch_sampler,
299
298
  )
300
299
 
301
- # Loading demographics
302
- print("Loading demographics as a dictionary")
303
- demographics_df = pd.concat(
304
- [
305
- pd.read_parquet(
306
- data_dir,
307
- columns=[
308
- "person_id",
309
- "index_date",
310
- "gender_concept_id",
311
- "race_concept_id",
312
- ],
313
- )
314
- for data_dir in [data_args.data_folder, data_args.test_data_folder]
315
- ]
316
- )
317
- demographics_df["index_date"] = demographics_df.index_date.dt.date
318
- demographics_dict = {
319
- (row["person_id"], row["index_date"]): {
320
- "gender_concept_id": row["gender_concept_id"],
321
- "race_concept_id": row["race_concept_id"],
300
+ if data_args.is_data_in_meds:
301
+ demographics_dict = dict()
302
+ else:
303
+ # Loading demographics
304
+ print("Loading demographics as a dictionary")
305
+ demographics_df = pd.concat(
306
+ [
307
+ pd.read_parquet(
308
+ data_dir,
309
+ columns=[
310
+ "person_id",
311
+ "index_date",
312
+ "gender_concept_id",
313
+ "race_concept_id",
314
+ ],
315
+ )
316
+ for data_dir in [data_args.data_folder, data_args.test_data_folder]
317
+ ]
318
+ )
319
+
320
+ demographics_df["index_date"] = (
321
+ demographics_df["index_date"].dt.tz_localize("UTC")
322
+ - datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
323
+ ).dt.total_seconds()
324
+
325
+ demographics_dict = {
326
+ (row["person_id"], row["index_date"]): {
327
+ "gender_concept_id": row["gender_concept_id"],
328
+ "race_concept_id": row["race_concept_id"],
329
+ }
330
+ for _, row in demographics_df.iterrows()
322
331
  }
323
- for _, row in demographics_df.iterrows()
324
- }
325
332
 
326
333
  data_loaders = [("train", train_loader), ("test", test_dataloader)]
327
334
 
@@ -351,9 +358,16 @@ def main():
351
358
  prediction_time_posix = batch.pop("index_date").numpy().squeeze()
352
359
  if prediction_time_posix.ndim == 0:
353
360
  prediction_time_posix = np.asarray([prediction_time_posix])
361
+
354
362
  prediction_time = list(
355
- map(datetime.fromtimestamp, prediction_time_posix)
363
+ map(
364
+ lambda posix_time: datetime.datetime.utcfromtimestamp(
365
+ posix_time
366
+ ).replace(tzinfo=None),
367
+ prediction_time_posix,
368
+ )
356
369
  )
370
+
357
371
  labels = (
358
372
  batch.pop("classifier_label")
359
373
  .float()
@@ -365,6 +379,10 @@ def main():
365
379
  if labels.ndim == 0:
366
380
  labels = np.asarray([labels])
367
381
 
382
+ # Right now the model does not support this column, we need to pop it
383
+ if "epoch_times" in batch:
384
+ batch.pop("epoch_times")
385
+
368
386
  batch = {k: v.to(device) for k, v in batch.items()}
369
387
  # Forward pass
370
388
  cehrgpt_output = cehrgpt_model(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,14 +12,15 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.4.1
16
- Requires-Dist: cehrbert_data==0.0.7
15
+ Requires-Dist: cehrbert==1.4.5
16
+ Requires-Dist: cehrbert_data==0.0.11
17
17
  Requires-Dist: openai==1.54.3
18
18
  Requires-Dist: optuna==4.0.0
19
- Requires-Dist: transformers==4.44.0
19
+ Requires-Dist: transformers==4.44.1
20
20
  Requires-Dist: tokenizers==0.19.0
21
21
  Requires-Dist: peft==0.10.0
22
22
  Requires-Dist: lightgbm
23
+ Requires-Dist: polars
23
24
  Provides-Extra: dev
24
25
  Requires-Dist: pre-commit; extra == "dev"
25
26
  Requires-Dist: pytest; extra == "dev"
@@ -36,9 +37,9 @@ Dynamic: license-file
36
37
 
37
38
  [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
38
39
  ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
39
- [![tests](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yml)
40
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt-public/blob/main/LICENSE)
41
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt-public.svg)](https://github.com/knatarajan-lab/cehrgpt-public/graphs/contributors)
40
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
41
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
42
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
42
43
 
43
44
  ## Description
44
45
  CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
@@ -104,6 +105,100 @@ sh scripts/omop_pipeline.sh \
104
105
  $OMOP_VOCAB_DIR
105
106
  ```
106
107
 
108
+ # MEDS Support
109
+
110
+ This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
111
+
112
+ ## Prerequisites
113
+
114
+ Set up the required environment variables before beginning:
115
+
116
+ ```bash
117
+ export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
118
+ export MEDS_DIR="" # Path to MEDS data directory
119
+ export MEDS_READER_DIR="" # Path to MEDS reader output directory
120
+ ```
121
+
122
+ ## Step 1: Create MIMIC MEDS Data
123
+
124
+ Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
125
+
126
+ ## Step 2: Create the MEDS Reader
127
+
128
+ Convert the MEDS data for use with CEHR-GPT:
129
+
130
+ ```bash
131
+ meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
132
+ ```
133
+
134
+ ## Step 3: Pretrain CEHR-GPT
135
+
136
+ Run the pretraining process using the prepared MEDS data:
137
+
138
+ ```bash
139
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
140
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
141
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
142
+ --output_dir $CEHR_GPT_MODEL_DIR \
143
+ --data_folder $MEDS_READER_DIR \
144
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
145
+ --do_train true --seed 42 \
146
+ --dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
147
+ --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
148
+ --evaluation_strategy epoch --save_strategy epoch \
149
+ --sample_packing --max_tokens_per_batch 16384 \
150
+ --warmup_steps 500 --weight_decay 0.01 \
151
+ --num_train_epochs 50 --learning_rate 0.0002 \
152
+ --use_early_stopping --early_stopping_threshold 0.001 \
153
+ --is_data_in_meds --inpatient_att_function_type day \
154
+ --att_function_type day --include_inpatient_hour_token \
155
+ --include_auxiliary_token --include_demographic_prompt \
156
+ --meds_to_cehrbert_conversion_type "MedsToBertMimic4"
157
+ ```
158
+
159
+ ## Step 4: Generate MEDS Trajectories
160
+
161
+ ### Environment Setup for Trajectory Generation
162
+
163
+ Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
164
+
165
+ ```bash
166
+ # MEDS_LABEL_COHORT_DIR must contain a set of parquet files
167
+ export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
168
+ export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
169
+ ```
170
+
171
+ ### Generate Trajectories
172
+
173
+ Create synthetic patient trajectories using the trained model:
174
+
175
+ > **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
176
+
177
+ ```bash
178
+ python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
179
+ --cohort_folder $MEDS_LABEL_COHORT_DIR \
180
+ --data_folder $MEDS_READER_DIR \
181
+ --dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
182
+ --model_name_or_path $CEHR_GPT_MODEL_DIR \
183
+ --tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
184
+ --output_dir $MEDS_TRAJECTORY_DIR \
185
+ --per_device_eval_batch_size 16 \
186
+ --num_of_trajectories_per_sample 2 \
187
+ --generation_input_length 4096 \
188
+ --generation_max_new_tokens 4096 \
189
+ --is_data_in_meds \
190
+ --att_function_type day --inpatient_att_function_type day \
191
+ --meds_to_cehrbert_conversion_type MedsToBertMimic4 \
192
+ --include_auxiliary_token --include_demographic_prompt \
193
+ --include_inpatient_hour_token
194
+ ```
195
+
196
+ ### Parameters Explanation
197
+
198
+ - `generation_input_length`: Controls the length of input context for generation
199
+ - `generation_max_new_tokens`: Maximum number of new tokens to generate
200
+ - `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
201
+
107
202
  ## Citation
108
203
  ```
109
204
  @article{cehrgpt2024,
@@ -1,8 +1,9 @@
1
1
  __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  cehrgpt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  cehrgpt/cehrgpt_args.py,sha256=zPLp9Qjlq5PapWx3R15BNnyaX8zV3dxr4PuWj71r0Lg,3516
4
- cehrgpt/gpt_utils.py,sha256=bksHCXMX4j859VSv1Q284rVr4gn1Y8dCx4a_V-g4mug,10939
4
+ cehrgpt/gpt_utils.py,sha256=IA5qw-hxcKkGO07AB47lDNRU6mlb9jblpKO7KeLLN78,11342
5
5
  cehrgpt/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ cehrgpt/analysis/irregularity.py,sha256=Rfl_daMvSh9cZ68vUwfmuH-JYCFXdAph2ITHHffYC0Y,1047
6
7
  cehrgpt/analysis/privacy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
8
  cehrgpt/analysis/privacy/attribute_inference.py,sha256=0ANVW0I5uvOl6IxQ15-vMVQd0mugOgSGReBUQQESImg,9368
8
9
  cehrgpt/analysis/privacy/attribute_inference_config.yml,sha256=hfLfpBlDqqsNOynpRHK414vV24edKA6ta-inmEhM2ao,103272
@@ -11,22 +12,23 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
11
12
  cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
12
13
  cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
13
14
  cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- cehrgpt/data/hf_cehrgpt_dataset.py,sha256=t9vpN05e--CiKgIlxLP0aLacISnvWWDPXtuFuJi3ksE,3736
15
- cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=DOvIF4Wzkd8-IO3zpIRZkX1j0IdvefaiSnrDn1YivCk,27912
16
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=eI8CTk6yJ4DlNJWrNAkEmhWh353NeLqg5rwPpKqKT-U,17308
17
- cehrgpt/data/sample_packing_sampler.py,sha256=0uKTbvtXpfS81esy_3epJ88eohyJPK46bfmxhle1fws,5419
15
+ cehrgpt/data/hf_cehrgpt_dataset.py,sha256=hwJlGW7XiJIr6cXtmwvReQf9yLZJPD-dvJGvRg5ERqU,3755
16
+ cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=juM5HeZScgj8w15Bl1qC83Swld4gY6avh0QkSWLqITA,45465
17
+ cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=_QDX9NXfmQ_S3kOf3yndb3AhoEeFiSzAOv836uYW0AY,26230
18
+ cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
18
19
  cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
+ cehrgpt/generation/cehrgpt_conditional_generation.py,sha256=AM76yaPyw1B-bcdei24HO0uspGZWHGKWpYpHywotTIQ,11972
19
21
  cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
20
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
21
- cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
22
+ cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=P8al4-zqymqEkCHCCu2sqz_45akcKF2o_AtQIjJdVmQ,11919
23
+ cehrgpt/generation/omop_converter_batch.py,sha256=LUmCD-t_6ZP1YfNDZCqYewl-XIIaIgRZ_dAxuR_VdCQ,26275
22
24
  cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
23
25
  cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
- cehrgpt/models/config.py,sha256=Y3CiXZWniLP9_RlpU80Oe9gjn5leLmTYnNe_fWqfJLQ,10158
25
- cehrgpt/models/hf_cehrgpt.py,sha256=3EQIOfa--oz4f8bM8KzbDi98G3XrUEQkox1vmBN001M,83321
26
- cehrgpt/models/hf_modeling_outputs.py,sha256=LaWa1jI6BRIKMEjWOy1QUeOfTur5y_p2c-JyuGVTdtw,10301
26
+ cehrgpt/models/config.py,sha256=nOAKgH5420HLCcy7n1hE7MbqR861Iq4DTutKoAd25tg,11090
27
+ cehrgpt/models/hf_cehrgpt.py,sha256=3P7bOLDr7NMSedGszhmlJJN4Mhpd_65-x6uzwvSjigE,92837
28
+ cehrgpt/models/hf_modeling_outputs.py,sha256=5X4WEYKqT37phv_e5ZAv3A_N0wqdAUJLJRm6TxS6dDQ,10356
27
29
  cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
28
- cehrgpt/models/special_tokens.py,sha256=-a7HPJBbdIH0qQ6B3CcRKqvpG6FZlm4nbVPTswGSJ4U,485
29
- cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=jjCRqS29IzMnKp40jNOs80UKh2z9lK5S6M02GSB-4mk,42351
30
+ cehrgpt/models/special_tokens.py,sha256=lrw45B4tea4Dsajn09Cz6w5D2TfHmYXikZkgwnstu_o,521
31
+ cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=cAxHTctpVBxfWfC3XcwDQavN1zwWN9Nid_Fajd5zQWQ,53159
30
32
  cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
33
  cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
32
34
  cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
@@ -37,22 +39,23 @@ cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
37
39
  cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
38
40
  cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
39
41
  cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
- cehrgpt/runners/data_utils.py,sha256=ScZZnfXwgXKaMvKgFzdb4vtQ7F_lw97O5uNsFbfsyP4,10620
42
+ cehrgpt/runners/data_utils.py,sha256=i-krtBx_6rvPYtdLdDoWwOTtJcaovd0wH8gBYmgN2l4,16013
41
43
  cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
42
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=bkPl30Y9CSXBlmMkH-3cA3-aW8XJK36Q-adx___WjkE,26921
43
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ViVa_flEGdk_SO0psMR7ho-o79igsz_l1x80u81WJ3A,23875
44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=VrqgDSiAMfGyHEIodoOg_8LU5O0ndWf9EE0YOKDFKKA,7019
45
- cehrgpt/runners/hyperparameter_search_util.py,sha256=pWFmGo9Ezju4YmuZ-ohbAbYB0GGMfIDVUCyvcTxS1iU,9153
46
- cehrgpt/runners/sample_packing_trainer.py,sha256=aezX30vxpP1DDcH5hO-yn395NqBKi2Xhb0mFNHi9OBs,7340
44
+ cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=1OgxLm4T7iHv5pKi2QaSdaz9ogWo2n3sSUGp6cHDF9s,28309
45
+ cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ERSnvB38fPYVghtKQeNTZ8VfeXnoRcCHB0cWISWaZ84,26523
46
+ cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=fJR4RHPqal1YI6_KUH-WlkoQLSZuBT5bKUGfPHDFrWI,9350
47
+ cehrgpt/runners/hyperparameter_search_util.py,sha256=YWdFQ1igQs-G_wqWUrUzYraGiz8OSpSYyvid-I5nhWA,9262
48
+ cehrgpt/runners/sample_packing_trainer.py,sha256=Zb7Aqwnk8-VqrjEKUVeg5XzZWmHxXOU2sDn1YURS-FU,7960
47
49
  cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
50
  cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
49
51
  cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
50
52
  cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
51
53
  cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
52
54
  cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
- cehrgpt/time_to_event/time_to_event_model.py,sha256=tfXa24l_0q1TBZ68BPRrHRC_3KRWYxrWGIv4myJlIb8,8497
54
- cehrgpt/time_to_event/time_to_event_prediction.py,sha256=Ajesq2gSsILghWHCTLiiBhcyOCa7m6JPPMdi_xvBlR4,12624
55
+ cehrgpt/time_to_event/time_to_event_model.py,sha256=Plm0bZxvlAbnMl82DTBXWvaXLvrqcdkzcP_celX8WC4,8055
56
+ cehrgpt/time_to_event/time_to_event_prediction.py,sha256=W2e7UqIV7ELdfTy997HS66vggjnhdncCKt840knI0Dw,13183
55
57
  cehrgpt/time_to_event/time_to_event_utils.py,sha256=KN4hwGgxy2nJtO7osbYQBF3-HpmGUWefNfexzPYiEwc,1937
58
+ cehrgpt/time_to_event/config/1_year_cabg.yaml,sha256=SFF2-F5D02pDSMRddDrEUoERBCd0t2Hzln_xC-Mo2hA,407
56
59
  cehrgpt/time_to_event/config/30_day_readmission.yaml,sha256=Hn5KnEXMtSV_CtCpmAU4wjkc0-gTXvniaH991TSbUXA,234
57
60
  cehrgpt/time_to_event/config/next_visit_type_prediction.yaml,sha256=WMj2ZutEvHKIMyGG51xtXaL6MyRANKvpg9xT8ouctLc,319
58
61
  cehrgpt/time_to_event/config/t2dm_hf.yaml,sha256=_oMQzh2eJTYzEaMOpmhAzbX-qmdsKlkORELL6HxOxHo,202
@@ -63,10 +66,10 @@ cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8e
63
66
  cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
64
67
  cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
65
68
  cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=jVgAmBrZKp7ABfqKkzwV5Vl_G9jDCjPl98NSVmSwHpE,19291
69
+ cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=Hpx7WvAWm2WwPHFfimCADXh019I7bwdzJ4_5_YCxQzU,19817
67
70
  cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
68
- cehrgpt-0.1.0.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
69
- cehrgpt-0.1.0.dist-info/METADATA,sha256=V02vsptjJRD_bybXVRFXPrJa-By9CX4j-oAA3EfXFq4,4933
70
- cehrgpt-0.1.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
71
- cehrgpt-0.1.0.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
72
- cehrgpt-0.1.0.dist-info/RECORD,,
71
+ cehrgpt-0.1.2.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
72
+ cehrgpt-0.1.2.dist-info/METADATA,sha256=D7gGKrQThiLivViFeNm711NCP8J-wXfkueMGb6RKqV0,8481
73
+ cehrgpt-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
74
+ cehrgpt-0.1.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
75
+ cehrgpt-0.1.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5