cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,172 @@
1
+ import logging
2
+ import os
3
+ import random
4
+ from datetime import datetime
5
+
6
+ import pandas as pd
7
+ from sklearn import metrics
8
+
9
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
10
+
11
+ from .utils import (
12
+ RANDOM_SEE,
13
+ create_demographics,
14
+ create_gender_encoder,
15
+ create_race_encoder,
16
+ create_vector_representations,
17
+ find_match,
18
+ scale_age,
19
+ )
20
+
21
+ logging.basicConfig(
22
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
23
+ )
24
+ LOG = logging.getLogger("member_inference")
25
+ THRESHOLD = 5
26
+
27
+
28
+ def main(args):
29
+ LOG.info(f"Started loading tokenizer at {args.tokenizer_path}")
30
+ concept_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_path)
31
+
32
+ LOG.info(f"Started loading training data at {args.training_data_folder}")
33
+ train_data = pd.read_parquet(args.training_data_folder)
34
+
35
+ LOG.info(f"Started loading evaluation data at {args.evaluation_data_folder}")
36
+ evaluation_data = pd.read_parquet(args.evaluation_data_folder)
37
+
38
+ LOG.info(f"Started loading synthetic_data at {args.synthetic_data_folder}")
39
+ synthetic_data = pd.read_parquet(args.synthetic_data_folder)
40
+
41
+ LOG.info(
42
+ "Started extracting the demographic information from the patient sequences"
43
+ )
44
+ train_data = create_demographics(train_data)
45
+ evaluation_data = create_demographics(evaluation_data)
46
+ synthetic_data = create_demographics(synthetic_data)
47
+
48
+ LOG.info("Started rescaling age columns")
49
+ train_data = scale_age(train_data)
50
+ evaluation_data = scale_age(evaluation_data)
51
+ synthetic_data = scale_age(synthetic_data)
52
+
53
+ LOG.info("Started encoding gender")
54
+ gender_encoder = create_gender_encoder(train_data, evaluation_data, synthetic_data)
55
+ LOG.info("Completed encoding gender")
56
+
57
+ LOG.info("Started encoding race")
58
+ race_encoder = create_race_encoder(train_data, evaluation_data, synthetic_data)
59
+ LOG.info("Completed encoding race")
60
+
61
+ random.seed(RANDOM_SEE)
62
+ all_results = []
63
+ for i in range(1, args.n_iterations + 1):
64
+ dist_metrics = []
65
+ LOG.info(f"Iteration {i} Started")
66
+ train_data_sample = train_data.sample(args.num_of_samples)
67
+ evaluation_data_sample = evaluation_data.sample(args.num_of_samples)
68
+ synthetic_data_sample = synthetic_data.sample(args.num_of_samples)
69
+ LOG.info(f"Iteration {i}: Started creating train vectors")
70
+ train_vectors = create_vector_representations(
71
+ train_data_sample, concept_tokenizer, gender_encoder, race_encoder
72
+ )
73
+ LOG.info(f"Iteration {i}: Started creating evaluation vectors")
74
+ evaluation_vectors = create_vector_representations(
75
+ evaluation_data_sample, concept_tokenizer, gender_encoder, race_encoder
76
+ )
77
+ LOG.info(f"Iteration {i}: Started creating synthetic vectors")
78
+ synthetic_vectors = create_vector_representations(
79
+ synthetic_data_sample, concept_tokenizer, gender_encoder, race_encoder
80
+ )
81
+ LOG.info(
82
+ f"Iteration {i}: Started calculating the distances between synthetic and training vectors"
83
+ )
84
+ synthetic_train_dist = find_match(synthetic_vectors, train_vectors)
85
+ synthetic_evaluation_dist = find_match(synthetic_vectors, evaluation_vectors)
86
+
87
+ dist_metrics.extend([(_, 1) for _ in synthetic_train_dist])
88
+ dist_metrics.extend([(_, 0) for _ in synthetic_evaluation_dist])
89
+
90
+ metrics_pd = pd.DataFrame(dist_metrics, columns=["dist", "label"])
91
+ metrics_pd["pred"] = (metrics_pd.dist < THRESHOLD).astype(int)
92
+
93
+ results = {
94
+ "Iteration": i,
95
+ "Accuracy": metrics.accuracy_score(metrics_pd.label, metrics_pd.pred),
96
+ "Precision": metrics.precision_score(metrics_pd.label, metrics_pd.pred),
97
+ "Recall": metrics.recall_score(metrics_pd.label, metrics_pd.pred),
98
+ "F1": metrics.f1_score(metrics_pd.label, metrics_pd.pred),
99
+ }
100
+ all_results.append(results)
101
+ LOG.info(f"Iteration {i}: Privacy loss {results}")
102
+
103
+ current_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
104
+ pd.DataFrame(
105
+ all_results, columns=["Iteration", "Accuracy", "Precision", "Recall", "F1"]
106
+ ).to_parquet(
107
+ os.path.join(args.output_folder, f"membership_inference_{current_time}.parquet")
108
+ )
109
+
110
+
111
+ def create_argparser():
112
+ import argparse
113
+
114
+ parser = argparse.ArgumentParser(
115
+ description="Membership Inference Analysis Arguments"
116
+ )
117
+ parser.add_argument(
118
+ "--training_data_folder",
119
+ dest="training_data_folder",
120
+ action="store",
121
+ help="The path for where the training data folder",
122
+ required=True,
123
+ )
124
+ parser.add_argument(
125
+ "--evaluation_data_folder",
126
+ dest="evaluation_data_folder",
127
+ action="store",
128
+ help="The path for where the evaluation data folder",
129
+ required=True,
130
+ )
131
+ parser.add_argument(
132
+ "--synthetic_data_folder",
133
+ dest="synthetic_data_folder",
134
+ action="store",
135
+ help="The path for where the synthetic data folder",
136
+ required=True,
137
+ )
138
+ parser.add_argument(
139
+ "--output_folder",
140
+ dest="output_folder",
141
+ action="store",
142
+ help="The output folder that stores the metrics",
143
+ required=True,
144
+ )
145
+ parser.add_argument(
146
+ "--tokenizer_path",
147
+ dest="tokenizer_path",
148
+ action="store",
149
+ help="The path to ConceptTokenizer",
150
+ required=True,
151
+ )
152
+ parser.add_argument(
153
+ "--num_of_samples",
154
+ dest="num_of_samples",
155
+ action="store",
156
+ type=int,
157
+ required=False,
158
+ default=5000,
159
+ )
160
+ parser.add_argument(
161
+ "--n_iterations",
162
+ dest="n_iterations",
163
+ action="store",
164
+ type=int,
165
+ required=False,
166
+ default=1,
167
+ )
168
+ return parser
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main(create_argparser().parse_args())
@@ -0,0 +1,189 @@
1
+ import logging
2
+ import os
3
+ import random
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
10
+
11
+ from .utils import (
12
+ RANDOM_SEE,
13
+ create_demographics,
14
+ create_gender_encoder,
15
+ create_race_encoder,
16
+ create_vector_representations,
17
+ find_match,
18
+ find_match_self,
19
+ scale_age,
20
+ )
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ LOG = logging.getLogger("NNAA")
24
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
25
+
26
+
27
+ def main(args):
28
+ LOG.info(f"Started loading tokenizer at {args.concept_tokenizer_path}")
29
+ concept_tokenizer = CehrGptTokenizer.from_pretrained(args.concept_tokenizer_path)
30
+
31
+ LOG.info(f"Started loading training data at {args.training_data_folder}")
32
+ train_data = pd.read_parquet(args.training_data_folder)
33
+
34
+ LOG.info(f"Started loading evaluation data at {args.evaluation_data_folder}")
35
+ evaluation_data = pd.read_parquet(args.evaluation_data_folder)
36
+
37
+ LOG.info(f"Started loading synthetic_data at {args.synthetic_data_folder}")
38
+ synthetic_data = pd.read_parquet(args.synthetic_data_folder)
39
+
40
+ LOG.info(
41
+ "Started extracting the demographic information from the patient sequences"
42
+ )
43
+ train_data = create_demographics(train_data)
44
+ evaluation_data = create_demographics(evaluation_data)
45
+ synthetic_data = create_demographics(synthetic_data)
46
+
47
+ LOG.info("Started rescaling age columns")
48
+ train_data = scale_age(train_data)
49
+ evaluation_data = scale_age(evaluation_data)
50
+ synthetic_data = scale_age(synthetic_data)
51
+
52
+ LOG.info("Started encoding gender")
53
+ gender_encoder = create_gender_encoder(train_data, evaluation_data, synthetic_data)
54
+ LOG.info("Completed encoding gender")
55
+
56
+ LOG.info("Started encoding race")
57
+ race_encoder = create_race_encoder(train_data, evaluation_data, synthetic_data)
58
+ LOG.info("Completed encoding race")
59
+
60
+ random.seed(RANDOM_SEE)
61
+ metrics = []
62
+
63
+ for i in range(args.n_iterations):
64
+ LOG.info(f"Iteration {i} Started")
65
+ train_data_sample = train_data.sample(args.num_of_samples)
66
+ evaluation_data_sample = evaluation_data.sample(args.num_of_samples)
67
+ synthetic_data_sample = synthetic_data.sample(args.num_of_samples)
68
+ LOG.info(f"Iteration {i}: Started creating train vectors")
69
+ train_vectors = create_vector_representations(
70
+ train_data_sample, concept_tokenizer, gender_encoder, race_encoder
71
+ )
72
+ LOG.info(f"Iteration {i}: Started creating evaluation vectors")
73
+ evaluation_vectors = create_vector_representations(
74
+ evaluation_data_sample, concept_tokenizer, gender_encoder, race_encoder
75
+ )
76
+ LOG.info(f"Iteration {i}: Started creating synthetic vectors")
77
+ synthetic_vectors = create_vector_representations(
78
+ synthetic_data_sample, concept_tokenizer, gender_encoder, race_encoder
79
+ )
80
+ LOG.info(
81
+ f"Iteration {i}: Started calculating the distances between synthetic and training vectors"
82
+ )
83
+ distance_train_TS = find_match(train_vectors, synthetic_vectors)
84
+ distance_train_ST = find_match(synthetic_vectors, train_vectors)
85
+ distance_train_TT = find_match_self(train_vectors, train_vectors)
86
+ distance_train_SS = find_match_self(synthetic_vectors, synthetic_vectors)
87
+
88
+ aa_train = (
89
+ (
90
+ np.sum(distance_train_TS > distance_train_TT)
91
+ + np.sum(distance_train_ST > distance_train_SS)
92
+ )
93
+ / args.num_of_samples
94
+ / 2
95
+ )
96
+
97
+ LOG.info(
98
+ f"Iteration {i}: Started calculating the distances between synthetic and evaluation vectors"
99
+ )
100
+ distance_test_TS = find_match(evaluation_vectors, synthetic_vectors)
101
+ distance_test_ST = find_match(synthetic_vectors, evaluation_vectors)
102
+ distance_test_TT = find_match_self(evaluation_vectors, evaluation_vectors)
103
+ distance_test_SS = find_match_self(synthetic_vectors, synthetic_vectors)
104
+
105
+ aa_test = (
106
+ (
107
+ np.sum(distance_test_TS > distance_test_TT)
108
+ + np.sum(distance_test_ST > distance_test_SS)
109
+ )
110
+ / args.num_of_samples
111
+ / 2
112
+ )
113
+
114
+ privacy_loss = aa_test - aa_train
115
+ metrics.append(privacy_loss)
116
+ LOG.info(f"Iteration {i}: Privacy loss {privacy_loss}")
117
+
118
+ results = {"NNAAE": metrics}
119
+
120
+ current_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
121
+ pd.DataFrame([results], columns=["NNAAE"]).to_parquet(
122
+ os.path.join(
123
+ args.metrics_folder, f"nearest_neighbor_inference_{current_time}.parquet"
124
+ )
125
+ )
126
+
127
+
128
+ def create_argparser():
129
+ import argparse
130
+
131
+ parser = argparse.ArgumentParser(
132
+ description="Nearest Neighbor Inference Analysis Arguments using the GPT model"
133
+ )
134
+ parser.add_argument(
135
+ "--training_data_folder",
136
+ dest="training_data_folder",
137
+ action="store",
138
+ help="The path for where the training data folder",
139
+ required=True,
140
+ )
141
+ parser.add_argument(
142
+ "--evaluation_data_folder",
143
+ dest="evaluation_data_folder",
144
+ action="store",
145
+ help="The path for where the evaluation data folder",
146
+ required=True,
147
+ )
148
+ parser.add_argument(
149
+ "--synthetic_data_folder",
150
+ dest="synthetic_data_folder",
151
+ action="store",
152
+ help="The path for where the synthetic data folder",
153
+ required=True,
154
+ )
155
+ parser.add_argument(
156
+ "--concept_tokenizer_path",
157
+ dest="concept_tokenizer_path",
158
+ action="store",
159
+ help="The path for where the concept tokenizer is located",
160
+ required=True,
161
+ )
162
+ parser.add_argument(
163
+ "--num_of_samples",
164
+ dest="num_of_samples",
165
+ action="store",
166
+ type=int,
167
+ required=False,
168
+ default=5000,
169
+ )
170
+ parser.add_argument(
171
+ "--n_iterations",
172
+ dest="n_iterations",
173
+ action="store",
174
+ type=int,
175
+ required=False,
176
+ default=1,
177
+ )
178
+ parser.add_argument(
179
+ "--metrics_folder",
180
+ dest="metrics_folder",
181
+ action="store",
182
+ help="The folder that stores the metrics",
183
+ required=False,
184
+ )
185
+ return parser
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main(create_argparser().parse_args())