cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +244 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/__init__.py +0 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,243 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
|
3
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
4
|
+
create_dataset_from_meds_reader,
|
5
|
+
)
|
6
|
+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
7
|
+
from cehrbert.runners.runner_util import (
|
8
|
+
get_meds_extension_path,
|
9
|
+
load_parquet_as_dataset,
|
10
|
+
)
|
11
|
+
from datasets import DatasetDict, concatenate_datasets, load_from_disk
|
12
|
+
from transformers import TrainingArguments
|
13
|
+
from transformers.utils import logging
|
14
|
+
|
15
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
|
16
|
+
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
17
|
+
|
18
|
+
LOG = logging.get_logger("transformers")
|
19
|
+
|
20
|
+
|
21
|
+
def prepare_finetune_dataset(
|
22
|
+
data_args: DataTrainingArguments,
|
23
|
+
training_args: TrainingArguments,
|
24
|
+
cehrgpt_args: CehrGPTArguments,
|
25
|
+
cache_file_collector: CacheFileCollector,
|
26
|
+
) -> DatasetDict:
|
27
|
+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
28
|
+
if data_args.is_data_in_meds:
|
29
|
+
meds_extension_path = get_meds_extension_path(
|
30
|
+
data_folder=data_args.cohort_folder,
|
31
|
+
dataset_prepared_path=data_args.dataset_prepared_path,
|
32
|
+
)
|
33
|
+
try:
|
34
|
+
LOG.info(
|
35
|
+
f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
|
36
|
+
)
|
37
|
+
dataset = load_from_disk(meds_extension_path)
|
38
|
+
if data_args.streaming:
|
39
|
+
if isinstance(dataset, DatasetDict):
|
40
|
+
dataset = {
|
41
|
+
k: v.to_iterable_dataset(
|
42
|
+
num_shards=training_args.dataloader_num_workers
|
43
|
+
)
|
44
|
+
for k, v in dataset.items()
|
45
|
+
}
|
46
|
+
else:
|
47
|
+
dataset = dataset.to_iterable_dataset(
|
48
|
+
num_shards=training_args.dataloader_num_workers
|
49
|
+
)
|
50
|
+
except Exception as e:
|
51
|
+
LOG.warning(e)
|
52
|
+
dataset = create_dataset_from_meds_reader(
|
53
|
+
data_args=data_args,
|
54
|
+
dataset_mappings=[
|
55
|
+
MedToCehrGPTDatasetMapping(
|
56
|
+
data_args=data_args,
|
57
|
+
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
58
|
+
)
|
59
|
+
],
|
60
|
+
cache_file_collector=cache_file_collector,
|
61
|
+
)
|
62
|
+
if not data_args.streaming:
|
63
|
+
dataset.save_to_disk(str(meds_extension_path))
|
64
|
+
stats = dataset.cleanup_cache_files()
|
65
|
+
LOG.info(
|
66
|
+
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
67
|
+
stats,
|
68
|
+
)
|
69
|
+
# Clean up the files created from the data generator
|
70
|
+
cache_file_collector.remove_cache_files()
|
71
|
+
dataset = load_from_disk(str(meds_extension_path))
|
72
|
+
|
73
|
+
train_set = dataset["train"]
|
74
|
+
validation_set = dataset["validation"]
|
75
|
+
test_set = dataset["test"]
|
76
|
+
|
77
|
+
if cehrgpt_args.meds_repartition:
|
78
|
+
train_val_set = concatenate_datasets([train_set, validation_set])
|
79
|
+
if data_args.streaming and data_args.validation_split_num:
|
80
|
+
train_val_set = train_val_set.shuffle(
|
81
|
+
buffer_size=10_000, seed=training_args.seed
|
82
|
+
)
|
83
|
+
train_set = train_val_set.skip(data_args.validation_split_num)
|
84
|
+
validation_set = train_val_set.take(data_args.validation_split_num)
|
85
|
+
elif data_args.validation_split_percentage:
|
86
|
+
dataset = train_val_set.train_test_split(
|
87
|
+
test_size=data_args.validation_split_percentage,
|
88
|
+
seed=training_args.seed,
|
89
|
+
)
|
90
|
+
train_set = dataset["train"]
|
91
|
+
validation_set = dataset["test"]
|
92
|
+
else:
|
93
|
+
raise RuntimeError(
|
94
|
+
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
95
|
+
f"defined, otherwise validation_split_percentage needs to be provided. "
|
96
|
+
f"The current values are:\n"
|
97
|
+
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
98
|
+
f"validation_split_num: {data_args.validation_split_num}\n"
|
99
|
+
f"streaming: {data_args.streaming}"
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
train_set, validation_set, test_set = create_dataset_splits(
|
103
|
+
data_args=data_args, seed=training_args.seed
|
104
|
+
)
|
105
|
+
# Organize them into a single DatasetDict
|
106
|
+
final_splits = DatasetDict(
|
107
|
+
{"train": train_set, "validation": validation_set, "test": test_set}
|
108
|
+
)
|
109
|
+
return final_splits
|
110
|
+
|
111
|
+
|
112
|
+
def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
113
|
+
"""
|
114
|
+
Creates training, validation, and testing dataset splits based on specified splitting strategies.
|
115
|
+
|
116
|
+
This function splits a dataset into training, validation, and test sets, using either chronological,
|
117
|
+
patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
|
118
|
+
|
119
|
+
- **Chronological split**: Sorts by a specified date and splits based on historical and future data.
|
120
|
+
- **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
|
121
|
+
- **Random split**: Performs a straightforward random split of the dataset.
|
122
|
+
|
123
|
+
If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
|
124
|
+
the test set is created by further splitting the validation set based on `test_eval_ratio`.
|
125
|
+
|
126
|
+
Parameters:
|
127
|
+
data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
|
128
|
+
- `data_folder` (str): Path to the main dataset.
|
129
|
+
- `test_data_folder` (str, optional): Path to an optional test dataset.
|
130
|
+
- `chronological_split` (bool): Whether to split chronologically.
|
131
|
+
- `split_by_patient` (bool): Whether to split by unique patient IDs.
|
132
|
+
- `validation_split_percentage` (float): Percentage of data to use for validation.
|
133
|
+
- `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
|
134
|
+
- `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
|
135
|
+
- `preprocessing_batch_size` (int): Batch size for batched operations.
|
136
|
+
seed (int): Random seed for reproducibility of splits.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
Tuple[Dataset, Dataset, Dataset]: A tuple containing:
|
140
|
+
- `train_set` (Dataset): Training split of the dataset.
|
141
|
+
- `validation_set` (Dataset): Validation split of the dataset.
|
142
|
+
- `test_set` (Dataset): Test split of the dataset.
|
143
|
+
|
144
|
+
Raises:
|
145
|
+
FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
|
146
|
+
ValueError: If incompatible arguments are passed for splitting strategies.
|
147
|
+
|
148
|
+
Example Usage:
|
149
|
+
data_args = DataTrainingArguments(
|
150
|
+
data_folder="data/",
|
151
|
+
validation_split_percentage=0.1,
|
152
|
+
test_eval_ratio=0.2,
|
153
|
+
chronological_split=True
|
154
|
+
)
|
155
|
+
train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
|
156
|
+
"""
|
157
|
+
dataset = load_parquet_as_dataset(data_args.data_folder)
|
158
|
+
test_set = (
|
159
|
+
None
|
160
|
+
if not data_args.test_data_folder
|
161
|
+
else load_parquet_as_dataset(data_args.test_data_folder)
|
162
|
+
)
|
163
|
+
|
164
|
+
if data_args.chronological_split:
|
165
|
+
# Chronological split by sorting on `index_date`
|
166
|
+
dataset = dataset.sort("index_date")
|
167
|
+
total_size = len(dataset)
|
168
|
+
train_end = int((1 - data_args.validation_split_percentage) * total_size)
|
169
|
+
|
170
|
+
# Perform the split
|
171
|
+
train_set = dataset.select(range(0, train_end))
|
172
|
+
validation_set = dataset.select(range(train_end, total_size))
|
173
|
+
|
174
|
+
if test_set is None:
|
175
|
+
test_valid_split = validation_set.train_test_split(
|
176
|
+
test_size=data_args.test_eval_ratio, seed=seed
|
177
|
+
)
|
178
|
+
validation_set, test_set = (
|
179
|
+
test_valid_split["train"],
|
180
|
+
test_valid_split["test"],
|
181
|
+
)
|
182
|
+
|
183
|
+
elif data_args.split_by_patient:
|
184
|
+
# Patient-based split
|
185
|
+
LOG.info("Using the split_by_patient strategy")
|
186
|
+
unique_patient_ids = dataset.unique("person_id")
|
187
|
+
LOG.info(f"There are {len(unique_patient_ids)} patients in total")
|
188
|
+
|
189
|
+
np.random.seed(seed)
|
190
|
+
np.random.shuffle(unique_patient_ids)
|
191
|
+
|
192
|
+
train_end = int(
|
193
|
+
len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
|
194
|
+
)
|
195
|
+
train_patient_ids = set(unique_patient_ids[:train_end])
|
196
|
+
|
197
|
+
if test_set is None:
|
198
|
+
validation_end = int(
|
199
|
+
train_end
|
200
|
+
+ len(unique_patient_ids)
|
201
|
+
* data_args.validation_split_percentage
|
202
|
+
* data_args.test_eval_ratio
|
203
|
+
)
|
204
|
+
val_patient_ids = set(unique_patient_ids[train_end:validation_end])
|
205
|
+
test_patient_ids = set(unique_patient_ids[validation_end:])
|
206
|
+
else:
|
207
|
+
val_patient_ids, test_patient_ids = (
|
208
|
+
set(unique_patient_ids[train_end:]),
|
209
|
+
None,
|
210
|
+
)
|
211
|
+
|
212
|
+
# Helper function to apply patient-based filtering
|
213
|
+
def filter_by_patient_ids(patient_ids):
|
214
|
+
return dataset.filter(
|
215
|
+
lambda batch: [pid in patient_ids for pid in batch["person_id"]],
|
216
|
+
num_proc=data_args.preprocessing_num_workers,
|
217
|
+
batched=True,
|
218
|
+
batch_size=data_args.preprocessing_batch_size,
|
219
|
+
)
|
220
|
+
|
221
|
+
# Generate splits
|
222
|
+
train_set = filter_by_patient_ids(train_patient_ids)
|
223
|
+
validation_set = filter_by_patient_ids(val_patient_ids)
|
224
|
+
if test_set is None:
|
225
|
+
test_set = filter_by_patient_ids(test_patient_ids)
|
226
|
+
|
227
|
+
else:
|
228
|
+
# Random split
|
229
|
+
train_val = dataset.train_test_split(
|
230
|
+
test_size=data_args.validation_split_percentage, seed=seed
|
231
|
+
)
|
232
|
+
train_set, validation_set = train_val["train"], train_val["test"]
|
233
|
+
|
234
|
+
if test_set is None:
|
235
|
+
test_valid_split = validation_set.train_test_split(
|
236
|
+
test_size=data_args.test_eval_ratio, seed=seed
|
237
|
+
)
|
238
|
+
validation_set, test_set = (
|
239
|
+
test_valid_split["train"],
|
240
|
+
test_valid_split["test"],
|
241
|
+
)
|
242
|
+
|
243
|
+
return train_set, validation_set, test_set
|
@@ -9,7 +9,6 @@ from cehrbert.runners.hf_runner_argument_dataclass import (
|
|
9
9
|
)
|
10
10
|
from transformers import HfArgumentParser, TrainingArguments
|
11
11
|
from transformers.utils import logging
|
12
|
-
from trl.trainer.dpo_config import DPOConfig
|
13
12
|
|
14
13
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
15
14
|
|
@@ -88,12 +87,3 @@ def parse_runner_args() -> (
|
|
88
87
|
(CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments)
|
89
88
|
)
|
90
89
|
return cehrgpt_args, data_args, model_args, training_args
|
91
|
-
|
92
|
-
|
93
|
-
def parse_dpo_runner_args() -> (
|
94
|
-
Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig]
|
95
|
-
):
|
96
|
-
cehrgpt_args, data_args, model_args, dpo_config = parse_dynamic_arguments(
|
97
|
-
(CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig)
|
98
|
-
)
|
99
|
-
return cehrgpt_args, data_args, model_args, dpo_config
|