cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
File without changes
@@ -0,0 +1,20 @@
1
+ from cehrgpt.omop.omop_argparse import create_omop_argparse
2
+ from cehrgpt.omop.omop_table_builder import OmopTableBuilder
3
+ from cehrgpt.omop.queries.condition_era import CONDITION_ERA_QUERY
4
+
5
+ CONDITION_ERA = "condition_era"
6
+
7
+
8
+ def main(args):
9
+ OmopTableBuilder.create_omop_query_builder(
10
+ input_folder=args.input_folder,
11
+ output_folder=args.output_folder,
12
+ continue_job=args.continue_job,
13
+ table_name=CONDITION_ERA,
14
+ query_template=CONDITION_ERA_QUERY,
15
+ dependency_list=["condition_occurrence"],
16
+ ).build()
17
+
18
+
19
+ if __name__ == "__main__":
20
+ main(create_omop_argparse().parse_args())
@@ -0,0 +1,43 @@
1
+ import glob
2
+ import os.path
3
+
4
+ from cehrgpt.omop.omop_argparse import create_omop_argparse
5
+ from cehrgpt.omop.omop_table_builder import OmopTableBuilder
6
+ from cehrgpt.omop.queries.observation_period import (
7
+ OBSERVATION_PERIOD_QUERY,
8
+ OBSERVATION_PERIOD_WITH_MEASUREMENT_QUERY,
9
+ )
10
+
11
+ OBSERVATION_PERIOD = "observation_period"
12
+
13
+
14
+ def main(args):
15
+ include_measurement = (
16
+ len(glob.glob(os.path.join(args.input_folder, "measurement", "*.parquet"))) > 0
17
+ )
18
+ dependency_list = [
19
+ "person",
20
+ "visit_occurrence",
21
+ "condition_occurrence",
22
+ "procedure_occurrence",
23
+ "drug_exposure",
24
+ ]
25
+ if include_measurement:
26
+ dependency_list.append("measurement")
27
+
28
+ OmopTableBuilder.create_omop_query_builder(
29
+ input_folder=args.input_folder,
30
+ output_folder=args.output_folder,
31
+ continue_job=args.continue_job,
32
+ table_name=OBSERVATION_PERIOD,
33
+ query_template=(
34
+ OBSERVATION_PERIOD_WITH_MEASUREMENT_QUERY
35
+ if include_measurement
36
+ else OBSERVATION_PERIOD_QUERY
37
+ ),
38
+ dependency_list=dependency_list,
39
+ ).build()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main(create_omop_argparse().parse_args())
@@ -0,0 +1,38 @@
1
+ import argparse
2
+
3
+ from cehrbert_data.utils.spark_utils import validate_table_names
4
+
5
+
6
+ def create_omop_argparse():
7
+ parser = argparse.ArgumentParser(
8
+ description="Spark application for creating OMOP table"
9
+ )
10
+ parser.add_argument(
11
+ "--input_folder",
12
+ dest="input_folder",
13
+ action="store",
14
+ help="The path for your input_folder where the raw data is",
15
+ required=True,
16
+ )
17
+ parser.add_argument(
18
+ "--output_folder",
19
+ dest="output_folder",
20
+ action="store",
21
+ help="The path for your output_folder",
22
+ required=True,
23
+ )
24
+ parser.add_argument(
25
+ "--continue_job",
26
+ dest="continue_job",
27
+ action="store_true",
28
+ )
29
+ parser.add_argument(
30
+ "--domain_table_list",
31
+ dest="domain_table_list",
32
+ nargs="+",
33
+ action="store",
34
+ help="The list of domain tables you want to download",
35
+ type=validate_table_names,
36
+ required=True,
37
+ )
38
+ return parser
@@ -0,0 +1,86 @@
1
+ from datetime import date
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import cehrbert_data.cohorts.spark_app_base
5
+ from cehrbert_data.cohorts.query_builder import QueryBuilder, QuerySpec
6
+ from cehrbert_data.cohorts.spark_app_base import BaseCohortBuilder
7
+ from cehrbert_data.utils.spark_utils import preprocess_domain_table
8
+
9
+
10
+ def custom_instantiate_dependencies(spark, input_folder, dependency_list):
11
+ dependency_dict = dict()
12
+ for domain_table_name in dependency_list:
13
+ table = preprocess_domain_table(spark, input_folder, domain_table_name)
14
+ table.createOrReplaceTempView(domain_table_name)
15
+ dependency_dict[domain_table_name] = table
16
+ return dependency_dict
17
+
18
+
19
+ # Monkeypatch the function
20
+ cehrbert_data.cohorts.spark_app_base.instantiate_dependencies = (
21
+ custom_instantiate_dependencies
22
+ )
23
+
24
+
25
+ class OmopTableBuilder(BaseCohortBuilder):
26
+ cohort_required_columns = []
27
+
28
+ def __init__(
29
+ self,
30
+ query_builder: QueryBuilder,
31
+ input_folder: str,
32
+ output_folder: str,
33
+ continue_job: bool = False,
34
+ ):
35
+ super().__init__(
36
+ query_builder=query_builder,
37
+ input_folder=input_folder,
38
+ output_folder=output_folder,
39
+ date_lower_bound="1900-01-01",
40
+ date_upper_bound=date.today().strftime("%Y-%m-%d"),
41
+ age_lower_bound=0,
42
+ age_upper_bound=200,
43
+ prior_observation_period=0,
44
+ post_observation_period=0,
45
+ continue_job=continue_job,
46
+ )
47
+
48
+ def build(self):
49
+ # Check whether the cohort has been generated
50
+ if self._continue_job and self.cohort_exists():
51
+ return self
52
+ cohort = self.create_cohort()
53
+ cohort.write.mode("overwrite").parquet(self._output_data_folder)
54
+ return self
55
+
56
+ @staticmethod
57
+ def create_omop_query_builder(
58
+ input_folder: str,
59
+ output_folder: str,
60
+ table_name: str,
61
+ query_template: str,
62
+ dependency_list: List[str],
63
+ query_parameters: Optional[Dict[str, Any]] = None,
64
+ continue_job: bool = False,
65
+ ):
66
+
67
+ if query_parameters is None:
68
+ query_parameters = dict()
69
+ query = QuerySpec(
70
+ table_name=table_name,
71
+ query_template=query_template,
72
+ parameters=query_parameters,
73
+ )
74
+ table_query_builder = QueryBuilder(
75
+ cohort_name=table_name,
76
+ dependency_list=dependency_list,
77
+ query=query,
78
+ ancestor_table_specs=[],
79
+ )
80
+
81
+ return OmopTableBuilder(
82
+ query_builder=table_query_builder,
83
+ input_folder=input_folder,
84
+ output_folder=output_folder,
85
+ continue_job=continue_job,
86
+ )
File without changes
@@ -0,0 +1,86 @@
1
+ CONDITION_ERA_QUERY = """
2
+ -- Common Table Expression for cteConditionTarget
3
+ WITH cteConditionTarget AS (
4
+ SELECT
5
+ co.CONDITION_OCCURRENCE_ID,
6
+ co.PERSON_ID,
7
+ co.CONDITION_CONCEPT_ID,
8
+ co.CONDITION_TYPE_CONCEPT_ID,
9
+ co.CONDITION_START_DATE,
10
+ COALESCE(co.CONDITION_END_DATE, DATE_ADD(co.CONDITION_START_DATE, 1)) AS CONDITION_END_DATE
11
+ FROM CONDITION_OCCURRENCE co
12
+ ),
13
+
14
+ -- Common Table Expression for cteEndDates
15
+ cteEndDates AS (
16
+ SELECT
17
+ PERSON_ID,
18
+ CONDITION_CONCEPT_ID,
19
+ DATE_ADD(EVENT_DATE, -30) AS END_DATE
20
+ FROM (
21
+ SELECT
22
+ PERSON_ID,
23
+ CONDITION_CONCEPT_ID,
24
+ EVENT_DATE,
25
+ EVENT_TYPE,
26
+ MAX(START_ORDINAL) OVER (PARTITION BY PERSON_ID, CONDITION_CONCEPT_ID ORDER BY EVENT_DATE, EVENT_TYPE ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS START_ORDINAL,
27
+ ROW_NUMBER() OVER (PARTITION BY PERSON_ID, CONDITION_CONCEPT_ID ORDER BY EVENT_DATE, EVENT_TYPE) AS OVERALL_ORD
28
+ FROM (
29
+ -- Select the start dates, assigning a row number to each
30
+ SELECT
31
+ PERSON_ID,
32
+ CONDITION_CONCEPT_ID,
33
+ CONDITION_START_DATE AS EVENT_DATE,
34
+ -1 AS EVENT_TYPE,
35
+ ROW_NUMBER() OVER (PARTITION BY PERSON_ID, CONDITION_CONCEPT_ID ORDER BY CONDITION_START_DATE) AS START_ORDINAL
36
+ FROM cteConditionTarget
37
+
38
+ UNION ALL
39
+
40
+ -- Pad the end dates by 30 to allow a grace period for overlapping ranges
41
+ SELECT
42
+ PERSON_ID,
43
+ CONDITION_CONCEPT_ID,
44
+ DATE_ADD(CONDITION_END_DATE, 30) AS EVENT_DATE,
45
+ 1 AS EVENT_TYPE,
46
+ NULL
47
+ FROM cteConditionTarget
48
+ ) RAWDATA
49
+ ) E
50
+ WHERE (2 * E.START_ORDINAL) - E.OVERALL_ORD = 0
51
+ ),
52
+
53
+ -- Common Table Expression for cteConditionEnds
54
+ cteConditionEnds AS (
55
+ SELECT
56
+ c.PERSON_ID,
57
+ c.CONDITION_CONCEPT_ID,
58
+ c.CONDITION_TYPE_CONCEPT_ID,
59
+ c.CONDITION_START_DATE,
60
+ MIN(e.END_DATE) AS ERA_END_DATE
61
+ FROM cteConditionTarget c
62
+ JOIN cteEndDates e ON c.PERSON_ID = e.PERSON_ID AND c.CONDITION_CONCEPT_ID = e.CONDITION_CONCEPT_ID AND e.END_DATE >= c.CONDITION_START_DATE
63
+ GROUP BY
64
+ c.PERSON_ID,
65
+ c.CONDITION_CONCEPT_ID,
66
+ c.CONDITION_TYPE_CONCEPT_ID,
67
+ c.CONDITION_START_DATE
68
+ )
69
+
70
+ SELECT
71
+ ROW_NUMBER() OVER (ORDER BY CONDITION_CONCEPT_ID) AS condition_era_id,
72
+ PERSON_ID,
73
+ CONDITION_CONCEPT_ID,
74
+ MIN(CONDITION_START_DATE) AS CONDITION_ERA_START_DATE,
75
+ ERA_END_DATE AS CONDITION_ERA_END_DATE,
76
+ COUNT(*) AS CONDITION_OCCURRENCE_COUNT
77
+ FROM cteConditionEnds
78
+ GROUP BY
79
+ PERSON_ID,
80
+ CONDITION_CONCEPT_ID,
81
+ CONDITION_TYPE_CONCEPT_ID,
82
+ ERA_END_DATE
83
+ ORDER BY
84
+ PERSON_ID,
85
+ CONDITION_CONCEPT_ID
86
+ """
@@ -0,0 +1,135 @@
1
+ OBSERVATION_PERIOD_QUERY = """
2
+ select
3
+ person_id,
4
+ observation_period_start_date,
5
+ case
6
+ when observation_period_end_date >= add_months(current_date(), -12) then current_date()
7
+ else observation_period_end_date
8
+ end as observation_period_end_date,
9
+ period_type_concept_id
10
+ from (
11
+ SELECT person_id,
12
+ MIN(observation_period_start_date) AS observation_period_start_date,
13
+ MAX(observation_period_end_date) AS observation_period_end_date,
14
+ 44814725 as period_type_concept_id
15
+ FROM (
16
+ SELECT pt.person_id AS person_id,
17
+ MIN(vt.visit_start_date) AS observation_period_start_date,
18
+ MAX(vt.visit_end_date) AS observation_period_end_date
19
+ FROM person as pt
20
+ JOIN visit_occurrence as vt ON pt.person_id = vt.person_id
21
+ WHERE YEAR(vt.visit_start_date) >= 1985
22
+ AND vt.visit_start_date <= current_date() --set lower and upper bound to ignore spurious dates
23
+ GROUP BY pt.person_id
24
+
25
+ UNION
26
+
27
+ SELECT pt.person_id AS person_id,
28
+ MIN(co.condition_start_date) AS observation_period_start_date,
29
+ MAX(co.condition_start_date) AS observation_period_end_date
30
+ FROM person as pt
31
+ JOIN condition_occurrence as co ON pt.person_id = co.person_id
32
+ WHERE YEAR(co.condition_start_date) >= 1985
33
+ AND co.condition_start_date <= current_date() --set lower and upper bound to ignore spurious dates
34
+ GROUP BY pt.person_id
35
+
36
+ UNION
37
+
38
+ SELECT pt.person_id AS person_id,
39
+ MIN(po.procedure_date) AS observation_period_start_date,
40
+ MAX(po.procedure_date) AS observation_period_end_date
41
+ FROM person as pt
42
+ JOIN procedure_occurrence as po ON pt.person_id = po.person_id
43
+ WHERE YEAR(po.procedure_date) >= 1985
44
+ AND po.procedure_date <= current_date() --set lower and upper bound to ignore spurious dates
45
+ GROUP BY pt.person_id
46
+
47
+ UNION
48
+
49
+ SELECT pt.person_id AS person_id,
50
+ MIN(de.drug_exposure_start_date) AS observation_period_start_date,
51
+ MAX(de.drug_exposure_start_date) AS observation_period_end_date
52
+ FROM person as pt
53
+ JOIN drug_exposure as de ON pt.person_id = de.person_id
54
+ WHERE YEAR(de.drug_exposure_start_date) >= 1985
55
+ AND de.drug_exposure_start_date <= current_date() --set lower and upper bound to ignore spurious dates
56
+ GROUP BY pt.person_id
57
+ ) as x
58
+ WHERE x.observation_period_end_date IS NOT NULL
59
+ GROUP BY person_id
60
+ ) as z
61
+ """
62
+
63
+ OBSERVATION_PERIOD_WITH_MEASUREMENT_QUERY = """
64
+ select
65
+ person_id,
66
+ observation_period_start_date,
67
+ case
68
+ when observation_period_end_date >= add_months(current_date(), -12) then current_date()
69
+ else observation_period_end_date
70
+ end as observation_period_end_date,
71
+ period_type_concept_id
72
+ from (
73
+ SELECT person_id,
74
+ MIN(observation_period_start_date) AS observation_period_start_date,
75
+ MAX(observation_period_end_date) AS observation_period_end_date,
76
+ 44814725 as period_type_concept_id
77
+ FROM (
78
+ SELECT pt.person_id AS person_id,
79
+ MIN(vt.visit_start_date) AS observation_period_start_date,
80
+ MAX(vt.visit_end_date) AS observation_period_end_date
81
+ FROM person as pt
82
+ JOIN visit_occurrence as vt ON pt.person_id = vt.person_id
83
+ WHERE YEAR(vt.visit_start_date) >= 1985
84
+ AND vt.visit_start_date <= current_date() --set lower and upper bound to ignore spurious dates
85
+ GROUP BY pt.person_id
86
+
87
+ UNION
88
+
89
+ SELECT pt.person_id AS person_id,
90
+ MIN(co.condition_start_date) AS observation_period_start_date,
91
+ MAX(co.condition_start_date) AS observation_period_end_date
92
+ FROM person as pt
93
+ JOIN condition_occurrence as co ON pt.person_id = co.person_id
94
+ WHERE YEAR(co.condition_start_date) >= 1985
95
+ AND co.condition_start_date <= current_date() --set lower and upper bound to ignore spurious dates
96
+ GROUP BY pt.person_id
97
+
98
+ UNION
99
+
100
+ SELECT pt.person_id AS person_id,
101
+ MIN(po.procedure_date) AS observation_period_start_date,
102
+ MAX(po.procedure_date) AS observation_period_end_date
103
+ FROM person as pt
104
+ JOIN procedure_occurrence as po ON pt.person_id = po.person_id
105
+ WHERE YEAR(po.procedure_date) >= 1985
106
+ AND po.procedure_date <= current_date() --set lower and upper bound to ignore spurious dates
107
+ GROUP BY pt.person_id
108
+
109
+ UNION
110
+
111
+ SELECT pt.person_id AS person_id,
112
+ MIN(de.drug_exposure_start_date) AS observation_period_start_date,
113
+ MAX(de.drug_exposure_start_date) AS observation_period_end_date
114
+ FROM person as pt
115
+ JOIN drug_exposure as de ON pt.person_id = de.person_id
116
+ WHERE YEAR(de.drug_exposure_start_date) >= 1985
117
+ AND de.drug_exposure_start_date <= current_date() --set lower and upper bound to ignore spurious dates
118
+ GROUP BY pt.person_id
119
+
120
+ UNION
121
+
122
+ SELECT
123
+ pt.person_id AS person_id,
124
+ MIN(m.measurement_date) AS observation_period_start_date,
125
+ MAX(m.measurement_date) AS observation_period_end_date
126
+ FROM person as pt
127
+ JOIN measurement as m ON pt.person_id = m.person_id
128
+ WHERE YEAR(m.measurement_date) >= 1985
129
+ AND m.measurement_date <= current_date() --set lower and upper bound to ignore spurious dates
130
+ GROUP BY pt.person_id
131
+ ) as x
132
+ WHERE x.observation_period_end_date IS NOT NULL
133
+ GROUP BY person_id
134
+ ) as z
135
+ """
@@ -0,0 +1,71 @@
1
+ import argparse
2
+ import os
3
+
4
+ from pyspark.sql import SparkSession
5
+
6
+ # Define timestamp column for filtering based on the folder name
7
+ omop_tables = [
8
+ "person",
9
+ "death",
10
+ "visit_occurrence",
11
+ "condition_occurrence",
12
+ "procedure_occurrence",
13
+ "drug_exposure",
14
+ "measurement",
15
+ "observation",
16
+ "observation_period",
17
+ "condition_era",
18
+ "drug_era",
19
+ ]
20
+
21
+
22
+ # Main function to process the folders and upload tables
23
+ def main(args):
24
+ spark = (
25
+ SparkSession.builder.appName("Sample OMOP Tables")
26
+ .config("spark.sql.legacy.parquet.int96RebaseModeInRead", "CORRECTED")
27
+ .config("spark.sql.legacy.parquet.int96RebaseModeInWrite", "CORRECTED")
28
+ .config("spark.sql.legacy.parquet.datetimeRebaseModeInRead", "CORRECTED")
29
+ .config("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "CORRECTED")
30
+ .getOrCreate()
31
+ )
32
+ patient_sample = spark.read.parquet(args.person_sample)
33
+ for omop_table in omop_tables:
34
+ omop_table_input = os.path.join(args.omop_folder, omop_table)
35
+ omop_table_output = os.path.join(args.output_folder, omop_table)
36
+ # If the input does not exist, let's skip
37
+ if not os.path.exists(omop_table_input):
38
+ continue
39
+ # If the output already exists, and overwrite is not set to True, let's skip it
40
+ if os.path.exists(omop_table_output) and not args.overwrite:
41
+ continue
42
+ omop_dataframe = spark.read.parquet(omop_table_input)
43
+ sub_omop_dataframe = omop_dataframe.join(
44
+ patient_sample.select("person_id"), "person_id"
45
+ )
46
+ sub_omop_dataframe.write.mode("overwrite" if args.overwrite else None).parquet(
47
+ omop_table_output
48
+ )
49
+
50
+
51
+ # Argument parsing moved under __name__ == "__main__"
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser(description="Arguments for uploading OMOP tables")
54
+ parser.add_argument(
55
+ "--person_sample",
56
+ required=True,
57
+ )
58
+ parser.add_argument(
59
+ "--omop_folder",
60
+ required=True,
61
+ )
62
+ parser.add_argument(
63
+ "--output_folder",
64
+ required=True,
65
+ )
66
+ parser.add_argument(
67
+ "--overwrite",
68
+ action="store_true",
69
+ )
70
+ # Call the main function with parsed arguments
71
+ main(parser.parse_args())
File without changes
@@ -0,0 +1,99 @@
1
+ import dataclasses
2
+ import os
3
+ import sys
4
+ from typing import Tuple
5
+
6
+ from cehrbert.runners.hf_runner_argument_dataclass import (
7
+ DataTrainingArguments,
8
+ ModelArguments,
9
+ )
10
+ from transformers import HfArgumentParser, TrainingArguments
11
+ from transformers.utils import logging
12
+ from trl.trainer.dpo_config import DPOConfig
13
+
14
+ from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
15
+
16
+ LOG = logging.get_logger("transformers")
17
+
18
+
19
+ def parse_dynamic_arguments(
20
+ argument_classes: Tuple[dataclasses.dataclass, ...] = (
21
+ DataTrainingArguments,
22
+ ModelArguments,
23
+ TrainingArguments,
24
+ )
25
+ ) -> Tuple:
26
+ """
27
+ Parses command-line arguments with extended flexibility, allowing for the inclusion of custom argument classes.
28
+
29
+ This function utilizes `HfArgumentParser` to parse arguments from command line input, JSON, or YAML files.
30
+ By default, it expects `ModelArguments`, `DataTrainingArguments`, and `TrainingArguments`, but it can be extended
31
+ with additional argument classes through the `argument_classes` parameter, making it suitable
32
+ for various custom setups.
33
+
34
+ Parameters:
35
+ argument_classes (Tuple[Type]): A tuple of argument classes to be parsed. Defaults to
36
+ `(ModelArguments, DataTrainingArguments, TrainingArguments)`. Additional argument classes can be specified
37
+ for greater flexibility in configuration.
38
+
39
+ Returns:
40
+ Tuple: A tuple of parsed arguments, one for each argument class provided. The order of the returned tuple
41
+ matches the order of the `argument_classes` parameter.
42
+
43
+ Raises:
44
+ FileNotFoundError: If the specified JSON or YAML file does not exist.
45
+ json.JSONDecodeError: If there is an error parsing a JSON file.
46
+ yaml.YAMLError: If there is an error parsing a YAML file.
47
+ Exception: For other issues that occur during argument parsing.
48
+
49
+ Example usage:
50
+ - Command-line: `python training_script.py --model_name_or_path bert-base-uncased --do_train`
51
+ - JSON file: `python training_script.py config.json`
52
+ - YAML file: `python training_script.py config.yaml`
53
+
54
+ Flexibility:
55
+ The function can be customized to include new argument classes as needed:
56
+
57
+ Example with a custom argument class:
58
+ ```python
59
+ class CustomArguments:
60
+ # Define custom arguments here
61
+ pass
62
+
63
+
64
+ custom_args = parse_extended_args(
65
+ (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
66
+ )
67
+ ```
68
+ This example demonstrates how to include additional argument classes
69
+ beyond the defaults for a more tailored setup.
70
+ """
71
+ parser = HfArgumentParser(argument_classes)
72
+
73
+ # Check if input is a JSON or YAML file
74
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
75
+ args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1]))
76
+ elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
77
+ args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1]))
78
+ else:
79
+ args = parser.parse_args_into_dataclasses()
80
+
81
+ return tuple(args)
82
+
83
+
84
+ def parse_runner_args() -> (
85
+ Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments]
86
+ ):
87
+ cehrgpt_args, data_args, model_args, training_args = parse_dynamic_arguments(
88
+ (CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments)
89
+ )
90
+ return cehrgpt_args, data_args, model_args, training_args
91
+
92
+
93
+ def parse_dpo_runner_args() -> (
94
+ Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig]
95
+ ):
96
+ cehrgpt_args, data_args, model_args, dpo_config = parse_dynamic_arguments(
97
+ (CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig)
98
+ )
99
+ return cehrgpt_args, data_args, model_args, dpo_config