cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,644 @@
1
+ import argparse
2
+ import datetime
3
+ import glob
4
+ import os
5
+ import uuid
6
+ from datetime import timedelta
7
+ from multiprocessing import Pool
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import pyarrow.parquet as pq
14
+ from tqdm import tqdm
15
+
16
+ from cehrgpt.generation.omop_entity import (
17
+ ConditionOccurrence,
18
+ Death,
19
+ DrugExposure,
20
+ Measurement,
21
+ OmopEntity,
22
+ Person,
23
+ ProcedureOccurrence,
24
+ VisitOccurrence,
25
+ )
26
+ from cehrgpt.gpt_utils import (
27
+ extract_time_interval_in_days,
28
+ generate_artificial_time_tokens,
29
+ is_inpatient_att_token,
30
+ is_visit_end,
31
+ is_visit_start,
32
+ )
33
+ from cehrgpt.models.tokenization_hf_cehrgpt import END_TOKEN
34
+
35
+ # TODO: move these to cehrbert_data
36
+ STOP_TOKENS = ["VE", "[VE]", END_TOKEN]
37
+
38
+ CURRENT_PATH = Path(__file__).parent
39
+ START_TOKEN_SIZE = 4
40
+ ATT_TIME_TOKENS = generate_artificial_time_tokens()
41
+ TABLE_LIST = [
42
+ "person",
43
+ "visit_occurrence",
44
+ "condition_occurrence",
45
+ "procedure_occurrence",
46
+ "drug_exposure",
47
+ "death",
48
+ "measurement",
49
+ ]
50
+ DISCHARGE_CONCEPT_LIST = [4216643, 4021968, 4146681, 4161979]
51
+ OOV_CONCEPT_MAP = {
52
+ 1525734: "Drug",
53
+ 779414: "Drug",
54
+ 722117: "Drug",
55
+ 722118: "Drug",
56
+ 722119: "Drug",
57
+ 905420: "Drug",
58
+ 1525543: "Drug",
59
+ }
60
+
61
+
62
+ def create_folder_if_not_exists(output_folder, table_name):
63
+ if not os.path.isdir(Path(output_folder) / table_name):
64
+ os.mkdir(Path(output_folder) / table_name)
65
+
66
+
67
+ def generate_omop_concept_domain(concept_parquet) -> Dict[int, str]:
68
+ """
69
+ Generate a dictionary of concept_id to domain_id.
70
+
71
+ :param concept_parquet: concept dataframe read from parquet file
72
+ :return: dictionary of concept_id to domain_id
73
+ """
74
+ domain_dict = {}
75
+ for i in concept_parquet.itertuples():
76
+ domain_dict[i.concept_id] = i.domain_id
77
+ return domain_dict
78
+
79
+
80
+ def generate_lab_stats_mapping(
81
+ all_lab_stats: Optional[List[Dict[str, Any]]]
82
+ ) -> Dict[int, Dict[str, Any]]:
83
+ lab_stats_mapping = {}
84
+ if all_lab_stats is not None:
85
+ for lab_stats in all_lab_stats:
86
+ # TODO: the numeric check will not hold true if we concatenate
87
+ # the concept with the corresponding unit concept
88
+ if lab_stats["concept_id"].isnumeric():
89
+ concept_id = int(lab_stats["concept_id"])
90
+ count = lab_stats["concept_id"]
91
+ if (concept_id in lab_stats_mapping) and (
92
+ count > lab_stats_mapping[concept_id]["count"]
93
+ ):
94
+ lab_stats_mapping[concept_id] = {
95
+ "mean": lab_stats["mean"],
96
+ "std": lab_stats["std"],
97
+ "count": lab_stats["count"],
98
+ }
99
+ else:
100
+ lab_stats_mapping[concept_id] = {
101
+ "mean": lab_stats["mean"],
102
+ "std": lab_stats["std"],
103
+ "count": lab_stats["count"],
104
+ }
105
+ return lab_stats_mapping
106
+
107
+
108
+ def append_to_dict(
109
+ export_dict: Dict[str, Dict[int, OmopEntity]],
110
+ omop_entity: OmopEntity,
111
+ entity_id: int,
112
+ ):
113
+ if omop_entity.get_table_name() not in export_dict:
114
+ export_dict[omop_entity.get_table_name()] = {}
115
+ export_dict[omop_entity.get_table_name()][entity_id] = omop_entity
116
+
117
+
118
+ def delete_bad_sequence(
119
+ export_dict: Dict[str, Dict[int, OmopEntity]],
120
+ id_mappings: Dict[str, Dict[int, int]],
121
+ person_id: int,
122
+ ):
123
+ for table_name, id_mapping in id_mappings.items():
124
+ omop_id_mapping = np.array(list(id_mapping.keys()))
125
+ person_id_mapping = np.array(list(id_mapping.values()))
126
+ ids_to_delete = omop_id_mapping[np.where(person_id_mapping == person_id)]
127
+ for id in ids_to_delete:
128
+ export_dict[table_name].pop(id)
129
+
130
+
131
+ def export_and_clear(
132
+ output_folder: str,
133
+ export_dict: Dict[str, Dict[int, OmopEntity]],
134
+ export_error: Dict[str, Dict[str, str]],
135
+ id_mappings_dict: Dict[str, Dict[int, int]],
136
+ pt_seq_dict: Dict[int, str],
137
+ is_parquet: bool = True,
138
+ ):
139
+ for table_name, records_to_export in export_dict.items():
140
+
141
+ records_in_json = []
142
+ # If there is no omop_entity, we skip it
143
+ if len(export_dict[table_name]) == 0:
144
+ continue
145
+
146
+ for entity_id, omop_entity in export_dict[table_name].items():
147
+ try:
148
+ records_in_json.append(omop_entity.export_as_json())
149
+ except AttributeError:
150
+ # append patient sequence to export error list using pt_seq_dict.
151
+ if table_name not in export_error:
152
+ export_error[table_name] = []
153
+ person_id = id_mappings_dict[table_name][entity_id]
154
+ export_error[table_name].append(pt_seq_dict[person_id])
155
+ continue
156
+ schema = next(iter(records_to_export.items()))[1].get_schema()
157
+ output_folder_path = Path(output_folder)
158
+ file_path = output_folder_path / table_name / f"{uuid.uuid4()}.parquet"
159
+ table_df = pd.DataFrame(records_in_json, columns=schema)
160
+
161
+ if is_parquet:
162
+ table_df.to_parquet(file_path)
163
+ else:
164
+ table_df.to_csv(file_path, header=schema, index=False)
165
+
166
+ export_dict[table_name].clear()
167
+
168
+
169
+ def _is_none(x):
170
+ return x is None or np.isnan(x)
171
+
172
+
173
+ def get_num_records(parquet_files: List[str]):
174
+ total = 0
175
+ for file_path in parquet_files:
176
+ parquet_file = pq.ParquetFile(file_path)
177
+ total += parquet_file.metadata.num_rows
178
+ return total
179
+
180
+
181
+ def record_generator(parquet_files):
182
+ for file_path in parquet_files:
183
+ df = pd.read_parquet(file_path)
184
+ for record in df.itertuples():
185
+ yield record
186
+
187
+
188
+ def gpt_to_omop_converter_batch(
189
+ const: int,
190
+ patient_sequence_parquet_files: List[str],
191
+ domain_map: Dict[int, str],
192
+ output_folder: str,
193
+ buffer_size: int,
194
+ use_original_person_id: bool,
195
+ ):
196
+ omop_export_dict = {}
197
+ error_dict = {}
198
+ export_error = {}
199
+ id_mappings_dict = {}
200
+ pt_seq_dict = {}
201
+
202
+ for tb in TABLE_LIST:
203
+ create_folder_if_not_exists(output_folder, tb)
204
+ id_mappings_dict[tb] = {}
205
+
206
+ visit_occurrence_id: int = const + 1
207
+ condition_occurrence_id: int = const + 1
208
+ procedure_occurrence_id: int = const + 1
209
+ drug_exposure_id: int = const + 1
210
+ measurement_id: int = const + 1
211
+
212
+ # Default the person_id
213
+ person_id: int = const + 1
214
+
215
+ patient_record_generator = record_generator(patient_sequence_parquet_files)
216
+ total_record = get_num_records(patient_sequence_parquet_files)
217
+
218
+ for index, record in tqdm(enumerate(patient_record_generator), total=total_record):
219
+ bad_sequence = False
220
+ # If original_person_id is set to true, we retrieve it from the record.
221
+ # If person_id doest not exist in the record, we use the default_person_id
222
+ if use_original_person_id:
223
+ person_id = getattr(record, "person_id", person_id)
224
+
225
+ # Retrieve the
226
+ concept_ids = getattr(record, "concept_ids")
227
+ is_numeric_types = getattr(record, "is_numeric_types", None)
228
+ number_as_values = getattr(record, "number_as_values", None)
229
+ concept_as_values = getattr(record, "concept_as_values", None)
230
+ units = getattr(record, "units", None)
231
+
232
+ # Skip the start token if it is the first token
233
+ if "start" in concept_ids[0].lower():
234
+ concept_ids = concept_ids[1:]
235
+ if is_numeric_types is not None:
236
+ is_numeric_types = is_numeric_types[1:]
237
+ if number_as_values is not None:
238
+ number_as_values = number_as_values[1:]
239
+ if concept_as_values is not None:
240
+ concept_as_values = concept_as_values[1:]
241
+ if units is not None:
242
+ units = units[1:]
243
+
244
+ clinical_events = concept_ids[START_TOKEN_SIZE:]
245
+ # Skip the sequences whose sequence length is 0
246
+ if len(clinical_events) == 0:
247
+ continue
248
+ # Skip the patients whose last token is not a valid end token
249
+ if clinical_events[-1] not in STOP_TOKENS:
250
+ continue
251
+
252
+ is_numeric_types = (
253
+ is_numeric_types[START_TOKEN_SIZE:]
254
+ if is_numeric_types is not None
255
+ else None
256
+ )
257
+ number_as_values = (
258
+ number_as_values[START_TOKEN_SIZE:]
259
+ if number_as_values is not None
260
+ else None
261
+ )
262
+ concept_as_values = (
263
+ concept_as_values[START_TOKEN_SIZE:]
264
+ if concept_as_values is not None
265
+ else None
266
+ )
267
+ units = units[START_TOKEN_SIZE:] if units is not None else None
268
+
269
+ # TODO:Need to decode if the input is tokenized
270
+ [start_year, start_age, start_gender, start_race] = concept_ids[
271
+ 0:START_TOKEN_SIZE
272
+ ]
273
+ if "year" not in start_year.lower():
274
+ continue
275
+
276
+ try:
277
+ start_year = start_year.split(":")[1]
278
+ start_age = start_age.split(":")[1]
279
+ birth_year = int(start_year) - int(start_age)
280
+ except Exception as e:
281
+ print(
282
+ f"Failed to convert {concept_ids[0:START_TOKEN_SIZE]} due to {e}, skipping to the next record"
283
+ )
284
+ continue
285
+
286
+ # Skip the patients whose birth year is either before 1900 or after this year
287
+ if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
288
+ continue
289
+
290
+ p = Person(person_id, start_gender, birth_year, start_race)
291
+ append_to_dict(omop_export_dict, p, person_id)
292
+ id_mappings_dict["person"][person_id] = person_id
293
+ pt_seq_dict[person_id] = " ".join(concept_ids)
294
+ discharged_to_concept_id = 0
295
+ date_cursor = datetime.datetime(year=int(start_year), month=1, day=1)
296
+ vo = None
297
+ inpatient_visit_indicator = False
298
+
299
+ for event_idx, event in enumerate(clinical_events, 0):
300
+ # For bad sequences, we don't proceed further and break from the for loop
301
+ if bad_sequence:
302
+ break
303
+ if is_visit_start(event):
304
+ if event_idx == len(clinical_events) - 1:
305
+ break
306
+ elif clinical_events[event_idx + 1] == "[DEATH]":
307
+ # If the [DEATH] token is not placed at the end of the sequence, this is a bad sequence
308
+ if event_idx + 2 != len(clinical_events) - 1:
309
+ bad_sequence = True
310
+ break
311
+ death = Death(p, date_cursor.date())
312
+ append_to_dict(omop_export_dict, death, person_id)
313
+ id_mappings_dict["death"][person_id] = person_id
314
+ else:
315
+ try:
316
+ visit_concept_id = int(clinical_events[event_idx + 1])
317
+ inpatient_visit_indicator = visit_concept_id in [
318
+ 9201,
319
+ 262,
320
+ 8971,
321
+ 8920,
322
+ ]
323
+ if visit_concept_id in domain_map:
324
+ if (
325
+ domain_map[visit_concept_id] != "Visit"
326
+ and visit_concept_id != 0
327
+ ):
328
+ bad_sequence = True
329
+ break
330
+ else:
331
+ bad_sequence = True
332
+ break
333
+
334
+ except (IndexError, ValueError):
335
+ error_dict[person_id] = {}
336
+ error_dict[person_id]["concept_ids"] = " ".join(concept_ids)
337
+ error_dict[person_id]["error"] = "Wrong visit concept id"
338
+ bad_sequence = True
339
+ continue
340
+
341
+ vo = VisitOccurrence(
342
+ visit_occurrence_id, visit_concept_id, date_cursor, p
343
+ )
344
+ append_to_dict(omop_export_dict, vo, visit_occurrence_id)
345
+ id_mappings_dict["visit_occurrence"][
346
+ visit_occurrence_id
347
+ ] = person_id
348
+ visit_occurrence_id += 1
349
+ elif event in ATT_TIME_TOKENS:
350
+ if event[0] == "D":
351
+ att_date_delta = int(event[1:])
352
+ elif event[0] == "W":
353
+ att_date_delta = int(event[1:]) * 7
354
+ elif event[0] == "M":
355
+ att_date_delta = int(event[1:]) * 30
356
+ elif event == "LT":
357
+ att_date_delta = 365 * 3
358
+ else:
359
+ att_date_delta = 0
360
+ # Between visits, the date delta is simply calculated as the date difference
361
+ date_cursor = date_cursor.replace(
362
+ hour=0, minute=0, second=0, microsecond=0
363
+ )
364
+ date_cursor = date_cursor + timedelta(days=att_date_delta)
365
+ elif inpatient_visit_indicator and is_inpatient_att_token(event):
366
+ inpatient_time_span_in_days = extract_time_interval_in_days(event)
367
+ # Reset the data cursor to the start of the day before adding the num of days parsed out from the token
368
+ date_cursor = date_cursor.replace(hour=0, minute=0, second=0)
369
+ date_cursor = date_cursor + timedelta(days=inpatient_time_span_in_days)
370
+ elif inpatient_visit_indicator and event.startswith("i-H"):
371
+ # Handle hour tokens differently than the day tokens
372
+ # The way we construct the inpatient hour tokens is that the sum of the consecutive
373
+ # hour tokens cannot exceed the current day, so the data_cursor is bounded by a
374
+ # theoretical upper limit
375
+ upper_bound = date_cursor.replace(
376
+ hour=0, minute=0, second=0
377
+ ) + timedelta(hours=23, minutes=59, seconds=59)
378
+ hour_delta = int(event[3:])
379
+ date_cursor = date_cursor + timedelta(hours=hour_delta)
380
+ if date_cursor > upper_bound:
381
+ date_cursor = upper_bound
382
+ elif is_visit_end(event):
383
+ if vo is None:
384
+ bad_sequence = True
385
+ break
386
+ # If it's a VE token, nothing needs to be updated because it just means the visit ended
387
+ if inpatient_visit_indicator:
388
+ vo.set_discharged_to_concept_id(discharged_to_concept_id)
389
+ vo.set_visit_end_date(date_cursor)
390
+ # if the discharged_to_concept_id patient had died, the death record is created
391
+ if discharged_to_concept_id == 4216643:
392
+ death = Death(
393
+ p, date_cursor.date(), death_type_concept_id=32823
394
+ )
395
+ append_to_dict(omop_export_dict, death, person_id)
396
+ id_mappings_dict["death"][person_id] = person_id
397
+ # If death record is generated, we need to stop the sequence conversion
398
+ break
399
+ else:
400
+ pass
401
+ elif event in [
402
+ "START",
403
+ start_year,
404
+ start_age,
405
+ start_gender,
406
+ start_race,
407
+ "[DEATH]",
408
+ ]:
409
+ # If it's a start token, skip it
410
+ pass
411
+ else:
412
+ try:
413
+ concept_id = int(event)
414
+ if (
415
+ concept_id not in domain_map
416
+ and concept_id not in OOV_CONCEPT_MAP
417
+ ):
418
+ error_dict[person_id] = {}
419
+ error_dict[person_id]["concept_ids"] = " ".join(concept_ids)
420
+ error_dict[person_id][
421
+ "error"
422
+ ] = f"No concept id found: {concept_id}"
423
+ bad_sequence = True
424
+ continue
425
+ else:
426
+ # If the current concept_id is 'Patient Died', this means it can only occur in the
427
+ # discharged_to_concept_id field, which indicates the current visit has to be an inpatient
428
+ # visit, this concept_id can only appear at the second last position
429
+ if concept_id == 4216643:
430
+ # If the current visit is not inpatient, reject the sequence
431
+ if not inpatient_visit_indicator:
432
+ bad_sequence = True
433
+ continue
434
+ # # If the current token is not the second last one of the sequence, reject because
435
+ # # death can only appear at the end of the sequence
436
+ # if idx + 1 != len(tokens_generated) - 1:
437
+ # bad_sequence = True
438
+ # continue
439
+ # we also enforce the rule where the sequence has to end on a VE token
440
+ if event_idx + 1 < len(
441
+ clinical_events
442
+ ) and not is_visit_end(clinical_events[event_idx + 1]):
443
+ bad_sequence = True
444
+ continue
445
+
446
+ if concept_id in domain_map:
447
+ domain = domain_map[concept_id]
448
+ elif concept_id in OOV_CONCEPT_MAP:
449
+ domain = OOV_CONCEPT_MAP[concept_id]
450
+ else:
451
+ domain = None
452
+
453
+ if domain == "Visit" or concept_id in DISCHARGE_CONCEPT_LIST:
454
+ discharged_to_concept_id = concept_id
455
+ elif domain == "Condition":
456
+ co = ConditionOccurrence(
457
+ condition_occurrence_id, concept_id, vo, date_cursor
458
+ )
459
+ append_to_dict(
460
+ omop_export_dict, co, condition_occurrence_id
461
+ )
462
+ id_mappings_dict["condition_occurrence"][
463
+ condition_occurrence_id
464
+ ] = person_id
465
+ condition_occurrence_id += 1
466
+ elif domain == "Procedure":
467
+ po = ProcedureOccurrence(
468
+ procedure_occurrence_id, concept_id, vo, date_cursor
469
+ )
470
+ append_to_dict(
471
+ omop_export_dict, po, procedure_occurrence_id
472
+ )
473
+ id_mappings_dict["procedure_occurrence"][
474
+ procedure_occurrence_id
475
+ ] = person_id
476
+ procedure_occurrence_id += 1
477
+ elif domain == "Drug":
478
+ de = DrugExposure(
479
+ drug_exposure_id, concept_id, vo, date_cursor
480
+ )
481
+ append_to_dict(omop_export_dict, de, drug_exposure_id)
482
+ id_mappings_dict["drug_exposure"][
483
+ drug_exposure_id
484
+ ] = person_id
485
+ drug_exposure_id += 1
486
+ elif domain == "Measurement":
487
+ number_as_value = (
488
+ number_as_values[event_idx]
489
+ if number_as_values is not None
490
+ else None
491
+ )
492
+ concept_as_value = (
493
+ concept_as_values[event_idx]
494
+ if concept_as_values is not None
495
+ else None
496
+ )
497
+ is_numeric_type = (
498
+ is_numeric_types[event_idx]
499
+ if is_numeric_types is not None
500
+ else None
501
+ )
502
+ unit = units[event_idx] if units is not None else None
503
+ m = Measurement(
504
+ measurement_id,
505
+ measurement_concept_id=concept_id,
506
+ is_numeric_type=is_numeric_type,
507
+ value_as_number=number_as_value,
508
+ value_as_concept_id=concept_as_value,
509
+ visit_occurrence=vo,
510
+ measurement_datetime=date_cursor,
511
+ unit_source_value=unit,
512
+ )
513
+ append_to_dict(omop_export_dict, m, measurement_id)
514
+ id_mappings_dict["measurement"][measurement_id] = person_id
515
+ measurement_id += 1
516
+
517
+ except ValueError:
518
+ error_dict[person_id] = {}
519
+ error_dict[person_id]["concept_ids"] = " ".join(concept_ids)
520
+ error_dict[person_id]["error"] = f"Wrong concept id: {event}"
521
+ bad_sequence = True
522
+ continue
523
+ if bad_sequence:
524
+ delete_bad_sequence(omop_export_dict, id_mappings_dict, person_id)
525
+
526
+ if not use_original_person_id:
527
+ person_id += 1
528
+
529
+ if index != 0 and index % buffer_size == 0:
530
+ export_and_clear(
531
+ output_folder,
532
+ omop_export_dict,
533
+ export_error,
534
+ id_mappings_dict,
535
+ pt_seq_dict,
536
+ )
537
+
538
+ # Final flush to the disk if there are still records in the cache
539
+ export_and_clear(
540
+ output_folder, omop_export_dict, export_error, id_mappings_dict, pt_seq_dict
541
+ )
542
+
543
+ with open(Path(output_folder) / "concept_errors.txt", "w") as f:
544
+ error_dict["total"] = len(error_dict)
545
+ f.write(str(error_dict))
546
+ with open(Path(output_folder) / "export_errors.txt", "w") as f:
547
+ total = 0
548
+ for k, v in export_error.items():
549
+ total += len(v)
550
+ export_error["total"] = total
551
+ f.write(str(export_error))
552
+
553
+
554
+ def main(args):
555
+ all_parquet_files = glob.glob(
556
+ os.path.join(args.patient_sequence_path, "*parquet"), recursive=True
557
+ )
558
+ if len(all_parquet_files) == 0:
559
+ raise RuntimeError(f"No parquet files found in {args.patient_sequence_path}")
560
+
561
+ print(
562
+ f"There are total {len(all_parquet_files)} parquet files detected in {args.patient_sequence_path}."
563
+ )
564
+ if not os.path.exists(args.output_folder):
565
+ Path(args.output_folder).mkdir(parents=True, exist_ok=True)
566
+
567
+ batched_parquet_files = np.array_split(all_parquet_files, args.cpu_cores)
568
+ concept_pd = pd.read_parquet(args.concept_path)
569
+ domain_map = generate_omop_concept_domain(concept_pd)
570
+
571
+ pool_tuples = []
572
+ # TODO: Need to make this dynamic
573
+ const = 10000000
574
+ for i in range(1, args.cpu_cores + 1):
575
+ pool_tuples.append(
576
+ (
577
+ const * i,
578
+ batched_parquet_files[i - 1],
579
+ domain_map,
580
+ args.output_folder,
581
+ args.buffer_size,
582
+ args.use_original_person_id,
583
+ )
584
+ )
585
+
586
+ with Pool(processes=args.cpu_cores) as p:
587
+ p.starmap(gpt_to_omop_converter_batch, pool_tuples)
588
+ p.close()
589
+ p.join()
590
+
591
+ return print("Done")
592
+
593
+
594
+ if __name__ == "__main__":
595
+ parser = argparse.ArgumentParser(
596
+ description="Arguments for converting patient sequences to OMOP"
597
+ )
598
+ parser.add_argument(
599
+ "--output_folder",
600
+ dest="output_folder",
601
+ action="store",
602
+ help="The path for the output_folder",
603
+ required=True,
604
+ )
605
+ parser.add_argument(
606
+ "--concept_path",
607
+ dest="concept_path",
608
+ action="store",
609
+ help="The path for your concept_path",
610
+ required=True,
611
+ )
612
+ parser.add_argument(
613
+ "--buffer_size",
614
+ dest="buffer_size",
615
+ action="store",
616
+ type=int,
617
+ help="The size of the batch",
618
+ required=False,
619
+ default=1024,
620
+ )
621
+ parser.add_argument(
622
+ "--patient_sequence_path",
623
+ dest="patient_sequence_path",
624
+ action="store",
625
+ help="The path for your patient sequence",
626
+ required=True,
627
+ )
628
+ parser.add_argument(
629
+ "--cpu_cores",
630
+ dest="cpu_cores",
631
+ type=int,
632
+ action="store",
633
+ help="The number of cpu cores to use for multiprocessing",
634
+ required=False,
635
+ default=1,
636
+ )
637
+ parser.add_argument(
638
+ "--use_original_person_id",
639
+ dest="use_original_person_id",
640
+ action="store_true",
641
+ help="Whether or not to use the original person id",
642
+ )
643
+
644
+ main(parser.parse_args())