cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,644 @@
|
|
1
|
+
import argparse
|
2
|
+
import datetime
|
3
|
+
import glob
|
4
|
+
import os
|
5
|
+
import uuid
|
6
|
+
from datetime import timedelta
|
7
|
+
from multiprocessing import Pool
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Dict, List, Optional
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
import pyarrow.parquet as pq
|
14
|
+
from tqdm import tqdm
|
15
|
+
|
16
|
+
from cehrgpt.generation.omop_entity import (
|
17
|
+
ConditionOccurrence,
|
18
|
+
Death,
|
19
|
+
DrugExposure,
|
20
|
+
Measurement,
|
21
|
+
OmopEntity,
|
22
|
+
Person,
|
23
|
+
ProcedureOccurrence,
|
24
|
+
VisitOccurrence,
|
25
|
+
)
|
26
|
+
from cehrgpt.gpt_utils import (
|
27
|
+
extract_time_interval_in_days,
|
28
|
+
generate_artificial_time_tokens,
|
29
|
+
is_inpatient_att_token,
|
30
|
+
is_visit_end,
|
31
|
+
is_visit_start,
|
32
|
+
)
|
33
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import END_TOKEN
|
34
|
+
|
35
|
+
# TODO: move these to cehrbert_data
|
36
|
+
STOP_TOKENS = ["VE", "[VE]", END_TOKEN]
|
37
|
+
|
38
|
+
CURRENT_PATH = Path(__file__).parent
|
39
|
+
START_TOKEN_SIZE = 4
|
40
|
+
ATT_TIME_TOKENS = generate_artificial_time_tokens()
|
41
|
+
TABLE_LIST = [
|
42
|
+
"person",
|
43
|
+
"visit_occurrence",
|
44
|
+
"condition_occurrence",
|
45
|
+
"procedure_occurrence",
|
46
|
+
"drug_exposure",
|
47
|
+
"death",
|
48
|
+
"measurement",
|
49
|
+
]
|
50
|
+
DISCHARGE_CONCEPT_LIST = [4216643, 4021968, 4146681, 4161979]
|
51
|
+
OOV_CONCEPT_MAP = {
|
52
|
+
1525734: "Drug",
|
53
|
+
779414: "Drug",
|
54
|
+
722117: "Drug",
|
55
|
+
722118: "Drug",
|
56
|
+
722119: "Drug",
|
57
|
+
905420: "Drug",
|
58
|
+
1525543: "Drug",
|
59
|
+
}
|
60
|
+
|
61
|
+
|
62
|
+
def create_folder_if_not_exists(output_folder, table_name):
|
63
|
+
if not os.path.isdir(Path(output_folder) / table_name):
|
64
|
+
os.mkdir(Path(output_folder) / table_name)
|
65
|
+
|
66
|
+
|
67
|
+
def generate_omop_concept_domain(concept_parquet) -> Dict[int, str]:
|
68
|
+
"""
|
69
|
+
Generate a dictionary of concept_id to domain_id.
|
70
|
+
|
71
|
+
:param concept_parquet: concept dataframe read from parquet file
|
72
|
+
:return: dictionary of concept_id to domain_id
|
73
|
+
"""
|
74
|
+
domain_dict = {}
|
75
|
+
for i in concept_parquet.itertuples():
|
76
|
+
domain_dict[i.concept_id] = i.domain_id
|
77
|
+
return domain_dict
|
78
|
+
|
79
|
+
|
80
|
+
def generate_lab_stats_mapping(
|
81
|
+
all_lab_stats: Optional[List[Dict[str, Any]]]
|
82
|
+
) -> Dict[int, Dict[str, Any]]:
|
83
|
+
lab_stats_mapping = {}
|
84
|
+
if all_lab_stats is not None:
|
85
|
+
for lab_stats in all_lab_stats:
|
86
|
+
# TODO: the numeric check will not hold true if we concatenate
|
87
|
+
# the concept with the corresponding unit concept
|
88
|
+
if lab_stats["concept_id"].isnumeric():
|
89
|
+
concept_id = int(lab_stats["concept_id"])
|
90
|
+
count = lab_stats["concept_id"]
|
91
|
+
if (concept_id in lab_stats_mapping) and (
|
92
|
+
count > lab_stats_mapping[concept_id]["count"]
|
93
|
+
):
|
94
|
+
lab_stats_mapping[concept_id] = {
|
95
|
+
"mean": lab_stats["mean"],
|
96
|
+
"std": lab_stats["std"],
|
97
|
+
"count": lab_stats["count"],
|
98
|
+
}
|
99
|
+
else:
|
100
|
+
lab_stats_mapping[concept_id] = {
|
101
|
+
"mean": lab_stats["mean"],
|
102
|
+
"std": lab_stats["std"],
|
103
|
+
"count": lab_stats["count"],
|
104
|
+
}
|
105
|
+
return lab_stats_mapping
|
106
|
+
|
107
|
+
|
108
|
+
def append_to_dict(
|
109
|
+
export_dict: Dict[str, Dict[int, OmopEntity]],
|
110
|
+
omop_entity: OmopEntity,
|
111
|
+
entity_id: int,
|
112
|
+
):
|
113
|
+
if omop_entity.get_table_name() not in export_dict:
|
114
|
+
export_dict[omop_entity.get_table_name()] = {}
|
115
|
+
export_dict[omop_entity.get_table_name()][entity_id] = omop_entity
|
116
|
+
|
117
|
+
|
118
|
+
def delete_bad_sequence(
|
119
|
+
export_dict: Dict[str, Dict[int, OmopEntity]],
|
120
|
+
id_mappings: Dict[str, Dict[int, int]],
|
121
|
+
person_id: int,
|
122
|
+
):
|
123
|
+
for table_name, id_mapping in id_mappings.items():
|
124
|
+
omop_id_mapping = np.array(list(id_mapping.keys()))
|
125
|
+
person_id_mapping = np.array(list(id_mapping.values()))
|
126
|
+
ids_to_delete = omop_id_mapping[np.where(person_id_mapping == person_id)]
|
127
|
+
for id in ids_to_delete:
|
128
|
+
export_dict[table_name].pop(id)
|
129
|
+
|
130
|
+
|
131
|
+
def export_and_clear(
|
132
|
+
output_folder: str,
|
133
|
+
export_dict: Dict[str, Dict[int, OmopEntity]],
|
134
|
+
export_error: Dict[str, Dict[str, str]],
|
135
|
+
id_mappings_dict: Dict[str, Dict[int, int]],
|
136
|
+
pt_seq_dict: Dict[int, str],
|
137
|
+
is_parquet: bool = True,
|
138
|
+
):
|
139
|
+
for table_name, records_to_export in export_dict.items():
|
140
|
+
|
141
|
+
records_in_json = []
|
142
|
+
# If there is no omop_entity, we skip it
|
143
|
+
if len(export_dict[table_name]) == 0:
|
144
|
+
continue
|
145
|
+
|
146
|
+
for entity_id, omop_entity in export_dict[table_name].items():
|
147
|
+
try:
|
148
|
+
records_in_json.append(omop_entity.export_as_json())
|
149
|
+
except AttributeError:
|
150
|
+
# append patient sequence to export error list using pt_seq_dict.
|
151
|
+
if table_name not in export_error:
|
152
|
+
export_error[table_name] = []
|
153
|
+
person_id = id_mappings_dict[table_name][entity_id]
|
154
|
+
export_error[table_name].append(pt_seq_dict[person_id])
|
155
|
+
continue
|
156
|
+
schema = next(iter(records_to_export.items()))[1].get_schema()
|
157
|
+
output_folder_path = Path(output_folder)
|
158
|
+
file_path = output_folder_path / table_name / f"{uuid.uuid4()}.parquet"
|
159
|
+
table_df = pd.DataFrame(records_in_json, columns=schema)
|
160
|
+
|
161
|
+
if is_parquet:
|
162
|
+
table_df.to_parquet(file_path)
|
163
|
+
else:
|
164
|
+
table_df.to_csv(file_path, header=schema, index=False)
|
165
|
+
|
166
|
+
export_dict[table_name].clear()
|
167
|
+
|
168
|
+
|
169
|
+
def _is_none(x):
|
170
|
+
return x is None or np.isnan(x)
|
171
|
+
|
172
|
+
|
173
|
+
def get_num_records(parquet_files: List[str]):
|
174
|
+
total = 0
|
175
|
+
for file_path in parquet_files:
|
176
|
+
parquet_file = pq.ParquetFile(file_path)
|
177
|
+
total += parquet_file.metadata.num_rows
|
178
|
+
return total
|
179
|
+
|
180
|
+
|
181
|
+
def record_generator(parquet_files):
|
182
|
+
for file_path in parquet_files:
|
183
|
+
df = pd.read_parquet(file_path)
|
184
|
+
for record in df.itertuples():
|
185
|
+
yield record
|
186
|
+
|
187
|
+
|
188
|
+
def gpt_to_omop_converter_batch(
|
189
|
+
const: int,
|
190
|
+
patient_sequence_parquet_files: List[str],
|
191
|
+
domain_map: Dict[int, str],
|
192
|
+
output_folder: str,
|
193
|
+
buffer_size: int,
|
194
|
+
use_original_person_id: bool,
|
195
|
+
):
|
196
|
+
omop_export_dict = {}
|
197
|
+
error_dict = {}
|
198
|
+
export_error = {}
|
199
|
+
id_mappings_dict = {}
|
200
|
+
pt_seq_dict = {}
|
201
|
+
|
202
|
+
for tb in TABLE_LIST:
|
203
|
+
create_folder_if_not_exists(output_folder, tb)
|
204
|
+
id_mappings_dict[tb] = {}
|
205
|
+
|
206
|
+
visit_occurrence_id: int = const + 1
|
207
|
+
condition_occurrence_id: int = const + 1
|
208
|
+
procedure_occurrence_id: int = const + 1
|
209
|
+
drug_exposure_id: int = const + 1
|
210
|
+
measurement_id: int = const + 1
|
211
|
+
|
212
|
+
# Default the person_id
|
213
|
+
person_id: int = const + 1
|
214
|
+
|
215
|
+
patient_record_generator = record_generator(patient_sequence_parquet_files)
|
216
|
+
total_record = get_num_records(patient_sequence_parquet_files)
|
217
|
+
|
218
|
+
for index, record in tqdm(enumerate(patient_record_generator), total=total_record):
|
219
|
+
bad_sequence = False
|
220
|
+
# If original_person_id is set to true, we retrieve it from the record.
|
221
|
+
# If person_id doest not exist in the record, we use the default_person_id
|
222
|
+
if use_original_person_id:
|
223
|
+
person_id = getattr(record, "person_id", person_id)
|
224
|
+
|
225
|
+
# Retrieve the
|
226
|
+
concept_ids = getattr(record, "concept_ids")
|
227
|
+
is_numeric_types = getattr(record, "is_numeric_types", None)
|
228
|
+
number_as_values = getattr(record, "number_as_values", None)
|
229
|
+
concept_as_values = getattr(record, "concept_as_values", None)
|
230
|
+
units = getattr(record, "units", None)
|
231
|
+
|
232
|
+
# Skip the start token if it is the first token
|
233
|
+
if "start" in concept_ids[0].lower():
|
234
|
+
concept_ids = concept_ids[1:]
|
235
|
+
if is_numeric_types is not None:
|
236
|
+
is_numeric_types = is_numeric_types[1:]
|
237
|
+
if number_as_values is not None:
|
238
|
+
number_as_values = number_as_values[1:]
|
239
|
+
if concept_as_values is not None:
|
240
|
+
concept_as_values = concept_as_values[1:]
|
241
|
+
if units is not None:
|
242
|
+
units = units[1:]
|
243
|
+
|
244
|
+
clinical_events = concept_ids[START_TOKEN_SIZE:]
|
245
|
+
# Skip the sequences whose sequence length is 0
|
246
|
+
if len(clinical_events) == 0:
|
247
|
+
continue
|
248
|
+
# Skip the patients whose last token is not a valid end token
|
249
|
+
if clinical_events[-1] not in STOP_TOKENS:
|
250
|
+
continue
|
251
|
+
|
252
|
+
is_numeric_types = (
|
253
|
+
is_numeric_types[START_TOKEN_SIZE:]
|
254
|
+
if is_numeric_types is not None
|
255
|
+
else None
|
256
|
+
)
|
257
|
+
number_as_values = (
|
258
|
+
number_as_values[START_TOKEN_SIZE:]
|
259
|
+
if number_as_values is not None
|
260
|
+
else None
|
261
|
+
)
|
262
|
+
concept_as_values = (
|
263
|
+
concept_as_values[START_TOKEN_SIZE:]
|
264
|
+
if concept_as_values is not None
|
265
|
+
else None
|
266
|
+
)
|
267
|
+
units = units[START_TOKEN_SIZE:] if units is not None else None
|
268
|
+
|
269
|
+
# TODO:Need to decode if the input is tokenized
|
270
|
+
[start_year, start_age, start_gender, start_race] = concept_ids[
|
271
|
+
0:START_TOKEN_SIZE
|
272
|
+
]
|
273
|
+
if "year" not in start_year.lower():
|
274
|
+
continue
|
275
|
+
|
276
|
+
try:
|
277
|
+
start_year = start_year.split(":")[1]
|
278
|
+
start_age = start_age.split(":")[1]
|
279
|
+
birth_year = int(start_year) - int(start_age)
|
280
|
+
except Exception as e:
|
281
|
+
print(
|
282
|
+
f"Failed to convert {concept_ids[0:START_TOKEN_SIZE]} due to {e}, skipping to the next record"
|
283
|
+
)
|
284
|
+
continue
|
285
|
+
|
286
|
+
# Skip the patients whose birth year is either before 1900 or after this year
|
287
|
+
if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
|
288
|
+
continue
|
289
|
+
|
290
|
+
p = Person(person_id, start_gender, birth_year, start_race)
|
291
|
+
append_to_dict(omop_export_dict, p, person_id)
|
292
|
+
id_mappings_dict["person"][person_id] = person_id
|
293
|
+
pt_seq_dict[person_id] = " ".join(concept_ids)
|
294
|
+
discharged_to_concept_id = 0
|
295
|
+
date_cursor = datetime.datetime(year=int(start_year), month=1, day=1)
|
296
|
+
vo = None
|
297
|
+
inpatient_visit_indicator = False
|
298
|
+
|
299
|
+
for event_idx, event in enumerate(clinical_events, 0):
|
300
|
+
# For bad sequences, we don't proceed further and break from the for loop
|
301
|
+
if bad_sequence:
|
302
|
+
break
|
303
|
+
if is_visit_start(event):
|
304
|
+
if event_idx == len(clinical_events) - 1:
|
305
|
+
break
|
306
|
+
elif clinical_events[event_idx + 1] == "[DEATH]":
|
307
|
+
# If the [DEATH] token is not placed at the end of the sequence, this is a bad sequence
|
308
|
+
if event_idx + 2 != len(clinical_events) - 1:
|
309
|
+
bad_sequence = True
|
310
|
+
break
|
311
|
+
death = Death(p, date_cursor.date())
|
312
|
+
append_to_dict(omop_export_dict, death, person_id)
|
313
|
+
id_mappings_dict["death"][person_id] = person_id
|
314
|
+
else:
|
315
|
+
try:
|
316
|
+
visit_concept_id = int(clinical_events[event_idx + 1])
|
317
|
+
inpatient_visit_indicator = visit_concept_id in [
|
318
|
+
9201,
|
319
|
+
262,
|
320
|
+
8971,
|
321
|
+
8920,
|
322
|
+
]
|
323
|
+
if visit_concept_id in domain_map:
|
324
|
+
if (
|
325
|
+
domain_map[visit_concept_id] != "Visit"
|
326
|
+
and visit_concept_id != 0
|
327
|
+
):
|
328
|
+
bad_sequence = True
|
329
|
+
break
|
330
|
+
else:
|
331
|
+
bad_sequence = True
|
332
|
+
break
|
333
|
+
|
334
|
+
except (IndexError, ValueError):
|
335
|
+
error_dict[person_id] = {}
|
336
|
+
error_dict[person_id]["concept_ids"] = " ".join(concept_ids)
|
337
|
+
error_dict[person_id]["error"] = "Wrong visit concept id"
|
338
|
+
bad_sequence = True
|
339
|
+
continue
|
340
|
+
|
341
|
+
vo = VisitOccurrence(
|
342
|
+
visit_occurrence_id, visit_concept_id, date_cursor, p
|
343
|
+
)
|
344
|
+
append_to_dict(omop_export_dict, vo, visit_occurrence_id)
|
345
|
+
id_mappings_dict["visit_occurrence"][
|
346
|
+
visit_occurrence_id
|
347
|
+
] = person_id
|
348
|
+
visit_occurrence_id += 1
|
349
|
+
elif event in ATT_TIME_TOKENS:
|
350
|
+
if event[0] == "D":
|
351
|
+
att_date_delta = int(event[1:])
|
352
|
+
elif event[0] == "W":
|
353
|
+
att_date_delta = int(event[1:]) * 7
|
354
|
+
elif event[0] == "M":
|
355
|
+
att_date_delta = int(event[1:]) * 30
|
356
|
+
elif event == "LT":
|
357
|
+
att_date_delta = 365 * 3
|
358
|
+
else:
|
359
|
+
att_date_delta = 0
|
360
|
+
# Between visits, the date delta is simply calculated as the date difference
|
361
|
+
date_cursor = date_cursor.replace(
|
362
|
+
hour=0, minute=0, second=0, microsecond=0
|
363
|
+
)
|
364
|
+
date_cursor = date_cursor + timedelta(days=att_date_delta)
|
365
|
+
elif inpatient_visit_indicator and is_inpatient_att_token(event):
|
366
|
+
inpatient_time_span_in_days = extract_time_interval_in_days(event)
|
367
|
+
# Reset the data cursor to the start of the day before adding the num of days parsed out from the token
|
368
|
+
date_cursor = date_cursor.replace(hour=0, minute=0, second=0)
|
369
|
+
date_cursor = date_cursor + timedelta(days=inpatient_time_span_in_days)
|
370
|
+
elif inpatient_visit_indicator and event.startswith("i-H"):
|
371
|
+
# Handle hour tokens differently than the day tokens
|
372
|
+
# The way we construct the inpatient hour tokens is that the sum of the consecutive
|
373
|
+
# hour tokens cannot exceed the current day, so the data_cursor is bounded by a
|
374
|
+
# theoretical upper limit
|
375
|
+
upper_bound = date_cursor.replace(
|
376
|
+
hour=0, minute=0, second=0
|
377
|
+
) + timedelta(hours=23, minutes=59, seconds=59)
|
378
|
+
hour_delta = int(event[3:])
|
379
|
+
date_cursor = date_cursor + timedelta(hours=hour_delta)
|
380
|
+
if date_cursor > upper_bound:
|
381
|
+
date_cursor = upper_bound
|
382
|
+
elif is_visit_end(event):
|
383
|
+
if vo is None:
|
384
|
+
bad_sequence = True
|
385
|
+
break
|
386
|
+
# If it's a VE token, nothing needs to be updated because it just means the visit ended
|
387
|
+
if inpatient_visit_indicator:
|
388
|
+
vo.set_discharged_to_concept_id(discharged_to_concept_id)
|
389
|
+
vo.set_visit_end_date(date_cursor)
|
390
|
+
# if the discharged_to_concept_id patient had died, the death record is created
|
391
|
+
if discharged_to_concept_id == 4216643:
|
392
|
+
death = Death(
|
393
|
+
p, date_cursor.date(), death_type_concept_id=32823
|
394
|
+
)
|
395
|
+
append_to_dict(omop_export_dict, death, person_id)
|
396
|
+
id_mappings_dict["death"][person_id] = person_id
|
397
|
+
# If death record is generated, we need to stop the sequence conversion
|
398
|
+
break
|
399
|
+
else:
|
400
|
+
pass
|
401
|
+
elif event in [
|
402
|
+
"START",
|
403
|
+
start_year,
|
404
|
+
start_age,
|
405
|
+
start_gender,
|
406
|
+
start_race,
|
407
|
+
"[DEATH]",
|
408
|
+
]:
|
409
|
+
# If it's a start token, skip it
|
410
|
+
pass
|
411
|
+
else:
|
412
|
+
try:
|
413
|
+
concept_id = int(event)
|
414
|
+
if (
|
415
|
+
concept_id not in domain_map
|
416
|
+
and concept_id not in OOV_CONCEPT_MAP
|
417
|
+
):
|
418
|
+
error_dict[person_id] = {}
|
419
|
+
error_dict[person_id]["concept_ids"] = " ".join(concept_ids)
|
420
|
+
error_dict[person_id][
|
421
|
+
"error"
|
422
|
+
] = f"No concept id found: {concept_id}"
|
423
|
+
bad_sequence = True
|
424
|
+
continue
|
425
|
+
else:
|
426
|
+
# If the current concept_id is 'Patient Died', this means it can only occur in the
|
427
|
+
# discharged_to_concept_id field, which indicates the current visit has to be an inpatient
|
428
|
+
# visit, this concept_id can only appear at the second last position
|
429
|
+
if concept_id == 4216643:
|
430
|
+
# If the current visit is not inpatient, reject the sequence
|
431
|
+
if not inpatient_visit_indicator:
|
432
|
+
bad_sequence = True
|
433
|
+
continue
|
434
|
+
# # If the current token is not the second last one of the sequence, reject because
|
435
|
+
# # death can only appear at the end of the sequence
|
436
|
+
# if idx + 1 != len(tokens_generated) - 1:
|
437
|
+
# bad_sequence = True
|
438
|
+
# continue
|
439
|
+
# we also enforce the rule where the sequence has to end on a VE token
|
440
|
+
if event_idx + 1 < len(
|
441
|
+
clinical_events
|
442
|
+
) and not is_visit_end(clinical_events[event_idx + 1]):
|
443
|
+
bad_sequence = True
|
444
|
+
continue
|
445
|
+
|
446
|
+
if concept_id in domain_map:
|
447
|
+
domain = domain_map[concept_id]
|
448
|
+
elif concept_id in OOV_CONCEPT_MAP:
|
449
|
+
domain = OOV_CONCEPT_MAP[concept_id]
|
450
|
+
else:
|
451
|
+
domain = None
|
452
|
+
|
453
|
+
if domain == "Visit" or concept_id in DISCHARGE_CONCEPT_LIST:
|
454
|
+
discharged_to_concept_id = concept_id
|
455
|
+
elif domain == "Condition":
|
456
|
+
co = ConditionOccurrence(
|
457
|
+
condition_occurrence_id, concept_id, vo, date_cursor
|
458
|
+
)
|
459
|
+
append_to_dict(
|
460
|
+
omop_export_dict, co, condition_occurrence_id
|
461
|
+
)
|
462
|
+
id_mappings_dict["condition_occurrence"][
|
463
|
+
condition_occurrence_id
|
464
|
+
] = person_id
|
465
|
+
condition_occurrence_id += 1
|
466
|
+
elif domain == "Procedure":
|
467
|
+
po = ProcedureOccurrence(
|
468
|
+
procedure_occurrence_id, concept_id, vo, date_cursor
|
469
|
+
)
|
470
|
+
append_to_dict(
|
471
|
+
omop_export_dict, po, procedure_occurrence_id
|
472
|
+
)
|
473
|
+
id_mappings_dict["procedure_occurrence"][
|
474
|
+
procedure_occurrence_id
|
475
|
+
] = person_id
|
476
|
+
procedure_occurrence_id += 1
|
477
|
+
elif domain == "Drug":
|
478
|
+
de = DrugExposure(
|
479
|
+
drug_exposure_id, concept_id, vo, date_cursor
|
480
|
+
)
|
481
|
+
append_to_dict(omop_export_dict, de, drug_exposure_id)
|
482
|
+
id_mappings_dict["drug_exposure"][
|
483
|
+
drug_exposure_id
|
484
|
+
] = person_id
|
485
|
+
drug_exposure_id += 1
|
486
|
+
elif domain == "Measurement":
|
487
|
+
number_as_value = (
|
488
|
+
number_as_values[event_idx]
|
489
|
+
if number_as_values is not None
|
490
|
+
else None
|
491
|
+
)
|
492
|
+
concept_as_value = (
|
493
|
+
concept_as_values[event_idx]
|
494
|
+
if concept_as_values is not None
|
495
|
+
else None
|
496
|
+
)
|
497
|
+
is_numeric_type = (
|
498
|
+
is_numeric_types[event_idx]
|
499
|
+
if is_numeric_types is not None
|
500
|
+
else None
|
501
|
+
)
|
502
|
+
unit = units[event_idx] if units is not None else None
|
503
|
+
m = Measurement(
|
504
|
+
measurement_id,
|
505
|
+
measurement_concept_id=concept_id,
|
506
|
+
is_numeric_type=is_numeric_type,
|
507
|
+
value_as_number=number_as_value,
|
508
|
+
value_as_concept_id=concept_as_value,
|
509
|
+
visit_occurrence=vo,
|
510
|
+
measurement_datetime=date_cursor,
|
511
|
+
unit_source_value=unit,
|
512
|
+
)
|
513
|
+
append_to_dict(omop_export_dict, m, measurement_id)
|
514
|
+
id_mappings_dict["measurement"][measurement_id] = person_id
|
515
|
+
measurement_id += 1
|
516
|
+
|
517
|
+
except ValueError:
|
518
|
+
error_dict[person_id] = {}
|
519
|
+
error_dict[person_id]["concept_ids"] = " ".join(concept_ids)
|
520
|
+
error_dict[person_id]["error"] = f"Wrong concept id: {event}"
|
521
|
+
bad_sequence = True
|
522
|
+
continue
|
523
|
+
if bad_sequence:
|
524
|
+
delete_bad_sequence(omop_export_dict, id_mappings_dict, person_id)
|
525
|
+
|
526
|
+
if not use_original_person_id:
|
527
|
+
person_id += 1
|
528
|
+
|
529
|
+
if index != 0 and index % buffer_size == 0:
|
530
|
+
export_and_clear(
|
531
|
+
output_folder,
|
532
|
+
omop_export_dict,
|
533
|
+
export_error,
|
534
|
+
id_mappings_dict,
|
535
|
+
pt_seq_dict,
|
536
|
+
)
|
537
|
+
|
538
|
+
# Final flush to the disk if there are still records in the cache
|
539
|
+
export_and_clear(
|
540
|
+
output_folder, omop_export_dict, export_error, id_mappings_dict, pt_seq_dict
|
541
|
+
)
|
542
|
+
|
543
|
+
with open(Path(output_folder) / "concept_errors.txt", "w") as f:
|
544
|
+
error_dict["total"] = len(error_dict)
|
545
|
+
f.write(str(error_dict))
|
546
|
+
with open(Path(output_folder) / "export_errors.txt", "w") as f:
|
547
|
+
total = 0
|
548
|
+
for k, v in export_error.items():
|
549
|
+
total += len(v)
|
550
|
+
export_error["total"] = total
|
551
|
+
f.write(str(export_error))
|
552
|
+
|
553
|
+
|
554
|
+
def main(args):
|
555
|
+
all_parquet_files = glob.glob(
|
556
|
+
os.path.join(args.patient_sequence_path, "*parquet"), recursive=True
|
557
|
+
)
|
558
|
+
if len(all_parquet_files) == 0:
|
559
|
+
raise RuntimeError(f"No parquet files found in {args.patient_sequence_path}")
|
560
|
+
|
561
|
+
print(
|
562
|
+
f"There are total {len(all_parquet_files)} parquet files detected in {args.patient_sequence_path}."
|
563
|
+
)
|
564
|
+
if not os.path.exists(args.output_folder):
|
565
|
+
Path(args.output_folder).mkdir(parents=True, exist_ok=True)
|
566
|
+
|
567
|
+
batched_parquet_files = np.array_split(all_parquet_files, args.cpu_cores)
|
568
|
+
concept_pd = pd.read_parquet(args.concept_path)
|
569
|
+
domain_map = generate_omop_concept_domain(concept_pd)
|
570
|
+
|
571
|
+
pool_tuples = []
|
572
|
+
# TODO: Need to make this dynamic
|
573
|
+
const = 10000000
|
574
|
+
for i in range(1, args.cpu_cores + 1):
|
575
|
+
pool_tuples.append(
|
576
|
+
(
|
577
|
+
const * i,
|
578
|
+
batched_parquet_files[i - 1],
|
579
|
+
domain_map,
|
580
|
+
args.output_folder,
|
581
|
+
args.buffer_size,
|
582
|
+
args.use_original_person_id,
|
583
|
+
)
|
584
|
+
)
|
585
|
+
|
586
|
+
with Pool(processes=args.cpu_cores) as p:
|
587
|
+
p.starmap(gpt_to_omop_converter_batch, pool_tuples)
|
588
|
+
p.close()
|
589
|
+
p.join()
|
590
|
+
|
591
|
+
return print("Done")
|
592
|
+
|
593
|
+
|
594
|
+
if __name__ == "__main__":
|
595
|
+
parser = argparse.ArgumentParser(
|
596
|
+
description="Arguments for converting patient sequences to OMOP"
|
597
|
+
)
|
598
|
+
parser.add_argument(
|
599
|
+
"--output_folder",
|
600
|
+
dest="output_folder",
|
601
|
+
action="store",
|
602
|
+
help="The path for the output_folder",
|
603
|
+
required=True,
|
604
|
+
)
|
605
|
+
parser.add_argument(
|
606
|
+
"--concept_path",
|
607
|
+
dest="concept_path",
|
608
|
+
action="store",
|
609
|
+
help="The path for your concept_path",
|
610
|
+
required=True,
|
611
|
+
)
|
612
|
+
parser.add_argument(
|
613
|
+
"--buffer_size",
|
614
|
+
dest="buffer_size",
|
615
|
+
action="store",
|
616
|
+
type=int,
|
617
|
+
help="The size of the batch",
|
618
|
+
required=False,
|
619
|
+
default=1024,
|
620
|
+
)
|
621
|
+
parser.add_argument(
|
622
|
+
"--patient_sequence_path",
|
623
|
+
dest="patient_sequence_path",
|
624
|
+
action="store",
|
625
|
+
help="The path for your patient sequence",
|
626
|
+
required=True,
|
627
|
+
)
|
628
|
+
parser.add_argument(
|
629
|
+
"--cpu_cores",
|
630
|
+
dest="cpu_cores",
|
631
|
+
type=int,
|
632
|
+
action="store",
|
633
|
+
help="The number of cpu cores to use for multiprocessing",
|
634
|
+
required=False,
|
635
|
+
default=1,
|
636
|
+
)
|
637
|
+
parser.add_argument(
|
638
|
+
"--use_original_person_id",
|
639
|
+
dest="use_original_person_id",
|
640
|
+
action="store_true",
|
641
|
+
help="Whether or not to use the original person id",
|
642
|
+
)
|
643
|
+
|
644
|
+
main(parser.parse_args())
|