cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -9,20 +9,16 @@ from pathlib import Path
|
|
9
9
|
import numpy as np
|
10
10
|
import pandas as pd
|
11
11
|
import torch
|
12
|
-
|
13
|
-
|
14
|
-
)
|
12
|
+
import torch.distributed as dist
|
13
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
15
14
|
from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
|
16
15
|
from cehrbert.runners.hf_runner_argument_dataclass import (
|
17
|
-
DataTrainingArguments,
|
18
16
|
FineTuneModelType,
|
19
17
|
ModelArguments,
|
20
18
|
)
|
21
19
|
from cehrbert.runners.runner_util import (
|
22
20
|
generate_prepared_ds_path,
|
23
21
|
get_last_hf_checkpoint,
|
24
|
-
get_meds_extension_path,
|
25
|
-
load_parquet_as_dataset,
|
26
22
|
)
|
27
23
|
from datasets import DatasetDict, concatenate_datasets, load_from_disk
|
28
24
|
from peft import LoraConfig, PeftModel, get_peft_model
|
@@ -38,12 +34,15 @@ from transformers import (
|
|
38
34
|
TrainingArguments,
|
39
35
|
set_seed,
|
40
36
|
)
|
41
|
-
from transformers.
|
37
|
+
from transformers.trainer_utils import is_main_process
|
42
38
|
from transformers.utils import is_flash_attn_2_available, logging
|
43
39
|
|
44
40
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
45
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import
|
46
|
-
|
41
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
42
|
+
CehrGptDataCollator,
|
43
|
+
SamplePackingCehrGptDataCollator,
|
44
|
+
)
|
45
|
+
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
47
46
|
from cehrgpt.models.hf_cehrgpt import (
|
48
47
|
CEHRGPTConfig,
|
49
48
|
CehrGptForClassification,
|
@@ -51,9 +50,16 @@ from cehrgpt.models.hf_cehrgpt import (
|
|
51
50
|
)
|
52
51
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
53
52
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
53
|
+
from cehrgpt.runners.data_utils import (
|
54
|
+
extract_cohort_sequences,
|
55
|
+
get_torch_dtype,
|
56
|
+
prepare_finetune_dataset,
|
57
|
+
)
|
54
58
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
59
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
55
60
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
56
61
|
from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
|
62
|
+
from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
|
57
63
|
|
58
64
|
LOG = logging.get_logger("transformers")
|
59
65
|
|
@@ -140,11 +146,10 @@ def load_finetuned_model(
|
|
140
146
|
raise ValueError(
|
141
147
|
f"finetune_model_type can be one of the following types {FineTuneModelType.POOLING.value}"
|
142
148
|
)
|
143
|
-
|
144
149
|
attn_implementation = (
|
145
150
|
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
146
151
|
)
|
147
|
-
torch_dtype =
|
152
|
+
torch_dtype = get_torch_dtype(model_args.torch_dtype)
|
148
153
|
# Try to create a new model based on the base model
|
149
154
|
try:
|
150
155
|
return finetune_model_cls.from_pretrained(
|
@@ -156,148 +161,25 @@ def load_finetuned_model(
|
|
156
161
|
raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
|
157
162
|
|
158
163
|
|
159
|
-
def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
160
|
-
"""
|
161
|
-
Creates training, validation, and testing dataset splits based on specified splitting strategies.
|
162
|
-
|
163
|
-
This function splits a dataset into training, validation, and test sets, using either chronological,
|
164
|
-
patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
|
165
|
-
|
166
|
-
- **Chronological split**: Sorts by a specified date and splits based on historical and future data.
|
167
|
-
- **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
|
168
|
-
- **Random split**: Performs a straightforward random split of the dataset.
|
169
|
-
|
170
|
-
If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
|
171
|
-
the test set is created by further splitting the validation set based on `test_eval_ratio`.
|
172
|
-
|
173
|
-
Parameters:
|
174
|
-
data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
|
175
|
-
- `data_folder` (str): Path to the main dataset.
|
176
|
-
- `test_data_folder` (str, optional): Path to an optional test dataset.
|
177
|
-
- `chronological_split` (bool): Whether to split chronologically.
|
178
|
-
- `split_by_patient` (bool): Whether to split by unique patient IDs.
|
179
|
-
- `validation_split_percentage` (float): Percentage of data to use for validation.
|
180
|
-
- `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
|
181
|
-
- `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
|
182
|
-
- `preprocessing_batch_size` (int): Batch size for batched operations.
|
183
|
-
seed (int): Random seed for reproducibility of splits.
|
184
|
-
|
185
|
-
Returns:
|
186
|
-
Tuple[Dataset, Dataset, Dataset]: A tuple containing:
|
187
|
-
- `train_set` (Dataset): Training split of the dataset.
|
188
|
-
- `validation_set` (Dataset): Validation split of the dataset.
|
189
|
-
- `test_set` (Dataset): Test split of the dataset.
|
190
|
-
|
191
|
-
Raises:
|
192
|
-
FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
|
193
|
-
ValueError: If incompatible arguments are passed for splitting strategies.
|
194
|
-
|
195
|
-
Example Usage:
|
196
|
-
data_args = DataTrainingArguments(
|
197
|
-
data_folder="data/",
|
198
|
-
validation_split_percentage=0.1,
|
199
|
-
test_eval_ratio=0.2,
|
200
|
-
chronological_split=True
|
201
|
-
)
|
202
|
-
train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
|
203
|
-
"""
|
204
|
-
dataset = load_parquet_as_dataset(data_args.data_folder)
|
205
|
-
test_set = (
|
206
|
-
None
|
207
|
-
if not data_args.test_data_folder
|
208
|
-
else load_parquet_as_dataset(data_args.test_data_folder)
|
209
|
-
)
|
210
|
-
|
211
|
-
if data_args.chronological_split:
|
212
|
-
# Chronological split by sorting on `index_date`
|
213
|
-
dataset = dataset.sort("index_date")
|
214
|
-
total_size = len(dataset)
|
215
|
-
train_end = int((1 - data_args.validation_split_percentage) * total_size)
|
216
|
-
|
217
|
-
# Perform the split
|
218
|
-
train_set = dataset.select(range(0, train_end))
|
219
|
-
validation_set = dataset.select(range(train_end, total_size))
|
220
|
-
|
221
|
-
if test_set is None:
|
222
|
-
test_valid_split = validation_set.train_test_split(
|
223
|
-
test_size=data_args.test_eval_ratio, seed=seed
|
224
|
-
)
|
225
|
-
validation_set, test_set = (
|
226
|
-
test_valid_split["train"],
|
227
|
-
test_valid_split["test"],
|
228
|
-
)
|
229
|
-
|
230
|
-
elif data_args.split_by_patient:
|
231
|
-
# Patient-based split
|
232
|
-
LOG.info("Using the split_by_patient strategy")
|
233
|
-
unique_patient_ids = dataset.unique("person_id")
|
234
|
-
LOG.info(f"There are {len(unique_patient_ids)} patients in total")
|
235
|
-
|
236
|
-
np.random.seed(seed)
|
237
|
-
np.random.shuffle(unique_patient_ids)
|
238
|
-
|
239
|
-
train_end = int(
|
240
|
-
len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
|
241
|
-
)
|
242
|
-
train_patient_ids = set(unique_patient_ids[:train_end])
|
243
|
-
|
244
|
-
if test_set is None:
|
245
|
-
validation_end = int(
|
246
|
-
train_end
|
247
|
-
+ len(unique_patient_ids)
|
248
|
-
* data_args.validation_split_percentage
|
249
|
-
* data_args.test_eval_ratio
|
250
|
-
)
|
251
|
-
val_patient_ids = set(unique_patient_ids[train_end:validation_end])
|
252
|
-
test_patient_ids = set(unique_patient_ids[validation_end:])
|
253
|
-
else:
|
254
|
-
val_patient_ids, test_patient_ids = (
|
255
|
-
set(unique_patient_ids[train_end:]),
|
256
|
-
None,
|
257
|
-
)
|
258
|
-
|
259
|
-
# Helper function to apply patient-based filtering
|
260
|
-
def filter_by_patient_ids(patient_ids):
|
261
|
-
return dataset.filter(
|
262
|
-
lambda batch: [pid in patient_ids for pid in batch["person_id"]],
|
263
|
-
num_proc=data_args.preprocessing_num_workers,
|
264
|
-
batched=True,
|
265
|
-
batch_size=data_args.preprocessing_batch_size,
|
266
|
-
)
|
267
|
-
|
268
|
-
# Generate splits
|
269
|
-
train_set = filter_by_patient_ids(train_patient_ids)
|
270
|
-
validation_set = filter_by_patient_ids(val_patient_ids)
|
271
|
-
if test_set is None:
|
272
|
-
test_set = filter_by_patient_ids(test_patient_ids)
|
273
|
-
|
274
|
-
else:
|
275
|
-
# Random split
|
276
|
-
train_val = dataset.train_test_split(
|
277
|
-
test_size=data_args.validation_split_percentage, seed=seed
|
278
|
-
)
|
279
|
-
train_set, validation_set = train_val["train"], train_val["test"]
|
280
|
-
|
281
|
-
if test_set is None:
|
282
|
-
test_valid_split = validation_set.train_test_split(
|
283
|
-
test_size=data_args.test_eval_ratio, seed=seed
|
284
|
-
)
|
285
|
-
validation_set, test_set = (
|
286
|
-
test_valid_split["train"],
|
287
|
-
test_valid_split["test"],
|
288
|
-
)
|
289
|
-
|
290
|
-
return train_set, validation_set, test_set
|
291
|
-
|
292
|
-
|
293
164
|
def model_init(
|
294
165
|
model_args: ModelArguments,
|
295
166
|
training_args: TrainingArguments,
|
167
|
+
cehrgpt_args: CehrGPTArguments,
|
296
168
|
tokenizer: CehrGptTokenizer,
|
297
169
|
):
|
298
170
|
model = load_finetuned_model(
|
299
171
|
model_args, training_args, model_args.model_name_or_path
|
300
172
|
)
|
173
|
+
|
174
|
+
if cehrgpt_args.class_weights:
|
175
|
+
model.config.class_weights = cehrgpt_args.class_weights
|
176
|
+
LOG.info(f"Setting class_weights to {model.config.class_weights}")
|
177
|
+
|
178
|
+
# Enable position embeddings when position embeddings are disabled in pre-training
|
179
|
+
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
180
|
+
LOG.info(f"Enable the position_embeddings")
|
181
|
+
model.cehrgpt.enable_position_embeddings()
|
182
|
+
|
301
183
|
if model.config.max_position_embeddings < model_args.max_position_embeddings:
|
302
184
|
LOG.info(
|
303
185
|
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
@@ -307,9 +189,6 @@ def model_init(
|
|
307
189
|
# Enable include_values when include_values is set to be False during pre-training
|
308
190
|
if model_args.include_values and not model.cehrgpt.include_values:
|
309
191
|
model.cehrgpt.include_values = True
|
310
|
-
# Enable position embeddings when position embeddings are disabled in pre-training
|
311
|
-
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
312
|
-
model.cehrgpt.exclude_position_ids = False
|
313
192
|
# Expand tokenizer to adapt to the finetuning dataset
|
314
193
|
if model.config.vocab_size < tokenizer.vocab_size:
|
315
194
|
model.resize_token_embeddings(tokenizer.vocab_size)
|
@@ -327,6 +206,7 @@ def model_init(
|
|
327
206
|
model.cehrgpt.update_pretrained_embeddings(
|
328
207
|
tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
|
329
208
|
)
|
209
|
+
|
330
210
|
# Expand value tokenizer to adapt to the fine-tuning dataset
|
331
211
|
if model.config.include_values:
|
332
212
|
if model.config.value_vocab_size < tokenizer.value_vocab_size:
|
@@ -364,16 +244,16 @@ def main():
|
|
364
244
|
prepared_ds_path = generate_prepared_ds_path(
|
365
245
|
data_args, model_args, data_folder=data_args.cohort_folder
|
366
246
|
)
|
367
|
-
|
247
|
+
cache_file_collector = CacheFileCollector()
|
368
248
|
processed_dataset = None
|
369
249
|
if any(prepared_ds_path.glob("*")):
|
370
250
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
371
251
|
processed_dataset = load_from_disk(str(prepared_ds_path))
|
372
252
|
LOG.info("Prepared dataset loaded from disk...")
|
373
253
|
if cehrgpt_args.expand_tokenizer:
|
374
|
-
|
254
|
+
if tokenizer_exists(training_args.output_dir):
|
375
255
|
tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
|
376
|
-
|
256
|
+
else:
|
377
257
|
LOG.warning(
|
378
258
|
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
379
259
|
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
@@ -383,101 +263,86 @@ def main():
|
|
383
263
|
shutil.rmtree(prepared_ds_path)
|
384
264
|
|
385
265
|
if processed_dataset is None:
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
try:
|
393
|
-
LOG.info(
|
394
|
-
f"Trying to load the MEDS extension from disk at {meds_extension_path}..."
|
266
|
+
if is_main_process(training_args.local_rank):
|
267
|
+
# If the full dataset has been tokenized, we don't want to tokenize the cohort containing
|
268
|
+
# the subset of the data. We should slice out the portion of the tokenized sequences for each sample
|
269
|
+
if cehrgpt_args.tokenized_full_dataset_path is not None:
|
270
|
+
processed_dataset = extract_cohort_sequences(
|
271
|
+
data_args, cehrgpt_args, cache_file_collector
|
395
272
|
)
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
}
|
273
|
+
else:
|
274
|
+
final_splits = prepare_finetune_dataset(
|
275
|
+
data_args, training_args, cehrgpt_args, cache_file_collector
|
276
|
+
)
|
277
|
+
if cehrgpt_args.expand_tokenizer:
|
278
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
279
|
+
if tokenizer_exists(new_tokenizer_path):
|
280
|
+
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
405
281
|
else:
|
406
|
-
|
407
|
-
|
282
|
+
# Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
|
283
|
+
# embedded in the pretrained model
|
284
|
+
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
285
|
+
cehrgpt_args.pretrained_embedding_path
|
408
286
|
)
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
287
|
+
if not pretrained_concept_embedding_model.exists:
|
288
|
+
pretrained_concept_embedding_model = (
|
289
|
+
tokenizer.pretrained_concept_embedding_model
|
290
|
+
)
|
291
|
+
tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
292
|
+
cehrgpt_tokenizer=tokenizer,
|
293
|
+
dataset=final_splits["train"],
|
415
294
|
data_args=data_args,
|
416
|
-
|
417
|
-
|
295
|
+
concept_name_mapping={},
|
296
|
+
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
418
297
|
)
|
419
|
-
|
420
|
-
|
298
|
+
tokenizer.save_pretrained(
|
299
|
+
os.path.expanduser(training_args.output_dir)
|
300
|
+
)
|
301
|
+
|
302
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
421
303
|
if not data_args.streaming:
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
)
|
428
|
-
dataset = load_from_disk(str(meds_extension_path))
|
429
|
-
|
430
|
-
train_set = dataset["train"]
|
431
|
-
validation_set = dataset["validation"]
|
432
|
-
test_set = dataset["test"]
|
433
|
-
else:
|
434
|
-
train_set, validation_set, test_set = create_dataset_splits(
|
435
|
-
data_args=data_args, seed=training_args.seed
|
436
|
-
)
|
437
|
-
# Organize them into a single DatasetDict
|
438
|
-
final_splits = DatasetDict(
|
439
|
-
{"train": train_set, "validation": validation_set, "test": test_set}
|
440
|
-
)
|
304
|
+
all_columns = final_splits["train"].column_names
|
305
|
+
if "visit_concept_ids" in all_columns:
|
306
|
+
final_splits = final_splits.remove_columns(
|
307
|
+
["visit_concept_ids"]
|
308
|
+
)
|
441
309
|
|
442
|
-
|
443
|
-
|
444
|
-
try:
|
445
|
-
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
446
|
-
except Exception:
|
447
|
-
# Try to use the defined pretrained embeddings if exists,
|
448
|
-
# Otherwise we default to the pretrained model embedded in the pretrained model
|
449
|
-
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
450
|
-
cehrgpt_args.pretrained_embedding_path
|
451
|
-
)
|
452
|
-
if not pretrained_concept_embedding_model.exists:
|
453
|
-
pretrained_concept_embedding_model = (
|
454
|
-
tokenizer.pretrained_concept_embedding_model
|
455
|
-
)
|
456
|
-
tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
310
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
311
|
+
dataset=final_splits,
|
457
312
|
cehrgpt_tokenizer=tokenizer,
|
458
|
-
dataset=final_splits["train"],
|
459
313
|
data_args=data_args,
|
460
|
-
|
461
|
-
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
314
|
+
cache_file_collector=cache_file_collector,
|
462
315
|
)
|
463
|
-
|
316
|
+
if not data_args.streaming:
|
317
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
318
|
+
stats = processed_dataset.cleanup_cache_files()
|
319
|
+
LOG.info(
|
320
|
+
"Clean up the cached files for the cehrgpt finetuning dataset : %s",
|
321
|
+
stats,
|
322
|
+
)
|
323
|
+
|
324
|
+
# Remove any cached files if there are any
|
325
|
+
cache_file_collector.remove_cache_files()
|
464
326
|
|
465
|
-
|
466
|
-
|
327
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
328
|
+
if dist.is_available() and dist.is_initialized():
|
329
|
+
dist.barrier()
|
330
|
+
|
331
|
+
# Loading tokenizer in all processes in torch distributed training
|
332
|
+
tokenizer_name_or_path = os.path.expanduser(
|
333
|
+
training_args.output_dir
|
334
|
+
if cehrgpt_args.expand_tokenizer
|
335
|
+
else model_args.tokenizer_name_or_path
|
467
336
|
)
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
LOG.info(
|
472
|
-
"Clean up the cached files for the cehrgpt finetuning dataset : %s",
|
473
|
-
stats,
|
474
|
-
)
|
475
|
-
processed_dataset = load_from_disk(str(prepared_ds_path))
|
337
|
+
tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
338
|
+
# Load the dataset from disk again to in torch distributed training
|
339
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
476
340
|
|
477
341
|
# Set seed before initializing model.
|
478
342
|
set_seed(training_args.seed)
|
479
343
|
|
480
|
-
|
344
|
+
if not data_args.streaming and not cehrgpt_args.sample_packing:
|
345
|
+
processed_dataset.set_format("pt")
|
481
346
|
|
482
347
|
if cehrgpt_args.few_shot_predict:
|
483
348
|
# At least we need two examples to have a validation set for early stopping
|
@@ -497,40 +362,76 @@ def main():
|
|
497
362
|
config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
|
498
363
|
if config.max_position_embeddings < model_args.max_position_embeddings:
|
499
364
|
config.max_position_embeddings = model_args.max_position_embeddings
|
365
|
+
|
366
|
+
# persist this parameter in case this is overwritten by sample packing
|
367
|
+
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
368
|
+
|
369
|
+
if cehrgpt_args.sample_packing:
|
370
|
+
trainer_class = partial(
|
371
|
+
SamplePackingTrainer,
|
372
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
373
|
+
max_position_embeddings=config.max_position_embeddings,
|
374
|
+
negative_sampling_probability=cehrgpt_args.negative_sampling_probability,
|
375
|
+
)
|
376
|
+
training_args.per_device_train_batch_size = 1
|
377
|
+
training_args.per_device_eval_batch_size = 1
|
378
|
+
data_collator_fn = partial(
|
379
|
+
SamplePackingCehrGptDataCollator,
|
380
|
+
cehrgpt_args.max_tokens_per_batch,
|
381
|
+
config.max_position_embeddings,
|
382
|
+
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
383
|
+
)
|
384
|
+
else:
|
385
|
+
trainer_class = Trainer
|
386
|
+
data_collator_fn = CehrGptDataCollator
|
387
|
+
|
500
388
|
# We suppress the additional learning objectives in fine-tuning
|
501
|
-
data_collator =
|
389
|
+
data_collator = data_collator_fn(
|
502
390
|
tokenizer=tokenizer,
|
503
391
|
max_length=(
|
504
|
-
|
505
|
-
if
|
506
|
-
else
|
392
|
+
cehrgpt_args.max_tokens_per_batch
|
393
|
+
if cehrgpt_args.sample_packing
|
394
|
+
else (
|
395
|
+
config.max_position_embeddings - 1
|
396
|
+
if config.causal_sfm
|
397
|
+
else config.max_position_embeddings
|
398
|
+
)
|
507
399
|
),
|
508
400
|
include_values=model_args.include_values,
|
509
401
|
pretraining=False,
|
510
402
|
include_ttv_prediction=False,
|
511
403
|
use_sub_time_tokenization=False,
|
512
404
|
include_demographics=cehrgpt_args.include_demographics,
|
405
|
+
add_linear_prob_token=True,
|
513
406
|
)
|
514
407
|
|
515
408
|
if training_args.do_train:
|
516
409
|
if cehrgpt_args.hyperparameter_tuning:
|
517
|
-
model_args.early_stopping_patience = LARGE_INTEGER
|
518
410
|
training_args = perform_hyperparameter_search(
|
519
|
-
|
411
|
+
trainer_class,
|
412
|
+
partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
|
520
413
|
processed_dataset,
|
521
414
|
data_collator,
|
522
415
|
training_args,
|
523
416
|
model_args,
|
524
417
|
cehrgpt_args,
|
525
418
|
)
|
419
|
+
|
420
|
+
if cehrgpt_args.retrain_with_full:
|
526
421
|
# Always retrain with the full set when hyperparameter tuning is set to true
|
527
422
|
retrain_with_full_set(
|
528
|
-
|
423
|
+
trainer_class,
|
424
|
+
model_args,
|
425
|
+
training_args,
|
426
|
+
cehrgpt_args,
|
427
|
+
tokenizer,
|
428
|
+
processed_dataset,
|
429
|
+
data_collator,
|
529
430
|
)
|
530
431
|
else:
|
531
432
|
# Initialize Trainer for final training on the combined train+val set
|
532
|
-
trainer =
|
533
|
-
model=model_init(model_args, training_args, tokenizer),
|
433
|
+
trainer = trainer_class(
|
434
|
+
model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
|
534
435
|
data_collator=data_collator,
|
535
436
|
args=training_args,
|
536
437
|
train_dataset=processed_dataset["train"],
|
@@ -552,47 +453,34 @@ def main():
|
|
552
453
|
trainer.save_metrics("train", metrics)
|
553
454
|
trainer.save_state()
|
554
455
|
|
555
|
-
# Retrain the model with full set using the num of epoches before earlying stopping
|
556
|
-
if cehrgpt_args.retrain_with_full:
|
557
|
-
update_num_epoch_before_early_stopping_callback = None
|
558
|
-
for callback in trainer.callback_handler.callbacks:
|
559
|
-
if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
|
560
|
-
update_num_epoch_before_early_stopping_callback = callback
|
561
|
-
|
562
|
-
if update_num_epoch_before_early_stopping_callback is None:
|
563
|
-
raise RuntimeError(
|
564
|
-
f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
|
565
|
-
)
|
566
|
-
final_num_epochs = (
|
567
|
-
update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
|
568
|
-
)
|
569
|
-
training_args.num_train_epochs = final_num_epochs
|
570
|
-
LOG.info(
|
571
|
-
"Num Epochs before early stopping: %s",
|
572
|
-
training_args.num_train_epochs,
|
573
|
-
)
|
574
|
-
retrain_with_full_set(
|
575
|
-
model_args,
|
576
|
-
training_args,
|
577
|
-
tokenizer,
|
578
|
-
processed_dataset,
|
579
|
-
data_collator,
|
580
|
-
)
|
581
|
-
|
582
456
|
if training_args.do_predict:
|
457
|
+
if cehrgpt_args.sample_packing:
|
458
|
+
batch_sampler = SamplePackingBatchSampler(
|
459
|
+
lengths=processed_dataset["test"]["num_of_concepts"],
|
460
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
461
|
+
max_position_embeddings=config.max_position_embeddings,
|
462
|
+
drop_last=training_args.dataloader_drop_last,
|
463
|
+
seed=training_args.seed,
|
464
|
+
)
|
465
|
+
per_device_eval_batch_size = 1
|
466
|
+
else:
|
467
|
+
batch_sampler = None
|
583
468
|
test_dataloader = DataLoader(
|
584
469
|
dataset=processed_dataset["test"],
|
585
|
-
batch_size=
|
470
|
+
batch_size=per_device_eval_batch_size,
|
586
471
|
num_workers=training_args.dataloader_num_workers,
|
587
472
|
collate_fn=data_collator,
|
588
473
|
pin_memory=training_args.dataloader_pin_memory,
|
474
|
+
batch_sampler=batch_sampler,
|
589
475
|
)
|
590
476
|
do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
|
591
477
|
|
592
478
|
|
593
479
|
def retrain_with_full_set(
|
480
|
+
trainer_class,
|
594
481
|
model_args: ModelArguments,
|
595
482
|
training_args: TrainingArguments,
|
483
|
+
cehrgpt_args: CehrGPTArguments,
|
596
484
|
tokenizer: CehrGptTokenizer,
|
597
485
|
dataset: DatasetDict,
|
598
486
|
data_collator: CehrGptDataCollator,
|
@@ -607,9 +495,11 @@ def retrain_with_full_set(
|
|
607
495
|
and state information.
|
608
496
|
|
609
497
|
Args:
|
498
|
+
trainer_class: Trainer or its subclass
|
610
499
|
model_args (ModelArguments): Model configuration and hyperparameters.
|
611
500
|
training_args (TrainingArguments): Training configuration, including output directory,
|
612
501
|
evaluation strategy, and other training parameters.
|
502
|
+
cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
|
613
503
|
tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
|
614
504
|
dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
|
615
505
|
data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
|
@@ -628,8 +518,8 @@ def retrain_with_full_set(
|
|
628
518
|
# Disable evaluation
|
629
519
|
training_args.evaluation_strategy = "no"
|
630
520
|
checkpoint = get_last_hf_checkpoint(training_args)
|
631
|
-
final_trainer =
|
632
|
-
model=model_init(model_args, training_args, tokenizer),
|
521
|
+
final_trainer = trainer_class(
|
522
|
+
model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
|
633
523
|
data_collator=data_collator,
|
634
524
|
args=training_args,
|
635
525
|
train_dataset=full_dataset,
|
@@ -683,15 +573,15 @@ def do_predict(
|
|
683
573
|
test_losses = []
|
684
574
|
with torch.no_grad():
|
685
575
|
for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
|
686
|
-
person_ids = batch.pop("person_id").numpy().
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
576
|
+
person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
|
577
|
+
if person_ids.ndim == 0:
|
578
|
+
person_ids = np.asarray([person_ids])
|
579
|
+
|
580
|
+
index_dates = batch.pop("index_date").numpy().squeeze()
|
581
|
+
if index_dates.ndim == 0:
|
582
|
+
index_dates = np.asarray([index_dates])
|
583
|
+
index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
|
584
|
+
|
695
585
|
batch = {k: v.to(device) for k, v in batch.items()}
|
696
586
|
# Forward pass
|
697
587
|
output = model(**batch, output_attentions=False, output_hidden_states=False)
|
@@ -699,17 +589,25 @@ def do_predict(
|
|
699
589
|
|
700
590
|
# Collect logits and labels for prediction
|
701
591
|
logits = output.logits.float().cpu().numpy().squeeze()
|
592
|
+
if logits.ndim == 0:
|
593
|
+
logits = np.asarray([logits])
|
594
|
+
probabilities = sigmoid(logits)
|
595
|
+
|
702
596
|
labels = (
|
703
|
-
batch["classifier_label"].float().cpu().numpy().
|
597
|
+
batch["classifier_label"].float().cpu().numpy().astype(bool).squeeze()
|
704
598
|
)
|
705
|
-
|
599
|
+
if labels.ndim == 0:
|
600
|
+
labels = np.asarray([labels])
|
601
|
+
|
706
602
|
# Save predictions to parquet file
|
707
603
|
test_prediction_pd = pd.DataFrame(
|
708
604
|
{
|
709
605
|
"subject_id": person_ids,
|
710
606
|
"prediction_time": index_dates,
|
711
|
-
"
|
712
|
-
"
|
607
|
+
"predicted_boolean_probability": probabilities,
|
608
|
+
"predicted_boolean_value": pd.Series(
|
609
|
+
[None] * len(person_ids), dtype=bool
|
610
|
+
),
|
713
611
|
"boolean_value": labels,
|
714
612
|
}
|
715
613
|
)
|
@@ -723,7 +621,7 @@ def do_predict(
|
|
723
621
|
# Compute metrics and save results
|
724
622
|
metrics = compute_metrics(
|
725
623
|
references=test_prediction_pd.boolean_value,
|
726
|
-
probs=test_prediction_pd.
|
624
|
+
probs=test_prediction_pd.predicted_boolean_probability,
|
727
625
|
)
|
728
626
|
metrics["test_loss"] = np.mean(test_losses)
|
729
627
|
|