cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -9,20 +9,16 @@ from pathlib import Path
|
|
9
9
|
import numpy as np
|
10
10
|
import pandas as pd
|
11
11
|
import torch
|
12
|
-
|
13
|
-
|
14
|
-
)
|
12
|
+
import torch.distributed as dist
|
13
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
15
14
|
from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
|
16
15
|
from cehrbert.runners.hf_runner_argument_dataclass import (
|
17
|
-
DataTrainingArguments,
|
18
16
|
FineTuneModelType,
|
19
17
|
ModelArguments,
|
20
18
|
)
|
21
19
|
from cehrbert.runners.runner_util import (
|
22
20
|
generate_prepared_ds_path,
|
23
21
|
get_last_hf_checkpoint,
|
24
|
-
get_meds_extension_path,
|
25
|
-
load_parquet_as_dataset,
|
26
22
|
)
|
27
23
|
from datasets import DatasetDict, concatenate_datasets, load_from_disk
|
28
24
|
from peft import LoraConfig, PeftModel, get_peft_model
|
@@ -38,12 +34,15 @@ from transformers import (
|
|
38
34
|
TrainingArguments,
|
39
35
|
set_seed,
|
40
36
|
)
|
41
|
-
from transformers.
|
37
|
+
from transformers.trainer_utils import is_main_process
|
42
38
|
from transformers.utils import is_flash_attn_2_available, logging
|
43
39
|
|
44
40
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
45
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import
|
46
|
-
|
41
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
42
|
+
CehrGptDataCollator,
|
43
|
+
SamplePackingCehrGptDataCollator,
|
44
|
+
)
|
45
|
+
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
47
46
|
from cehrgpt.models.hf_cehrgpt import (
|
48
47
|
CEHRGPTConfig,
|
49
48
|
CehrGptForClassification,
|
@@ -51,9 +50,12 @@ from cehrgpt.models.hf_cehrgpt import (
|
|
51
50
|
)
|
52
51
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
53
52
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
53
|
+
from cehrgpt.runners.data_utils import prepare_finetune_dataset
|
54
54
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
55
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
55
56
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
56
57
|
from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
|
58
|
+
from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
|
57
59
|
|
58
60
|
LOG = logging.get_logger("transformers")
|
59
61
|
|
@@ -156,140 +158,6 @@ def load_finetuned_model(
|
|
156
158
|
raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
|
157
159
|
|
158
160
|
|
159
|
-
def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
160
|
-
"""
|
161
|
-
Creates training, validation, and testing dataset splits based on specified splitting strategies.
|
162
|
-
|
163
|
-
This function splits a dataset into training, validation, and test sets, using either chronological,
|
164
|
-
patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
|
165
|
-
|
166
|
-
- **Chronological split**: Sorts by a specified date and splits based on historical and future data.
|
167
|
-
- **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
|
168
|
-
- **Random split**: Performs a straightforward random split of the dataset.
|
169
|
-
|
170
|
-
If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
|
171
|
-
the test set is created by further splitting the validation set based on `test_eval_ratio`.
|
172
|
-
|
173
|
-
Parameters:
|
174
|
-
data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
|
175
|
-
- `data_folder` (str): Path to the main dataset.
|
176
|
-
- `test_data_folder` (str, optional): Path to an optional test dataset.
|
177
|
-
- `chronological_split` (bool): Whether to split chronologically.
|
178
|
-
- `split_by_patient` (bool): Whether to split by unique patient IDs.
|
179
|
-
- `validation_split_percentage` (float): Percentage of data to use for validation.
|
180
|
-
- `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
|
181
|
-
- `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
|
182
|
-
- `preprocessing_batch_size` (int): Batch size for batched operations.
|
183
|
-
seed (int): Random seed for reproducibility of splits.
|
184
|
-
|
185
|
-
Returns:
|
186
|
-
Tuple[Dataset, Dataset, Dataset]: A tuple containing:
|
187
|
-
- `train_set` (Dataset): Training split of the dataset.
|
188
|
-
- `validation_set` (Dataset): Validation split of the dataset.
|
189
|
-
- `test_set` (Dataset): Test split of the dataset.
|
190
|
-
|
191
|
-
Raises:
|
192
|
-
FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
|
193
|
-
ValueError: If incompatible arguments are passed for splitting strategies.
|
194
|
-
|
195
|
-
Example Usage:
|
196
|
-
data_args = DataTrainingArguments(
|
197
|
-
data_folder="data/",
|
198
|
-
validation_split_percentage=0.1,
|
199
|
-
test_eval_ratio=0.2,
|
200
|
-
chronological_split=True
|
201
|
-
)
|
202
|
-
train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
|
203
|
-
"""
|
204
|
-
dataset = load_parquet_as_dataset(data_args.data_folder)
|
205
|
-
test_set = (
|
206
|
-
None
|
207
|
-
if not data_args.test_data_folder
|
208
|
-
else load_parquet_as_dataset(data_args.test_data_folder)
|
209
|
-
)
|
210
|
-
|
211
|
-
if data_args.chronological_split:
|
212
|
-
# Chronological split by sorting on `index_date`
|
213
|
-
dataset = dataset.sort("index_date")
|
214
|
-
total_size = len(dataset)
|
215
|
-
train_end = int((1 - data_args.validation_split_percentage) * total_size)
|
216
|
-
|
217
|
-
# Perform the split
|
218
|
-
train_set = dataset.select(range(0, train_end))
|
219
|
-
validation_set = dataset.select(range(train_end, total_size))
|
220
|
-
|
221
|
-
if test_set is None:
|
222
|
-
test_valid_split = validation_set.train_test_split(
|
223
|
-
test_size=data_args.test_eval_ratio, seed=seed
|
224
|
-
)
|
225
|
-
validation_set, test_set = (
|
226
|
-
test_valid_split["train"],
|
227
|
-
test_valid_split["test"],
|
228
|
-
)
|
229
|
-
|
230
|
-
elif data_args.split_by_patient:
|
231
|
-
# Patient-based split
|
232
|
-
LOG.info("Using the split_by_patient strategy")
|
233
|
-
unique_patient_ids = dataset.unique("person_id")
|
234
|
-
LOG.info(f"There are {len(unique_patient_ids)} patients in total")
|
235
|
-
|
236
|
-
np.random.seed(seed)
|
237
|
-
np.random.shuffle(unique_patient_ids)
|
238
|
-
|
239
|
-
train_end = int(
|
240
|
-
len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
|
241
|
-
)
|
242
|
-
train_patient_ids = set(unique_patient_ids[:train_end])
|
243
|
-
|
244
|
-
if test_set is None:
|
245
|
-
validation_end = int(
|
246
|
-
train_end
|
247
|
-
+ len(unique_patient_ids)
|
248
|
-
* data_args.validation_split_percentage
|
249
|
-
* data_args.test_eval_ratio
|
250
|
-
)
|
251
|
-
val_patient_ids = set(unique_patient_ids[train_end:validation_end])
|
252
|
-
test_patient_ids = set(unique_patient_ids[validation_end:])
|
253
|
-
else:
|
254
|
-
val_patient_ids, test_patient_ids = (
|
255
|
-
set(unique_patient_ids[train_end:]),
|
256
|
-
None,
|
257
|
-
)
|
258
|
-
|
259
|
-
# Helper function to apply patient-based filtering
|
260
|
-
def filter_by_patient_ids(patient_ids):
|
261
|
-
return dataset.filter(
|
262
|
-
lambda batch: [pid in patient_ids for pid in batch["person_id"]],
|
263
|
-
num_proc=data_args.preprocessing_num_workers,
|
264
|
-
batched=True,
|
265
|
-
batch_size=data_args.preprocessing_batch_size,
|
266
|
-
)
|
267
|
-
|
268
|
-
# Generate splits
|
269
|
-
train_set = filter_by_patient_ids(train_patient_ids)
|
270
|
-
validation_set = filter_by_patient_ids(val_patient_ids)
|
271
|
-
if test_set is None:
|
272
|
-
test_set = filter_by_patient_ids(test_patient_ids)
|
273
|
-
|
274
|
-
else:
|
275
|
-
# Random split
|
276
|
-
train_val = dataset.train_test_split(
|
277
|
-
test_size=data_args.validation_split_percentage, seed=seed
|
278
|
-
)
|
279
|
-
train_set, validation_set = train_val["train"], train_val["test"]
|
280
|
-
|
281
|
-
if test_set is None:
|
282
|
-
test_valid_split = validation_set.train_test_split(
|
283
|
-
test_size=data_args.test_eval_ratio, seed=seed
|
284
|
-
)
|
285
|
-
validation_set, test_set = (
|
286
|
-
test_valid_split["train"],
|
287
|
-
test_valid_split["test"],
|
288
|
-
)
|
289
|
-
|
290
|
-
return train_set, validation_set, test_set
|
291
|
-
|
292
|
-
|
293
161
|
def model_init(
|
294
162
|
model_args: ModelArguments,
|
295
163
|
training_args: TrainingArguments,
|
@@ -364,16 +232,16 @@ def main():
|
|
364
232
|
prepared_ds_path = generate_prepared_ds_path(
|
365
233
|
data_args, model_args, data_folder=data_args.cohort_folder
|
366
234
|
)
|
367
|
-
|
235
|
+
cache_file_collector = CacheFileCollector()
|
368
236
|
processed_dataset = None
|
369
237
|
if any(prepared_ds_path.glob("*")):
|
370
238
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
371
239
|
processed_dataset = load_from_disk(str(prepared_ds_path))
|
372
240
|
LOG.info("Prepared dataset loaded from disk...")
|
373
241
|
if cehrgpt_args.expand_tokenizer:
|
374
|
-
|
242
|
+
if tokenizer_exists(training_args.output_dir):
|
375
243
|
tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
|
376
|
-
|
244
|
+
else:
|
377
245
|
LOG.warning(
|
378
246
|
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
379
247
|
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
@@ -383,101 +251,77 @@ def main():
|
|
383
251
|
shutil.rmtree(prepared_ds_path)
|
384
252
|
|
385
253
|
if processed_dataset is None:
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
data_folder=data_args.cohort_folder,
|
390
|
-
dataset_prepared_path=data_args.dataset_prepared_path,
|
254
|
+
if is_main_process(training_args.local_rank):
|
255
|
+
final_splits = prepare_finetune_dataset(
|
256
|
+
data_args, training_args, cehrgpt_args, cache_file_collector
|
391
257
|
)
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
else:
|
406
|
-
dataset = dataset.to_iterable_dataset(
|
407
|
-
num_shards=training_args.dataloader_num_workers
|
408
|
-
)
|
409
|
-
except Exception as e:
|
410
|
-
LOG.exception(e)
|
411
|
-
dataset = create_dataset_from_meds_reader(
|
412
|
-
data_args=data_args,
|
413
|
-
dataset_mappings=[
|
414
|
-
MedToCehrGPTDatasetMapping(
|
415
|
-
data_args=data_args,
|
416
|
-
is_pretraining=False,
|
417
|
-
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
258
|
+
if cehrgpt_args.expand_tokenizer:
|
259
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
260
|
+
if tokenizer_exists(new_tokenizer_path):
|
261
|
+
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
262
|
+
else:
|
263
|
+
# Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
|
264
|
+
# embedded in the pretrained model
|
265
|
+
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
266
|
+
cehrgpt_args.pretrained_embedding_path
|
267
|
+
)
|
268
|
+
if not pretrained_concept_embedding_model.exists:
|
269
|
+
pretrained_concept_embedding_model = (
|
270
|
+
tokenizer.pretrained_concept_embedding_model
|
418
271
|
)
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
272
|
+
tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
273
|
+
cehrgpt_tokenizer=tokenizer,
|
274
|
+
dataset=final_splits["train"],
|
275
|
+
data_args=data_args,
|
276
|
+
concept_name_mapping={},
|
277
|
+
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
278
|
+
)
|
279
|
+
tokenizer.save_pretrained(
|
280
|
+
os.path.expanduser(training_args.output_dir)
|
427
281
|
)
|
428
|
-
dataset = load_from_disk(str(meds_extension_path))
|
429
282
|
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
283
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
284
|
+
if not data_args.streaming:
|
285
|
+
all_columns = final_splits["train"].column_names
|
286
|
+
if "visit_concept_ids" in all_columns:
|
287
|
+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
|
288
|
+
|
289
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
290
|
+
dataset=final_splits,
|
291
|
+
cehrgpt_tokenizer=tokenizer,
|
292
|
+
data_args=data_args,
|
293
|
+
cache_file_collector=cache_file_collector,
|
436
294
|
)
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
444
|
-
try:
|
445
|
-
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
446
|
-
except Exception:
|
447
|
-
# Try to use the defined pretrained embeddings if exists,
|
448
|
-
# Otherwise we default to the pretrained model embedded in the pretrained model
|
449
|
-
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
450
|
-
cehrgpt_args.pretrained_embedding_path
|
451
|
-
)
|
452
|
-
if not pretrained_concept_embedding_model.exists:
|
453
|
-
pretrained_concept_embedding_model = (
|
454
|
-
tokenizer.pretrained_concept_embedding_model
|
455
|
-
)
|
456
|
-
tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
457
|
-
cehrgpt_tokenizer=tokenizer,
|
458
|
-
dataset=final_splits["train"],
|
459
|
-
data_args=data_args,
|
460
|
-
concept_name_mapping={},
|
461
|
-
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
295
|
+
if not data_args.streaming:
|
296
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
297
|
+
stats = processed_dataset.cleanup_cache_files()
|
298
|
+
LOG.info(
|
299
|
+
"Clean up the cached files for the cehrgpt finetuning dataset : %s",
|
300
|
+
stats,
|
462
301
|
)
|
463
|
-
tokenizer.save_pretrained(os.path.expanduser(training_args.output_dir))
|
464
302
|
|
465
|
-
|
466
|
-
|
303
|
+
# Remove any cached files if there are any
|
304
|
+
cache_file_collector.remove_cache_files()
|
305
|
+
|
306
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
307
|
+
if dist.is_available() and dist.is_initialized():
|
308
|
+
dist.barrier()
|
309
|
+
|
310
|
+
# Loading tokenizer in all processes in torch distributed training
|
311
|
+
tokenizer_name_or_path = os.path.expanduser(
|
312
|
+
training_args.output_dir
|
313
|
+
if cehrgpt_args.expand_tokenizer
|
314
|
+
else model_args.tokenizer_name_or_path
|
467
315
|
)
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
LOG.info(
|
472
|
-
"Clean up the cached files for the cehrgpt finetuning dataset : %s",
|
473
|
-
stats,
|
474
|
-
)
|
475
|
-
processed_dataset = load_from_disk(str(prepared_ds_path))
|
316
|
+
tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
317
|
+
# Load the dataset from disk again to in torch distributed training
|
318
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
476
319
|
|
477
320
|
# Set seed before initializing model.
|
478
321
|
set_seed(training_args.seed)
|
479
322
|
|
480
|
-
|
323
|
+
if not data_args.streaming and not cehrgpt_args.sample_packing:
|
324
|
+
processed_dataset.set_format("pt")
|
481
325
|
|
482
326
|
if cehrgpt_args.few_shot_predict:
|
483
327
|
# At least we need two examples to have a validation set for early stopping
|
@@ -497,13 +341,40 @@ def main():
|
|
497
341
|
config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
|
498
342
|
if config.max_position_embeddings < model_args.max_position_embeddings:
|
499
343
|
config.max_position_embeddings = model_args.max_position_embeddings
|
344
|
+
|
345
|
+
# persist this parameter in case this is overwritten by sample packing
|
346
|
+
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
347
|
+
|
348
|
+
if cehrgpt_args.sample_packing:
|
349
|
+
trainer_class = partial(
|
350
|
+
SamplePackingTrainer,
|
351
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
352
|
+
max_position_embeddings=config.max_position_embeddings,
|
353
|
+
train_lengths=processed_dataset["train"]["num_of_concepts"],
|
354
|
+
validation_lengths=processed_dataset["validation"]["num_of_concepts"],
|
355
|
+
)
|
356
|
+
training_args.per_device_train_batch_size = 1
|
357
|
+
training_args.per_device_eval_batch_size = 1
|
358
|
+
data_collator_fn = partial(
|
359
|
+
SamplePackingCehrGptDataCollator,
|
360
|
+
cehrgpt_args.max_tokens_per_batch,
|
361
|
+
config.max_position_embeddings,
|
362
|
+
)
|
363
|
+
else:
|
364
|
+
trainer_class = Trainer
|
365
|
+
data_collator_fn = CehrGptDataCollator
|
366
|
+
|
500
367
|
# We suppress the additional learning objectives in fine-tuning
|
501
|
-
data_collator =
|
368
|
+
data_collator = data_collator_fn(
|
502
369
|
tokenizer=tokenizer,
|
503
370
|
max_length=(
|
504
|
-
|
505
|
-
if
|
506
|
-
else
|
371
|
+
cehrgpt_args.max_tokens_per_batch
|
372
|
+
if cehrgpt_args.sample_packing
|
373
|
+
else (
|
374
|
+
config.max_position_embeddings - 1
|
375
|
+
if config.causal_sfm
|
376
|
+
else config.max_position_embeddings
|
377
|
+
)
|
507
378
|
),
|
508
379
|
include_values=model_args.include_values,
|
509
380
|
pretraining=False,
|
@@ -514,8 +385,8 @@ def main():
|
|
514
385
|
|
515
386
|
if training_args.do_train:
|
516
387
|
if cehrgpt_args.hyperparameter_tuning:
|
517
|
-
model_args.early_stopping_patience = LARGE_INTEGER
|
518
388
|
training_args = perform_hyperparameter_search(
|
389
|
+
trainer_class,
|
519
390
|
partial(model_init, model_args, training_args, tokenizer),
|
520
391
|
processed_dataset,
|
521
392
|
data_collator,
|
@@ -523,13 +394,20 @@ def main():
|
|
523
394
|
model_args,
|
524
395
|
cehrgpt_args,
|
525
396
|
)
|
397
|
+
|
398
|
+
if cehrgpt_args.retrain_with_full:
|
526
399
|
# Always retrain with the full set when hyperparameter tuning is set to true
|
527
400
|
retrain_with_full_set(
|
528
|
-
|
401
|
+
trainer_class,
|
402
|
+
model_args,
|
403
|
+
training_args,
|
404
|
+
tokenizer,
|
405
|
+
processed_dataset,
|
406
|
+
data_collator,
|
529
407
|
)
|
530
408
|
else:
|
531
409
|
# Initialize Trainer for final training on the combined train+val set
|
532
|
-
trainer =
|
410
|
+
trainer = trainer_class(
|
533
411
|
model=model_init(model_args, training_args, tokenizer),
|
534
412
|
data_collator=data_collator,
|
535
413
|
args=training_args,
|
@@ -552,45 +430,31 @@ def main():
|
|
552
430
|
trainer.save_metrics("train", metrics)
|
553
431
|
trainer.save_state()
|
554
432
|
|
555
|
-
# Retrain the model with full set using the num of epoches before earlying stopping
|
556
|
-
if cehrgpt_args.retrain_with_full:
|
557
|
-
update_num_epoch_before_early_stopping_callback = None
|
558
|
-
for callback in trainer.callback_handler.callbacks:
|
559
|
-
if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
|
560
|
-
update_num_epoch_before_early_stopping_callback = callback
|
561
|
-
|
562
|
-
if update_num_epoch_before_early_stopping_callback is None:
|
563
|
-
raise RuntimeError(
|
564
|
-
f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
|
565
|
-
)
|
566
|
-
final_num_epochs = (
|
567
|
-
update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
|
568
|
-
)
|
569
|
-
training_args.num_train_epochs = final_num_epochs
|
570
|
-
LOG.info(
|
571
|
-
"Num Epochs before early stopping: %s",
|
572
|
-
training_args.num_train_epochs,
|
573
|
-
)
|
574
|
-
retrain_with_full_set(
|
575
|
-
model_args,
|
576
|
-
training_args,
|
577
|
-
tokenizer,
|
578
|
-
processed_dataset,
|
579
|
-
data_collator,
|
580
|
-
)
|
581
|
-
|
582
433
|
if training_args.do_predict:
|
434
|
+
if cehrgpt_args.sample_packing:
|
435
|
+
batch_sampler = SamplePackingBatchSampler(
|
436
|
+
lengths=processed_dataset["test"]["num_of_concepts"],
|
437
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
438
|
+
max_position_embeddings=config.max_position_embeddings,
|
439
|
+
drop_last=training_args.dataloader_drop_last,
|
440
|
+
seed=training_args.seed,
|
441
|
+
)
|
442
|
+
per_device_eval_batch_size = 1
|
443
|
+
else:
|
444
|
+
batch_sampler = None
|
583
445
|
test_dataloader = DataLoader(
|
584
446
|
dataset=processed_dataset["test"],
|
585
|
-
batch_size=
|
447
|
+
batch_size=per_device_eval_batch_size,
|
586
448
|
num_workers=training_args.dataloader_num_workers,
|
587
449
|
collate_fn=data_collator,
|
588
450
|
pin_memory=training_args.dataloader_pin_memory,
|
451
|
+
batch_sampler=batch_sampler,
|
589
452
|
)
|
590
453
|
do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
|
591
454
|
|
592
455
|
|
593
456
|
def retrain_with_full_set(
|
457
|
+
trainer_class,
|
594
458
|
model_args: ModelArguments,
|
595
459
|
training_args: TrainingArguments,
|
596
460
|
tokenizer: CehrGptTokenizer,
|
@@ -607,6 +471,7 @@ def retrain_with_full_set(
|
|
607
471
|
and state information.
|
608
472
|
|
609
473
|
Args:
|
474
|
+
trainer_class: Trainer or its subclass
|
610
475
|
model_args (ModelArguments): Model configuration and hyperparameters.
|
611
476
|
training_args (TrainingArguments): Training configuration, including output directory,
|
612
477
|
evaluation strategy, and other training parameters.
|
@@ -628,7 +493,7 @@ def retrain_with_full_set(
|
|
628
493
|
# Disable evaluation
|
629
494
|
training_args.evaluation_strategy = "no"
|
630
495
|
checkpoint = get_last_hf_checkpoint(training_args)
|
631
|
-
final_trainer =
|
496
|
+
final_trainer = trainer_class(
|
632
497
|
model=model_init(model_args, training_args, tokenizer),
|
633
498
|
data_collator=data_collator,
|
634
499
|
args=training_args,
|
@@ -683,15 +548,15 @@ def do_predict(
|
|
683
548
|
test_losses = []
|
684
549
|
with torch.no_grad():
|
685
550
|
for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
|
686
|
-
person_ids = batch.pop("person_id").numpy().
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
551
|
+
person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
|
552
|
+
if person_ids.ndim == 0:
|
553
|
+
person_ids = np.asarray([person_ids])
|
554
|
+
|
555
|
+
index_dates = batch.pop("index_date").numpy().squeeze()
|
556
|
+
if index_dates.ndim == 0:
|
557
|
+
index_dates = np.asarray([index_dates])
|
558
|
+
index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
|
559
|
+
|
695
560
|
batch = {k: v.to(device) for k, v in batch.items()}
|
696
561
|
# Forward pass
|
697
562
|
output = model(**batch, output_attentions=False, output_hidden_states=False)
|
@@ -699,17 +564,25 @@ def do_predict(
|
|
699
564
|
|
700
565
|
# Collect logits and labels for prediction
|
701
566
|
logits = output.logits.float().cpu().numpy().squeeze()
|
567
|
+
if logits.ndim == 0:
|
568
|
+
logits = np.asarray([logits])
|
569
|
+
probabilities = sigmoid(logits)
|
570
|
+
|
702
571
|
labels = (
|
703
|
-
batch["classifier_label"].float().cpu().numpy().
|
572
|
+
batch["classifier_label"].float().cpu().numpy().astype(bool).squeeze()
|
704
573
|
)
|
705
|
-
|
574
|
+
if labels.ndim == 0:
|
575
|
+
labels = np.asarray([labels])
|
576
|
+
|
706
577
|
# Save predictions to parquet file
|
707
578
|
test_prediction_pd = pd.DataFrame(
|
708
579
|
{
|
709
580
|
"subject_id": person_ids,
|
710
581
|
"prediction_time": index_dates,
|
711
|
-
"
|
712
|
-
"
|
582
|
+
"predicted_boolean_probability": probabilities,
|
583
|
+
"predicted_boolean_value": pd.Series(
|
584
|
+
[None] * len(person_ids), dtype=bool
|
585
|
+
),
|
713
586
|
"boolean_value": labels,
|
714
587
|
}
|
715
588
|
)
|
@@ -723,7 +596,7 @@ def do_predict(
|
|
723
596
|
# Compute metrics and save results
|
724
597
|
metrics = compute_metrics(
|
725
598
|
references=test_prediction_pd.boolean_value,
|
726
|
-
probs=test_prediction_pd.
|
599
|
+
probs=test_prediction_pd.predicted_boolean_probability,
|
727
600
|
)
|
728
601
|
metrics["test_loss"] = np.mean(test_losses)
|
729
602
|
|