cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.4.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ Diabetes Treatment Pathways Study Script.
4
+
5
+ Specialized script for analyzing diabetes treatment pathways using OMOP CDM data.
6
+ Uses the generalized treatment_pathway module with diabetes-specific parameters.
7
+ """
8
+
9
+ import argparse
10
+ import sys
11
+
12
+ from cehrgpt.analysis.treatment_pathway.treatment_pathway import main as treatment_main
13
+
14
+ # Diabetes-specific configuration
15
+ DIABETES_CONFIG = {
16
+ "study_name": "DIABETES",
17
+ "drug_concepts": [21600712, 21500148],
18
+ "target_conditions": [201820],
19
+ "exclusion_conditions": [444094, 35506621],
20
+ }
21
+
22
+
23
+ def parse_diabetes_arguments():
24
+ """Parse command line arguments specific to diabetes study."""
25
+ parser = argparse.ArgumentParser(
26
+ description="Analyze diabetes treatment pathways using OMOP CDM data"
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--omop-folder",
31
+ required=True,
32
+ help="Path to OMOP CDM data folder containing parquet files",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--output-folder", required=True, help="Output folder for results"
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--app-name",
41
+ default="Diabetes_Treatment_Pathways",
42
+ help="Spark application name (default: Diabetes_Treatment_Pathways)",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--spark-master",
47
+ default="local[*]",
48
+ help="Spark master URL (default: local[*])",
49
+ )
50
+
51
+ return parser.parse_args()
52
+
53
+
54
+ def main():
55
+ """Main function that overwrites sys.argv and calls the generalized treatment pathway main."""
56
+ diabetes_args = parse_diabetes_arguments()
57
+
58
+ print("Diabetes Treatment Pathways Analysis")
59
+ print("=" * 50)
60
+ print(f"Study Configuration:")
61
+ print(f" - Study Name: {DIABETES_CONFIG['study_name']}")
62
+ print(f" - Drug Concepts: {DIABETES_CONFIG['drug_concepts']}")
63
+ print(f" - Target Conditions: {DIABETES_CONFIG['target_conditions']}")
64
+ print(f" - Exclusion Conditions: {DIABETES_CONFIG['exclusion_conditions']}")
65
+ print(f" - OMOP Folder: {diabetes_args.omop_folder}")
66
+ print(f" - Output Folder: {diabetes_args.output_folder}")
67
+ print("=" * 50)
68
+
69
+ # Overwrite sys.argv with the generalized treatment pathway arguments
70
+ sys.argv = [
71
+ "treatment_pathway.py",
72
+ "--omop-folder",
73
+ diabetes_args.omop_folder,
74
+ "--output-folder",
75
+ diabetes_args.output_folder,
76
+ "--drug-concepts",
77
+ ",".join(map(str, DIABETES_CONFIG["drug_concepts"])),
78
+ "--target-conditions",
79
+ ",".join(map(str, DIABETES_CONFIG["target_conditions"])),
80
+ "--exclusion-conditions",
81
+ ",".join(map(str, DIABETES_CONFIG["exclusion_conditions"])),
82
+ "--app-name",
83
+ diabetes_args.app_name,
84
+ "--spark-master",
85
+ diabetes_args.spark_master,
86
+ "--study-name",
87
+ DIABETES_CONFIG["study_name"],
88
+ ]
89
+ # Parse arguments using the generalized script's parser and call its main function
90
+ treatment_main()
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()
@@ -0,0 +1,94 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ HTN Treatment Pathways Study Script.
4
+
5
+ Specialized script for analyzing hypertension treatment pathways using OMOP CDM data.
6
+ Uses the generalized treatment_pathway module with HTN-specific parameters.
7
+ """
8
+
9
+ import argparse
10
+ import sys
11
+
12
+ from cehrgpt.analysis.treatment_pathway.treatment_pathway import main as treatment_main
13
+
14
+ # HTN-specific configuration
15
+ HTN_CONFIG = {
16
+ "study_name": "HTN",
17
+ "drug_concepts": [21600381, 21601461, 21601560, 21601664, 21601744, 21601782],
18
+ "target_conditions": [316866],
19
+ "exclusion_conditions": [444094],
20
+ }
21
+
22
+
23
+ def parse_htn_arguments():
24
+ """Parse command line arguments specific to HTN study."""
25
+ parser = argparse.ArgumentParser(
26
+ description="Analyze hypertension treatment pathways using OMOP CDM data"
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--omop-folder",
31
+ required=True,
32
+ help="Path to OMOP CDM data folder containing parquet files",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--output-folder", required=True, help="Output folder for results"
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--app-name",
41
+ default="HTN_Treatment_Pathways",
42
+ help="Spark application name (default: HTN_Treatment_Pathways)",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--spark-master",
47
+ default="local[*]",
48
+ help="Spark master URL (default: local[*])",
49
+ )
50
+
51
+ return parser.parse_args()
52
+
53
+
54
+ def main():
55
+ """Main function that overwrites sys.argv and calls the generalized treatment pathway main."""
56
+ htn_args = parse_htn_arguments()
57
+
58
+ print("HTN Treatment Pathways Analysis")
59
+ print("=" * 50)
60
+ print(f"Study Configuration:")
61
+ print(f" - Study Name: {HTN_CONFIG['study_name']}")
62
+ print(f" - Drug Concepts: {HTN_CONFIG['drug_concepts']}")
63
+ print(f" - Target Conditions: {HTN_CONFIG['target_conditions']}")
64
+ print(f" - Exclusion Conditions: {HTN_CONFIG['exclusion_conditions']}")
65
+ print(f" - OMOP Folder: {htn_args.omop_folder}")
66
+ print(f" - Output Folder: {htn_args.output_folder}")
67
+ print("=" * 50)
68
+
69
+ # Overwrite sys.argv with the generalized treatment pathway arguments
70
+ sys.argv = [
71
+ "treatment_pathway.py",
72
+ "--omop-folder",
73
+ htn_args.omop_folder,
74
+ "--output-folder",
75
+ htn_args.output_folder,
76
+ "--drug-concepts",
77
+ ",".join(map(str, HTN_CONFIG["drug_concepts"])),
78
+ "--target-conditions",
79
+ ",".join(map(str, HTN_CONFIG["target_conditions"])),
80
+ "--exclusion-conditions",
81
+ ",".join(map(str, HTN_CONFIG["exclusion_conditions"])),
82
+ "--app-name",
83
+ htn_args.app_name,
84
+ "--spark-master",
85
+ htn_args.spark_master,
86
+ "--study-name",
87
+ HTN_CONFIG["study_name"],
88
+ ]
89
+ # Parse arguments using the generalized script's parser and call its main function
90
+ treatment_main()
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()