cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
cehrgpt/omop/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from cehrgpt.omop.omop_argparse import create_omop_argparse
|
2
|
+
from cehrgpt.omop.omop_table_builder import OmopTableBuilder
|
3
|
+
from cehrgpt.omop.queries.condition_era import CONDITION_ERA_QUERY
|
4
|
+
|
5
|
+
CONDITION_ERA = "condition_era"
|
6
|
+
|
7
|
+
|
8
|
+
def main(args):
|
9
|
+
OmopTableBuilder.create_omop_query_builder(
|
10
|
+
input_folder=args.input_folder,
|
11
|
+
output_folder=args.output_folder,
|
12
|
+
continue_job=args.continue_job,
|
13
|
+
table_name=CONDITION_ERA,
|
14
|
+
query_template=CONDITION_ERA_QUERY,
|
15
|
+
dependency_list=["condition_occurrence"],
|
16
|
+
).build()
|
17
|
+
|
18
|
+
|
19
|
+
if __name__ == "__main__":
|
20
|
+
main(create_omop_argparse().parse_args())
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import glob
|
2
|
+
import os.path
|
3
|
+
|
4
|
+
from cehrgpt.omop.omop_argparse import create_omop_argparse
|
5
|
+
from cehrgpt.omop.omop_table_builder import OmopTableBuilder
|
6
|
+
from cehrgpt.omop.queries.observation_period import (
|
7
|
+
OBSERVATION_PERIOD_QUERY,
|
8
|
+
OBSERVATION_PERIOD_WITH_MEASUREMENT_QUERY,
|
9
|
+
)
|
10
|
+
|
11
|
+
OBSERVATION_PERIOD = "observation_period"
|
12
|
+
|
13
|
+
|
14
|
+
def main(args):
|
15
|
+
include_measurement = (
|
16
|
+
len(glob.glob(os.path.join(args.input_folder, "measurement", "*.parquet"))) > 0
|
17
|
+
)
|
18
|
+
dependency_list = [
|
19
|
+
"person",
|
20
|
+
"visit_occurrence",
|
21
|
+
"condition_occurrence",
|
22
|
+
"procedure_occurrence",
|
23
|
+
"drug_exposure",
|
24
|
+
]
|
25
|
+
if include_measurement:
|
26
|
+
dependency_list.append("measurement")
|
27
|
+
|
28
|
+
OmopTableBuilder.create_omop_query_builder(
|
29
|
+
input_folder=args.input_folder,
|
30
|
+
output_folder=args.output_folder,
|
31
|
+
continue_job=args.continue_job,
|
32
|
+
table_name=OBSERVATION_PERIOD,
|
33
|
+
query_template=(
|
34
|
+
OBSERVATION_PERIOD_WITH_MEASUREMENT_QUERY
|
35
|
+
if include_measurement
|
36
|
+
else OBSERVATION_PERIOD_QUERY
|
37
|
+
),
|
38
|
+
dependency_list=dependency_list,
|
39
|
+
).build()
|
40
|
+
|
41
|
+
|
42
|
+
if __name__ == "__main__":
|
43
|
+
main(create_omop_argparse().parse_args())
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import argparse
|
2
|
+
|
3
|
+
from cehrbert_data.utils.spark_utils import validate_table_names
|
4
|
+
|
5
|
+
|
6
|
+
def create_omop_argparse():
|
7
|
+
parser = argparse.ArgumentParser(
|
8
|
+
description="Spark application for creating OMOP table"
|
9
|
+
)
|
10
|
+
parser.add_argument(
|
11
|
+
"--input_folder",
|
12
|
+
dest="input_folder",
|
13
|
+
action="store",
|
14
|
+
help="The path for your input_folder where the raw data is",
|
15
|
+
required=True,
|
16
|
+
)
|
17
|
+
parser.add_argument(
|
18
|
+
"--output_folder",
|
19
|
+
dest="output_folder",
|
20
|
+
action="store",
|
21
|
+
help="The path for your output_folder",
|
22
|
+
required=True,
|
23
|
+
)
|
24
|
+
parser.add_argument(
|
25
|
+
"--continue_job",
|
26
|
+
dest="continue_job",
|
27
|
+
action="store_true",
|
28
|
+
)
|
29
|
+
parser.add_argument(
|
30
|
+
"--domain_table_list",
|
31
|
+
dest="domain_table_list",
|
32
|
+
nargs="+",
|
33
|
+
action="store",
|
34
|
+
help="The list of domain tables you want to download",
|
35
|
+
type=validate_table_names,
|
36
|
+
required=True,
|
37
|
+
)
|
38
|
+
return parser
|
@@ -0,0 +1,86 @@
|
|
1
|
+
from datetime import date
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import cehrbert_data.cohorts.spark_app_base
|
5
|
+
from cehrbert_data.cohorts.query_builder import QueryBuilder, QuerySpec
|
6
|
+
from cehrbert_data.cohorts.spark_app_base import BaseCohortBuilder
|
7
|
+
from cehrbert_data.utils.spark_utils import preprocess_domain_table
|
8
|
+
|
9
|
+
|
10
|
+
def custom_instantiate_dependencies(spark, input_folder, dependency_list):
|
11
|
+
dependency_dict = dict()
|
12
|
+
for domain_table_name in dependency_list:
|
13
|
+
table = preprocess_domain_table(spark, input_folder, domain_table_name)
|
14
|
+
table.createOrReplaceTempView(domain_table_name)
|
15
|
+
dependency_dict[domain_table_name] = table
|
16
|
+
return dependency_dict
|
17
|
+
|
18
|
+
|
19
|
+
# Monkeypatch the function
|
20
|
+
cehrbert_data.cohorts.spark_app_base.instantiate_dependencies = (
|
21
|
+
custom_instantiate_dependencies
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
class OmopTableBuilder(BaseCohortBuilder):
|
26
|
+
cohort_required_columns = []
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
query_builder: QueryBuilder,
|
31
|
+
input_folder: str,
|
32
|
+
output_folder: str,
|
33
|
+
continue_job: bool = False,
|
34
|
+
):
|
35
|
+
super().__init__(
|
36
|
+
query_builder=query_builder,
|
37
|
+
input_folder=input_folder,
|
38
|
+
output_folder=output_folder,
|
39
|
+
date_lower_bound="1900-01-01",
|
40
|
+
date_upper_bound=date.today().strftime("%Y-%m-%d"),
|
41
|
+
age_lower_bound=0,
|
42
|
+
age_upper_bound=200,
|
43
|
+
prior_observation_period=0,
|
44
|
+
post_observation_period=0,
|
45
|
+
continue_job=continue_job,
|
46
|
+
)
|
47
|
+
|
48
|
+
def build(self):
|
49
|
+
# Check whether the cohort has been generated
|
50
|
+
if self._continue_job and self.cohort_exists():
|
51
|
+
return self
|
52
|
+
cohort = self.create_cohort()
|
53
|
+
cohort.write.mode("overwrite").parquet(self._output_data_folder)
|
54
|
+
return self
|
55
|
+
|
56
|
+
@staticmethod
|
57
|
+
def create_omop_query_builder(
|
58
|
+
input_folder: str,
|
59
|
+
output_folder: str,
|
60
|
+
table_name: str,
|
61
|
+
query_template: str,
|
62
|
+
dependency_list: List[str],
|
63
|
+
query_parameters: Optional[Dict[str, Any]] = None,
|
64
|
+
continue_job: bool = False,
|
65
|
+
):
|
66
|
+
|
67
|
+
if query_parameters is None:
|
68
|
+
query_parameters = dict()
|
69
|
+
query = QuerySpec(
|
70
|
+
table_name=table_name,
|
71
|
+
query_template=query_template,
|
72
|
+
parameters=query_parameters,
|
73
|
+
)
|
74
|
+
table_query_builder = QueryBuilder(
|
75
|
+
cohort_name=table_name,
|
76
|
+
dependency_list=dependency_list,
|
77
|
+
query=query,
|
78
|
+
ancestor_table_specs=[],
|
79
|
+
)
|
80
|
+
|
81
|
+
return OmopTableBuilder(
|
82
|
+
query_builder=table_query_builder,
|
83
|
+
input_folder=input_folder,
|
84
|
+
output_folder=output_folder,
|
85
|
+
continue_job=continue_job,
|
86
|
+
)
|
File without changes
|
@@ -0,0 +1,86 @@
|
|
1
|
+
CONDITION_ERA_QUERY = """
|
2
|
+
-- Common Table Expression for cteConditionTarget
|
3
|
+
WITH cteConditionTarget AS (
|
4
|
+
SELECT
|
5
|
+
co.CONDITION_OCCURRENCE_ID,
|
6
|
+
co.PERSON_ID,
|
7
|
+
co.CONDITION_CONCEPT_ID,
|
8
|
+
co.CONDITION_TYPE_CONCEPT_ID,
|
9
|
+
co.CONDITION_START_DATE,
|
10
|
+
COALESCE(co.CONDITION_END_DATE, DATE_ADD(co.CONDITION_START_DATE, 1)) AS CONDITION_END_DATE
|
11
|
+
FROM CONDITION_OCCURRENCE co
|
12
|
+
),
|
13
|
+
|
14
|
+
-- Common Table Expression for cteEndDates
|
15
|
+
cteEndDates AS (
|
16
|
+
SELECT
|
17
|
+
PERSON_ID,
|
18
|
+
CONDITION_CONCEPT_ID,
|
19
|
+
DATE_ADD(EVENT_DATE, -30) AS END_DATE
|
20
|
+
FROM (
|
21
|
+
SELECT
|
22
|
+
PERSON_ID,
|
23
|
+
CONDITION_CONCEPT_ID,
|
24
|
+
EVENT_DATE,
|
25
|
+
EVENT_TYPE,
|
26
|
+
MAX(START_ORDINAL) OVER (PARTITION BY PERSON_ID, CONDITION_CONCEPT_ID ORDER BY EVENT_DATE, EVENT_TYPE ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS START_ORDINAL,
|
27
|
+
ROW_NUMBER() OVER (PARTITION BY PERSON_ID, CONDITION_CONCEPT_ID ORDER BY EVENT_DATE, EVENT_TYPE) AS OVERALL_ORD
|
28
|
+
FROM (
|
29
|
+
-- Select the start dates, assigning a row number to each
|
30
|
+
SELECT
|
31
|
+
PERSON_ID,
|
32
|
+
CONDITION_CONCEPT_ID,
|
33
|
+
CONDITION_START_DATE AS EVENT_DATE,
|
34
|
+
-1 AS EVENT_TYPE,
|
35
|
+
ROW_NUMBER() OVER (PARTITION BY PERSON_ID, CONDITION_CONCEPT_ID ORDER BY CONDITION_START_DATE) AS START_ORDINAL
|
36
|
+
FROM cteConditionTarget
|
37
|
+
|
38
|
+
UNION ALL
|
39
|
+
|
40
|
+
-- Pad the end dates by 30 to allow a grace period for overlapping ranges
|
41
|
+
SELECT
|
42
|
+
PERSON_ID,
|
43
|
+
CONDITION_CONCEPT_ID,
|
44
|
+
DATE_ADD(CONDITION_END_DATE, 30) AS EVENT_DATE,
|
45
|
+
1 AS EVENT_TYPE,
|
46
|
+
NULL
|
47
|
+
FROM cteConditionTarget
|
48
|
+
) RAWDATA
|
49
|
+
) E
|
50
|
+
WHERE (2 * E.START_ORDINAL) - E.OVERALL_ORD = 0
|
51
|
+
),
|
52
|
+
|
53
|
+
-- Common Table Expression for cteConditionEnds
|
54
|
+
cteConditionEnds AS (
|
55
|
+
SELECT
|
56
|
+
c.PERSON_ID,
|
57
|
+
c.CONDITION_CONCEPT_ID,
|
58
|
+
c.CONDITION_TYPE_CONCEPT_ID,
|
59
|
+
c.CONDITION_START_DATE,
|
60
|
+
MIN(e.END_DATE) AS ERA_END_DATE
|
61
|
+
FROM cteConditionTarget c
|
62
|
+
JOIN cteEndDates e ON c.PERSON_ID = e.PERSON_ID AND c.CONDITION_CONCEPT_ID = e.CONDITION_CONCEPT_ID AND e.END_DATE >= c.CONDITION_START_DATE
|
63
|
+
GROUP BY
|
64
|
+
c.PERSON_ID,
|
65
|
+
c.CONDITION_CONCEPT_ID,
|
66
|
+
c.CONDITION_TYPE_CONCEPT_ID,
|
67
|
+
c.CONDITION_START_DATE
|
68
|
+
)
|
69
|
+
|
70
|
+
SELECT
|
71
|
+
ROW_NUMBER() OVER (ORDER BY CONDITION_CONCEPT_ID) AS condition_era_id,
|
72
|
+
PERSON_ID,
|
73
|
+
CONDITION_CONCEPT_ID,
|
74
|
+
MIN(CONDITION_START_DATE) AS CONDITION_ERA_START_DATE,
|
75
|
+
ERA_END_DATE AS CONDITION_ERA_END_DATE,
|
76
|
+
COUNT(*) AS CONDITION_OCCURRENCE_COUNT
|
77
|
+
FROM cteConditionEnds
|
78
|
+
GROUP BY
|
79
|
+
PERSON_ID,
|
80
|
+
CONDITION_CONCEPT_ID,
|
81
|
+
CONDITION_TYPE_CONCEPT_ID,
|
82
|
+
ERA_END_DATE
|
83
|
+
ORDER BY
|
84
|
+
PERSON_ID,
|
85
|
+
CONDITION_CONCEPT_ID
|
86
|
+
"""
|
@@ -0,0 +1,135 @@
|
|
1
|
+
OBSERVATION_PERIOD_QUERY = """
|
2
|
+
select
|
3
|
+
person_id,
|
4
|
+
observation_period_start_date,
|
5
|
+
case
|
6
|
+
when observation_period_end_date >= add_months(current_date(), -12) then current_date()
|
7
|
+
else observation_period_end_date
|
8
|
+
end as observation_period_end_date,
|
9
|
+
period_type_concept_id
|
10
|
+
from (
|
11
|
+
SELECT person_id,
|
12
|
+
MIN(observation_period_start_date) AS observation_period_start_date,
|
13
|
+
MAX(observation_period_end_date) AS observation_period_end_date,
|
14
|
+
44814725 as period_type_concept_id
|
15
|
+
FROM (
|
16
|
+
SELECT pt.person_id AS person_id,
|
17
|
+
MIN(vt.visit_start_date) AS observation_period_start_date,
|
18
|
+
MAX(vt.visit_end_date) AS observation_period_end_date
|
19
|
+
FROM person as pt
|
20
|
+
JOIN visit_occurrence as vt ON pt.person_id = vt.person_id
|
21
|
+
WHERE YEAR(vt.visit_start_date) >= 1985
|
22
|
+
AND vt.visit_start_date <= current_date() --set lower and upper bound to ignore spurious dates
|
23
|
+
GROUP BY pt.person_id
|
24
|
+
|
25
|
+
UNION
|
26
|
+
|
27
|
+
SELECT pt.person_id AS person_id,
|
28
|
+
MIN(co.condition_start_date) AS observation_period_start_date,
|
29
|
+
MAX(co.condition_start_date) AS observation_period_end_date
|
30
|
+
FROM person as pt
|
31
|
+
JOIN condition_occurrence as co ON pt.person_id = co.person_id
|
32
|
+
WHERE YEAR(co.condition_start_date) >= 1985
|
33
|
+
AND co.condition_start_date <= current_date() --set lower and upper bound to ignore spurious dates
|
34
|
+
GROUP BY pt.person_id
|
35
|
+
|
36
|
+
UNION
|
37
|
+
|
38
|
+
SELECT pt.person_id AS person_id,
|
39
|
+
MIN(po.procedure_date) AS observation_period_start_date,
|
40
|
+
MAX(po.procedure_date) AS observation_period_end_date
|
41
|
+
FROM person as pt
|
42
|
+
JOIN procedure_occurrence as po ON pt.person_id = po.person_id
|
43
|
+
WHERE YEAR(po.procedure_date) >= 1985
|
44
|
+
AND po.procedure_date <= current_date() --set lower and upper bound to ignore spurious dates
|
45
|
+
GROUP BY pt.person_id
|
46
|
+
|
47
|
+
UNION
|
48
|
+
|
49
|
+
SELECT pt.person_id AS person_id,
|
50
|
+
MIN(de.drug_exposure_start_date) AS observation_period_start_date,
|
51
|
+
MAX(de.drug_exposure_start_date) AS observation_period_end_date
|
52
|
+
FROM person as pt
|
53
|
+
JOIN drug_exposure as de ON pt.person_id = de.person_id
|
54
|
+
WHERE YEAR(de.drug_exposure_start_date) >= 1985
|
55
|
+
AND de.drug_exposure_start_date <= current_date() --set lower and upper bound to ignore spurious dates
|
56
|
+
GROUP BY pt.person_id
|
57
|
+
) as x
|
58
|
+
WHERE x.observation_period_end_date IS NOT NULL
|
59
|
+
GROUP BY person_id
|
60
|
+
) as z
|
61
|
+
"""
|
62
|
+
|
63
|
+
OBSERVATION_PERIOD_WITH_MEASUREMENT_QUERY = """
|
64
|
+
select
|
65
|
+
person_id,
|
66
|
+
observation_period_start_date,
|
67
|
+
case
|
68
|
+
when observation_period_end_date >= add_months(current_date(), -12) then current_date()
|
69
|
+
else observation_period_end_date
|
70
|
+
end as observation_period_end_date,
|
71
|
+
period_type_concept_id
|
72
|
+
from (
|
73
|
+
SELECT person_id,
|
74
|
+
MIN(observation_period_start_date) AS observation_period_start_date,
|
75
|
+
MAX(observation_period_end_date) AS observation_period_end_date,
|
76
|
+
44814725 as period_type_concept_id
|
77
|
+
FROM (
|
78
|
+
SELECT pt.person_id AS person_id,
|
79
|
+
MIN(vt.visit_start_date) AS observation_period_start_date,
|
80
|
+
MAX(vt.visit_end_date) AS observation_period_end_date
|
81
|
+
FROM person as pt
|
82
|
+
JOIN visit_occurrence as vt ON pt.person_id = vt.person_id
|
83
|
+
WHERE YEAR(vt.visit_start_date) >= 1985
|
84
|
+
AND vt.visit_start_date <= current_date() --set lower and upper bound to ignore spurious dates
|
85
|
+
GROUP BY pt.person_id
|
86
|
+
|
87
|
+
UNION
|
88
|
+
|
89
|
+
SELECT pt.person_id AS person_id,
|
90
|
+
MIN(co.condition_start_date) AS observation_period_start_date,
|
91
|
+
MAX(co.condition_start_date) AS observation_period_end_date
|
92
|
+
FROM person as pt
|
93
|
+
JOIN condition_occurrence as co ON pt.person_id = co.person_id
|
94
|
+
WHERE YEAR(co.condition_start_date) >= 1985
|
95
|
+
AND co.condition_start_date <= current_date() --set lower and upper bound to ignore spurious dates
|
96
|
+
GROUP BY pt.person_id
|
97
|
+
|
98
|
+
UNION
|
99
|
+
|
100
|
+
SELECT pt.person_id AS person_id,
|
101
|
+
MIN(po.procedure_date) AS observation_period_start_date,
|
102
|
+
MAX(po.procedure_date) AS observation_period_end_date
|
103
|
+
FROM person as pt
|
104
|
+
JOIN procedure_occurrence as po ON pt.person_id = po.person_id
|
105
|
+
WHERE YEAR(po.procedure_date) >= 1985
|
106
|
+
AND po.procedure_date <= current_date() --set lower and upper bound to ignore spurious dates
|
107
|
+
GROUP BY pt.person_id
|
108
|
+
|
109
|
+
UNION
|
110
|
+
|
111
|
+
SELECT pt.person_id AS person_id,
|
112
|
+
MIN(de.drug_exposure_start_date) AS observation_period_start_date,
|
113
|
+
MAX(de.drug_exposure_start_date) AS observation_period_end_date
|
114
|
+
FROM person as pt
|
115
|
+
JOIN drug_exposure as de ON pt.person_id = de.person_id
|
116
|
+
WHERE YEAR(de.drug_exposure_start_date) >= 1985
|
117
|
+
AND de.drug_exposure_start_date <= current_date() --set lower and upper bound to ignore spurious dates
|
118
|
+
GROUP BY pt.person_id
|
119
|
+
|
120
|
+
UNION
|
121
|
+
|
122
|
+
SELECT
|
123
|
+
pt.person_id AS person_id,
|
124
|
+
MIN(m.measurement_date) AS observation_period_start_date,
|
125
|
+
MAX(m.measurement_date) AS observation_period_end_date
|
126
|
+
FROM person as pt
|
127
|
+
JOIN measurement as m ON pt.person_id = m.person_id
|
128
|
+
WHERE YEAR(m.measurement_date) >= 1985
|
129
|
+
AND m.measurement_date <= current_date() --set lower and upper bound to ignore spurious dates
|
130
|
+
GROUP BY pt.person_id
|
131
|
+
) as x
|
132
|
+
WHERE x.observation_period_end_date IS NOT NULL
|
133
|
+
GROUP BY person_id
|
134
|
+
) as z
|
135
|
+
"""
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
|
4
|
+
from pyspark.sql import SparkSession
|
5
|
+
|
6
|
+
# Define timestamp column for filtering based on the folder name
|
7
|
+
omop_tables = [
|
8
|
+
"person",
|
9
|
+
"death",
|
10
|
+
"visit_occurrence",
|
11
|
+
"condition_occurrence",
|
12
|
+
"procedure_occurrence",
|
13
|
+
"drug_exposure",
|
14
|
+
"measurement",
|
15
|
+
"observation",
|
16
|
+
"observation_period",
|
17
|
+
"condition_era",
|
18
|
+
"drug_era",
|
19
|
+
]
|
20
|
+
|
21
|
+
|
22
|
+
# Main function to process the folders and upload tables
|
23
|
+
def main(args):
|
24
|
+
spark = (
|
25
|
+
SparkSession.builder.appName("Sample OMOP Tables")
|
26
|
+
.config("spark.sql.legacy.parquet.int96RebaseModeInRead", "CORRECTED")
|
27
|
+
.config("spark.sql.legacy.parquet.int96RebaseModeInWrite", "CORRECTED")
|
28
|
+
.config("spark.sql.legacy.parquet.datetimeRebaseModeInRead", "CORRECTED")
|
29
|
+
.config("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "CORRECTED")
|
30
|
+
.getOrCreate()
|
31
|
+
)
|
32
|
+
patient_sample = spark.read.parquet(args.person_sample)
|
33
|
+
for omop_table in omop_tables:
|
34
|
+
omop_table_input = os.path.join(args.omop_folder, omop_table)
|
35
|
+
omop_table_output = os.path.join(args.output_folder, omop_table)
|
36
|
+
# If the input does not exist, let's skip
|
37
|
+
if not os.path.exists(omop_table_input):
|
38
|
+
continue
|
39
|
+
# If the output already exists, and overwrite is not set to True, let's skip it
|
40
|
+
if os.path.exists(omop_table_output) and not args.overwrite:
|
41
|
+
continue
|
42
|
+
omop_dataframe = spark.read.parquet(omop_table_input)
|
43
|
+
sub_omop_dataframe = omop_dataframe.join(
|
44
|
+
patient_sample.select("person_id"), "person_id"
|
45
|
+
)
|
46
|
+
sub_omop_dataframe.write.mode("overwrite" if args.overwrite else None).parquet(
|
47
|
+
omop_table_output
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
# Argument parsing moved under __name__ == "__main__"
|
52
|
+
if __name__ == "__main__":
|
53
|
+
parser = argparse.ArgumentParser(description="Arguments for uploading OMOP tables")
|
54
|
+
parser.add_argument(
|
55
|
+
"--person_sample",
|
56
|
+
required=True,
|
57
|
+
)
|
58
|
+
parser.add_argument(
|
59
|
+
"--omop_folder",
|
60
|
+
required=True,
|
61
|
+
)
|
62
|
+
parser.add_argument(
|
63
|
+
"--output_folder",
|
64
|
+
required=True,
|
65
|
+
)
|
66
|
+
parser.add_argument(
|
67
|
+
"--overwrite",
|
68
|
+
action="store_true",
|
69
|
+
)
|
70
|
+
# Call the main function with parsed arguments
|
71
|
+
main(parser.parse_args())
|
File without changes
|
@@ -0,0 +1,99 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
from typing import Tuple
|
5
|
+
|
6
|
+
from cehrbert.runners.hf_runner_argument_dataclass import (
|
7
|
+
DataTrainingArguments,
|
8
|
+
ModelArguments,
|
9
|
+
)
|
10
|
+
from transformers import HfArgumentParser, TrainingArguments
|
11
|
+
from transformers.utils import logging
|
12
|
+
from trl.trainer.dpo_config import DPOConfig
|
13
|
+
|
14
|
+
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
15
|
+
|
16
|
+
LOG = logging.get_logger("transformers")
|
17
|
+
|
18
|
+
|
19
|
+
def parse_dynamic_arguments(
|
20
|
+
argument_classes: Tuple[dataclasses.dataclass, ...] = (
|
21
|
+
DataTrainingArguments,
|
22
|
+
ModelArguments,
|
23
|
+
TrainingArguments,
|
24
|
+
)
|
25
|
+
) -> Tuple:
|
26
|
+
"""
|
27
|
+
Parses command-line arguments with extended flexibility, allowing for the inclusion of custom argument classes.
|
28
|
+
|
29
|
+
This function utilizes `HfArgumentParser` to parse arguments from command line input, JSON, or YAML files.
|
30
|
+
By default, it expects `ModelArguments`, `DataTrainingArguments`, and `TrainingArguments`, but it can be extended
|
31
|
+
with additional argument classes through the `argument_classes` parameter, making it suitable
|
32
|
+
for various custom setups.
|
33
|
+
|
34
|
+
Parameters:
|
35
|
+
argument_classes (Tuple[Type]): A tuple of argument classes to be parsed. Defaults to
|
36
|
+
`(ModelArguments, DataTrainingArguments, TrainingArguments)`. Additional argument classes can be specified
|
37
|
+
for greater flexibility in configuration.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Tuple: A tuple of parsed arguments, one for each argument class provided. The order of the returned tuple
|
41
|
+
matches the order of the `argument_classes` parameter.
|
42
|
+
|
43
|
+
Raises:
|
44
|
+
FileNotFoundError: If the specified JSON or YAML file does not exist.
|
45
|
+
json.JSONDecodeError: If there is an error parsing a JSON file.
|
46
|
+
yaml.YAMLError: If there is an error parsing a YAML file.
|
47
|
+
Exception: For other issues that occur during argument parsing.
|
48
|
+
|
49
|
+
Example usage:
|
50
|
+
- Command-line: `python training_script.py --model_name_or_path bert-base-uncased --do_train`
|
51
|
+
- JSON file: `python training_script.py config.json`
|
52
|
+
- YAML file: `python training_script.py config.yaml`
|
53
|
+
|
54
|
+
Flexibility:
|
55
|
+
The function can be customized to include new argument classes as needed:
|
56
|
+
|
57
|
+
Example with a custom argument class:
|
58
|
+
```python
|
59
|
+
class CustomArguments:
|
60
|
+
# Define custom arguments here
|
61
|
+
pass
|
62
|
+
|
63
|
+
|
64
|
+
custom_args = parse_extended_args(
|
65
|
+
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
|
66
|
+
)
|
67
|
+
```
|
68
|
+
This example demonstrates how to include additional argument classes
|
69
|
+
beyond the defaults for a more tailored setup.
|
70
|
+
"""
|
71
|
+
parser = HfArgumentParser(argument_classes)
|
72
|
+
|
73
|
+
# Check if input is a JSON or YAML file
|
74
|
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
75
|
+
args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1]))
|
76
|
+
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
77
|
+
args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1]))
|
78
|
+
else:
|
79
|
+
args = parser.parse_args_into_dataclasses()
|
80
|
+
|
81
|
+
return tuple(args)
|
82
|
+
|
83
|
+
|
84
|
+
def parse_runner_args() -> (
|
85
|
+
Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments]
|
86
|
+
):
|
87
|
+
cehrgpt_args, data_args, model_args, training_args = parse_dynamic_arguments(
|
88
|
+
(CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments)
|
89
|
+
)
|
90
|
+
return cehrgpt_args, data_args, model_args, training_args
|
91
|
+
|
92
|
+
|
93
|
+
def parse_dpo_runner_args() -> (
|
94
|
+
Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig]
|
95
|
+
):
|
96
|
+
cehrgpt_args, data_args, model_args, dpo_config = parse_dynamic_arguments(
|
97
|
+
(CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig)
|
98
|
+
)
|
99
|
+
return cehrgpt_args, data_args, model_args, dpo_config
|