cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,243 @@
1
+ import numpy as np
2
+ from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
3
+ from cehrbert.data_generators.hf_data_generator.meds_utils import (
4
+ create_dataset_from_meds_reader,
5
+ )
6
+ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
7
+ from cehrbert.runners.runner_util import (
8
+ get_meds_extension_path,
9
+ load_parquet_as_dataset,
10
+ )
11
+ from datasets import DatasetDict, concatenate_datasets, load_from_disk
12
+ from transformers import TrainingArguments
13
+ from transformers.utils import logging
14
+
15
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
16
+ from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
17
+
18
+ LOG = logging.get_logger("transformers")
19
+
20
+
21
+ def prepare_finetune_dataset(
22
+ data_args: DataTrainingArguments,
23
+ training_args: TrainingArguments,
24
+ cehrgpt_args: CehrGPTArguments,
25
+ cache_file_collector: CacheFileCollector,
26
+ ) -> DatasetDict:
27
+ # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
28
+ if data_args.is_data_in_meds:
29
+ meds_extension_path = get_meds_extension_path(
30
+ data_folder=data_args.cohort_folder,
31
+ dataset_prepared_path=data_args.dataset_prepared_path,
32
+ )
33
+ try:
34
+ LOG.info(
35
+ f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
36
+ )
37
+ dataset = load_from_disk(meds_extension_path)
38
+ if data_args.streaming:
39
+ if isinstance(dataset, DatasetDict):
40
+ dataset = {
41
+ k: v.to_iterable_dataset(
42
+ num_shards=training_args.dataloader_num_workers
43
+ )
44
+ for k, v in dataset.items()
45
+ }
46
+ else:
47
+ dataset = dataset.to_iterable_dataset(
48
+ num_shards=training_args.dataloader_num_workers
49
+ )
50
+ except Exception as e:
51
+ LOG.warning(e)
52
+ dataset = create_dataset_from_meds_reader(
53
+ data_args=data_args,
54
+ dataset_mappings=[
55
+ MedToCehrGPTDatasetMapping(
56
+ data_args=data_args,
57
+ include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
58
+ )
59
+ ],
60
+ cache_file_collector=cache_file_collector,
61
+ )
62
+ if not data_args.streaming:
63
+ dataset.save_to_disk(str(meds_extension_path))
64
+ stats = dataset.cleanup_cache_files()
65
+ LOG.info(
66
+ "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
67
+ stats,
68
+ )
69
+ # Clean up the files created from the data generator
70
+ cache_file_collector.remove_cache_files()
71
+ dataset = load_from_disk(str(meds_extension_path))
72
+
73
+ train_set = dataset["train"]
74
+ validation_set = dataset["validation"]
75
+ test_set = dataset["test"]
76
+
77
+ if cehrgpt_args.meds_repartition:
78
+ train_val_set = concatenate_datasets([train_set, validation_set])
79
+ if data_args.streaming and data_args.validation_split_num:
80
+ train_val_set = train_val_set.shuffle(
81
+ buffer_size=10_000, seed=training_args.seed
82
+ )
83
+ train_set = train_val_set.skip(data_args.validation_split_num)
84
+ validation_set = train_val_set.take(data_args.validation_split_num)
85
+ elif data_args.validation_split_percentage:
86
+ dataset = train_val_set.train_test_split(
87
+ test_size=data_args.validation_split_percentage,
88
+ seed=training_args.seed,
89
+ )
90
+ train_set = dataset["train"]
91
+ validation_set = dataset["test"]
92
+ else:
93
+ raise RuntimeError(
94
+ f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
95
+ f"defined, otherwise validation_split_percentage needs to be provided. "
96
+ f"The current values are:\n"
97
+ f"validation_split_percentage: {data_args.validation_split_percentage}\n"
98
+ f"validation_split_num: {data_args.validation_split_num}\n"
99
+ f"streaming: {data_args.streaming}"
100
+ )
101
+ else:
102
+ train_set, validation_set, test_set = create_dataset_splits(
103
+ data_args=data_args, seed=training_args.seed
104
+ )
105
+ # Organize them into a single DatasetDict
106
+ final_splits = DatasetDict(
107
+ {"train": train_set, "validation": validation_set, "test": test_set}
108
+ )
109
+ return final_splits
110
+
111
+
112
+ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
113
+ """
114
+ Creates training, validation, and testing dataset splits based on specified splitting strategies.
115
+
116
+ This function splits a dataset into training, validation, and test sets, using either chronological,
117
+ patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
118
+
119
+ - **Chronological split**: Sorts by a specified date and splits based on historical and future data.
120
+ - **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
121
+ - **Random split**: Performs a straightforward random split of the dataset.
122
+
123
+ If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
124
+ the test set is created by further splitting the validation set based on `test_eval_ratio`.
125
+
126
+ Parameters:
127
+ data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
128
+ - `data_folder` (str): Path to the main dataset.
129
+ - `test_data_folder` (str, optional): Path to an optional test dataset.
130
+ - `chronological_split` (bool): Whether to split chronologically.
131
+ - `split_by_patient` (bool): Whether to split by unique patient IDs.
132
+ - `validation_split_percentage` (float): Percentage of data to use for validation.
133
+ - `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
134
+ - `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
135
+ - `preprocessing_batch_size` (int): Batch size for batched operations.
136
+ seed (int): Random seed for reproducibility of splits.
137
+
138
+ Returns:
139
+ Tuple[Dataset, Dataset, Dataset]: A tuple containing:
140
+ - `train_set` (Dataset): Training split of the dataset.
141
+ - `validation_set` (Dataset): Validation split of the dataset.
142
+ - `test_set` (Dataset): Test split of the dataset.
143
+
144
+ Raises:
145
+ FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
146
+ ValueError: If incompatible arguments are passed for splitting strategies.
147
+
148
+ Example Usage:
149
+ data_args = DataTrainingArguments(
150
+ data_folder="data/",
151
+ validation_split_percentage=0.1,
152
+ test_eval_ratio=0.2,
153
+ chronological_split=True
154
+ )
155
+ train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
156
+ """
157
+ dataset = load_parquet_as_dataset(data_args.data_folder)
158
+ test_set = (
159
+ None
160
+ if not data_args.test_data_folder
161
+ else load_parquet_as_dataset(data_args.test_data_folder)
162
+ )
163
+
164
+ if data_args.chronological_split:
165
+ # Chronological split by sorting on `index_date`
166
+ dataset = dataset.sort("index_date")
167
+ total_size = len(dataset)
168
+ train_end = int((1 - data_args.validation_split_percentage) * total_size)
169
+
170
+ # Perform the split
171
+ train_set = dataset.select(range(0, train_end))
172
+ validation_set = dataset.select(range(train_end, total_size))
173
+
174
+ if test_set is None:
175
+ test_valid_split = validation_set.train_test_split(
176
+ test_size=data_args.test_eval_ratio, seed=seed
177
+ )
178
+ validation_set, test_set = (
179
+ test_valid_split["train"],
180
+ test_valid_split["test"],
181
+ )
182
+
183
+ elif data_args.split_by_patient:
184
+ # Patient-based split
185
+ LOG.info("Using the split_by_patient strategy")
186
+ unique_patient_ids = dataset.unique("person_id")
187
+ LOG.info(f"There are {len(unique_patient_ids)} patients in total")
188
+
189
+ np.random.seed(seed)
190
+ np.random.shuffle(unique_patient_ids)
191
+
192
+ train_end = int(
193
+ len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
194
+ )
195
+ train_patient_ids = set(unique_patient_ids[:train_end])
196
+
197
+ if test_set is None:
198
+ validation_end = int(
199
+ train_end
200
+ + len(unique_patient_ids)
201
+ * data_args.validation_split_percentage
202
+ * data_args.test_eval_ratio
203
+ )
204
+ val_patient_ids = set(unique_patient_ids[train_end:validation_end])
205
+ test_patient_ids = set(unique_patient_ids[validation_end:])
206
+ else:
207
+ val_patient_ids, test_patient_ids = (
208
+ set(unique_patient_ids[train_end:]),
209
+ None,
210
+ )
211
+
212
+ # Helper function to apply patient-based filtering
213
+ def filter_by_patient_ids(patient_ids):
214
+ return dataset.filter(
215
+ lambda batch: [pid in patient_ids for pid in batch["person_id"]],
216
+ num_proc=data_args.preprocessing_num_workers,
217
+ batched=True,
218
+ batch_size=data_args.preprocessing_batch_size,
219
+ )
220
+
221
+ # Generate splits
222
+ train_set = filter_by_patient_ids(train_patient_ids)
223
+ validation_set = filter_by_patient_ids(val_patient_ids)
224
+ if test_set is None:
225
+ test_set = filter_by_patient_ids(test_patient_ids)
226
+
227
+ else:
228
+ # Random split
229
+ train_val = dataset.train_test_split(
230
+ test_size=data_args.validation_split_percentage, seed=seed
231
+ )
232
+ train_set, validation_set = train_val["train"], train_val["test"]
233
+
234
+ if test_set is None:
235
+ test_valid_split = validation_set.train_test_split(
236
+ test_size=data_args.test_eval_ratio, seed=seed
237
+ )
238
+ validation_set, test_set = (
239
+ test_valid_split["train"],
240
+ test_valid_split["test"],
241
+ )
242
+
243
+ return train_set, validation_set, test_set
@@ -9,7 +9,6 @@ from cehrbert.runners.hf_runner_argument_dataclass import (
9
9
  )
10
10
  from transformers import HfArgumentParser, TrainingArguments
11
11
  from transformers.utils import logging
12
- from trl.trainer.dpo_config import DPOConfig
13
12
 
14
13
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
15
14
 
@@ -88,12 +87,3 @@ def parse_runner_args() -> (
88
87
  (CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments)
89
88
  )
90
89
  return cehrgpt_args, data_args, model_args, training_args
91
-
92
-
93
- def parse_dpo_runner_args() -> (
94
- Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig]
95
- ):
96
- cehrgpt_args, data_args, model_args, dpo_config = parse_dynamic_arguments(
97
- (CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig)
98
- )
99
- return cehrgpt_args, data_args, model_args, dpo_config