cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,495 @@
|
|
1
|
+
import glob
|
2
|
+
import os
|
3
|
+
import shutil
|
4
|
+
import uuid
|
5
|
+
from datetime import datetime
|
6
|
+
from functools import partial
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Optional, Union
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
import polars as pl
|
13
|
+
import torch
|
14
|
+
import torch.distributed as dist
|
15
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
16
|
+
from cehrbert.runners.runner_util import generate_prepared_ds_path
|
17
|
+
from datasets import concatenate_datasets, load_from_disk
|
18
|
+
from torch.utils.data import DataLoader
|
19
|
+
from tqdm import tqdm
|
20
|
+
from transformers.trainer_utils import is_main_process
|
21
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
22
|
+
|
23
|
+
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
24
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
25
|
+
CehrGptDataCollator,
|
26
|
+
SamplePackingCehrGptDataCollator,
|
27
|
+
)
|
28
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import ExtractTokenizedSequenceDataMapping
|
29
|
+
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
30
|
+
from cehrgpt.models.hf_cehrgpt import (
|
31
|
+
CEHRGPT2Model,
|
32
|
+
extract_features_from_packed_sequence,
|
33
|
+
)
|
34
|
+
from cehrgpt.models.special_tokens import LINEAR_PROB_TOKEN
|
35
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
36
|
+
from cehrgpt.runners.data_utils import (
|
37
|
+
extract_cohort_sequences,
|
38
|
+
prepare_finetune_dataset,
|
39
|
+
)
|
40
|
+
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
41
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
42
|
+
|
43
|
+
LOG = logging.get_logger("transformers")
|
44
|
+
|
45
|
+
|
46
|
+
def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
|
47
|
+
if torch_dtype and hasattr(torch, torch_dtype):
|
48
|
+
return getattr(torch, torch_dtype)
|
49
|
+
return torch.float32
|
50
|
+
|
51
|
+
|
52
|
+
def extract_averaged_embeddings_from_packed_sequence(
|
53
|
+
hidden_states: torch.Tensor,
|
54
|
+
attention_mask: torch.Tensor,
|
55
|
+
ve_token_indicators: torch.BoolTensor,
|
56
|
+
) -> torch.Tensor:
|
57
|
+
"""
|
58
|
+
Args:
|
59
|
+
|
60
|
+
hidden_states: (batch_size=1, seq_len, hidden_dim) tensor
|
61
|
+
attention_mask: (batch_size=1, seq_len) tensor, where 0 indicates padding
|
62
|
+
ve_token_indicators: (batch_size=1, seq_len) bool tensor, True if token is VE token
|
63
|
+
Returns:
|
64
|
+
(num_samples, hidden_dim) tensor: averaged embeddings over VE tokens for each sample
|
65
|
+
"""
|
66
|
+
# Step 1: Create segment IDs
|
67
|
+
mask = attention_mask[0] # (seq_len,)
|
68
|
+
segment_ids = (mask == 0).cumsum(dim=0) + 1 # start segment IDs from 1
|
69
|
+
segment_ids = (segment_ids * mask).to(torch.int32) # set PAD positions back to 0
|
70
|
+
|
71
|
+
# Step 2: Only keep tokens that are both valid and VE tokens
|
72
|
+
valid = (segment_ids > 0) & (ve_token_indicators[0])
|
73
|
+
valid_embeddings = hidden_states[0, valid].to(
|
74
|
+
torch.float32
|
75
|
+
) # (num_valid_ve_tokens, hidden_dim)
|
76
|
+
valid_segments = segment_ids[valid] # (num_valid_ve_tokens,)
|
77
|
+
|
78
|
+
# Step 3: Group by segment id and average
|
79
|
+
num_segments = int(segment_ids.max().item())
|
80
|
+
|
81
|
+
sample_embeddings = torch.zeros(
|
82
|
+
num_segments, hidden_states.size(-1), device=hidden_states.device
|
83
|
+
)
|
84
|
+
counts = torch.zeros(num_segments, device=hidden_states.device)
|
85
|
+
|
86
|
+
sample_embeddings.index_add_(0, valid_segments - 1, valid_embeddings)
|
87
|
+
counts.index_add_(
|
88
|
+
0, valid_segments - 1, torch.ones_like(valid_segments, dtype=counts.dtype)
|
89
|
+
)
|
90
|
+
|
91
|
+
# Avoid divide-by-zero (if some segments have no VE tokens, set their embeddings to zero)
|
92
|
+
counts = counts.masked_fill(counts == 0, 1.0)
|
93
|
+
|
94
|
+
sample_embeddings = sample_embeddings / counts.unsqueeze(-1)
|
95
|
+
|
96
|
+
return sample_embeddings
|
97
|
+
|
98
|
+
|
99
|
+
def main():
|
100
|
+
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
101
|
+
if torch.cuda.is_available():
|
102
|
+
device = torch.device("cuda")
|
103
|
+
else:
|
104
|
+
device = torch.device("cpu")
|
105
|
+
|
106
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
107
|
+
model_args.tokenizer_name_or_path
|
108
|
+
)
|
109
|
+
torch_dtype = get_torch_dtype(model_args.torch_dtype)
|
110
|
+
cehrgpt_model = (
|
111
|
+
CEHRGPT2Model.from_pretrained(
|
112
|
+
model_args.model_name_or_path,
|
113
|
+
attn_implementation=(
|
114
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
115
|
+
),
|
116
|
+
torch_dtype=torch_dtype,
|
117
|
+
)
|
118
|
+
.eval()
|
119
|
+
.to(device)
|
120
|
+
)
|
121
|
+
|
122
|
+
if LINEAR_PROB_TOKEN not in cehrgpt_tokenizer.get_vocab():
|
123
|
+
cehrgpt_tokenizer.add_tokens(LINEAR_PROB_TOKEN)
|
124
|
+
cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
125
|
+
|
126
|
+
prepared_ds_path = generate_prepared_ds_path(
|
127
|
+
data_args, model_args, data_folder=data_args.cohort_folder
|
128
|
+
)
|
129
|
+
cache_file_collector = CacheFileCollector()
|
130
|
+
processed_dataset = None
|
131
|
+
if any(prepared_ds_path.glob("*")):
|
132
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
133
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
134
|
+
LOG.info("Prepared dataset loaded from disk...")
|
135
|
+
if cehrgpt_args.expand_tokenizer:
|
136
|
+
if tokenizer_exists(training_args.output_dir):
|
137
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
138
|
+
training_args.output_dir
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
LOG.warning(
|
142
|
+
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
143
|
+
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
144
|
+
f"Please delete the processed dataset at {prepared_ds_path}."
|
145
|
+
)
|
146
|
+
processed_dataset = None
|
147
|
+
shutil.rmtree(prepared_ds_path)
|
148
|
+
|
149
|
+
if processed_dataset is None:
|
150
|
+
if is_main_process(training_args.local_rank):
|
151
|
+
# If the full dataset has been tokenized, we don't want to tokenize the cohort containing
|
152
|
+
# the subset of the data. We should slice out the portion of the tokenized sequences for each sample
|
153
|
+
if cehrgpt_args.tokenized_full_dataset_path is not None:
|
154
|
+
processed_dataset = extract_cohort_sequences(
|
155
|
+
data_args, cehrgpt_args, cache_file_collector
|
156
|
+
)
|
157
|
+
else:
|
158
|
+
# Organize them into a single DatasetDict
|
159
|
+
final_splits = prepare_finetune_dataset(
|
160
|
+
data_args, training_args, cehrgpt_args, cache_file_collector
|
161
|
+
)
|
162
|
+
if cehrgpt_args.expand_tokenizer:
|
163
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
164
|
+
if tokenizer_exists(new_tokenizer_path):
|
165
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
166
|
+
new_tokenizer_path
|
167
|
+
)
|
168
|
+
else:
|
169
|
+
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
170
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
171
|
+
dataset=final_splits["train"],
|
172
|
+
data_args=data_args,
|
173
|
+
concept_name_mapping={},
|
174
|
+
)
|
175
|
+
cehrgpt_tokenizer.save_pretrained(
|
176
|
+
os.path.expanduser(training_args.output_dir)
|
177
|
+
)
|
178
|
+
|
179
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
180
|
+
if not data_args.streaming:
|
181
|
+
all_columns = final_splits["train"].column_names
|
182
|
+
if "visit_concept_ids" in all_columns:
|
183
|
+
final_splits = final_splits.remove_columns(
|
184
|
+
["visit_concept_ids"]
|
185
|
+
)
|
186
|
+
|
187
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
188
|
+
dataset=final_splits,
|
189
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
190
|
+
data_args=data_args,
|
191
|
+
cache_file_collector=cache_file_collector,
|
192
|
+
)
|
193
|
+
if not data_args.streaming:
|
194
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
195
|
+
processed_dataset.cleanup_cache_files()
|
196
|
+
|
197
|
+
# Remove all the cached files if processed_dataset.cleanup_cache_files() did not remove them already
|
198
|
+
cache_file_collector.remove_cache_files()
|
199
|
+
|
200
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
201
|
+
if dist.is_available() and dist.is_initialized():
|
202
|
+
dist.barrier()
|
203
|
+
|
204
|
+
# Load the dataset from disk again to in torch distributed training
|
205
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
206
|
+
|
207
|
+
# Getting the existing features
|
208
|
+
feature_folders = glob.glob(
|
209
|
+
os.path.join(training_args.output_dir, "*", "features", "*.parquet")
|
210
|
+
)
|
211
|
+
if feature_folders:
|
212
|
+
existing_features = pd.concat(
|
213
|
+
[
|
214
|
+
pd.read_parquet(f, columns=["subject_id", "prediction_time_posix"])
|
215
|
+
for f in feature_folders
|
216
|
+
],
|
217
|
+
ignore_index=True,
|
218
|
+
)
|
219
|
+
subject_prediction_tuples = set(
|
220
|
+
existing_features.apply(
|
221
|
+
lambda row: f"{int(row['subject_id'])}-{int(row['prediction_time_posix'])}",
|
222
|
+
axis=1,
|
223
|
+
).tolist()
|
224
|
+
)
|
225
|
+
processed_dataset = processed_dataset.filter(
|
226
|
+
lambda _batch: [
|
227
|
+
f"{int(subject)}-{int(time)}" not in subject_prediction_tuples
|
228
|
+
for subject, time in zip(_batch["person_id"], _batch["index_date"])
|
229
|
+
],
|
230
|
+
num_proc=data_args.preprocessing_num_workers,
|
231
|
+
batch_size=data_args.preprocessing_batch_size,
|
232
|
+
batched=True,
|
233
|
+
)
|
234
|
+
LOG.info(
|
235
|
+
"The datasets after filtering (train: %s, validation: %s, test: %s)",
|
236
|
+
len(processed_dataset["train"]),
|
237
|
+
len(processed_dataset["validation"]),
|
238
|
+
len(processed_dataset["test"]),
|
239
|
+
)
|
240
|
+
|
241
|
+
LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
|
242
|
+
LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
|
243
|
+
if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
244
|
+
cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
245
|
+
if (
|
246
|
+
cehrgpt_model.config.max_position_embeddings
|
247
|
+
< model_args.max_position_embeddings
|
248
|
+
):
|
249
|
+
LOG.info(
|
250
|
+
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
251
|
+
)
|
252
|
+
cehrgpt_model.config.max_position_embeddings = (
|
253
|
+
model_args.max_position_embeddings
|
254
|
+
)
|
255
|
+
cehrgpt_model.resize_position_embeddings(model_args.max_position_embeddings)
|
256
|
+
|
257
|
+
train_set = concatenate_datasets(
|
258
|
+
[processed_dataset["train"], processed_dataset["validation"]]
|
259
|
+
)
|
260
|
+
|
261
|
+
if cehrgpt_args.sample_packing:
|
262
|
+
per_device_eval_batch_size = 1
|
263
|
+
data_collator_fn = partial(
|
264
|
+
SamplePackingCehrGptDataCollator,
|
265
|
+
cehrgpt_args.max_tokens_per_batch,
|
266
|
+
cehrgpt_model.config.max_position_embeddings,
|
267
|
+
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
268
|
+
)
|
269
|
+
train_batch_sampler = SamplePackingBatchSampler(
|
270
|
+
lengths=train_set["num_of_concepts"],
|
271
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
272
|
+
max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
|
273
|
+
drop_last=training_args.dataloader_drop_last,
|
274
|
+
seed=training_args.seed,
|
275
|
+
)
|
276
|
+
test_batch_sampler = SamplePackingBatchSampler(
|
277
|
+
lengths=processed_dataset["test"]["num_of_concepts"],
|
278
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
279
|
+
max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
|
280
|
+
drop_last=training_args.dataloader_drop_last,
|
281
|
+
seed=training_args.seed,
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
data_collator_fn = CehrGptDataCollator
|
285
|
+
train_batch_sampler = None
|
286
|
+
test_batch_sampler = None
|
287
|
+
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
288
|
+
|
289
|
+
# We suppress the additional learning objectives in fine-tuning
|
290
|
+
data_collator = data_collator_fn(
|
291
|
+
tokenizer=cehrgpt_tokenizer,
|
292
|
+
max_length=(
|
293
|
+
cehrgpt_args.max_tokens_per_batch
|
294
|
+
if cehrgpt_args.sample_packing
|
295
|
+
else model_args.max_position_embeddings
|
296
|
+
),
|
297
|
+
include_values=cehrgpt_model.config.include_values,
|
298
|
+
pretraining=False,
|
299
|
+
include_ttv_prediction=False,
|
300
|
+
use_sub_time_tokenization=False,
|
301
|
+
include_demographics=cehrgpt_args.include_demographics,
|
302
|
+
add_linear_prob_token=True,
|
303
|
+
)
|
304
|
+
|
305
|
+
train_loader = DataLoader(
|
306
|
+
dataset=train_set,
|
307
|
+
batch_size=per_device_eval_batch_size,
|
308
|
+
num_workers=training_args.dataloader_num_workers,
|
309
|
+
collate_fn=data_collator,
|
310
|
+
pin_memory=training_args.dataloader_pin_memory,
|
311
|
+
batch_sampler=train_batch_sampler,
|
312
|
+
)
|
313
|
+
|
314
|
+
test_dataloader = DataLoader(
|
315
|
+
dataset=processed_dataset["test"],
|
316
|
+
batch_size=per_device_eval_batch_size,
|
317
|
+
num_workers=training_args.dataloader_num_workers,
|
318
|
+
collate_fn=data_collator,
|
319
|
+
pin_memory=training_args.dataloader_pin_memory,
|
320
|
+
batch_sampler=test_batch_sampler,
|
321
|
+
)
|
322
|
+
|
323
|
+
if data_args.is_data_in_meds:
|
324
|
+
demographics_dict = dict()
|
325
|
+
else:
|
326
|
+
# Loading demographics
|
327
|
+
print("Loading demographics as a dictionary")
|
328
|
+
demographics_df = pd.concat(
|
329
|
+
[
|
330
|
+
pd.read_parquet(
|
331
|
+
data_dir,
|
332
|
+
columns=[
|
333
|
+
"person_id",
|
334
|
+
"index_date",
|
335
|
+
"gender_concept_id",
|
336
|
+
"race_concept_id",
|
337
|
+
],
|
338
|
+
)
|
339
|
+
for data_dir in [data_args.data_folder, data_args.test_data_folder]
|
340
|
+
]
|
341
|
+
)
|
342
|
+
# This is a pre-caution in case the index_date is not a datetime type
|
343
|
+
demographics_df["index_date"] = pd.to_datetime(
|
344
|
+
demographics_df["index_date"]
|
345
|
+
).dt.date
|
346
|
+
demographics_dict = {
|
347
|
+
(row["person_id"], row["index_date"]): {
|
348
|
+
"gender_concept_id": row["gender_concept_id"],
|
349
|
+
"race_concept_id": row["race_concept_id"],
|
350
|
+
}
|
351
|
+
for _, row in demographics_df.iterrows()
|
352
|
+
}
|
353
|
+
|
354
|
+
data_loaders = [("train", train_loader), ("test", test_dataloader)]
|
355
|
+
|
356
|
+
ve_token_id = cehrgpt_tokenizer._convert_token_to_id("[VE]")
|
357
|
+
for split, data_loader in data_loaders:
|
358
|
+
# Ensure prediction folder exists
|
359
|
+
feature_output_folder = (
|
360
|
+
Path(training_args.output_dir) / "features_with_label" / f"{split}_features"
|
361
|
+
)
|
362
|
+
feature_output_folder.mkdir(parents=True, exist_ok=True)
|
363
|
+
|
364
|
+
LOG.info("Generating features for %s set at %s", split, feature_output_folder)
|
365
|
+
|
366
|
+
with torch.no_grad():
|
367
|
+
for index, batch in enumerate(
|
368
|
+
tqdm(data_loader, desc="Generating features")
|
369
|
+
):
|
370
|
+
prediction_time_ages = (
|
371
|
+
batch.pop("age_at_index").numpy().astype(float).squeeze()
|
372
|
+
)
|
373
|
+
if prediction_time_ages.ndim == 0:
|
374
|
+
prediction_time_ages = np.asarray([prediction_time_ages])
|
375
|
+
|
376
|
+
person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
|
377
|
+
if person_ids.ndim == 0:
|
378
|
+
person_ids = np.asarray([person_ids])
|
379
|
+
prediction_time_posix = batch.pop("index_date").numpy().squeeze()
|
380
|
+
if prediction_time_posix.ndim == 0:
|
381
|
+
prediction_time_posix = np.asarray([prediction_time_posix])
|
382
|
+
prediction_time = list(
|
383
|
+
map(datetime.fromtimestamp, prediction_time_posix)
|
384
|
+
)
|
385
|
+
labels = (
|
386
|
+
batch.pop("classifier_label")
|
387
|
+
.float()
|
388
|
+
.cpu()
|
389
|
+
.numpy()
|
390
|
+
.astype(bool)
|
391
|
+
.squeeze()
|
392
|
+
)
|
393
|
+
if labels.ndim == 0:
|
394
|
+
labels = np.asarray([labels])
|
395
|
+
|
396
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
397
|
+
# Forward pass
|
398
|
+
cehrgpt_output = cehrgpt_model(
|
399
|
+
**batch, output_attentions=False, output_hidden_states=False
|
400
|
+
)
|
401
|
+
if cehrgpt_args.sample_packing:
|
402
|
+
if cehrgpt_args.average_over_sequence:
|
403
|
+
ve_token_indicators: torch.BoolTensor = (
|
404
|
+
batch["input_ids"] == ve_token_id
|
405
|
+
)
|
406
|
+
features = (
|
407
|
+
extract_averaged_embeddings_from_packed_sequence(
|
408
|
+
cehrgpt_output.last_hidden_state,
|
409
|
+
batch["attention_mask"],
|
410
|
+
ve_token_indicators,
|
411
|
+
)
|
412
|
+
.cpu()
|
413
|
+
.float()
|
414
|
+
.detach()
|
415
|
+
.numpy()
|
416
|
+
)
|
417
|
+
else:
|
418
|
+
features = (
|
419
|
+
extract_features_from_packed_sequence(
|
420
|
+
cehrgpt_output.last_hidden_state,
|
421
|
+
batch["attention_mask"],
|
422
|
+
)
|
423
|
+
.cpu()
|
424
|
+
.float()
|
425
|
+
.detach()
|
426
|
+
.numpy()
|
427
|
+
.squeeze(axis=0)
|
428
|
+
)
|
429
|
+
else:
|
430
|
+
if cehrgpt_args.average_over_sequence:
|
431
|
+
features = torch.where(
|
432
|
+
batch["attention_mask"].unsqueeze(dim=-1).to(torch.bool),
|
433
|
+
cehrgpt_output.last_hidden_state,
|
434
|
+
0,
|
435
|
+
)
|
436
|
+
# Average across the sequence
|
437
|
+
features = features.mean(dim=1)
|
438
|
+
else:
|
439
|
+
last_end_token = any(
|
440
|
+
[
|
441
|
+
cehrgpt_tokenizer.end_token_id == input_id
|
442
|
+
for input_id in batch.pop("input_ids")
|
443
|
+
.cpu()
|
444
|
+
.numpy()
|
445
|
+
.squeeze()
|
446
|
+
.tolist()
|
447
|
+
]
|
448
|
+
)
|
449
|
+
last_token_index = -2 if last_end_token else -1
|
450
|
+
LOG.debug(
|
451
|
+
"The last token is [END], we need to use the token index before that: %s",
|
452
|
+
last_token_index,
|
453
|
+
)
|
454
|
+
features = (
|
455
|
+
cehrgpt_output.last_hidden_state[..., last_token_index, :]
|
456
|
+
.cpu()
|
457
|
+
.float()
|
458
|
+
.detach()
|
459
|
+
.numpy()
|
460
|
+
)
|
461
|
+
|
462
|
+
# Flatten features or handle them as a list of arrays (one array per row)
|
463
|
+
features_list = [feature for feature in features]
|
464
|
+
race_concept_ids = []
|
465
|
+
gender_concept_ids = []
|
466
|
+
for person_id, index_date in zip(person_ids, prediction_time):
|
467
|
+
key = (person_id, index_date.date())
|
468
|
+
if key in demographics_dict:
|
469
|
+
demographics = demographics_dict[key]
|
470
|
+
gender_concept_ids.append(demographics["gender_concept_id"])
|
471
|
+
race_concept_ids.append(demographics["race_concept_id"])
|
472
|
+
else:
|
473
|
+
gender_concept_ids.append(0)
|
474
|
+
race_concept_ids.append(0)
|
475
|
+
|
476
|
+
features_pd = pd.DataFrame(
|
477
|
+
{
|
478
|
+
"subject_id": person_ids,
|
479
|
+
"prediction_time": prediction_time,
|
480
|
+
"prediction_time_posix": prediction_time_posix,
|
481
|
+
"boolean_value": labels,
|
482
|
+
"age_at_index": prediction_time_ages,
|
483
|
+
}
|
484
|
+
)
|
485
|
+
# Adding features as a separate column where each row contains a feature array
|
486
|
+
features_pd["features"] = features_list
|
487
|
+
features_pd["race_concept_id"] = race_concept_ids
|
488
|
+
features_pd["gender_concept_id"] = gender_concept_ids
|
489
|
+
features_pd.to_parquet(
|
490
|
+
feature_output_folder / f"{uuid.uuid4()}.parquet"
|
491
|
+
)
|
492
|
+
|
493
|
+
|
494
|
+
if __name__ == "__main__":
|
495
|
+
main()
|
@@ -0,0 +1,152 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import pickle
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, Dict, Union
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import polars as pl
|
10
|
+
from sklearn.linear_model import LogisticRegressionCV
|
11
|
+
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
|
12
|
+
from sklearn.preprocessing import OneHotEncoder, StandardScaler
|
13
|
+
|
14
|
+
|
15
|
+
def prepare_dataset(
|
16
|
+
df: pd.DataFrame, feature_processor: Dict[str, Union[StandardScaler, OneHotEncoder]]
|
17
|
+
) -> Dict[str, Any]:
|
18
|
+
age_scaler = feature_processor["age_scaler"]
|
19
|
+
gender_encoder = feature_processor["gender_encoder"]
|
20
|
+
race_encoder = feature_processor["race_encoder"]
|
21
|
+
age_scaler.transform(df[["age_at_index"]].to_numpy())
|
22
|
+
|
23
|
+
one_hot_gender = gender_encoder.transform(
|
24
|
+
np.expand_dims(df.gender_concept_id.to_numpy(), axis=1)
|
25
|
+
)
|
26
|
+
one_hot_race = race_encoder.transform(
|
27
|
+
np.expand_dims(df.race_concept_id.to_numpy(), axis=1)
|
28
|
+
)
|
29
|
+
|
30
|
+
features = np.stack(df["features"].apply(lambda x: np.array(x).flatten()))
|
31
|
+
# features = np.hstack(
|
32
|
+
# [scaled_age, one_hot_gender.toarray(), one_hot_race.toarray(), features]
|
33
|
+
# )
|
34
|
+
return {
|
35
|
+
"subject_id": df["subject_id"].to_numpy(),
|
36
|
+
"prediction_time": df["prediction_time"].tolist(),
|
37
|
+
"features": features,
|
38
|
+
"boolean_value": df["boolean_value"].to_numpy(),
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
def main(args):
|
43
|
+
features_data_dir = Path(args.features_data_dir)
|
44
|
+
output_dir = Path(args.output_dir)
|
45
|
+
feature_processor_path = output_dir / "feature_processor.pickle"
|
46
|
+
logistic_dir = output_dir / "logistic"
|
47
|
+
logistic_dir.mkdir(exist_ok=True, parents=True)
|
48
|
+
logistic_test_result_file = logistic_dir / "metrics.json"
|
49
|
+
if logistic_test_result_file.exists():
|
50
|
+
print("The models have been trained, and skip ...")
|
51
|
+
exit(0)
|
52
|
+
|
53
|
+
feature_train = pd.read_parquet(
|
54
|
+
features_data_dir / "features_with_label" / "train_features"
|
55
|
+
)
|
56
|
+
feature_test = pd.read_parquet(
|
57
|
+
features_data_dir / "features_with_label" / "test_features"
|
58
|
+
)
|
59
|
+
|
60
|
+
feature_train = feature_train.sort_values(["subject_id", "prediction_time"]).sample(
|
61
|
+
frac=1.0,
|
62
|
+
random_state=42,
|
63
|
+
replace=False,
|
64
|
+
)
|
65
|
+
|
66
|
+
if feature_processor_path.exists():
|
67
|
+
with open(feature_processor_path, "rb") as f:
|
68
|
+
feature_processor = pickle.load(f)
|
69
|
+
else:
|
70
|
+
age_scaler, gender_encoder, race_encoder = (
|
71
|
+
StandardScaler(),
|
72
|
+
OneHotEncoder(handle_unknown="ignore"),
|
73
|
+
OneHotEncoder(handle_unknown="ignore"),
|
74
|
+
)
|
75
|
+
age_scaler = age_scaler.fit(feature_train[["age_at_index"]].to_numpy())
|
76
|
+
gender_encoder = gender_encoder.fit(
|
77
|
+
feature_train[["gender_concept_id"]].to_numpy()
|
78
|
+
)
|
79
|
+
race_encoder = race_encoder.fit(feature_train[["race_concept_id"]].to_numpy())
|
80
|
+
feature_processor = {
|
81
|
+
"age_scaler": age_scaler,
|
82
|
+
"gender_encoder": gender_encoder,
|
83
|
+
"race_encoder": race_encoder,
|
84
|
+
}
|
85
|
+
with open(feature_processor_path, "wb") as f:
|
86
|
+
pickle.dump(feature_processor, f)
|
87
|
+
|
88
|
+
if logistic_test_result_file.exists():
|
89
|
+
print(
|
90
|
+
f"The results for logistic regression already exist at {logistic_test_result_file}"
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
logistic_model_file = logistic_dir / "model.pickle"
|
94
|
+
if logistic_model_file.exists():
|
95
|
+
print(
|
96
|
+
f"The logistic regression model already exist, loading it from {logistic_model_file}"
|
97
|
+
)
|
98
|
+
with open(logistic_model_file, "rb") as f:
|
99
|
+
model = pickle.load(f)
|
100
|
+
else:
|
101
|
+
train_dataset = prepare_dataset(feature_train, feature_processor)
|
102
|
+
# Train logistic regression
|
103
|
+
model = LogisticRegressionCV(scoring="roc_auc", random_state=42)
|
104
|
+
model.fit(train_dataset["features"], train_dataset["boolean_value"])
|
105
|
+
with open(logistic_model_file, "wb") as f:
|
106
|
+
pickle.dump(model, f)
|
107
|
+
|
108
|
+
test_dataset = prepare_dataset(feature_test, feature_processor)
|
109
|
+
y_pred = model.predict_proba(test_dataset["features"])[:, 1]
|
110
|
+
logistic_predictions = pl.DataFrame(
|
111
|
+
{
|
112
|
+
"subject_id": test_dataset["subject_id"].tolist(),
|
113
|
+
"prediction_time": test_dataset["prediction_time"],
|
114
|
+
"predicted_boolean_probability": y_pred.tolist(),
|
115
|
+
"predicted_boolean_value": None,
|
116
|
+
"boolean_value": test_dataset["boolean_value"].astype(bool).tolist(),
|
117
|
+
}
|
118
|
+
)
|
119
|
+
logistic_predictions = logistic_predictions.with_columns(
|
120
|
+
pl.col("predicted_boolean_value").cast(pl.Boolean())
|
121
|
+
)
|
122
|
+
logistic_test_predictions = logistic_dir / "test_predictions"
|
123
|
+
logistic_test_predictions.mkdir(exist_ok=True, parents=True)
|
124
|
+
logistic_predictions.write_parquet(
|
125
|
+
logistic_test_predictions / "predictions.parquet"
|
126
|
+
)
|
127
|
+
|
128
|
+
roc_auc = roc_auc_score(test_dataset["boolean_value"], y_pred)
|
129
|
+
precision, recall, _ = precision_recall_curve(
|
130
|
+
test_dataset["boolean_value"], y_pred
|
131
|
+
)
|
132
|
+
pr_auc = auc(recall, precision)
|
133
|
+
|
134
|
+
metrics = {"roc_auc": roc_auc, "pr_auc": pr_auc}
|
135
|
+
print("Logistic:", features_data_dir.name, metrics)
|
136
|
+
with open(logistic_test_result_file, "w") as f:
|
137
|
+
json.dump(metrics, f, indent=4)
|
138
|
+
|
139
|
+
|
140
|
+
if __name__ == "__main__":
|
141
|
+
parser = argparse.ArgumentParser(
|
142
|
+
description="Train logistic regression model with cehrgpt features"
|
143
|
+
)
|
144
|
+
parser.add_argument(
|
145
|
+
"--features_data_dir",
|
146
|
+
required=True,
|
147
|
+
help="Directory containing training and test feature files",
|
148
|
+
)
|
149
|
+
parser.add_argument(
|
150
|
+
"--output_dir", required=True, help="Directory to save the output results"
|
151
|
+
)
|
152
|
+
main(parser.parse_args())
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.1.1
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -12,13 +12,15 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Requires-Python: >=3.10.0
|
13
13
|
Description-Content-Type: text/markdown
|
14
14
|
License-File: LICENSE
|
15
|
-
Requires-Dist: cehrbert==1.
|
15
|
+
Requires-Dist: cehrbert==1.4.5
|
16
|
+
Requires-Dist: cehrbert_data==0.0.11
|
16
17
|
Requires-Dist: openai==1.54.3
|
17
18
|
Requires-Dist: optuna==4.0.0
|
18
|
-
Requires-Dist: transformers==4.
|
19
|
+
Requires-Dist: transformers==4.44.1
|
19
20
|
Requires-Dist: tokenizers==0.19.0
|
20
21
|
Requires-Dist: peft==0.10.0
|
21
|
-
Requires-Dist:
|
22
|
+
Requires-Dist: lightgbm
|
23
|
+
Requires-Dist: polars
|
22
24
|
Provides-Extra: dev
|
23
25
|
Requires-Dist: pre-commit; extra == "dev"
|
24
26
|
Requires-Dist: pytest; extra == "dev"
|
@@ -29,14 +31,15 @@ Requires-Dist: hypothesis; extra == "dev"
|
|
29
31
|
Requires-Dist: black; extra == "dev"
|
30
32
|
Provides-Extra: flash-attn
|
31
33
|
Requires-Dist: flash_attn; extra == "flash-attn"
|
34
|
+
Dynamic: license-file
|
32
35
|
|
33
36
|
# CEHRGPT
|
34
37
|
|
35
38
|
[](https://pypi.org/project/cehrgpt/)
|
36
39
|

|
37
|
-
[](https://github.com/knatarajan-lab/cehrgpt
|
39
|
-
[](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
|
41
|
+
[](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
|
42
|
+
[](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
|
40
43
|
|
41
44
|
## Description
|
42
45
|
CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
|