cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,407 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
from pyspark.sql import SparkSession
|
4
|
+
from pyspark.sql import functions as f
|
5
|
+
from pyspark.sql import types as t
|
6
|
+
from tqdm import tqdm
|
7
|
+
|
8
|
+
VOCAB_SIZE = 17870
|
9
|
+
L = 0.001
|
10
|
+
SENSITIVE_MATCH_THRESHOLD = int(VOCAB_SIZE * L)
|
11
|
+
|
12
|
+
# This number came from the tokenizer, essentially all numeric tokens that represent valid OMOP concept ids
|
13
|
+
COMMON_ATTRIBUTES = [
|
14
|
+
"320128",
|
15
|
+
"200219",
|
16
|
+
"77670",
|
17
|
+
"432867",
|
18
|
+
"254761",
|
19
|
+
"312437",
|
20
|
+
"378253",
|
21
|
+
]
|
22
|
+
MAX_LEVELS = [1, 0, 0] + [1] * len(COMMON_ATTRIBUTES)
|
23
|
+
# The age groups are generated using the following bases e.g. 1 indicates using the age as is
|
24
|
+
AGE_GENERALIZATION_LEVELS = [10, 1]
|
25
|
+
|
26
|
+
|
27
|
+
def generate_lattice_bfs(top_gen_levels): # BFS
|
28
|
+
"""Came from https://github.com/yy6linda/synthetic-ehr-benchmarking/blob/main/privacy_evaluation/Synthetic_risk_model_reid.py#L81."""
|
29
|
+
visited = [top_gen_levels]
|
30
|
+
queue = [top_gen_levels]
|
31
|
+
lattice = []
|
32
|
+
while queue:
|
33
|
+
gen_levels = queue.pop(0)
|
34
|
+
lattice.append(gen_levels)
|
35
|
+
for i in range(len(gen_levels)):
|
36
|
+
if gen_levels[i] != 0:
|
37
|
+
gen_levels_new = gen_levels.copy()
|
38
|
+
gen_levels_new[i] -= 1
|
39
|
+
if not gen_levels_new in visited:
|
40
|
+
visited.append(gen_levels_new)
|
41
|
+
queue.append(gen_levels_new)
|
42
|
+
return lattice
|
43
|
+
|
44
|
+
|
45
|
+
def update_dataset(reid_data, real_sample, synthetic_reid_data, config):
|
46
|
+
reid_data_dup = reid_data.alias("reid_data_dup")
|
47
|
+
real_sample_dup = real_sample.alias("real_sample_dup")
|
48
|
+
synthetic_reid_data_dup = synthetic_reid_data.alias("synthetic_reid_data_dup")
|
49
|
+
common_attributes_to_remove = []
|
50
|
+
|
51
|
+
for i in range(len(config)):
|
52
|
+
# age group
|
53
|
+
if i == 0:
|
54
|
+
age_level_index = config[i]
|
55
|
+
age_group_base = AGE_GENERALIZATION_LEVELS[age_level_index]
|
56
|
+
reid_data_dup = reid_data_dup.withColumn(
|
57
|
+
"age", f.ceil(f.col("age").cast("int") / age_group_base)
|
58
|
+
)
|
59
|
+
synthetic_reid_data_dup = synthetic_reid_data_dup.withColumn(
|
60
|
+
"age", f.ceil(f.col("age").cast("int") / age_group_base)
|
61
|
+
)
|
62
|
+
real_sample_dup = real_sample_dup.withColumn(
|
63
|
+
"age", f.ceil(f.col("age").cast("int") / age_group_base)
|
64
|
+
)
|
65
|
+
elif i in [1, 2]:
|
66
|
+
# gender and race are not generalized
|
67
|
+
continue
|
68
|
+
else:
|
69
|
+
# this indicates that the common attribute should be removed
|
70
|
+
if config[i] == 0:
|
71
|
+
common_attributes_to_remove.append(COMMON_ATTRIBUTES[i - 3])
|
72
|
+
|
73
|
+
def remove_common_attributes(concept_ids):
|
74
|
+
comm_atts = sorted(set(concept_ids) - set(common_attributes_to_remove))
|
75
|
+
if comm_atts:
|
76
|
+
return "-".join(comm_atts)
|
77
|
+
else:
|
78
|
+
return "empty"
|
79
|
+
|
80
|
+
extract_common_attributes_udf = f.udf(
|
81
|
+
lambda concept_ids: remove_common_attributes(concept_ids), t.StringType()
|
82
|
+
)
|
83
|
+
|
84
|
+
reid_data_dup = reid_data_dup.withColumn(
|
85
|
+
"common_attributes", extract_common_attributes_udf("common_attributes")
|
86
|
+
)
|
87
|
+
|
88
|
+
real_sample_dup = real_sample_dup.withColumn(
|
89
|
+
"common_attributes", extract_common_attributes_udf("common_attributes")
|
90
|
+
)
|
91
|
+
|
92
|
+
synthetic_reid_data_dup = synthetic_reid_data_dup.withColumn(
|
93
|
+
"common_attributes", extract_common_attributes_udf("common_attributes")
|
94
|
+
)
|
95
|
+
return reid_data_dup, real_sample_dup, synthetic_reid_data_dup
|
96
|
+
|
97
|
+
|
98
|
+
def calculate_reid_risk_score(
|
99
|
+
real_sample_dup,
|
100
|
+
reid_data_dup,
|
101
|
+
synthetic_reid_data_dup,
|
102
|
+
lower_n,
|
103
|
+
cap_n,
|
104
|
+
lambda_val=0.23,
|
105
|
+
num_salts=20,
|
106
|
+
):
|
107
|
+
real_sample_dup = real_sample_dup.withColumn(
|
108
|
+
"salt", (f.rand() * num_salts).cast("int")
|
109
|
+
)
|
110
|
+
|
111
|
+
reid_data_stats = reid_data_dup.groupby(
|
112
|
+
"age", "gender", "race", "common_attributes"
|
113
|
+
).count()
|
114
|
+
|
115
|
+
real_to_population_matches = real_sample_dup.join(
|
116
|
+
reid_data_stats,
|
117
|
+
(real_sample_dup["age"] == reid_data_stats["age"])
|
118
|
+
& (real_sample_dup["gender"] == reid_data_stats["gender"])
|
119
|
+
& (real_sample_dup["race"] == reid_data_stats["race"])
|
120
|
+
& (
|
121
|
+
real_sample_dup["common_attributes"] == reid_data_stats["common_attributes"]
|
122
|
+
),
|
123
|
+
).select("person_id", "count")
|
124
|
+
|
125
|
+
# Alias the DataFrame for self join
|
126
|
+
real_sample_stats = real_sample_dup.groupby(
|
127
|
+
"age", "gender", "race", "common_attributes"
|
128
|
+
).count()
|
129
|
+
|
130
|
+
real_to_real_matches = real_sample_dup.join(
|
131
|
+
real_sample_stats,
|
132
|
+
(real_sample_dup["age"] == real_sample_stats["age"])
|
133
|
+
& (real_sample_dup["gender"] == real_sample_stats["gender"])
|
134
|
+
& (real_sample_dup["race"] == real_sample_stats["race"])
|
135
|
+
& (
|
136
|
+
real_sample_dup["common_attributes"]
|
137
|
+
== real_sample_stats["common_attributes"]
|
138
|
+
),
|
139
|
+
).select("person_id", "count")
|
140
|
+
|
141
|
+
reid_data_step_one = (
|
142
|
+
reid_data_dup.join(real_to_population_matches, "person_id")
|
143
|
+
.withColumnRenamed("count", "upper_F_s")
|
144
|
+
.join(real_to_real_matches, "person_id")
|
145
|
+
.withColumnRenamed("count", "lower_f_s")
|
146
|
+
)
|
147
|
+
|
148
|
+
synthetic_reid_data_dup = synthetic_reid_data_dup.withColumn(
|
149
|
+
"salt", f.explode(f.array([f.lit(x) for x in range(num_salts)]))
|
150
|
+
)
|
151
|
+
|
152
|
+
real_sample_dup = real_sample_dup.where(
|
153
|
+
f.size("sensitive_attributes") >= SENSITIVE_MATCH_THRESHOLD
|
154
|
+
)
|
155
|
+
synthetic_reid_data_dup = synthetic_reid_data_dup.where(
|
156
|
+
f.size("sensitive_attributes") >= SENSITIVE_MATCH_THRESHOLD
|
157
|
+
)
|
158
|
+
|
159
|
+
real_to_synthetic = (
|
160
|
+
real_sample_dup.join(
|
161
|
+
synthetic_reid_data_dup,
|
162
|
+
(real_sample_dup["age"] == synthetic_reid_data_dup["age"])
|
163
|
+
& (real_sample_dup["gender"] == synthetic_reid_data_dup["gender"])
|
164
|
+
& (real_sample_dup["race"] == synthetic_reid_data_dup["race"])
|
165
|
+
& (
|
166
|
+
real_sample_dup["common_attributes"]
|
167
|
+
== synthetic_reid_data_dup["common_attributes"]
|
168
|
+
)
|
169
|
+
& (real_sample_dup["salt"] == synthetic_reid_data_dup["salt"]),
|
170
|
+
)
|
171
|
+
.select(
|
172
|
+
real_sample_dup["person_id"],
|
173
|
+
real_sample_dup["sensitive_attributes"].alias("real_sensitive_attributes"),
|
174
|
+
synthetic_reid_data_dup["sensitive_attributes"].alias(
|
175
|
+
"synthetic_sensitive_attributes"
|
176
|
+
),
|
177
|
+
)
|
178
|
+
.withColumn(
|
179
|
+
"n_of_sensitive_matches",
|
180
|
+
f.size(
|
181
|
+
f.array_intersect(
|
182
|
+
"real_sensitive_attributes", "synthetic_sensitive_attributes"
|
183
|
+
)
|
184
|
+
),
|
185
|
+
)
|
186
|
+
.drop("real_sensitive_attributes", "synthetic_sensitive_attributes")
|
187
|
+
)
|
188
|
+
|
189
|
+
real_to_synthetic = (
|
190
|
+
real_to_synthetic.groupby("person_id")
|
191
|
+
.agg(f.max("n_of_sensitive_matches").alias("n_of_sensitive_matches"))
|
192
|
+
.withColumn(
|
193
|
+
"new_info",
|
194
|
+
(f.col("n_of_sensitive_matches") > f.lit(SENSITIVE_MATCH_THRESHOLD)).cast(
|
195
|
+
"int"
|
196
|
+
),
|
197
|
+
)
|
198
|
+
)
|
199
|
+
|
200
|
+
reid_data_step_two = reid_data_step_one.join(
|
201
|
+
real_to_synthetic, "person_id", how="left_outer"
|
202
|
+
).withColumn("new_info", f.coalesce(real_to_synthetic["new_info"], f.lit(0)))
|
203
|
+
|
204
|
+
return (
|
205
|
+
reid_data_step_two.withColumn(
|
206
|
+
"A_term",
|
207
|
+
1 / f.col("lower_f_s") * f.lit((1 + lambda_val) / 2) * f.col("new_info"),
|
208
|
+
)
|
209
|
+
.withColumn(
|
210
|
+
"B_term",
|
211
|
+
1 / f.col("upper_f_s") * f.lit((1 + lambda_val) / 2) * f.col("new_info"),
|
212
|
+
)
|
213
|
+
.select(
|
214
|
+
(f.sum("A_term") / f.lit(cap_n)).alias("A_term"),
|
215
|
+
(f.sum("B_term") / f.lit(lower_n)).alias("B_term"),
|
216
|
+
)
|
217
|
+
)
|
218
|
+
|
219
|
+
|
220
|
+
def main(args):
|
221
|
+
all_configs = generate_lattice_bfs(MAX_LEVELS)
|
222
|
+
|
223
|
+
N = None
|
224
|
+
n = None
|
225
|
+
excluded_sensitive_attributes = None
|
226
|
+
|
227
|
+
for config in tqdm(all_configs, total=len(all_configs)):
|
228
|
+
|
229
|
+
spark = SparkSession.builder.appName(
|
230
|
+
f'Generate REID {"".join(map(str, config))}'
|
231
|
+
).getOrCreate()
|
232
|
+
|
233
|
+
if sum(config) == 0:
|
234
|
+
continue
|
235
|
+
|
236
|
+
experiment_output_name = os.path.join(
|
237
|
+
args.output_folder, f'{"".join(map(str, config))}.parquet'
|
238
|
+
)
|
239
|
+
if os.path.exists(experiment_output_name):
|
240
|
+
continue
|
241
|
+
|
242
|
+
@f.udf(t.ArrayType(t.StringType()))
|
243
|
+
def extract_omop_concepts_udf(concept_ids):
|
244
|
+
return list(sorted(set([_ for _ in concept_ids if str.isnumeric(_)])))
|
245
|
+
|
246
|
+
def extract_common_attributes(concept_ids):
|
247
|
+
commn_atts = set([c for c in concept_ids if c in COMMON_ATTRIBUTES])
|
248
|
+
return list(commn_atts)
|
249
|
+
|
250
|
+
extract_common_attributes_udf = f.udf(
|
251
|
+
lambda concept_ids: extract_common_attributes(concept_ids),
|
252
|
+
t.ArrayType(t.StringType()),
|
253
|
+
)
|
254
|
+
|
255
|
+
def extract_sensitive_attributes(concept_ids):
|
256
|
+
return list(
|
257
|
+
sorted(set([c for c in concept_ids if c not in COMMON_ATTRIBUTES]))
|
258
|
+
)
|
259
|
+
|
260
|
+
extract_sensitive_attributes_udf = f.udf(
|
261
|
+
lambda concept_ids: extract_sensitive_attributes(concept_ids),
|
262
|
+
t.ArrayType(t.StringType()),
|
263
|
+
)
|
264
|
+
|
265
|
+
training_data = spark.read.parquet(args.training_data_folder)
|
266
|
+
evaluation_data = spark.read.parquet(args.evaluation_data_folder)
|
267
|
+
training_data = training_data.select(
|
268
|
+
"concept_ids", "num_of_concepts", "person_id"
|
269
|
+
).withColumn("is_real", f.lit(True))
|
270
|
+
evaluation_data = evaluation_data.select(
|
271
|
+
"concept_ids", "num_of_concepts", "person_id"
|
272
|
+
).withColumn("is_real", f.lit(False))
|
273
|
+
data = training_data.unionByName(evaluation_data)
|
274
|
+
|
275
|
+
reid_data = (
|
276
|
+
data.withColumn("age", f.col("concept_ids")[1])
|
277
|
+
.withColumn("age", f.split("age", ":")[1])
|
278
|
+
.withColumn("gender", f.col("concept_ids")[2])
|
279
|
+
.withColumn("race", f.col("concept_ids")[3])
|
280
|
+
.withColumn("concept_ids", extract_omop_concepts_udf("concept_ids"))
|
281
|
+
.withColumn(
|
282
|
+
"common_attributes", extract_common_attributes_udf("concept_ids")
|
283
|
+
)
|
284
|
+
.withColumn(
|
285
|
+
"sensitive_attributes", extract_sensitive_attributes_udf("concept_ids")
|
286
|
+
)
|
287
|
+
.drop("concept_ids")
|
288
|
+
)
|
289
|
+
|
290
|
+
real_sample = reid_data.where("is_real")
|
291
|
+
|
292
|
+
if not N:
|
293
|
+
N = reid_data.count()
|
294
|
+
if not n:
|
295
|
+
n = real_sample.count()
|
296
|
+
|
297
|
+
if not excluded_sensitive_attributes:
|
298
|
+
excluded_sensitive_attributes_df = (
|
299
|
+
real_sample.select(
|
300
|
+
f.explode("sensitive_attributes").alias("sensitive_attribute")
|
301
|
+
)
|
302
|
+
.groupby("sensitive_attribute")
|
303
|
+
.count()
|
304
|
+
.withColumn("sensitive_attribute_prevalence", f.col("count") / n)
|
305
|
+
.withColumn(
|
306
|
+
"is_majority", f.col("sensitive_attribute_prevalence") >= 0.5
|
307
|
+
)
|
308
|
+
.where("is_majority")
|
309
|
+
)
|
310
|
+
excluded_sensitive_attributes = [
|
311
|
+
row.sensitive_attribute
|
312
|
+
for row in excluded_sensitive_attributes_df.select(
|
313
|
+
"sensitive_attribute"
|
314
|
+
).collect()
|
315
|
+
]
|
316
|
+
|
317
|
+
def filter_sensitive_attributes(concept_ids):
|
318
|
+
return [c for c in concept_ids if c not in excluded_sensitive_attributes]
|
319
|
+
|
320
|
+
filter_sensitive_attributes_udf = f.udf(
|
321
|
+
lambda concept_ids: filter_sensitive_attributes(concept_ids),
|
322
|
+
t.ArrayType(t.StringType()),
|
323
|
+
)
|
324
|
+
|
325
|
+
real_sample = real_sample.withColumn(
|
326
|
+
"sensitive_attributes",
|
327
|
+
filter_sensitive_attributes_udf("sensitive_attributes"),
|
328
|
+
)
|
329
|
+
|
330
|
+
synthetic_data = spark.read.parquet(args.synthetic_data_folder)
|
331
|
+
|
332
|
+
synthetic_reid_data = (
|
333
|
+
synthetic_data.withColumn("age", f.col("concept_ids")[1])
|
334
|
+
.withColumn("age", f.split("age", ":")[1])
|
335
|
+
.withColumn("gender", f.col("concept_ids")[2])
|
336
|
+
.withColumn("race", f.col("concept_ids")[3])
|
337
|
+
.withColumn("concept_ids", extract_omop_concepts_udf("concept_ids"))
|
338
|
+
.withColumn(
|
339
|
+
"common_attributes", extract_common_attributes_udf("concept_ids")
|
340
|
+
)
|
341
|
+
.withColumn(
|
342
|
+
"sensitive_attributes", extract_sensitive_attributes_udf("concept_ids")
|
343
|
+
)
|
344
|
+
.drop("concept_ids")
|
345
|
+
)
|
346
|
+
|
347
|
+
reid_data_dup, real_sample_dup, synthetic_reid_data_dup = update_dataset(
|
348
|
+
reid_data, real_sample, synthetic_reid_data, config=config
|
349
|
+
)
|
350
|
+
experiment_risk_score = calculate_reid_risk_score(
|
351
|
+
real_sample_dup,
|
352
|
+
reid_data_dup,
|
353
|
+
synthetic_reid_data_dup,
|
354
|
+
lower_n=n,
|
355
|
+
cap_n=N,
|
356
|
+
lambda_val=0.23,
|
357
|
+
num_salts=args.num_salts,
|
358
|
+
)
|
359
|
+
experiment_risk_score.toPandas().to_parquet(experiment_output_name)
|
360
|
+
|
361
|
+
spark.stop()
|
362
|
+
|
363
|
+
|
364
|
+
def create_argparser():
|
365
|
+
import argparse
|
366
|
+
|
367
|
+
parser = argparse.ArgumentParser(
|
368
|
+
description="Arguments for re-identification risk evaluation"
|
369
|
+
)
|
370
|
+
parser.add_argument(
|
371
|
+
"--training_data_folder",
|
372
|
+
dest="training_data_folder",
|
373
|
+
action="store",
|
374
|
+
required=True,
|
375
|
+
)
|
376
|
+
parser.add_argument(
|
377
|
+
"--evaluation_data_folder",
|
378
|
+
dest="evaluation_data_folder",
|
379
|
+
action="store",
|
380
|
+
required=True,
|
381
|
+
)
|
382
|
+
parser.add_argument(
|
383
|
+
"--synthetic_data_folder",
|
384
|
+
dest="synthetic_data_folder",
|
385
|
+
action="store",
|
386
|
+
required=True,
|
387
|
+
)
|
388
|
+
parser.add_argument(
|
389
|
+
"--output_folder",
|
390
|
+
dest="output_folder",
|
391
|
+
action="store",
|
392
|
+
help="The output folder for storing the results",
|
393
|
+
required=True,
|
394
|
+
)
|
395
|
+
parser.add_argument(
|
396
|
+
"--num_salts",
|
397
|
+
dest="num_salts",
|
398
|
+
action="store",
|
399
|
+
type=int,
|
400
|
+
default=40,
|
401
|
+
required=False,
|
402
|
+
)
|
403
|
+
return parser
|
404
|
+
|
405
|
+
|
406
|
+
if __name__ == "__main__":
|
407
|
+
main(create_argparser().parse_args())
|
@@ -0,0 +1,255 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from sklearn.preprocessing import OneHotEncoder
|
5
|
+
from tqdm import tqdm
|
6
|
+
|
7
|
+
LOG = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
RANDOM_SEE = 42
|
10
|
+
NUM_OF_GENDERS = 3
|
11
|
+
NUM_OF_RACES = 21
|
12
|
+
|
13
|
+
|
14
|
+
def get_demographics(concept_ids):
|
15
|
+
year_token, age_token, gender, race = concept_ids[0:4]
|
16
|
+
try:
|
17
|
+
year = int(year_token[5:])
|
18
|
+
except ValueError:
|
19
|
+
LOG.error(
|
20
|
+
f"{year_token[5:]} cannot be converted to an integer, use the default value 1900"
|
21
|
+
)
|
22
|
+
year = 1900
|
23
|
+
|
24
|
+
try:
|
25
|
+
age = int(age_token[4:])
|
26
|
+
except ValueError:
|
27
|
+
LOG.error(
|
28
|
+
f"{age_token[4:]} cannot be converted to an integer, use the default value 1900"
|
29
|
+
)
|
30
|
+
age = -1
|
31
|
+
|
32
|
+
return year, age, gender, race
|
33
|
+
|
34
|
+
|
35
|
+
def create_race_encoder(
|
36
|
+
train_data_sample, evaluation_data_sample, synthetic_data_sample
|
37
|
+
):
|
38
|
+
race_encoder = OneHotEncoder()
|
39
|
+
all_unique_races = np.unique(
|
40
|
+
np.concatenate(
|
41
|
+
[
|
42
|
+
train_data_sample.race.unique(),
|
43
|
+
evaluation_data_sample.race.unique(),
|
44
|
+
synthetic_data_sample.race.unique(),
|
45
|
+
],
|
46
|
+
axis=0,
|
47
|
+
)
|
48
|
+
)
|
49
|
+
race_encoder.fit(all_unique_races[:, np.newaxis])
|
50
|
+
return race_encoder
|
51
|
+
|
52
|
+
|
53
|
+
def create_gender_encoder(
|
54
|
+
train_data_sample, evaluation_data_sample, synthetic_data_sample
|
55
|
+
):
|
56
|
+
gender_encoder = OneHotEncoder()
|
57
|
+
all_unique_genders = np.unique(
|
58
|
+
np.concatenate(
|
59
|
+
[
|
60
|
+
train_data_sample.gender.unique(),
|
61
|
+
evaluation_data_sample.gender.unique(),
|
62
|
+
synthetic_data_sample.gender.unique(),
|
63
|
+
],
|
64
|
+
axis=0,
|
65
|
+
)
|
66
|
+
)
|
67
|
+
gender_encoder.fit(all_unique_genders[:, np.newaxis])
|
68
|
+
return gender_encoder
|
69
|
+
|
70
|
+
|
71
|
+
def extract_medical_concepts(concept_ids):
|
72
|
+
concept_ids = [_ for _ in concept_ids[4:] if str.isnumeric(_)]
|
73
|
+
return list(set(concept_ids))
|
74
|
+
|
75
|
+
|
76
|
+
def create_binary_format(concept_ids, concept_tokenizer):
|
77
|
+
indices = np.array(concept_tokenizer.encode(concept_ids)).flatten().astype(int)
|
78
|
+
embeddings = np.zeros(concept_tokenizer.vocab_size)
|
79
|
+
embeddings.put(indices, 1)
|
80
|
+
return embeddings
|
81
|
+
|
82
|
+
|
83
|
+
def extract_common_sensitive_concepts(
|
84
|
+
dataset, concept_tokenizer, common_attributes, sensitive_attributes
|
85
|
+
):
|
86
|
+
common_embeddings_list = []
|
87
|
+
sensitive_embeddings_list = []
|
88
|
+
for _, pat_seq in dataset.concept_ids.items():
|
89
|
+
concept_ids = extract_medical_concepts(pat_seq)
|
90
|
+
common_concept_ids = [_ for _ in concept_ids if _ in common_attributes]
|
91
|
+
sensitive_concept_ids = [_ for _ in concept_ids if _ in sensitive_attributes]
|
92
|
+
|
93
|
+
common_embeddings_list.append(
|
94
|
+
create_binary_format(common_concept_ids, concept_tokenizer)
|
95
|
+
)
|
96
|
+
sensitive_embeddings_list.append(
|
97
|
+
create_binary_format(sensitive_concept_ids, concept_tokenizer)
|
98
|
+
)
|
99
|
+
return np.array(common_embeddings_list), np.array(sensitive_embeddings_list)
|
100
|
+
|
101
|
+
|
102
|
+
def transform_concepts(dataset, concept_tokenizer):
|
103
|
+
embedding_list = []
|
104
|
+
for _, pat_seq in dataset.concept_ids.items():
|
105
|
+
embedding_list.append(
|
106
|
+
create_binary_format(extract_medical_concepts(pat_seq), concept_tokenizer)
|
107
|
+
)
|
108
|
+
return np.asarray(embedding_list)
|
109
|
+
|
110
|
+
|
111
|
+
def scale_age(dataset):
|
112
|
+
ages = dataset.concept_ids.apply(
|
113
|
+
lambda concept_list: get_demographics(concept_list)[1]
|
114
|
+
)
|
115
|
+
dataset = dataset[ages >= 0]
|
116
|
+
ages = ages[ages >= 0]
|
117
|
+
max_age = ages.max()
|
118
|
+
dataset["scaled_age"] = ages / max_age
|
119
|
+
return dataset
|
120
|
+
|
121
|
+
|
122
|
+
def create_demographics(dataset):
|
123
|
+
genders = dataset.concept_ids.apply(
|
124
|
+
lambda concept_list: get_demographics(concept_list)[2]
|
125
|
+
)
|
126
|
+
races = dataset.concept_ids.apply(
|
127
|
+
lambda concept_list: get_demographics(concept_list)[3]
|
128
|
+
)
|
129
|
+
dataset["gender"] = genders
|
130
|
+
dataset["race"] = races
|
131
|
+
return dataset
|
132
|
+
|
133
|
+
|
134
|
+
def create_vector_representations(
|
135
|
+
dataset, concept_tokenizer, gender_encoder, race_encoder
|
136
|
+
):
|
137
|
+
concept_vectors = transform_concepts(dataset, concept_tokenizer)
|
138
|
+
gender_vectors = gender_encoder.transform(
|
139
|
+
dataset.gender.to_numpy()[:, np.newaxis]
|
140
|
+
).todense()
|
141
|
+
race_vectors = race_encoder.transform(
|
142
|
+
dataset.race.to_numpy()[:, np.newaxis]
|
143
|
+
).todense()
|
144
|
+
age_vectors = dataset.scaled_age.to_numpy()[:, np.newaxis]
|
145
|
+
|
146
|
+
pat_vectors = np.concatenate(
|
147
|
+
[age_vectors, gender_vectors, race_vectors, concept_vectors], axis=-1
|
148
|
+
)
|
149
|
+
|
150
|
+
return np.asarray(pat_vectors)
|
151
|
+
|
152
|
+
|
153
|
+
def create_vector_representations_for_attribute(
|
154
|
+
dataset,
|
155
|
+
concept_tokenizer,
|
156
|
+
gender_encoder,
|
157
|
+
race_encoder,
|
158
|
+
common_attributes,
|
159
|
+
sensitive_attributes,
|
160
|
+
):
|
161
|
+
common_concept_vectors, sensitive_concept_vectors = (
|
162
|
+
extract_common_sensitive_concepts(
|
163
|
+
dataset, concept_tokenizer, common_attributes, sensitive_attributes
|
164
|
+
)
|
165
|
+
)
|
166
|
+
gender_vectors = gender_encoder.transform(
|
167
|
+
dataset.gender.to_numpy()[:, np.newaxis]
|
168
|
+
).todense()
|
169
|
+
race_vectors = race_encoder.transform(
|
170
|
+
dataset.race.to_numpy()[:, np.newaxis]
|
171
|
+
).todense()
|
172
|
+
age_vectors = dataset.scaled_age.to_numpy()[:, np.newaxis]
|
173
|
+
|
174
|
+
common_pat_vectors = np.concatenate(
|
175
|
+
[age_vectors, gender_vectors, race_vectors, common_concept_vectors], axis=-1
|
176
|
+
)
|
177
|
+
|
178
|
+
return np.asarray(common_pat_vectors), np.asarray(sensitive_concept_vectors)
|
179
|
+
|
180
|
+
|
181
|
+
def batched_pairwise_euclidean_distance_indices(A, B, batch_size, self_exclude=False):
|
182
|
+
# Initialize arrays to hold the minimum distances and indices for each point in A
|
183
|
+
min_distances = np.full((A.shape[0],), np.inf)
|
184
|
+
min_indices = np.full((A.shape[0],), -1, dtype=int)
|
185
|
+
|
186
|
+
# Iterate over A in batches
|
187
|
+
for i in tqdm(range(0, A.shape[0], batch_size), total=A.shape[0] // batch_size + 1):
|
188
|
+
end_i = i + batch_size
|
189
|
+
A_batch = A[i:end_i]
|
190
|
+
|
191
|
+
# Adjust the identity matrix size based on the actual batch size
|
192
|
+
actual_batch_size = A_batch.shape[0]
|
193
|
+
|
194
|
+
# Iterate over B in batches
|
195
|
+
for j in range(0, B.shape[0], batch_size):
|
196
|
+
end_j = j + batch_size
|
197
|
+
B_batch = B[j:end_j]
|
198
|
+
|
199
|
+
# Compute distances between the current batches of A and B
|
200
|
+
distances = np.sqrt(
|
201
|
+
np.sum(
|
202
|
+
(A_batch[:, np.newaxis, :] - B_batch[np.newaxis, :, :]) ** 2, axis=2
|
203
|
+
)
|
204
|
+
)
|
205
|
+
|
206
|
+
# Apply the identity matrix to exclude self-matches if required
|
207
|
+
if self_exclude and i == j:
|
208
|
+
identity_matrix = np.eye(actual_batch_size) * 10e8
|
209
|
+
distances += identity_matrix
|
210
|
+
|
211
|
+
# Find the minimum distance and corresponding indices for the A batch
|
212
|
+
min_batch_indices = np.argmin(distances, axis=1) + j
|
213
|
+
min_batch_distances = np.min(distances, axis=1)
|
214
|
+
|
215
|
+
# Update the minimum distances and indices if the current batch distances are smaller
|
216
|
+
update_mask = min_batch_distances < min_distances[i:end_i]
|
217
|
+
min_distances[i:end_i][update_mask] = min_batch_distances[update_mask]
|
218
|
+
min_indices[i:end_i][update_mask] = min_batch_indices[update_mask]
|
219
|
+
|
220
|
+
return min_indices
|
221
|
+
|
222
|
+
|
223
|
+
def find_match(source, target, return_index: bool = False):
|
224
|
+
a = np.sum(target**2, axis=1).reshape(target.shape[0], 1) + np.sum(
|
225
|
+
source.T**2, axis=0
|
226
|
+
)
|
227
|
+
b = np.dot(target, source.T) * 2
|
228
|
+
distance_matrix = a - b
|
229
|
+
return (
|
230
|
+
np.argmin(distance_matrix, axis=0)
|
231
|
+
if return_index
|
232
|
+
else np.min(distance_matrix, axis=0)
|
233
|
+
)
|
234
|
+
|
235
|
+
|
236
|
+
def find_match_self(source, target, return_index: bool = False):
|
237
|
+
a = np.sum(target**2, axis=1).reshape(target.shape[0], 1) + np.sum(
|
238
|
+
source.T**2, axis=0
|
239
|
+
)
|
240
|
+
b = np.dot(target, source.T) * 2
|
241
|
+
distance_matrix = a - b
|
242
|
+
n_col = np.shape(distance_matrix)[1]
|
243
|
+
|
244
|
+
if return_index:
|
245
|
+
min_indices = np.zeros(n_col, dtype=int)
|
246
|
+
for i in range(n_col):
|
247
|
+
sorted_indices = np.argsort(distance_matrix[:, i])
|
248
|
+
min_indices[i] = sorted_indices[1] # Get index of second smallest value
|
249
|
+
return min_indices
|
250
|
+
else:
|
251
|
+
min_distance = np.zeros(n_col)
|
252
|
+
for i in range(n_col):
|
253
|
+
sorted_column = np.sort(distance_matrix[:, i])
|
254
|
+
min_distance[i] = sorted_column[1]
|
255
|
+
return min_distance
|