cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,358 @@
1
+ import os
2
+ from datetime import datetime
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import polars as pl
7
+ import torch
8
+ from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
9
+ from cehrbert.data_generators.hf_data_generator.meds_utils import (
10
+ create_dataset_from_meds_reader,
11
+ )
12
+ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
13
+ from cehrbert.runners.runner_util import (
14
+ get_meds_extension_path,
15
+ load_parquet_as_dataset,
16
+ )
17
+ from datasets import DatasetDict, concatenate_datasets, load_from_disk
18
+ from transformers import TrainingArguments
19
+ from transformers.utils import logging
20
+
21
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import (
22
+ ExtractTokenizedSequenceDataMapping,
23
+ MedToCehrGPTDatasetMapping,
24
+ )
25
+ from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
26
+
27
+ LOG = logging.get_logger("transformers")
28
+
29
+
30
+ def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
31
+ if torch_dtype and hasattr(torch, torch_dtype):
32
+ return getattr(torch, torch_dtype)
33
+ return torch.float
34
+
35
+
36
+ def data_collate_fn(features, model_type: torch.dtype, collator):
37
+ batch = collator(features)
38
+ if model_type != torch.float32:
39
+ for key, value in batch.items():
40
+ # Only convert float32 tensors to bfloat16
41
+ if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
42
+ batch[key] = value.to(model_type)
43
+ return batch
44
+
45
+
46
+ def prepare_finetune_dataset(
47
+ data_args: DataTrainingArguments,
48
+ training_args: TrainingArguments,
49
+ cehrgpt_args: CehrGPTArguments,
50
+ cache_file_collector: CacheFileCollector,
51
+ ) -> DatasetDict:
52
+ # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
53
+ if data_args.is_data_in_meds:
54
+ meds_extension_path = get_meds_extension_path(
55
+ data_folder=data_args.cohort_folder,
56
+ dataset_prepared_path=data_args.dataset_prepared_path,
57
+ )
58
+ try:
59
+ LOG.info(
60
+ f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
61
+ )
62
+ dataset = load_from_disk(meds_extension_path)
63
+ if data_args.streaming:
64
+ if isinstance(dataset, DatasetDict):
65
+ dataset = {
66
+ k: v.to_iterable_dataset(
67
+ num_shards=training_args.dataloader_num_workers
68
+ )
69
+ for k, v in dataset.items()
70
+ }
71
+ else:
72
+ dataset = dataset.to_iterable_dataset(
73
+ num_shards=training_args.dataloader_num_workers
74
+ )
75
+ except Exception as e:
76
+ LOG.warning(e)
77
+ dataset = create_dataset_from_meds_reader(
78
+ data_args=data_args,
79
+ dataset_mappings=[
80
+ MedToCehrGPTDatasetMapping(
81
+ data_args=data_args,
82
+ include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
83
+ )
84
+ ],
85
+ cache_file_collector=cache_file_collector,
86
+ )
87
+ if not data_args.streaming:
88
+ dataset.save_to_disk(str(meds_extension_path))
89
+ stats = dataset.cleanup_cache_files()
90
+ LOG.info(
91
+ "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
92
+ stats,
93
+ )
94
+ # Clean up the files created from the data generator
95
+ cache_file_collector.remove_cache_files()
96
+ dataset = load_from_disk(str(meds_extension_path))
97
+
98
+ train_set = dataset["train"]
99
+ validation_set = dataset["validation"]
100
+ test_set = dataset["test"]
101
+
102
+ if cehrgpt_args.meds_repartition:
103
+ train_val_set = concatenate_datasets([train_set, validation_set])
104
+ if data_args.streaming and data_args.validation_split_num:
105
+ train_val_set = train_val_set.shuffle(
106
+ buffer_size=10_000, seed=training_args.seed
107
+ )
108
+ train_set = train_val_set.skip(data_args.validation_split_num)
109
+ validation_set = train_val_set.take(data_args.validation_split_num)
110
+ elif data_args.validation_split_percentage:
111
+ dataset = train_val_set.train_test_split(
112
+ test_size=data_args.validation_split_percentage,
113
+ seed=training_args.seed,
114
+ )
115
+ train_set = dataset["train"]
116
+ validation_set = dataset["test"]
117
+ else:
118
+ raise RuntimeError(
119
+ f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
120
+ f"defined, otherwise validation_split_percentage needs to be provided. "
121
+ f"The current values are:\n"
122
+ f"validation_split_percentage: {data_args.validation_split_percentage}\n"
123
+ f"validation_split_num: {data_args.validation_split_num}\n"
124
+ f"streaming: {data_args.streaming}"
125
+ )
126
+ else:
127
+ train_set, validation_set, test_set = create_dataset_splits(
128
+ data_args=data_args, seed=training_args.seed
129
+ )
130
+ # Organize them into a single DatasetDict
131
+ final_splits = DatasetDict(
132
+ {"train": train_set, "validation": validation_set, "test": test_set}
133
+ )
134
+ return final_splits
135
+
136
+
137
+ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
138
+ """
139
+ Creates training, validation, and testing dataset splits based on specified splitting strategies.
140
+
141
+ This function splits a dataset into training, validation, and test sets, using either chronological,
142
+ patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
143
+
144
+ - **Chronological split**: Sorts by a specified date and splits based on historical and future data.
145
+ - **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
146
+ - **Random split**: Performs a straightforward random split of the dataset.
147
+
148
+ If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
149
+ the test set is created by further splitting the validation set based on `test_eval_ratio`.
150
+
151
+ Parameters:
152
+ data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
153
+ - `data_folder` (str): Path to the main dataset.
154
+ - `test_data_folder` (str, optional): Path to an optional test dataset.
155
+ - `chronological_split` (bool): Whether to split chronologically.
156
+ - `split_by_patient` (bool): Whether to split by unique patient IDs.
157
+ - `validation_split_percentage` (float): Percentage of data to use for validation.
158
+ - `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
159
+ - `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
160
+ - `preprocessing_batch_size` (int): Batch size for batched operations.
161
+ seed (int): Random seed for reproducibility of splits.
162
+
163
+ Returns:
164
+ Tuple[Dataset, Dataset, Dataset]: A tuple containing:
165
+ - `train_set` (Dataset): Training split of the dataset.
166
+ - `validation_set` (Dataset): Validation split of the dataset.
167
+ - `test_set` (Dataset): Test split of the dataset.
168
+
169
+ Raises:
170
+ FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
171
+ ValueError: If incompatible arguments are passed for splitting strategies.
172
+
173
+ Example Usage:
174
+ data_args = DataTrainingArguments(
175
+ data_folder="data/",
176
+ validation_split_percentage=0.1,
177
+ test_eval_ratio=0.2,
178
+ chronological_split=True
179
+ )
180
+ train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
181
+ """
182
+ dataset = load_parquet_as_dataset(data_args.data_folder)
183
+ test_set = (
184
+ None
185
+ if not data_args.test_data_folder
186
+ else load_parquet_as_dataset(data_args.test_data_folder)
187
+ )
188
+
189
+ if data_args.chronological_split:
190
+ # Chronological split by sorting on `index_date`
191
+ dataset = dataset.sort("index_date")
192
+ total_size = len(dataset)
193
+ train_end = int((1 - data_args.validation_split_percentage) * total_size)
194
+
195
+ # Perform the split
196
+ train_set = dataset.select(range(0, train_end))
197
+ validation_set = dataset.select(range(train_end, total_size))
198
+
199
+ if test_set is None:
200
+ test_valid_split = validation_set.train_test_split(
201
+ test_size=data_args.test_eval_ratio, seed=seed
202
+ )
203
+ validation_set, test_set = (
204
+ test_valid_split["train"],
205
+ test_valid_split["test"],
206
+ )
207
+
208
+ elif data_args.split_by_patient:
209
+ # Patient-based split
210
+ LOG.info("Using the split_by_patient strategy")
211
+ unique_patient_ids = dataset.unique("person_id")
212
+ LOG.info(f"There are {len(unique_patient_ids)} patients in total")
213
+
214
+ np.random.seed(seed)
215
+ np.random.shuffle(unique_patient_ids)
216
+
217
+ train_end = int(
218
+ len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
219
+ )
220
+ train_patient_ids = set(unique_patient_ids[:train_end])
221
+
222
+ if test_set is None:
223
+ validation_end = int(
224
+ train_end
225
+ + len(unique_patient_ids)
226
+ * data_args.validation_split_percentage
227
+ * data_args.test_eval_ratio
228
+ )
229
+ val_patient_ids = set(unique_patient_ids[train_end:validation_end])
230
+ test_patient_ids = set(unique_patient_ids[validation_end:])
231
+ else:
232
+ val_patient_ids, test_patient_ids = (
233
+ set(unique_patient_ids[train_end:]),
234
+ None,
235
+ )
236
+
237
+ # Helper function to apply patient-based filtering
238
+ def filter_by_patient_ids(patient_ids):
239
+ return dataset.filter(
240
+ lambda batch: [pid in patient_ids for pid in batch["person_id"]],
241
+ num_proc=data_args.preprocessing_num_workers,
242
+ batched=True,
243
+ batch_size=data_args.preprocessing_batch_size,
244
+ )
245
+
246
+ # Generate splits
247
+ train_set = filter_by_patient_ids(train_patient_ids).shuffle(seed=seed)
248
+ validation_set = filter_by_patient_ids(val_patient_ids)
249
+ if test_set is None:
250
+ test_set = filter_by_patient_ids(test_patient_ids)
251
+
252
+ else:
253
+ # Random split
254
+ train_val = dataset.train_test_split(
255
+ test_size=data_args.validation_split_percentage, seed=seed
256
+ )
257
+ train_set, validation_set = train_val["train"], train_val["test"]
258
+
259
+ if test_set is None:
260
+ test_valid_split = validation_set.train_test_split(
261
+ test_size=data_args.test_eval_ratio, seed=seed
262
+ )
263
+ validation_set, test_set = (
264
+ test_valid_split["train"],
265
+ test_valid_split["test"],
266
+ )
267
+
268
+ return train_set, validation_set, test_set
269
+
270
+
271
+ def extract_cohort_sequences(
272
+ data_args: DataTrainingArguments,
273
+ cehrgpt_args: CehrGPTArguments,
274
+ cache_file_collector: CacheFileCollector,
275
+ ) -> DatasetDict:
276
+ """
277
+ Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
278
+
279
+ based on the provided cohort Parquet files and observation window constraints.
280
+
281
+ This function performs the following steps:
282
+ 1. Loads cohort definitions from Parquet files located in `data_args.cohort_folder`.
283
+ 2. Renames relevant columns if the data originates from a Meds format.
284
+ 3. Filters a pre-tokenized dataset (loaded from `cehrgpt_args.tokenized_full_dataset_path`)
285
+ to include only patients present in the cohort.
286
+ 4. Aggregates each person's index date and label into a mapping.
287
+ 5. Checks for consistency to ensure all cohort person_ids are present in the tokenized dataset.
288
+ 6. Applies a transformation (`ExtractTokenizedSequenceDataMapping`) to generate
289
+ observation-window-constrained patient sequences.
290
+ 7. Caches both the filtered and processed datasets using the provided `cache_file_collector`.
291
+
292
+ Args:
293
+ data_args (DataTrainingArguments): Configuration parameters for data processing,
294
+ including cohort folder, observation window, batch size, and parallelism.
295
+ cehrgpt_args (CehrGPTArguments): Contains paths to pre-tokenized datasets and CEHR-GPT-specific arguments.
296
+ cache_file_collector (CacheFileCollector): Utility to register and manage dataset cache files.
297
+
298
+ Returns:
299
+ DatasetDict: A Hugging Face `DatasetDict` containing the processed datasets (e.g., train/validation/test),
300
+ where each entry includes sequences filtered and truncated by the observation window.
301
+
302
+ Raises:
303
+ RuntimeError: If any `person_id` in the cohort is missing from the tokenized dataset.
304
+ """
305
+
306
+ cohort = pl.read_parquet(os.path.join(data_args.cohort_folder, "*.parquet"))
307
+ if data_args.is_data_in_meds:
308
+ cohort = cohort.rename(
309
+ mapping={
310
+ "prediction_time": "index_date",
311
+ "subject_id": "person_id",
312
+ }
313
+ )
314
+ all_person_ids = cohort["person_id"].unique().to_list()
315
+ # data_args.observation_window
316
+ tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
317
+ filtered_tokenized_dataset = tokenized_dataset.filter(
318
+ lambda batch: [person_id in all_person_ids for person_id in batch["person_id"]],
319
+ batched=True,
320
+ batch_size=data_args.preprocessing_batch_size,
321
+ num_proc=data_args.preprocessing_num_workers,
322
+ )
323
+ person_index_date_agg = cohort.group_by("person_id").agg(
324
+ pl.struct("index_date", "label").alias("index_date_label")
325
+ )
326
+ # Convert to dictionary
327
+ person_index_date_map: Dict[int, List[datetime]] = dict(
328
+ zip(
329
+ person_index_date_agg["person_id"].to_list(),
330
+ person_index_date_agg["index_date_label"].to_list(),
331
+ )
332
+ )
333
+ LOG.info(f"person_index_date_agg: {person_index_date_agg}")
334
+ tokenized_person_ids = []
335
+ for _, dataset in filtered_tokenized_dataset.items():
336
+ tokenized_person_ids.extend(dataset["person_id"])
337
+ missing_person_ids = [
338
+ person_id
339
+ for person_id in person_index_date_map.keys()
340
+ if person_id not in tokenized_person_ids
341
+ ]
342
+ if missing_person_ids:
343
+ raise RuntimeError(
344
+ f"There are {len(missing_person_ids)} missing in the tokenized dataset. "
345
+ f"The list contains: {missing_person_ids}"
346
+ )
347
+ processed_dataset = filtered_tokenized_dataset.map(
348
+ ExtractTokenizedSequenceDataMapping(
349
+ person_index_date_map, data_args.observation_window
350
+ ).batch_transform,
351
+ batched=True,
352
+ batch_size=data_args.preprocessing_batch_size,
353
+ num_proc=data_args.preprocessing_num_workers,
354
+ remove_columns=filtered_tokenized_dataset["train"].column_names,
355
+ )
356
+ cache_file_collector.add_cache_files(filtered_tokenized_dataset)
357
+ cache_file_collector.add_cache_files(processed_dataset)
358
+ return processed_dataset
@@ -9,7 +9,6 @@ from cehrbert.runners.hf_runner_argument_dataclass import (
9
9
  )
10
10
  from transformers import HfArgumentParser, TrainingArguments
11
11
  from transformers.utils import logging
12
- from trl.trainer.dpo_config import DPOConfig
13
12
 
14
13
  from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
15
14
 
@@ -88,12 +87,3 @@ def parse_runner_args() -> (
88
87
  (CehrGPTArguments, DataTrainingArguments, ModelArguments, TrainingArguments)
89
88
  )
90
89
  return cehrgpt_args, data_args, model_args, training_args
91
-
92
-
93
- def parse_dpo_runner_args() -> (
94
- Tuple[CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig]
95
- ):
96
- cehrgpt_args, data_args, model_args, dpo_config = parse_dynamic_arguments(
97
- (CehrGPTArguments, DataTrainingArguments, ModelArguments, DPOConfig)
98
- )
99
- return cehrgpt_args, data_args, model_args, dpo_config