cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,142 @@
1
+ import argparse
2
+ from enum import Enum
3
+
4
+
5
+ class SamplingStrategy(Enum):
6
+ TopKStrategy = "TopKStrategy"
7
+ TopPStrategy = "TopPStrategy"
8
+ TopMixStrategy = "TopMixStrategy"
9
+
10
+
11
+ def create_inference_base_arg_parser(
12
+ description: str = "Base arguments for cehr-gpt inference",
13
+ ):
14
+ parser = argparse.ArgumentParser(description=description)
15
+ parser.add_argument(
16
+ "--tokenizer_folder",
17
+ dest="tokenizer_folder",
18
+ action="store",
19
+ help="The path for your model_folder",
20
+ required=True,
21
+ )
22
+ parser.add_argument(
23
+ "--model_folder",
24
+ dest="model_folder",
25
+ action="store",
26
+ help="The path for your model_folder",
27
+ required=True,
28
+ )
29
+ parser.add_argument(
30
+ "--output_folder",
31
+ dest="output_folder",
32
+ action="store",
33
+ help="The path for your generated data",
34
+ required=True,
35
+ )
36
+ parser.add_argument(
37
+ "--batch_size",
38
+ dest="batch_size",
39
+ action="store",
40
+ type=int,
41
+ help="batch_size",
42
+ required=True,
43
+ )
44
+ parser.add_argument(
45
+ "--buffer_size",
46
+ dest="buffer_size",
47
+ action="store",
48
+ type=int,
49
+ default=100,
50
+ help="buffer_size",
51
+ required=False,
52
+ )
53
+ parser.add_argument(
54
+ "--context_window",
55
+ dest="context_window",
56
+ action="store",
57
+ type=int,
58
+ help="The context window of the gpt model",
59
+ required=True,
60
+ )
61
+ parser.add_argument(
62
+ "--min_num_of_concepts",
63
+ dest="min_num_of_concepts",
64
+ action="store",
65
+ type=int,
66
+ default=1,
67
+ required=False,
68
+ )
69
+ parser.add_argument(
70
+ "--sampling_strategy",
71
+ dest="sampling_strategy",
72
+ action="store",
73
+ choices=[e.value for e in SamplingStrategy],
74
+ help="Pick the sampling strategy from the three options top_k, top_p and top_mix",
75
+ required=True,
76
+ )
77
+ parser.add_argument(
78
+ "--top_k",
79
+ dest="top_k",
80
+ action="store",
81
+ default=100,
82
+ type=int,
83
+ help="The number of top concepts to sample",
84
+ required=False,
85
+ )
86
+ parser.add_argument(
87
+ "--top_p",
88
+ dest="top_p",
89
+ action="store",
90
+ default=1.0,
91
+ type=float,
92
+ help="The accumulative probability of top concepts to sample",
93
+ required=False,
94
+ )
95
+ parser.add_argument(
96
+ "--temperature",
97
+ dest="temperature",
98
+ action="store",
99
+ default=1.0,
100
+ type=float,
101
+ help="The temperature parameter for softmax",
102
+ required=False,
103
+ )
104
+ parser.add_argument(
105
+ "--repetition_penalty",
106
+ dest="repetition_penalty",
107
+ action="store",
108
+ default=1.0,
109
+ type=float,
110
+ help="The repetition penalty during decoding",
111
+ required=False,
112
+ )
113
+ parser.add_argument(
114
+ "--num_beams",
115
+ dest="num_beams",
116
+ action="store",
117
+ default=1,
118
+ type=int,
119
+ required=False,
120
+ )
121
+ parser.add_argument(
122
+ "--num_beam_groups",
123
+ dest="num_beam_groups",
124
+ action="store",
125
+ default=1,
126
+ type=int,
127
+ required=False,
128
+ )
129
+ parser.add_argument(
130
+ "--epsilon_cutoff",
131
+ dest="epsilon_cutoff",
132
+ action="store",
133
+ default=0.0,
134
+ type=float,
135
+ required=False,
136
+ )
137
+ parser.add_argument(
138
+ "--use_bfloat16",
139
+ dest="use_bfloat16",
140
+ action="store_true",
141
+ )
142
+ return parser
File without changes
@@ -0,0 +1,80 @@
1
+ from typing import Union
2
+
3
+ from cehrbert.data_generators.hf_data_generator.hf_dataset import (
4
+ FINETUNING_COLUMNS,
5
+ apply_cehrbert_dataset_mapping,
6
+ )
7
+ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
8
+ from datasets import Dataset, DatasetDict
9
+
10
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import (
11
+ HFCehrGptTokenizationMapping,
12
+ HFFineTuningMapping,
13
+ )
14
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
15
+
16
+ CEHRGPT_COLUMNS = [
17
+ "person_id",
18
+ "concept_ids",
19
+ "concept_values",
20
+ "concept_value_masks",
21
+ "num_of_concepts",
22
+ "num_of_visits",
23
+ "values",
24
+ "value_indicators",
25
+ ]
26
+
27
+ TRANSFORMER_COLUMNS = ["input_ids"]
28
+
29
+
30
+ def create_cehrgpt_pretraining_dataset(
31
+ dataset: Union[Dataset, DatasetDict],
32
+ cehrgpt_tokenizer: CehrGptTokenizer,
33
+ data_args: DataTrainingArguments,
34
+ ) -> Dataset:
35
+ required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
36
+ dataset = apply_cehrbert_dataset_mapping(
37
+ dataset,
38
+ HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
39
+ num_proc=data_args.preprocessing_num_workers,
40
+ batch_size=data_args.preprocessing_batch_size,
41
+ streaming=data_args.streaming,
42
+ )
43
+
44
+ if not data_args.streaming:
45
+ if isinstance(dataset, DatasetDict):
46
+ all_columns = dataset["train"].column_names
47
+ else:
48
+ all_columns = dataset.column_names
49
+ columns_to_remove = [_ for _ in all_columns if _ not in required_columns]
50
+ dataset = dataset.remove_columns(columns_to_remove)
51
+
52
+ return dataset
53
+
54
+
55
+ def create_cehrgpt_finetuning_dataset(
56
+ dataset: Union[Dataset, DatasetDict],
57
+ cehrgpt_tokenizer: CehrGptTokenizer,
58
+ data_args: DataTrainingArguments,
59
+ ) -> Dataset:
60
+ required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
61
+ mapping_functions = [
62
+ HFFineTuningMapping(cehrgpt_tokenizer),
63
+ ]
64
+ for mapping_function in mapping_functions:
65
+ dataset = apply_cehrbert_dataset_mapping(
66
+ dataset,
67
+ mapping_function,
68
+ num_proc=data_args.preprocessing_num_workers,
69
+ batch_size=data_args.preprocessing_batch_size,
70
+ streaming=data_args.streaming,
71
+ )
72
+
73
+ if not data_args.streaming:
74
+ if isinstance(dataset, DatasetDict):
75
+ all_columns = dataset["train"].column_names
76
+ else:
77
+ all_columns = dataset.column_names
78
+ columns_to_remove = [_ for _ in all_columns if _ not in required_columns]
79
+ dataset = dataset.remove_columns(columns_to_remove)
80
+ return dataset