cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
"""
|
3
|
+
Treatment Pathways Study Script.
|
4
|
+
|
5
|
+
Analyzes treatment pathways using OMOP CDM data - generalized version of the original script
|
6
|
+
"""
|
7
|
+
|
8
|
+
import argparse
|
9
|
+
import os
|
10
|
+
import sys
|
11
|
+
|
12
|
+
from pyspark.sql import SparkSession
|
13
|
+
from pyspark.sql import functions as f
|
14
|
+
|
15
|
+
|
16
|
+
def parse_arguments():
|
17
|
+
"""Parse command line arguments."""
|
18
|
+
parser = argparse.ArgumentParser(
|
19
|
+
description="Analyze treatment pathways using OMOP CDM data"
|
20
|
+
)
|
21
|
+
|
22
|
+
parser.add_argument(
|
23
|
+
"--omop-folder",
|
24
|
+
required=True,
|
25
|
+
help="Path to OMOP CDM data folder containing parquet files",
|
26
|
+
)
|
27
|
+
|
28
|
+
parser.add_argument(
|
29
|
+
"--output-folder", required=True, help="Output folder for results"
|
30
|
+
)
|
31
|
+
|
32
|
+
parser.add_argument(
|
33
|
+
"--drug-concepts",
|
34
|
+
required=True,
|
35
|
+
help="Comma-separated list of drug ancestor concept IDs (e.g., '21600381,21601461,21601560')",
|
36
|
+
)
|
37
|
+
|
38
|
+
parser.add_argument(
|
39
|
+
"--target-conditions",
|
40
|
+
required=True,
|
41
|
+
help="Comma-separated list of target condition ancestor concept IDs (e.g., '316866')",
|
42
|
+
)
|
43
|
+
|
44
|
+
parser.add_argument(
|
45
|
+
"--exclusion-conditions",
|
46
|
+
help="Comma-separated list of exclusion condition ancestor concept IDs (e.g., '444094'). If not provided, no exclusion conditions will be applied.",
|
47
|
+
)
|
48
|
+
|
49
|
+
parser.add_argument(
|
50
|
+
"--app-name",
|
51
|
+
default="Treatment_Pathways",
|
52
|
+
help="Spark application name (default: Treatment_Pathways)",
|
53
|
+
)
|
54
|
+
|
55
|
+
parser.add_argument(
|
56
|
+
"--spark-master",
|
57
|
+
default="local[*]",
|
58
|
+
help="Spark master URL (default: local[*])",
|
59
|
+
)
|
60
|
+
|
61
|
+
parser.add_argument(
|
62
|
+
"--study-name",
|
63
|
+
default="HTN",
|
64
|
+
help="Study name prefix for cohort tables (default: HTN)",
|
65
|
+
)
|
66
|
+
|
67
|
+
parser.add_argument(
|
68
|
+
"--save-cohort",
|
69
|
+
action="store_true",
|
70
|
+
default=False,
|
71
|
+
help="Save cohort tables as parquet files",
|
72
|
+
)
|
73
|
+
|
74
|
+
return parser.parse_args()
|
75
|
+
|
76
|
+
|
77
|
+
def parse_concept_ids(concept_string):
|
78
|
+
"""Parse comma-separated concept IDs into a list of integers."""
|
79
|
+
if not concept_string:
|
80
|
+
return []
|
81
|
+
return [int(x.strip()) for x in concept_string.split(",")]
|
82
|
+
|
83
|
+
|
84
|
+
def create_drug_concept_mapping(spark, drug_concepts):
|
85
|
+
"""Create drug concept mapping for medications - exact query from original."""
|
86
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
87
|
+
|
88
|
+
drug_concept = spark.sql(
|
89
|
+
f"""
|
90
|
+
SELECT DISTINCT
|
91
|
+
ancestor_concept_id,
|
92
|
+
descendant_concept_id
|
93
|
+
FROM
|
94
|
+
(
|
95
|
+
SELECT
|
96
|
+
ancestor_concept_id,
|
97
|
+
descendant_concept_id
|
98
|
+
FROM concept_ancestor AS ca
|
99
|
+
WHERE ca.ancestor_concept_id IN ({drug_concept_ids})
|
100
|
+
) a
|
101
|
+
"""
|
102
|
+
)
|
103
|
+
drug_concept.cache()
|
104
|
+
drug_concept.createOrReplaceTempView("drug_concept")
|
105
|
+
|
106
|
+
|
107
|
+
def create_htn_index_cohort(spark, drug_concepts, study_name):
|
108
|
+
"""Create HTN index cohort - exact query from original."""
|
109
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
110
|
+
|
111
|
+
htn_index_cohort = spark.sql(
|
112
|
+
f"""
|
113
|
+
SELECT person_id, INDEX_DATE, COHORT_END_DATE, observation_period_start_date, observation_period_end_date
|
114
|
+
FROM (
|
115
|
+
SELECT ot.PERSON_ID, ot.INDEX_DATE, MIN(e.END_DATE) as COHORT_END_DATE, ot.OBSERVATION_PERIOD_START_DATE, ot.OBSERVATION_PERIOD_END_DATE,
|
116
|
+
ROW_NUMBER() OVER (PARTITION BY ot.PERSON_ID ORDER BY ot.INDEX_DATE) as RowNumber
|
117
|
+
FROM (
|
118
|
+
SELECT dt.PERSON_ID, dt.DRUG_EXPOSURE_START_DATE as index_date, op.OBSERVATION_PERIOD_START_DATE, op.OBSERVATION_PERIOD_END_DATE
|
119
|
+
FROM (
|
120
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID, de.DRUG_EXPOSURE_START_DATE
|
121
|
+
FROM (
|
122
|
+
SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
|
123
|
+
COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
|
124
|
+
ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY d.DRUG_EXPOSURE_START_DATE) as RowNumber
|
125
|
+
FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
|
126
|
+
JOIN drug_concept ca
|
127
|
+
ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
128
|
+
) de
|
129
|
+
JOIN PERSON p ON p.PERSON_ID = de.PERSON_ID
|
130
|
+
WHERE de.RowNumber = 1
|
131
|
+
) dt
|
132
|
+
JOIN observation_period op
|
133
|
+
ON op.PERSON_ID = dt.PERSON_ID AND (dt.DRUG_EXPOSURE_START_DATE BETWEEN op.OBSERVATION_PERIOD_START_DATE AND op.OBSERVATION_PERIOD_END_DATE)
|
134
|
+
WHERE DATE_ADD(op.OBSERVATION_PERIOD_START_DATE, 365) <= dt.DRUG_EXPOSURE_START_DATE
|
135
|
+
AND DATE_ADD(dt.DRUG_EXPOSURE_START_DATE, 1095) <= op.OBSERVATION_PERIOD_END_DATE
|
136
|
+
) ot
|
137
|
+
JOIN (
|
138
|
+
SELECT PERSON_ID, DATE_ADD(EVENT_DATE, -31) as END_DATE
|
139
|
+
FROM (
|
140
|
+
SELECT PERSON_ID, EVENT_DATE, EVENT_TYPE, START_ORDINAL,
|
141
|
+
ROW_NUMBER() OVER (PARTITION BY PERSON_ID ORDER BY EVENT_DATE, EVENT_TYPE) AS EVENT_ORDINAL,
|
142
|
+
MAX(START_ORDINAL) OVER (PARTITION BY PERSON_ID ORDER BY EVENT_DATE, EVENT_TYPE ROWS UNBOUNDED PRECEDING) as STARTS
|
143
|
+
FROM (
|
144
|
+
SELECT PERSON_ID, DRUG_EXPOSURE_START_DATE AS EVENT_DATE, 1 as EVENT_TYPE,
|
145
|
+
ROW_NUMBER() OVER (PARTITION BY PERSON_ID ORDER BY DRUG_EXPOSURE_START_DATE) as START_ORDINAL
|
146
|
+
FROM (
|
147
|
+
SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
|
148
|
+
COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
|
149
|
+
ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY d.DRUG_EXPOSURE_START_DATE) as RowNumber
|
150
|
+
FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
|
151
|
+
JOIN drug_concept ca
|
152
|
+
ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
153
|
+
) cteExposureData
|
154
|
+
UNION ALL
|
155
|
+
SELECT PERSON_ID, DATE_ADD(DRUG_EXPOSURE_END_DATE, 31), 0 as EVENT_TYPE, NULL
|
156
|
+
FROM (
|
157
|
+
SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
|
158
|
+
COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
|
159
|
+
ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY d.DRUG_EXPOSURE_START_DATE) as RowNumber
|
160
|
+
FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
|
161
|
+
JOIN drug_concept ca
|
162
|
+
ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
163
|
+
) cteExposureData
|
164
|
+
) RAWDATA
|
165
|
+
) E
|
166
|
+
WHERE 2 * E.STARTS - E.EVENT_ORDINAL = 0
|
167
|
+
) e ON e.PERSON_ID = ot.PERSON_ID AND e.END_DATE >= ot.INDEX_DATE
|
168
|
+
GROUP BY ot.PERSON_ID, ot.INDEX_DATE, ot.OBSERVATION_PERIOD_START_DATE, ot.OBSERVATION_PERIOD_END_DATE
|
169
|
+
) r
|
170
|
+
WHERE r.RowNumber = 1
|
171
|
+
"""
|
172
|
+
)
|
173
|
+
|
174
|
+
htn_index_cohort.cache()
|
175
|
+
htn_index_cohort.createOrReplaceTempView(f"{study_name}_index_cohort")
|
176
|
+
|
177
|
+
|
178
|
+
def create_htn_e0(spark, exclusion_conditions, study_name):
|
179
|
+
"""Create HTN_E0 - exact query from original."""
|
180
|
+
if not exclusion_conditions:
|
181
|
+
# If no exclusion conditions, return all patients from index cohort
|
182
|
+
HTN_E0 = spark.sql(
|
183
|
+
f"""
|
184
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
185
|
+
FROM {study_name}_index_cohort ip
|
186
|
+
"""
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
exclusion_concept_ids = ",".join(map(str, exclusion_conditions))
|
190
|
+
HTN_E0 = spark.sql(
|
191
|
+
f"""
|
192
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
193
|
+
FROM {study_name}_index_cohort ip
|
194
|
+
LEFT JOIN (
|
195
|
+
SELECT co.PERSON_ID, co.CONDITION_CONCEPT_ID
|
196
|
+
FROM condition_occurrence co
|
197
|
+
JOIN {study_name}_index_cohort ip ON co.PERSON_ID = ip.PERSON_ID
|
198
|
+
JOIN drug_concept ca ON co.CONDITION_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({exclusion_concept_ids})
|
199
|
+
WHERE co.CONDITION_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
200
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
201
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
202
|
+
HAVING COUNT(dt.CONDITION_CONCEPT_ID) <= 0
|
203
|
+
"""
|
204
|
+
)
|
205
|
+
|
206
|
+
HTN_E0.cache()
|
207
|
+
HTN_E0.createOrReplaceTempView(f"{study_name}_E0")
|
208
|
+
|
209
|
+
|
210
|
+
def create_htn_t0(spark, drug_concepts, study_name):
|
211
|
+
"""Create HTN_T0 - exact query from original."""
|
212
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
213
|
+
|
214
|
+
HTN_T0 = spark.sql(
|
215
|
+
f"""
|
216
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
217
|
+
FROM {study_name}_index_cohort ip
|
218
|
+
LEFT JOIN (
|
219
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
220
|
+
FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) de
|
221
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
222
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
223
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
224
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND DATE_ADD(ip.INDEX_DATE, -1)
|
225
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
226
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
227
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) <= 0
|
228
|
+
"""
|
229
|
+
)
|
230
|
+
|
231
|
+
HTN_T0.createOrReplaceTempView(f"{study_name}_T0")
|
232
|
+
|
233
|
+
|
234
|
+
def create_htn_t1(spark, target_conditions, study_name):
|
235
|
+
"""Create HTN_T1 - exact query from original."""
|
236
|
+
target_concept_ids = ",".join(map(str, target_conditions))
|
237
|
+
|
238
|
+
HTN_T1 = spark.sql(
|
239
|
+
f"""
|
240
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
241
|
+
FROM {study_name}_index_cohort ip
|
242
|
+
LEFT JOIN (
|
243
|
+
SELECT ce.PERSON_ID, ce.CONDITION_CONCEPT_ID
|
244
|
+
FROM CONDITION_ERA ce
|
245
|
+
JOIN {study_name}_index_cohort ip ON ce.PERSON_ID = ip.PERSON_ID
|
246
|
+
JOIN concept_ancestor ca ON ce.CONDITION_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({target_concept_ids})
|
247
|
+
WHERE ce.CONDITION_ERA_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
248
|
+
) ct ON ct.PERSON_ID = ip.PERSON_ID
|
249
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
250
|
+
HAVING COUNT(ct.CONDITION_CONCEPT_ID) >= 1
|
251
|
+
"""
|
252
|
+
)
|
253
|
+
|
254
|
+
HTN_T1.createOrReplaceTempView(f"{study_name}_T1")
|
255
|
+
|
256
|
+
|
257
|
+
def create_htn_t2(spark, drug_concepts, study_name):
|
258
|
+
"""Create HTN_T2 - exact query from original."""
|
259
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
260
|
+
|
261
|
+
HTN_T2 = spark.sql(
|
262
|
+
f"""
|
263
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
264
|
+
FROM {study_name}_index_cohort ip
|
265
|
+
LEFT JOIN (
|
266
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
267
|
+
FROM (
|
268
|
+
SELECT *
|
269
|
+
FROM DRUG_EXPOSURE
|
270
|
+
WHERE visit_occurrence_id IS NOT NULL
|
271
|
+
) de
|
272
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
273
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
274
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
275
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 121) AND DATE_ADD(ip.INDEX_DATE, 240)
|
276
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
277
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
278
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
279
|
+
"""
|
280
|
+
)
|
281
|
+
|
282
|
+
HTN_T2.createOrReplaceTempView(f"{study_name}_T2")
|
283
|
+
|
284
|
+
|
285
|
+
def create_htn_t3(spark, drug_concepts, study_name):
|
286
|
+
"""Create HTN_T3 - exact query from original."""
|
287
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
288
|
+
|
289
|
+
HTN_T3 = spark.sql(
|
290
|
+
f"""
|
291
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
292
|
+
FROM {study_name}_index_cohort ip
|
293
|
+
LEFT JOIN (
|
294
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
295
|
+
FROM (
|
296
|
+
SELECT *
|
297
|
+
FROM DRUG_EXPOSURE
|
298
|
+
WHERE visit_occurrence_id IS NOT NULL
|
299
|
+
) de
|
300
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
301
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
302
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
303
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 241) AND DATE_ADD(ip.INDEX_DATE, 360)
|
304
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
305
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
306
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
307
|
+
"""
|
308
|
+
)
|
309
|
+
|
310
|
+
HTN_T3.createOrReplaceTempView(f"{study_name}_T3")
|
311
|
+
|
312
|
+
|
313
|
+
def create_htn_t4(spark, drug_concepts, study_name):
|
314
|
+
"""Create HTN_T4 - exact query from original."""
|
315
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
316
|
+
|
317
|
+
HTN_T4 = spark.sql(
|
318
|
+
f"""
|
319
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
320
|
+
FROM {study_name}_index_cohort ip
|
321
|
+
LEFT JOIN (
|
322
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
323
|
+
FROM (
|
324
|
+
SELECT *
|
325
|
+
FROM DRUG_EXPOSURE
|
326
|
+
WHERE visit_occurrence_id IS NOT NULL
|
327
|
+
) de
|
328
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
329
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
330
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
331
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 361) AND DATE_ADD(ip.INDEX_DATE, 480)
|
332
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
333
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
334
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
335
|
+
"""
|
336
|
+
)
|
337
|
+
|
338
|
+
HTN_T4.createOrReplaceTempView(f"{study_name}_T4")
|
339
|
+
|
340
|
+
|
341
|
+
def create_htn_t5(spark, drug_concepts, study_name):
|
342
|
+
"""Create HTN_T5 - exact query from original."""
|
343
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
344
|
+
|
345
|
+
HTN_T5 = spark.sql(
|
346
|
+
f"""
|
347
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
348
|
+
FROM {study_name}_index_cohort ip
|
349
|
+
LEFT JOIN (
|
350
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
351
|
+
FROM (
|
352
|
+
SELECT *
|
353
|
+
FROM DRUG_EXPOSURE
|
354
|
+
WHERE visit_occurrence_id IS NOT NULL
|
355
|
+
) de
|
356
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
357
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
358
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
359
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 481) AND DATE_ADD(ip.INDEX_DATE, 600)
|
360
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
361
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
362
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
363
|
+
"""
|
364
|
+
)
|
365
|
+
|
366
|
+
HTN_T5.createOrReplaceTempView(f"{study_name}_T5")
|
367
|
+
|
368
|
+
|
369
|
+
def create_htn_t6(spark, drug_concepts, study_name):
|
370
|
+
"""Create HTN_T6 - exact query from original."""
|
371
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
372
|
+
|
373
|
+
HTN_T6 = spark.sql(
|
374
|
+
f"""
|
375
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
376
|
+
FROM {study_name}_index_cohort ip
|
377
|
+
LEFT JOIN (
|
378
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
379
|
+
FROM (
|
380
|
+
SELECT *
|
381
|
+
FROM DRUG_EXPOSURE
|
382
|
+
WHERE visit_occurrence_id IS NOT NULL
|
383
|
+
) de
|
384
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
385
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
386
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
387
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 601) AND DATE_ADD(ip.INDEX_DATE, 720)
|
388
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
389
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
390
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
391
|
+
"""
|
392
|
+
)
|
393
|
+
|
394
|
+
HTN_T6.createOrReplaceTempView(f"{study_name}_T6")
|
395
|
+
|
396
|
+
|
397
|
+
def create_htn_t7(spark, drug_concepts, study_name):
|
398
|
+
"""Create HTN_T7 - exact query from original."""
|
399
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
400
|
+
|
401
|
+
HTN_T7 = spark.sql(
|
402
|
+
f"""
|
403
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
404
|
+
FROM {study_name}_index_cohort ip
|
405
|
+
LEFT JOIN (
|
406
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
407
|
+
FROM (
|
408
|
+
SELECT *
|
409
|
+
FROM DRUG_EXPOSURE
|
410
|
+
WHERE visit_occurrence_id IS NOT NULL
|
411
|
+
) de
|
412
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
413
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
414
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
415
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 721) AND DATE_ADD(ip.INDEX_DATE, 840)
|
416
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
417
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
418
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
419
|
+
"""
|
420
|
+
)
|
421
|
+
|
422
|
+
HTN_T7.createOrReplaceTempView(f"{study_name}_T7")
|
423
|
+
|
424
|
+
|
425
|
+
def create_htn_t8(spark, drug_concepts, study_name):
|
426
|
+
"""Create HTN_T8 - exact query from original."""
|
427
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
428
|
+
|
429
|
+
HTN_T8 = spark.sql(
|
430
|
+
f"""
|
431
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
432
|
+
FROM {study_name}_index_cohort ip
|
433
|
+
LEFT JOIN (
|
434
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
435
|
+
FROM (
|
436
|
+
SELECT *
|
437
|
+
FROM DRUG_EXPOSURE
|
438
|
+
WHERE visit_occurrence_id IS NOT NULL
|
439
|
+
) de
|
440
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
441
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
442
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
443
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 841) AND DATE_ADD(ip.INDEX_DATE, 960)
|
444
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
445
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
446
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
447
|
+
"""
|
448
|
+
)
|
449
|
+
|
450
|
+
HTN_T8.createOrReplaceTempView(f"{study_name}_T8")
|
451
|
+
|
452
|
+
|
453
|
+
def create_htn_t9(spark, drug_concepts, study_name):
|
454
|
+
"""Create HTN_T9 - exact query from original."""
|
455
|
+
drug_concept_ids = ",".join(map(str, drug_concepts))
|
456
|
+
|
457
|
+
HTN_T9 = spark.sql(
|
458
|
+
f"""
|
459
|
+
SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
460
|
+
FROM {study_name}_index_cohort ip
|
461
|
+
LEFT JOIN (
|
462
|
+
SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
|
463
|
+
FROM (
|
464
|
+
SELECT *
|
465
|
+
FROM DRUG_EXPOSURE
|
466
|
+
WHERE visit_occurrence_id IS NOT NULL
|
467
|
+
) de
|
468
|
+
JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
|
469
|
+
JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
|
470
|
+
WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
|
471
|
+
AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 961) AND DATE_ADD(ip.INDEX_DATE, 1080)
|
472
|
+
) dt ON dt.PERSON_ID = ip.PERSON_ID
|
473
|
+
GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
|
474
|
+
HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
|
475
|
+
"""
|
476
|
+
)
|
477
|
+
|
478
|
+
HTN_T9.createOrReplaceTempView(f"{study_name}_T9")
|
479
|
+
|
480
|
+
|
481
|
+
def create_htn_match_cohort(spark, study_name):
|
482
|
+
"""Create HTN_MatchCohort - exact query from original."""
|
483
|
+
HTN_MatchCohort = spark.sql(
|
484
|
+
f"""
|
485
|
+
SELECT c.person_id, c.index_date, c.cohort_end_date, c.observation_period_start_date, c.observation_period_end_date
|
486
|
+
FROM {study_name}_index_cohort C
|
487
|
+
INNER JOIN (
|
488
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID
|
489
|
+
FROM (
|
490
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_E0
|
491
|
+
INTERSECT
|
492
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T0
|
493
|
+
INTERSECT
|
494
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T1
|
495
|
+
INTERSECT
|
496
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T2
|
497
|
+
INTERSECT
|
498
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T3
|
499
|
+
INTERSECT
|
500
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T4
|
501
|
+
INTERSECT
|
502
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T5
|
503
|
+
INTERSECT
|
504
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T6
|
505
|
+
INTERSECT
|
506
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T7
|
507
|
+
INTERSECT
|
508
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T8
|
509
|
+
INTERSECT
|
510
|
+
SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T9
|
511
|
+
) TopGroup
|
512
|
+
) I
|
513
|
+
ON C.PERSON_ID = I.PERSON_ID
|
514
|
+
AND c.index_date = i.index_date
|
515
|
+
"""
|
516
|
+
)
|
517
|
+
|
518
|
+
return HTN_MatchCohort
|
519
|
+
|
520
|
+
|
521
|
+
def main():
|
522
|
+
args = parse_arguments()
|
523
|
+
# Parse concept IDs
|
524
|
+
drug_concepts = parse_concept_ids(args.drug_concepts)
|
525
|
+
target_conditions = parse_concept_ids(args.target_conditions)
|
526
|
+
exclusion_conditions = (
|
527
|
+
parse_concept_ids(args.exclusion_conditions)
|
528
|
+
if args.exclusion_conditions
|
529
|
+
else []
|
530
|
+
)
|
531
|
+
|
532
|
+
print(f"Study name: {args.study_name}")
|
533
|
+
print(f"Drug concepts: {drug_concepts}")
|
534
|
+
print(f"Target conditions: {target_conditions}")
|
535
|
+
print(f"Exclusion conditions: {exclusion_conditions}")
|
536
|
+
|
537
|
+
# Initialize Spark
|
538
|
+
spark = SparkSession.builder.appName(args.app_name).getOrCreate()
|
539
|
+
|
540
|
+
try:
|
541
|
+
# Load the source OMOP tables
|
542
|
+
person = spark.read.parquet(os.path.join(args.omop_folder, "person"))
|
543
|
+
visit_occurrence = spark.read.parquet(
|
544
|
+
os.path.join(args.omop_folder, "visit_occurrence")
|
545
|
+
)
|
546
|
+
condition_occurrence = spark.read.parquet(
|
547
|
+
os.path.join(args.omop_folder, "condition_occurrence")
|
548
|
+
)
|
549
|
+
procedure_occurrence = spark.read.parquet(
|
550
|
+
os.path.join(args.omop_folder, "procedure_occurrence")
|
551
|
+
)
|
552
|
+
drug_exposure = spark.read.parquet(
|
553
|
+
os.path.join(args.omop_folder, "drug_exposure")
|
554
|
+
)
|
555
|
+
observation_period = spark.read.parquet(
|
556
|
+
os.path.join(args.omop_folder, "observation_period")
|
557
|
+
)
|
558
|
+
condition_era = spark.read.parquet(
|
559
|
+
os.path.join(args.omop_folder, "condition_era")
|
560
|
+
)
|
561
|
+
|
562
|
+
print(f"person: {person.select('person_id').distinct().count()}")
|
563
|
+
print(f"visit_occurrence: {visit_occurrence.count()}")
|
564
|
+
print(f"condition_occurrence: {condition_occurrence.count()}")
|
565
|
+
print(f"procedure_occurrence: {procedure_occurrence.count()}")
|
566
|
+
print(f"drug_exposure: {drug_exposure.count()}")
|
567
|
+
print(f"observation_period: {observation_period.count()}")
|
568
|
+
|
569
|
+
concept = spark.read.parquet(os.path.join(args.omop_folder, "concept"))
|
570
|
+
concept_ancestor = spark.read.parquet(
|
571
|
+
os.path.join(args.omop_folder, "concept_ancestor")
|
572
|
+
)
|
573
|
+
|
574
|
+
# Create temporary views
|
575
|
+
person.createOrReplaceTempView("person")
|
576
|
+
visit_occurrence.createOrReplaceTempView("visit_occurrence")
|
577
|
+
condition_occurrence.createOrReplaceTempView("condition_occurrence")
|
578
|
+
procedure_occurrence.createOrReplaceTempView("procedure_occurrence")
|
579
|
+
drug_exposure.createOrReplaceTempView("drug_exposure")
|
580
|
+
observation_period.createOrReplaceTempView("observation_period")
|
581
|
+
condition_era.createOrReplaceTempView("condition_era")
|
582
|
+
|
583
|
+
concept_ancestor.createOrReplaceTempView("concept_ancestor")
|
584
|
+
concept.createOrReplaceTempView("concept")
|
585
|
+
|
586
|
+
# Create drug concept mapping
|
587
|
+
create_drug_concept_mapping(spark, drug_concepts)
|
588
|
+
|
589
|
+
# Create HTN index cohort
|
590
|
+
create_htn_index_cohort(spark, drug_concepts, args.study_name)
|
591
|
+
|
592
|
+
# Create all cohorts in sequence - keeping exact original function calls
|
593
|
+
create_htn_e0(spark, exclusion_conditions, args.study_name)
|
594
|
+
create_htn_t0(spark, drug_concepts, args.study_name)
|
595
|
+
create_htn_t1(spark, target_conditions, args.study_name)
|
596
|
+
create_htn_t2(spark, drug_concepts, args.study_name)
|
597
|
+
create_htn_t3(spark, drug_concepts, args.study_name)
|
598
|
+
create_htn_t4(spark, drug_concepts, args.study_name)
|
599
|
+
create_htn_t5(spark, drug_concepts, args.study_name)
|
600
|
+
create_htn_t6(spark, drug_concepts, args.study_name)
|
601
|
+
create_htn_t7(spark, drug_concepts, args.study_name)
|
602
|
+
create_htn_t8(spark, drug_concepts, args.study_name)
|
603
|
+
create_htn_t9(spark, drug_concepts, args.study_name)
|
604
|
+
|
605
|
+
# Create final cohort
|
606
|
+
htn_match_cohort = create_htn_match_cohort(spark, args.study_name)
|
607
|
+
|
608
|
+
if args.save_cohort:
|
609
|
+
# Save results
|
610
|
+
if not os.path.exists(args.output_folder):
|
611
|
+
os.makedirs(args.output_folder)
|
612
|
+
|
613
|
+
output_path = os.path.join(args.output_folder, "htn_match_cohort")
|
614
|
+
htn_match_cohort.write.mode("overwrite").parquet(output_path)
|
615
|
+
|
616
|
+
# Read back and count
|
617
|
+
htn_match_cohort = spark.read.parquet(output_path)
|
618
|
+
final_count = htn_match_cohort.count()
|
619
|
+
print(f"Final cohort count: {final_count}")
|
620
|
+
|
621
|
+
print("Analysis completed successfully!")
|
622
|
+
|
623
|
+
except Exception as e:
|
624
|
+
print(f"Error: {e}")
|
625
|
+
sys.exit(1)
|
626
|
+
finally:
|
627
|
+
spark.stop()
|
628
|
+
|
629
|
+
|
630
|
+
if __name__ == "__main__":
|
631
|
+
main()
|