cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
__init__.py ADDED
File without changes
cehrgpt/__init__.py ADDED
File without changes
File without changes
File without changes
@@ -0,0 +1,275 @@
1
+ import logging
2
+ import os
3
+ import random
4
+ import sys
5
+ from datetime import datetime
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import yaml
11
+
12
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
13
+
14
+ from .utils import (
15
+ batched_pairwise_euclidean_distance_indices,
16
+ create_demographics,
17
+ create_gender_encoder,
18
+ create_race_encoder,
19
+ create_vector_representations_for_attribute,
20
+ find_match,
21
+ find_match_self,
22
+ scale_age,
23
+ )
24
+
25
+ RANDOM_SEE = 42
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
28
+ )
29
+ LOG = logging.getLogger("attribute_inference")
30
+
31
+
32
+ def safe_divide(numerator_vector, denominator_vector):
33
+ return np.where(denominator_vector > 0, numerator_vector / denominator_vector, 0)
34
+
35
+
36
+ def cal_f1_score(vector_a, vector_b, index_matrix):
37
+ # vector_a is train data and vector_b is synthetic data or iteself
38
+ shared_vector = np.logical_and(
39
+ vector_a[: len(index_matrix)], vector_b[index_matrix]
40
+ ).astype(int)
41
+ shared_vector_sum = np.sum(shared_vector, axis=1)
42
+
43
+ precision = safe_divide(shared_vector_sum, np.sum(vector_b, axis=1))
44
+ recall = safe_divide(shared_vector_sum, np.sum(vector_a, axis=1))
45
+
46
+ f1 = safe_divide(2 * recall * precision, recall + precision)
47
+ return f1, precision, recall
48
+
49
+
50
+ def main(args):
51
+ try:
52
+ with open(args.attribute_config, "r") as file:
53
+ data = yaml.safe_load(file)
54
+ if "common_attributes" in data:
55
+ common_attributes = data["common_attributes"]
56
+ if "sensitive_attributes" in data:
57
+ sensitive_attributes = data["sensitive_attributes"]
58
+ except Union[FileNotFoundError, PermissionError, OSError] as e:
59
+ sys.exit(e)
60
+
61
+ attribute_inference_folder = os.path.join(args.output_folder, "attribute_inference")
62
+ if not os.path.exists(attribute_inference_folder):
63
+ LOG.info(
64
+ f"Creating the attribute_inference output folder at {attribute_inference_folder}"
65
+ )
66
+ os.makedirs(attribute_inference_folder, exist_ok=True)
67
+
68
+ LOG.info(f"Started loading tokenizer at {args.tokenizer_path}")
69
+ concept_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_path)
70
+
71
+ LOG.info(f"Started loading training data at {args.training_data_folder}")
72
+ train_data = pd.read_parquet(args.training_data_folder)
73
+
74
+ LOG.info(f"Started loading synthetic_data at {args.synthetic_data_folder}")
75
+ synthetic_data = pd.read_parquet(args.synthetic_data_folder)
76
+
77
+ LOG.info(
78
+ "Started extracting the demographic information from the patient sequences"
79
+ )
80
+ train_data = create_demographics(train_data)
81
+ synthetic_data = create_demographics(synthetic_data)
82
+
83
+ LOG.info("Started rescaling age columns")
84
+ train_data = scale_age(train_data)
85
+ synthetic_data = scale_age(synthetic_data)
86
+
87
+ LOG.info("Started encoding gender")
88
+ gender_encoder = create_gender_encoder(
89
+ train_data,
90
+ # TODO need to change this function to be generic
91
+ train_data[:10],
92
+ synthetic_data,
93
+ )
94
+ LOG.info("Completed encoding gender")
95
+
96
+ LOG.info("Started encoding race")
97
+ race_encoder = create_race_encoder(
98
+ train_data,
99
+ # TODO need to change this function to be generic
100
+ train_data[:10],
101
+ synthetic_data,
102
+ )
103
+ LOG.info("Completed encoding race")
104
+
105
+ random.seed(RANDOM_SEE)
106
+ for i in range(1, args.n_iterations + 1):
107
+ LOG.info(f"Iteration {i}: Started creating data samples")
108
+ train_data_sample = train_data.sample(args.num_of_samples)
109
+ synthetic_data_sample = synthetic_data.sample(args.num_of_samples)
110
+ LOG.info(f"Iteration {i}: Started creating train sample vectors")
111
+ train_common_vectors, train_sensitive_vectors = (
112
+ create_vector_representations_for_attribute(
113
+ train_data_sample,
114
+ concept_tokenizer,
115
+ gender_encoder,
116
+ race_encoder,
117
+ common_attributes=common_attributes,
118
+ sensitive_attributes=sensitive_attributes,
119
+ )
120
+ )
121
+
122
+ LOG.info(f"Iteration {i}: Started creating synthetic vectors")
123
+ synthetic_common_vectors, synthetic_sensitive_vectors = (
124
+ create_vector_representations_for_attribute(
125
+ synthetic_data_sample,
126
+ concept_tokenizer,
127
+ gender_encoder,
128
+ race_encoder,
129
+ common_attributes=common_attributes,
130
+ sensitive_attributes=sensitive_attributes,
131
+ )
132
+ )
133
+
134
+ LOG.info(
135
+ f"Started calculating the distances between synthetic and training vectors"
136
+ )
137
+ if args.batched:
138
+ train_synthetic_index = batched_pairwise_euclidean_distance_indices(
139
+ train_common_vectors,
140
+ synthetic_common_vectors,
141
+ batch_size=args.batch_size,
142
+ )
143
+ train_train_index = batched_pairwise_euclidean_distance_indices(
144
+ train_common_vectors,
145
+ train_common_vectors,
146
+ batch_size=args.batch_size,
147
+ self_exclude=True,
148
+ )
149
+ else:
150
+ train_synthetic_index = find_match(
151
+ train_common_vectors, synthetic_common_vectors, return_index=True
152
+ )
153
+ train_train_index = find_match_self(
154
+ train_common_vectors, train_common_vectors, return_index=True
155
+ )
156
+
157
+ f1_syn_train, precision_syn_train, recall_syn_train = cal_f1_score(
158
+ train_sensitive_vectors, synthetic_sensitive_vectors, train_synthetic_index
159
+ )
160
+ f1_train_train, precision_train_train, recall_train_train = cal_f1_score(
161
+ train_sensitive_vectors, train_sensitive_vectors, train_train_index
162
+ )
163
+
164
+ results = {
165
+ "Precision Synthetic Train": precision_syn_train,
166
+ "Recall Synthetic Train": recall_syn_train,
167
+ "F1 Synthetic Train": f1_syn_train,
168
+ "Precision Train Train": precision_train_train,
169
+ "Recall Train Train": recall_train_train,
170
+ "F1 Train Train": f1_train_train,
171
+ }
172
+ LOG.info(
173
+ f"Attribute Inference: Average Precision Synthetic Train: {np.mean(precision_syn_train)} \n"
174
+ f"Attribute Inference: Average Recall Synthetic Train:{np.mean(recall_syn_train)} \n"
175
+ f"Attribute Inference: Average F1 Synthetic Train: {np.mean(f1_syn_train)} \n"
176
+ f"Attribute Inference: Average Precision Train Train: {np.mean(precision_train_train)} \n"
177
+ f"Attribute Inference: Average Recall Train Train: {np.mean(recall_train_train)} \n"
178
+ f"Attribute Inference: Average F1 Train Train: {np.mean(f1_train_train)}"
179
+ )
180
+ current_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
181
+ pd.DataFrame(
182
+ results,
183
+ columns=[
184
+ "Precision Synthetic Train",
185
+ "Recall Synthetic Train",
186
+ "F1 Synthetic Train",
187
+ "Precision Train Train",
188
+ "Recall Train Train",
189
+ "F1 Train Train",
190
+ ],
191
+ ).to_parquet(
192
+ os.path.join(
193
+ attribute_inference_folder,
194
+ f"attribute_inference_{current_time}.parquet",
195
+ )
196
+ )
197
+
198
+
199
+ def create_argparser():
200
+ import argparse
201
+
202
+ parser = argparse.ArgumentParser(
203
+ description="Attribute Inference Analysis Arguments"
204
+ )
205
+ parser.add_argument(
206
+ "--training_data_folder",
207
+ dest="training_data_folder",
208
+ action="store",
209
+ help="The path for where the training data folder",
210
+ required=True,
211
+ )
212
+ parser.add_argument(
213
+ "--synthetic_data_folder",
214
+ dest="synthetic_data_folder",
215
+ action="store",
216
+ help="The path for where the synthetic data folder",
217
+ required=True,
218
+ )
219
+ parser.add_argument(
220
+ "--output_folder",
221
+ dest="output_folder",
222
+ action="store",
223
+ help="The output folder that stores the metrics",
224
+ required=True,
225
+ )
226
+ parser.add_argument(
227
+ "--tokenizer_path",
228
+ dest="tokenizer_path",
229
+ action="store",
230
+ help="The path to ConceptTokenizer",
231
+ required=True,
232
+ )
233
+ parser.add_argument(
234
+ "--attribute_config",
235
+ dest="attribute_config",
236
+ action="store",
237
+ help="The configuration yaml file for common and sensitive attributes",
238
+ required=True,
239
+ )
240
+ parser.add_argument(
241
+ "--batch_size",
242
+ dest="batch_size",
243
+ action="store",
244
+ type=int,
245
+ default=1000,
246
+ help="The batch size of the matching algorithm",
247
+ required=False,
248
+ )
249
+ parser.add_argument(
250
+ "--batched",
251
+ dest="batched",
252
+ action="store_true",
253
+ help="Indicate whether we want to use the batch matrix operation",
254
+ )
255
+ parser.add_argument(
256
+ "--num_of_samples",
257
+ dest="num_of_samples",
258
+ action="store",
259
+ type=int,
260
+ required=False,
261
+ default=5000,
262
+ )
263
+ parser.add_argument(
264
+ "--n_iterations",
265
+ dest="n_iterations",
266
+ action="store",
267
+ type=int,
268
+ required=False,
269
+ default=1,
270
+ )
271
+ return parser
272
+
273
+
274
+ if __name__ == "__main__":
275
+ main(create_argparser().parse_args())