cehrgpt 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +57 -33
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +22 -9
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/models/hf_cehrgpt.py +17 -6
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +9 -1
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +12 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +20 -30
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +95 -1
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +14 -13
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -162,6 +162,22 @@ class CehrGptDataCollator:
|
|
162
162
|
f"batch['input_ids']: {batch['input_ids']} "
|
163
163
|
)
|
164
164
|
|
165
|
+
if "epoch_times" in examples[0]:
|
166
|
+
batch_epoch_times = [
|
167
|
+
self._try_reverse_tensor(
|
168
|
+
self._convert_to_tensor(example["epoch_times"])
|
169
|
+
)
|
170
|
+
for example in examples
|
171
|
+
]
|
172
|
+
# Pad sequences to the max length in the batch
|
173
|
+
batch["epoch_times"] = self._try_reverse_tensor(
|
174
|
+
pad_sequence(
|
175
|
+
batch_epoch_times,
|
176
|
+
batch_first=True,
|
177
|
+
padding_value=0,
|
178
|
+
).to(torch.float32)
|
179
|
+
)
|
180
|
+
|
165
181
|
if "position_ids" in examples[0]:
|
166
182
|
batch_position_ids = [
|
167
183
|
self._try_reverse_tensor(
|
@@ -663,7 +679,9 @@ class CehrGptDataCollator:
|
|
663
679
|
|
664
680
|
# Subtract one for the [END] token when sample_packing is not enabled
|
665
681
|
new_max_length = (
|
666
|
-
max_length_allowed
|
682
|
+
max_length_allowed - 1
|
683
|
+
if not sample_packing and self.pretraining
|
684
|
+
else max_length_allowed
|
667
685
|
)
|
668
686
|
|
669
687
|
if self.include_ttv_prediction:
|
@@ -685,13 +703,20 @@ class CehrGptDataCollator:
|
|
685
703
|
|
686
704
|
# Return the record directly if the actual sequence length is less than the max sequence
|
687
705
|
if seq_length <= new_max_length:
|
688
|
-
if not sample_packing:
|
706
|
+
if not sample_packing and self.pretraining:
|
689
707
|
record["input_ids"] = torch.concat(
|
690
708
|
[
|
691
709
|
self._convert_to_tensor(record["input_ids"]),
|
692
710
|
self._convert_to_tensor([eos_token]),
|
693
711
|
]
|
694
712
|
)
|
713
|
+
if "epoch_times" in record:
|
714
|
+
record["epoch_times"] = torch.concat(
|
715
|
+
[
|
716
|
+
self._convert_to_tensor(record["epoch_times"]),
|
717
|
+
self._convert_to_tensor([record["epoch_times"][-1]]),
|
718
|
+
]
|
719
|
+
)
|
695
720
|
if self.include_values:
|
696
721
|
record["value_indicators"] = torch.concat(
|
697
722
|
[
|
@@ -727,6 +752,10 @@ class CehrGptDataCollator:
|
|
727
752
|
record["input_ids"] = self._convert_to_tensor(
|
728
753
|
record["input_ids"][start_index : end_index + 1]
|
729
754
|
)
|
755
|
+
if "epoch_times" in record:
|
756
|
+
record["epoch_times"] = self._convert_to_tensor(
|
757
|
+
record["epoch_times"][start_index : end_index + 1]
|
758
|
+
)
|
730
759
|
if self.include_values:
|
731
760
|
record["value_indicators"] = self._convert_to_tensor(
|
732
761
|
record["value_indicators"][start_index : end_index + 1]
|
@@ -760,6 +789,11 @@ class CehrGptDataCollator:
|
|
760
789
|
if sample_packing and "position_ids" in record:
|
761
790
|
record["position_ids"] = record["position_ids"][0:end_index]
|
762
791
|
|
792
|
+
if "epoch_times" in record:
|
793
|
+
record["epoch_times"] = self._convert_to_tensor(
|
794
|
+
record["epoch_times"][0:end_index]
|
795
|
+
)
|
796
|
+
|
763
797
|
if self.include_values:
|
764
798
|
record["value_indicators"] = self._convert_to_tensor(
|
765
799
|
record["value_indicators"][0:end_index]
|
@@ -792,6 +826,17 @@ class CehrGptDataCollator:
|
|
792
826
|
),
|
793
827
|
]
|
794
828
|
)
|
829
|
+
if "epoch_times" in record:
|
830
|
+
record["epoch_times"] = torch.concat(
|
831
|
+
[
|
832
|
+
torch.zeros(
|
833
|
+
[record["epoch_times"][0]], dtype=torch.float32
|
834
|
+
),
|
835
|
+
self._convert_to_tensor(
|
836
|
+
record["epoch_times"][token_index:seq_length]
|
837
|
+
),
|
838
|
+
]
|
839
|
+
)
|
795
840
|
if self.include_values:
|
796
841
|
record["value_indicators"] = torch.concat(
|
797
842
|
[
|
@@ -830,7 +875,7 @@ class CehrGptDataCollator:
|
|
830
875
|
)
|
831
876
|
break
|
832
877
|
else:
|
833
|
-
start_index = seq_length - new_max_length
|
878
|
+
start_index = max(seq_length - new_max_length, 0)
|
834
879
|
end_index = seq_length
|
835
880
|
for i in range(start_index, end_index):
|
836
881
|
current_token = record["input_ids"][i]
|
@@ -842,6 +887,11 @@ class CehrGptDataCollator:
|
|
842
887
|
]
|
843
888
|
if sample_packing and "position_ids" in record:
|
844
889
|
record["position_ids"] = record["position_ids"][i:end_index]
|
890
|
+
|
891
|
+
if "epoch_times" in record:
|
892
|
+
record["epoch_times"] = self._convert_to_tensor(
|
893
|
+
record["epoch_times"][i:end_index]
|
894
|
+
)
|
845
895
|
if self.include_values:
|
846
896
|
record["value_indicators"] = record["value_indicators"][
|
847
897
|
i:end_index
|
@@ -863,6 +913,10 @@ class CehrGptDataCollator:
|
|
863
913
|
]
|
864
914
|
if sample_packing and "position_ids" in record:
|
865
915
|
record["position_ids"] = record["position_ids"][-new_max_length:]
|
916
|
+
if "epoch_times" in record:
|
917
|
+
record["epoch_times"] = self._convert_to_tensor(
|
918
|
+
record["epoch_times"][-new_max_length:]
|
919
|
+
)
|
866
920
|
if self.include_values:
|
867
921
|
record["value_indicators"] = record["value_indicators"][
|
868
922
|
-new_max_length:
|
@@ -873,36 +927,6 @@ class CehrGptDataCollator:
|
|
873
927
|
-new_max_length:
|
874
928
|
]
|
875
929
|
|
876
|
-
if not sample_packing:
|
877
|
-
# Finally we add the end token to the end of the sequence
|
878
|
-
record["input_ids"] = torch.concat(
|
879
|
-
[
|
880
|
-
self._convert_to_tensor(record["input_ids"]),
|
881
|
-
self._convert_to_tensor([eos_token]),
|
882
|
-
]
|
883
|
-
)
|
884
|
-
if self.include_values:
|
885
|
-
record["value_indicators"] = torch.concat(
|
886
|
-
[
|
887
|
-
self._convert_to_tensor(record["value_indicators"]),
|
888
|
-
self._convert_to_tensor([False]),
|
889
|
-
]
|
890
|
-
).to(torch.bool)
|
891
|
-
record["values"] = torch.concat(
|
892
|
-
[
|
893
|
-
self._convert_to_tensor(record["values"]),
|
894
|
-
self._convert_to_tensor(
|
895
|
-
[self.tokenizer.pad_value_token_id]
|
896
|
-
),
|
897
|
-
]
|
898
|
-
)
|
899
|
-
if self.include_ttv_prediction:
|
900
|
-
record["time_to_visits"] = torch.concat(
|
901
|
-
[
|
902
|
-
record["time_to_visits"],
|
903
|
-
self._convert_to_tensor([-100.0]),
|
904
|
-
]
|
905
|
-
)
|
906
930
|
return record
|
907
931
|
|
908
932
|
|
@@ -21,7 +21,6 @@ from cehrbert_data.const.artificial_tokens import (
|
|
21
21
|
DISCHARGE_UNKNOWN_TOKEN,
|
22
22
|
GENDER_UNKNOWN_TOKEN,
|
23
23
|
RACE_UNKNOWN_TOKEN,
|
24
|
-
VISIT_UNKNOWN_TOKEN,
|
25
24
|
)
|
26
25
|
from cehrbert_data.const.common import NA
|
27
26
|
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
@@ -47,10 +46,16 @@ CEHRGPT_COLUMNS = [
|
|
47
46
|
]
|
48
47
|
|
49
48
|
|
50
|
-
def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
51
|
-
|
52
|
-
|
53
|
-
|
49
|
+
def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
|
50
|
+
if isinstance(index_date, datetime.date):
|
51
|
+
return (
|
52
|
+
datetime.datetime.combine(index_date, datetime.datetime.min.time())
|
53
|
+
.replace(tzinfo=datetime.timezone.utc)
|
54
|
+
.timestamp()
|
55
|
+
)
|
56
|
+
elif isinstance(index_date, datetime.datetime):
|
57
|
+
return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
|
58
|
+
return index_date
|
54
59
|
|
55
60
|
|
56
61
|
class DatasetMappingDecorator(DatasetMapping):
|
@@ -128,7 +133,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
128
133
|
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
129
134
|
cehrgpt_record["units"].append(unit)
|
130
135
|
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
131
|
-
cehrgpt_record["epoch_times"].append(
|
136
|
+
cehrgpt_record["epoch_times"].append(
|
137
|
+
time.replace(tzinfo=datetime.timezone.utc).timestamp()
|
138
|
+
)
|
132
139
|
|
133
140
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
134
141
|
cehrgpt_record = {
|
@@ -360,7 +367,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
360
367
|
cehrgpt_record["num_of_visits"] = len(visits)
|
361
368
|
|
362
369
|
if record.get("index_date", None) is not None:
|
363
|
-
cehrgpt_record["index_date"] =
|
370
|
+
cehrgpt_record["index_date"] = (
|
371
|
+
record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
|
372
|
+
)
|
364
373
|
if record.get("label", None) is not None:
|
365
374
|
cehrgpt_record["label"] = record["label"]
|
366
375
|
if record.get("age_at_index", None) is not None:
|
@@ -529,9 +538,13 @@ class ExtractTokenizedSequenceDataMapping:
|
|
529
538
|
prediction_start_end_times = [
|
530
539
|
(
|
531
540
|
self._calculate_prediction_start_time(
|
532
|
-
prediction_time_label_map["index_date"]
|
541
|
+
prediction_time_label_map["index_date"]
|
542
|
+
.replace(tzinfo=datetime.timezone.utc)
|
543
|
+
.timestamp()
|
533
544
|
),
|
534
|
-
prediction_time_label_map["index_date"]
|
545
|
+
prediction_time_label_map["index_date"]
|
546
|
+
.replace(tzinfo=datetime.timezone.utc)
|
547
|
+
.timestamp(),
|
535
548
|
prediction_time_label_map["label"],
|
536
549
|
)
|
537
550
|
for prediction_time_label_map in prediction_times
|
@@ -0,0 +1,314 @@
|
|
1
|
+
import datetime
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import shutil
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Dict
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import polars as pl
|
10
|
+
import torch
|
11
|
+
import torch.distributed as dist
|
12
|
+
from cehrbert.runners.runner_util import generate_prepared_ds_path
|
13
|
+
from datasets import load_from_disk
|
14
|
+
from meds import held_out_split, train_split, tuning_split
|
15
|
+
from torch.utils.data import DataLoader
|
16
|
+
from tqdm import tqdm
|
17
|
+
from transformers.trainer_utils import is_main_process
|
18
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
19
|
+
|
20
|
+
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
21
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
22
|
+
from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
|
23
|
+
generate_single_batch,
|
24
|
+
normalize_value,
|
25
|
+
)
|
26
|
+
from cehrgpt.gpt_utils import (
|
27
|
+
extract_time_interval_in_days,
|
28
|
+
extract_time_interval_in_hours,
|
29
|
+
is_att_token,
|
30
|
+
is_inpatient_hour_token,
|
31
|
+
is_visit_end,
|
32
|
+
is_visit_start,
|
33
|
+
)
|
34
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
35
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
36
|
+
from cehrgpt.runners.data_utils import (
|
37
|
+
extract_cohort_sequences,
|
38
|
+
prepare_finetune_dataset,
|
39
|
+
)
|
40
|
+
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
41
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
42
|
+
|
43
|
+
LOG = logging.get_logger("transformers")
|
44
|
+
|
45
|
+
|
46
|
+
def map_data_split_name(split: str) -> str:
|
47
|
+
if split == "train":
|
48
|
+
return train_split
|
49
|
+
elif split == "validation":
|
50
|
+
return tuning_split
|
51
|
+
elif split == "test":
|
52
|
+
return held_out_split
|
53
|
+
raise ValueError(f"Unknown split: {split}")
|
54
|
+
|
55
|
+
|
56
|
+
def seed_all(seed: int = 42):
|
57
|
+
"""Set seed for Python, NumPy, and PyTorch (CPU & CUDA)."""
|
58
|
+
random.seed(seed) # Python random
|
59
|
+
np.random.seed(seed) # NumPy
|
60
|
+
torch.manual_seed(seed) # PyTorch CPU
|
61
|
+
torch.cuda.manual_seed(seed) # Current GPU
|
62
|
+
torch.cuda.manual_seed_all(seed) # All GPUs
|
63
|
+
|
64
|
+
# For reproducibility in dataloader workers
|
65
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
66
|
+
|
67
|
+
|
68
|
+
def generate_trajectories_per_batch(
|
69
|
+
batch: Dict[str, Any],
|
70
|
+
cehrgpt_tokenizer: CehrGptTokenizer,
|
71
|
+
cehrgpt_model: CEHRGPT2LMHeadModel,
|
72
|
+
device,
|
73
|
+
data_output_path: Path,
|
74
|
+
max_length: int,
|
75
|
+
):
|
76
|
+
subject_ids = batch["person_id"].squeeze().detach().cpu().tolist()
|
77
|
+
prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
|
78
|
+
batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
|
79
|
+
batched_input_ids = batch["input_ids"]
|
80
|
+
batched_value_indicators = batch["value_indicators"]
|
81
|
+
batched_values = batch["values"]
|
82
|
+
# Make sure the batch does not exceed batch_size
|
83
|
+
batch_sequences = generate_single_batch(
|
84
|
+
cehrgpt_model,
|
85
|
+
cehrgpt_tokenizer,
|
86
|
+
batched_input_ids,
|
87
|
+
values=batched_values,
|
88
|
+
value_indicators=batched_value_indicators,
|
89
|
+
max_length=max_length,
|
90
|
+
top_p=1.0,
|
91
|
+
top_k=cehrgpt_tokenizer.vocab_size,
|
92
|
+
device=device,
|
93
|
+
)
|
94
|
+
# Clear the cache
|
95
|
+
torch.cuda.empty_cache()
|
96
|
+
|
97
|
+
trajectories = []
|
98
|
+
for sample_i, (concept_ids, value_indicators, values) in enumerate(
|
99
|
+
zip(
|
100
|
+
batch_sequences["sequences"],
|
101
|
+
batch_sequences["value_indicators"],
|
102
|
+
batch_sequences["values"],
|
103
|
+
)
|
104
|
+
):
|
105
|
+
(
|
106
|
+
concept_ids,
|
107
|
+
is_numeric_types,
|
108
|
+
number_as_values,
|
109
|
+
concept_as_values,
|
110
|
+
units,
|
111
|
+
) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
|
112
|
+
|
113
|
+
epoch_times = batched_epoch_times[sample_i]
|
114
|
+
input_length = len(epoch_times)
|
115
|
+
# Getting the last observed event time from the token before the prediction time
|
116
|
+
window_last_observed = epoch_times[input_length - 1]
|
117
|
+
current_cursor = epoch_times[-1]
|
118
|
+
generated_epoch_times = []
|
119
|
+
valid_indices = []
|
120
|
+
|
121
|
+
for i in range(input_length, len(concept_ids)):
|
122
|
+
concept_id = concept_ids[i]
|
123
|
+
# We use the left padding strategy in the data collator
|
124
|
+
if concept_id in [cehrgpt_tokenizer.pad_token, cehrgpt_tokenizer.end_token]:
|
125
|
+
continue
|
126
|
+
# We need to construct the time stamp
|
127
|
+
if is_att_token(concept_id):
|
128
|
+
current_cursor += extract_time_interval_in_days(concept_id) * 24 * 3600
|
129
|
+
elif is_inpatient_hour_token(concept_id):
|
130
|
+
current_cursor += extract_time_interval_in_hours(concept_id) * 3600
|
131
|
+
elif is_visit_start(concept_id) or is_visit_end(concept_id):
|
132
|
+
continue
|
133
|
+
else:
|
134
|
+
valid_indices.append(i)
|
135
|
+
generated_epoch_times.append(
|
136
|
+
datetime.datetime.utcfromtimestamp(current_cursor).replace(
|
137
|
+
tzinfo=None
|
138
|
+
)
|
139
|
+
)
|
140
|
+
|
141
|
+
trajectories.append(
|
142
|
+
{
|
143
|
+
"subject_id": subject_ids[sample_i],
|
144
|
+
"prediction_time": datetime.datetime.utcfromtimestamp(
|
145
|
+
prediction_times[sample_i]
|
146
|
+
).replace(tzinfo=None),
|
147
|
+
"window_last_observed_time": datetime.datetime.utcfromtimestamp(
|
148
|
+
window_last_observed
|
149
|
+
).replace(tzinfo=None),
|
150
|
+
"times": generated_epoch_times,
|
151
|
+
"concept_ids": np.asarray(concept_ids)[valid_indices].tolist(),
|
152
|
+
"numeric_values": np.asarray(number_as_values)[valid_indices].tolist(),
|
153
|
+
"text_value": np.asarray(concept_as_values)[valid_indices].tolist(),
|
154
|
+
"units": np.asarray(units)[valid_indices].tolist(),
|
155
|
+
}
|
156
|
+
)
|
157
|
+
|
158
|
+
trajectories = (
|
159
|
+
pl.DataFrame(trajectories)
|
160
|
+
.explode(["times", "concept_ids", "numeric_values", "text_value", "units"])
|
161
|
+
.rename(
|
162
|
+
{
|
163
|
+
"times": "time",
|
164
|
+
"concept_ids": "code",
|
165
|
+
"numeric_values": "numeric_value",
|
166
|
+
"units": "unit",
|
167
|
+
}
|
168
|
+
)
|
169
|
+
.select(
|
170
|
+
"subject_id",
|
171
|
+
"prediction_time",
|
172
|
+
"window_last_observed_time",
|
173
|
+
"time",
|
174
|
+
"code",
|
175
|
+
"numeric_value",
|
176
|
+
"text_value",
|
177
|
+
"unit",
|
178
|
+
)
|
179
|
+
)
|
180
|
+
trajectories.write_parquet(data_output_path)
|
181
|
+
|
182
|
+
|
183
|
+
def main():
|
184
|
+
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
185
|
+
if torch.cuda.is_available():
|
186
|
+
device = torch.device("cuda")
|
187
|
+
else:
|
188
|
+
device = torch.device("cpu")
|
189
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
190
|
+
model_args.tokenizer_name_or_path
|
191
|
+
)
|
192
|
+
cehrgpt_model = (
|
193
|
+
CEHRGPT2LMHeadModel.from_pretrained(
|
194
|
+
model_args.model_name_or_path,
|
195
|
+
attn_implementation=(
|
196
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
197
|
+
),
|
198
|
+
)
|
199
|
+
.eval()
|
200
|
+
.to(device)
|
201
|
+
)
|
202
|
+
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
203
|
+
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
204
|
+
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
205
|
+
|
206
|
+
if not os.path.exists(training_args.output_dir):
|
207
|
+
os.makedirs(training_args.output_dir)
|
208
|
+
|
209
|
+
prepared_ds_path = generate_prepared_ds_path(
|
210
|
+
data_args, model_args, data_folder=data_args.cohort_folder
|
211
|
+
)
|
212
|
+
|
213
|
+
processed_dataset = None
|
214
|
+
if any(prepared_ds_path.glob("*")):
|
215
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
216
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
217
|
+
LOG.info("Prepared dataset loaded from disk...")
|
218
|
+
if cehrgpt_args.expand_tokenizer:
|
219
|
+
if tokenizer_exists(training_args.output_dir):
|
220
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
221
|
+
training_args.output_dir
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
LOG.warning(
|
225
|
+
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
226
|
+
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
227
|
+
f"Please delete the processed dataset at {prepared_ds_path}."
|
228
|
+
)
|
229
|
+
processed_dataset = None
|
230
|
+
shutil.rmtree(prepared_ds_path)
|
231
|
+
|
232
|
+
if processed_dataset is None and is_main_process(training_args.local_rank):
|
233
|
+
# If the full dataset has been tokenized, we don't want to tokenize the cohort containing
|
234
|
+
# the subset of the data. We should slice out the portion of the tokenized sequences for each sample
|
235
|
+
if cehrgpt_args.tokenized_full_dataset_path is not None:
|
236
|
+
processed_dataset = extract_cohort_sequences(data_args, cehrgpt_args)
|
237
|
+
else:
|
238
|
+
# Organize them into a single DatasetDict
|
239
|
+
final_splits = prepare_finetune_dataset(
|
240
|
+
data_args, training_args, cehrgpt_args
|
241
|
+
)
|
242
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
243
|
+
if not data_args.streaming:
|
244
|
+
all_columns = final_splits["train"].column_names
|
245
|
+
if "visit_concept_ids" in all_columns:
|
246
|
+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
|
247
|
+
|
248
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
249
|
+
dataset=final_splits,
|
250
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
251
|
+
data_args=data_args,
|
252
|
+
)
|
253
|
+
if not data_args.streaming:
|
254
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
255
|
+
processed_dataset.cleanup_cache_files()
|
256
|
+
|
257
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
258
|
+
if dist.is_available() and dist.is_initialized():
|
259
|
+
dist.barrier()
|
260
|
+
|
261
|
+
# We suppress the additional learning objectives in fine-tuning
|
262
|
+
data_collator = CehrGptDataCollator(
|
263
|
+
tokenizer=cehrgpt_tokenizer,
|
264
|
+
max_length=cehrgpt_args.generation_input_length,
|
265
|
+
include_values=cehrgpt_model.config.include_values,
|
266
|
+
pretraining=False,
|
267
|
+
include_ttv_prediction=False,
|
268
|
+
use_sub_time_tokenization=False,
|
269
|
+
include_demographics=False,
|
270
|
+
add_linear_prob_token=False,
|
271
|
+
)
|
272
|
+
|
273
|
+
LOG.info(
|
274
|
+
"Generating %s trajectories per sample",
|
275
|
+
cehrgpt_args.num_of_trajectories_per_sample,
|
276
|
+
)
|
277
|
+
for sample_i in range(cehrgpt_args.num_of_trajectories_per_sample):
|
278
|
+
for split, dataset in processed_dataset.items():
|
279
|
+
meds_split = map_data_split_name(split)
|
280
|
+
dataloader = DataLoader(
|
281
|
+
dataset=dataset,
|
282
|
+
batch_size=training_args.per_device_eval_batch_size,
|
283
|
+
num_workers=training_args.dataloader_num_workers,
|
284
|
+
collate_fn=data_collator,
|
285
|
+
pin_memory=training_args.dataloader_pin_memory,
|
286
|
+
)
|
287
|
+
sample_output_dir = (
|
288
|
+
Path(training_args.output_dir) / meds_split / f"{sample_i}"
|
289
|
+
)
|
290
|
+
sample_output_dir.mkdir(exist_ok=True, parents=True)
|
291
|
+
for batch_i, batch in tqdm(
|
292
|
+
enumerate(dataloader),
|
293
|
+
desc=f"Generating Trajectories for split {meds_split} with trajectory {sample_i + 1}",
|
294
|
+
):
|
295
|
+
output_parquet_file = sample_output_dir / f"{batch_i}.parquet"
|
296
|
+
if output_parquet_file.exists():
|
297
|
+
LOG.info("%s already exists, skip...", output_parquet_file)
|
298
|
+
continue
|
299
|
+
|
300
|
+
generate_trajectories_per_batch(
|
301
|
+
batch,
|
302
|
+
cehrgpt_tokenizer,
|
303
|
+
cehrgpt_model,
|
304
|
+
device,
|
305
|
+
sample_output_dir / f"{batch_i}.parquet",
|
306
|
+
cehrgpt_args.generation_max_new_tokens
|
307
|
+
+ cehrgpt_args.generation_input_length,
|
308
|
+
)
|
309
|
+
|
310
|
+
|
311
|
+
if __name__ == "__main__":
|
312
|
+
# ✅ Call first thing inside main()
|
313
|
+
seed_all(42)
|
314
|
+
main()
|
@@ -74,7 +74,10 @@ def generate_single_batch(
|
|
74
74
|
model: CEHRGPT2LMHeadModel,
|
75
75
|
tokenizer: CehrGptTokenizer,
|
76
76
|
prompts: List[List[int]],
|
77
|
-
|
77
|
+
max_length: int,
|
78
|
+
values: Optional[torch.Tensor] = None,
|
79
|
+
value_indicators: Optional[torch.Tensor] = None,
|
80
|
+
max_new_tokens: Optional[int] = None,
|
78
81
|
mini_num_of_concepts=1,
|
79
82
|
top_p=0.95,
|
80
83
|
top_k=50,
|
@@ -88,7 +91,8 @@ def generate_single_batch(
|
|
88
91
|
with torch.no_grad():
|
89
92
|
generation_config = GenerationConfig(
|
90
93
|
repetition_penalty=repetition_penalty,
|
91
|
-
|
94
|
+
max_new_tokens=max_new_tokens,
|
95
|
+
max_length=max_length,
|
92
96
|
min_length=mini_num_of_concepts,
|
93
97
|
temperature=temperature,
|
94
98
|
top_p=top_p,
|
@@ -107,9 +111,17 @@ def generate_single_batch(
|
|
107
111
|
num_beam_groups=num_beam_groups,
|
108
112
|
epsilon_cutoff=epsilon_cutoff,
|
109
113
|
)
|
114
|
+
|
110
115
|
batched_prompts = torch.tensor(prompts).to(device)
|
116
|
+
if values is not None:
|
117
|
+
values = values.to(device)
|
118
|
+
if value_indicators is not None:
|
119
|
+
value_indicators = value_indicators.to(device)
|
120
|
+
|
111
121
|
results = model.generate(
|
112
122
|
inputs=batched_prompts,
|
123
|
+
values=values,
|
124
|
+
value_indicators=value_indicators,
|
113
125
|
generation_config=generation_config,
|
114
126
|
lab_token_ids=tokenizer.lab_token_ids,
|
115
127
|
)
|
@@ -226,7 +238,7 @@ def main(args):
|
|
226
238
|
cehrgpt_model,
|
227
239
|
cehrgpt_tokenizer,
|
228
240
|
random_prompts[: args.batch_size],
|
229
|
-
|
241
|
+
max_length=args.context_window,
|
230
242
|
mini_num_of_concepts=args.min_num_of_concepts,
|
231
243
|
top_p=args.top_p,
|
232
244
|
top_k=args.top_k,
|
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -102,7 +102,9 @@ def is_sample_pack(attention_mask: torch.Tensor) -> bool:
|
|
102
102
|
attention_mask = attention_mask.flip(dims=[1])
|
103
103
|
|
104
104
|
nonzero_counts = attention_mask.sum(dim=1)
|
105
|
-
max_token_positions = torch.argmax(
|
105
|
+
max_token_positions = torch.argmax(
|
106
|
+
attention_mask.to(torch.int32).flip(dims=[1]), dim=1
|
107
|
+
)
|
106
108
|
max_indices = attention_mask.shape[1] - 1 - max_token_positions
|
107
109
|
return torch.any(nonzero_counts < (max_indices + 1)).item()
|
108
110
|
|
@@ -1848,6 +1850,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1848
1850
|
|
1849
1851
|
# keep track of which sequences are already finished
|
1850
1852
|
batch_size, cur_len = input_ids.shape
|
1853
|
+
model_kwargs["attention_mask"] = input_ids != pad_token_id
|
1851
1854
|
if "inputs_embeds" in model_kwargs:
|
1852
1855
|
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
1853
1856
|
this_peer_finished = False
|
@@ -1866,11 +1869,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1866
1869
|
[] if self.config.lab_token_ids is None else self.config.lab_token_ids,
|
1867
1870
|
dtype=torch.int32,
|
1868
1871
|
)
|
1869
|
-
|
1870
|
-
|
1871
|
-
|
1872
|
-
|
1873
|
-
|
1872
|
+
|
1873
|
+
if model_kwargs.get("value_indicators", None) is not None:
|
1874
|
+
value_indicators = model_kwargs.get("value_indicators")
|
1875
|
+
else:
|
1876
|
+
value_indicators = torch.zeros_like(input_ids).to(torch.bool)
|
1877
|
+
|
1878
|
+
if model_kwargs.get("values", None) is not None:
|
1879
|
+
values = model_kwargs.get("values")
|
1880
|
+
else:
|
1881
|
+
values = torch.zeros_like(
|
1882
|
+
input_ids,
|
1883
|
+
dtype=torch.int32,
|
1884
|
+
)
|
1874
1885
|
# Generate initial random_vectors
|
1875
1886
|
if self.cehrgpt.config.causal_sfm:
|
1876
1887
|
model_kwargs["random_vectors"] = torch.rand(
|
cehrgpt/runners/data_utils.py
CHANGED
@@ -47,7 +47,7 @@ def prepare_finetune_dataset(
|
|
47
47
|
data_args: DataTrainingArguments,
|
48
48
|
training_args: TrainingArguments,
|
49
49
|
cehrgpt_args: CehrGPTArguments,
|
50
|
-
cache_file_collector: CacheFileCollector,
|
50
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
51
51
|
) -> DatasetDict:
|
52
52
|
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
53
53
|
if data_args.is_data_in_meds:
|
@@ -91,8 +91,9 @@ def prepare_finetune_dataset(
|
|
91
91
|
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
92
92
|
stats,
|
93
93
|
)
|
94
|
-
|
95
|
-
|
94
|
+
if cache_file_collector:
|
95
|
+
# Clean up the files created from the data generator
|
96
|
+
cache_file_collector.remove_cache_files()
|
96
97
|
dataset = load_from_disk(str(meds_extension_path))
|
97
98
|
|
98
99
|
train_set = dataset["train"]
|
@@ -271,7 +272,7 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
|
271
272
|
def extract_cohort_sequences(
|
272
273
|
data_args: DataTrainingArguments,
|
273
274
|
cehrgpt_args: CehrGPTArguments,
|
274
|
-
cache_file_collector: CacheFileCollector,
|
275
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
275
276
|
) -> DatasetDict:
|
276
277
|
"""
|
277
278
|
Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
|
@@ -309,9 +310,18 @@ def extract_cohort_sequences(
|
|
309
310
|
mapping={
|
310
311
|
"prediction_time": "index_date",
|
311
312
|
"subject_id": "person_id",
|
313
|
+
"boolean_value": "label",
|
312
314
|
}
|
313
315
|
)
|
314
316
|
all_person_ids = cohort["person_id"].unique().to_list()
|
317
|
+
# In case the label column does not exist, we add a fake column to the dataframe so subsequent process can work
|
318
|
+
if "label" not in cohort.columns:
|
319
|
+
cohort = cohort.with_columns(
|
320
|
+
pl.Series(
|
321
|
+
name="label", values=np.zeros_like(cohort["person_id"].to_numpy())
|
322
|
+
)
|
323
|
+
)
|
324
|
+
|
315
325
|
# data_args.observation_window
|
316
326
|
tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
|
317
327
|
filtered_tokenized_dataset = tokenized_dataset.filter(
|
@@ -353,6 +363,7 @@ def extract_cohort_sequences(
|
|
353
363
|
num_proc=data_args.preprocessing_num_workers,
|
354
364
|
remove_columns=filtered_tokenized_dataset["train"].column_names,
|
355
365
|
)
|
356
|
-
cache_file_collector
|
357
|
-
|
366
|
+
if cache_file_collector:
|
367
|
+
cache_file_collector.add_cache_files(filtered_tokenized_dataset)
|
368
|
+
cache_file_collector.add_cache_files(processed_dataset)
|
358
369
|
return processed_dataset
|
@@ -580,7 +580,15 @@ def do_predict(
|
|
580
580
|
index_dates = batch.pop("index_date").numpy().squeeze()
|
581
581
|
if index_dates.ndim == 0:
|
582
582
|
index_dates = np.asarray([index_dates])
|
583
|
-
|
583
|
+
|
584
|
+
index_dates = list(
|
585
|
+
map(
|
586
|
+
lambda posix_time: datetime.utcfromtimestamp(posix_time).replace(
|
587
|
+
tzinfo=None
|
588
|
+
),
|
589
|
+
index_dates.tolist(),
|
590
|
+
)
|
591
|
+
)
|
584
592
|
|
585
593
|
batch = {k: v.to(device) for k, v in batch.items()}
|
586
594
|
# Forward pass
|
@@ -229,3 +229,15 @@ class CehrGPTArguments:
|
|
229
229
|
"help": "The probability of negative samples will be included in the training data"
|
230
230
|
},
|
231
231
|
)
|
232
|
+
num_of_trajectories_per_sample: Optional[int] = dataclasses.field(
|
233
|
+
default=1,
|
234
|
+
metadata={"help": "The number of trajectories per sample"},
|
235
|
+
)
|
236
|
+
generation_input_length: Optional[int] = dataclasses.field(
|
237
|
+
default=1024,
|
238
|
+
metadata={"help": "The length of the input sequence"},
|
239
|
+
)
|
240
|
+
generation_max_new_tokens: Optional[int] = dataclasses.field(
|
241
|
+
default=1024,
|
242
|
+
metadata={"help": "The maximum number of tokens in the generation sequence"},
|
243
|
+
)
|
@@ -1,15 +1,14 @@
|
|
1
|
+
import datetime
|
1
2
|
import glob
|
2
3
|
import os
|
3
4
|
import shutil
|
4
5
|
import uuid
|
5
|
-
from datetime import datetime
|
6
6
|
from functools import partial
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Optional, Union
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import pandas as pd
|
12
|
-
import polars as pl
|
13
12
|
import torch
|
14
13
|
import torch.distributed as dist
|
15
14
|
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
@@ -25,7 +24,6 @@ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
|
25
24
|
CehrGptDataCollator,
|
26
25
|
SamplePackingCehrGptDataCollator,
|
27
26
|
)
|
28
|
-
from cehrgpt.data.hf_cehrgpt_dataset_mapping import ExtractTokenizedSequenceDataMapping
|
29
27
|
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
30
28
|
from cehrgpt.models.hf_cehrgpt import (
|
31
29
|
CEHRGPT2Model,
|
@@ -159,24 +157,7 @@ def main():
|
|
159
157
|
final_splits = prepare_finetune_dataset(
|
160
158
|
data_args, training_args, cehrgpt_args, cache_file_collector
|
161
159
|
)
|
162
|
-
|
163
|
-
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
164
|
-
if tokenizer_exists(new_tokenizer_path):
|
165
|
-
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
166
|
-
new_tokenizer_path
|
167
|
-
)
|
168
|
-
else:
|
169
|
-
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
170
|
-
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
171
|
-
dataset=final_splits["train"],
|
172
|
-
data_args=data_args,
|
173
|
-
concept_name_mapping={},
|
174
|
-
)
|
175
|
-
cehrgpt_tokenizer.save_pretrained(
|
176
|
-
os.path.expanduser(training_args.output_dir)
|
177
|
-
)
|
178
|
-
|
179
|
-
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
160
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
180
161
|
if not data_args.streaming:
|
181
162
|
all_columns = final_splits["train"].column_names
|
182
163
|
if "visit_concept_ids" in all_columns:
|
@@ -238,10 +219,6 @@ def main():
|
|
238
219
|
len(processed_dataset["test"]),
|
239
220
|
)
|
240
221
|
|
241
|
-
LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
|
242
|
-
LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
|
243
|
-
if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
244
|
-
cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
245
222
|
if (
|
246
223
|
cehrgpt_model.config.max_position_embeddings
|
247
224
|
< model_args.max_position_embeddings
|
@@ -339,10 +316,12 @@ def main():
|
|
339
316
|
for data_dir in [data_args.data_folder, data_args.test_data_folder]
|
340
317
|
]
|
341
318
|
)
|
342
|
-
|
343
|
-
demographics_df["index_date"] =
|
344
|
-
demographics_df["index_date"]
|
345
|
-
|
319
|
+
|
320
|
+
demographics_df["index_date"] = (
|
321
|
+
demographics_df["index_date"].dt.tz_localize("UTC")
|
322
|
+
- datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
|
323
|
+
).dt.total_seconds()
|
324
|
+
|
346
325
|
demographics_dict = {
|
347
326
|
(row["person_id"], row["index_date"]): {
|
348
327
|
"gender_concept_id": row["gender_concept_id"],
|
@@ -379,9 +358,16 @@ def main():
|
|
379
358
|
prediction_time_posix = batch.pop("index_date").numpy().squeeze()
|
380
359
|
if prediction_time_posix.ndim == 0:
|
381
360
|
prediction_time_posix = np.asarray([prediction_time_posix])
|
361
|
+
|
382
362
|
prediction_time = list(
|
383
|
-
map(
|
363
|
+
map(
|
364
|
+
lambda posix_time: datetime.datetime.utcfromtimestamp(
|
365
|
+
posix_time
|
366
|
+
).replace(tzinfo=None),
|
367
|
+
prediction_time_posix,
|
368
|
+
)
|
384
369
|
)
|
370
|
+
|
385
371
|
labels = (
|
386
372
|
batch.pop("classifier_label")
|
387
373
|
.float()
|
@@ -393,6 +379,10 @@ def main():
|
|
393
379
|
if labels.ndim == 0:
|
394
380
|
labels = np.asarray([labels])
|
395
381
|
|
382
|
+
# Right now the model does not support this column, we need to pop it
|
383
|
+
if "epoch_times" in batch:
|
384
|
+
batch.pop("epoch_times")
|
385
|
+
|
396
386
|
batch = {k: v.to(device) for k, v in batch.items()}
|
397
387
|
# Forward pass
|
398
388
|
cehrgpt_output = cehrgpt_model(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -105,6 +105,100 @@ sh scripts/omop_pipeline.sh \
|
|
105
105
|
$OMOP_VOCAB_DIR
|
106
106
|
```
|
107
107
|
|
108
|
+
# MEDS Support
|
109
|
+
|
110
|
+
This section demonstrates how to pretrain CEHR-GPT using MIMIC-IV data in the MEDS (Medical Event Data Standard) format.
|
111
|
+
|
112
|
+
## Prerequisites
|
113
|
+
|
114
|
+
Set up the required environment variables before beginning:
|
115
|
+
|
116
|
+
```bash
|
117
|
+
export CEHR_GPT_MODEL_DIR="" # Path to CEHR-GPT model directory
|
118
|
+
export MEDS_DIR="" # Path to MEDS data directory
|
119
|
+
export MEDS_READER_DIR="" # Path to MEDS reader output directory
|
120
|
+
```
|
121
|
+
|
122
|
+
## Step 1: Create MIMIC MEDS Data
|
123
|
+
|
124
|
+
Transform your MIMIC files into MEDS format by following the instructions in the [MEDS_transforms](https://github.com/mmcdermott/MEDS_transforms/) repository.
|
125
|
+
|
126
|
+
## Step 2: Create the MEDS Reader
|
127
|
+
|
128
|
+
Convert the MEDS data for use with CEHR-GPT:
|
129
|
+
|
130
|
+
```bash
|
131
|
+
meds_reader_convert $MEDS_DIR $MEDS_READER_DIR --num_threads 10
|
132
|
+
```
|
133
|
+
|
134
|
+
## Step 3: Pretrain CEHR-GPT
|
135
|
+
|
136
|
+
Run the pretraining process using the prepared MEDS data:
|
137
|
+
|
138
|
+
```bash
|
139
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
|
140
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
141
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
142
|
+
--output_dir $CEHR_GPT_MODEL_DIR \
|
143
|
+
--data_folder $MEDS_READER_DIR \
|
144
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
145
|
+
--do_train true --seed 42 \
|
146
|
+
--dataloader_num_workers 16 --dataloader_prefetch_factor 8 \
|
147
|
+
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
|
148
|
+
--evaluation_strategy epoch --save_strategy epoch \
|
149
|
+
--sample_packing --max_tokens_per_batch 16384 \
|
150
|
+
--warmup_steps 500 --weight_decay 0.01 \
|
151
|
+
--num_train_epochs 50 --learning_rate 0.0002 \
|
152
|
+
--use_early_stopping --early_stopping_threshold 0.001 \
|
153
|
+
--is_data_in_meds --inpatient_att_function_type day \
|
154
|
+
--att_function_type day --include_inpatient_hour_token \
|
155
|
+
--include_auxiliary_token --include_demographic_prompt \
|
156
|
+
--meds_to_cehrbert_conversion_type "MedsToBertMimic4"
|
157
|
+
```
|
158
|
+
|
159
|
+
## Step 4: Generate MEDS Trajectories
|
160
|
+
|
161
|
+
### Environment Setup for Trajectory Generation
|
162
|
+
|
163
|
+
Configure additional environment variables for trajectory generation with task labels (`subject_id`, `prediction_time`, `boolean_value` [optional]):
|
164
|
+
|
165
|
+
```bash
|
166
|
+
# MEDS_LABEL_COHORT_DIR must contain a set of parquet files
|
167
|
+
export MEDS_LABEL_COHORT_DIR="" # Path to cohort labels directory
|
168
|
+
export MEDS_TRAJECTORY_DIR="" # Path for trajectory output
|
169
|
+
```
|
170
|
+
|
171
|
+
### Generate Trajectories
|
172
|
+
|
173
|
+
Create synthetic patient trajectories using the trained model:
|
174
|
+
|
175
|
+
> **Important:** The total sequence length (`generation_input_length` + `generation_max_new_tokens`) cannot exceed the `max_position_embeddings` value (8192) defined during pretraining.
|
176
|
+
|
177
|
+
```bash
|
178
|
+
python -u -m cehrgpt.generation.cehrgpt_conditional_generation \
|
179
|
+
--cohort_folder $MEDS_LABEL_COHORT_DIR \
|
180
|
+
--data_folder $MEDS_READER_DIR \
|
181
|
+
--dataset_prepared_path "$CEHR_GPT_MODEL_DIR/dataset_prepared" \
|
182
|
+
--model_name_or_path $CEHR_GPT_MODEL_DIR \
|
183
|
+
--tokenizer_name_or_path $CEHR_GPT_MODEL_DIR \
|
184
|
+
--output_dir $MEDS_TRAJECTORY_DIR \
|
185
|
+
--per_device_eval_batch_size 16 \
|
186
|
+
--num_of_trajectories_per_sample 2 \
|
187
|
+
--generation_input_length 4096 \
|
188
|
+
--generation_max_new_tokens 4096 \
|
189
|
+
--is_data_in_meds \
|
190
|
+
--att_function_type day --inpatient_att_function_type day \
|
191
|
+
--meds_to_cehrbert_conversion_type MedsToBertMimic4 \
|
192
|
+
--include_auxiliary_token --include_demographic_prompt \
|
193
|
+
--include_inpatient_hour_token
|
194
|
+
```
|
195
|
+
|
196
|
+
### Parameters Explanation
|
197
|
+
|
198
|
+
- `generation_input_length`: Controls the length of input context for generation
|
199
|
+
- `generation_max_new_tokens`: Maximum number of new tokens to generate
|
200
|
+
- `num_of_trajectories_per_sample`: Number of trajectories to generate per patient sample
|
201
|
+
|
108
202
|
## Citation
|
109
203
|
```
|
110
204
|
@article{cehrgpt2024,
|
@@ -13,17 +13,18 @@ cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMj
|
|
13
13
|
cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
|
14
14
|
cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=hwJlGW7XiJIr6cXtmwvReQf9yLZJPD-dvJGvRg5ERqU,3755
|
16
|
-
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=
|
17
|
-
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=
|
16
|
+
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=juM5HeZScgj8w15Bl1qC83Swld4gY6avh0QkSWLqITA,45465
|
17
|
+
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=_QDX9NXfmQ_S3kOf3yndb3AhoEeFiSzAOv836uYW0AY,26230
|
18
18
|
cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
|
19
19
|
cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
+
cehrgpt/generation/cehrgpt_conditional_generation.py,sha256=AM76yaPyw1B-bcdei24HO0uspGZWHGKWpYpHywotTIQ,11972
|
20
21
|
cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
|
21
|
-
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=
|
22
|
+
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=P8al4-zqymqEkCHCCu2sqz_45akcKF2o_AtQIjJdVmQ,11919
|
22
23
|
cehrgpt/generation/omop_converter_batch.py,sha256=LUmCD-t_6ZP1YfNDZCqYewl-XIIaIgRZ_dAxuR_VdCQ,26275
|
23
24
|
cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
|
24
25
|
cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
26
|
cehrgpt/models/config.py,sha256=nOAKgH5420HLCcy7n1hE7MbqR861Iq4DTutKoAd25tg,11090
|
26
|
-
cehrgpt/models/hf_cehrgpt.py,sha256=
|
27
|
+
cehrgpt/models/hf_cehrgpt.py,sha256=3P7bOLDr7NMSedGszhmlJJN4Mhpd_65-x6uzwvSjigE,92837
|
27
28
|
cehrgpt/models/hf_modeling_outputs.py,sha256=5X4WEYKqT37phv_e5ZAv3A_N0wqdAUJLJRm6TxS6dDQ,10356
|
28
29
|
cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
|
29
30
|
cehrgpt/models/special_tokens.py,sha256=lrw45B4tea4Dsajn09Cz6w5D2TfHmYXikZkgwnstu_o,521
|
@@ -38,11 +39,11 @@ cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
|
|
38
39
|
cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
|
39
40
|
cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
|
40
41
|
cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
|
-
cehrgpt/runners/data_utils.py,sha256=
|
42
|
+
cehrgpt/runners/data_utils.py,sha256=i-krtBx_6rvPYtdLdDoWwOTtJcaovd0wH8gBYmgN2l4,16013
|
42
43
|
cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
|
43
|
-
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=
|
44
|
+
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=1OgxLm4T7iHv5pKi2QaSdaz9ogWo2n3sSUGp6cHDF9s,28309
|
44
45
|
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ERSnvB38fPYVghtKQeNTZ8VfeXnoRcCHB0cWISWaZ84,26523
|
45
|
-
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=
|
46
|
+
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=fJR4RHPqal1YI6_KUH-WlkoQLSZuBT5bKUGfPHDFrWI,9350
|
46
47
|
cehrgpt/runners/hyperparameter_search_util.py,sha256=YWdFQ1igQs-G_wqWUrUzYraGiz8OSpSYyvid-I5nhWA,9262
|
47
48
|
cehrgpt/runners/sample_packing_trainer.py,sha256=Zb7Aqwnk8-VqrjEKUVeg5XzZWmHxXOU2sDn1YURS-FU,7960
|
48
49
|
cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -65,10 +66,10 @@ cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8e
|
|
65
66
|
cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
|
66
67
|
cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
|
67
68
|
cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
68
|
-
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=
|
69
|
+
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=Hpx7WvAWm2WwPHFfimCADXh019I7bwdzJ4_5_YCxQzU,19817
|
69
70
|
cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
|
70
|
-
cehrgpt-0.1.
|
71
|
-
cehrgpt-0.1.
|
72
|
-
cehrgpt-0.1.
|
73
|
-
cehrgpt-0.1.
|
74
|
-
cehrgpt-0.1.
|
71
|
+
cehrgpt-0.1.2.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
|
72
|
+
cehrgpt-0.1.2.dist-info/METADATA,sha256=D7gGKrQThiLivViFeNm711NCP8J-wXfkueMGb6RKqV0,8481
|
73
|
+
cehrgpt-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
74
|
+
cehrgpt-0.1.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
|
75
|
+
cehrgpt-0.1.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|