cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
__init__.py
ADDED
File without changes
|
cehrgpt/__init__.py
ADDED
File without changes
|
File without changes
|
File without changes
|
@@ -0,0 +1,275 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import sys
|
5
|
+
from datetime import datetime
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import yaml
|
11
|
+
|
12
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
13
|
+
|
14
|
+
from .utils import (
|
15
|
+
batched_pairwise_euclidean_distance_indices,
|
16
|
+
create_demographics,
|
17
|
+
create_gender_encoder,
|
18
|
+
create_race_encoder,
|
19
|
+
create_vector_representations_for_attribute,
|
20
|
+
find_match,
|
21
|
+
find_match_self,
|
22
|
+
scale_age,
|
23
|
+
)
|
24
|
+
|
25
|
+
RANDOM_SEE = 42
|
26
|
+
logging.basicConfig(
|
27
|
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
28
|
+
)
|
29
|
+
LOG = logging.getLogger("attribute_inference")
|
30
|
+
|
31
|
+
|
32
|
+
def safe_divide(numerator_vector, denominator_vector):
|
33
|
+
return np.where(denominator_vector > 0, numerator_vector / denominator_vector, 0)
|
34
|
+
|
35
|
+
|
36
|
+
def cal_f1_score(vector_a, vector_b, index_matrix):
|
37
|
+
# vector_a is train data and vector_b is synthetic data or iteself
|
38
|
+
shared_vector = np.logical_and(
|
39
|
+
vector_a[: len(index_matrix)], vector_b[index_matrix]
|
40
|
+
).astype(int)
|
41
|
+
shared_vector_sum = np.sum(shared_vector, axis=1)
|
42
|
+
|
43
|
+
precision = safe_divide(shared_vector_sum, np.sum(vector_b, axis=1))
|
44
|
+
recall = safe_divide(shared_vector_sum, np.sum(vector_a, axis=1))
|
45
|
+
|
46
|
+
f1 = safe_divide(2 * recall * precision, recall + precision)
|
47
|
+
return f1, precision, recall
|
48
|
+
|
49
|
+
|
50
|
+
def main(args):
|
51
|
+
try:
|
52
|
+
with open(args.attribute_config, "r") as file:
|
53
|
+
data = yaml.safe_load(file)
|
54
|
+
if "common_attributes" in data:
|
55
|
+
common_attributes = data["common_attributes"]
|
56
|
+
if "sensitive_attributes" in data:
|
57
|
+
sensitive_attributes = data["sensitive_attributes"]
|
58
|
+
except Union[FileNotFoundError, PermissionError, OSError] as e:
|
59
|
+
sys.exit(e)
|
60
|
+
|
61
|
+
attribute_inference_folder = os.path.join(args.output_folder, "attribute_inference")
|
62
|
+
if not os.path.exists(attribute_inference_folder):
|
63
|
+
LOG.info(
|
64
|
+
f"Creating the attribute_inference output folder at {attribute_inference_folder}"
|
65
|
+
)
|
66
|
+
os.makedirs(attribute_inference_folder, exist_ok=True)
|
67
|
+
|
68
|
+
LOG.info(f"Started loading tokenizer at {args.tokenizer_path}")
|
69
|
+
concept_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_path)
|
70
|
+
|
71
|
+
LOG.info(f"Started loading training data at {args.training_data_folder}")
|
72
|
+
train_data = pd.read_parquet(args.training_data_folder)
|
73
|
+
|
74
|
+
LOG.info(f"Started loading synthetic_data at {args.synthetic_data_folder}")
|
75
|
+
synthetic_data = pd.read_parquet(args.synthetic_data_folder)
|
76
|
+
|
77
|
+
LOG.info(
|
78
|
+
"Started extracting the demographic information from the patient sequences"
|
79
|
+
)
|
80
|
+
train_data = create_demographics(train_data)
|
81
|
+
synthetic_data = create_demographics(synthetic_data)
|
82
|
+
|
83
|
+
LOG.info("Started rescaling age columns")
|
84
|
+
train_data = scale_age(train_data)
|
85
|
+
synthetic_data = scale_age(synthetic_data)
|
86
|
+
|
87
|
+
LOG.info("Started encoding gender")
|
88
|
+
gender_encoder = create_gender_encoder(
|
89
|
+
train_data,
|
90
|
+
# TODO need to change this function to be generic
|
91
|
+
train_data[:10],
|
92
|
+
synthetic_data,
|
93
|
+
)
|
94
|
+
LOG.info("Completed encoding gender")
|
95
|
+
|
96
|
+
LOG.info("Started encoding race")
|
97
|
+
race_encoder = create_race_encoder(
|
98
|
+
train_data,
|
99
|
+
# TODO need to change this function to be generic
|
100
|
+
train_data[:10],
|
101
|
+
synthetic_data,
|
102
|
+
)
|
103
|
+
LOG.info("Completed encoding race")
|
104
|
+
|
105
|
+
random.seed(RANDOM_SEE)
|
106
|
+
for i in range(1, args.n_iterations + 1):
|
107
|
+
LOG.info(f"Iteration {i}: Started creating data samples")
|
108
|
+
train_data_sample = train_data.sample(args.num_of_samples)
|
109
|
+
synthetic_data_sample = synthetic_data.sample(args.num_of_samples)
|
110
|
+
LOG.info(f"Iteration {i}: Started creating train sample vectors")
|
111
|
+
train_common_vectors, train_sensitive_vectors = (
|
112
|
+
create_vector_representations_for_attribute(
|
113
|
+
train_data_sample,
|
114
|
+
concept_tokenizer,
|
115
|
+
gender_encoder,
|
116
|
+
race_encoder,
|
117
|
+
common_attributes=common_attributes,
|
118
|
+
sensitive_attributes=sensitive_attributes,
|
119
|
+
)
|
120
|
+
)
|
121
|
+
|
122
|
+
LOG.info(f"Iteration {i}: Started creating synthetic vectors")
|
123
|
+
synthetic_common_vectors, synthetic_sensitive_vectors = (
|
124
|
+
create_vector_representations_for_attribute(
|
125
|
+
synthetic_data_sample,
|
126
|
+
concept_tokenizer,
|
127
|
+
gender_encoder,
|
128
|
+
race_encoder,
|
129
|
+
common_attributes=common_attributes,
|
130
|
+
sensitive_attributes=sensitive_attributes,
|
131
|
+
)
|
132
|
+
)
|
133
|
+
|
134
|
+
LOG.info(
|
135
|
+
f"Started calculating the distances between synthetic and training vectors"
|
136
|
+
)
|
137
|
+
if args.batched:
|
138
|
+
train_synthetic_index = batched_pairwise_euclidean_distance_indices(
|
139
|
+
train_common_vectors,
|
140
|
+
synthetic_common_vectors,
|
141
|
+
batch_size=args.batch_size,
|
142
|
+
)
|
143
|
+
train_train_index = batched_pairwise_euclidean_distance_indices(
|
144
|
+
train_common_vectors,
|
145
|
+
train_common_vectors,
|
146
|
+
batch_size=args.batch_size,
|
147
|
+
self_exclude=True,
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
train_synthetic_index = find_match(
|
151
|
+
train_common_vectors, synthetic_common_vectors, return_index=True
|
152
|
+
)
|
153
|
+
train_train_index = find_match_self(
|
154
|
+
train_common_vectors, train_common_vectors, return_index=True
|
155
|
+
)
|
156
|
+
|
157
|
+
f1_syn_train, precision_syn_train, recall_syn_train = cal_f1_score(
|
158
|
+
train_sensitive_vectors, synthetic_sensitive_vectors, train_synthetic_index
|
159
|
+
)
|
160
|
+
f1_train_train, precision_train_train, recall_train_train = cal_f1_score(
|
161
|
+
train_sensitive_vectors, train_sensitive_vectors, train_train_index
|
162
|
+
)
|
163
|
+
|
164
|
+
results = {
|
165
|
+
"Precision Synthetic Train": precision_syn_train,
|
166
|
+
"Recall Synthetic Train": recall_syn_train,
|
167
|
+
"F1 Synthetic Train": f1_syn_train,
|
168
|
+
"Precision Train Train": precision_train_train,
|
169
|
+
"Recall Train Train": recall_train_train,
|
170
|
+
"F1 Train Train": f1_train_train,
|
171
|
+
}
|
172
|
+
LOG.info(
|
173
|
+
f"Attribute Inference: Average Precision Synthetic Train: {np.mean(precision_syn_train)} \n"
|
174
|
+
f"Attribute Inference: Average Recall Synthetic Train:{np.mean(recall_syn_train)} \n"
|
175
|
+
f"Attribute Inference: Average F1 Synthetic Train: {np.mean(f1_syn_train)} \n"
|
176
|
+
f"Attribute Inference: Average Precision Train Train: {np.mean(precision_train_train)} \n"
|
177
|
+
f"Attribute Inference: Average Recall Train Train: {np.mean(recall_train_train)} \n"
|
178
|
+
f"Attribute Inference: Average F1 Train Train: {np.mean(f1_train_train)}"
|
179
|
+
)
|
180
|
+
current_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
|
181
|
+
pd.DataFrame(
|
182
|
+
results,
|
183
|
+
columns=[
|
184
|
+
"Precision Synthetic Train",
|
185
|
+
"Recall Synthetic Train",
|
186
|
+
"F1 Synthetic Train",
|
187
|
+
"Precision Train Train",
|
188
|
+
"Recall Train Train",
|
189
|
+
"F1 Train Train",
|
190
|
+
],
|
191
|
+
).to_parquet(
|
192
|
+
os.path.join(
|
193
|
+
attribute_inference_folder,
|
194
|
+
f"attribute_inference_{current_time}.parquet",
|
195
|
+
)
|
196
|
+
)
|
197
|
+
|
198
|
+
|
199
|
+
def create_argparser():
|
200
|
+
import argparse
|
201
|
+
|
202
|
+
parser = argparse.ArgumentParser(
|
203
|
+
description="Attribute Inference Analysis Arguments"
|
204
|
+
)
|
205
|
+
parser.add_argument(
|
206
|
+
"--training_data_folder",
|
207
|
+
dest="training_data_folder",
|
208
|
+
action="store",
|
209
|
+
help="The path for where the training data folder",
|
210
|
+
required=True,
|
211
|
+
)
|
212
|
+
parser.add_argument(
|
213
|
+
"--synthetic_data_folder",
|
214
|
+
dest="synthetic_data_folder",
|
215
|
+
action="store",
|
216
|
+
help="The path for where the synthetic data folder",
|
217
|
+
required=True,
|
218
|
+
)
|
219
|
+
parser.add_argument(
|
220
|
+
"--output_folder",
|
221
|
+
dest="output_folder",
|
222
|
+
action="store",
|
223
|
+
help="The output folder that stores the metrics",
|
224
|
+
required=True,
|
225
|
+
)
|
226
|
+
parser.add_argument(
|
227
|
+
"--tokenizer_path",
|
228
|
+
dest="tokenizer_path",
|
229
|
+
action="store",
|
230
|
+
help="The path to ConceptTokenizer",
|
231
|
+
required=True,
|
232
|
+
)
|
233
|
+
parser.add_argument(
|
234
|
+
"--attribute_config",
|
235
|
+
dest="attribute_config",
|
236
|
+
action="store",
|
237
|
+
help="The configuration yaml file for common and sensitive attributes",
|
238
|
+
required=True,
|
239
|
+
)
|
240
|
+
parser.add_argument(
|
241
|
+
"--batch_size",
|
242
|
+
dest="batch_size",
|
243
|
+
action="store",
|
244
|
+
type=int,
|
245
|
+
default=1000,
|
246
|
+
help="The batch size of the matching algorithm",
|
247
|
+
required=False,
|
248
|
+
)
|
249
|
+
parser.add_argument(
|
250
|
+
"--batched",
|
251
|
+
dest="batched",
|
252
|
+
action="store_true",
|
253
|
+
help="Indicate whether we want to use the batch matrix operation",
|
254
|
+
)
|
255
|
+
parser.add_argument(
|
256
|
+
"--num_of_samples",
|
257
|
+
dest="num_of_samples",
|
258
|
+
action="store",
|
259
|
+
type=int,
|
260
|
+
required=False,
|
261
|
+
default=5000,
|
262
|
+
)
|
263
|
+
parser.add_argument(
|
264
|
+
"--n_iterations",
|
265
|
+
dest="n_iterations",
|
266
|
+
action="store",
|
267
|
+
type=int,
|
268
|
+
required=False,
|
269
|
+
default=1,
|
270
|
+
)
|
271
|
+
return parser
|
272
|
+
|
273
|
+
|
274
|
+
if __name__ == "__main__":
|
275
|
+
main(create_argparser().parse_args())
|