cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ Treatment Pathways Study Script.
4
+
5
+ Analyzes treatment pathways using OMOP CDM data - generalized version of the original script
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ import sys
11
+
12
+ from pyspark.sql import SparkSession
13
+ from pyspark.sql import functions as f
14
+
15
+
16
+ def parse_arguments():
17
+ """Parse command line arguments."""
18
+ parser = argparse.ArgumentParser(
19
+ description="Analyze treatment pathways using OMOP CDM data"
20
+ )
21
+
22
+ parser.add_argument(
23
+ "--omop-folder",
24
+ required=True,
25
+ help="Path to OMOP CDM data folder containing parquet files",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "--output-folder", required=True, help="Output folder for results"
30
+ )
31
+
32
+ parser.add_argument(
33
+ "--drug-concepts",
34
+ required=True,
35
+ help="Comma-separated list of drug ancestor concept IDs (e.g., '21600381,21601461,21601560')",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--target-conditions",
40
+ required=True,
41
+ help="Comma-separated list of target condition ancestor concept IDs (e.g., '316866')",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--exclusion-conditions",
46
+ help="Comma-separated list of exclusion condition ancestor concept IDs (e.g., '444094'). If not provided, no exclusion conditions will be applied.",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "--app-name",
51
+ default="Treatment_Pathways",
52
+ help="Spark application name (default: Treatment_Pathways)",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--spark-master",
57
+ default="local[*]",
58
+ help="Spark master URL (default: local[*])",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--study-name",
63
+ default="HTN",
64
+ help="Study name prefix for cohort tables (default: HTN)",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--save-cohort",
69
+ action="store_true",
70
+ default=False,
71
+ help="Save cohort tables as parquet files",
72
+ )
73
+
74
+ return parser.parse_args()
75
+
76
+
77
+ def parse_concept_ids(concept_string):
78
+ """Parse comma-separated concept IDs into a list of integers."""
79
+ if not concept_string:
80
+ return []
81
+ return [int(x.strip()) for x in concept_string.split(",")]
82
+
83
+
84
+ def create_drug_concept_mapping(spark, drug_concepts):
85
+ """Create drug concept mapping for medications - exact query from original."""
86
+ drug_concept_ids = ",".join(map(str, drug_concepts))
87
+
88
+ drug_concept = spark.sql(
89
+ f"""
90
+ SELECT DISTINCT
91
+ ancestor_concept_id,
92
+ descendant_concept_id
93
+ FROM
94
+ (
95
+ SELECT
96
+ ancestor_concept_id,
97
+ descendant_concept_id
98
+ FROM concept_ancestor AS ca
99
+ WHERE ca.ancestor_concept_id IN ({drug_concept_ids})
100
+ ) a
101
+ """
102
+ )
103
+ drug_concept.cache()
104
+ drug_concept.createOrReplaceTempView("drug_concept")
105
+
106
+
107
+ def create_htn_index_cohort(spark, drug_concepts, study_name):
108
+ """Create HTN index cohort - exact query from original."""
109
+ drug_concept_ids = ",".join(map(str, drug_concepts))
110
+
111
+ htn_index_cohort = spark.sql(
112
+ f"""
113
+ SELECT person_id, INDEX_DATE, COHORT_END_DATE, observation_period_start_date, observation_period_end_date
114
+ FROM (
115
+ SELECT ot.PERSON_ID, ot.INDEX_DATE, MIN(e.END_DATE) as COHORT_END_DATE, ot.OBSERVATION_PERIOD_START_DATE, ot.OBSERVATION_PERIOD_END_DATE,
116
+ ROW_NUMBER() OVER (PARTITION BY ot.PERSON_ID ORDER BY ot.INDEX_DATE) as RowNumber
117
+ FROM (
118
+ SELECT dt.PERSON_ID, dt.DRUG_EXPOSURE_START_DATE as index_date, op.OBSERVATION_PERIOD_START_DATE, op.OBSERVATION_PERIOD_END_DATE
119
+ FROM (
120
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID, de.DRUG_EXPOSURE_START_DATE
121
+ FROM (
122
+ SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
123
+ COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
124
+ ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY d.DRUG_EXPOSURE_START_DATE) as RowNumber
125
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
126
+ JOIN drug_concept ca
127
+ ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
128
+ ) de
129
+ JOIN PERSON p ON p.PERSON_ID = de.PERSON_ID
130
+ WHERE de.RowNumber = 1
131
+ ) dt
132
+ JOIN observation_period op
133
+ ON op.PERSON_ID = dt.PERSON_ID AND (dt.DRUG_EXPOSURE_START_DATE BETWEEN op.OBSERVATION_PERIOD_START_DATE AND op.OBSERVATION_PERIOD_END_DATE)
134
+ WHERE DATE_ADD(op.OBSERVATION_PERIOD_START_DATE, 365) <= dt.DRUG_EXPOSURE_START_DATE
135
+ AND DATE_ADD(dt.DRUG_EXPOSURE_START_DATE, 1095) <= op.OBSERVATION_PERIOD_END_DATE
136
+ ) ot
137
+ JOIN (
138
+ SELECT PERSON_ID, DATE_ADD(EVENT_DATE, -31) as END_DATE
139
+ FROM (
140
+ SELECT PERSON_ID, EVENT_DATE, EVENT_TYPE, START_ORDINAL,
141
+ ROW_NUMBER() OVER (PARTITION BY PERSON_ID ORDER BY EVENT_DATE, EVENT_TYPE) AS EVENT_ORDINAL,
142
+ MAX(START_ORDINAL) OVER (PARTITION BY PERSON_ID ORDER BY EVENT_DATE, EVENT_TYPE ROWS UNBOUNDED PRECEDING) as STARTS
143
+ FROM (
144
+ SELECT PERSON_ID, DRUG_EXPOSURE_START_DATE AS EVENT_DATE, 1 as EVENT_TYPE,
145
+ ROW_NUMBER() OVER (PARTITION BY PERSON_ID ORDER BY DRUG_EXPOSURE_START_DATE) as START_ORDINAL
146
+ FROM (
147
+ SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
148
+ COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
149
+ ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY d.DRUG_EXPOSURE_START_DATE) as RowNumber
150
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
151
+ JOIN drug_concept ca
152
+ ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
153
+ ) cteExposureData
154
+ UNION ALL
155
+ SELECT PERSON_ID, DATE_ADD(DRUG_EXPOSURE_END_DATE, 31), 0 as EVENT_TYPE, NULL
156
+ FROM (
157
+ SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
158
+ COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
159
+ ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY d.DRUG_EXPOSURE_START_DATE) as RowNumber
160
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
161
+ JOIN drug_concept ca
162
+ ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
163
+ ) cteExposureData
164
+ ) RAWDATA
165
+ ) E
166
+ WHERE 2 * E.STARTS - E.EVENT_ORDINAL = 0
167
+ ) e ON e.PERSON_ID = ot.PERSON_ID AND e.END_DATE >= ot.INDEX_DATE
168
+ GROUP BY ot.PERSON_ID, ot.INDEX_DATE, ot.OBSERVATION_PERIOD_START_DATE, ot.OBSERVATION_PERIOD_END_DATE
169
+ ) r
170
+ WHERE r.RowNumber = 1
171
+ """
172
+ )
173
+
174
+ htn_index_cohort.cache()
175
+ htn_index_cohort.createOrReplaceTempView(f"{study_name}_index_cohort")
176
+
177
+
178
+ def create_htn_e0(spark, exclusion_conditions, study_name):
179
+ """Create HTN_E0 - exact query from original."""
180
+ if not exclusion_conditions:
181
+ # If no exclusion conditions, return all patients from index cohort
182
+ HTN_E0 = spark.sql(
183
+ f"""
184
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
185
+ FROM {study_name}_index_cohort ip
186
+ """
187
+ )
188
+ else:
189
+ exclusion_concept_ids = ",".join(map(str, exclusion_conditions))
190
+ HTN_E0 = spark.sql(
191
+ f"""
192
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
193
+ FROM {study_name}_index_cohort ip
194
+ LEFT JOIN (
195
+ SELECT co.PERSON_ID, co.CONDITION_CONCEPT_ID
196
+ FROM condition_occurrence co
197
+ JOIN {study_name}_index_cohort ip ON co.PERSON_ID = ip.PERSON_ID
198
+ JOIN drug_concept ca ON co.CONDITION_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({exclusion_concept_ids})
199
+ WHERE co.CONDITION_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
200
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
201
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
202
+ HAVING COUNT(dt.CONDITION_CONCEPT_ID) <= 0
203
+ """
204
+ )
205
+
206
+ HTN_E0.cache()
207
+ HTN_E0.createOrReplaceTempView(f"{study_name}_E0")
208
+
209
+
210
+ def create_htn_t0(spark, drug_concepts, study_name):
211
+ """Create HTN_T0 - exact query from original."""
212
+ drug_concept_ids = ",".join(map(str, drug_concepts))
213
+
214
+ HTN_T0 = spark.sql(
215
+ f"""
216
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
217
+ FROM {study_name}_index_cohort ip
218
+ LEFT JOIN (
219
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
220
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) de
221
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
222
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
223
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
224
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND DATE_ADD(ip.INDEX_DATE, -1)
225
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
226
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
227
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) <= 0
228
+ """
229
+ )
230
+
231
+ HTN_T0.createOrReplaceTempView(f"{study_name}_T0")
232
+
233
+
234
+ def create_htn_t1(spark, target_conditions, study_name):
235
+ """Create HTN_T1 - exact query from original."""
236
+ target_concept_ids = ",".join(map(str, target_conditions))
237
+
238
+ HTN_T1 = spark.sql(
239
+ f"""
240
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
241
+ FROM {study_name}_index_cohort ip
242
+ LEFT JOIN (
243
+ SELECT ce.PERSON_ID, ce.CONDITION_CONCEPT_ID
244
+ FROM CONDITION_ERA ce
245
+ JOIN {study_name}_index_cohort ip ON ce.PERSON_ID = ip.PERSON_ID
246
+ JOIN concept_ancestor ca ON ce.CONDITION_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({target_concept_ids})
247
+ WHERE ce.CONDITION_ERA_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
248
+ ) ct ON ct.PERSON_ID = ip.PERSON_ID
249
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
250
+ HAVING COUNT(ct.CONDITION_CONCEPT_ID) >= 1
251
+ """
252
+ )
253
+
254
+ HTN_T1.createOrReplaceTempView(f"{study_name}_T1")
255
+
256
+
257
+ def create_htn_t2(spark, drug_concepts, study_name):
258
+ """Create HTN_T2 - exact query from original."""
259
+ drug_concept_ids = ",".join(map(str, drug_concepts))
260
+
261
+ HTN_T2 = spark.sql(
262
+ f"""
263
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
264
+ FROM {study_name}_index_cohort ip
265
+ LEFT JOIN (
266
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
267
+ FROM (
268
+ SELECT *
269
+ FROM DRUG_EXPOSURE
270
+ WHERE visit_occurrence_id IS NOT NULL
271
+ ) de
272
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
273
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
274
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
275
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 121) AND DATE_ADD(ip.INDEX_DATE, 240)
276
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
277
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
278
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
279
+ """
280
+ )
281
+
282
+ HTN_T2.createOrReplaceTempView(f"{study_name}_T2")
283
+
284
+
285
+ def create_htn_t3(spark, drug_concepts, study_name):
286
+ """Create HTN_T3 - exact query from original."""
287
+ drug_concept_ids = ",".join(map(str, drug_concepts))
288
+
289
+ HTN_T3 = spark.sql(
290
+ f"""
291
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
292
+ FROM {study_name}_index_cohort ip
293
+ LEFT JOIN (
294
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
295
+ FROM (
296
+ SELECT *
297
+ FROM DRUG_EXPOSURE
298
+ WHERE visit_occurrence_id IS NOT NULL
299
+ ) de
300
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
301
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
302
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
303
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 241) AND DATE_ADD(ip.INDEX_DATE, 360)
304
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
305
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
306
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
307
+ """
308
+ )
309
+
310
+ HTN_T3.createOrReplaceTempView(f"{study_name}_T3")
311
+
312
+
313
+ def create_htn_t4(spark, drug_concepts, study_name):
314
+ """Create HTN_T4 - exact query from original."""
315
+ drug_concept_ids = ",".join(map(str, drug_concepts))
316
+
317
+ HTN_T4 = spark.sql(
318
+ f"""
319
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
320
+ FROM {study_name}_index_cohort ip
321
+ LEFT JOIN (
322
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
323
+ FROM (
324
+ SELECT *
325
+ FROM DRUG_EXPOSURE
326
+ WHERE visit_occurrence_id IS NOT NULL
327
+ ) de
328
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
329
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
330
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
331
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 361) AND DATE_ADD(ip.INDEX_DATE, 480)
332
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
333
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
334
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
335
+ """
336
+ )
337
+
338
+ HTN_T4.createOrReplaceTempView(f"{study_name}_T4")
339
+
340
+
341
+ def create_htn_t5(spark, drug_concepts, study_name):
342
+ """Create HTN_T5 - exact query from original."""
343
+ drug_concept_ids = ",".join(map(str, drug_concepts))
344
+
345
+ HTN_T5 = spark.sql(
346
+ f"""
347
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
348
+ FROM {study_name}_index_cohort ip
349
+ LEFT JOIN (
350
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
351
+ FROM (
352
+ SELECT *
353
+ FROM DRUG_EXPOSURE
354
+ WHERE visit_occurrence_id IS NOT NULL
355
+ ) de
356
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
357
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
358
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
359
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 481) AND DATE_ADD(ip.INDEX_DATE, 600)
360
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
361
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
362
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
363
+ """
364
+ )
365
+
366
+ HTN_T5.createOrReplaceTempView(f"{study_name}_T5")
367
+
368
+
369
+ def create_htn_t6(spark, drug_concepts, study_name):
370
+ """Create HTN_T6 - exact query from original."""
371
+ drug_concept_ids = ",".join(map(str, drug_concepts))
372
+
373
+ HTN_T6 = spark.sql(
374
+ f"""
375
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
376
+ FROM {study_name}_index_cohort ip
377
+ LEFT JOIN (
378
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
379
+ FROM (
380
+ SELECT *
381
+ FROM DRUG_EXPOSURE
382
+ WHERE visit_occurrence_id IS NOT NULL
383
+ ) de
384
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
385
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
386
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
387
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 601) AND DATE_ADD(ip.INDEX_DATE, 720)
388
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
389
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
390
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
391
+ """
392
+ )
393
+
394
+ HTN_T6.createOrReplaceTempView(f"{study_name}_T6")
395
+
396
+
397
+ def create_htn_t7(spark, drug_concepts, study_name):
398
+ """Create HTN_T7 - exact query from original."""
399
+ drug_concept_ids = ",".join(map(str, drug_concepts))
400
+
401
+ HTN_T7 = spark.sql(
402
+ f"""
403
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
404
+ FROM {study_name}_index_cohort ip
405
+ LEFT JOIN (
406
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
407
+ FROM (
408
+ SELECT *
409
+ FROM DRUG_EXPOSURE
410
+ WHERE visit_occurrence_id IS NOT NULL
411
+ ) de
412
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
413
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
414
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
415
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 721) AND DATE_ADD(ip.INDEX_DATE, 840)
416
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
417
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
418
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
419
+ """
420
+ )
421
+
422
+ HTN_T7.createOrReplaceTempView(f"{study_name}_T7")
423
+
424
+
425
+ def create_htn_t8(spark, drug_concepts, study_name):
426
+ """Create HTN_T8 - exact query from original."""
427
+ drug_concept_ids = ",".join(map(str, drug_concepts))
428
+
429
+ HTN_T8 = spark.sql(
430
+ f"""
431
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
432
+ FROM {study_name}_index_cohort ip
433
+ LEFT JOIN (
434
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
435
+ FROM (
436
+ SELECT *
437
+ FROM DRUG_EXPOSURE
438
+ WHERE visit_occurrence_id IS NOT NULL
439
+ ) de
440
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
441
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
442
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
443
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 841) AND DATE_ADD(ip.INDEX_DATE, 960)
444
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
445
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
446
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
447
+ """
448
+ )
449
+
450
+ HTN_T8.createOrReplaceTempView(f"{study_name}_T8")
451
+
452
+
453
+ def create_htn_t9(spark, drug_concepts, study_name):
454
+ """Create HTN_T9 - exact query from original."""
455
+ drug_concept_ids = ",".join(map(str, drug_concepts))
456
+
457
+ HTN_T9 = spark.sql(
458
+ f"""
459
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
460
+ FROM {study_name}_index_cohort ip
461
+ LEFT JOIN (
462
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
463
+ FROM (
464
+ SELECT *
465
+ FROM DRUG_EXPOSURE
466
+ WHERE visit_occurrence_id IS NOT NULL
467
+ ) de
468
+ JOIN {study_name}_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
469
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN ({drug_concept_ids})
470
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
471
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 961) AND DATE_ADD(ip.INDEX_DATE, 1080)
472
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
473
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
474
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
475
+ """
476
+ )
477
+
478
+ HTN_T9.createOrReplaceTempView(f"{study_name}_T9")
479
+
480
+
481
+ def create_htn_match_cohort(spark, study_name):
482
+ """Create HTN_MatchCohort - exact query from original."""
483
+ HTN_MatchCohort = spark.sql(
484
+ f"""
485
+ SELECT c.person_id, c.index_date, c.cohort_end_date, c.observation_period_start_date, c.observation_period_end_date
486
+ FROM {study_name}_index_cohort C
487
+ INNER JOIN (
488
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID
489
+ FROM (
490
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_E0
491
+ INTERSECT
492
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T0
493
+ INTERSECT
494
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T1
495
+ INTERSECT
496
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T2
497
+ INTERSECT
498
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T3
499
+ INTERSECT
500
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T4
501
+ INTERSECT
502
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T5
503
+ INTERSECT
504
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T6
505
+ INTERSECT
506
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T7
507
+ INTERSECT
508
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T8
509
+ INTERSECT
510
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM {study_name}_T9
511
+ ) TopGroup
512
+ ) I
513
+ ON C.PERSON_ID = I.PERSON_ID
514
+ AND c.index_date = i.index_date
515
+ """
516
+ )
517
+
518
+ return HTN_MatchCohort
519
+
520
+
521
+ def main():
522
+ args = parse_arguments()
523
+ # Parse concept IDs
524
+ drug_concepts = parse_concept_ids(args.drug_concepts)
525
+ target_conditions = parse_concept_ids(args.target_conditions)
526
+ exclusion_conditions = (
527
+ parse_concept_ids(args.exclusion_conditions)
528
+ if args.exclusion_conditions
529
+ else []
530
+ )
531
+
532
+ print(f"Study name: {args.study_name}")
533
+ print(f"Drug concepts: {drug_concepts}")
534
+ print(f"Target conditions: {target_conditions}")
535
+ print(f"Exclusion conditions: {exclusion_conditions}")
536
+
537
+ # Initialize Spark
538
+ spark = SparkSession.builder.appName(args.app_name).getOrCreate()
539
+
540
+ try:
541
+ # Load the source OMOP tables
542
+ person = spark.read.parquet(os.path.join(args.omop_folder, "person"))
543
+ visit_occurrence = spark.read.parquet(
544
+ os.path.join(args.omop_folder, "visit_occurrence")
545
+ )
546
+ condition_occurrence = spark.read.parquet(
547
+ os.path.join(args.omop_folder, "condition_occurrence")
548
+ )
549
+ procedure_occurrence = spark.read.parquet(
550
+ os.path.join(args.omop_folder, "procedure_occurrence")
551
+ )
552
+ drug_exposure = spark.read.parquet(
553
+ os.path.join(args.omop_folder, "drug_exposure")
554
+ )
555
+ observation_period = spark.read.parquet(
556
+ os.path.join(args.omop_folder, "observation_period")
557
+ )
558
+ condition_era = spark.read.parquet(
559
+ os.path.join(args.omop_folder, "condition_era")
560
+ )
561
+
562
+ print(f"person: {person.select('person_id').distinct().count()}")
563
+ print(f"visit_occurrence: {visit_occurrence.count()}")
564
+ print(f"condition_occurrence: {condition_occurrence.count()}")
565
+ print(f"procedure_occurrence: {procedure_occurrence.count()}")
566
+ print(f"drug_exposure: {drug_exposure.count()}")
567
+ print(f"observation_period: {observation_period.count()}")
568
+
569
+ concept = spark.read.parquet(os.path.join(args.omop_folder, "concept"))
570
+ concept_ancestor = spark.read.parquet(
571
+ os.path.join(args.omop_folder, "concept_ancestor")
572
+ )
573
+
574
+ # Create temporary views
575
+ person.createOrReplaceTempView("person")
576
+ visit_occurrence.createOrReplaceTempView("visit_occurrence")
577
+ condition_occurrence.createOrReplaceTempView("condition_occurrence")
578
+ procedure_occurrence.createOrReplaceTempView("procedure_occurrence")
579
+ drug_exposure.createOrReplaceTempView("drug_exposure")
580
+ observation_period.createOrReplaceTempView("observation_period")
581
+ condition_era.createOrReplaceTempView("condition_era")
582
+
583
+ concept_ancestor.createOrReplaceTempView("concept_ancestor")
584
+ concept.createOrReplaceTempView("concept")
585
+
586
+ # Create drug concept mapping
587
+ create_drug_concept_mapping(spark, drug_concepts)
588
+
589
+ # Create HTN index cohort
590
+ create_htn_index_cohort(spark, drug_concepts, args.study_name)
591
+
592
+ # Create all cohorts in sequence - keeping exact original function calls
593
+ create_htn_e0(spark, exclusion_conditions, args.study_name)
594
+ create_htn_t0(spark, drug_concepts, args.study_name)
595
+ create_htn_t1(spark, target_conditions, args.study_name)
596
+ create_htn_t2(spark, drug_concepts, args.study_name)
597
+ create_htn_t3(spark, drug_concepts, args.study_name)
598
+ create_htn_t4(spark, drug_concepts, args.study_name)
599
+ create_htn_t5(spark, drug_concepts, args.study_name)
600
+ create_htn_t6(spark, drug_concepts, args.study_name)
601
+ create_htn_t7(spark, drug_concepts, args.study_name)
602
+ create_htn_t8(spark, drug_concepts, args.study_name)
603
+ create_htn_t9(spark, drug_concepts, args.study_name)
604
+
605
+ # Create final cohort
606
+ htn_match_cohort = create_htn_match_cohort(spark, args.study_name)
607
+
608
+ if args.save_cohort:
609
+ # Save results
610
+ if not os.path.exists(args.output_folder):
611
+ os.makedirs(args.output_folder)
612
+
613
+ output_path = os.path.join(args.output_folder, "htn_match_cohort")
614
+ htn_match_cohort.write.mode("overwrite").parquet(output_path)
615
+
616
+ # Read back and count
617
+ htn_match_cohort = spark.read.parquet(output_path)
618
+ final_count = htn_match_cohort.count()
619
+ print(f"Final cohort count: {final_count}")
620
+
621
+ print("Analysis completed successfully!")
622
+
623
+ except Exception as e:
624
+ print(f"Error: {e}")
625
+ sys.exit(1)
626
+ finally:
627
+ spark.stop()
628
+
629
+
630
+ if __name__ == "__main__":
631
+ main()