cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +244 -39
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +131 -5
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,13 @@ class SamplePackingTrainer(Trainer):
|
|
35
35
|
self.max_tokens_per_batch,
|
36
36
|
)
|
37
37
|
|
38
|
+
self.negative_sampling_probability = kwargs.pop(
|
39
|
+
"negative_sampling_probability", None
|
40
|
+
)
|
41
|
+
if self.negative_sampling_probability:
|
42
|
+
LOG.info(
|
43
|
+
"negative_sampling_probability: %s", self.negative_sampling_probability
|
44
|
+
)
|
38
45
|
self.train_lengths = kwargs.pop("train_lengths", None)
|
39
46
|
self.validation_lengths = kwargs.pop("validation_lengths", None)
|
40
47
|
super().__init__(*args, **kwargs)
|
@@ -70,6 +77,14 @@ class SamplePackingTrainer(Trainer):
|
|
70
77
|
data_collator = self._get_collator_with_removed_columns(
|
71
78
|
data_collator, description="training"
|
72
79
|
)
|
80
|
+
|
81
|
+
labels = None
|
82
|
+
if (
|
83
|
+
self.negative_sampling_probability is not None
|
84
|
+
and "classifier_label" in train_dataset.column_names
|
85
|
+
):
|
86
|
+
labels = train_dataset["classifier_label"]
|
87
|
+
|
73
88
|
# Create our custom batch sampler
|
74
89
|
batch_sampler = SamplePackingBatchSampler(
|
75
90
|
lengths=lengths,
|
@@ -77,6 +92,8 @@ class SamplePackingTrainer(Trainer):
|
|
77
92
|
max_position_embeddings=self.max_position_embeddings,
|
78
93
|
drop_last=self.args.dataloader_drop_last,
|
79
94
|
seed=self.args.seed,
|
95
|
+
negative_sampling_probability=self.negative_sampling_probability,
|
96
|
+
labels=labels,
|
80
97
|
)
|
81
98
|
dataloader_params = {
|
82
99
|
"collate_fn": data_collator,
|
@@ -0,0 +1,23 @@
|
|
1
|
+
task_name: "cabg_prediction"
|
2
|
+
outcome_events: [
|
3
|
+
"43528001",
|
4
|
+
"43528003",
|
5
|
+
"43528004",
|
6
|
+
"43528002",
|
7
|
+
"4305852",
|
8
|
+
"4168831",
|
9
|
+
"2107250",
|
10
|
+
"2107216",
|
11
|
+
"2107222",
|
12
|
+
"2107231",
|
13
|
+
"4336464",
|
14
|
+
"4231998",
|
15
|
+
"4284104",
|
16
|
+
"2100873",
|
17
|
+
]
|
18
|
+
future_visit_start: 0
|
19
|
+
future_visit_end: -1
|
20
|
+
prediction_window_start: 0
|
21
|
+
prediction_window_end: 365
|
22
|
+
max_new_tokens: 1024
|
23
|
+
include_descendants: true
|
@@ -80,20 +80,9 @@ class TimeToEventModel:
|
|
80
80
|
return token in self.outcome_events
|
81
81
|
|
82
82
|
def simulate(
|
83
|
-
self,
|
83
|
+
self,
|
84
|
+
partial_history: Union[np.ndarray, List[str]],
|
84
85
|
) -> List[List[str]]:
|
85
|
-
|
86
|
-
sequence_is_demographics = len(partial_history) == 4 and partial_history[
|
87
|
-
0
|
88
|
-
].startswith("year")
|
89
|
-
sequence_ends_ve = is_visit_end(partial_history[-1])
|
90
|
-
|
91
|
-
if not (sequence_is_demographics | sequence_ends_ve):
|
92
|
-
raise ValueError(
|
93
|
-
"There are only two types of sequences allowed. 1) the sequence only contains "
|
94
|
-
"demographics; 2) the sequence ends on VE;"
|
95
|
-
)
|
96
|
-
|
97
86
|
token_ids = self.tokenizer.encode(partial_history)
|
98
87
|
prompt = torch.tensor(token_ids).unsqueeze(0).to(self.device)
|
99
88
|
|
@@ -118,9 +118,9 @@ def main(args):
|
|
118
118
|
LOG.info(f"Top P {args.top_p}")
|
119
119
|
LOG.info(f"Top K {args.top_k}")
|
120
120
|
|
121
|
-
cehrgpt_model.resize_position_embeddings(
|
122
|
-
|
123
|
-
)
|
121
|
+
# cehrgpt_model.resize_position_embeddings(
|
122
|
+
# cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
|
123
|
+
# )
|
124
124
|
|
125
125
|
generation_config = TimeToEventModel.get_generation_config(
|
126
126
|
tokenizer=cehrgpt_tokenizer,
|
@@ -190,14 +190,22 @@ def main(args):
|
|
190
190
|
args.max_n_trial,
|
191
191
|
)
|
192
192
|
visit_counter = sum([int(is_visit_end(_)) for _ in partial_history])
|
193
|
+
predicted_boolean_probability = (
|
194
|
+
sum([event != "0" for event in concept_time_to_event.outcome_events])
|
195
|
+
/ len(concept_time_to_event.outcome_events)
|
196
|
+
if concept_time_to_event
|
197
|
+
else 0.0
|
198
|
+
)
|
193
199
|
tte_outputs.append(
|
194
200
|
{
|
195
|
-
"
|
196
|
-
"
|
201
|
+
"subject_id": record["person_id"],
|
202
|
+
"prediction_time": record["index_date"],
|
197
203
|
"visit_counter": visit_counter,
|
198
|
-
"
|
204
|
+
"boolean_value": label,
|
205
|
+
"predicted_boolean_probability": predicted_boolean_probability,
|
206
|
+
"predicted_boolean_value": None,
|
199
207
|
"time_to_event": time_to_event,
|
200
|
-
"
|
208
|
+
"trials": (
|
201
209
|
asdict(concept_time_to_event) if concept_time_to_event else None
|
202
210
|
),
|
203
211
|
}
|
@@ -263,9 +271,13 @@ def filter_out_existing_results(
|
|
263
271
|
parquet_files = glob.glob(os.path.join(prediction_output_folder_name, "*parquet"))
|
264
272
|
if parquet_files:
|
265
273
|
cohort_members = set()
|
266
|
-
results_dataframe = pd.read_parquet(parquet_files)[
|
274
|
+
results_dataframe = pd.read_parquet(parquet_files)[
|
275
|
+
["subject_id", "prediction_time"]
|
276
|
+
]
|
267
277
|
for row in results_dataframe.itertuples():
|
268
|
-
cohort_members.add(
|
278
|
+
cohort_members.add(
|
279
|
+
(row.subject_id, row.prediction_time.strftime("%Y-%m-%d"))
|
280
|
+
)
|
269
281
|
|
270
282
|
def filter_func(batched):
|
271
283
|
return [
|
@@ -292,12 +304,14 @@ def flush_to_disk_if_full(
|
|
292
304
|
pd.DataFrame(
|
293
305
|
tte_outputs,
|
294
306
|
columns=[
|
295
|
-
"
|
296
|
-
"
|
307
|
+
"subject_id",
|
308
|
+
"prediction_time",
|
297
309
|
"visit_counter",
|
298
|
-
"
|
310
|
+
"boolean_value",
|
311
|
+
"predicted_boolean_probability",
|
312
|
+
"predicted_boolean_value",
|
299
313
|
"time_to_event",
|
300
|
-
"
|
314
|
+
"trials",
|
301
315
|
],
|
302
316
|
).to_parquet(output_parquet_file)
|
303
317
|
tte_outputs.clear()
|
@@ -1,8 +1,8 @@
|
|
1
|
+
import datetime
|
1
2
|
import glob
|
2
3
|
import os
|
3
4
|
import shutil
|
4
5
|
import uuid
|
5
|
-
from datetime import datetime
|
6
6
|
from functools import partial
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Optional, Union
|
@@ -29,8 +29,12 @@ from cehrgpt.models.hf_cehrgpt import (
|
|
29
29
|
CEHRGPT2Model,
|
30
30
|
extract_features_from_packed_sequence,
|
31
31
|
)
|
32
|
+
from cehrgpt.models.special_tokens import LINEAR_PROB_TOKEN
|
32
33
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
33
|
-
from cehrgpt.runners.data_utils import
|
34
|
+
from cehrgpt.runners.data_utils import (
|
35
|
+
extract_cohort_sequences,
|
36
|
+
prepare_finetune_dataset,
|
37
|
+
)
|
34
38
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
35
39
|
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
36
40
|
|
@@ -112,6 +116,11 @@ def main():
|
|
112
116
|
.eval()
|
113
117
|
.to(device)
|
114
118
|
)
|
119
|
+
|
120
|
+
if LINEAR_PROB_TOKEN not in cehrgpt_tokenizer.get_vocab():
|
121
|
+
cehrgpt_tokenizer.add_tokens(LINEAR_PROB_TOKEN)
|
122
|
+
cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
123
|
+
|
115
124
|
prepared_ds_path = generate_prepared_ds_path(
|
116
125
|
data_args, model_args, data_folder=data_args.cohort_folder
|
117
126
|
)
|
@@ -137,39 +146,31 @@ def main():
|
|
137
146
|
|
138
147
|
if processed_dataset is None:
|
139
148
|
if is_main_process(training_args.local_rank):
|
140
|
-
#
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
152
|
-
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
153
|
-
dataset=final_splits["train"],
|
154
|
-
data_args=data_args,
|
155
|
-
concept_name_mapping={},
|
156
|
-
)
|
157
|
-
cehrgpt_tokenizer.save_pretrained(
|
158
|
-
os.path.expanduser(training_args.output_dir)
|
159
|
-
)
|
160
|
-
|
149
|
+
# If the full dataset has been tokenized, we don't want to tokenize the cohort containing
|
150
|
+
# the subset of the data. We should slice out the portion of the tokenized sequences for each sample
|
151
|
+
if cehrgpt_args.tokenized_full_dataset_path is not None:
|
152
|
+
processed_dataset = extract_cohort_sequences(
|
153
|
+
data_args, cehrgpt_args, cache_file_collector
|
154
|
+
)
|
155
|
+
else:
|
156
|
+
# Organize them into a single DatasetDict
|
157
|
+
final_splits = prepare_finetune_dataset(
|
158
|
+
data_args, training_args, cehrgpt_args, cache_file_collector
|
159
|
+
)
|
161
160
|
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
161
|
+
if not data_args.streaming:
|
162
|
+
all_columns = final_splits["train"].column_names
|
163
|
+
if "visit_concept_ids" in all_columns:
|
164
|
+
final_splits = final_splits.remove_columns(
|
165
|
+
["visit_concept_ids"]
|
166
|
+
)
|
167
|
+
|
168
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
169
|
+
dataset=final_splits,
|
170
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
171
|
+
data_args=data_args,
|
172
|
+
cache_file_collector=cache_file_collector,
|
173
|
+
)
|
173
174
|
if not data_args.streaming:
|
174
175
|
processed_dataset.save_to_disk(prepared_ds_path)
|
175
176
|
processed_dataset.cleanup_cache_files()
|
@@ -218,10 +219,6 @@ def main():
|
|
218
219
|
len(processed_dataset["test"]),
|
219
220
|
)
|
220
221
|
|
221
|
-
LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
|
222
|
-
LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
|
223
|
-
if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
224
|
-
cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
225
222
|
if (
|
226
223
|
cehrgpt_model.config.max_position_embeddings
|
227
224
|
< model_args.max_position_embeddings
|
@@ -244,6 +241,7 @@ def main():
|
|
244
241
|
SamplePackingCehrGptDataCollator,
|
245
242
|
cehrgpt_args.max_tokens_per_batch,
|
246
243
|
cehrgpt_model.config.max_position_embeddings,
|
244
|
+
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
247
245
|
)
|
248
246
|
train_batch_sampler = SamplePackingBatchSampler(
|
249
247
|
lengths=train_set["num_of_concepts"],
|
@@ -278,6 +276,7 @@ def main():
|
|
278
276
|
include_ttv_prediction=False,
|
279
277
|
use_sub_time_tokenization=False,
|
280
278
|
include_demographics=cehrgpt_args.include_demographics,
|
279
|
+
add_linear_prob_token=True,
|
281
280
|
)
|
282
281
|
|
283
282
|
train_loader = DataLoader(
|
@@ -298,30 +297,38 @@ def main():
|
|
298
297
|
batch_sampler=test_batch_sampler,
|
299
298
|
)
|
300
299
|
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
300
|
+
if data_args.is_data_in_meds:
|
301
|
+
demographics_dict = dict()
|
302
|
+
else:
|
303
|
+
# Loading demographics
|
304
|
+
print("Loading demographics as a dictionary")
|
305
|
+
demographics_df = pd.concat(
|
306
|
+
[
|
307
|
+
pd.read_parquet(
|
308
|
+
data_dir,
|
309
|
+
columns=[
|
310
|
+
"person_id",
|
311
|
+
"index_date",
|
312
|
+
"gender_concept_id",
|
313
|
+
"race_concept_id",
|
314
|
+
],
|
315
|
+
)
|
316
|
+
for data_dir in [data_args.data_folder, data_args.test_data_folder]
|
317
|
+
]
|
318
|
+
)
|
319
|
+
|
320
|
+
demographics_df["index_date"] = (
|
321
|
+
demographics_df["index_date"].dt.tz_localize("UTC")
|
322
|
+
- datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
|
323
|
+
).dt.total_seconds()
|
324
|
+
|
325
|
+
demographics_dict = {
|
326
|
+
(row["person_id"], row["index_date"]): {
|
327
|
+
"gender_concept_id": row["gender_concept_id"],
|
328
|
+
"race_concept_id": row["race_concept_id"],
|
329
|
+
}
|
330
|
+
for _, row in demographics_df.iterrows()
|
322
331
|
}
|
323
|
-
for _, row in demographics_df.iterrows()
|
324
|
-
}
|
325
332
|
|
326
333
|
data_loaders = [("train", train_loader), ("test", test_dataloader)]
|
327
334
|
|
@@ -351,9 +358,16 @@ def main():
|
|
351
358
|
prediction_time_posix = batch.pop("index_date").numpy().squeeze()
|
352
359
|
if prediction_time_posix.ndim == 0:
|
353
360
|
prediction_time_posix = np.asarray([prediction_time_posix])
|
361
|
+
|
354
362
|
prediction_time = list(
|
355
|
-
map(
|
363
|
+
map(
|
364
|
+
lambda posix_time: datetime.datetime.utcfromtimestamp(
|
365
|
+
posix_time
|
366
|
+
).replace(tzinfo=None),
|
367
|
+
prediction_time_posix,
|
368
|
+
)
|
356
369
|
)
|
370
|
+
|
357
371
|
labels = (
|
358
372
|
batch.pop("classifier_label")
|
359
373
|
.float()
|
@@ -365,6 +379,10 @@ def main():
|
|
365
379
|
if labels.ndim == 0:
|
366
380
|
labels = np.asarray([labels])
|
367
381
|
|
382
|
+
# Right now the model does not support this column, we need to pop it
|
383
|
+
if "epoch_times" in batch:
|
384
|
+
batch.pop("epoch_times")
|
385
|
+
|
368
386
|
batch = {k: v.to(device) for k, v in batch.items()}
|
369
387
|
# Forward pass
|
370
388
|
cehrgpt_output = cehrgpt_model(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -12,14 +12,15 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Requires-Python: >=3.10.0
|
13
13
|
Description-Content-Type: text/markdown
|
14
14
|
License-File: LICENSE
|
15
|
-
Requires-Dist: cehrbert==1.4.
|
16
|
-
Requires-Dist: cehrbert_data==0.0.
|
15
|
+
Requires-Dist: cehrbert==1.4.5
|
16
|
+
Requires-Dist: cehrbert_data==0.0.11
|
17
17
|
Requires-Dist: openai==1.54.3
|
18
18
|
Requires-Dist: optuna==4.0.0
|
19
|
-
Requires-Dist: transformers==4.44.
|
19
|
+
Requires-Dist: transformers==4.44.1
|
20
20
|
Requires-Dist: tokenizers==0.19.0
|
21
21
|
Requires-Dist: peft==0.10.0
|
22
22
|
Requires-Dist: lightgbm
|
23
|
+
Requires-Dist: polars
|
23
24
|
Provides-Extra: dev
|
24
25
|
Requires-Dist: pre-commit; extra == "dev"
|
25
26
|
Requires-Dist: pytest; extra == "dev"
|
@@ -36,9 +37,9 @@ Dynamic: license-file
|
|
36
37
|
|
37
38
|
[](https://pypi.org/project/cehrgpt/)
|
38
39
|

|
39
|
-
[](https://github.com/knatarajan-lab/cehrgpt
|
41
|
-
[](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
|
41
|
+
[](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
|
42
|
+
[](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
|
42
43
|
|
43
44
|
## Description
|
44
45
|
CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
|
@@ -104,6 +105,100 @@ sh scripts/omop_pipeline.sh \
|
|
104
105
|
$OMOP_VOCAB_DIR
|
105
106
|
```
|
106
107
|
|
108
|
+
# MEDS Support
|
109
|
+
|
110
|
+
This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
|
111
|
+
|
112
|
+
## Prerequisites
|
113
|
+
|
114
|
+
Set up the required environment variables before beginning:
|
115
|
+
|
116
|
+
```bash
|
117
|
+
export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
|
118
|
+
export MEDS_DIR="" # Path to MEDS data directory
|
119
|
+
export MEDS_READER_DIR="" # Path to MEDS reader output directory
|
120
|
+
```
|
121
|
+
|
122
|
+
## Step 1: Create MIMIC MEDS Data
|
123
|
+
|
124
|
+
Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
|
125
|
+
|
126
|
+
## Step 2: Create the MEDS Reader
|
127
|
+
|
128
|
+
Convert the MEDS data for use with CEHR-GPT:
|
129
|
+
|
130
|
+
```bash
|
131
|
+
meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
|
132
|
+
```
|
133
|
+
|
134
|
+
## Step 3: Pretrain CEHR-GPT
|
135
|
+
|
136
|
+
Run the pretraining process using the prepared MEDS data:
|
137
|
+
|
138
|
+
```bash
|
139
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
|
140
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
141
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
142
|
+
--output_dir $CEHR_GPT_MODEL_DIR \
|
143
|
+
--data_folder $MEDS_READER_DIR \
|
144
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
145
|
+
--do_train true --seed 42 \
|
146
|
+
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
|
147
|
+
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
|
148
|
+
--evaluation_strategy epoch --save_strategy epoch \
|
149
|
+
--sample_packing --max_tokens_per_batch 16384 \
|
150
|
+
--warmup_steps 500 --weight_decay 0.01 \
|
151
|
+
--num_train_epochs 50 --learning_rate 0.0002 \
|
152
|
+
--use_early_stopping --early_stopping_threshold 0.001 \
|
153
|
+
--is_data_in_meds --inpatient_att_function_type day \
|
154
|
+
--att_function_type day --include_inpatient_hour_token \
|
155
|
+
--include_auxiliary_token --include_demographic_prompt \
|
156
|
+
--meds_to_cehrbert_conversion_type "MedsToBertMimic4"
|
157
|
+
```
|
158
|
+
|
159
|
+
## Step 4: Generate MEDS Trajectories
|
160
|
+
|
161
|
+
### Environment Setup for Trajectory Generation
|
162
|
+
|
163
|
+
Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
|
164
|
+
|
165
|
+
```bash
|
166
|
+
# MEDS_LABEL_COHORT_DIR must contain a set of parquet files
|
167
|
+
export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
|
168
|
+
export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
|
169
|
+
```
|
170
|
+
|
171
|
+
### Generate Trajectories
|
172
|
+
|
173
|
+
Create synthetic patient trajectories using the trained model:
|
174
|
+
|
175
|
+
> **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
|
176
|
+
|
177
|
+
```bash
|
178
|
+
python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
|
179
|
+
--cohort_folder $MEDS_LABEL_COHORT_DIR \
|
180
|
+
--data_folder $MEDS_READER_DIR \
|
181
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
182
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
183
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
184
|
+
--output_dir $MEDS_TRAJECTORY_DIR \
|
185
|
+
--per_device_eval_batch_size 16 \
|
186
|
+
--num_of_trajectories_per_sample 2 \
|
187
|
+
--generation_input_length 4096 \
|
188
|
+
--generation_max_new_tokens 4096 \
|
189
|
+
--is_data_in_meds \
|
190
|
+
--att_function_type day --inpatient_att_function_type day \
|
191
|
+
--meds_to_cehrbert_conversion_type MedsToBertMimic4 \
|
192
|
+
--include_auxiliary_token --include_demographic_prompt \
|
193
|
+
--include_inpatient_hour_token
|
194
|
+
```
|
195
|
+
|
196
|
+
### Parameters Explanation
|
197
|
+
|
198
|
+
- `generation_input_length`: Controls the length of input context for generation
|
199
|
+
- `generation_max_new_tokens`: Maximum number of new tokens to generate
|
200
|
+
- `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
|
201
|
+
|
107
202
|
## Citation
|
108
203
|
```
|
109
204
|
@article{cehrgpt2024,
|
@@ -1,8 +1,9 @@
|
|
1
1
|
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
cehrgpt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
cehrgpt/cehrgpt_args.py,sha256=zPLp9Qjlq5PapWx3R15BNnyaX8zV3dxr4PuWj71r0Lg,3516
|
4
|
-
cehrgpt/gpt_utils.py,sha256=
|
4
|
+
cehrgpt/gpt_utils.py,sha256=IA5qw-hxcKkGO07AB47lDNRU6mlb9jblpKO7KeLLN78,11342
|
5
5
|
cehrgpt/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
cehrgpt/analysis/irregularity.py,sha256=Rfl_daMvSh9cZ68vUwfmuH-JYCFXdAph2ITHHffYC0Y,1047
|
6
7
|
cehrgpt/analysis/privacy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
8
|
cehrgpt/analysis/privacy/attribute_inference.py,sha256=0ANVW0I5uvOl6IxQ15-vMVQd0mugOgSGReBUQQESImg,9368
|
8
9
|
cehrgpt/analysis/privacy/attribute_inference_config.yml,sha256=hfLfpBlDqqsNOynpRHK414vV24edKA6ta-inmEhM2ao,103272
|
@@ -11,22 +12,23 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
|
|
11
12
|
cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
|
12
13
|
cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
|
13
14
|
cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=
|
15
|
-
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=
|
16
|
-
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=
|
17
|
-
cehrgpt/data/sample_packing_sampler.py,sha256=
|
15
|
+
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=hwJlGW7XiJIr6cXtmwvReQf9yLZJPD-dvJGvRg5ERqU,3755
|
16
|
+
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=juM5HeZScgj8w15Bl1qC83Swld4gY6avh0QkSWLqITA,45465
|
17
|
+
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=_QDX9NXfmQ_S3kOf3yndb3AhoEeFiSzAOv836uYW0AY,26230
|
18
|
+
cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
|
18
19
|
cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
+
cehrgpt/generation/cehrgpt_conditional_generation.py,sha256=AM76yaPyw1B-bcdei24HO0uspGZWHGKWpYpHywotTIQ,11972
|
19
21
|
cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
|
20
|
-
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=
|
21
|
-
cehrgpt/generation/omop_converter_batch.py,sha256
|
22
|
+
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=P8al4-zqymqEkCHCCu2sqz_45akcKF2o_AtQIjJdVmQ,11919
|
23
|
+
cehrgpt/generation/omop_converter_batch.py,sha256=LUmCD-t_6ZP1YfNDZCqYewl-XIIaIgRZ_dAxuR_VdCQ,26275
|
22
24
|
cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
|
23
25
|
cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
24
|
-
cehrgpt/models/config.py,sha256=
|
25
|
-
cehrgpt/models/hf_cehrgpt.py,sha256=
|
26
|
-
cehrgpt/models/hf_modeling_outputs.py,sha256=
|
26
|
+
cehrgpt/models/config.py,sha256=nOAKgH5420HLCcy7n1hE7MbqR861Iq4DTutKoAd25tg,11090
|
27
|
+
cehrgpt/models/hf_cehrgpt.py,sha256=3P7bOLDr7NMSedGszhmlJJN4Mhpd_65-x6uzwvSjigE,92837
|
28
|
+
cehrgpt/models/hf_modeling_outputs.py,sha256=5X4WEYKqT37phv_e5ZAv3A_N0wqdAUJLJRm6TxS6dDQ,10356
|
27
29
|
cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
|
28
|
-
cehrgpt/models/special_tokens.py,sha256
|
29
|
-
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=
|
30
|
+
cehrgpt/models/special_tokens.py,sha256=lrw45B4tea4Dsajn09Cz6w5D2TfHmYXikZkgwnstu_o,521
|
31
|
+
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=cAxHTctpVBxfWfC3XcwDQavN1zwWN9Nid_Fajd5zQWQ,53159
|
30
32
|
cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
33
|
cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
|
32
34
|
cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
|
@@ -37,22 +39,23 @@ cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
|
|
37
39
|
cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
|
38
40
|
cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
|
39
41
|
cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
-
cehrgpt/runners/data_utils.py,sha256=
|
42
|
+
cehrgpt/runners/data_utils.py,sha256=i-krtBx_6rvPYtdLdDoWwOTtJcaovd0wH8gBYmgN2l4,16013
|
41
43
|
cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
|
42
|
-
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=
|
43
|
-
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=
|
44
|
-
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=
|
45
|
-
cehrgpt/runners/hyperparameter_search_util.py,sha256=
|
46
|
-
cehrgpt/runners/sample_packing_trainer.py,sha256=
|
44
|
+
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=1OgxLm4T7iHv5pKi2QaSdaz9ogWo2n3sSUGp6cHDF9s,28309
|
45
|
+
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ERSnvB38fPYVghtKQeNTZ8VfeXnoRcCHB0cWISWaZ84,26523
|
46
|
+
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=fJR4RHPqal1YI6_KUH-WlkoQLSZuBT5bKUGfPHDFrWI,9350
|
47
|
+
cehrgpt/runners/hyperparameter_search_util.py,sha256=YWdFQ1igQs-G_wqWUrUzYraGiz8OSpSYyvid-I5nhWA,9262
|
48
|
+
cehrgpt/runners/sample_packing_trainer.py,sha256=Zb7Aqwnk8-VqrjEKUVeg5XzZWmHxXOU2sDn1YURS-FU,7960
|
47
49
|
cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
50
|
cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
|
49
51
|
cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
|
50
52
|
cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
|
51
53
|
cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
|
52
54
|
cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
53
|
-
cehrgpt/time_to_event/time_to_event_model.py,sha256=
|
54
|
-
cehrgpt/time_to_event/time_to_event_prediction.py,sha256=
|
55
|
+
cehrgpt/time_to_event/time_to_event_model.py,sha256=Plm0bZxvlAbnMl82DTBXWvaXLvrqcdkzcP_celX8WC4,8055
|
56
|
+
cehrgpt/time_to_event/time_to_event_prediction.py,sha256=W2e7UqIV7ELdfTy997HS66vggjnhdncCKt840knI0Dw,13183
|
55
57
|
cehrgpt/time_to_event/time_to_event_utils.py,sha256=KN4hwGgxy2nJtO7osbYQBF3-HpmGUWefNfexzPYiEwc,1937
|
58
|
+
cehrgpt/time_to_event/config/1_year_cabg.yaml,sha256=SFF2-F5D02pDSMRddDrEUoERBCd0t2Hzln_xC-Mo2hA,407
|
56
59
|
cehrgpt/time_to_event/config/30_day_readmission.yaml,sha256=Hn5KnEXMtSV_CtCpmAU4wjkc0-gTXvniaH991TSbUXA,234
|
57
60
|
cehrgpt/time_to_event/config/next_visit_type_prediction.yaml,sha256=WMj2ZutEvHKIMyGG51xtXaL6MyRANKvpg9xT8ouctLc,319
|
58
61
|
cehrgpt/time_to_event/config/t2dm_hf.yaml,sha256=_oMQzh2eJTYzEaMOpmhAzbX-qmdsKlkORELL6HxOxHo,202
|
@@ -63,10 +66,10 @@ cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8e
|
|
63
66
|
cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
|
64
67
|
cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
|
65
68
|
cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
66
|
-
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=
|
69
|
+
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=Hpx7WvAWm2WwPHFfimCADXh019I7bwdzJ4_5_YCxQzU,19817
|
67
70
|
cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
|
68
|
-
cehrgpt-0.1.
|
69
|
-
cehrgpt-0.1.
|
70
|
-
cehrgpt-0.1.
|
71
|
-
cehrgpt-0.1.
|
72
|
-
cehrgpt-0.1.
|
71
|
+
cehrgpt-0.1.2.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
|
72
|
+
cehrgpt-0.1.2.dist-info/METADATA,sha256=D7gGKrQThiLivViFeNm711NCP8J-wXfkueMGb6RKqV0,8481
|
73
|
+
cehrgpt-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
74
|
+
cehrgpt-0.1.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
|
75
|
+
cehrgpt-0.1.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|