cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,12 @@
|
|
1
1
|
import os
|
2
|
+
from functools import partial
|
2
3
|
from typing import Optional, Union
|
3
4
|
|
5
|
+
import numpy as np
|
4
6
|
import torch
|
7
|
+
import torch.distributed as dist
|
5
8
|
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
9
|
+
CacheFileCollector,
|
6
10
|
create_dataset_from_meds_reader,
|
7
11
|
)
|
8
12
|
from cehrbert.runners.hf_runner_argument_dataclass import (
|
@@ -16,11 +20,15 @@ from cehrbert.runners.runner_util import (
|
|
16
20
|
load_parquet_as_dataset,
|
17
21
|
)
|
18
22
|
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
|
19
|
-
from transformers import
|
23
|
+
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
|
24
|
+
from transformers.trainer_utils import is_main_process
|
20
25
|
from transformers.utils import is_flash_attn_2_available, logging
|
21
26
|
|
22
27
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
|
23
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import
|
28
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
29
|
+
CehrGptDataCollator,
|
30
|
+
SamplePackingCehrGptDataCollator,
|
31
|
+
)
|
24
32
|
from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
|
25
33
|
from cehrgpt.models.config import CEHRGPTConfig
|
26
34
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
@@ -28,10 +36,25 @@ from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
|
28
36
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
29
37
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
30
38
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
39
|
+
from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
|
31
40
|
|
32
41
|
LOG = logging.get_logger("transformers")
|
33
42
|
|
34
43
|
|
44
|
+
class CustomEarlyStoppingCallback(EarlyStoppingCallback):
|
45
|
+
def check_metric_value(self, args, state, control, metric_value):
|
46
|
+
# best_metric is set by code for load_best_model
|
47
|
+
operator = np.greater if args.greater_is_better else np.less
|
48
|
+
if state.best_metric is None or (
|
49
|
+
operator(metric_value, state.best_metric)
|
50
|
+
and abs(metric_value - state.best_metric) / state.best_metric
|
51
|
+
> self.early_stopping_threshold
|
52
|
+
):
|
53
|
+
self.early_stopping_patience_counter = 0
|
54
|
+
else:
|
55
|
+
self.early_stopping_patience_counter += 1
|
56
|
+
|
57
|
+
|
35
58
|
def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
|
36
59
|
# Try to load the pretrained tokenizer
|
37
60
|
try:
|
@@ -59,13 +82,16 @@ def load_and_create_tokenizer(
|
|
59
82
|
f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
|
60
83
|
f"Tried to create the tokenizer, however the dataset is not provided."
|
61
84
|
)
|
85
|
+
LOG.info("Started training the tokenizer ...")
|
62
86
|
tokenizer = CehrGptTokenizer.train_tokenizer(
|
63
87
|
dataset,
|
64
88
|
{},
|
65
89
|
data_args,
|
66
90
|
PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
|
67
91
|
)
|
92
|
+
LOG.info("Finished training the tokenizer ...")
|
68
93
|
tokenizer.save_pretrained(tokenizer_abspath)
|
94
|
+
LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
|
69
95
|
|
70
96
|
return tokenizer
|
71
97
|
|
@@ -120,6 +146,7 @@ def load_and_create_model(
|
|
120
146
|
pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
|
121
147
|
else:
|
122
148
|
pretrained_embedding_dim = model_args.hidden_size
|
149
|
+
|
123
150
|
model_config = CEHRGPTConfig(
|
124
151
|
vocab_size=tokenizer.vocab_size,
|
125
152
|
value_vocab_size=tokenizer.value_vocab_size,
|
@@ -131,15 +158,23 @@ def load_and_create_model(
|
|
131
158
|
attn_implementation=attn_implementation,
|
132
159
|
causal_sfm=cehrgpt_args.causal_sfm,
|
133
160
|
demographics_size=cehrgpt_args.demographics_size,
|
161
|
+
next_token_prediction_loss_weight=cehrgpt_args.next_token_prediction_loss_weight,
|
134
162
|
lab_token_penalty=cehrgpt_args.lab_token_penalty,
|
135
163
|
lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
|
164
|
+
value_prediction_loss_weight=cehrgpt_args.value_prediction_loss_weight,
|
136
165
|
entropy_penalty=cehrgpt_args.entropy_penalty,
|
137
166
|
entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
|
138
167
|
n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
|
139
168
|
use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
|
140
169
|
pretrained_embedding_dim=pretrained_embedding_dim,
|
170
|
+
sample_packing_max_positions=(
|
171
|
+
cehrgpt_args.max_tokens_per_batch
|
172
|
+
if cehrgpt_args.sample_packing
|
173
|
+
else model_args.max_position_embeddings
|
174
|
+
),
|
141
175
|
**model_args.as_dict(),
|
142
176
|
)
|
177
|
+
|
143
178
|
model = CEHRGPT2LMHeadModel(model_config)
|
144
179
|
if tokenizer.pretrained_token_ids:
|
145
180
|
model.cehrgpt.update_pretrained_embeddings(
|
@@ -156,6 +191,11 @@ def load_and_create_model(
|
|
156
191
|
def main():
|
157
192
|
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
158
193
|
|
194
|
+
if cehrgpt_args.sample_packing and data_args.streaming:
|
195
|
+
raise RuntimeError(
|
196
|
+
f"sample_packing is not supported when streaming is enabled, please set streaming to False"
|
197
|
+
)
|
198
|
+
|
159
199
|
if data_args.streaming:
|
160
200
|
# This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
|
161
201
|
# This happens only when streaming is enabled
|
@@ -165,6 +205,8 @@ def main():
|
|
165
205
|
training_args.dataloader_num_workers = 0
|
166
206
|
training_args.dataloader_prefetch_factor = None
|
167
207
|
|
208
|
+
processed_dataset: Optional[DatasetDict] = None
|
209
|
+
cache_file_collector = CacheFileCollector()
|
168
210
|
prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
|
169
211
|
if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
|
170
212
|
LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
|
@@ -200,118 +242,158 @@ def main():
|
|
200
242
|
)
|
201
243
|
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
202
244
|
else:
|
203
|
-
#
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
"Trying to load the MEDS extension from disk at %s...",
|
212
|
-
meds_extension_path,
|
245
|
+
# Only run tokenization and data transformation in the main process in torch distributed training
|
246
|
+
# otherwise the multiple processes will create tokenizers at the same time
|
247
|
+
if is_main_process(training_args.local_rank):
|
248
|
+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
249
|
+
if data_args.is_data_in_meds:
|
250
|
+
meds_extension_path = get_meds_extension_path(
|
251
|
+
data_folder=data_args.data_folder,
|
252
|
+
dataset_prepared_path=data_args.dataset_prepared_path,
|
213
253
|
)
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
254
|
+
try:
|
255
|
+
LOG.info(
|
256
|
+
"Trying to load the MEDS extension from disk at %s...",
|
257
|
+
meds_extension_path,
|
258
|
+
)
|
259
|
+
dataset = load_from_disk(meds_extension_path)
|
260
|
+
if data_args.streaming:
|
261
|
+
if isinstance(dataset, DatasetDict):
|
262
|
+
dataset = {
|
263
|
+
k: v.to_iterable_dataset(
|
264
|
+
num_shards=training_args.dataloader_num_workers
|
265
|
+
)
|
266
|
+
for k, v in dataset.items()
|
267
|
+
}
|
268
|
+
else:
|
269
|
+
dataset = dataset.to_iterable_dataset(
|
219
270
|
num_shards=training_args.dataloader_num_workers
|
220
271
|
)
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
272
|
+
except FileNotFoundError as e:
|
273
|
+
LOG.warning(e)
|
274
|
+
dataset = create_dataset_from_meds_reader(
|
275
|
+
data_args=data_args,
|
276
|
+
dataset_mappings=[
|
277
|
+
MedToCehrGPTDatasetMapping(
|
278
|
+
data_args=data_args,
|
279
|
+
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
280
|
+
)
|
281
|
+
],
|
282
|
+
cache_file_collector=cache_file_collector,
|
283
|
+
)
|
284
|
+
if not data_args.streaming:
|
285
|
+
dataset.save_to_disk(str(meds_extension_path))
|
286
|
+
stats = dataset.cleanup_cache_files()
|
287
|
+
LOG.info(
|
288
|
+
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
289
|
+
stats,
|
236
290
|
)
|
237
|
-
|
291
|
+
# Clean up the files created from the data generator
|
292
|
+
cache_file_collector.remove_cache_files()
|
293
|
+
dataset = load_from_disk(str(meds_extension_path))
|
294
|
+
else:
|
295
|
+
# Load the dataset from the parquet files
|
296
|
+
dataset = load_parquet_as_dataset(
|
297
|
+
os.path.expanduser(data_args.data_folder),
|
298
|
+
split="train",
|
299
|
+
streaming=data_args.streaming,
|
238
300
|
)
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
244
|
-
stats,
|
301
|
+
# If streaming is enabled, we need to manually split the data into train/val
|
302
|
+
if data_args.streaming and data_args.validation_split_num:
|
303
|
+
dataset = dataset.shuffle(
|
304
|
+
buffer_size=10_000, seed=training_args.seed
|
245
305
|
)
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
306
|
+
train_set = dataset.skip(data_args.validation_split_num)
|
307
|
+
val_set = dataset.take(data_args.validation_split_num)
|
308
|
+
dataset = DatasetDict({"train": train_set, "validation": val_set})
|
309
|
+
elif data_args.validation_split_percentage:
|
310
|
+
dataset = dataset.train_test_split(
|
311
|
+
test_size=data_args.validation_split_percentage,
|
312
|
+
seed=training_args.seed,
|
313
|
+
)
|
314
|
+
dataset = DatasetDict(
|
315
|
+
{"train": dataset["train"], "validation": dataset["test"]}
|
316
|
+
)
|
317
|
+
else:
|
318
|
+
raise RuntimeError(
|
319
|
+
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
320
|
+
f"defined, otherwise validation_split_percentage needs to be provided. "
|
321
|
+
f"The current values are:\n"
|
322
|
+
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
323
|
+
f"validation_split_num: {data_args.validation_split_num}\n"
|
324
|
+
f"streaming: {data_args.streaming}"
|
325
|
+
)
|
326
|
+
|
327
|
+
# Create the CEHR-GPT tokenizer if it's not available in the output folder
|
328
|
+
cehrgpt_tokenizer = load_and_create_tokenizer(
|
329
|
+
data_args=data_args,
|
330
|
+
model_args=model_args,
|
331
|
+
cehrgpt_args=cehrgpt_args,
|
332
|
+
dataset=dataset,
|
253
333
|
)
|
254
|
-
# If streaming is enabled, we need to manually split the data into train/val
|
255
|
-
if data_args.streaming and data_args.validation_split_num:
|
256
|
-
dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
|
257
|
-
train_set = dataset.skip(data_args.validation_split_num)
|
258
|
-
val_set = dataset.take(data_args.validation_split_num)
|
259
|
-
dataset = DatasetDict({"train": train_set, "test": val_set})
|
260
|
-
elif data_args.validation_split_percentage:
|
261
|
-
dataset = dataset.train_test_split(
|
262
|
-
test_size=data_args.validation_split_percentage,
|
263
|
-
seed=training_args.seed,
|
264
|
-
)
|
265
|
-
else:
|
266
|
-
raise RuntimeError(
|
267
|
-
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
268
|
-
f"defined, otherwise validation_split_percentage needs to be provided. "
|
269
|
-
f"The current values are:\n"
|
270
|
-
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
271
|
-
f"validation_split_num: {data_args.validation_split_num}\n"
|
272
|
-
f"streaming: {data_args.streaming}"
|
273
|
-
)
|
274
334
|
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
335
|
+
# Retrain the tokenizer in case we want to pretrain the model further using different datasets
|
336
|
+
if cehrgpt_args.expand_tokenizer:
|
337
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
338
|
+
try:
|
339
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
340
|
+
new_tokenizer_path
|
341
|
+
)
|
342
|
+
except Exception:
|
343
|
+
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
344
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
345
|
+
dataset=dataset["train"],
|
346
|
+
data_args=data_args,
|
347
|
+
concept_name_mapping={},
|
348
|
+
pretrained_concept_embedding_model=PretrainedEmbeddings(
|
349
|
+
cehrgpt_args.pretrained_embedding_path
|
350
|
+
),
|
351
|
+
)
|
352
|
+
cehrgpt_tokenizer.save_pretrained(
|
353
|
+
os.path.expanduser(training_args.output_dir)
|
354
|
+
)
|
355
|
+
|
356
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
357
|
+
if not data_args.streaming:
|
358
|
+
all_columns = dataset["train"].column_names
|
359
|
+
if "visit_concept_ids" in all_columns:
|
360
|
+
dataset = dataset.remove_columns(["visit_concept_ids"])
|
361
|
+
|
362
|
+
# sort the patient features chronologically and tokenize the data
|
363
|
+
processed_dataset = create_cehrgpt_pretraining_dataset(
|
364
|
+
dataset=dataset,
|
365
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
366
|
+
data_args=data_args,
|
367
|
+
cache_file_collector=cache_file_collector,
|
368
|
+
)
|
369
|
+
# only save the data to the disk if it is not streaming
|
370
|
+
if not data_args.streaming:
|
371
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
372
|
+
stats = processed_dataset.cleanup_cache_files()
|
373
|
+
LOG.info(
|
374
|
+
"Clean up the cached files for the cehrgpt pretraining dataset: %s",
|
375
|
+
stats,
|
299
376
|
)
|
377
|
+
cache_file_collector.remove_cache_files()
|
378
|
+
|
379
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
380
|
+
if dist.is_available() and dist.is_initialized():
|
381
|
+
dist.barrier()
|
300
382
|
|
301
|
-
#
|
302
|
-
|
303
|
-
|
383
|
+
# Loading tokenizer in all processes in torch distributed training
|
384
|
+
tokenizer_name_or_path = os.path.expanduser(
|
385
|
+
training_args.output_dir
|
386
|
+
if cehrgpt_args.expand_tokenizer
|
387
|
+
else model_args.tokenizer_name_or_path
|
304
388
|
)
|
305
|
-
|
389
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
390
|
+
# Load the dataset from disk again to in torch distributed training
|
306
391
|
if not data_args.streaming:
|
307
|
-
processed_dataset.save_to_disk(str(prepared_ds_path))
|
308
|
-
stats = processed_dataset.cleanup_cache_files()
|
309
|
-
LOG.info(
|
310
|
-
"Clean up the cached files for the cehrgpt pretraining dataset: %s",
|
311
|
-
stats,
|
312
|
-
)
|
313
392
|
processed_dataset = load_from_disk(str(prepared_ds_path))
|
314
393
|
|
394
|
+
if processed_dataset is None:
|
395
|
+
raise RuntimeError("The processed dataset cannot be None")
|
396
|
+
|
315
397
|
def filter_func(examples):
|
316
398
|
if cehrgpt_args.drop_long_sequences:
|
317
399
|
return [
|
@@ -369,22 +451,64 @@ def main():
|
|
369
451
|
# Set seed before initializing model.
|
370
452
|
set_seed(training_args.seed)
|
371
453
|
|
372
|
-
if not data_args.streaming:
|
454
|
+
if not data_args.streaming and not cehrgpt_args.sample_packing:
|
373
455
|
processed_dataset.set_format("pt")
|
374
456
|
|
375
|
-
|
457
|
+
callbacks = []
|
458
|
+
if cehrgpt_args.use_early_stopping:
|
459
|
+
callbacks.append(
|
460
|
+
CustomEarlyStoppingCallback(
|
461
|
+
model_args.early_stopping_patience,
|
462
|
+
cehrgpt_args.early_stopping_threshold,
|
463
|
+
)
|
464
|
+
)
|
465
|
+
|
466
|
+
if cehrgpt_args.sample_packing:
|
467
|
+
trainer_class = partial(
|
468
|
+
SamplePackingTrainer,
|
469
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
470
|
+
max_position_embeddings=model_args.max_position_embeddings,
|
471
|
+
train_lengths=processed_dataset["train"]["num_of_concepts"],
|
472
|
+
validation_lengths=(
|
473
|
+
processed_dataset["validation"]
|
474
|
+
if "validation" in processed_dataset
|
475
|
+
else processed_dataset["test"]
|
476
|
+
)["num_of_concepts"],
|
477
|
+
)
|
478
|
+
training_args.per_device_train_batch_size = 1
|
479
|
+
training_args.per_device_eval_batch_size = 1
|
480
|
+
data_collator_fn = partial(
|
481
|
+
SamplePackingCehrGptDataCollator,
|
482
|
+
cehrgpt_args.max_tokens_per_batch,
|
483
|
+
model_args.max_position_embeddings,
|
484
|
+
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
485
|
+
)
|
486
|
+
else:
|
487
|
+
trainer_class = Trainer
|
488
|
+
data_collator_fn = CehrGptDataCollator
|
489
|
+
|
490
|
+
trainer = trainer_class(
|
376
491
|
model=model,
|
377
|
-
data_collator=
|
492
|
+
data_collator=data_collator_fn(
|
378
493
|
tokenizer=cehrgpt_tokenizer,
|
379
|
-
max_length=
|
494
|
+
max_length=(
|
495
|
+
cehrgpt_args.max_tokens_per_batch
|
496
|
+
if cehrgpt_args.sample_packing
|
497
|
+
else model_args.max_position_embeddings
|
498
|
+
),
|
380
499
|
shuffle_records=data_args.shuffle_records,
|
381
500
|
include_ttv_prediction=model_args.include_ttv_prediction,
|
382
501
|
use_sub_time_tokenization=model_args.use_sub_time_tokenization,
|
383
502
|
include_values=model_args.include_values,
|
384
503
|
),
|
385
504
|
train_dataset=processed_dataset["train"],
|
386
|
-
eval_dataset=
|
505
|
+
eval_dataset=(
|
506
|
+
processed_dataset["validation"]
|
507
|
+
if "validation" in processed_dataset
|
508
|
+
else processed_dataset["test"]
|
509
|
+
),
|
387
510
|
args=training_args,
|
511
|
+
callbacks=callbacks,
|
388
512
|
)
|
389
513
|
|
390
514
|
checkpoint = None
|
@@ -115,6 +115,9 @@ class CehrGPTArguments:
|
|
115
115
|
"help": "The lower bound of the learning rate range for hyperparameter tuning."
|
116
116
|
},
|
117
117
|
)
|
118
|
+
next_token_prediction_loss_weight: float = dataclasses.field(
|
119
|
+
default=1.0, metadata={"help": "The weight of the next token prediction loss"}
|
120
|
+
)
|
118
121
|
lab_token_penalty: Optional[bool] = dataclasses.field(
|
119
122
|
default=False,
|
120
123
|
metadata={
|
@@ -125,6 +128,10 @@ class CehrGPTArguments:
|
|
125
128
|
default=1.0,
|
126
129
|
metadata={"help": "lab_token_loss_weight penalty co-efficient"},
|
127
130
|
)
|
131
|
+
value_prediction_loss_weight: Optional[float] = dataclasses.field(
|
132
|
+
default=1.0,
|
133
|
+
metadata={"help": "The weight of the value prediction loss"},
|
134
|
+
)
|
128
135
|
entropy_penalty: Optional[bool] = dataclasses.field(
|
129
136
|
default=False,
|
130
137
|
metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
|
@@ -139,3 +146,38 @@ class CehrGPTArguments:
|
|
139
146
|
"help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
|
140
147
|
},
|
141
148
|
)
|
149
|
+
meds_repartition: Optional[bool] = dataclasses.field(
|
150
|
+
default=False,
|
151
|
+
metadata={
|
152
|
+
"help": "A flag to indicate whether we want to repartition the meds train tune sets"
|
153
|
+
},
|
154
|
+
)
|
155
|
+
use_early_stopping: Optional[bool] = dataclasses.field(
|
156
|
+
default=True,
|
157
|
+
metadata={"help": "A flag to indicate whether we want to use early stopping."},
|
158
|
+
)
|
159
|
+
early_stopping_threshold: Optional[float] = dataclasses.field(
|
160
|
+
default=0.01,
|
161
|
+
metadata={
|
162
|
+
"help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
|
163
|
+
},
|
164
|
+
)
|
165
|
+
sample_packing: Optional[bool] = dataclasses.field(
|
166
|
+
default=False,
|
167
|
+
metadata={
|
168
|
+
"help": "A flag to indicate whether we want to use sample packing for efficient training."
|
169
|
+
},
|
170
|
+
)
|
171
|
+
max_tokens_per_batch: int = dataclasses.field(
|
172
|
+
default=16384, metadata={"help": "Maximum number of tokens in each batch"}
|
173
|
+
)
|
174
|
+
add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
|
175
|
+
default=False,
|
176
|
+
metadata={
|
177
|
+
"help": "A flag to indicate whether we want to add end token in sample packing"
|
178
|
+
},
|
179
|
+
)
|
180
|
+
average_over_sequence: bool = dataclasses.field(
|
181
|
+
default=False,
|
182
|
+
metadata={"help": "Whether or not to average tokens per sequence"},
|
183
|
+
)
|
@@ -126,6 +126,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
|
|
126
126
|
|
127
127
|
|
128
128
|
def perform_hyperparameter_search(
|
129
|
+
trainer_class,
|
129
130
|
model_init: Callable,
|
130
131
|
dataset: DatasetDict,
|
131
132
|
data_collator: CehrGptDataCollator,
|
@@ -142,6 +143,7 @@ def perform_hyperparameter_search(
|
|
142
143
|
After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
|
143
144
|
|
144
145
|
Args:
|
146
|
+
trainer_class: A Trainer or its subclass
|
145
147
|
model_init (Callable): A function to initialize the model, used for each hyperparameter trial.
|
146
148
|
dataset (DatasetDict): A Hugging Face DatasetDict containing "train" and "validation" datasets.
|
147
149
|
data_collator (CehrGptDataCollator): A data collator for processing batches.
|
@@ -157,6 +159,7 @@ def perform_hyperparameter_search(
|
|
157
159
|
Example:
|
158
160
|
```
|
159
161
|
best_training_args = perform_hyperparameter_search(
|
162
|
+
trainer_class=Trainer,
|
160
163
|
model_init=my_model_init,
|
161
164
|
dataset=my_dataset_dict,
|
162
165
|
data_collator=my_data_collator,
|
@@ -187,7 +190,7 @@ def perform_hyperparameter_search(
|
|
187
190
|
cehrgpt_args.hyperparameter_tuning_percentage,
|
188
191
|
training_args.seed,
|
189
192
|
)
|
190
|
-
hyperparam_trainer =
|
193
|
+
hyperparam_trainer = trainer_class(
|
191
194
|
model_init=model_init,
|
192
195
|
data_collator=data_collator,
|
193
196
|
train_dataset=sampled_train,
|