cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,12 @@
|
|
1
1
|
import os
|
2
|
+
from functools import partial
|
2
3
|
from typing import Optional, Union
|
3
4
|
|
5
|
+
import numpy as np
|
4
6
|
import torch
|
7
|
+
import torch.distributed as dist
|
5
8
|
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
9
|
+
CacheFileCollector,
|
6
10
|
create_dataset_from_meds_reader,
|
7
11
|
)
|
8
12
|
from cehrbert.runners.hf_runner_argument_dataclass import (
|
@@ -16,22 +20,42 @@ from cehrbert.runners.runner_util import (
|
|
16
20
|
load_parquet_as_dataset,
|
17
21
|
)
|
18
22
|
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_from_disk
|
19
|
-
from transformers import
|
23
|
+
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
|
24
|
+
from transformers.trainer_utils import is_main_process
|
20
25
|
from transformers.utils import is_flash_attn_2_available, logging
|
21
26
|
|
22
27
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
|
23
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import
|
28
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
29
|
+
CehrGptDataCollator,
|
30
|
+
SamplePackingCehrGptDataCollator,
|
31
|
+
)
|
24
32
|
from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
|
25
33
|
from cehrgpt.models.config import CEHRGPTConfig
|
26
34
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
27
35
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
28
36
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
37
|
+
from cehrgpt.runners.data_utils import get_torch_dtype
|
29
38
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
30
39
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
40
|
+
from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
|
31
41
|
|
32
42
|
LOG = logging.get_logger("transformers")
|
33
43
|
|
34
44
|
|
45
|
+
class CustomEarlyStoppingCallback(EarlyStoppingCallback):
|
46
|
+
def check_metric_value(self, args, state, control, metric_value):
|
47
|
+
# best_metric is set by code for load_best_model
|
48
|
+
operator = np.greater if args.greater_is_better else np.less
|
49
|
+
if state.best_metric is None or (
|
50
|
+
operator(metric_value, state.best_metric)
|
51
|
+
and abs(metric_value - state.best_metric) / state.best_metric
|
52
|
+
> self.early_stopping_threshold
|
53
|
+
):
|
54
|
+
self.early_stopping_patience_counter = 0
|
55
|
+
else:
|
56
|
+
self.early_stopping_patience_counter += 1
|
57
|
+
|
58
|
+
|
35
59
|
def tokenizer_exists(tokenizer_name_or_path: str) -> bool:
|
36
60
|
# Try to load the pretrained tokenizer
|
37
61
|
try:
|
@@ -48,6 +72,36 @@ def load_and_create_tokenizer(
|
|
48
72
|
cehrgpt_args: CehrGPTArguments,
|
49
73
|
dataset: Optional[Union[Dataset, DatasetDict]] = None,
|
50
74
|
) -> CehrGptTokenizer:
|
75
|
+
|
76
|
+
concept_name_mapping = {}
|
77
|
+
allowed_motor_codes = list()
|
78
|
+
if cehrgpt_args.concept_dir:
|
79
|
+
import pandas as pd
|
80
|
+
from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
|
81
|
+
from meds.schema import death_code
|
82
|
+
|
83
|
+
LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
|
84
|
+
concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
|
85
|
+
LOG.info(
|
86
|
+
"Creating concept name mapping and motor_time_to_event_codes from disk at %s",
|
87
|
+
cehrgpt_args.concept_dir,
|
88
|
+
)
|
89
|
+
for row in concept_pd.itertuples():
|
90
|
+
concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
|
91
|
+
row, "concept_name"
|
92
|
+
)
|
93
|
+
if (
|
94
|
+
cehrgpt_args.include_motor_time_to_event
|
95
|
+
and getattr(row, "domain_id")
|
96
|
+
in ["Condition", "Procedure", "Drug", "Visit"]
|
97
|
+
and getattr(row, "standard_concept") == "S"
|
98
|
+
):
|
99
|
+
allowed_motor_codes.append(str(getattr(row, "concept_id")))
|
100
|
+
LOG.info(
|
101
|
+
"Adding death codes for MOTOR TTE predictions: %s",
|
102
|
+
[DEATH_TOKEN, death_code],
|
103
|
+
)
|
104
|
+
allowed_motor_codes.extend([DEATH_TOKEN, death_code])
|
51
105
|
# Try to load the pretrained tokenizer
|
52
106
|
tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
|
53
107
|
try:
|
@@ -59,13 +113,24 @@ def load_and_create_tokenizer(
|
|
59
113
|
f"Failed to load the tokenizer from {tokenizer_abspath} with the error \n{e}\n"
|
60
114
|
f"Tried to create the tokenizer, however the dataset is not provided."
|
61
115
|
)
|
116
|
+
LOG.info("Started training the tokenizer ...")
|
62
117
|
tokenizer = CehrGptTokenizer.train_tokenizer(
|
63
118
|
dataset,
|
64
|
-
|
119
|
+
concept_name_mapping,
|
65
120
|
data_args,
|
66
121
|
PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
|
122
|
+
allowed_motor_codes if cehrgpt_args.include_motor_time_to_event else None,
|
123
|
+
(
|
124
|
+
cehrgpt_args.num_motor_tasks
|
125
|
+
if cehrgpt_args.include_motor_time_to_event
|
126
|
+
else None
|
127
|
+
),
|
128
|
+
apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
|
129
|
+
min_prevalence=cehrgpt_args.min_prevalence,
|
67
130
|
)
|
131
|
+
LOG.info("Finished training the tokenizer ...")
|
68
132
|
tokenizer.save_pretrained(tokenizer_abspath)
|
133
|
+
LOG.info("Saved the tokenizer to %s", tokenizer_abspath)
|
69
134
|
|
70
135
|
return tokenizer
|
71
136
|
|
@@ -73,13 +138,12 @@ def load_and_create_tokenizer(
|
|
73
138
|
def load_and_create_model(
|
74
139
|
model_args: ModelArguments,
|
75
140
|
cehrgpt_args: CehrGPTArguments,
|
76
|
-
training_args: TrainingArguments,
|
77
141
|
tokenizer: CehrGptTokenizer,
|
78
142
|
) -> CEHRGPT2LMHeadModel:
|
79
143
|
attn_implementation = (
|
80
144
|
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
81
145
|
)
|
82
|
-
torch_dtype =
|
146
|
+
torch_dtype = get_torch_dtype(model_args.torch_dtype)
|
83
147
|
model_abspath = os.path.expanduser(model_args.model_name_or_path)
|
84
148
|
if cehrgpt_args.continue_pretrain:
|
85
149
|
try:
|
@@ -120,6 +184,9 @@ def load_and_create_model(
|
|
120
184
|
pretrained_embedding_dim = tokenizer.pretrained_embeddings.shape[1]
|
121
185
|
else:
|
122
186
|
pretrained_embedding_dim = model_args.hidden_size
|
187
|
+
|
188
|
+
model_args_cehrgpt = model_args.as_dict()
|
189
|
+
model_args_cehrgpt.pop("attn_implementation")
|
123
190
|
model_config = CEHRGPTConfig(
|
124
191
|
vocab_size=tokenizer.vocab_size,
|
125
192
|
value_vocab_size=tokenizer.value_vocab_size,
|
@@ -131,15 +198,28 @@ def load_and_create_model(
|
|
131
198
|
attn_implementation=attn_implementation,
|
132
199
|
causal_sfm=cehrgpt_args.causal_sfm,
|
133
200
|
demographics_size=cehrgpt_args.demographics_size,
|
201
|
+
next_token_prediction_loss_weight=cehrgpt_args.next_token_prediction_loss_weight,
|
134
202
|
lab_token_penalty=cehrgpt_args.lab_token_penalty,
|
135
203
|
lab_token_loss_weight=cehrgpt_args.lab_token_loss_weight,
|
204
|
+
value_prediction_loss_weight=cehrgpt_args.value_prediction_loss_weight,
|
136
205
|
entropy_penalty=cehrgpt_args.entropy_penalty,
|
137
206
|
entropy_penalty_alpha=cehrgpt_args.entropy_penalty_alpha,
|
138
207
|
n_pretrained_embeddings_layers=cehrgpt_args.n_pretrained_embeddings_layers,
|
139
208
|
use_pretrained_embeddings=len(tokenizer.pretrained_token_ids) > 0,
|
140
209
|
pretrained_embedding_dim=pretrained_embedding_dim,
|
141
|
-
|
210
|
+
sample_packing_max_positions=(
|
211
|
+
cehrgpt_args.max_tokens_per_batch
|
212
|
+
if cehrgpt_args.sample_packing
|
213
|
+
else model_args.max_position_embeddings
|
214
|
+
),
|
215
|
+
include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
|
216
|
+
motor_tte_vocab_size=tokenizer.motor_tte_vocab_size,
|
217
|
+
motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
|
218
|
+
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
219
|
+
ve_token_id=tokenizer.ve_token_id,
|
220
|
+
**model_args_cehrgpt,
|
142
221
|
)
|
222
|
+
|
143
223
|
model = CEHRGPT2LMHeadModel(model_config)
|
144
224
|
if tokenizer.pretrained_token_ids:
|
145
225
|
model.cehrgpt.update_pretrained_embeddings(
|
@@ -156,6 +236,11 @@ def load_and_create_model(
|
|
156
236
|
def main():
|
157
237
|
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
158
238
|
|
239
|
+
if cehrgpt_args.sample_packing and data_args.streaming:
|
240
|
+
raise RuntimeError(
|
241
|
+
f"sample_packing is not supported when streaming is enabled, please set streaming to False"
|
242
|
+
)
|
243
|
+
|
159
244
|
if data_args.streaming:
|
160
245
|
# This is for disabling the warning message https://github.com/huggingface/transformers/issues/5486
|
161
246
|
# This happens only when streaming is enabled
|
@@ -165,6 +250,8 @@ def main():
|
|
165
250
|
training_args.dataloader_num_workers = 0
|
166
251
|
training_args.dataloader_prefetch_factor = None
|
167
252
|
|
253
|
+
processed_dataset: Optional[DatasetDict] = None
|
254
|
+
cache_file_collector = CacheFileCollector()
|
168
255
|
prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
|
169
256
|
if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
|
170
257
|
LOG.info(f"Loading prepared dataset from disk at {data_args.data_folder}...")
|
@@ -200,118 +287,160 @@ def main():
|
|
200
287
|
)
|
201
288
|
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
202
289
|
else:
|
203
|
-
#
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
"Trying to load the MEDS extension from disk at %s...",
|
212
|
-
meds_extension_path,
|
290
|
+
# Only run tokenization and data transformation in the main process in torch distributed training
|
291
|
+
# otherwise the multiple processes will create tokenizers at the same time
|
292
|
+
if is_main_process(training_args.local_rank):
|
293
|
+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
294
|
+
if data_args.is_data_in_meds:
|
295
|
+
meds_extension_path = get_meds_extension_path(
|
296
|
+
data_folder=data_args.data_folder,
|
297
|
+
dataset_prepared_path=data_args.dataset_prepared_path,
|
213
298
|
)
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
299
|
+
try:
|
300
|
+
LOG.info(
|
301
|
+
"Trying to load the MEDS extension from disk at %s...",
|
302
|
+
meds_extension_path,
|
303
|
+
)
|
304
|
+
dataset = load_from_disk(meds_extension_path)
|
305
|
+
if data_args.streaming:
|
306
|
+
if isinstance(dataset, DatasetDict):
|
307
|
+
dataset = {
|
308
|
+
k: v.to_iterable_dataset(
|
309
|
+
num_shards=training_args.dataloader_num_workers
|
310
|
+
)
|
311
|
+
for k, v in dataset.items()
|
312
|
+
}
|
313
|
+
else:
|
314
|
+
dataset = dataset.to_iterable_dataset(
|
219
315
|
num_shards=training_args.dataloader_num_workers
|
220
316
|
)
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
317
|
+
except FileNotFoundError as e:
|
318
|
+
LOG.warning(e)
|
319
|
+
dataset = create_dataset_from_meds_reader(
|
320
|
+
data_args=data_args,
|
321
|
+
dataset_mappings=[
|
322
|
+
MedToCehrGPTDatasetMapping(
|
323
|
+
data_args=data_args,
|
324
|
+
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
325
|
+
)
|
326
|
+
],
|
327
|
+
cache_file_collector=cache_file_collector,
|
328
|
+
)
|
329
|
+
if not data_args.streaming:
|
330
|
+
dataset.save_to_disk(str(meds_extension_path))
|
331
|
+
stats = dataset.cleanup_cache_files()
|
332
|
+
LOG.info(
|
333
|
+
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
334
|
+
stats,
|
236
335
|
)
|
237
|
-
|
336
|
+
# Clean up the files created from the data generator
|
337
|
+
cache_file_collector.remove_cache_files()
|
338
|
+
dataset = load_from_disk(str(meds_extension_path))
|
339
|
+
else:
|
340
|
+
# Load the dataset from the parquet files
|
341
|
+
dataset = load_parquet_as_dataset(
|
342
|
+
os.path.expanduser(data_args.data_folder),
|
343
|
+
split="train",
|
344
|
+
streaming=data_args.streaming,
|
238
345
|
)
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
244
|
-
stats,
|
346
|
+
# If streaming is enabled, we need to manually split the data into train/val
|
347
|
+
if data_args.streaming and data_args.validation_split_num:
|
348
|
+
dataset = dataset.shuffle(
|
349
|
+
buffer_size=10_000, seed=training_args.seed
|
245
350
|
)
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
351
|
+
train_set = dataset.skip(data_args.validation_split_num)
|
352
|
+
val_set = dataset.take(data_args.validation_split_num)
|
353
|
+
dataset = DatasetDict({"train": train_set, "validation": val_set})
|
354
|
+
elif data_args.validation_split_percentage:
|
355
|
+
dataset = dataset.train_test_split(
|
356
|
+
test_size=data_args.validation_split_percentage,
|
357
|
+
seed=training_args.seed,
|
358
|
+
)
|
359
|
+
dataset = DatasetDict(
|
360
|
+
{"train": dataset["train"], "validation": dataset["test"]}
|
361
|
+
)
|
362
|
+
else:
|
363
|
+
raise RuntimeError(
|
364
|
+
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
365
|
+
f"defined, otherwise validation_split_percentage needs to be provided. "
|
366
|
+
f"The current values are:\n"
|
367
|
+
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
368
|
+
f"validation_split_num: {data_args.validation_split_num}\n"
|
369
|
+
f"streaming: {data_args.streaming}"
|
370
|
+
)
|
371
|
+
|
372
|
+
# Create the CEHR-GPT tokenizer if it's not available in the output folder
|
373
|
+
cehrgpt_tokenizer = load_and_create_tokenizer(
|
374
|
+
data_args=data_args,
|
375
|
+
model_args=model_args,
|
376
|
+
cehrgpt_args=cehrgpt_args,
|
377
|
+
dataset=dataset,
|
253
378
|
)
|
254
|
-
# If streaming is enabled, we need to manually split the data into train/val
|
255
|
-
if data_args.streaming and data_args.validation_split_num:
|
256
|
-
dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
|
257
|
-
train_set = dataset.skip(data_args.validation_split_num)
|
258
|
-
val_set = dataset.take(data_args.validation_split_num)
|
259
|
-
dataset = DatasetDict({"train": train_set, "test": val_set})
|
260
|
-
elif data_args.validation_split_percentage:
|
261
|
-
dataset = dataset.train_test_split(
|
262
|
-
test_size=data_args.validation_split_percentage,
|
263
|
-
seed=training_args.seed,
|
264
|
-
)
|
265
|
-
else:
|
266
|
-
raise RuntimeError(
|
267
|
-
f"Can not split the data. If streaming is enabled, validation_split_num needs to be "
|
268
|
-
f"defined, otherwise validation_split_percentage needs to be provided. "
|
269
|
-
f"The current values are:\n"
|
270
|
-
f"validation_split_percentage: {data_args.validation_split_percentage}\n"
|
271
|
-
f"validation_split_num: {data_args.validation_split_num}\n"
|
272
|
-
f"streaming: {data_args.streaming}"
|
273
|
-
)
|
274
379
|
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
380
|
+
# Retrain the tokenizer in case we want to pretrain the model further using different datasets
|
381
|
+
if cehrgpt_args.expand_tokenizer:
|
382
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
383
|
+
try:
|
384
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
385
|
+
new_tokenizer_path
|
386
|
+
)
|
387
|
+
except Exception:
|
388
|
+
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
389
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
390
|
+
dataset=dataset["train"],
|
391
|
+
data_args=data_args,
|
392
|
+
concept_name_mapping={},
|
393
|
+
pretrained_concept_embedding_model=PretrainedEmbeddings(
|
394
|
+
cehrgpt_args.pretrained_embedding_path
|
395
|
+
),
|
396
|
+
apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
|
397
|
+
min_prevalence=cehrgpt_args.min_prevalence,
|
398
|
+
)
|
399
|
+
cehrgpt_tokenizer.save_pretrained(
|
400
|
+
os.path.expanduser(training_args.output_dir)
|
401
|
+
)
|
402
|
+
|
403
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
404
|
+
if not data_args.streaming:
|
405
|
+
all_columns = dataset["train"].column_names
|
406
|
+
if "visit_concept_ids" in all_columns:
|
407
|
+
dataset = dataset.remove_columns(["visit_concept_ids"])
|
408
|
+
|
409
|
+
# sort the patient features chronologically and tokenize the data
|
410
|
+
processed_dataset = create_cehrgpt_pretraining_dataset(
|
411
|
+
dataset=dataset,
|
412
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
413
|
+
data_args=data_args,
|
414
|
+
cache_file_collector=cache_file_collector,
|
415
|
+
)
|
416
|
+
# only save the data to the disk if it is not streaming
|
417
|
+
if not data_args.streaming:
|
418
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
419
|
+
stats = processed_dataset.cleanup_cache_files()
|
420
|
+
LOG.info(
|
421
|
+
"Clean up the cached files for the cehrgpt pretraining dataset: %s",
|
422
|
+
stats,
|
299
423
|
)
|
424
|
+
cache_file_collector.remove_cache_files()
|
425
|
+
|
426
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
427
|
+
if dist.is_available() and dist.is_initialized():
|
428
|
+
dist.barrier()
|
300
429
|
|
301
|
-
#
|
302
|
-
|
303
|
-
|
430
|
+
# Loading tokenizer in all processes in torch distributed training
|
431
|
+
tokenizer_name_or_path = os.path.expanduser(
|
432
|
+
training_args.output_dir
|
433
|
+
if cehrgpt_args.expand_tokenizer
|
434
|
+
else model_args.tokenizer_name_or_path
|
304
435
|
)
|
305
|
-
|
436
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(tokenizer_name_or_path)
|
437
|
+
# Load the dataset from disk again to in torch distributed training
|
306
438
|
if not data_args.streaming:
|
307
|
-
processed_dataset.save_to_disk(str(prepared_ds_path))
|
308
|
-
stats = processed_dataset.cleanup_cache_files()
|
309
|
-
LOG.info(
|
310
|
-
"Clean up the cached files for the cehrgpt pretraining dataset: %s",
|
311
|
-
stats,
|
312
|
-
)
|
313
439
|
processed_dataset = load_from_disk(str(prepared_ds_path))
|
314
440
|
|
441
|
+
if processed_dataset is None:
|
442
|
+
raise RuntimeError("The processed dataset cannot be None")
|
443
|
+
|
315
444
|
def filter_func(examples):
|
316
445
|
if cehrgpt_args.drop_long_sequences:
|
317
446
|
return [
|
@@ -339,9 +468,11 @@ def main():
|
|
339
468
|
else:
|
340
469
|
processed_dataset = processed_dataset.filter(filter_func, **filter_args)
|
341
470
|
|
342
|
-
model = load_and_create_model(
|
343
|
-
|
344
|
-
|
471
|
+
model = load_and_create_model(model_args, cehrgpt_args, cehrgpt_tokenizer)
|
472
|
+
|
473
|
+
# Try to update motor tte vocab size if the new configuration is different from the existing one
|
474
|
+
if cehrgpt_args.include_motor_time_to_event:
|
475
|
+
model.update_motor_tte_vocab_size(cehrgpt_tokenizer.motor_tte_vocab_size)
|
345
476
|
|
346
477
|
# Expand tokenizer to adapt to the new pretraining dataset
|
347
478
|
if model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
@@ -369,22 +500,67 @@ def main():
|
|
369
500
|
# Set seed before initializing model.
|
370
501
|
set_seed(training_args.seed)
|
371
502
|
|
372
|
-
if not data_args.streaming:
|
503
|
+
if not data_args.streaming and not cehrgpt_args.sample_packing:
|
373
504
|
processed_dataset.set_format("pt")
|
374
505
|
|
375
|
-
|
506
|
+
callbacks = []
|
507
|
+
if cehrgpt_args.use_early_stopping:
|
508
|
+
callbacks.append(
|
509
|
+
CustomEarlyStoppingCallback(
|
510
|
+
model_args.early_stopping_patience,
|
511
|
+
cehrgpt_args.early_stopping_threshold,
|
512
|
+
)
|
513
|
+
)
|
514
|
+
|
515
|
+
if cehrgpt_args.sample_packing:
|
516
|
+
trainer_class = partial(
|
517
|
+
SamplePackingTrainer,
|
518
|
+
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
519
|
+
max_position_embeddings=model_args.max_position_embeddings,
|
520
|
+
train_lengths=processed_dataset["train"]["num_of_concepts"],
|
521
|
+
validation_lengths=(
|
522
|
+
processed_dataset["validation"]
|
523
|
+
if "validation" in processed_dataset
|
524
|
+
else processed_dataset["test"]
|
525
|
+
)["num_of_concepts"],
|
526
|
+
)
|
527
|
+
training_args.per_device_train_batch_size = 1
|
528
|
+
training_args.per_device_eval_batch_size = 1
|
529
|
+
data_collator_fn = partial(
|
530
|
+
SamplePackingCehrGptDataCollator,
|
531
|
+
cehrgpt_args.max_tokens_per_batch,
|
532
|
+
model_args.max_position_embeddings,
|
533
|
+
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
534
|
+
)
|
535
|
+
else:
|
536
|
+
trainer_class = Trainer
|
537
|
+
data_collator_fn = CehrGptDataCollator
|
538
|
+
|
539
|
+
trainer = trainer_class(
|
376
540
|
model=model,
|
377
|
-
data_collator=
|
541
|
+
data_collator=data_collator_fn(
|
378
542
|
tokenizer=cehrgpt_tokenizer,
|
379
|
-
max_length=
|
543
|
+
max_length=(
|
544
|
+
cehrgpt_args.max_tokens_per_batch
|
545
|
+
if cehrgpt_args.sample_packing
|
546
|
+
else model_args.max_position_embeddings
|
547
|
+
),
|
380
548
|
shuffle_records=data_args.shuffle_records,
|
381
549
|
include_ttv_prediction=model_args.include_ttv_prediction,
|
382
550
|
use_sub_time_tokenization=model_args.use_sub_time_tokenization,
|
383
551
|
include_values=model_args.include_values,
|
552
|
+
include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
|
553
|
+
motor_tte_vocab_size=model.config.motor_tte_vocab_size,
|
554
|
+
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
384
555
|
),
|
385
556
|
train_dataset=processed_dataset["train"],
|
386
|
-
eval_dataset=
|
557
|
+
eval_dataset=(
|
558
|
+
processed_dataset["validation"]
|
559
|
+
if "validation" in processed_dataset
|
560
|
+
else processed_dataset["test"]
|
561
|
+
),
|
387
562
|
args=training_args,
|
563
|
+
callbacks=callbacks,
|
388
564
|
)
|
389
565
|
|
390
566
|
checkpoint = None
|
@@ -6,6 +6,12 @@ from typing import List, Optional
|
|
6
6
|
class CehrGPTArguments:
|
7
7
|
"""Arguments pertaining to what data we are going to input our model for training and eval."""
|
8
8
|
|
9
|
+
tokenized_full_dataset_path: Optional[str] = dataclasses.field(
|
10
|
+
default=None,
|
11
|
+
metadata={
|
12
|
+
"help": "The path to the tokenized dataset created for the full population"
|
13
|
+
},
|
14
|
+
)
|
9
15
|
include_inpatient_hour_token: Optional[bool] = dataclasses.field(
|
10
16
|
default=True,
|
11
17
|
metadata={"help": "Include inpatient hour token"},
|
@@ -115,6 +121,9 @@ class CehrGPTArguments:
|
|
115
121
|
"help": "The lower bound of the learning rate range for hyperparameter tuning."
|
116
122
|
},
|
117
123
|
)
|
124
|
+
next_token_prediction_loss_weight: float = dataclasses.field(
|
125
|
+
default=1.0, metadata={"help": "The weight of the next token prediction loss"}
|
126
|
+
)
|
118
127
|
lab_token_penalty: Optional[bool] = dataclasses.field(
|
119
128
|
default=False,
|
120
129
|
metadata={
|
@@ -125,6 +134,10 @@ class CehrGPTArguments:
|
|
125
134
|
default=1.0,
|
126
135
|
metadata={"help": "lab_token_loss_weight penalty co-efficient"},
|
127
136
|
)
|
137
|
+
value_prediction_loss_weight: Optional[float] = dataclasses.field(
|
138
|
+
default=1.0,
|
139
|
+
metadata={"help": "The weight of the value prediction loss"},
|
140
|
+
)
|
128
141
|
entropy_penalty: Optional[bool] = dataclasses.field(
|
129
142
|
default=False,
|
130
143
|
metadata={"help": "A flag to indicate whether we want to use entropy penalty."},
|
@@ -139,3 +152,80 @@ class CehrGPTArguments:
|
|
139
152
|
"help": "The number of feed forward layers for transforming pretrained embeddings to internal embeddings"
|
140
153
|
},
|
141
154
|
)
|
155
|
+
meds_repartition: Optional[bool] = dataclasses.field(
|
156
|
+
default=False,
|
157
|
+
metadata={
|
158
|
+
"help": "A flag to indicate whether we want to repartition the meds train tune sets"
|
159
|
+
},
|
160
|
+
)
|
161
|
+
use_early_stopping: Optional[bool] = dataclasses.field(
|
162
|
+
default=True,
|
163
|
+
metadata={"help": "A flag to indicate whether we want to use early stopping."},
|
164
|
+
)
|
165
|
+
early_stopping_threshold: Optional[float] = dataclasses.field(
|
166
|
+
default=0.01,
|
167
|
+
metadata={
|
168
|
+
"help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
|
169
|
+
},
|
170
|
+
)
|
171
|
+
sample_packing: Optional[bool] = dataclasses.field(
|
172
|
+
default=False,
|
173
|
+
metadata={
|
174
|
+
"help": "A flag to indicate whether we want to use sample packing for efficient training."
|
175
|
+
},
|
176
|
+
)
|
177
|
+
max_tokens_per_batch: int = dataclasses.field(
|
178
|
+
default=16384, metadata={"help": "Maximum number of tokens in each batch"}
|
179
|
+
)
|
180
|
+
add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
|
181
|
+
default=False,
|
182
|
+
metadata={
|
183
|
+
"help": "A flag to indicate whether we want to add end token in sample packing"
|
184
|
+
},
|
185
|
+
)
|
186
|
+
include_motor_time_to_event: Optional[bool] = dataclasses.field(
|
187
|
+
default=False,
|
188
|
+
metadata={
|
189
|
+
"help": "A flag to indicate whether we want to include the motor time to events"
|
190
|
+
},
|
191
|
+
)
|
192
|
+
num_motor_tasks: Optional[int] = dataclasses.field(
|
193
|
+
default=10000,
|
194
|
+
metadata={"help": "The number of max MOTOR tasks"},
|
195
|
+
)
|
196
|
+
motor_time_to_event_weight: Optional[float] = dataclasses.field(
|
197
|
+
default=1.0,
|
198
|
+
metadata={"help": "The MOTOR time to event loss weight"},
|
199
|
+
)
|
200
|
+
motor_num_time_pieces: Optional[int] = dataclasses.field(
|
201
|
+
default=8,
|
202
|
+
metadata={
|
203
|
+
"help": "The number of times each motor_num_time_pieces piece has to be"
|
204
|
+
},
|
205
|
+
)
|
206
|
+
concept_dir: Optional[str] = dataclasses.field(
|
207
|
+
default=None,
|
208
|
+
metadata={"help": "The directory where the concept data is stored."},
|
209
|
+
)
|
210
|
+
average_over_sequence: bool = dataclasses.field(
|
211
|
+
default=False,
|
212
|
+
metadata={"help": "Whether or not to average tokens per sequence"},
|
213
|
+
)
|
214
|
+
apply_entropy_filter: Optional[bool] = dataclasses.field(
|
215
|
+
default=False,
|
216
|
+
metadata={"help": "A flag to indicate whether we want to use entropy filter."},
|
217
|
+
)
|
218
|
+
min_prevalence: Optional[float] = dataclasses.field(
|
219
|
+
default=1 / 1000,
|
220
|
+
metadata={"help": "The min_prevalence to keep the concepts in the tokenizer"},
|
221
|
+
)
|
222
|
+
class_weights: Optional[List[int]] = dataclasses.field(
|
223
|
+
default=None,
|
224
|
+
metadata={"help": "The class weights for training"},
|
225
|
+
)
|
226
|
+
negative_sampling_probability: Optional[float] = dataclasses.field(
|
227
|
+
default=None,
|
228
|
+
metadata={
|
229
|
+
"help": "The probability of negative samples will be included in the training data"
|
230
|
+
},
|
231
|
+
)
|