cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,347 @@
|
|
1
|
+
import datetime
|
2
|
+
import glob
|
3
|
+
import os
|
4
|
+
import shutil
|
5
|
+
import uuid
|
6
|
+
from dataclasses import asdict, dataclass
|
7
|
+
from typing import Any, Dict, List
|
8
|
+
|
9
|
+
import pandas as pd
|
10
|
+
import torch
|
11
|
+
import yaml
|
12
|
+
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
13
|
+
from datasets import Dataset
|
14
|
+
from tqdm import tqdm
|
15
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
16
|
+
|
17
|
+
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
18
|
+
from cehrgpt.gpt_utils import get_cehrgpt_output_folder, is_visit_end, is_visit_start
|
19
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
20
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
21
|
+
from cehrgpt.time_to_event.time_to_event_model import TimeToEventModel
|
22
|
+
|
23
|
+
LOG = logging.get_logger("transformers")
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class TaskConfig:
|
28
|
+
task_name: str
|
29
|
+
outcome_events: List[str]
|
30
|
+
include_descendants: bool = False
|
31
|
+
future_visit_start: int = 0
|
32
|
+
future_visit_end: int = -1
|
33
|
+
prediction_window_start: int = 0
|
34
|
+
prediction_window_end: int = 365
|
35
|
+
max_new_tokens: int = 128
|
36
|
+
|
37
|
+
|
38
|
+
def load_task_config_from_yaml(task_config_yaml_file_path: str) -> TaskConfig:
|
39
|
+
# Read YAML file
|
40
|
+
try:
|
41
|
+
with open(task_config_yaml_file_path, "r") as stream:
|
42
|
+
task_definition = yaml.safe_load(stream)
|
43
|
+
return TaskConfig(**task_definition)
|
44
|
+
except yaml.YAMLError | OSError as e:
|
45
|
+
raise ValueError(
|
46
|
+
f"Could not open the task_config yaml file from {task_config_yaml_file_path}"
|
47
|
+
) from e
|
48
|
+
|
49
|
+
|
50
|
+
def get_device():
|
51
|
+
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
52
|
+
|
53
|
+
|
54
|
+
def main(args):
|
55
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
56
|
+
cehrgpt_model = (
|
57
|
+
CEHRGPT2LMHeadModel.from_pretrained(
|
58
|
+
args.model_folder,
|
59
|
+
attn_implementation=(
|
60
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
61
|
+
),
|
62
|
+
torch_dtype=(
|
63
|
+
torch.bfloat16
|
64
|
+
if is_flash_attn_2_available() and args.use_bfloat16
|
65
|
+
else "auto"
|
66
|
+
),
|
67
|
+
)
|
68
|
+
.eval()
|
69
|
+
.to(get_device())
|
70
|
+
)
|
71
|
+
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
72
|
+
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
73
|
+
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
74
|
+
|
75
|
+
folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
|
76
|
+
|
77
|
+
task_config = load_task_config_from_yaml(args.task_config)
|
78
|
+
task_name = task_config.task_name
|
79
|
+
outcome_events = task_config.outcome_events
|
80
|
+
|
81
|
+
if task_config.include_descendants:
|
82
|
+
if not args.concept_ancestor:
|
83
|
+
raise RuntimeError(
|
84
|
+
"When include_descendants is set to True, the concept_ancestor data needs to be provided."
|
85
|
+
)
|
86
|
+
concept_ancestor = pd.read_parquet(args.concept_ancestor)
|
87
|
+
ancestor_concept_ids = [int(_) for _ in outcome_events if _.isnumeric()]
|
88
|
+
descendant_concept_ids = (
|
89
|
+
concept_ancestor[
|
90
|
+
concept_ancestor.ancestor_concept_id.isin(ancestor_concept_ids)
|
91
|
+
]
|
92
|
+
.descendant_concept_id.unique()
|
93
|
+
.astype(str)
|
94
|
+
.tolist()
|
95
|
+
)
|
96
|
+
descendant_concept_ids = [
|
97
|
+
_ for _ in descendant_concept_ids if _ not in outcome_events
|
98
|
+
]
|
99
|
+
outcome_events += descendant_concept_ids
|
100
|
+
|
101
|
+
prediction_output_folder_name = os.path.join(
|
102
|
+
args.output_folder, folder_name, task_name
|
103
|
+
)
|
104
|
+
temp_folder = os.path.join(args.output_folder, folder_name, "temp")
|
105
|
+
os.makedirs(prediction_output_folder_name, exist_ok=True)
|
106
|
+
os.makedirs(temp_folder, exist_ok=True)
|
107
|
+
|
108
|
+
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
109
|
+
LOG.info(f"Loading model at {args.model_folder}")
|
110
|
+
LOG.info(f"Loading dataset_folder at {args.dataset_folder}")
|
111
|
+
LOG.info(f"Write time sensitive predictions to {prediction_output_folder_name}")
|
112
|
+
LOG.info(f"Context window {args.context_window}")
|
113
|
+
LOG.info(f"Number of new tokens {task_config.max_new_tokens}")
|
114
|
+
LOG.info(f"Temperature {args.temperature}")
|
115
|
+
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
116
|
+
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
117
|
+
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
118
|
+
LOG.info(f"Top P {args.top_p}")
|
119
|
+
LOG.info(f"Top K {args.top_k}")
|
120
|
+
|
121
|
+
cehrgpt_model.resize_position_embeddings(
|
122
|
+
cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
|
123
|
+
)
|
124
|
+
|
125
|
+
generation_config = TimeToEventModel.get_generation_config(
|
126
|
+
tokenizer=cehrgpt_tokenizer,
|
127
|
+
max_length=cehrgpt_model.config.n_positions,
|
128
|
+
num_return_sequences=args.num_return_sequences,
|
129
|
+
top_p=args.top_p,
|
130
|
+
top_k=args.top_k,
|
131
|
+
temperature=args.temperature,
|
132
|
+
repetition_penalty=args.repetition_penalty,
|
133
|
+
epsilon_cutoff=args.epsilon_cutoff,
|
134
|
+
max_new_tokens=task_config.max_new_tokens,
|
135
|
+
)
|
136
|
+
ts_pred_model = TimeToEventModel(
|
137
|
+
tokenizer=cehrgpt_tokenizer,
|
138
|
+
model=cehrgpt_model,
|
139
|
+
outcome_events=outcome_events,
|
140
|
+
generation_config=generation_config,
|
141
|
+
batch_size=args.batch_size,
|
142
|
+
device=get_device(),
|
143
|
+
)
|
144
|
+
dataset = load_parquet_as_dataset(args.dataset_folder)
|
145
|
+
|
146
|
+
def filter_func(examples):
|
147
|
+
return [_ >= args.min_num_of_concepts for _ in examples["num_of_concepts"]]
|
148
|
+
|
149
|
+
test_dataset = dataset.filter(filter_func, batched=True, batch_size=1000)
|
150
|
+
test_dataset = test_dataset.shuffle(seed=42)
|
151
|
+
|
152
|
+
# Filter out the records for which the predictions have been generated previously
|
153
|
+
test_dataset = filter_out_existing_results(
|
154
|
+
test_dataset, prediction_output_folder_name
|
155
|
+
)
|
156
|
+
tte_outputs = []
|
157
|
+
for record in tqdm(test_dataset, total=len(test_dataset)):
|
158
|
+
sample_identifier = (
|
159
|
+
f"{record['person_id']}_{record['index_date'].strftime('%Y_%m_%d')}"
|
160
|
+
)
|
161
|
+
if acquire_lock_or_skip_if_already_exist(
|
162
|
+
output_folder=temp_folder, sample_id=sample_identifier
|
163
|
+
):
|
164
|
+
continue
|
165
|
+
partial_history = record["concept_ids"]
|
166
|
+
label = record["label"]
|
167
|
+
time_to_event = record["time_to_event"] if "time_to_event" in record else None
|
168
|
+
seq_length = len(partial_history)
|
169
|
+
if (
|
170
|
+
generation_config.max_length
|
171
|
+
<= seq_length + generation_config.max_new_tokens
|
172
|
+
):
|
173
|
+
start_index = seq_length - (
|
174
|
+
generation_config.max_length - generation_config.max_new_tokens
|
175
|
+
)
|
176
|
+
# Make sure the first token starts on VS
|
177
|
+
for i, token in enumerate(partial_history[start_index:]):
|
178
|
+
if is_visit_start(token):
|
179
|
+
start_index += i
|
180
|
+
break
|
181
|
+
partial_history = partial_history[start_index:]
|
182
|
+
|
183
|
+
concept_time_to_event = ts_pred_model.predict_time_to_events(
|
184
|
+
partial_history,
|
185
|
+
task_config.future_visit_start,
|
186
|
+
task_config.future_visit_end,
|
187
|
+
task_config.prediction_window_start,
|
188
|
+
task_config.prediction_window_end,
|
189
|
+
args.debug,
|
190
|
+
args.max_n_trial,
|
191
|
+
)
|
192
|
+
visit_counter = sum([int(is_visit_end(_)) for _ in partial_history])
|
193
|
+
tte_outputs.append(
|
194
|
+
{
|
195
|
+
"person_id": record["person_id"],
|
196
|
+
"index_date": record["index_date"],
|
197
|
+
"visit_counter": visit_counter,
|
198
|
+
"label": label,
|
199
|
+
"time_to_event": time_to_event,
|
200
|
+
"prediction": (
|
201
|
+
asdict(concept_time_to_event) if concept_time_to_event else None
|
202
|
+
),
|
203
|
+
}
|
204
|
+
)
|
205
|
+
delete_lock_create_processed_flag(
|
206
|
+
output_folder=temp_folder, sample_id=sample_identifier
|
207
|
+
)
|
208
|
+
flush_to_disk_if_full(
|
209
|
+
tte_outputs, prediction_output_folder_name, args.buffer_size
|
210
|
+
)
|
211
|
+
|
212
|
+
# Final flush
|
213
|
+
flush_to_disk_if_full(tte_outputs, prediction_output_folder_name, args.buffer_size)
|
214
|
+
# Remove the temp folder
|
215
|
+
shutil.rmtree(temp_folder)
|
216
|
+
|
217
|
+
|
218
|
+
def delete_lock_create_processed_flag(output_folder: str, sample_id: str):
|
219
|
+
processed_flag_file = os.path.join(output_folder, f"{sample_id}.done")
|
220
|
+
# Obtain the lock for this example by creating an empty lock file
|
221
|
+
try:
|
222
|
+
# Using 'x' mode for exclusive creation; fails if the file already exists
|
223
|
+
with open(processed_flag_file, "x"):
|
224
|
+
pass # The file is created; nothing is written to it
|
225
|
+
except FileExistsError as e:
|
226
|
+
raise FileExistsError(
|
227
|
+
f"The processed flag file {processed_flag_file} already exists."
|
228
|
+
) from e
|
229
|
+
|
230
|
+
lock_file = os.path.join(output_folder, f"{sample_id}.lock")
|
231
|
+
# Clean up the lock file
|
232
|
+
# Safely attempt to delete the lock file
|
233
|
+
try:
|
234
|
+
os.remove(lock_file)
|
235
|
+
except OSError as e:
|
236
|
+
raise OSError(f"Can not remove the lock file at {lock_file}") from e
|
237
|
+
|
238
|
+
|
239
|
+
def acquire_lock_or_skip_if_already_exist(output_folder: str, sample_id: str):
|
240
|
+
lock_file = os.path.join(output_folder, f"{sample_id}.lock")
|
241
|
+
if os.path.exists(lock_file):
|
242
|
+
LOG.info(f"Other process acquired the lock --> %s. Skipping...", sample_id)
|
243
|
+
return True
|
244
|
+
processed_flag_file = os.path.join(output_folder, f"{sample_id}.done")
|
245
|
+
if os.path.exists(processed_flag_file):
|
246
|
+
LOG.info(f"The sample has been processed --> %s. Skipping...", sample_id)
|
247
|
+
return True
|
248
|
+
|
249
|
+
# Obtain the lock for this example by creating an empty lock file
|
250
|
+
try:
|
251
|
+
# Using 'x' mode for exclusive creation; fails if the file already exists
|
252
|
+
with open(lock_file, "x"):
|
253
|
+
pass # The file is created; nothing is written to it
|
254
|
+
except FileExistsError:
|
255
|
+
LOG.info(f"Other process acquired the lock --> %s. Skipping...", sample_id)
|
256
|
+
return True
|
257
|
+
return False
|
258
|
+
|
259
|
+
|
260
|
+
def filter_out_existing_results(
|
261
|
+
test_dataset: Dataset, prediction_output_folder_name: str
|
262
|
+
):
|
263
|
+
parquet_files = glob.glob(os.path.join(prediction_output_folder_name, "*parquet"))
|
264
|
+
if parquet_files:
|
265
|
+
cohort_members = set()
|
266
|
+
results_dataframe = pd.read_parquet(parquet_files)[["person_id", "index_date"]]
|
267
|
+
for row in results_dataframe.itertuples():
|
268
|
+
cohort_members.add((row.person_id, row.index_date.strftime("%Y-%m-%d")))
|
269
|
+
|
270
|
+
def filter_func(batched):
|
271
|
+
return [
|
272
|
+
(person_id, index_date.strftime("%Y-%m-%d")) not in cohort_members
|
273
|
+
for person_id, index_date in zip(
|
274
|
+
batched["person_id"], batched["index_date"]
|
275
|
+
)
|
276
|
+
]
|
277
|
+
|
278
|
+
test_dataset = test_dataset.filter(filter_func, batched=True, batch_size=1000)
|
279
|
+
return test_dataset
|
280
|
+
|
281
|
+
|
282
|
+
def flush_to_disk_if_full(
|
283
|
+
tte_outputs: List[Dict[str, Any]], prediction_output_folder_name, buffer_size: int
|
284
|
+
) -> None:
|
285
|
+
if len(tte_outputs) >= buffer_size:
|
286
|
+
LOG.info(
|
287
|
+
f"{datetime.datetime.now()}: Flushing time to visit predictions to disk"
|
288
|
+
)
|
289
|
+
output_parquet_file = os.path.join(
|
290
|
+
prediction_output_folder_name, f"{uuid.uuid4()}.parquet"
|
291
|
+
)
|
292
|
+
pd.DataFrame(
|
293
|
+
tte_outputs,
|
294
|
+
columns=[
|
295
|
+
"person_id",
|
296
|
+
"index_date",
|
297
|
+
"visit_counter",
|
298
|
+
"label",
|
299
|
+
"time_to_event",
|
300
|
+
"prediction",
|
301
|
+
],
|
302
|
+
).to_parquet(output_parquet_file)
|
303
|
+
tte_outputs.clear()
|
304
|
+
|
305
|
+
|
306
|
+
def create_arg_parser():
|
307
|
+
base_arg_parser = create_inference_base_arg_parser(
|
308
|
+
description="Arguments for time sensitive prediction"
|
309
|
+
)
|
310
|
+
base_arg_parser.add_argument(
|
311
|
+
"--dataset_folder",
|
312
|
+
dest="dataset_folder",
|
313
|
+
action="store",
|
314
|
+
help="The path for your dataset",
|
315
|
+
required=True,
|
316
|
+
)
|
317
|
+
base_arg_parser.add_argument(
|
318
|
+
"--num_return_sequences",
|
319
|
+
dest="num_return_sequences",
|
320
|
+
action="store",
|
321
|
+
type=int,
|
322
|
+
required=True,
|
323
|
+
)
|
324
|
+
base_arg_parser.add_argument(
|
325
|
+
"--task_config", dest="task_config", action="store", required=True
|
326
|
+
)
|
327
|
+
base_arg_parser.add_argument(
|
328
|
+
"--concept_ancestor", dest="concept_ancestor", action="store", required=False
|
329
|
+
)
|
330
|
+
base_arg_parser.add_argument(
|
331
|
+
"--debug",
|
332
|
+
dest="debug",
|
333
|
+
action="store_true",
|
334
|
+
)
|
335
|
+
base_arg_parser.add_argument(
|
336
|
+
"--max_n_trial",
|
337
|
+
dest="max_n_trial",
|
338
|
+
action="store",
|
339
|
+
type=int,
|
340
|
+
default=2,
|
341
|
+
required=False,
|
342
|
+
)
|
343
|
+
return base_arg_parser
|
344
|
+
|
345
|
+
|
346
|
+
if __name__ == "__main__":
|
347
|
+
main(create_arg_parser().parse_args())
|
@@ -0,0 +1,55 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from typing import Any, Dict, List, Tuple
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from cehrgpt.gpt_utils import is_att_token
|
7
|
+
|
8
|
+
|
9
|
+
def convert_month_token_to_upperbound_days(
|
10
|
+
month_token: str, time_bucket_size: int = 90
|
11
|
+
) -> str:
|
12
|
+
if is_att_token(month_token):
|
13
|
+
if month_token == "LT":
|
14
|
+
return ">= 1095 days"
|
15
|
+
else:
|
16
|
+
base = (int(month_token[1:]) + 1) * 30 // (time_bucket_size + 1)
|
17
|
+
return (
|
18
|
+
f"{base * time_bucket_size} days - {(base + 1) * time_bucket_size} days"
|
19
|
+
)
|
20
|
+
raise ValueError(f"month_token: {month_token} is not a valid month token")
|
21
|
+
|
22
|
+
|
23
|
+
def calculate_time_bucket_probability(
|
24
|
+
predictions: List[Dict[str, Any]], time_bucket_size: int = 90
|
25
|
+
) -> List[Tuple[str, Any]]:
|
26
|
+
predictions_with_time_buckets = [
|
27
|
+
{
|
28
|
+
"probability": p["probability"],
|
29
|
+
"time_bucket": convert_month_token_to_upperbound_days(
|
30
|
+
p["time_interval"], time_bucket_size
|
31
|
+
),
|
32
|
+
}
|
33
|
+
for p in predictions
|
34
|
+
]
|
35
|
+
# Dictionary to store summed probabilities per time bucket
|
36
|
+
grouped_probabilities = defaultdict(float)
|
37
|
+
# Loop through the data
|
38
|
+
for entry in predictions_with_time_buckets:
|
39
|
+
time_bucket = entry["time_bucket"]
|
40
|
+
probability = entry["probability"]
|
41
|
+
grouped_probabilities[time_bucket] += probability
|
42
|
+
return sorted(grouped_probabilities.items(), key=lambda item: item[1], reverse=True)
|
43
|
+
|
44
|
+
|
45
|
+
def calculate_accumulative_time_bucket_probability(
|
46
|
+
predictions: List[Dict[str, Any]], time_bucket_size: int = 90
|
47
|
+
) -> List[Tuple[str, Any]]:
|
48
|
+
time_bucket_probability = calculate_time_bucket_probability(
|
49
|
+
predictions, time_bucket_size
|
50
|
+
)
|
51
|
+
accumulative_probs = np.cumsum([_[1] for _ in time_bucket_probability])
|
52
|
+
return [
|
53
|
+
(*_, accumulative_prob)
|
54
|
+
for _, accumulative_prob in zip(time_bucket_probability, accumulative_probs)
|
55
|
+
]
|
File without changes
|
@@ -0,0 +1,74 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
import select # Import select for monitoring stdout and stderr
|
4
|
+
import subprocess
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import yaml
|
8
|
+
|
9
|
+
|
10
|
+
def create_arg_parser():
|
11
|
+
parser = argparse.ArgumentParser(
|
12
|
+
description="Arguments for benchmarking CEHRGPT on ehrshot cohorts"
|
13
|
+
)
|
14
|
+
parser.add_argument("--cohort_dir", required=True)
|
15
|
+
parser.add_argument("--base_yaml_file", required=True)
|
16
|
+
parser.add_argument("--output_folder", required=True)
|
17
|
+
return parser.parse_args()
|
18
|
+
|
19
|
+
|
20
|
+
if __name__ == "__main__":
|
21
|
+
args = create_arg_parser()
|
22
|
+
|
23
|
+
with open(args.base_yaml_file, "rb") as stream:
|
24
|
+
base_config = yaml.safe_load(stream)
|
25
|
+
|
26
|
+
for cohort_name in os.listdir(args.cohort_dir):
|
27
|
+
if cohort_name.endswith("/"):
|
28
|
+
cohort_name = cohort_name[:-1]
|
29
|
+
individual_output = os.path.join(args.output_folder, cohort_name)
|
30
|
+
if os.path.exists(individual_output):
|
31
|
+
continue
|
32
|
+
Path(individual_output).mkdir(parents=True, exist_ok=True)
|
33
|
+
base_config["data_folder"] = os.path.join(args.cohort_dir, cohort_name, "train")
|
34
|
+
base_config["test_data_folder"] = os.path.join(
|
35
|
+
args.cohort_dir, cohort_name, "test"
|
36
|
+
)
|
37
|
+
base_config["output_dir"] = individual_output
|
38
|
+
|
39
|
+
# Write YAML data to a file
|
40
|
+
config_path = os.path.join(individual_output, "config.yaml")
|
41
|
+
with open(config_path, "w") as yaml_file:
|
42
|
+
yaml.dump(base_config, yaml_file, default_flow_style=False)
|
43
|
+
|
44
|
+
command = [
|
45
|
+
"python",
|
46
|
+
"-u",
|
47
|
+
"-m",
|
48
|
+
"cehrgpt.runners.hf_cehrgpt_finetune_runner",
|
49
|
+
config_path,
|
50
|
+
]
|
51
|
+
|
52
|
+
# Start the subprocess
|
53
|
+
with subprocess.Popen(
|
54
|
+
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
55
|
+
) as process:
|
56
|
+
while True:
|
57
|
+
# Use select to wait for either stdout or stderr to have data
|
58
|
+
reads = [process.stdout.fileno(), process.stderr.fileno()]
|
59
|
+
ret = select.select(reads, [], [])
|
60
|
+
|
61
|
+
# Read data from stdout and stderr as it becomes available
|
62
|
+
for fd in ret[0]:
|
63
|
+
if fd == process.stdout.fileno():
|
64
|
+
line = process.stdout.readline()
|
65
|
+
if line:
|
66
|
+
print(line, end="")
|
67
|
+
elif fd == process.stderr.fileno():
|
68
|
+
line = process.stderr.readline()
|
69
|
+
if line:
|
70
|
+
print(line, end="")
|
71
|
+
|
72
|
+
# Break loop when process finishes
|
73
|
+
if process.poll() is not None:
|
74
|
+
break
|
@@ -0,0 +1,130 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
import pickle
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
import torch
|
8
|
+
from sklearn.preprocessing import normalize
|
9
|
+
from tqdm import tqdm
|
10
|
+
from transformers import AutoModel, AutoTokenizer
|
11
|
+
|
12
|
+
from cehrgpt.models.pretrained_embeddings import (
|
13
|
+
PRETRAINED_EMBEDDING_CONCEPT_FILE_NAME,
|
14
|
+
PRETRAINED_EMBEDDING_VECTOR_FILE_NAME,
|
15
|
+
)
|
16
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
17
|
+
|
18
|
+
|
19
|
+
def generate_embeddings_batch(texts, tokenizer, device, model):
|
20
|
+
input = tokenizer(
|
21
|
+
texts, return_tensors="pt", padding=True, truncation=True, max_length=512
|
22
|
+
)
|
23
|
+
input = {k: v.to(device) for k, v in input.items()}
|
24
|
+
|
25
|
+
with torch.no_grad():
|
26
|
+
outputs = model(**input)
|
27
|
+
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
28
|
+
return normalize(embeddings)
|
29
|
+
|
30
|
+
|
31
|
+
def main(args):
|
32
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
33
|
+
"dunzhang/stella_en_1.5B_v5", trust_remote_code=True
|
34
|
+
)
|
35
|
+
model = AutoModel.from_pretrained(
|
36
|
+
"dunzhang/stella_en_1.5B_v5", trust_remote_code=True
|
37
|
+
).eval()
|
38
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
39
|
+
model = model.to(device)
|
40
|
+
|
41
|
+
print("Load cehrgpt tokenizer")
|
42
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_path)
|
43
|
+
concept_ids = [_ for _ in cehrgpt_tokenizer.get_vocab().keys() if _.isnumeric()]
|
44
|
+
vocab = pd.DataFrame(concept_ids, columns=["concept_id"])
|
45
|
+
vocab.drop_duplicates(subset=["concept_id"], inplace=True)
|
46
|
+
vocab = vocab.astype(str)
|
47
|
+
|
48
|
+
print("Load concept dataframe")
|
49
|
+
concept = pd.read_parquet(args.concept_parquet_file_path)
|
50
|
+
concept = concept.astype(str)
|
51
|
+
|
52
|
+
print("Merge concept ids and concept names")
|
53
|
+
vocab_with_name = vocab.merge(
|
54
|
+
concept, how="inner", left_on="concept_id", right_on="concept_id"
|
55
|
+
)
|
56
|
+
|
57
|
+
concept_ids = vocab_with_name["concept_id"].to_list()
|
58
|
+
concept_names = vocab_with_name["concept_name"].to_list()
|
59
|
+
all_embeddings = []
|
60
|
+
concept_dict = []
|
61
|
+
for i in tqdm(
|
62
|
+
range(0, (len(concept_names) + args.batch_size - 1), args.batch_size)
|
63
|
+
):
|
64
|
+
batched_concept_names = concept_names[i : i + args.batch_size]
|
65
|
+
batched_concept_ids = concept_ids[i : i + args.batch_size]
|
66
|
+
try:
|
67
|
+
batch_embeddings = generate_embeddings_batch(
|
68
|
+
batched_concept_names, tokenizer, device, model
|
69
|
+
)
|
70
|
+
all_embeddings.extend(batch_embeddings)
|
71
|
+
concept_dict.extend(
|
72
|
+
[
|
73
|
+
{"concept_id": concept_id, "concept_name": concept_name}
|
74
|
+
for concept_id, concept_name in zip(
|
75
|
+
batched_concept_ids, batched_concept_names
|
76
|
+
)
|
77
|
+
]
|
78
|
+
)
|
79
|
+
except Exception as e:
|
80
|
+
print(f"Error processing batch: {str(e)}")
|
81
|
+
|
82
|
+
np.save(
|
83
|
+
os.path.join(args.output_folder_path, PRETRAINED_EMBEDDING_VECTOR_FILE_NAME),
|
84
|
+
all_embeddings,
|
85
|
+
)
|
86
|
+
|
87
|
+
with open(
|
88
|
+
os.path.join(args.output_folder_path, PRETRAINED_EMBEDDING_CONCEPT_FILE_NAME),
|
89
|
+
"wb",
|
90
|
+
) as file:
|
91
|
+
pickle.dump(concept_dict, file)
|
92
|
+
|
93
|
+
|
94
|
+
def create_arg_parser():
|
95
|
+
parser = argparse.ArgumentParser(description="Create pretrained embeddings")
|
96
|
+
parser.add_argument(
|
97
|
+
"--tokenizer_path",
|
98
|
+
dest="tokenizer_path",
|
99
|
+
action="store",
|
100
|
+
help="The path for the vocabulary json file",
|
101
|
+
required=True,
|
102
|
+
)
|
103
|
+
parser.add_argument(
|
104
|
+
"--concept_parquet_file_path",
|
105
|
+
dest="concept_parquet_file_path",
|
106
|
+
action="store",
|
107
|
+
help="The path for your concept_path",
|
108
|
+
required=True,
|
109
|
+
)
|
110
|
+
parser.add_argument(
|
111
|
+
"--batch_size",
|
112
|
+
dest="batch_size",
|
113
|
+
type=int,
|
114
|
+
default=16,
|
115
|
+
action="store",
|
116
|
+
help="Batch size to process the concept_names",
|
117
|
+
required=True,
|
118
|
+
)
|
119
|
+
parser.add_argument(
|
120
|
+
"--output_folder_path",
|
121
|
+
dest="output_folder_path",
|
122
|
+
action="store",
|
123
|
+
help="Output folder path for saving the embeddings and concept_names",
|
124
|
+
required=True,
|
125
|
+
)
|
126
|
+
return parser
|
127
|
+
|
128
|
+
|
129
|
+
if __name__ == "__main__":
|
130
|
+
main(create_arg_parser().parse_args())
|