cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,172 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
from datetime import datetime
|
5
|
+
|
6
|
+
import pandas as pd
|
7
|
+
from sklearn import metrics
|
8
|
+
|
9
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
10
|
+
|
11
|
+
from .utils import (
|
12
|
+
RANDOM_SEE,
|
13
|
+
create_demographics,
|
14
|
+
create_gender_encoder,
|
15
|
+
create_race_encoder,
|
16
|
+
create_vector_representations,
|
17
|
+
find_match,
|
18
|
+
scale_age,
|
19
|
+
)
|
20
|
+
|
21
|
+
logging.basicConfig(
|
22
|
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
23
|
+
)
|
24
|
+
LOG = logging.getLogger("member_inference")
|
25
|
+
THRESHOLD = 5
|
26
|
+
|
27
|
+
|
28
|
+
def main(args):
|
29
|
+
LOG.info(f"Started loading tokenizer at {args.tokenizer_path}")
|
30
|
+
concept_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_path)
|
31
|
+
|
32
|
+
LOG.info(f"Started loading training data at {args.training_data_folder}")
|
33
|
+
train_data = pd.read_parquet(args.training_data_folder)
|
34
|
+
|
35
|
+
LOG.info(f"Started loading evaluation data at {args.evaluation_data_folder}")
|
36
|
+
evaluation_data = pd.read_parquet(args.evaluation_data_folder)
|
37
|
+
|
38
|
+
LOG.info(f"Started loading synthetic_data at {args.synthetic_data_folder}")
|
39
|
+
synthetic_data = pd.read_parquet(args.synthetic_data_folder)
|
40
|
+
|
41
|
+
LOG.info(
|
42
|
+
"Started extracting the demographic information from the patient sequences"
|
43
|
+
)
|
44
|
+
train_data = create_demographics(train_data)
|
45
|
+
evaluation_data = create_demographics(evaluation_data)
|
46
|
+
synthetic_data = create_demographics(synthetic_data)
|
47
|
+
|
48
|
+
LOG.info("Started rescaling age columns")
|
49
|
+
train_data = scale_age(train_data)
|
50
|
+
evaluation_data = scale_age(evaluation_data)
|
51
|
+
synthetic_data = scale_age(synthetic_data)
|
52
|
+
|
53
|
+
LOG.info("Started encoding gender")
|
54
|
+
gender_encoder = create_gender_encoder(train_data, evaluation_data, synthetic_data)
|
55
|
+
LOG.info("Completed encoding gender")
|
56
|
+
|
57
|
+
LOG.info("Started encoding race")
|
58
|
+
race_encoder = create_race_encoder(train_data, evaluation_data, synthetic_data)
|
59
|
+
LOG.info("Completed encoding race")
|
60
|
+
|
61
|
+
random.seed(RANDOM_SEE)
|
62
|
+
all_results = []
|
63
|
+
for i in range(1, args.n_iterations + 1):
|
64
|
+
dist_metrics = []
|
65
|
+
LOG.info(f"Iteration {i} Started")
|
66
|
+
train_data_sample = train_data.sample(args.num_of_samples)
|
67
|
+
evaluation_data_sample = evaluation_data.sample(args.num_of_samples)
|
68
|
+
synthetic_data_sample = synthetic_data.sample(args.num_of_samples)
|
69
|
+
LOG.info(f"Iteration {i}: Started creating train vectors")
|
70
|
+
train_vectors = create_vector_representations(
|
71
|
+
train_data_sample, concept_tokenizer, gender_encoder, race_encoder
|
72
|
+
)
|
73
|
+
LOG.info(f"Iteration {i}: Started creating evaluation vectors")
|
74
|
+
evaluation_vectors = create_vector_representations(
|
75
|
+
evaluation_data_sample, concept_tokenizer, gender_encoder, race_encoder
|
76
|
+
)
|
77
|
+
LOG.info(f"Iteration {i}: Started creating synthetic vectors")
|
78
|
+
synthetic_vectors = create_vector_representations(
|
79
|
+
synthetic_data_sample, concept_tokenizer, gender_encoder, race_encoder
|
80
|
+
)
|
81
|
+
LOG.info(
|
82
|
+
f"Iteration {i}: Started calculating the distances between synthetic and training vectors"
|
83
|
+
)
|
84
|
+
synthetic_train_dist = find_match(synthetic_vectors, train_vectors)
|
85
|
+
synthetic_evaluation_dist = find_match(synthetic_vectors, evaluation_vectors)
|
86
|
+
|
87
|
+
dist_metrics.extend([(_, 1) for _ in synthetic_train_dist])
|
88
|
+
dist_metrics.extend([(_, 0) for _ in synthetic_evaluation_dist])
|
89
|
+
|
90
|
+
metrics_pd = pd.DataFrame(dist_metrics, columns=["dist", "label"])
|
91
|
+
metrics_pd["pred"] = (metrics_pd.dist < THRESHOLD).astype(int)
|
92
|
+
|
93
|
+
results = {
|
94
|
+
"Iteration": i,
|
95
|
+
"Accuracy": metrics.accuracy_score(metrics_pd.label, metrics_pd.pred),
|
96
|
+
"Precision": metrics.precision_score(metrics_pd.label, metrics_pd.pred),
|
97
|
+
"Recall": metrics.recall_score(metrics_pd.label, metrics_pd.pred),
|
98
|
+
"F1": metrics.f1_score(metrics_pd.label, metrics_pd.pred),
|
99
|
+
}
|
100
|
+
all_results.append(results)
|
101
|
+
LOG.info(f"Iteration {i}: Privacy loss {results}")
|
102
|
+
|
103
|
+
current_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
|
104
|
+
pd.DataFrame(
|
105
|
+
all_results, columns=["Iteration", "Accuracy", "Precision", "Recall", "F1"]
|
106
|
+
).to_parquet(
|
107
|
+
os.path.join(args.output_folder, f"membership_inference_{current_time}.parquet")
|
108
|
+
)
|
109
|
+
|
110
|
+
|
111
|
+
def create_argparser():
|
112
|
+
import argparse
|
113
|
+
|
114
|
+
parser = argparse.ArgumentParser(
|
115
|
+
description="Membership Inference Analysis Arguments"
|
116
|
+
)
|
117
|
+
parser.add_argument(
|
118
|
+
"--training_data_folder",
|
119
|
+
dest="training_data_folder",
|
120
|
+
action="store",
|
121
|
+
help="The path for where the training data folder",
|
122
|
+
required=True,
|
123
|
+
)
|
124
|
+
parser.add_argument(
|
125
|
+
"--evaluation_data_folder",
|
126
|
+
dest="evaluation_data_folder",
|
127
|
+
action="store",
|
128
|
+
help="The path for where the evaluation data folder",
|
129
|
+
required=True,
|
130
|
+
)
|
131
|
+
parser.add_argument(
|
132
|
+
"--synthetic_data_folder",
|
133
|
+
dest="synthetic_data_folder",
|
134
|
+
action="store",
|
135
|
+
help="The path for where the synthetic data folder",
|
136
|
+
required=True,
|
137
|
+
)
|
138
|
+
parser.add_argument(
|
139
|
+
"--output_folder",
|
140
|
+
dest="output_folder",
|
141
|
+
action="store",
|
142
|
+
help="The output folder that stores the metrics",
|
143
|
+
required=True,
|
144
|
+
)
|
145
|
+
parser.add_argument(
|
146
|
+
"--tokenizer_path",
|
147
|
+
dest="tokenizer_path",
|
148
|
+
action="store",
|
149
|
+
help="The path to ConceptTokenizer",
|
150
|
+
required=True,
|
151
|
+
)
|
152
|
+
parser.add_argument(
|
153
|
+
"--num_of_samples",
|
154
|
+
dest="num_of_samples",
|
155
|
+
action="store",
|
156
|
+
type=int,
|
157
|
+
required=False,
|
158
|
+
default=5000,
|
159
|
+
)
|
160
|
+
parser.add_argument(
|
161
|
+
"--n_iterations",
|
162
|
+
dest="n_iterations",
|
163
|
+
action="store",
|
164
|
+
type=int,
|
165
|
+
required=False,
|
166
|
+
default=1,
|
167
|
+
)
|
168
|
+
return parser
|
169
|
+
|
170
|
+
|
171
|
+
if __name__ == "__main__":
|
172
|
+
main(create_argparser().parse_args())
|
@@ -0,0 +1,189 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
from datetime import datetime
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
|
9
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
10
|
+
|
11
|
+
from .utils import (
|
12
|
+
RANDOM_SEE,
|
13
|
+
create_demographics,
|
14
|
+
create_gender_encoder,
|
15
|
+
create_race_encoder,
|
16
|
+
create_vector_representations,
|
17
|
+
find_match,
|
18
|
+
find_match_self,
|
19
|
+
scale_age,
|
20
|
+
)
|
21
|
+
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
23
|
+
LOG = logging.getLogger("NNAA")
|
24
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
25
|
+
|
26
|
+
|
27
|
+
def main(args):
|
28
|
+
LOG.info(f"Started loading tokenizer at {args.concept_tokenizer_path}")
|
29
|
+
concept_tokenizer = CehrGptTokenizer.from_pretrained(args.concept_tokenizer_path)
|
30
|
+
|
31
|
+
LOG.info(f"Started loading training data at {args.training_data_folder}")
|
32
|
+
train_data = pd.read_parquet(args.training_data_folder)
|
33
|
+
|
34
|
+
LOG.info(f"Started loading evaluation data at {args.evaluation_data_folder}")
|
35
|
+
evaluation_data = pd.read_parquet(args.evaluation_data_folder)
|
36
|
+
|
37
|
+
LOG.info(f"Started loading synthetic_data at {args.synthetic_data_folder}")
|
38
|
+
synthetic_data = pd.read_parquet(args.synthetic_data_folder)
|
39
|
+
|
40
|
+
LOG.info(
|
41
|
+
"Started extracting the demographic information from the patient sequences"
|
42
|
+
)
|
43
|
+
train_data = create_demographics(train_data)
|
44
|
+
evaluation_data = create_demographics(evaluation_data)
|
45
|
+
synthetic_data = create_demographics(synthetic_data)
|
46
|
+
|
47
|
+
LOG.info("Started rescaling age columns")
|
48
|
+
train_data = scale_age(train_data)
|
49
|
+
evaluation_data = scale_age(evaluation_data)
|
50
|
+
synthetic_data = scale_age(synthetic_data)
|
51
|
+
|
52
|
+
LOG.info("Started encoding gender")
|
53
|
+
gender_encoder = create_gender_encoder(train_data, evaluation_data, synthetic_data)
|
54
|
+
LOG.info("Completed encoding gender")
|
55
|
+
|
56
|
+
LOG.info("Started encoding race")
|
57
|
+
race_encoder = create_race_encoder(train_data, evaluation_data, synthetic_data)
|
58
|
+
LOG.info("Completed encoding race")
|
59
|
+
|
60
|
+
random.seed(RANDOM_SEE)
|
61
|
+
metrics = []
|
62
|
+
|
63
|
+
for i in range(args.n_iterations):
|
64
|
+
LOG.info(f"Iteration {i} Started")
|
65
|
+
train_data_sample = train_data.sample(args.num_of_samples)
|
66
|
+
evaluation_data_sample = evaluation_data.sample(args.num_of_samples)
|
67
|
+
synthetic_data_sample = synthetic_data.sample(args.num_of_samples)
|
68
|
+
LOG.info(f"Iteration {i}: Started creating train vectors")
|
69
|
+
train_vectors = create_vector_representations(
|
70
|
+
train_data_sample, concept_tokenizer, gender_encoder, race_encoder
|
71
|
+
)
|
72
|
+
LOG.info(f"Iteration {i}: Started creating evaluation vectors")
|
73
|
+
evaluation_vectors = create_vector_representations(
|
74
|
+
evaluation_data_sample, concept_tokenizer, gender_encoder, race_encoder
|
75
|
+
)
|
76
|
+
LOG.info(f"Iteration {i}: Started creating synthetic vectors")
|
77
|
+
synthetic_vectors = create_vector_representations(
|
78
|
+
synthetic_data_sample, concept_tokenizer, gender_encoder, race_encoder
|
79
|
+
)
|
80
|
+
LOG.info(
|
81
|
+
f"Iteration {i}: Started calculating the distances between synthetic and training vectors"
|
82
|
+
)
|
83
|
+
distance_train_TS = find_match(train_vectors, synthetic_vectors)
|
84
|
+
distance_train_ST = find_match(synthetic_vectors, train_vectors)
|
85
|
+
distance_train_TT = find_match_self(train_vectors, train_vectors)
|
86
|
+
distance_train_SS = find_match_self(synthetic_vectors, synthetic_vectors)
|
87
|
+
|
88
|
+
aa_train = (
|
89
|
+
(
|
90
|
+
np.sum(distance_train_TS > distance_train_TT)
|
91
|
+
+ np.sum(distance_train_ST > distance_train_SS)
|
92
|
+
)
|
93
|
+
/ args.num_of_samples
|
94
|
+
/ 2
|
95
|
+
)
|
96
|
+
|
97
|
+
LOG.info(
|
98
|
+
f"Iteration {i}: Started calculating the distances between synthetic and evaluation vectors"
|
99
|
+
)
|
100
|
+
distance_test_TS = find_match(evaluation_vectors, synthetic_vectors)
|
101
|
+
distance_test_ST = find_match(synthetic_vectors, evaluation_vectors)
|
102
|
+
distance_test_TT = find_match_self(evaluation_vectors, evaluation_vectors)
|
103
|
+
distance_test_SS = find_match_self(synthetic_vectors, synthetic_vectors)
|
104
|
+
|
105
|
+
aa_test = (
|
106
|
+
(
|
107
|
+
np.sum(distance_test_TS > distance_test_TT)
|
108
|
+
+ np.sum(distance_test_ST > distance_test_SS)
|
109
|
+
)
|
110
|
+
/ args.num_of_samples
|
111
|
+
/ 2
|
112
|
+
)
|
113
|
+
|
114
|
+
privacy_loss = aa_test - aa_train
|
115
|
+
metrics.append(privacy_loss)
|
116
|
+
LOG.info(f"Iteration {i}: Privacy loss {privacy_loss}")
|
117
|
+
|
118
|
+
results = {"NNAAE": metrics}
|
119
|
+
|
120
|
+
current_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
|
121
|
+
pd.DataFrame([results], columns=["NNAAE"]).to_parquet(
|
122
|
+
os.path.join(
|
123
|
+
args.metrics_folder, f"nearest_neighbor_inference_{current_time}.parquet"
|
124
|
+
)
|
125
|
+
)
|
126
|
+
|
127
|
+
|
128
|
+
def create_argparser():
|
129
|
+
import argparse
|
130
|
+
|
131
|
+
parser = argparse.ArgumentParser(
|
132
|
+
description="Nearest Neighbor Inference Analysis Arguments using the GPT model"
|
133
|
+
)
|
134
|
+
parser.add_argument(
|
135
|
+
"--training_data_folder",
|
136
|
+
dest="training_data_folder",
|
137
|
+
action="store",
|
138
|
+
help="The path for where the training data folder",
|
139
|
+
required=True,
|
140
|
+
)
|
141
|
+
parser.add_argument(
|
142
|
+
"--evaluation_data_folder",
|
143
|
+
dest="evaluation_data_folder",
|
144
|
+
action="store",
|
145
|
+
help="The path for where the evaluation data folder",
|
146
|
+
required=True,
|
147
|
+
)
|
148
|
+
parser.add_argument(
|
149
|
+
"--synthetic_data_folder",
|
150
|
+
dest="synthetic_data_folder",
|
151
|
+
action="store",
|
152
|
+
help="The path for where the synthetic data folder",
|
153
|
+
required=True,
|
154
|
+
)
|
155
|
+
parser.add_argument(
|
156
|
+
"--concept_tokenizer_path",
|
157
|
+
dest="concept_tokenizer_path",
|
158
|
+
action="store",
|
159
|
+
help="The path for where the concept tokenizer is located",
|
160
|
+
required=True,
|
161
|
+
)
|
162
|
+
parser.add_argument(
|
163
|
+
"--num_of_samples",
|
164
|
+
dest="num_of_samples",
|
165
|
+
action="store",
|
166
|
+
type=int,
|
167
|
+
required=False,
|
168
|
+
default=5000,
|
169
|
+
)
|
170
|
+
parser.add_argument(
|
171
|
+
"--n_iterations",
|
172
|
+
dest="n_iterations",
|
173
|
+
action="store",
|
174
|
+
type=int,
|
175
|
+
required=False,
|
176
|
+
default=1,
|
177
|
+
)
|
178
|
+
parser.add_argument(
|
179
|
+
"--metrics_folder",
|
180
|
+
dest="metrics_folder",
|
181
|
+
action="store",
|
182
|
+
help="The folder that stores the metrics",
|
183
|
+
required=False,
|
184
|
+
)
|
185
|
+
return parser
|
186
|
+
|
187
|
+
|
188
|
+
if __name__ == "__main__":
|
189
|
+
main(create_argparser().parse_args())
|