cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
cehrgpt/cehrgpt_args.py
ADDED
@@ -0,0 +1,142 @@
|
|
1
|
+
import argparse
|
2
|
+
from enum import Enum
|
3
|
+
|
4
|
+
|
5
|
+
class SamplingStrategy(Enum):
|
6
|
+
TopKStrategy = "TopKStrategy"
|
7
|
+
TopPStrategy = "TopPStrategy"
|
8
|
+
TopMixStrategy = "TopMixStrategy"
|
9
|
+
|
10
|
+
|
11
|
+
def create_inference_base_arg_parser(
|
12
|
+
description: str = "Base arguments for cehr-gpt inference",
|
13
|
+
):
|
14
|
+
parser = argparse.ArgumentParser(description=description)
|
15
|
+
parser.add_argument(
|
16
|
+
"--tokenizer_folder",
|
17
|
+
dest="tokenizer_folder",
|
18
|
+
action="store",
|
19
|
+
help="The path for your model_folder",
|
20
|
+
required=True,
|
21
|
+
)
|
22
|
+
parser.add_argument(
|
23
|
+
"--model_folder",
|
24
|
+
dest="model_folder",
|
25
|
+
action="store",
|
26
|
+
help="The path for your model_folder",
|
27
|
+
required=True,
|
28
|
+
)
|
29
|
+
parser.add_argument(
|
30
|
+
"--output_folder",
|
31
|
+
dest="output_folder",
|
32
|
+
action="store",
|
33
|
+
help="The path for your generated data",
|
34
|
+
required=True,
|
35
|
+
)
|
36
|
+
parser.add_argument(
|
37
|
+
"--batch_size",
|
38
|
+
dest="batch_size",
|
39
|
+
action="store",
|
40
|
+
type=int,
|
41
|
+
help="batch_size",
|
42
|
+
required=True,
|
43
|
+
)
|
44
|
+
parser.add_argument(
|
45
|
+
"--buffer_size",
|
46
|
+
dest="buffer_size",
|
47
|
+
action="store",
|
48
|
+
type=int,
|
49
|
+
default=100,
|
50
|
+
help="buffer_size",
|
51
|
+
required=False,
|
52
|
+
)
|
53
|
+
parser.add_argument(
|
54
|
+
"--context_window",
|
55
|
+
dest="context_window",
|
56
|
+
action="store",
|
57
|
+
type=int,
|
58
|
+
help="The context window of the gpt model",
|
59
|
+
required=True,
|
60
|
+
)
|
61
|
+
parser.add_argument(
|
62
|
+
"--min_num_of_concepts",
|
63
|
+
dest="min_num_of_concepts",
|
64
|
+
action="store",
|
65
|
+
type=int,
|
66
|
+
default=1,
|
67
|
+
required=False,
|
68
|
+
)
|
69
|
+
parser.add_argument(
|
70
|
+
"--sampling_strategy",
|
71
|
+
dest="sampling_strategy",
|
72
|
+
action="store",
|
73
|
+
choices=[e.value for e in SamplingStrategy],
|
74
|
+
help="Pick the sampling strategy from the three options top_k, top_p and top_mix",
|
75
|
+
required=True,
|
76
|
+
)
|
77
|
+
parser.add_argument(
|
78
|
+
"--top_k",
|
79
|
+
dest="top_k",
|
80
|
+
action="store",
|
81
|
+
default=100,
|
82
|
+
type=int,
|
83
|
+
help="The number of top concepts to sample",
|
84
|
+
required=False,
|
85
|
+
)
|
86
|
+
parser.add_argument(
|
87
|
+
"--top_p",
|
88
|
+
dest="top_p",
|
89
|
+
action="store",
|
90
|
+
default=1.0,
|
91
|
+
type=float,
|
92
|
+
help="The accumulative probability of top concepts to sample",
|
93
|
+
required=False,
|
94
|
+
)
|
95
|
+
parser.add_argument(
|
96
|
+
"--temperature",
|
97
|
+
dest="temperature",
|
98
|
+
action="store",
|
99
|
+
default=1.0,
|
100
|
+
type=float,
|
101
|
+
help="The temperature parameter for softmax",
|
102
|
+
required=False,
|
103
|
+
)
|
104
|
+
parser.add_argument(
|
105
|
+
"--repetition_penalty",
|
106
|
+
dest="repetition_penalty",
|
107
|
+
action="store",
|
108
|
+
default=1.0,
|
109
|
+
type=float,
|
110
|
+
help="The repetition penalty during decoding",
|
111
|
+
required=False,
|
112
|
+
)
|
113
|
+
parser.add_argument(
|
114
|
+
"--num_beams",
|
115
|
+
dest="num_beams",
|
116
|
+
action="store",
|
117
|
+
default=1,
|
118
|
+
type=int,
|
119
|
+
required=False,
|
120
|
+
)
|
121
|
+
parser.add_argument(
|
122
|
+
"--num_beam_groups",
|
123
|
+
dest="num_beam_groups",
|
124
|
+
action="store",
|
125
|
+
default=1,
|
126
|
+
type=int,
|
127
|
+
required=False,
|
128
|
+
)
|
129
|
+
parser.add_argument(
|
130
|
+
"--epsilon_cutoff",
|
131
|
+
dest="epsilon_cutoff",
|
132
|
+
action="store",
|
133
|
+
default=0.0,
|
134
|
+
type=float,
|
135
|
+
required=False,
|
136
|
+
)
|
137
|
+
parser.add_argument(
|
138
|
+
"--use_bfloat16",
|
139
|
+
dest="use_bfloat16",
|
140
|
+
action="store_true",
|
141
|
+
)
|
142
|
+
return parser
|
cehrgpt/data/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset import (
|
4
|
+
FINETUNING_COLUMNS,
|
5
|
+
apply_cehrbert_dataset_mapping,
|
6
|
+
)
|
7
|
+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
8
|
+
from datasets import Dataset, DatasetDict
|
9
|
+
|
10
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import (
|
11
|
+
HFCehrGptTokenizationMapping,
|
12
|
+
HFFineTuningMapping,
|
13
|
+
)
|
14
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
15
|
+
|
16
|
+
CEHRGPT_COLUMNS = [
|
17
|
+
"person_id",
|
18
|
+
"concept_ids",
|
19
|
+
"concept_values",
|
20
|
+
"concept_value_masks",
|
21
|
+
"num_of_concepts",
|
22
|
+
"num_of_visits",
|
23
|
+
"values",
|
24
|
+
"value_indicators",
|
25
|
+
]
|
26
|
+
|
27
|
+
TRANSFORMER_COLUMNS = ["input_ids"]
|
28
|
+
|
29
|
+
|
30
|
+
def create_cehrgpt_pretraining_dataset(
|
31
|
+
dataset: Union[Dataset, DatasetDict],
|
32
|
+
cehrgpt_tokenizer: CehrGptTokenizer,
|
33
|
+
data_args: DataTrainingArguments,
|
34
|
+
) -> Dataset:
|
35
|
+
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
|
36
|
+
dataset = apply_cehrbert_dataset_mapping(
|
37
|
+
dataset,
|
38
|
+
HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
|
39
|
+
num_proc=data_args.preprocessing_num_workers,
|
40
|
+
batch_size=data_args.preprocessing_batch_size,
|
41
|
+
streaming=data_args.streaming,
|
42
|
+
)
|
43
|
+
|
44
|
+
if not data_args.streaming:
|
45
|
+
if isinstance(dataset, DatasetDict):
|
46
|
+
all_columns = dataset["train"].column_names
|
47
|
+
else:
|
48
|
+
all_columns = dataset.column_names
|
49
|
+
columns_to_remove = [_ for _ in all_columns if _ not in required_columns]
|
50
|
+
dataset = dataset.remove_columns(columns_to_remove)
|
51
|
+
|
52
|
+
return dataset
|
53
|
+
|
54
|
+
|
55
|
+
def create_cehrgpt_finetuning_dataset(
|
56
|
+
dataset: Union[Dataset, DatasetDict],
|
57
|
+
cehrgpt_tokenizer: CehrGptTokenizer,
|
58
|
+
data_args: DataTrainingArguments,
|
59
|
+
) -> Dataset:
|
60
|
+
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
|
61
|
+
mapping_functions = [
|
62
|
+
HFFineTuningMapping(cehrgpt_tokenizer),
|
63
|
+
]
|
64
|
+
for mapping_function in mapping_functions:
|
65
|
+
dataset = apply_cehrbert_dataset_mapping(
|
66
|
+
dataset,
|
67
|
+
mapping_function,
|
68
|
+
num_proc=data_args.preprocessing_num_workers,
|
69
|
+
batch_size=data_args.preprocessing_batch_size,
|
70
|
+
streaming=data_args.streaming,
|
71
|
+
)
|
72
|
+
|
73
|
+
if not data_args.streaming:
|
74
|
+
if isinstance(dataset, DatasetDict):
|
75
|
+
all_columns = dataset["train"].column_names
|
76
|
+
else:
|
77
|
+
all_columns = dataset.column_names
|
78
|
+
columns_to_remove = [_ for _ in all_columns if _ not in required_columns]
|
79
|
+
dataset = dataset.remove_columns(columns_to_remove)
|
80
|
+
return dataset
|