cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +244 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/__init__.py +0 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,12 @@
|
|
1
1
|
import os
|
2
|
+
from functools import partial
|
2
3
|
from typing import Optional, Union
|
3
4
|
|
5
|
+
import numpy as np
|
4
6
|
import torch
|
7
|
+
import torch.distributed as dist
|
5
8
|
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
9
|
+
CacheFileCollector,
|
6
10
|
create_dataset_from_meds_reader,
|
7
11
|
)
|
8
12
|
from cehrbert.runners.hf_runner_argument_dataclass import (
|
@@ -16,21 +20,41 @@ from cehrbert.runners.runner_util import (
|
|
16
20
|
load_parquet_as_dataset,
|
17
21
|
)
|
18
22
|
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
|
19
|
-
from transformers import
|
23
|
+
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
|
24
|
+
from transformers.trainer_utils import is_main_process
|
20
25
|
from transformers.utils import is_flash_attn_2_available, logging
|
21
26
|
|
22
27
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
|
23
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import
|
28
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
29
|
+
CehrGptDataCollator,
|
30
|
+
SamplePackingCehrGptDataCollator,
|
31
|
+
)
|
32
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
|
24
33
|
from cehrgpt.models.config import CEHRGPTConfig
|
25
34
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
26
35
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
27
36
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
28
37
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
29
|
-
from
|
38
|
+
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
39
|
+
from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
|
30
40
|
|
31
41
|
LOG = logging.get_logger("transformers")
|
32
42
|
|
33
43
|
|
44
|
+
class CustomEarlyStoppingCallback(EarlyStoppingCallback):
|
45
|
+
def check_metric_value(self, args, state, control, metric_value):
|
46
|
+
# best_metric is set by code for load_best_model
|
47
|
+
operator = np.greater if args.greater_is_better else np.less
|
48
|
+
if state.best_metric is None or (
|
49
|
+
operator(metric_value, state.best_metric)
|
50
|
+
and abs(metric_value - state.best_metric) / state.best_metric
|
51
|
+
> self.early_stopping_threshold
|
52
|
+
):
|
53
|
+
self.early_stopping_patience_counter = 0
|
54
|
+
else:
|
55
|
+
self.early_stopping_patience_counter += 1
|
56
|
+
|
57
|
+
|
34
58
|
def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
|
35
59
|
# Try to load the pretrained tokenizer
|
36
60
|
try:
|
@@ -58,13 +82,16 @@ def load_and_create_tokenizer(
|
|
58
82
|
f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
|
59
83
|
f"Tried to create the tokenizer, however the dataset is not provided."
|
60
84
|
)
|
85
|
+
LOG.info("Started training the tokenizer ...")
|
61
86
|
tokenizer = CehrGptTokenizer.train_tokenizer(
|
62
87
|
dataset,
|
63
88
|
{},
|
64
89
|
data_args,
|
65
90
|
PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
|
66
91
|
)
|
92
|
+
LOG.info("Finished training the tokenizer ...")
|
67
93
|
tokenizer.save_pretrained(tokenizer_abspath)
|
94
|
+
LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
|
68
95
|
|
69
96
|
return tokenizer
|
70
97
|
|
@@ -82,11 +109,25 @@ def load_and_create_model(
|
|
82
109
|
model_abspath = os.path.expanduser(model_args.model_name_or_path)
|
83
110
|
if cehrgpt_args.continue_pretrain:
|
84
111
|
try:
|
85
|
-
|
112
|
+
pretrained_model = CEHRGPT2LMHeadModel.from_pretrained(
|
86
113
|
model_abspath,
|
87
114
|
attn_implementation=attn_implementation,
|
88
115
|
torch_dtype=torch_dtype,
|
89
116
|
)
|
117
|
+
if (
|
118
|
+
pretrained_model.config.max_position_embeddings
|
119
|
+
< model_args.max_position_embeddings
|
120
|
+
):
|
121
|
+
LOG.info(
|
122
|
+
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
123
|
+
)
|
124
|
+
pretrained_model.config.max_position_embeddings = (
|
125
|
+
model_args.max_position_embeddings
|
126
|
+
)
|
127
|
+
pretrained_model.resize_position_embeddings(
|
128
|
+
model_args.max_position_embeddings
|
129
|
+
)
|
130
|
+
return pretrained_model
|
90
131
|
except Exception as e:
|
91
132
|
LOG.error(
|
92
133
|
f"When continue_pretrain is set to True, it assumes that CEHR-GPT has been trained "
|
@@ -94,7 +135,7 @@ def load_and_create_model(
|
|
94
135
|
)
|
95
136
|
raise e
|
96
137
|
try:
|
97
|
-
model_config =
|
138
|
+
model_config = CEHRGPTConfig.from_pretrained(
|
98
139
|
model_abspath, attn_implementation=attn_implementation
|
99
140
|
)
|
100
141
|
except Exception as e:
|
@@ -105,6 +146,7 @@ def load_and_create_model(
|
|
105
146
|
pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
|
106
147
|
else:
|
107
148
|
pretrained_embedding_dim = model_args.hidden_size
|
149
|
+
|
108
150
|
model_config = CEHRGPTConfig(
|
109
151
|
vocab_size=tokenizer.vocab_size,
|
110
152
|
value_vocab_size=tokenizer.value_vocab_size,
|
@@ -116,15 +158,23 @@ def load_and_create_model(
|
|
116
158
|
attn_implementation=attn_implementation,
|
117
159
|
causal_sfm=cehrgpt_args.causal_sfm,
|
118
160
|
demographics_size=cehrgpt_args.demographics_size,
|
161
|
+
next_token_prediction_loss_weight=cehrgpt_args.next_token_prediction_loss_weight,
|
119
162
|
lab_token_penalty=cehrgpt_args.lab_token_penalty,
|
120
163
|
lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
|
164
|
+
value_prediction_loss_weight=cehrgpt_args.value_prediction_loss_weight,
|
121
165
|
entropy_penalty=cehrgpt_args.entropy_penalty,
|
122
166
|
entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
|
123
167
|
n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
|
124
168
|
use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
|
125
169
|
pretrained_embedding_dim=pretrained_embedding_dim,
|
170
|
+
sample_packing_max_positions=(
|
171
|
+
cehrgpt_args.max_tokens_per_batch
|
172
|
+
if cehrgpt_args.sample_packing
|
173
|
+
else model_args.max_position_embeddings
|
174
|
+
),
|
126
175
|
**model_args.as_dict(),
|
127
176
|
)
|
177
|
+
|
128
178
|
model = CEHRGPT2LMHeadModel(model_config)
|
129
179
|
if tokenizer.pretrained_token_ids:
|
130
180
|
model.cehrgpt.update_pretrained_embeddings(
|
@@ -141,6 +191,11 @@ def load_and_create_model(
|
|
141
191
|
def main():
|
142
192
|
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
143
193
|
|
194
|
+
if cehrgpt_args.sample_packing and data_args.streaming:
|
195
|
+
raise RuntimeError(
|
196
|
+
f"sample_packing is not supported when streaming is enabled, please set streaming to False"
|
197
|
+
)
|
198
|
+
|
144
199
|
if data_args.streaming:
|
145
200
|
# This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
|
146
201
|
# This happens only when streaming is enabled
|
@@ -148,8 +203,10 @@ def main():
|
|
148
203
|
# The iterable dataset doesn't have sharding implemented, so the number of works has to be set to 0
|
149
204
|
# Otherwise the trainer will throw an error
|
150
205
|
training_args.dataloader_num_workers = 0
|
151
|
-
training_args.dataloader_prefetch_factor =
|
206
|
+
training_args.dataloader_prefetch_factor = None
|
152
207
|
|
208
|
+
processed_dataset: Optional[DatasetDict] = None
|
209
|
+
cache_file_collector = CacheFileCollector()
|
153
210
|
prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
|
154
211
|
if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
|
155
212
|
LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
|
@@ -185,96 +242,157 @@ def main():
|
|
185
242
|
)
|
186
243
|
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
187
244
|
else:
|
188
|
-
#
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
"Trying to load the MEDS extension from disk at %s...",
|
197
|
-
meds_extension_path,
|
245
|
+
# Only run tokenization and data transformation in the main process in torch distributed training
|
246
|
+
# otherwise the multiple processes will create tokenizers at the same time
|
247
|
+
if is_main_process(training_args.local_rank):
|
248
|
+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
249
|
+
if data_args.is_data_in_meds:
|
250
|
+
meds_extension_path = get_meds_extension_path(
|
251
|
+
data_folder=data_args.data_folder,
|
252
|
+
dataset_prepared_path=data_args.dataset_prepared_path,
|
198
253
|
)
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
254
|
+
try:
|
255
|
+
LOG.info(
|
256
|
+
"Trying to load the MEDS extension from disk at %s...",
|
257
|
+
meds_extension_path,
|
258
|
+
)
|
259
|
+
dataset = load_from_disk(meds_extension_path)
|
260
|
+
if data_args.streaming:
|
261
|
+
if isinstance(dataset, DatasetDict):
|
262
|
+
dataset = {
|
263
|
+
k: v.to_iterable_dataset(
|
264
|
+
num_shards=training_args.dataloader_num_workers
|
265
|
+
)
|
266
|
+
for k, v in dataset.items()
|
267
|
+
}
|
268
|
+
else:
|
269
|
+
dataset = dataset.to_iterable_dataset(
|
204
270
|
num_shards=training_args.dataloader_num_workers
|
205
271
|
)
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
272
|
+
except FileNotFoundError as e:
|
273
|
+
LOG.warning(e)
|
274
|
+
dataset = create_dataset_from_meds_reader(
|
275
|
+
data_args=data_args,
|
276
|
+
dataset_mappings=[
|
277
|
+
MedToCehrGPTDatasetMapping(
|
278
|
+
data_args=data_args,
|
279
|
+
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
280
|
+
)
|
281
|
+
],
|
282
|
+
cache_file_collector=cache_file_collector,
|
283
|
+
)
|
284
|
+
if not data_args.streaming:
|
285
|
+
dataset.save_to_disk(str(meds_extension_path))
|
286
|
+
stats = dataset.cleanup_cache_files()
|
287
|
+
LOG.info(
|
288
|
+
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
289
|
+
stats,
|
211
290
|
)
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
data_args, is_pretraining=True
|
216
|
-
)
|
217
|
-
if not data_args.streaming:
|
218
|
-
dataset.save_to_disk(meds_extension_path)
|
219
|
-
else:
|
220
|
-
# Load the dataset from the parquet files
|
221
|
-
dataset = load_parquet_as_dataset(
|
222
|
-
data_args.data_folder, split="train", streaming=data_args.streaming
|
223
|
-
)
|
224
|
-
# If streaming is enabled, we need to manually split the data into train/val
|
225
|
-
if data_args.streaming and data_args.validation_split_num:
|
226
|
-
dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
|
227
|
-
train_set = dataset.skip(data_args.validation_split_num)
|
228
|
-
val_set = dataset.take(data_args.validation_split_num)
|
229
|
-
dataset = DatasetDict({"train": train_set, "test": val_set})
|
230
|
-
elif data_args.validation_split_percentage:
|
231
|
-
dataset = dataset.train_test_split(
|
232
|
-
test_size=data_args.validation_split_percentage,
|
233
|
-
seed=training_args.seed,
|
234
|
-
)
|
291
|
+
# Clean up the files created from the data generator
|
292
|
+
cache_file_collector.remove_cache_files()
|
293
|
+
dataset = load_from_disk(str(meds_extension_path))
|
235
294
|
else:
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
f"validation_split_num: {data_args.validation_split_num}\n"
|
242
|
-
f"streaming: {data_args.streaming}"
|
295
|
+
# Load the dataset from the parquet files
|
296
|
+
dataset = load_parquet_as_dataset(
|
297
|
+
os.path.expanduser(data_args.data_folder),
|
298
|
+
split="train",
|
299
|
+
streaming=data_args.streaming,
|
243
300
|
)
|
301
|
+
# If streaming is enabled, we need to manually split the data into train/val
|
302
|
+
if data_args.streaming and data_args.validation_split_num:
|
303
|
+
dataset = dataset.shuffle(
|
304
|
+
buffer_size=10_000, seed=training_args.seed
|
305
|
+
)
|
306
|
+
train_set = dataset.skip(data_args.validation_split_num)
|
307
|
+
val_set = dataset.take(data_args.validation_split_num)
|
308
|
+
dataset = DatasetDict({"train": train_set, "validation": val_set})
|
309
|
+
elif data_args.validation_split_percentage:
|
310
|
+
dataset = dataset.train_test_split(
|
311
|
+
test_size=data_args.validation_split_percentage,
|
312
|
+
seed=training_args.seed,
|
313
|
+
)
|
314
|
+
dataset = DatasetDict(
|
315
|
+
{"train": dataset["train"], "validation": dataset["test"]}
|
316
|
+
)
|
317
|
+
else:
|
318
|
+
raise RuntimeError(
|
319
|
+
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
320
|
+
f"defined, otherwise validation_split_percentage needs to be provided. "
|
321
|
+
f"The current values are:\n"
|
322
|
+
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
323
|
+
f"validation_split_num: {data_args.validation_split_num}\n"
|
324
|
+
f"streaming: {data_args.streaming}"
|
325
|
+
)
|
244
326
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
327
|
+
# Create the CEHR-GPT tokenizer if it's not available in the output folder
|
328
|
+
cehrgpt_tokenizer = load_and_create_tokenizer(
|
329
|
+
data_args=data_args,
|
330
|
+
model_args=model_args,
|
331
|
+
cehrgpt_args=cehrgpt_args,
|
332
|
+
dataset=dataset,
|
333
|
+
)
|
334
|
+
|
335
|
+
# Retrain the tokenizer in case we want to pretrain the model further using different datasets
|
336
|
+
if cehrgpt_args.expand_tokenizer:
|
337
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
338
|
+
try:
|
339
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
340
|
+
new_tokenizer_path
|
341
|
+
)
|
342
|
+
except Exception:
|
343
|
+
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
344
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
345
|
+
dataset=dataset["train"],
|
346
|
+
data_args=data_args,
|
347
|
+
concept_name_mapping={},
|
348
|
+
pretrained_concept_embedding_model=PretrainedEmbeddings(
|
349
|
+
cehrgpt_args.pretrained_embedding_path
|
350
|
+
),
|
351
|
+
)
|
352
|
+
cehrgpt_tokenizer.save_pretrained(
|
353
|
+
os.path.expanduser(training_args.output_dir)
|
354
|
+
)
|
355
|
+
|
356
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
357
|
+
if not data_args.streaming:
|
358
|
+
all_columns = dataset["train"].column_names
|
359
|
+
if "visit_concept_ids" in all_columns:
|
360
|
+
dataset = dataset.remove_columns(["visit_concept_ids"])
|
361
|
+
|
362
|
+
# sort the patient features chronologically and tokenize the data
|
363
|
+
processed_dataset = create_cehrgpt_pretraining_dataset(
|
364
|
+
dataset=dataset,
|
365
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
366
|
+
data_args=data_args,
|
367
|
+
cache_file_collector=cache_file_collector,
|
368
|
+
)
|
369
|
+
# only save the data to the disk if it is not streaming
|
370
|
+
if not data_args.streaming:
|
371
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
372
|
+
stats = processed_dataset.cleanup_cache_files()
|
373
|
+
LOG.info(
|
374
|
+
"Clean up the cached files for the cehrgpt pretraining dataset: %s",
|
375
|
+
stats,
|
269
376
|
)
|
377
|
+
cache_file_collector.remove_cache_files()
|
378
|
+
|
379
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
380
|
+
if dist.is_available() and dist.is_initialized():
|
381
|
+
dist.barrier()
|
270
382
|
|
271
|
-
#
|
272
|
-
|
273
|
-
|
383
|
+
# Loading tokenizer in all processes in torch distributed training
|
384
|
+
tokenizer_name_or_path = os.path.expanduser(
|
385
|
+
training_args.output_dir
|
386
|
+
if cehrgpt_args.expand_tokenizer
|
387
|
+
else model_args.tokenizer_name_or_path
|
274
388
|
)
|
275
|
-
|
389
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
390
|
+
# Load the dataset from disk again to in torch distributed training
|
276
391
|
if not data_args.streaming:
|
277
|
-
processed_dataset
|
392
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
393
|
+
|
394
|
+
if processed_dataset is None:
|
395
|
+
raise RuntimeError("The processed dataset cannot be None")
|
278
396
|
|
279
397
|
def filter_func(examples):
|
280
398
|
if cehrgpt_args.drop_long_sequences:
|
@@ -333,22 +451,64 @@ def main():
|
|
333
451
|
# Set seed before initializing model.
|
334
452
|
set_seed(training_args.seed)
|
335
453
|
|
336
|
-
if not data_args.streaming:
|
454
|
+
if not data_args.streaming and not cehrgpt_args.sample_packing:
|
337
455
|
processed_dataset.set_format("pt")
|
338
456
|
|
339
|
-
|
457
|
+
callbacks = []
|
458
|
+
if cehrgpt_args.use_early_stopping:
|
459
|
+
callbacks.append(
|
460
|
+
CustomEarlyStoppingCallback(
|
461
|
+
model_args.early_stopping_patience,
|
462
|
+
cehrgpt_args.early_stopping_threshold,
|
463
|
+
)
|
464
|
+
)
|
465
|
+
|
466
|
+
if cehrgpt_args.sample_packing:
|
467
|
+
trainer_class = partial(
|
468
|
+
SamplePackingTrainer,
|
469
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
470
|
+
max_position_embeddings=model_args.max_position_embeddings,
|
471
|
+
train_lengths=processed_dataset["train"]["num_of_concepts"],
|
472
|
+
validation_lengths=(
|
473
|
+
processed_dataset["validation"]
|
474
|
+
if "validation" in processed_dataset
|
475
|
+
else processed_dataset["test"]
|
476
|
+
)["num_of_concepts"],
|
477
|
+
)
|
478
|
+
training_args.per_device_train_batch_size = 1
|
479
|
+
training_args.per_device_eval_batch_size = 1
|
480
|
+
data_collator_fn = partial(
|
481
|
+
SamplePackingCehrGptDataCollator,
|
482
|
+
cehrgpt_args.max_tokens_per_batch,
|
483
|
+
model_args.max_position_embeddings,
|
484
|
+
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
485
|
+
)
|
486
|
+
else:
|
487
|
+
trainer_class = Trainer
|
488
|
+
data_collator_fn = CehrGptDataCollator
|
489
|
+
|
490
|
+
trainer = trainer_class(
|
340
491
|
model=model,
|
341
|
-
data_collator=
|
492
|
+
data_collator=data_collator_fn(
|
342
493
|
tokenizer=cehrgpt_tokenizer,
|
343
|
-
max_length=
|
494
|
+
max_length=(
|
495
|
+
cehrgpt_args.max_tokens_per_batch
|
496
|
+
if cehrgpt_args.sample_packing
|
497
|
+
else model_args.max_position_embeddings
|
498
|
+
),
|
344
499
|
shuffle_records=data_args.shuffle_records,
|
345
500
|
include_ttv_prediction=model_args.include_ttv_prediction,
|
346
501
|
use_sub_time_tokenization=model_args.use_sub_time_tokenization,
|
347
502
|
include_values=model_args.include_values,
|
348
503
|
),
|
349
504
|
train_dataset=processed_dataset["train"],
|
350
|
-
eval_dataset=
|
505
|
+
eval_dataset=(
|
506
|
+
processed_dataset["validation"]
|
507
|
+
if "validation" in processed_dataset
|
508
|
+
else processed_dataset["test"]
|
509
|
+
),
|
351
510
|
args=training_args,
|
511
|
+
callbacks=callbacks,
|
352
512
|
)
|
353
513
|
|
354
514
|
checkpoint = None
|
@@ -6,6 +6,10 @@ from typing import List, Optional
|
|
6
6
|
class CehrGPTArguments:
|
7
7
|
"""Arguments pertaining to what data we are going to input our model for training and eval."""
|
8
8
|
|
9
|
+
include_inpatient_hour_token: Optional[bool] = dataclasses.field(
|
10
|
+
default=True,
|
11
|
+
metadata={"help": "Include inpatient hour token"},
|
12
|
+
)
|
9
13
|
include_demographics: Optional[bool] = dataclasses.field(
|
10
14
|
default=False,
|
11
15
|
metadata={
|
@@ -111,6 +115,9 @@ class CehrGPTArguments:
|
|
111
115
|
"help": "The lower bound of the learning rate range for hyperparameter tuning."
|
112
116
|
},
|
113
117
|
)
|
118
|
+
next_token_prediction_loss_weight: float = dataclasses.field(
|
119
|
+
default=1.0, metadata={"help": "The weight of the next token prediction loss"}
|
120
|
+
)
|
114
121
|
lab_token_penalty: Optional[bool] = dataclasses.field(
|
115
122
|
default=False,
|
116
123
|
metadata={
|
@@ -121,6 +128,10 @@ class CehrGPTArguments:
|
|
121
128
|
default=1.0,
|
122
129
|
metadata={"help": "lab_token_loss_weight penalty co-efficient"},
|
123
130
|
)
|
131
|
+
value_prediction_loss_weight: Optional[float] = dataclasses.field(
|
132
|
+
default=1.0,
|
133
|
+
metadata={"help": "The weight of the value prediction loss"},
|
134
|
+
)
|
124
135
|
entropy_penalty: Optional[bool] = dataclasses.field(
|
125
136
|
default=False,
|
126
137
|
metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
|
@@ -135,3 +146,38 @@ class CehrGPTArguments:
|
|
135
146
|
"help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
|
136
147
|
},
|
137
148
|
)
|
149
|
+
meds_repartition: Optional[bool] = dataclasses.field(
|
150
|
+
default=False,
|
151
|
+
metadata={
|
152
|
+
"help": "A flag to indicate whether we want to repartition the meds train tune sets"
|
153
|
+
},
|
154
|
+
)
|
155
|
+
use_early_stopping: Optional[bool] = dataclasses.field(
|
156
|
+
default=True,
|
157
|
+
metadata={"help": "A flag to indicate whether we want to use early stopping."},
|
158
|
+
)
|
159
|
+
early_stopping_threshold: Optional[float] = dataclasses.field(
|
160
|
+
default=0.01,
|
161
|
+
metadata={
|
162
|
+
"help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
|
163
|
+
},
|
164
|
+
)
|
165
|
+
sample_packing: Optional[bool] = dataclasses.field(
|
166
|
+
default=False,
|
167
|
+
metadata={
|
168
|
+
"help": "A flag to indicate whether we want to use sample packing for efficient training."
|
169
|
+
},
|
170
|
+
)
|
171
|
+
max_tokens_per_batch: int = dataclasses.field(
|
172
|
+
default=16384, metadata={"help": "Maximum number of tokens in each batch"}
|
173
|
+
)
|
174
|
+
add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
|
175
|
+
default=False,
|
176
|
+
metadata={
|
177
|
+
"help": "A flag to indicate whether we want to add end token in sample packing"
|
178
|
+
},
|
179
|
+
)
|
180
|
+
average_over_sequence: bool = dataclasses.field(
|
181
|
+
default=False,
|
182
|
+
metadata={"help": "Whether or not to average tokens per sequence"},
|
183
|
+
)
|
@@ -126,6 +126,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
|
|
126
126
|
|
127
127
|
|
128
128
|
def perform_hyperparameter_search(
|
129
|
+
trainer_class,
|
129
130
|
model_init: Callable,
|
130
131
|
dataset: DatasetDict,
|
131
132
|
data_collator: CehrGptDataCollator,
|
@@ -142,6 +143,7 @@ def perform_hyperparameter_search(
|
|
142
143
|
After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
|
143
144
|
|
144
145
|
Args:
|
146
|
+
trainer_class: A Trainer or its subclass
|
145
147
|
model_init (Callable): A function to initialize the model, used for each hyperparameter trial.
|
146
148
|
dataset (DatasetDict): A Hugging Face DatasetDict containing "train" and "validation" datasets.
|
147
149
|
data_collator (CehrGptDataCollator): A data collator for processing batches.
|
@@ -157,6 +159,7 @@ def perform_hyperparameter_search(
|
|
157
159
|
Example:
|
158
160
|
```
|
159
161
|
best_training_args = perform_hyperparameter_search(
|
162
|
+
trainer_class=Trainer,
|
160
163
|
model_init=my_model_init,
|
161
164
|
dataset=my_dataset_dict,
|
162
165
|
data_collator=my_data_collator,
|
@@ -187,7 +190,7 @@ def perform_hyperparameter_search(
|
|
187
190
|
cehrgpt_args.hyperparameter_tuning_percentage,
|
188
191
|
training_args.seed,
|
189
192
|
)
|
190
|
-
hyperparam_trainer =
|
193
|
+
hyperparam_trainer = trainer_class(
|
191
194
|
model_init=model_init,
|
192
195
|
data_collator=data_collator,
|
193
196
|
train_dataset=sampled_train,
|