cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,546 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ Treatment Pathways Study Script.
4
+
5
+ Analyzes hypertension treatment pathways using OMOP CDM data
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ import sys
11
+
12
+ from pyspark.sql import SparkSession
13
+ from pyspark.sql import functions as f
14
+
15
+
16
+ def parse_arguments():
17
+ """Parse command line arguments."""
18
+ parser = argparse.ArgumentParser(
19
+ description="Analyze hypertension treatment pathways using OMOP CDM data"
20
+ )
21
+
22
+ parser.add_argument(
23
+ "--omop-folder",
24
+ required=True,
25
+ help="Path to OMOP CDM data folder containing parquet files",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "--output-folder", required=True, help="Output folder for results"
30
+ )
31
+
32
+ parser.add_argument(
33
+ "--app-name",
34
+ default="HTN_Treatment_Pathways",
35
+ help="Spark application name (default: HTN_Treatment_Pathways)",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--spark-master",
40
+ default="local[*]",
41
+ help="Spark master URL (default: local[*])",
42
+ )
43
+
44
+ return parser.parse_args()
45
+
46
+
47
+ def create_drug_concept_mapping(spark):
48
+ """Create drug concept mapping for HTN medications - exact query from original."""
49
+ drug_concept = spark.sql(
50
+ """
51
+ SELECT DISTINCT
52
+ ancestor_concept_id,
53
+ descendant_concept_id
54
+ FROM
55
+ (
56
+ SELECT
57
+ ancestor_concept_id,
58
+ descendant_concept_id
59
+ FROM concept_ancestor AS ca
60
+ WHERE ca.ancestor_concept_id IN (21600381,21601461,21601560,21601664,21601744,21601782)
61
+ ) a
62
+ """
63
+ )
64
+ drug_concept.cache()
65
+ drug_concept.createOrReplaceTempView("drug_concept")
66
+
67
+
68
+ def create_htn_index_cohort(spark):
69
+ """Create HTN index cohort - exact query from original."""
70
+ htn_index_cohort = spark.sql(
71
+ """
72
+ SELECT person_id, INDEX_DATE, COHORT_END_DATE, observation_period_start_date, observation_period_end_date
73
+ FROM (
74
+ SELECT ot.PERSON_ID, ot.INDEX_DATE, MIN(e.END_DATE) as COHORT_END_DATE, ot.OBSERVATION_PERIOD_START_DATE, ot.OBSERVATION_PERIOD_END_DATE,
75
+ ROW_NUMBER() OVER (PARTITION BY ot.PERSON_ID ORDER BY ot.INDEX_DATE) as RowNumber
76
+ FROM (
77
+ SELECT dt.PERSON_ID, dt.DRUG_EXPOSURE_START_DATE as index_date, op.OBSERVATION_PERIOD_START_DATE, op.OBSERVATION_PERIOD_END_DATE
78
+ FROM (
79
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID, de.DRUG_EXPOSURE_START_DATE
80
+ FROM (
81
+ SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
82
+ COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
83
+ ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY d.DRUG_EXPOSURE_START_DATE) as RowNumber
84
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
85
+ JOIN drug_concept ca
86
+ ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
87
+ ) de
88
+ JOIN PERSON p ON p.PERSON_ID = de.PERSON_ID
89
+ WHERE de.RowNumber = 1
90
+ ) dt
91
+ JOIN observation_period op
92
+ ON op.PERSON_ID = dt.PERSON_ID AND (dt.DRUG_EXPOSURE_START_DATE BETWEEN op.OBSERVATION_PERIOD_START_DATE AND op.OBSERVATION_PERIOD_END_DATE)
93
+ WHERE DATE_ADD(op.OBSERVATION_PERIOD_START_DATE, 365) <= dt.DRUG_EXPOSURE_START_DATE
94
+ AND DATE_ADD(dt.DRUG_EXPOSURE_START_DATE, 1095) <= op.OBSERVATION_PERIOD_END_DATE
95
+ ) ot
96
+ JOIN (
97
+ SELECT PERSON_ID, DATE_ADD(EVENT_DATE, -31) as END_DATE
98
+ FROM (
99
+ SELECT PERSON_ID, EVENT_DATE, EVENT_TYPE, START_ORDINAL,
100
+ ROW_NUMBER() OVER (PARTITION BY PERSON_ID ORDER BY EVENT_DATE, EVENT_TYPE) AS EVENT_ORDINAL,
101
+ MAX(START_ORDINAL) OVER (PARTITION BY PERSON_ID ORDER BY EVENT_DATE, EVENT_TYPE ROWS UNBOUNDED PRECEDING) as STARTS
102
+ FROM (
103
+ SELECT PERSON_ID, DRUG_EXPOSURE_START_DATE AS EVENT_DATE, 1 as EVENT_TYPE,
104
+ ROW_NUMBER() OVER (PARTITION BY PERSON_ID ORDER BY DRUG_EXPOSURE_START_DATE) as START_ORDINAL
105
+ FROM (
106
+ SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
107
+ COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
108
+ ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY DRUG_EXPOSURE_START_DATE) as RowNumber
109
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
110
+ JOIN drug_concept ca
111
+ ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
112
+ ) cteExposureData
113
+ UNION ALL
114
+ SELECT PERSON_ID, DATE_ADD(DRUG_EXPOSURE_END_DATE, 31), 0 as EVENT_TYPE, NULL
115
+ FROM (
116
+ SELECT d.PERSON_ID, d.DRUG_CONCEPT_ID, d.DRUG_EXPOSURE_START_DATE,
117
+ COALESCE(d.DRUG_EXPOSURE_END_DATE, DATE_ADD(d.DRUG_EXPOSURE_START_DATE, d.DAYS_SUPPLY), DATE_ADD(d.DRUG_EXPOSURE_START_DATE, 1)) as DRUG_EXPOSURE_END_DATE,
118
+ ROW_NUMBER() OVER (PARTITION BY d.PERSON_ID ORDER BY DRUG_EXPOSURE_START_DATE) as RowNumber
119
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) d
120
+ JOIN drug_concept ca
121
+ ON d.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
122
+ ) cteExposureData
123
+ ) RAWDATA
124
+ ) E
125
+ WHERE 2 * E.STARTS - E.EVENT_ORDINAL = 0
126
+ ) e ON e.PERSON_ID = ot.PERSON_ID AND e.END_DATE >= ot.INDEX_DATE
127
+ GROUP BY ot.PERSON_ID, ot.INDEX_DATE, ot.OBSERVATION_PERIOD_START_DATE, ot.OBSERVATION_PERIOD_END_DATE
128
+ ) r
129
+ WHERE r.RowNumber = 1
130
+ """
131
+ )
132
+
133
+ htn_index_cohort.cache()
134
+ htn_index_cohort.createOrReplaceTempView("htn_index_cohort")
135
+
136
+
137
+ def create_htn_e0(spark):
138
+ """Create HTN_E0 - exact query from original."""
139
+ HTN_E0 = spark.sql(
140
+ """
141
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
142
+ FROM htn_index_cohort ip
143
+ LEFT JOIN (
144
+ SELECT co.PERSON_ID, co.CONDITION_CONCEPT_ID
145
+ FROM condition_occurrence co
146
+ JOIN htn_index_cohort ip ON co.PERSON_ID = ip.PERSON_ID
147
+ JOIN drug_concept ca ON co.CONDITION_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (444094)
148
+ WHERE co.CONDITION_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
149
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
150
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
151
+ HAVING COUNT(dt.CONDITION_CONCEPT_ID) <= 0
152
+ """
153
+ )
154
+
155
+ HTN_E0.cache()
156
+ HTN_E0.createOrReplaceTempView("HTN_E0")
157
+
158
+
159
+ def create_htn_t0(spark):
160
+ """Create HTN_T0 - exact query from original."""
161
+ HTN_T0 = spark.sql(
162
+ """
163
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
164
+ FROM htn_index_cohort ip
165
+ LEFT JOIN (
166
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
167
+ FROM (SELECT * FROM DRUG_EXPOSURE WHERE visit_occurrence_id IS NOT NULL) de
168
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
169
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
170
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
171
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND DATE_ADD(ip.INDEX_DATE, -1)
172
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
173
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
174
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) <= 0
175
+ """
176
+ )
177
+
178
+ HTN_T0.createOrReplaceTempView("HTN_T0")
179
+
180
+
181
+ def create_htn_t1(spark):
182
+ """Create HTN_T1 - exact query from original."""
183
+ HTN_T1 = spark.sql(
184
+ """
185
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
186
+ FROM htn_index_cohort ip
187
+ LEFT JOIN (
188
+ SELECT ce.PERSON_ID, ce.CONDITION_CONCEPT_ID
189
+ FROM CONDITION_ERA ce
190
+ JOIN htn_index_cohort ip ON ce.PERSON_ID = ip.PERSON_ID
191
+ JOIN concept_ancestor ca ON ce.CONDITION_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (316866)
192
+ WHERE ce.CONDITION_ERA_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
193
+ ) ct ON ct.PERSON_ID = ip.PERSON_ID
194
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
195
+ HAVING COUNT(ct.CONDITION_CONCEPT_ID) >= 1
196
+ """
197
+ )
198
+
199
+ HTN_T1.createOrReplaceTempView("HTN_T1")
200
+
201
+
202
+ def create_htn_t2(spark):
203
+ """Create HTN_T2 - exact query from original."""
204
+ HTN_T2 = spark.sql(
205
+ """
206
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
207
+ FROM htn_index_cohort ip
208
+ LEFT JOIN (
209
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
210
+ FROM (
211
+ SELECT *
212
+ FROM DRUG_EXPOSURE
213
+ WHERE visit_occurrence_id IS NOT NULL
214
+ ) de
215
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
216
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
217
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
218
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 121) AND DATE_ADD(ip.INDEX_DATE, 240)
219
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
220
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
221
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
222
+ """
223
+ )
224
+
225
+ HTN_T2.createOrReplaceTempView("HTN_T2")
226
+
227
+
228
+ def create_htn_t3(spark):
229
+ """Create HTN_T3 - exact query from original."""
230
+ HTN_T3 = spark.sql(
231
+ """
232
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
233
+ FROM htn_index_cohort ip
234
+ LEFT JOIN (
235
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
236
+ FROM (
237
+ SELECT *
238
+ FROM DRUG_EXPOSURE
239
+ WHERE visit_occurrence_id IS NOT NULL
240
+ ) de
241
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
242
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
243
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
244
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 241) AND DATE_ADD(ip.INDEX_DATE, 360)
245
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
246
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
247
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
248
+ """
249
+ )
250
+
251
+ HTN_T3.createOrReplaceTempView("HTN_T3")
252
+
253
+
254
+ def create_htn_t4(spark):
255
+ """Create HTN_T4 - exact query from original."""
256
+ HTN_T4 = spark.sql(
257
+ """
258
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
259
+ FROM htn_index_cohort ip
260
+ LEFT JOIN (
261
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
262
+ FROM (
263
+ SELECT *
264
+ FROM DRUG_EXPOSURE
265
+ WHERE visit_occurrence_id IS NOT NULL
266
+ ) de
267
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
268
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
269
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
270
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 361) AND DATE_ADD(ip.INDEX_DATE, 480)
271
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
272
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
273
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
274
+ """
275
+ )
276
+
277
+ HTN_T4.createOrReplaceTempView("HTN_T4")
278
+
279
+
280
+ def create_htn_t5(spark):
281
+ """Create HTN_T5 - exact query from original."""
282
+ HTN_T5 = spark.sql(
283
+ """
284
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
285
+ FROM htn_index_cohort ip
286
+ LEFT JOIN (
287
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
288
+ FROM (
289
+ SELECT *
290
+ FROM DRUG_EXPOSURE
291
+ WHERE visit_occurrence_id IS NOT NULL
292
+ ) de
293
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
294
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
295
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
296
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 481) AND DATE_ADD(ip.INDEX_DATE, 600)
297
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
298
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
299
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
300
+ """
301
+ )
302
+
303
+ HTN_T5.createOrReplaceTempView("HTN_T5")
304
+
305
+
306
+ def create_htn_t6(spark):
307
+ """Create HTN_T6 - exact query from original."""
308
+ HTN_T6 = spark.sql(
309
+ """
310
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
311
+ FROM htn_index_cohort ip
312
+ LEFT JOIN (
313
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
314
+ FROM (
315
+ SELECT *
316
+ FROM DRUG_EXPOSURE
317
+ WHERE visit_occurrence_id IS NOT NULL
318
+ ) de
319
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
320
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
321
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
322
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 601) AND DATE_ADD(ip.INDEX_DATE, 720)
323
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
324
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
325
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
326
+ """
327
+ )
328
+
329
+ HTN_T6.createOrReplaceTempView("HTN_T6")
330
+
331
+
332
+ def create_htn_t7(spark):
333
+ """Create HTN_T7 - exact query from original."""
334
+ HTN_T7 = spark.sql(
335
+ """
336
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
337
+ FROM htn_index_cohort ip
338
+ LEFT JOIN (
339
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
340
+ FROM (
341
+ SELECT *
342
+ FROM DRUG_EXPOSURE
343
+ WHERE visit_occurrence_id IS NOT NULL
344
+ ) de
345
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
346
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
347
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
348
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 721) AND DATE_ADD(ip.INDEX_DATE, 840)
349
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
350
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
351
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
352
+ """
353
+ )
354
+
355
+ HTN_T7.createOrReplaceTempView("HTN_T7")
356
+
357
+
358
+ def create_htn_t8(spark):
359
+ """Create HTN_T8 - exact query from original."""
360
+ HTN_T8 = spark.sql(
361
+ """
362
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
363
+ FROM htn_index_cohort ip
364
+ LEFT JOIN (
365
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
366
+ FROM (
367
+ SELECT *
368
+ FROM DRUG_EXPOSURE
369
+ WHERE visit_occurrence_id IS NOT NULL
370
+ ) de
371
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
372
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
373
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
374
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 841) AND DATE_ADD(ip.INDEX_DATE, 960)
375
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
376
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
377
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
378
+ """
379
+ )
380
+
381
+ HTN_T8.createOrReplaceTempView("HTN_T8")
382
+
383
+
384
+ def create_htn_t9(spark):
385
+ """Create HTN_T9 - exact query from original."""
386
+ HTN_T9 = spark.sql(
387
+ """
388
+ SELECT ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
389
+ FROM htn_index_cohort ip
390
+ LEFT JOIN (
391
+ SELECT de.PERSON_ID, de.DRUG_CONCEPT_ID
392
+ FROM (
393
+ SELECT *
394
+ FROM DRUG_EXPOSURE
395
+ WHERE visit_occurrence_id IS NOT NULL
396
+ ) de
397
+ JOIN htn_index_cohort ip ON de.PERSON_ID = ip.PERSON_ID
398
+ JOIN drug_concept ca ON de.DRUG_CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID AND ca.ANCESTOR_CONCEPT_ID IN (21600381, 21601461, 21601560, 21601664, 21601744, 21601782)
399
+ WHERE de.DRUG_EXPOSURE_START_DATE BETWEEN ip.OBSERVATION_PERIOD_START_DATE AND ip.OBSERVATION_PERIOD_END_DATE
400
+ AND de.DRUG_EXPOSURE_START_DATE BETWEEN DATE_ADD(ip.INDEX_DATE, 961) AND DATE_ADD(ip.INDEX_DATE, 1080)
401
+ ) dt ON dt.PERSON_ID = ip.PERSON_ID
402
+ GROUP BY ip.PERSON_ID, ip.INDEX_DATE, ip.COHORT_END_DATE
403
+ HAVING COUNT(dt.DRUG_CONCEPT_ID) >= 1
404
+ """
405
+ )
406
+
407
+ HTN_T9.createOrReplaceTempView("HTN_T9")
408
+
409
+
410
+ def create_htn_match_cohort(spark):
411
+ """Create HTN_MatchCohort - exact query from original."""
412
+ HTN_MatchCohort = spark.sql(
413
+ """
414
+ SELECT c.person_id, c.index_date, c.cohort_end_date, c.observation_period_start_date, c.observation_period_end_date
415
+ FROM HTN_Index_Cohort C
416
+ INNER JOIN (
417
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID
418
+ FROM (
419
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_E0
420
+ INTERSECT
421
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T0
422
+ INTERSECT
423
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T1
424
+ INTERSECT
425
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T2
426
+ INTERSECT
427
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T3
428
+ INTERSECT
429
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T4
430
+ INTERSECT
431
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T5
432
+ INTERSECT
433
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T6
434
+ INTERSECT
435
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T7
436
+ INTERSECT
437
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T8
438
+ INTERSECT
439
+ SELECT INDEX_DATE, COHORT_END_DATE, PERSON_ID FROM HTN_T9
440
+ ) TopGroup
441
+ ) I
442
+ ON C.PERSON_ID = I.PERSON_ID
443
+ AND c.index_date = i.index_date
444
+ """
445
+ )
446
+
447
+ return HTN_MatchCohort
448
+
449
+
450
+ def main():
451
+ args = parse_arguments()
452
+
453
+ # Initialize Spark
454
+ spark = SparkSession.builder.appName(f"HTN treatment pathway").getOrCreate()
455
+
456
+ try:
457
+ # Load the source OMOP tables
458
+ person = spark.read.parquet(os.path.join(args.omop_folder, "person"))
459
+ visit_occurrence = spark.read.parquet(
460
+ os.path.join(args.omop_folder, "visit_occurrence")
461
+ )
462
+ condition_occurrence = spark.read.parquet(
463
+ os.path.join(args.omop_folder, "condition_occurrence")
464
+ )
465
+ procedure_occurrence = spark.read.parquet(
466
+ os.path.join(args.omop_folder, "procedure_occurrence")
467
+ )
468
+ drug_exposure = spark.read.parquet(
469
+ os.path.join(args.omop_folder, "drug_exposure")
470
+ )
471
+ observation_period = spark.read.parquet(
472
+ os.path.join(args.omop_folder, "observation_period")
473
+ )
474
+ condition_era = spark.read.parquet(
475
+ os.path.join(args.omop_folder, "condition_era")
476
+ )
477
+
478
+ print(f"person: {visit_occurrence.select('person_id').distinct().count()}")
479
+ print(f"visit_occurrence: {visit_occurrence.count()}")
480
+ print(f"condition_occurrence: {condition_occurrence.count()}")
481
+ print(f"procedure_occurrence: {procedure_occurrence.count()}")
482
+ print(f"drug_exposure: {drug_exposure.count()}")
483
+ print(f"observation_period: {observation_period.count()}")
484
+
485
+ concept = spark.read.parquet(os.path.join(args.omop_folder, "concept"))
486
+ concept_ancestor = spark.read.parquet(
487
+ os.path.join(args.omop_folder, "concept_ancestor")
488
+ )
489
+
490
+ # Create temporary views
491
+ person.createOrReplaceTempView("person")
492
+ visit_occurrence.createOrReplaceTempView("visit_occurrence")
493
+ condition_occurrence.createOrReplaceTempView("condition_occurrence")
494
+ procedure_occurrence.createOrReplaceTempView("procedure_occurrence")
495
+ drug_exposure.createOrReplaceTempView("drug_exposure")
496
+ observation_period.createOrReplaceTempView("observation_period")
497
+ condition_era.createOrReplaceTempView("condition_era")
498
+
499
+ concept_ancestor.createOrReplaceTempView("concept_ancestor")
500
+ concept.createOrReplaceTempView("concept")
501
+
502
+ # Create drug concept mapping
503
+ create_drug_concept_mapping(spark)
504
+
505
+ # Create HTN index cohort
506
+ create_htn_index_cohort(spark)
507
+
508
+ # Create all cohorts in sequence
509
+ create_htn_e0(spark)
510
+ create_htn_t0(spark)
511
+ create_htn_t1(spark)
512
+ create_htn_t2(spark)
513
+ create_htn_t3(spark)
514
+ create_htn_t4(spark)
515
+ create_htn_t5(spark)
516
+ create_htn_t6(spark)
517
+ create_htn_t7(spark)
518
+ create_htn_t8(spark)
519
+ create_htn_t9(spark)
520
+
521
+ # Create final cohort
522
+ htn_match_cohort = create_htn_match_cohort(spark)
523
+
524
+ # # Save results
525
+ # if not os.path.exists(args.output_folder):
526
+ # os.makedirs(args.output_folder)
527
+ #
528
+ # output_path = os.path.join(args.output_folder, "htn_match_cohort")
529
+ # htn_match_cohort.write.mode("overwrite").parquet(output_path)
530
+ #
531
+ # # Read back and count
532
+ # saved_cohort = spark.read.parquet(output_path)
533
+ final_count = htn_match_cohort.count()
534
+ print(f"Final cohort count: {final_count}")
535
+
536
+ print("Analysis completed successfully!")
537
+
538
+ except Exception as e:
539
+ print(f"Error: {e}")
540
+ sys.exit(1)
541
+ finally:
542
+ spark.stop()
543
+
544
+
545
+ if __name__ == "__main__":
546
+ main()
File without changes
@@ -0,0 +1,94 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ Depression Treatment Pathways Study Script.
4
+
5
+ Specialized script for analyzing depression treatment pathways using OMOP CDM data.
6
+ Uses the generalized treatment_pathway module with depression-specific parameters.
7
+ """
8
+
9
+ import argparse
10
+ import sys
11
+
12
+ from cehrgpt.analysis.treatment_pathway.treatment_pathway import main as treatment_main
13
+
14
+ # Depression-specific configuration
15
+ DEPRESSION_CONFIG = {
16
+ "study_name": "DEPRESSION",
17
+ "drug_concepts": [21604686, 21500526],
18
+ "target_conditions": [440383],
19
+ "exclusion_conditions": [444094, 432876, 435783],
20
+ }
21
+
22
+
23
+ def parse_depression_arguments():
24
+ """Parse command line arguments specific to depression study."""
25
+ parser = argparse.ArgumentParser(
26
+ description="Analyze depression treatment pathways using OMOP CDM data"
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--omop-folder",
31
+ required=True,
32
+ help="Path to OMOP CDM data folder containing parquet files",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--output-folder", required=True, help="Output folder for results"
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--app-name",
41
+ default="Depression_Treatment_Pathways",
42
+ help="Spark application name (default: Depression_Treatment_Pathways)",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--spark-master",
47
+ default="local[*]",
48
+ help="Spark master URL (default: local[*])",
49
+ )
50
+
51
+ return parser.parse_args()
52
+
53
+
54
+ def main():
55
+ """Main function that overwrites sys.argv and calls the generalized treatment pathway main."""
56
+ depression_args = parse_depression_arguments()
57
+
58
+ print("Depression Treatment Pathways Analysis")
59
+ print("=" * 50)
60
+ print(f"Study Configuration:")
61
+ print(f" - Study Name: {DEPRESSION_CONFIG['study_name']}")
62
+ print(f" - Drug Concepts: {DEPRESSION_CONFIG['drug_concepts']}")
63
+ print(f" - Target Conditions: {DEPRESSION_CONFIG['target_conditions']}")
64
+ print(f" - Exclusion Conditions: {DEPRESSION_CONFIG['exclusion_conditions']}")
65
+ print(f" - OMOP Folder: {depression_args.omop_folder}")
66
+ print(f" - Output Folder: {depression_args.output_folder}")
67
+ print("=" * 50)
68
+
69
+ # Overwrite sys.argv with the generalized treatment pathway arguments
70
+ sys.argv = [
71
+ "treatment_pathway.py",
72
+ "--omop-folder",
73
+ depression_args.omop_folder,
74
+ "--output-folder",
75
+ depression_args.output_folder,
76
+ "--drug-concepts",
77
+ ",".join(map(str, DEPRESSION_CONFIG["drug_concepts"])),
78
+ "--target-conditions",
79
+ ",".join(map(str, DEPRESSION_CONFIG["target_conditions"])),
80
+ "--exclusion-conditions",
81
+ ",".join(map(str, DEPRESSION_CONFIG["exclusion_conditions"])),
82
+ "--app-name",
83
+ depression_args.app_name,
84
+ "--spark-master",
85
+ depression_args.spark_master,
86
+ "--study-name",
87
+ DEPRESSION_CONFIG["study_name"],
88
+ ]
89
+ # Parse arguments using the generalized script's parser and call its main function
90
+ treatment_main()
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()