cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,358 @@
|
|
1
|
+
import os
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Dict, List, Optional, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import polars as pl
|
7
|
+
import torch
|
8
|
+
from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
|
9
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
10
|
+
create_dataset_from_meds_reader,
|
11
|
+
)
|
12
|
+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
13
|
+
from cehrbert.runners.runner_util import (
|
14
|
+
get_meds_extension_path,
|
15
|
+
load_parquet_as_dataset,
|
16
|
+
)
|
17
|
+
from datasets import DatasetDict, concatenate_datasets, load_from_disk
|
18
|
+
from transformers import TrainingArguments
|
19
|
+
from transformers.utils import logging
|
20
|
+
|
21
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import (
|
22
|
+
ExtractTokenizedSequenceDataMapping,
|
23
|
+
MedToCehrGPTDatasetMapping,
|
24
|
+
)
|
25
|
+
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
26
|
+
|
27
|
+
LOG = logging.get_logger("transformers")
|
28
|
+
|
29
|
+
|
30
|
+
def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
|
31
|
+
if torch_dtype and hasattr(torch, torch_dtype):
|
32
|
+
return getattr(torch, torch_dtype)
|
33
|
+
return torch.float
|
34
|
+
|
35
|
+
|
36
|
+
def data_collate_fn(features, model_type: torch.dtype, collator):
|
37
|
+
batch = collator(features)
|
38
|
+
if model_type != torch.float32:
|
39
|
+
for key, value in batch.items():
|
40
|
+
# Only convert float32 tensors to bfloat16
|
41
|
+
if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
|
42
|
+
batch[key] = value.to(model_type)
|
43
|
+
return batch
|
44
|
+
|
45
|
+
|
46
|
+
def prepare_finetune_dataset(
|
47
|
+
data_args: DataTrainingArguments,
|
48
|
+
training_args: TrainingArguments,
|
49
|
+
cehrgpt_args: CehrGPTArguments,
|
50
|
+
cache_file_collector: CacheFileCollector,
|
51
|
+
) -> DatasetDict:
|
52
|
+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
53
|
+
if data_args.is_data_in_meds:
|
54
|
+
meds_extension_path = get_meds_extension_path(
|
55
|
+
data_folder=data_args.cohort_folder,
|
56
|
+
dataset_prepared_path=data_args.dataset_prepared_path,
|
57
|
+
)
|
58
|
+
try:
|
59
|
+
LOG.info(
|
60
|
+
f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
|
61
|
+
)
|
62
|
+
dataset = load_from_disk(meds_extension_path)
|
63
|
+
if data_args.streaming:
|
64
|
+
if isinstance(dataset, DatasetDict):
|
65
|
+
dataset = {
|
66
|
+
k: v.to_iterable_dataset(
|
67
|
+
num_shards=training_args.dataloader_num_workers
|
68
|
+
)
|
69
|
+
for k, v in dataset.items()
|
70
|
+
}
|
71
|
+
else:
|
72
|
+
dataset = dataset.to_iterable_dataset(
|
73
|
+
num_shards=training_args.dataloader_num_workers
|
74
|
+
)
|
75
|
+
except Exception as e:
|
76
|
+
LOG.warning(e)
|
77
|
+
dataset = create_dataset_from_meds_reader(
|
78
|
+
data_args=data_args,
|
79
|
+
dataset_mappings=[
|
80
|
+
MedToCehrGPTDatasetMapping(
|
81
|
+
data_args=data_args,
|
82
|
+
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
83
|
+
)
|
84
|
+
],
|
85
|
+
cache_file_collector=cache_file_collector,
|
86
|
+
)
|
87
|
+
if not data_args.streaming:
|
88
|
+
dataset.save_to_disk(str(meds_extension_path))
|
89
|
+
stats = dataset.cleanup_cache_files()
|
90
|
+
LOG.info(
|
91
|
+
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
92
|
+
stats,
|
93
|
+
)
|
94
|
+
# Clean up the files created from the data generator
|
95
|
+
cache_file_collector.remove_cache_files()
|
96
|
+
dataset = load_from_disk(str(meds_extension_path))
|
97
|
+
|
98
|
+
train_set = dataset["train"]
|
99
|
+
validation_set = dataset["validation"]
|
100
|
+
test_set = dataset["test"]
|
101
|
+
|
102
|
+
if cehrgpt_args.meds_repartition:
|
103
|
+
train_val_set = concatenate_datasets([train_set, validation_set])
|
104
|
+
if data_args.streaming and data_args.validation_split_num:
|
105
|
+
train_val_set = train_val_set.shuffle(
|
106
|
+
buffer_size=10_000, seed=training_args.seed
|
107
|
+
)
|
108
|
+
train_set = train_val_set.skip(data_args.validation_split_num)
|
109
|
+
validation_set = train_val_set.take(data_args.validation_split_num)
|
110
|
+
elif data_args.validation_split_percentage:
|
111
|
+
dataset = train_val_set.train_test_split(
|
112
|
+
test_size=data_args.validation_split_percentage,
|
113
|
+
seed=training_args.seed,
|
114
|
+
)
|
115
|
+
train_set = dataset["train"]
|
116
|
+
validation_set = dataset["test"]
|
117
|
+
else:
|
118
|
+
raise RuntimeError(
|
119
|
+
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
120
|
+
f"defined, otherwise validation_split_percentage needs to be provided. "
|
121
|
+
f"The current values are:\n"
|
122
|
+
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
123
|
+
f"validation_split_num: {data_args.validation_split_num}\n"
|
124
|
+
f"streaming: {data_args.streaming}"
|
125
|
+
)
|
126
|
+
else:
|
127
|
+
train_set, validation_set, test_set = create_dataset_splits(
|
128
|
+
data_args=data_args, seed=training_args.seed
|
129
|
+
)
|
130
|
+
# Organize them into a single DatasetDict
|
131
|
+
final_splits = DatasetDict(
|
132
|
+
{"train": train_set, "validation": validation_set, "test": test_set}
|
133
|
+
)
|
134
|
+
return final_splits
|
135
|
+
|
136
|
+
|
137
|
+
def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
138
|
+
"""
|
139
|
+
Creates training, validation, and testing dataset splits based on specified splitting strategies.
|
140
|
+
|
141
|
+
This function splits a dataset into training, validation, and test sets, using either chronological,
|
142
|
+
patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
|
143
|
+
|
144
|
+
- **Chronological split**: Sorts by a specified date and splits based on historical and future data.
|
145
|
+
- **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
|
146
|
+
- **Random split**: Performs a straightforward random split of the dataset.
|
147
|
+
|
148
|
+
If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
|
149
|
+
the test set is created by further splitting the validation set based on `test_eval_ratio`.
|
150
|
+
|
151
|
+
Parameters:
|
152
|
+
data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
|
153
|
+
- `data_folder` (str): Path to the main dataset.
|
154
|
+
- `test_data_folder` (str, optional): Path to an optional test dataset.
|
155
|
+
- `chronological_split` (bool): Whether to split chronologically.
|
156
|
+
- `split_by_patient` (bool): Whether to split by unique patient IDs.
|
157
|
+
- `validation_split_percentage` (float): Percentage of data to use for validation.
|
158
|
+
- `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
|
159
|
+
- `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
|
160
|
+
- `preprocessing_batch_size` (int): Batch size for batched operations.
|
161
|
+
seed (int): Random seed for reproducibility of splits.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
Tuple[Dataset, Dataset, Dataset]: A tuple containing:
|
165
|
+
- `train_set` (Dataset): Training split of the dataset.
|
166
|
+
- `validation_set` (Dataset): Validation split of the dataset.
|
167
|
+
- `test_set` (Dataset): Test split of the dataset.
|
168
|
+
|
169
|
+
Raises:
|
170
|
+
FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
|
171
|
+
ValueError: If incompatible arguments are passed for splitting strategies.
|
172
|
+
|
173
|
+
Example Usage:
|
174
|
+
data_args = DataTrainingArguments(
|
175
|
+
data_folder="data/",
|
176
|
+
validation_split_percentage=0.1,
|
177
|
+
test_eval_ratio=0.2,
|
178
|
+
chronological_split=True
|
179
|
+
)
|
180
|
+
train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
|
181
|
+
"""
|
182
|
+
dataset = load_parquet_as_dataset(data_args.data_folder)
|
183
|
+
test_set = (
|
184
|
+
None
|
185
|
+
if not data_args.test_data_folder
|
186
|
+
else load_parquet_as_dataset(data_args.test_data_folder)
|
187
|
+
)
|
188
|
+
|
189
|
+
if data_args.chronological_split:
|
190
|
+
# Chronological split by sorting on `index_date`
|
191
|
+
dataset = dataset.sort("index_date")
|
192
|
+
total_size = len(dataset)
|
193
|
+
train_end = int((1 - data_args.validation_split_percentage) * total_size)
|
194
|
+
|
195
|
+
# Perform the split
|
196
|
+
train_set = dataset.select(range(0, train_end))
|
197
|
+
validation_set = dataset.select(range(train_end, total_size))
|
198
|
+
|
199
|
+
if test_set is None:
|
200
|
+
test_valid_split = validation_set.train_test_split(
|
201
|
+
test_size=data_args.test_eval_ratio, seed=seed
|
202
|
+
)
|
203
|
+
validation_set, test_set = (
|
204
|
+
test_valid_split["train"],
|
205
|
+
test_valid_split["test"],
|
206
|
+
)
|
207
|
+
|
208
|
+
elif data_args.split_by_patient:
|
209
|
+
# Patient-based split
|
210
|
+
LOG.info("Using the split_by_patient strategy")
|
211
|
+
unique_patient_ids = dataset.unique("person_id")
|
212
|
+
LOG.info(f"There are {len(unique_patient_ids)} patients in total")
|
213
|
+
|
214
|
+
np.random.seed(seed)
|
215
|
+
np.random.shuffle(unique_patient_ids)
|
216
|
+
|
217
|
+
train_end = int(
|
218
|
+
len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
|
219
|
+
)
|
220
|
+
train_patient_ids = set(unique_patient_ids[:train_end])
|
221
|
+
|
222
|
+
if test_set is None:
|
223
|
+
validation_end = int(
|
224
|
+
train_end
|
225
|
+
+ len(unique_patient_ids)
|
226
|
+
* data_args.validation_split_percentage
|
227
|
+
* data_args.test_eval_ratio
|
228
|
+
)
|
229
|
+
val_patient_ids = set(unique_patient_ids[train_end:validation_end])
|
230
|
+
test_patient_ids = set(unique_patient_ids[validation_end:])
|
231
|
+
else:
|
232
|
+
val_patient_ids, test_patient_ids = (
|
233
|
+
set(unique_patient_ids[train_end:]),
|
234
|
+
None,
|
235
|
+
)
|
236
|
+
|
237
|
+
# Helper function to apply patient-based filtering
|
238
|
+
def filter_by_patient_ids(patient_ids):
|
239
|
+
return dataset.filter(
|
240
|
+
lambda batch: [pid in patient_ids for pid in batch["person_id"]],
|
241
|
+
num_proc=data_args.preprocessing_num_workers,
|
242
|
+
batched=True,
|
243
|
+
batch_size=data_args.preprocessing_batch_size,
|
244
|
+
)
|
245
|
+
|
246
|
+
# Generate splits
|
247
|
+
train_set = filter_by_patient_ids(train_patient_ids).shuffle(seed=seed)
|
248
|
+
validation_set = filter_by_patient_ids(val_patient_ids)
|
249
|
+
if test_set is None:
|
250
|
+
test_set = filter_by_patient_ids(test_patient_ids)
|
251
|
+
|
252
|
+
else:
|
253
|
+
# Random split
|
254
|
+
train_val = dataset.train_test_split(
|
255
|
+
test_size=data_args.validation_split_percentage, seed=seed
|
256
|
+
)
|
257
|
+
train_set, validation_set = train_val["train"], train_val["test"]
|
258
|
+
|
259
|
+
if test_set is None:
|
260
|
+
test_valid_split = validation_set.train_test_split(
|
261
|
+
test_size=data_args.test_eval_ratio, seed=seed
|
262
|
+
)
|
263
|
+
validation_set, test_set = (
|
264
|
+
test_valid_split["train"],
|
265
|
+
test_valid_split["test"],
|
266
|
+
)
|
267
|
+
|
268
|
+
return train_set, validation_set, test_set
|
269
|
+
|
270
|
+
|
271
|
+
def extract_cohort_sequences(
|
272
|
+
data_args: DataTrainingArguments,
|
273
|
+
cehrgpt_args: CehrGPTArguments,
|
274
|
+
cache_file_collector: CacheFileCollector,
|
275
|
+
) -> DatasetDict:
|
276
|
+
"""
|
277
|
+
Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
|
278
|
+
|
279
|
+
based on the provided cohort Parquet files and observation window constraints.
|
280
|
+
|
281
|
+
This function performs the following steps:
|
282
|
+
1. Loads cohort definitions from Parquet files located in `data_args.cohort_folder`.
|
283
|
+
2. Renames relevant columns if the data originates from a Meds format.
|
284
|
+
3. Filters a pre-tokenized dataset (loaded from `cehrgpt_args.tokenized_full_dataset_path`)
|
285
|
+
to include only patients present in the cohort.
|
286
|
+
4. Aggregates each person's index date and label into a mapping.
|
287
|
+
5. Checks for consistency to ensure all cohort person_ids are present in the tokenized dataset.
|
288
|
+
6. Applies a transformation (`ExtractTokenizedSequenceDataMapping`) to generate
|
289
|
+
observation-window-constrained patient sequences.
|
290
|
+
7. Caches both the filtered and processed datasets using the provided `cache_file_collector`.
|
291
|
+
|
292
|
+
Args:
|
293
|
+
data_args (DataTrainingArguments): Configuration parameters for data processing,
|
294
|
+
including cohort folder, observation window, batch size, and parallelism.
|
295
|
+
cehrgpt_args (CehrGPTArguments): Contains paths to pre-tokenized datasets and CEHR-GPT-specific arguments.
|
296
|
+
cache_file_collector (CacheFileCollector): Utility to register and manage dataset cache files.
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
DatasetDict: A Hugging Face `DatasetDict` containing the processed datasets (e.g., train/validation/test),
|
300
|
+
where each entry includes sequences filtered and truncated by the observation window.
|
301
|
+
|
302
|
+
Raises:
|
303
|
+
RuntimeError: If any `person_id` in the cohort is missing from the tokenized dataset.
|
304
|
+
"""
|
305
|
+
|
306
|
+
cohort = pl.read_parquet(os.path.join(data_args.cohort_folder, "*.parquet"))
|
307
|
+
if data_args.is_data_in_meds:
|
308
|
+
cohort = cohort.rename(
|
309
|
+
mapping={
|
310
|
+
"prediction_time": "index_date",
|
311
|
+
"subject_id": "person_id",
|
312
|
+
}
|
313
|
+
)
|
314
|
+
all_person_ids = cohort["person_id"].unique().to_list()
|
315
|
+
# data_args.observation_window
|
316
|
+
tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
|
317
|
+
filtered_tokenized_dataset = tokenized_dataset.filter(
|
318
|
+
lambda batch: [person_id in all_person_ids for person_id in batch["person_id"]],
|
319
|
+
batched=True,
|
320
|
+
batch_size=data_args.preprocessing_batch_size,
|
321
|
+
num_proc=data_args.preprocessing_num_workers,
|
322
|
+
)
|
323
|
+
person_index_date_agg = cohort.group_by("person_id").agg(
|
324
|
+
pl.struct("index_date", "label").alias("index_date_label")
|
325
|
+
)
|
326
|
+
# Convert to dictionary
|
327
|
+
person_index_date_map: Dict[int, List[datetime]] = dict(
|
328
|
+
zip(
|
329
|
+
person_index_date_agg["person_id"].to_list(),
|
330
|
+
person_index_date_agg["index_date_label"].to_list(),
|
331
|
+
)
|
332
|
+
)
|
333
|
+
LOG.info(f"person_index_date_agg: {person_index_date_agg}")
|
334
|
+
tokenized_person_ids = []
|
335
|
+
for _, dataset in filtered_tokenized_dataset.items():
|
336
|
+
tokenized_person_ids.extend(dataset["person_id"])
|
337
|
+
missing_person_ids = [
|
338
|
+
person_id
|
339
|
+
for person_id in person_index_date_map.keys()
|
340
|
+
if person_id not in tokenized_person_ids
|
341
|
+
]
|
342
|
+
if missing_person_ids:
|
343
|
+
raise RuntimeError(
|
344
|
+
f"There are {len(missing_person_ids)} missing in the tokenized dataset. "
|
345
|
+
f"The list contains: {missing_person_ids}"
|
346
|
+
)
|
347
|
+
processed_dataset = filtered_tokenized_dataset.map(
|
348
|
+
ExtractTokenizedSequenceDataMapping(
|
349
|
+
person_index_date_map, data_args.observation_window
|
350
|
+
).batch_transform,
|
351
|
+
batched=True,
|
352
|
+
batch_size=data_args.preprocessing_batch_size,
|
353
|
+
num_proc=data_args.preprocessing_num_workers,
|
354
|
+
remove_columns=filtered_tokenized_dataset["train"].column_names,
|
355
|
+
)
|
356
|
+
cache_file_collector.add_cache_files(filtered_tokenized_dataset)
|
357
|
+
cache_file_collector.add_cache_files(processed_dataset)
|
358
|
+
return processed_dataset
|
@@ -9,7 +9,6 @@ from cehrbert.runners.hf_runner_argument_dataclass import (
|
|
9
9
|
)
|
10
10
|
from transformers import HfArgumentParser, TrainingArguments
|
11
11
|
from transformers.utils import logging
|
12
|
-
from trl.trainer.dpo_config import DPOConfig
|
13
12
|
|
14
13
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
15
14
|
|
@@ -88,12 +87,3 @@ def parse_runner_args() -> (
|
|
88
87
|
(CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments)
|
89
88
|
)
|
90
89
|
return cehrgpt_args, data_args, model_args, training_args
|
91
|
-
|
92
|
-
|
93
|
-
def parse_dpo_runner_args() -> (
|
94
|
-
Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig]
|
95
|
-
):
|
96
|
-
cehrgpt_args, data_args, model_args, dpo_config = parse_dynamic_arguments(
|
97
|
-
(CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig)
|
98
|
-
)
|
99
|
-
return cehrgpt_args, data_args, model_args, dpo_config
|