cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +244 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/__init__.py +0 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -9,20 +9,16 @@ from pathlib import Path
|
|
9
9
|
import numpy as np
|
10
10
|
import pandas as pd
|
11
11
|
import torch
|
12
|
-
|
13
|
-
|
14
|
-
)
|
12
|
+
import torch.distributed as dist
|
13
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
15
14
|
from cehrbert.runners.hf_cehrbert_finetune_runner import compute_metrics
|
16
15
|
from cehrbert.runners.hf_runner_argument_dataclass import (
|
17
|
-
DataTrainingArguments,
|
18
16
|
FineTuneModelType,
|
19
17
|
ModelArguments,
|
20
18
|
)
|
21
19
|
from cehrbert.runners.runner_util import (
|
22
20
|
generate_prepared_ds_path,
|
23
21
|
get_last_hf_checkpoint,
|
24
|
-
get_meds_extension_path,
|
25
|
-
load_parquet_as_dataset,
|
26
22
|
)
|
27
23
|
from datasets import DatasetDict, concatenate_datasets, load_from_disk
|
28
24
|
from peft import LoraConfig, PeftModel, get_peft_model
|
@@ -38,11 +34,15 @@ from transformers import (
|
|
38
34
|
TrainingArguments,
|
39
35
|
set_seed,
|
40
36
|
)
|
41
|
-
from transformers.
|
37
|
+
from transformers.trainer_utils import is_main_process
|
42
38
|
from transformers.utils import is_flash_attn_2_available, logging
|
43
39
|
|
44
40
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
45
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import
|
41
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
42
|
+
CehrGptDataCollator,
|
43
|
+
SamplePackingCehrGptDataCollator,
|
44
|
+
)
|
45
|
+
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
46
46
|
from cehrgpt.models.hf_cehrgpt import (
|
47
47
|
CEHRGPTConfig,
|
48
48
|
CehrGptForClassification,
|
@@ -50,9 +50,12 @@ from cehrgpt.models.hf_cehrgpt import (
|
|
50
50
|
)
|
51
51
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
52
52
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
53
|
+
from cehrgpt.runners.data_utils import prepare_finetune_dataset
|
53
54
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
55
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
54
56
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
55
57
|
from cehrgpt.runners.hyperparameter_search_util import perform_hyperparameter_search
|
58
|
+
from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
|
56
59
|
|
57
60
|
LOG = logging.get_logger("transformers")
|
58
61
|
|
@@ -155,140 +158,6 @@ def load_finetuned_model(
|
|
155
158
|
raise ValueError(f"Can not load the finetuned model from {model_name_or_path}")
|
156
159
|
|
157
160
|
|
158
|
-
def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
159
|
-
"""
|
160
|
-
Creates training, validation, and testing dataset splits based on specified splitting strategies.
|
161
|
-
|
162
|
-
This function splits a dataset into training, validation, and test sets, using either chronological,
|
163
|
-
patient-based, or random splitting strategies, depending on the parameters provided in `data_args`.
|
164
|
-
|
165
|
-
- **Chronological split**: Sorts by a specified date and splits based on historical and future data.
|
166
|
-
- **Patient-based split**: Splits by unique patient IDs to ensure that patients in each split are distinct.
|
167
|
-
- **Random split**: Performs a straightforward random split of the dataset.
|
168
|
-
|
169
|
-
If `data_args.test_data_folder` is provided, a test set is loaded directly from it. Otherwise,
|
170
|
-
the test set is created by further splitting the validation set based on `test_eval_ratio`.
|
171
|
-
|
172
|
-
Parameters:
|
173
|
-
data_args (DataTrainingArguments): A configuration object containing data-related arguments, including:
|
174
|
-
- `data_folder` (str): Path to the main dataset.
|
175
|
-
- `test_data_folder` (str, optional): Path to an optional test dataset.
|
176
|
-
- `chronological_split` (bool): Whether to split chronologically.
|
177
|
-
- `split_by_patient` (bool): Whether to split by unique patient IDs.
|
178
|
-
- `validation_split_percentage` (float): Percentage of data to use for validation.
|
179
|
-
- `test_eval_ratio` (float): Ratio of test to validation data when creating a test set from validation.
|
180
|
-
- `preprocessing_num_workers` (int): Number of processes for parallel data filtering.
|
181
|
-
- `preprocessing_batch_size` (int): Batch size for batched operations.
|
182
|
-
seed (int): Random seed for reproducibility of splits.
|
183
|
-
|
184
|
-
Returns:
|
185
|
-
Tuple[Dataset, Dataset, Dataset]: A tuple containing:
|
186
|
-
- `train_set` (Dataset): Training split of the dataset.
|
187
|
-
- `validation_set` (Dataset): Validation split of the dataset.
|
188
|
-
- `test_set` (Dataset): Test split of the dataset.
|
189
|
-
|
190
|
-
Raises:
|
191
|
-
FileNotFoundError: If `data_args.data_folder` or `data_args.test_data_folder` does not exist.
|
192
|
-
ValueError: If incompatible arguments are passed for splitting strategies.
|
193
|
-
|
194
|
-
Example Usage:
|
195
|
-
data_args = DataTrainingArguments(
|
196
|
-
data_folder="data/",
|
197
|
-
validation_split_percentage=0.1,
|
198
|
-
test_eval_ratio=0.2,
|
199
|
-
chronological_split=True
|
200
|
-
)
|
201
|
-
train_set, validation_set, test_set = create_dataset_splits(data_args, seed=42)
|
202
|
-
"""
|
203
|
-
dataset = load_parquet_as_dataset(data_args.data_folder)
|
204
|
-
test_set = (
|
205
|
-
None
|
206
|
-
if not data_args.test_data_folder
|
207
|
-
else load_parquet_as_dataset(data_args.test_data_folder)
|
208
|
-
)
|
209
|
-
|
210
|
-
if data_args.chronological_split:
|
211
|
-
# Chronological split by sorting on `index_date`
|
212
|
-
dataset = dataset.sort("index_date")
|
213
|
-
total_size = len(dataset)
|
214
|
-
train_end = int((1 - data_args.validation_split_percentage) * total_size)
|
215
|
-
|
216
|
-
# Perform the split
|
217
|
-
train_set = dataset.select(range(0, train_end))
|
218
|
-
validation_set = dataset.select(range(train_end, total_size))
|
219
|
-
|
220
|
-
if test_set is None:
|
221
|
-
test_valid_split = validation_set.train_test_split(
|
222
|
-
test_size=data_args.test_eval_ratio, seed=seed
|
223
|
-
)
|
224
|
-
validation_set, test_set = (
|
225
|
-
test_valid_split["train"],
|
226
|
-
test_valid_split["test"],
|
227
|
-
)
|
228
|
-
|
229
|
-
elif data_args.split_by_patient:
|
230
|
-
# Patient-based split
|
231
|
-
LOG.info("Using the split_by_patient strategy")
|
232
|
-
unique_patient_ids = dataset.unique("person_id")
|
233
|
-
LOG.info(f"There are {len(unique_patient_ids)} patients in total")
|
234
|
-
|
235
|
-
np.random.seed(seed)
|
236
|
-
np.random.shuffle(unique_patient_ids)
|
237
|
-
|
238
|
-
train_end = int(
|
239
|
-
len(unique_patient_ids) * (1 - data_args.validation_split_percentage)
|
240
|
-
)
|
241
|
-
train_patient_ids = set(unique_patient_ids[:train_end])
|
242
|
-
|
243
|
-
if test_set is None:
|
244
|
-
validation_end = int(
|
245
|
-
train_end
|
246
|
-
+ len(unique_patient_ids)
|
247
|
-
* data_args.validation_split_percentage
|
248
|
-
* data_args.test_eval_ratio
|
249
|
-
)
|
250
|
-
val_patient_ids = set(unique_patient_ids[train_end:validation_end])
|
251
|
-
test_patient_ids = set(unique_patient_ids[validation_end:])
|
252
|
-
else:
|
253
|
-
val_patient_ids, test_patient_ids = (
|
254
|
-
set(unique_patient_ids[train_end:]),
|
255
|
-
None,
|
256
|
-
)
|
257
|
-
|
258
|
-
# Helper function to apply patient-based filtering
|
259
|
-
def filter_by_patient_ids(patient_ids):
|
260
|
-
return dataset.filter(
|
261
|
-
lambda batch: [pid in patient_ids for pid in batch["person_id"]],
|
262
|
-
num_proc=data_args.preprocessing_num_workers,
|
263
|
-
batched=True,
|
264
|
-
batch_size=data_args.preprocessing_batch_size,
|
265
|
-
)
|
266
|
-
|
267
|
-
# Generate splits
|
268
|
-
train_set = filter_by_patient_ids(train_patient_ids)
|
269
|
-
validation_set = filter_by_patient_ids(val_patient_ids)
|
270
|
-
if test_set is None:
|
271
|
-
test_set = filter_by_patient_ids(test_patient_ids)
|
272
|
-
|
273
|
-
else:
|
274
|
-
# Random split
|
275
|
-
train_val = dataset.train_test_split(
|
276
|
-
test_size=data_args.validation_split_percentage, seed=seed
|
277
|
-
)
|
278
|
-
train_set, validation_set = train_val["train"], train_val["test"]
|
279
|
-
|
280
|
-
if test_set is None:
|
281
|
-
test_valid_split = validation_set.train_test_split(
|
282
|
-
test_size=data_args.test_eval_ratio, seed=seed
|
283
|
-
)
|
284
|
-
validation_set, test_set = (
|
285
|
-
test_valid_split["train"],
|
286
|
-
test_valid_split["test"],
|
287
|
-
)
|
288
|
-
|
289
|
-
return train_set, validation_set, test_set
|
290
|
-
|
291
|
-
|
292
161
|
def model_init(
|
293
162
|
model_args: ModelArguments,
|
294
163
|
training_args: TrainingArguments,
|
@@ -363,16 +232,16 @@ def main():
|
|
363
232
|
prepared_ds_path = generate_prepared_ds_path(
|
364
233
|
data_args, model_args, data_folder=data_args.cohort_folder
|
365
234
|
)
|
366
|
-
|
235
|
+
cache_file_collector = CacheFileCollector()
|
367
236
|
processed_dataset = None
|
368
237
|
if any(prepared_ds_path.glob("*")):
|
369
238
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
370
239
|
processed_dataset = load_from_disk(str(prepared_ds_path))
|
371
240
|
LOG.info("Prepared dataset loaded from disk...")
|
372
241
|
if cehrgpt_args.expand_tokenizer:
|
373
|
-
|
242
|
+
if tokenizer_exists(training_args.output_dir):
|
374
243
|
tokenizer = CehrGptTokenizer.from_pretrained(training_args.output_dir)
|
375
|
-
|
244
|
+
else:
|
376
245
|
LOG.warning(
|
377
246
|
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
378
247
|
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
@@ -382,81 +251,77 @@ def main():
|
|
382
251
|
shutil.rmtree(prepared_ds_path)
|
383
252
|
|
384
253
|
if processed_dataset is None:
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
data_folder=data_args.cohort_folder,
|
389
|
-
dataset_prepared_path=data_args.dataset_prepared_path,
|
254
|
+
if is_main_process(training_args.local_rank):
|
255
|
+
final_splits = prepare_finetune_dataset(
|
256
|
+
data_args, training_args, cehrgpt_args, cache_file_collector
|
390
257
|
)
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
else:
|
405
|
-
dataset = dataset.to_iterable_dataset(
|
406
|
-
num_shards=training_args.dataloader_num_workers
|
258
|
+
if cehrgpt_args.expand_tokenizer:
|
259
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
260
|
+
if tokenizer_exists(new_tokenizer_path):
|
261
|
+
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
262
|
+
else:
|
263
|
+
# Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
|
264
|
+
# embedded in the pretrained model
|
265
|
+
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
266
|
+
cehrgpt_args.pretrained_embedding_path
|
267
|
+
)
|
268
|
+
if not pretrained_concept_embedding_model.exists:
|
269
|
+
pretrained_concept_embedding_model = (
|
270
|
+
tokenizer.pretrained_concept_embedding_model
|
407
271
|
)
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
dataset.save_to_disk(meds_extension_path)
|
415
|
-
train_set = dataset["train"]
|
416
|
-
validation_set = dataset["validation"]
|
417
|
-
test_set = dataset["test"]
|
418
|
-
else:
|
419
|
-
train_set, validation_set, test_set = create_dataset_splits(
|
420
|
-
data_args=data_args, seed=training_args.seed
|
421
|
-
)
|
422
|
-
# Organize them into a single DatasetDict
|
423
|
-
final_splits = DatasetDict(
|
424
|
-
{"train": train_set, "validation": validation_set, "test": test_set}
|
425
|
-
)
|
426
|
-
|
427
|
-
if cehrgpt_args.expand_tokenizer:
|
428
|
-
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
429
|
-
try:
|
430
|
-
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
431
|
-
except Exception:
|
432
|
-
# Try to use the defined pretrained embeddings if exists,
|
433
|
-
# Otherwise we default to the pretrained model embedded in the pretrained model
|
434
|
-
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
435
|
-
cehrgpt_args.pretrained_embedding_path
|
436
|
-
)
|
437
|
-
if not pretrained_concept_embedding_model.exists:
|
438
|
-
pretrained_concept_embedding_model = (
|
439
|
-
tokenizer.pretrained_concept_embedding_model
|
272
|
+
tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
273
|
+
cehrgpt_tokenizer=tokenizer,
|
274
|
+
dataset=final_splits["train"],
|
275
|
+
data_args=data_args,
|
276
|
+
concept_name_mapping={},
|
277
|
+
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
440
278
|
)
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
279
|
+
tokenizer.save_pretrained(
|
280
|
+
os.path.expanduser(training_args.output_dir)
|
281
|
+
)
|
282
|
+
|
283
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
284
|
+
if not data_args.streaming:
|
285
|
+
all_columns = final_splits["train"].column_names
|
286
|
+
if "visit_concept_ids" in all_columns:
|
287
|
+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
|
288
|
+
|
289
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
290
|
+
dataset=final_splits,
|
291
|
+
cehrgpt_tokenizer=tokenizer,
|
292
|
+
data_args=data_args,
|
293
|
+
cache_file_collector=cache_file_collector,
|
294
|
+
)
|
295
|
+
if not data_args.streaming:
|
296
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
297
|
+
stats = processed_dataset.cleanup_cache_files()
|
298
|
+
LOG.info(
|
299
|
+
"Clean up the cached files for the cehrgpt finetuning dataset : %s",
|
300
|
+
stats,
|
447
301
|
)
|
448
|
-
tokenizer.save_pretrained(os.path.expanduser(training_args.output_dir))
|
449
302
|
|
450
|
-
|
451
|
-
|
303
|
+
# Remove any cached files if there are any
|
304
|
+
cache_file_collector.remove_cache_files()
|
305
|
+
|
306
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
307
|
+
if dist.is_available() and dist.is_initialized():
|
308
|
+
dist.barrier()
|
309
|
+
|
310
|
+
# Loading tokenizer in all processes in torch distributed training
|
311
|
+
tokenizer_name_or_path = os.path.expanduser(
|
312
|
+
training_args.output_dir
|
313
|
+
if cehrgpt_args.expand_tokenizer
|
314
|
+
else model_args.tokenizer_name_or_path
|
452
315
|
)
|
453
|
-
|
454
|
-
|
316
|
+
tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
317
|
+
# Load the dataset from disk again to in torch distributed training
|
318
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
455
319
|
|
456
320
|
# Set seed before initializing model.
|
457
321
|
set_seed(training_args.seed)
|
458
322
|
|
459
|
-
|
323
|
+
if not data_args.streaming and not cehrgpt_args.sample_packing:
|
324
|
+
processed_dataset.set_format("pt")
|
460
325
|
|
461
326
|
if cehrgpt_args.few_shot_predict:
|
462
327
|
# At least we need two examples to have a validation set for early stopping
|
@@ -476,13 +341,40 @@ def main():
|
|
476
341
|
config = CEHRGPTConfig.from_pretrained(model_args.model_name_or_path)
|
477
342
|
if config.max_position_embeddings < model_args.max_position_embeddings:
|
478
343
|
config.max_position_embeddings = model_args.max_position_embeddings
|
344
|
+
|
345
|
+
# persist this parameter in case this is overwritten by sample packing
|
346
|
+
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
347
|
+
|
348
|
+
if cehrgpt_args.sample_packing:
|
349
|
+
trainer_class = partial(
|
350
|
+
SamplePackingTrainer,
|
351
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
352
|
+
max_position_embeddings=config.max_position_embeddings,
|
353
|
+
train_lengths=processed_dataset["train"]["num_of_concepts"],
|
354
|
+
validation_lengths=processed_dataset["validation"]["num_of_concepts"],
|
355
|
+
)
|
356
|
+
training_args.per_device_train_batch_size = 1
|
357
|
+
training_args.per_device_eval_batch_size = 1
|
358
|
+
data_collator_fn = partial(
|
359
|
+
SamplePackingCehrGptDataCollator,
|
360
|
+
cehrgpt_args.max_tokens_per_batch,
|
361
|
+
config.max_position_embeddings,
|
362
|
+
)
|
363
|
+
else:
|
364
|
+
trainer_class = Trainer
|
365
|
+
data_collator_fn = CehrGptDataCollator
|
366
|
+
|
479
367
|
# We suppress the additional learning objectives in fine-tuning
|
480
|
-
data_collator =
|
368
|
+
data_collator = data_collator_fn(
|
481
369
|
tokenizer=tokenizer,
|
482
370
|
max_length=(
|
483
|
-
|
484
|
-
if
|
485
|
-
else
|
371
|
+
cehrgpt_args.max_tokens_per_batch
|
372
|
+
if cehrgpt_args.sample_packing
|
373
|
+
else (
|
374
|
+
config.max_position_embeddings - 1
|
375
|
+
if config.causal_sfm
|
376
|
+
else config.max_position_embeddings
|
377
|
+
)
|
486
378
|
),
|
487
379
|
include_values=model_args.include_values,
|
488
380
|
pretraining=False,
|
@@ -493,8 +385,8 @@ def main():
|
|
493
385
|
|
494
386
|
if training_args.do_train:
|
495
387
|
if cehrgpt_args.hyperparameter_tuning:
|
496
|
-
model_args.early_stopping_patience = LARGE_INTEGER
|
497
388
|
training_args = perform_hyperparameter_search(
|
389
|
+
trainer_class,
|
498
390
|
partial(model_init, model_args, training_args, tokenizer),
|
499
391
|
processed_dataset,
|
500
392
|
data_collator,
|
@@ -502,13 +394,20 @@ def main():
|
|
502
394
|
model_args,
|
503
395
|
cehrgpt_args,
|
504
396
|
)
|
397
|
+
|
398
|
+
if cehrgpt_args.retrain_with_full:
|
505
399
|
# Always retrain with the full set when hyperparameter tuning is set to true
|
506
400
|
retrain_with_full_set(
|
507
|
-
|
401
|
+
trainer_class,
|
402
|
+
model_args,
|
403
|
+
training_args,
|
404
|
+
tokenizer,
|
405
|
+
processed_dataset,
|
406
|
+
data_collator,
|
508
407
|
)
|
509
408
|
else:
|
510
409
|
# Initialize Trainer for final training on the combined train+val set
|
511
|
-
trainer =
|
410
|
+
trainer = trainer_class(
|
512
411
|
model=model_init(model_args, training_args, tokenizer),
|
513
412
|
data_collator=data_collator,
|
514
413
|
args=training_args,
|
@@ -531,45 +430,31 @@ def main():
|
|
531
430
|
trainer.save_metrics("train", metrics)
|
532
431
|
trainer.save_state()
|
533
432
|
|
534
|
-
# Retrain the model with full set using the num of epoches before earlying stopping
|
535
|
-
if cehrgpt_args.retrain_with_full:
|
536
|
-
update_num_epoch_before_early_stopping_callback = None
|
537
|
-
for callback in trainer.callback_handler.callbacks:
|
538
|
-
if isinstance(callback, UpdateNumEpochsBeforeEarlyStoppingCallback):
|
539
|
-
update_num_epoch_before_early_stopping_callback = callback
|
540
|
-
|
541
|
-
if update_num_epoch_before_early_stopping_callback is None:
|
542
|
-
raise RuntimeError(
|
543
|
-
f"{UpdateNumEpochsBeforeEarlyStoppingCallback} must be included as a callback!"
|
544
|
-
)
|
545
|
-
final_num_epochs = (
|
546
|
-
update_num_epoch_before_early_stopping_callback.num_epochs_before_early_stopping
|
547
|
-
)
|
548
|
-
training_args.num_train_epochs = final_num_epochs
|
549
|
-
LOG.info(
|
550
|
-
"Num Epochs before early stopping: %s",
|
551
|
-
training_args.num_train_epochs,
|
552
|
-
)
|
553
|
-
retrain_with_full_set(
|
554
|
-
model_args,
|
555
|
-
training_args,
|
556
|
-
tokenizer,
|
557
|
-
processed_dataset,
|
558
|
-
data_collator,
|
559
|
-
)
|
560
|
-
|
561
433
|
if training_args.do_predict:
|
434
|
+
if cehrgpt_args.sample_packing:
|
435
|
+
batch_sampler = SamplePackingBatchSampler(
|
436
|
+
lengths=processed_dataset["test"]["num_of_concepts"],
|
437
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
438
|
+
max_position_embeddings=config.max_position_embeddings,
|
439
|
+
drop_last=training_args.dataloader_drop_last,
|
440
|
+
seed=training_args.seed,
|
441
|
+
)
|
442
|
+
per_device_eval_batch_size = 1
|
443
|
+
else:
|
444
|
+
batch_sampler = None
|
562
445
|
test_dataloader = DataLoader(
|
563
446
|
dataset=processed_dataset["test"],
|
564
|
-
batch_size=
|
447
|
+
batch_size=per_device_eval_batch_size,
|
565
448
|
num_workers=training_args.dataloader_num_workers,
|
566
449
|
collate_fn=data_collator,
|
567
450
|
pin_memory=training_args.dataloader_pin_memory,
|
451
|
+
batch_sampler=batch_sampler,
|
568
452
|
)
|
569
453
|
do_predict(test_dataloader, model_args, training_args, cehrgpt_args)
|
570
454
|
|
571
455
|
|
572
456
|
def retrain_with_full_set(
|
457
|
+
trainer_class,
|
573
458
|
model_args: ModelArguments,
|
574
459
|
training_args: TrainingArguments,
|
575
460
|
tokenizer: CehrGptTokenizer,
|
@@ -586,6 +471,7 @@ def retrain_with_full_set(
|
|
586
471
|
and state information.
|
587
472
|
|
588
473
|
Args:
|
474
|
+
trainer_class: Trainer or its subclass
|
589
475
|
model_args (ModelArguments): Model configuration and hyperparameters.
|
590
476
|
training_args (TrainingArguments): Training configuration, including output directory,
|
591
477
|
evaluation strategy, and other training parameters.
|
@@ -607,7 +493,7 @@ def retrain_with_full_set(
|
|
607
493
|
# Disable evaluation
|
608
494
|
training_args.evaluation_strategy = "no"
|
609
495
|
checkpoint = get_last_hf_checkpoint(training_args)
|
610
|
-
final_trainer =
|
496
|
+
final_trainer = trainer_class(
|
611
497
|
model=model_init(model_args, training_args, tokenizer),
|
612
498
|
data_collator=data_collator,
|
613
499
|
args=training_args,
|
@@ -662,15 +548,15 @@ def do_predict(
|
|
662
548
|
test_losses = []
|
663
549
|
with torch.no_grad():
|
664
550
|
for index, batch in enumerate(tqdm(test_dataloader, desc="Predicting")):
|
665
|
-
person_ids = batch.pop("person_id").numpy().
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
551
|
+
person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
|
552
|
+
if person_ids.ndim == 0:
|
553
|
+
person_ids = np.asarray([person_ids])
|
554
|
+
|
555
|
+
index_dates = batch.pop("index_date").numpy().squeeze()
|
556
|
+
if index_dates.ndim == 0:
|
557
|
+
index_dates = np.asarray([index_dates])
|
558
|
+
index_dates = list(map(datetime.fromtimestamp, index_dates.tolist()))
|
559
|
+
|
674
560
|
batch = {k: v.to(device) for k, v in batch.items()}
|
675
561
|
# Forward pass
|
676
562
|
output = model(**batch, output_attentions=False, output_hidden_states=False)
|
@@ -678,17 +564,25 @@ def do_predict(
|
|
678
564
|
|
679
565
|
# Collect logits and labels for prediction
|
680
566
|
logits = output.logits.float().cpu().numpy().squeeze()
|
567
|
+
if logits.ndim == 0:
|
568
|
+
logits = np.asarray([logits])
|
569
|
+
probabilities = sigmoid(logits)
|
570
|
+
|
681
571
|
labels = (
|
682
|
-
batch["classifier_label"].float().cpu().numpy().
|
572
|
+
batch["classifier_label"].float().cpu().numpy().astype(bool).squeeze()
|
683
573
|
)
|
684
|
-
|
574
|
+
if labels.ndim == 0:
|
575
|
+
labels = np.asarray([labels])
|
576
|
+
|
685
577
|
# Save predictions to parquet file
|
686
578
|
test_prediction_pd = pd.DataFrame(
|
687
579
|
{
|
688
580
|
"subject_id": person_ids,
|
689
581
|
"prediction_time": index_dates,
|
690
|
-
"
|
691
|
-
"
|
582
|
+
"predicted_boolean_probability": probabilities,
|
583
|
+
"predicted_boolean_value": pd.Series(
|
584
|
+
[None] * len(person_ids), dtype=bool
|
585
|
+
),
|
692
586
|
"boolean_value": labels,
|
693
587
|
}
|
694
588
|
)
|
@@ -702,7 +596,7 @@ def do_predict(
|
|
702
596
|
# Compute metrics and save results
|
703
597
|
metrics = compute_metrics(
|
704
598
|
references=test_prediction_pd.boolean_value,
|
705
|
-
probs=test_prediction_pd.
|
599
|
+
probs=test_prediction_pd.predicted_boolean_probability,
|
706
600
|
)
|
707
601
|
metrics["test_loss"] = np.mean(test_losses)
|
708
602
|
|