cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,407 @@
1
+ import os
2
+
3
+ from pyspark.sql import SparkSession
4
+ from pyspark.sql import functions as f
5
+ from pyspark.sql import types as t
6
+ from tqdm import tqdm
7
+
8
+ VOCAB_SIZE = 17870
9
+ L = 0.001
10
+ SENSITIVE_MATCH_THRESHOLD = int(VOCAB_SIZE * L)
11
+
12
+ # This number came from the tokenizer, essentially all numeric tokens that represent valid OMOP concept ids
13
+ COMMON_ATTRIBUTES = [
14
+ "320128",
15
+ "200219",
16
+ "77670",
17
+ "432867",
18
+ "254761",
19
+ "312437",
20
+ "378253",
21
+ ]
22
+ MAX_LEVELS = [1, 0, 0] + [1] * len(COMMON_ATTRIBUTES)
23
+ # The age groups are generated using the following bases e.g. 1 indicates using the age as is
24
+ AGE_GENERALIZATION_LEVELS = [10, 1]
25
+
26
+
27
+ def generate_lattice_bfs(top_gen_levels): # BFS
28
+ """Came from https://github.com/yy6linda/synthetic-ehr-benchmarking/blob/main/privacy_evaluation/Synthetic_risk_model_reid.py#L81."""
29
+ visited = [top_gen_levels]
30
+ queue = [top_gen_levels]
31
+ lattice = []
32
+ while queue:
33
+ gen_levels = queue.pop(0)
34
+ lattice.append(gen_levels)
35
+ for i in range(len(gen_levels)):
36
+ if gen_levels[i] != 0:
37
+ gen_levels_new = gen_levels.copy()
38
+ gen_levels_new[i] -= 1
39
+ if not gen_levels_new in visited:
40
+ visited.append(gen_levels_new)
41
+ queue.append(gen_levels_new)
42
+ return lattice
43
+
44
+
45
+ def update_dataset(reid_data, real_sample, synthetic_reid_data, config):
46
+ reid_data_dup = reid_data.alias("reid_data_dup")
47
+ real_sample_dup = real_sample.alias("real_sample_dup")
48
+ synthetic_reid_data_dup = synthetic_reid_data.alias("synthetic_reid_data_dup")
49
+ common_attributes_to_remove = []
50
+
51
+ for i in range(len(config)):
52
+ # age group
53
+ if i == 0:
54
+ age_level_index = config[i]
55
+ age_group_base = AGE_GENERALIZATION_LEVELS[age_level_index]
56
+ reid_data_dup = reid_data_dup.withColumn(
57
+ "age", f.ceil(f.col("age").cast("int") / age_group_base)
58
+ )
59
+ synthetic_reid_data_dup = synthetic_reid_data_dup.withColumn(
60
+ "age", f.ceil(f.col("age").cast("int") / age_group_base)
61
+ )
62
+ real_sample_dup = real_sample_dup.withColumn(
63
+ "age", f.ceil(f.col("age").cast("int") / age_group_base)
64
+ )
65
+ elif i in [1, 2]:
66
+ # gender and race are not generalized
67
+ continue
68
+ else:
69
+ # this indicates that the common attribute should be removed
70
+ if config[i] == 0:
71
+ common_attributes_to_remove.append(COMMON_ATTRIBUTES[i - 3])
72
+
73
+ def remove_common_attributes(concept_ids):
74
+ comm_atts = sorted(set(concept_ids) - set(common_attributes_to_remove))
75
+ if comm_atts:
76
+ return "-".join(comm_atts)
77
+ else:
78
+ return "empty"
79
+
80
+ extract_common_attributes_udf = f.udf(
81
+ lambda concept_ids: remove_common_attributes(concept_ids), t.StringType()
82
+ )
83
+
84
+ reid_data_dup = reid_data_dup.withColumn(
85
+ "common_attributes", extract_common_attributes_udf("common_attributes")
86
+ )
87
+
88
+ real_sample_dup = real_sample_dup.withColumn(
89
+ "common_attributes", extract_common_attributes_udf("common_attributes")
90
+ )
91
+
92
+ synthetic_reid_data_dup = synthetic_reid_data_dup.withColumn(
93
+ "common_attributes", extract_common_attributes_udf("common_attributes")
94
+ )
95
+ return reid_data_dup, real_sample_dup, synthetic_reid_data_dup
96
+
97
+
98
+ def calculate_reid_risk_score(
99
+ real_sample_dup,
100
+ reid_data_dup,
101
+ synthetic_reid_data_dup,
102
+ lower_n,
103
+ cap_n,
104
+ lambda_val=0.23,
105
+ num_salts=20,
106
+ ):
107
+ real_sample_dup = real_sample_dup.withColumn(
108
+ "salt", (f.rand() * num_salts).cast("int")
109
+ )
110
+
111
+ reid_data_stats = reid_data_dup.groupby(
112
+ "age", "gender", "race", "common_attributes"
113
+ ).count()
114
+
115
+ real_to_population_matches = real_sample_dup.join(
116
+ reid_data_stats,
117
+ (real_sample_dup["age"] == reid_data_stats["age"])
118
+ & (real_sample_dup["gender"] == reid_data_stats["gender"])
119
+ & (real_sample_dup["race"] == reid_data_stats["race"])
120
+ & (
121
+ real_sample_dup["common_attributes"] == reid_data_stats["common_attributes"]
122
+ ),
123
+ ).select("person_id", "count")
124
+
125
+ # Alias the DataFrame for self join
126
+ real_sample_stats = real_sample_dup.groupby(
127
+ "age", "gender", "race", "common_attributes"
128
+ ).count()
129
+
130
+ real_to_real_matches = real_sample_dup.join(
131
+ real_sample_stats,
132
+ (real_sample_dup["age"] == real_sample_stats["age"])
133
+ & (real_sample_dup["gender"] == real_sample_stats["gender"])
134
+ & (real_sample_dup["race"] == real_sample_stats["race"])
135
+ & (
136
+ real_sample_dup["common_attributes"]
137
+ == real_sample_stats["common_attributes"]
138
+ ),
139
+ ).select("person_id", "count")
140
+
141
+ reid_data_step_one = (
142
+ reid_data_dup.join(real_to_population_matches, "person_id")
143
+ .withColumnRenamed("count", "upper_F_s")
144
+ .join(real_to_real_matches, "person_id")
145
+ .withColumnRenamed("count", "lower_f_s")
146
+ )
147
+
148
+ synthetic_reid_data_dup = synthetic_reid_data_dup.withColumn(
149
+ "salt", f.explode(f.array([f.lit(x) for x in range(num_salts)]))
150
+ )
151
+
152
+ real_sample_dup = real_sample_dup.where(
153
+ f.size("sensitive_attributes") >= SENSITIVE_MATCH_THRESHOLD
154
+ )
155
+ synthetic_reid_data_dup = synthetic_reid_data_dup.where(
156
+ f.size("sensitive_attributes") >= SENSITIVE_MATCH_THRESHOLD
157
+ )
158
+
159
+ real_to_synthetic = (
160
+ real_sample_dup.join(
161
+ synthetic_reid_data_dup,
162
+ (real_sample_dup["age"] == synthetic_reid_data_dup["age"])
163
+ & (real_sample_dup["gender"] == synthetic_reid_data_dup["gender"])
164
+ & (real_sample_dup["race"] == synthetic_reid_data_dup["race"])
165
+ & (
166
+ real_sample_dup["common_attributes"]
167
+ == synthetic_reid_data_dup["common_attributes"]
168
+ )
169
+ & (real_sample_dup["salt"] == synthetic_reid_data_dup["salt"]),
170
+ )
171
+ .select(
172
+ real_sample_dup["person_id"],
173
+ real_sample_dup["sensitive_attributes"].alias("real_sensitive_attributes"),
174
+ synthetic_reid_data_dup["sensitive_attributes"].alias(
175
+ "synthetic_sensitive_attributes"
176
+ ),
177
+ )
178
+ .withColumn(
179
+ "n_of_sensitive_matches",
180
+ f.size(
181
+ f.array_intersect(
182
+ "real_sensitive_attributes", "synthetic_sensitive_attributes"
183
+ )
184
+ ),
185
+ )
186
+ .drop("real_sensitive_attributes", "synthetic_sensitive_attributes")
187
+ )
188
+
189
+ real_to_synthetic = (
190
+ real_to_synthetic.groupby("person_id")
191
+ .agg(f.max("n_of_sensitive_matches").alias("n_of_sensitive_matches"))
192
+ .withColumn(
193
+ "new_info",
194
+ (f.col("n_of_sensitive_matches") > f.lit(SENSITIVE_MATCH_THRESHOLD)).cast(
195
+ "int"
196
+ ),
197
+ )
198
+ )
199
+
200
+ reid_data_step_two = reid_data_step_one.join(
201
+ real_to_synthetic, "person_id", how="left_outer"
202
+ ).withColumn("new_info", f.coalesce(real_to_synthetic["new_info"], f.lit(0)))
203
+
204
+ return (
205
+ reid_data_step_two.withColumn(
206
+ "A_term",
207
+ 1 / f.col("lower_f_s") * f.lit((1 + lambda_val) / 2) * f.col("new_info"),
208
+ )
209
+ .withColumn(
210
+ "B_term",
211
+ 1 / f.col("upper_f_s") * f.lit((1 + lambda_val) / 2) * f.col("new_info"),
212
+ )
213
+ .select(
214
+ (f.sum("A_term") / f.lit(cap_n)).alias("A_term"),
215
+ (f.sum("B_term") / f.lit(lower_n)).alias("B_term"),
216
+ )
217
+ )
218
+
219
+
220
+ def main(args):
221
+ all_configs = generate_lattice_bfs(MAX_LEVELS)
222
+
223
+ N = None
224
+ n = None
225
+ excluded_sensitive_attributes = None
226
+
227
+ for config in tqdm(all_configs, total=len(all_configs)):
228
+
229
+ spark = SparkSession.builder.appName(
230
+ f'Generate REID {"".join(map(str, config))}'
231
+ ).getOrCreate()
232
+
233
+ if sum(config) == 0:
234
+ continue
235
+
236
+ experiment_output_name = os.path.join(
237
+ args.output_folder, f'{"".join(map(str, config))}.parquet'
238
+ )
239
+ if os.path.exists(experiment_output_name):
240
+ continue
241
+
242
+ @f.udf(t.ArrayType(t.StringType()))
243
+ def extract_omop_concepts_udf(concept_ids):
244
+ return list(sorted(set([_ for _ in concept_ids if str.isnumeric(_)])))
245
+
246
+ def extract_common_attributes(concept_ids):
247
+ commn_atts = set([c for c in concept_ids if c in COMMON_ATTRIBUTES])
248
+ return list(commn_atts)
249
+
250
+ extract_common_attributes_udf = f.udf(
251
+ lambda concept_ids: extract_common_attributes(concept_ids),
252
+ t.ArrayType(t.StringType()),
253
+ )
254
+
255
+ def extract_sensitive_attributes(concept_ids):
256
+ return list(
257
+ sorted(set([c for c in concept_ids if c not in COMMON_ATTRIBUTES]))
258
+ )
259
+
260
+ extract_sensitive_attributes_udf = f.udf(
261
+ lambda concept_ids: extract_sensitive_attributes(concept_ids),
262
+ t.ArrayType(t.StringType()),
263
+ )
264
+
265
+ training_data = spark.read.parquet(args.training_data_folder)
266
+ evaluation_data = spark.read.parquet(args.evaluation_data_folder)
267
+ training_data = training_data.select(
268
+ "concept_ids", "num_of_concepts", "person_id"
269
+ ).withColumn("is_real", f.lit(True))
270
+ evaluation_data = evaluation_data.select(
271
+ "concept_ids", "num_of_concepts", "person_id"
272
+ ).withColumn("is_real", f.lit(False))
273
+ data = training_data.unionByName(evaluation_data)
274
+
275
+ reid_data = (
276
+ data.withColumn("age", f.col("concept_ids")[1])
277
+ .withColumn("age", f.split("age", ":")[1])
278
+ .withColumn("gender", f.col("concept_ids")[2])
279
+ .withColumn("race", f.col("concept_ids")[3])
280
+ .withColumn("concept_ids", extract_omop_concepts_udf("concept_ids"))
281
+ .withColumn(
282
+ "common_attributes", extract_common_attributes_udf("concept_ids")
283
+ )
284
+ .withColumn(
285
+ "sensitive_attributes", extract_sensitive_attributes_udf("concept_ids")
286
+ )
287
+ .drop("concept_ids")
288
+ )
289
+
290
+ real_sample = reid_data.where("is_real")
291
+
292
+ if not N:
293
+ N = reid_data.count()
294
+ if not n:
295
+ n = real_sample.count()
296
+
297
+ if not excluded_sensitive_attributes:
298
+ excluded_sensitive_attributes_df = (
299
+ real_sample.select(
300
+ f.explode("sensitive_attributes").alias("sensitive_attribute")
301
+ )
302
+ .groupby("sensitive_attribute")
303
+ .count()
304
+ .withColumn("sensitive_attribute_prevalence", f.col("count") / n)
305
+ .withColumn(
306
+ "is_majority", f.col("sensitive_attribute_prevalence") >= 0.5
307
+ )
308
+ .where("is_majority")
309
+ )
310
+ excluded_sensitive_attributes = [
311
+ row.sensitive_attribute
312
+ for row in excluded_sensitive_attributes_df.select(
313
+ "sensitive_attribute"
314
+ ).collect()
315
+ ]
316
+
317
+ def filter_sensitive_attributes(concept_ids):
318
+ return [c for c in concept_ids if c not in excluded_sensitive_attributes]
319
+
320
+ filter_sensitive_attributes_udf = f.udf(
321
+ lambda concept_ids: filter_sensitive_attributes(concept_ids),
322
+ t.ArrayType(t.StringType()),
323
+ )
324
+
325
+ real_sample = real_sample.withColumn(
326
+ "sensitive_attributes",
327
+ filter_sensitive_attributes_udf("sensitive_attributes"),
328
+ )
329
+
330
+ synthetic_data = spark.read.parquet(args.synthetic_data_folder)
331
+
332
+ synthetic_reid_data = (
333
+ synthetic_data.withColumn("age", f.col("concept_ids")[1])
334
+ .withColumn("age", f.split("age", ":")[1])
335
+ .withColumn("gender", f.col("concept_ids")[2])
336
+ .withColumn("race", f.col("concept_ids")[3])
337
+ .withColumn("concept_ids", extract_omop_concepts_udf("concept_ids"))
338
+ .withColumn(
339
+ "common_attributes", extract_common_attributes_udf("concept_ids")
340
+ )
341
+ .withColumn(
342
+ "sensitive_attributes", extract_sensitive_attributes_udf("concept_ids")
343
+ )
344
+ .drop("concept_ids")
345
+ )
346
+
347
+ reid_data_dup, real_sample_dup, synthetic_reid_data_dup = update_dataset(
348
+ reid_data, real_sample, synthetic_reid_data, config=config
349
+ )
350
+ experiment_risk_score = calculate_reid_risk_score(
351
+ real_sample_dup,
352
+ reid_data_dup,
353
+ synthetic_reid_data_dup,
354
+ lower_n=n,
355
+ cap_n=N,
356
+ lambda_val=0.23,
357
+ num_salts=args.num_salts,
358
+ )
359
+ experiment_risk_score.toPandas().to_parquet(experiment_output_name)
360
+
361
+ spark.stop()
362
+
363
+
364
+ def create_argparser():
365
+ import argparse
366
+
367
+ parser = argparse.ArgumentParser(
368
+ description="Arguments for re-identification risk evaluation"
369
+ )
370
+ parser.add_argument(
371
+ "--training_data_folder",
372
+ dest="training_data_folder",
373
+ action="store",
374
+ required=True,
375
+ )
376
+ parser.add_argument(
377
+ "--evaluation_data_folder",
378
+ dest="evaluation_data_folder",
379
+ action="store",
380
+ required=True,
381
+ )
382
+ parser.add_argument(
383
+ "--synthetic_data_folder",
384
+ dest="synthetic_data_folder",
385
+ action="store",
386
+ required=True,
387
+ )
388
+ parser.add_argument(
389
+ "--output_folder",
390
+ dest="output_folder",
391
+ action="store",
392
+ help="The output folder for storing the results",
393
+ required=True,
394
+ )
395
+ parser.add_argument(
396
+ "--num_salts",
397
+ dest="num_salts",
398
+ action="store",
399
+ type=int,
400
+ default=40,
401
+ required=False,
402
+ )
403
+ return parser
404
+
405
+
406
+ if __name__ == "__main__":
407
+ main(create_argparser().parse_args())
@@ -0,0 +1,255 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+ from sklearn.preprocessing import OneHotEncoder
5
+ from tqdm import tqdm
6
+
7
+ LOG = logging.getLogger(__name__)
8
+
9
+ RANDOM_SEE = 42
10
+ NUM_OF_GENDERS = 3
11
+ NUM_OF_RACES = 21
12
+
13
+
14
+ def get_demographics(concept_ids):
15
+ year_token, age_token, gender, race = concept_ids[0:4]
16
+ try:
17
+ year = int(year_token[5:])
18
+ except ValueError:
19
+ LOG.error(
20
+ f"{year_token[5:]} cannot be converted to an integer, use the default value 1900"
21
+ )
22
+ year = 1900
23
+
24
+ try:
25
+ age = int(age_token[4:])
26
+ except ValueError:
27
+ LOG.error(
28
+ f"{age_token[4:]} cannot be converted to an integer, use the default value 1900"
29
+ )
30
+ age = -1
31
+
32
+ return year, age, gender, race
33
+
34
+
35
+ def create_race_encoder(
36
+ train_data_sample, evaluation_data_sample, synthetic_data_sample
37
+ ):
38
+ race_encoder = OneHotEncoder()
39
+ all_unique_races = np.unique(
40
+ np.concatenate(
41
+ [
42
+ train_data_sample.race.unique(),
43
+ evaluation_data_sample.race.unique(),
44
+ synthetic_data_sample.race.unique(),
45
+ ],
46
+ axis=0,
47
+ )
48
+ )
49
+ race_encoder.fit(all_unique_races[:, np.newaxis])
50
+ return race_encoder
51
+
52
+
53
+ def create_gender_encoder(
54
+ train_data_sample, evaluation_data_sample, synthetic_data_sample
55
+ ):
56
+ gender_encoder = OneHotEncoder()
57
+ all_unique_genders = np.unique(
58
+ np.concatenate(
59
+ [
60
+ train_data_sample.gender.unique(),
61
+ evaluation_data_sample.gender.unique(),
62
+ synthetic_data_sample.gender.unique(),
63
+ ],
64
+ axis=0,
65
+ )
66
+ )
67
+ gender_encoder.fit(all_unique_genders[:, np.newaxis])
68
+ return gender_encoder
69
+
70
+
71
+ def extract_medical_concepts(concept_ids):
72
+ concept_ids = [_ for _ in concept_ids[4:] if str.isnumeric(_)]
73
+ return list(set(concept_ids))
74
+
75
+
76
+ def create_binary_format(concept_ids, concept_tokenizer):
77
+ indices = np.array(concept_tokenizer.encode(concept_ids)).flatten().astype(int)
78
+ embeddings = np.zeros(concept_tokenizer.vocab_size)
79
+ embeddings.put(indices, 1)
80
+ return embeddings
81
+
82
+
83
+ def extract_common_sensitive_concepts(
84
+ dataset, concept_tokenizer, common_attributes, sensitive_attributes
85
+ ):
86
+ common_embeddings_list = []
87
+ sensitive_embeddings_list = []
88
+ for _, pat_seq in dataset.concept_ids.items():
89
+ concept_ids = extract_medical_concepts(pat_seq)
90
+ common_concept_ids = [_ for _ in concept_ids if _ in common_attributes]
91
+ sensitive_concept_ids = [_ for _ in concept_ids if _ in sensitive_attributes]
92
+
93
+ common_embeddings_list.append(
94
+ create_binary_format(common_concept_ids, concept_tokenizer)
95
+ )
96
+ sensitive_embeddings_list.append(
97
+ create_binary_format(sensitive_concept_ids, concept_tokenizer)
98
+ )
99
+ return np.array(common_embeddings_list), np.array(sensitive_embeddings_list)
100
+
101
+
102
+ def transform_concepts(dataset, concept_tokenizer):
103
+ embedding_list = []
104
+ for _, pat_seq in dataset.concept_ids.items():
105
+ embedding_list.append(
106
+ create_binary_format(extract_medical_concepts(pat_seq), concept_tokenizer)
107
+ )
108
+ return np.asarray(embedding_list)
109
+
110
+
111
+ def scale_age(dataset):
112
+ ages = dataset.concept_ids.apply(
113
+ lambda concept_list: get_demographics(concept_list)[1]
114
+ )
115
+ dataset = dataset[ages >= 0]
116
+ ages = ages[ages >= 0]
117
+ max_age = ages.max()
118
+ dataset["scaled_age"] = ages / max_age
119
+ return dataset
120
+
121
+
122
+ def create_demographics(dataset):
123
+ genders = dataset.concept_ids.apply(
124
+ lambda concept_list: get_demographics(concept_list)[2]
125
+ )
126
+ races = dataset.concept_ids.apply(
127
+ lambda concept_list: get_demographics(concept_list)[3]
128
+ )
129
+ dataset["gender"] = genders
130
+ dataset["race"] = races
131
+ return dataset
132
+
133
+
134
+ def create_vector_representations(
135
+ dataset, concept_tokenizer, gender_encoder, race_encoder
136
+ ):
137
+ concept_vectors = transform_concepts(dataset, concept_tokenizer)
138
+ gender_vectors = gender_encoder.transform(
139
+ dataset.gender.to_numpy()[:, np.newaxis]
140
+ ).todense()
141
+ race_vectors = race_encoder.transform(
142
+ dataset.race.to_numpy()[:, np.newaxis]
143
+ ).todense()
144
+ age_vectors = dataset.scaled_age.to_numpy()[:, np.newaxis]
145
+
146
+ pat_vectors = np.concatenate(
147
+ [age_vectors, gender_vectors, race_vectors, concept_vectors], axis=-1
148
+ )
149
+
150
+ return np.asarray(pat_vectors)
151
+
152
+
153
+ def create_vector_representations_for_attribute(
154
+ dataset,
155
+ concept_tokenizer,
156
+ gender_encoder,
157
+ race_encoder,
158
+ common_attributes,
159
+ sensitive_attributes,
160
+ ):
161
+ common_concept_vectors, sensitive_concept_vectors = (
162
+ extract_common_sensitive_concepts(
163
+ dataset, concept_tokenizer, common_attributes, sensitive_attributes
164
+ )
165
+ )
166
+ gender_vectors = gender_encoder.transform(
167
+ dataset.gender.to_numpy()[:, np.newaxis]
168
+ ).todense()
169
+ race_vectors = race_encoder.transform(
170
+ dataset.race.to_numpy()[:, np.newaxis]
171
+ ).todense()
172
+ age_vectors = dataset.scaled_age.to_numpy()[:, np.newaxis]
173
+
174
+ common_pat_vectors = np.concatenate(
175
+ [age_vectors, gender_vectors, race_vectors, common_concept_vectors], axis=-1
176
+ )
177
+
178
+ return np.asarray(common_pat_vectors), np.asarray(sensitive_concept_vectors)
179
+
180
+
181
+ def batched_pairwise_euclidean_distance_indices(A, B, batch_size, self_exclude=False):
182
+ # Initialize arrays to hold the minimum distances and indices for each point in A
183
+ min_distances = np.full((A.shape[0],), np.inf)
184
+ min_indices = np.full((A.shape[0],), -1, dtype=int)
185
+
186
+ # Iterate over A in batches
187
+ for i in tqdm(range(0, A.shape[0], batch_size), total=A.shape[0] // batch_size + 1):
188
+ end_i = i + batch_size
189
+ A_batch = A[i:end_i]
190
+
191
+ # Adjust the identity matrix size based on the actual batch size
192
+ actual_batch_size = A_batch.shape[0]
193
+
194
+ # Iterate over B in batches
195
+ for j in range(0, B.shape[0], batch_size):
196
+ end_j = j + batch_size
197
+ B_batch = B[j:end_j]
198
+
199
+ # Compute distances between the current batches of A and B
200
+ distances = np.sqrt(
201
+ np.sum(
202
+ (A_batch[:, np.newaxis, :] - B_batch[np.newaxis, :, :]) ** 2, axis=2
203
+ )
204
+ )
205
+
206
+ # Apply the identity matrix to exclude self-matches if required
207
+ if self_exclude and i == j:
208
+ identity_matrix = np.eye(actual_batch_size) * 10e8
209
+ distances += identity_matrix
210
+
211
+ # Find the minimum distance and corresponding indices for the A batch
212
+ min_batch_indices = np.argmin(distances, axis=1) + j
213
+ min_batch_distances = np.min(distances, axis=1)
214
+
215
+ # Update the minimum distances and indices if the current batch distances are smaller
216
+ update_mask = min_batch_distances < min_distances[i:end_i]
217
+ min_distances[i:end_i][update_mask] = min_batch_distances[update_mask]
218
+ min_indices[i:end_i][update_mask] = min_batch_indices[update_mask]
219
+
220
+ return min_indices
221
+
222
+
223
+ def find_match(source, target, return_index: bool = False):
224
+ a = np.sum(target**2, axis=1).reshape(target.shape[0], 1) + np.sum(
225
+ source.T**2, axis=0
226
+ )
227
+ b = np.dot(target, source.T) * 2
228
+ distance_matrix = a - b
229
+ return (
230
+ np.argmin(distance_matrix, axis=0)
231
+ if return_index
232
+ else np.min(distance_matrix, axis=0)
233
+ )
234
+
235
+
236
+ def find_match_self(source, target, return_index: bool = False):
237
+ a = np.sum(target**2, axis=1).reshape(target.shape[0], 1) + np.sum(
238
+ source.T**2, axis=0
239
+ )
240
+ b = np.dot(target, source.T) * 2
241
+ distance_matrix = a - b
242
+ n_col = np.shape(distance_matrix)[1]
243
+
244
+ if return_index:
245
+ min_indices = np.zeros(n_col, dtype=int)
246
+ for i in range(n_col):
247
+ sorted_indices = np.argsort(distance_matrix[:, i])
248
+ min_indices[i] = sorted_indices[1] # Get index of second smallest value
249
+ return min_indices
250
+ else:
251
+ min_distance = np.zeros(n_col)
252
+ for i in range(n_col):
253
+ sorted_column = np.sort(distance_matrix[:, i])
254
+ min_distance[i] = sorted_column[1]
255
+ return min_distance