cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +244 -39
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +131 -5
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
cehrgpt/runners/data_utils.py
CHANGED
@@ -1,4 +1,10 @@
|
|
1
|
+
import os
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Dict, List, Optional, Union
|
4
|
+
|
1
5
|
import numpy as np
|
6
|
+
import polars as pl
|
7
|
+
import torch
|
2
8
|
from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
|
3
9
|
from cehrbert.data_generators.hf_data_generator.meds_utils import (
|
4
10
|
create_dataset_from_meds_reader,
|
@@ -12,17 +18,36 @@ from datasets import DatasetDict, concatenate_datasets, load_from_disk
|
|
12
18
|
from transformers import TrainingArguments
|
13
19
|
from transformers.utils import logging
|
14
20
|
|
15
|
-
from cehrgpt.data.hf_cehrgpt_dataset_mapping import
|
21
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import (
|
22
|
+
ExtractTokenizedSequenceDataMapping,
|
23
|
+
MedToCehrGPTDatasetMapping,
|
24
|
+
)
|
16
25
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
17
26
|
|
18
27
|
LOG = logging.get_logger("transformers")
|
19
28
|
|
20
29
|
|
30
|
+
def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
|
31
|
+
if torch_dtype and hasattr(torch, torch_dtype):
|
32
|
+
return getattr(torch, torch_dtype)
|
33
|
+
return torch.float
|
34
|
+
|
35
|
+
|
36
|
+
def data_collate_fn(features, model_type: torch.dtype, collator):
|
37
|
+
batch = collator(features)
|
38
|
+
if model_type != torch.float32:
|
39
|
+
for key, value in batch.items():
|
40
|
+
# Only convert float32 tensors to bfloat16
|
41
|
+
if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
|
42
|
+
batch[key] = value.to(model_type)
|
43
|
+
return batch
|
44
|
+
|
45
|
+
|
21
46
|
def prepare_finetune_dataset(
|
22
47
|
data_args: DataTrainingArguments,
|
23
48
|
training_args: TrainingArguments,
|
24
49
|
cehrgpt_args: CehrGPTArguments,
|
25
|
-
cache_file_collector: CacheFileCollector,
|
50
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
26
51
|
) -> DatasetDict:
|
27
52
|
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
|
28
53
|
if data_args.is_data_in_meds:
|
@@ -66,8 +91,9 @@ def prepare_finetune_dataset(
|
|
66
91
|
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
67
92
|
stats,
|
68
93
|
)
|
69
|
-
|
70
|
-
|
94
|
+
if cache_file_collector:
|
95
|
+
# Clean up the files created from the data generator
|
96
|
+
cache_file_collector.remove_cache_files()
|
71
97
|
dataset = load_from_disk(str(meds_extension_path))
|
72
98
|
|
73
99
|
train_set = dataset["train"]
|
@@ -219,7 +245,7 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
|
219
245
|
)
|
220
246
|
|
221
247
|
# Generate splits
|
222
|
-
train_set = filter_by_patient_ids(train_patient_ids)
|
248
|
+
train_set = filter_by_patient_ids(train_patient_ids).shuffle(seed=seed)
|
223
249
|
validation_set = filter_by_patient_ids(val_patient_ids)
|
224
250
|
if test_set is None:
|
225
251
|
test_set = filter_by_patient_ids(test_patient_ids)
|
@@ -241,3 +267,103 @@ def create_dataset_splits(data_args: DataTrainingArguments, seed: int):
|
|
241
267
|
)
|
242
268
|
|
243
269
|
return train_set, validation_set, test_set
|
270
|
+
|
271
|
+
|
272
|
+
def extract_cohort_sequences(
|
273
|
+
data_args: DataTrainingArguments,
|
274
|
+
cehrgpt_args: CehrGPTArguments,
|
275
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
276
|
+
) -> DatasetDict:
|
277
|
+
"""
|
278
|
+
Extracts and processes cohort-specific tokenized sequences from a pre-tokenized dataset,.
|
279
|
+
|
280
|
+
based on the provided cohort Parquet files and observation window constraints.
|
281
|
+
|
282
|
+
This function performs the following steps:
|
283
|
+
1. Loads cohort definitions from Parquet files located in `data_args.cohort_folder`.
|
284
|
+
2. Renames relevant columns if the data originates from a Meds format.
|
285
|
+
3. Filters a pre-tokenized dataset (loaded from `cehrgpt_args.tokenized_full_dataset_path`)
|
286
|
+
to include only patients present in the cohort.
|
287
|
+
4. Aggregates each person's index date and label into a mapping.
|
288
|
+
5. Checks for consistency to ensure all cohort person_ids are present in the tokenized dataset.
|
289
|
+
6. Applies a transformation (`ExtractTokenizedSequenceDataMapping`) to generate
|
290
|
+
observation-window-constrained patient sequences.
|
291
|
+
7. Caches both the filtered and processed datasets using the provided `cache_file_collector`.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
data_args (DataTrainingArguments): Configuration parameters for data processing,
|
295
|
+
including cohort folder, observation window, batch size, and parallelism.
|
296
|
+
cehrgpt_args (CehrGPTArguments): Contains paths to pre-tokenized datasets and CEHR-GPT-specific arguments.
|
297
|
+
cache_file_collector (CacheFileCollector): Utility to register and manage dataset cache files.
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
DatasetDict: A Hugging Face `DatasetDict` containing the processed datasets (e.g., train/validation/test),
|
301
|
+
where each entry includes sequences filtered and truncated by the observation window.
|
302
|
+
|
303
|
+
Raises:
|
304
|
+
RuntimeError: If any `person_id` in the cohort is missing from the tokenized dataset.
|
305
|
+
"""
|
306
|
+
|
307
|
+
cohort = pl.read_parquet(os.path.join(data_args.cohort_folder, "*.parquet"))
|
308
|
+
if data_args.is_data_in_meds:
|
309
|
+
cohort = cohort.rename(
|
310
|
+
mapping={
|
311
|
+
"prediction_time": "index_date",
|
312
|
+
"subject_id": "person_id",
|
313
|
+
"boolean_value": "label",
|
314
|
+
}
|
315
|
+
)
|
316
|
+
all_person_ids = cohort["person_id"].unique().to_list()
|
317
|
+
# In case the label column does not exist, we add a fake column to the dataframe so subsequent process can work
|
318
|
+
if "label" not in cohort.columns:
|
319
|
+
cohort = cohort.with_columns(
|
320
|
+
pl.Series(
|
321
|
+
name="label", values=np.zeros_like(cohort["person_id"].to_numpy())
|
322
|
+
)
|
323
|
+
)
|
324
|
+
|
325
|
+
# data_args.observation_window
|
326
|
+
tokenized_dataset = load_from_disk(cehrgpt_args.tokenized_full_dataset_path)
|
327
|
+
filtered_tokenized_dataset = tokenized_dataset.filter(
|
328
|
+
lambda batch: [person_id in all_person_ids for person_id in batch["person_id"]],
|
329
|
+
batched=True,
|
330
|
+
batch_size=data_args.preprocessing_batch_size,
|
331
|
+
num_proc=data_args.preprocessing_num_workers,
|
332
|
+
)
|
333
|
+
person_index_date_agg = cohort.group_by("person_id").agg(
|
334
|
+
pl.struct("index_date", "label").alias("index_date_label")
|
335
|
+
)
|
336
|
+
# Convert to dictionary
|
337
|
+
person_index_date_map: Dict[int, List[datetime]] = dict(
|
338
|
+
zip(
|
339
|
+
person_index_date_agg["person_id"].to_list(),
|
340
|
+
person_index_date_agg["index_date_label"].to_list(),
|
341
|
+
)
|
342
|
+
)
|
343
|
+
LOG.info(f"person_index_date_agg: {person_index_date_agg}")
|
344
|
+
tokenized_person_ids = []
|
345
|
+
for _, dataset in filtered_tokenized_dataset.items():
|
346
|
+
tokenized_person_ids.extend(dataset["person_id"])
|
347
|
+
missing_person_ids = [
|
348
|
+
person_id
|
349
|
+
for person_id in person_index_date_map.keys()
|
350
|
+
if person_id not in tokenized_person_ids
|
351
|
+
]
|
352
|
+
if missing_person_ids:
|
353
|
+
raise RuntimeError(
|
354
|
+
f"There are {len(missing_person_ids)} missing in the tokenized dataset. "
|
355
|
+
f"The list contains: {missing_person_ids}"
|
356
|
+
)
|
357
|
+
processed_dataset = filtered_tokenized_dataset.map(
|
358
|
+
ExtractTokenizedSequenceDataMapping(
|
359
|
+
person_index_date_map, data_args.observation_window
|
360
|
+
).batch_transform,
|
361
|
+
batched=True,
|
362
|
+
batch_size=data_args.preprocessing_batch_size,
|
363
|
+
num_proc=data_args.preprocessing_num_workers,
|
364
|
+
remove_columns=filtered_tokenized_dataset["train"].column_names,
|
365
|
+
)
|
366
|
+
if cache_file_collector:
|
367
|
+
cache_file_collector.add_cache_files(filtered_tokenized_dataset)
|
368
|
+
cache_file_collector.add_cache_files(processed_dataset)
|
369
|
+
return processed_dataset
|
@@ -50,7 +50,11 @@ from cehrgpt.models.hf_cehrgpt import (
|
|
50
50
|
)
|
51
51
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
52
52
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
53
|
-
from cehrgpt.runners.data_utils import
|
53
|
+
from cehrgpt.runners.data_utils import (
|
54
|
+
extract_cohort_sequences,
|
55
|
+
get_torch_dtype,
|
56
|
+
prepare_finetune_dataset,
|
57
|
+
)
|
54
58
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
55
59
|
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
56
60
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
@@ -142,11 +146,10 @@ def load_finetuned_model(
|
|
142
146
|
raise ValueError(
|
143
147
|
f"finetune_model_type can be one of the following types {FineTuneModelType.POOLING.value}"
|
144
148
|
)
|
145
|
-
|
146
149
|
attn_implementation = (
|
147
150
|
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
148
151
|
)
|
149
|
-
torch_dtype =
|
152
|
+
torch_dtype = get_torch_dtype(model_args.torch_dtype)
|
150
153
|
# Try to create a new model based on the base model
|
151
154
|
try:
|
152
155
|
return finetune_model_cls.from_pretrained(
|
@@ -161,11 +164,22 @@ def load_finetuned_model(
|
|
161
164
|
def model_init(
|
162
165
|
model_args: ModelArguments,
|
163
166
|
training_args: TrainingArguments,
|
167
|
+
cehrgpt_args: CehrGPTArguments,
|
164
168
|
tokenizer: CehrGptTokenizer,
|
165
169
|
):
|
166
170
|
model = load_finetuned_model(
|
167
171
|
model_args, training_args, model_args.model_name_or_path
|
168
172
|
)
|
173
|
+
|
174
|
+
if cehrgpt_args.class_weights:
|
175
|
+
model.config.class_weights = cehrgpt_args.class_weights
|
176
|
+
LOG.info(f"Setting class_weights to {model.config.class_weights}")
|
177
|
+
|
178
|
+
# Enable position embeddings when position embeddings are disabled in pre-training
|
179
|
+
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
180
|
+
LOG.info(f"Enable the position_embeddings")
|
181
|
+
model.cehrgpt.enable_position_embeddings()
|
182
|
+
|
169
183
|
if model.config.max_position_embeddings < model_args.max_position_embeddings:
|
170
184
|
LOG.info(
|
171
185
|
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
@@ -175,9 +189,6 @@ def model_init(
|
|
175
189
|
# Enable include_values when include_values is set to be False during pre-training
|
176
190
|
if model_args.include_values and not model.cehrgpt.include_values:
|
177
191
|
model.cehrgpt.include_values = True
|
178
|
-
# Enable position embeddings when position embeddings are disabled in pre-training
|
179
|
-
if not model_args.exclude_position_ids and model.cehrgpt.exclude_position_ids:
|
180
|
-
model.cehrgpt.exclude_position_ids = False
|
181
192
|
# Expand tokenizer to adapt to the finetuning dataset
|
182
193
|
if model.config.vocab_size < tokenizer.vocab_size:
|
183
194
|
model.resize_token_embeddings(tokenizer.vocab_size)
|
@@ -195,6 +206,7 @@ def model_init(
|
|
195
206
|
model.cehrgpt.update_pretrained_embeddings(
|
196
207
|
tokenizer.pretrained_token_ids, tokenizer.pretrained_embeddings
|
197
208
|
)
|
209
|
+
|
198
210
|
# Expand value tokenizer to adapt to the fine-tuning dataset
|
199
211
|
if model.config.include_values:
|
200
212
|
if model.config.value_vocab_size < tokenizer.value_vocab_size:
|
@@ -252,46 +264,55 @@ def main():
|
|
252
264
|
|
253
265
|
if processed_dataset is None:
|
254
266
|
if is_main_process(training_args.local_rank):
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
)
|
268
|
-
|
269
|
-
|
270
|
-
|
267
|
+
# If the full dataset has been tokenized, we don't want to tokenize the cohort containing
|
268
|
+
# the subset of the data. We should slice out the portion of the tokenized sequences for each sample
|
269
|
+
if cehrgpt_args.tokenized_full_dataset_path is not None:
|
270
|
+
processed_dataset = extract_cohort_sequences(
|
271
|
+
data_args, cehrgpt_args, cache_file_collector
|
272
|
+
)
|
273
|
+
else:
|
274
|
+
final_splits = prepare_finetune_dataset(
|
275
|
+
data_args, training_args, cehrgpt_args, cache_file_collector
|
276
|
+
)
|
277
|
+
if cehrgpt_args.expand_tokenizer:
|
278
|
+
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
279
|
+
if tokenizer_exists(new_tokenizer_path):
|
280
|
+
tokenizer = CehrGptTokenizer.from_pretrained(new_tokenizer_path)
|
281
|
+
else:
|
282
|
+
# Try to use the defined pretrained embeddings if exists, Otherwise we default to the pretrained model
|
283
|
+
# embedded in the pretrained model
|
284
|
+
pretrained_concept_embedding_model = PretrainedEmbeddings(
|
285
|
+
cehrgpt_args.pretrained_embedding_path
|
271
286
|
)
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
287
|
+
if not pretrained_concept_embedding_model.exists:
|
288
|
+
pretrained_concept_embedding_model = (
|
289
|
+
tokenizer.pretrained_concept_embedding_model
|
290
|
+
)
|
291
|
+
tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
292
|
+
cehrgpt_tokenizer=tokenizer,
|
293
|
+
dataset=final_splits["train"],
|
294
|
+
data_args=data_args,
|
295
|
+
concept_name_mapping={},
|
296
|
+
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
297
|
+
)
|
298
|
+
tokenizer.save_pretrained(
|
299
|
+
os.path.expanduser(training_args.output_dir)
|
300
|
+
)
|
301
|
+
|
302
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
303
|
+
if not data_args.streaming:
|
304
|
+
all_columns = final_splits["train"].column_names
|
305
|
+
if "visit_concept_ids" in all_columns:
|
306
|
+
final_splits = final_splits.remove_columns(
|
307
|
+
["visit_concept_ids"]
|
308
|
+
)
|
309
|
+
|
310
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
311
|
+
dataset=final_splits,
|
312
|
+
cehrgpt_tokenizer=tokenizer,
|
313
|
+
data_args=data_args,
|
314
|
+
cache_file_collector=cache_file_collector,
|
315
|
+
)
|
295
316
|
if not data_args.streaming:
|
296
317
|
processed_dataset.save_to_disk(str(prepared_ds_path))
|
297
318
|
stats = processed_dataset.cleanup_cache_files()
|
@@ -350,8 +371,7 @@ def main():
|
|
350
371
|
SamplePackingTrainer,
|
351
372
|
max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
|
352
373
|
max_position_embeddings=config.max_position_embeddings,
|
353
|
-
|
354
|
-
validation_lengths=processed_dataset["validation"]["num_of_concepts"],
|
374
|
+
negative_sampling_probability=cehrgpt_args.negative_sampling_probability,
|
355
375
|
)
|
356
376
|
training_args.per_device_train_batch_size = 1
|
357
377
|
training_args.per_device_eval_batch_size = 1
|
@@ -359,6 +379,7 @@ def main():
|
|
359
379
|
SamplePackingCehrGptDataCollator,
|
360
380
|
cehrgpt_args.max_tokens_per_batch,
|
361
381
|
config.max_position_embeddings,
|
382
|
+
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
362
383
|
)
|
363
384
|
else:
|
364
385
|
trainer_class = Trainer
|
@@ -381,13 +402,14 @@ def main():
|
|
381
402
|
include_ttv_prediction=False,
|
382
403
|
use_sub_time_tokenization=False,
|
383
404
|
include_demographics=cehrgpt_args.include_demographics,
|
405
|
+
add_linear_prob_token=True,
|
384
406
|
)
|
385
407
|
|
386
408
|
if training_args.do_train:
|
387
409
|
if cehrgpt_args.hyperparameter_tuning:
|
388
410
|
training_args = perform_hyperparameter_search(
|
389
411
|
trainer_class,
|
390
|
-
partial(model_init, model_args, training_args, tokenizer),
|
412
|
+
partial(model_init, model_args, training_args, cehrgpt_args, tokenizer),
|
391
413
|
processed_dataset,
|
392
414
|
data_collator,
|
393
415
|
training_args,
|
@@ -401,6 +423,7 @@ def main():
|
|
401
423
|
trainer_class,
|
402
424
|
model_args,
|
403
425
|
training_args,
|
426
|
+
cehrgpt_args,
|
404
427
|
tokenizer,
|
405
428
|
processed_dataset,
|
406
429
|
data_collator,
|
@@ -408,7 +431,7 @@ def main():
|
|
408
431
|
else:
|
409
432
|
# Initialize Trainer for final training on the combined train+val set
|
410
433
|
trainer = trainer_class(
|
411
|
-
model=model_init(model_args, training_args, tokenizer),
|
434
|
+
model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
|
412
435
|
data_collator=data_collator,
|
413
436
|
args=training_args,
|
414
437
|
train_dataset=processed_dataset["train"],
|
@@ -457,6 +480,7 @@ def retrain_with_full_set(
|
|
457
480
|
trainer_class,
|
458
481
|
model_args: ModelArguments,
|
459
482
|
training_args: TrainingArguments,
|
483
|
+
cehrgpt_args: CehrGPTArguments,
|
460
484
|
tokenizer: CehrGptTokenizer,
|
461
485
|
dataset: DatasetDict,
|
462
486
|
data_collator: CehrGptDataCollator,
|
@@ -475,6 +499,7 @@ def retrain_with_full_set(
|
|
475
499
|
model_args (ModelArguments): Model configuration and hyperparameters.
|
476
500
|
training_args (TrainingArguments): Training configuration, including output directory,
|
477
501
|
evaluation strategy, and other training parameters.
|
502
|
+
cehrgpt_args (CehrGPTArguments): CehrGPT specific parameters.
|
478
503
|
tokenizer (CehrGptTokenizer): Tokenizer instance specific to CEHR-GPT.
|
479
504
|
dataset (DatasetDict): A dictionary containing the 'train' and 'validation' datasets.
|
480
505
|
data_collator (CehrGptDataCollator): Data collator for handling data batching and tokenization.
|
@@ -494,7 +519,7 @@ def retrain_with_full_set(
|
|
494
519
|
training_args.evaluation_strategy = "no"
|
495
520
|
checkpoint = get_last_hf_checkpoint(training_args)
|
496
521
|
final_trainer = trainer_class(
|
497
|
-
model=model_init(model_args, training_args, tokenizer),
|
522
|
+
model=model_init(model_args, training_args, cehrgpt_args, tokenizer),
|
498
523
|
data_collator=data_collator,
|
499
524
|
args=training_args,
|
500
525
|
train_dataset=full_dataset,
|
@@ -555,7 +580,15 @@ def do_predict(
|
|
555
580
|
index_dates = batch.pop("index_date").numpy().squeeze()
|
556
581
|
if index_dates.ndim == 0:
|
557
582
|
index_dates = np.asarray([index_dates])
|
558
|
-
|
583
|
+
|
584
|
+
index_dates = list(
|
585
|
+
map(
|
586
|
+
lambda posix_time: datetime.utcfromtimestamp(posix_time).replace(
|
587
|
+
tzinfo=None
|
588
|
+
),
|
589
|
+
index_dates.tolist(),
|
590
|
+
)
|
591
|
+
)
|
559
592
|
|
560
593
|
batch = {k: v.to(device) for k, v in batch.items()}
|
561
594
|
# Forward pass
|
@@ -34,6 +34,7 @@ from cehrgpt.models.config import CEHRGPTConfig
|
|
34
34
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
35
35
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
36
36
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
37
|
+
from cehrgpt.runners.data_utils import get_torch_dtype
|
37
38
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
38
39
|
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
39
40
|
from cehrgpt.runners.sample_packing_trainer import SamplePackingTrainer
|
@@ -71,6 +72,36 @@ def load_and_create_tokenizer(
|
|
71
72
|
cehrgpt_args: CehrGPTArguments,
|
72
73
|
dataset: Optional[Union[Dataset, DatasetDict]] = None,
|
73
74
|
) -> CehrGptTokenizer:
|
75
|
+
|
76
|
+
concept_name_mapping = {}
|
77
|
+
allowed_motor_codes = list()
|
78
|
+
if cehrgpt_args.concept_dir:
|
79
|
+
import pandas as pd
|
80
|
+
from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
|
81
|
+
from meds.schema import death_code
|
82
|
+
|
83
|
+
LOG.info("Loading concept data from disk at %s", cehrgpt_args.concept_dir)
|
84
|
+
concept_pd = pd.read_parquet(cehrgpt_args.concept_dir)
|
85
|
+
LOG.info(
|
86
|
+
"Creating concept name mapping and motor_time_to_event_codes from disk at %s",
|
87
|
+
cehrgpt_args.concept_dir,
|
88
|
+
)
|
89
|
+
for row in concept_pd.itertuples():
|
90
|
+
concept_name_mapping[str(getattr(row, "concept_id"))] = getattr(
|
91
|
+
row, "concept_name"
|
92
|
+
)
|
93
|
+
if (
|
94
|
+
cehrgpt_args.include_motor_time_to_event
|
95
|
+
and getattr(row, "domain_id")
|
96
|
+
in ["Condition", "Procedure", "Drug", "Visit"]
|
97
|
+
and getattr(row, "standard_concept") == "S"
|
98
|
+
):
|
99
|
+
allowed_motor_codes.append(str(getattr(row, "concept_id")))
|
100
|
+
LOG.info(
|
101
|
+
"Adding death codes for MOTOR TTE predictions: %s",
|
102
|
+
[DEATH_TOKEN, death_code],
|
103
|
+
)
|
104
|
+
allowed_motor_codes.extend([DEATH_TOKEN, death_code])
|
74
105
|
# Try to load the pretrained tokenizer
|
75
106
|
tokenizer_abspath = os.path.expanduser(model_args.tokenizer_name_or_path)
|
76
107
|
try:
|
@@ -85,9 +116,17 @@ def load_and_create_tokenizer(
|
|
85
116
|
LOG.info("Started training the tokenizer ...")
|
86
117
|
tokenizer = CehrGptTokenizer.train_tokenizer(
|
87
118
|
dataset,
|
88
|
-
|
119
|
+
concept_name_mapping,
|
89
120
|
data_args,
|
90
121
|
PretrainedEmbeddings(cehrgpt_args.pretrained_embedding_path),
|
122
|
+
allowed_motor_codes if cehrgpt_args.include_motor_time_to_event else None,
|
123
|
+
(
|
124
|
+
cehrgpt_args.num_motor_tasks
|
125
|
+
if cehrgpt_args.include_motor_time_to_event
|
126
|
+
else None
|
127
|
+
),
|
128
|
+
apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
|
129
|
+
min_prevalence=cehrgpt_args.min_prevalence,
|
91
130
|
)
|
92
131
|
LOG.info("Finished training the tokenizer ...")
|
93
132
|
tokenizer.save_pretrained(tokenizer_abspath)
|
@@ -99,13 +138,12 @@ def load_and_create_tokenizer(
|
|
99
138
|
def load_and_create_model(
|
100
139
|
model_args: ModelArguments,
|
101
140
|
cehrgpt_args: CehrGPTArguments,
|
102
|
-
training_args: TrainingArguments,
|
103
141
|
tokenizer: CehrGptTokenizer,
|
104
142
|
) -> CEHRGPT2LMHeadModel:
|
105
143
|
attn_implementation = (
|
106
144
|
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
107
145
|
)
|
108
|
-
torch_dtype =
|
146
|
+
torch_dtype = get_torch_dtype(model_args.torch_dtype)
|
109
147
|
model_abspath = os.path.expanduser(model_args.model_name_or_path)
|
110
148
|
if cehrgpt_args.continue_pretrain:
|
111
149
|
try:
|
@@ -147,6 +185,8 @@ def load_and_create_model(
|
|
147
185
|
else:
|
148
186
|
pretrained_embedding_dim = model_args.hidden_size
|
149
187
|
|
188
|
+
model_args_cehrgpt = model_args.as_dict()
|
189
|
+
model_args_cehrgpt.pop("attn_implementation")
|
150
190
|
model_config = CEHRGPTConfig(
|
151
191
|
vocab_size=tokenizer.vocab_size,
|
152
192
|
value_vocab_size=tokenizer.value_vocab_size,
|
@@ -172,7 +212,12 @@ def load_and_create_model(
|
|
172
212
|
if cehrgpt_args.sample_packing
|
173
213
|
else model_args.max_position_embeddings
|
174
214
|
),
|
175
|
-
|
215
|
+
include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
|
216
|
+
motor_tte_vocab_size=tokenizer.motor_tte_vocab_size,
|
217
|
+
motor_time_to_event_weight=cehrgpt_args.motor_time_to_event_weight,
|
218
|
+
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
219
|
+
ve_token_id=tokenizer.ve_token_id,
|
220
|
+
**model_args_cehrgpt,
|
176
221
|
)
|
177
222
|
|
178
223
|
model = CEHRGPT2LMHeadModel(model_config)
|
@@ -348,6 +393,8 @@ def main():
|
|
348
393
|
pretrained_concept_embedding_model=PretrainedEmbeddings(
|
349
394
|
cehrgpt_args.pretrained_embedding_path
|
350
395
|
),
|
396
|
+
apply_entropy_filter=cehrgpt_args.apply_entropy_filter,
|
397
|
+
min_prevalence=cehrgpt_args.min_prevalence,
|
351
398
|
)
|
352
399
|
cehrgpt_tokenizer.save_pretrained(
|
353
400
|
os.path.expanduser(training_args.output_dir)
|
@@ -421,9 +468,11 @@ def main():
|
|
421
468
|
else:
|
422
469
|
processed_dataset = processed_dataset.filter(filter_func, **filter_args)
|
423
470
|
|
424
|
-
model = load_and_create_model(
|
425
|
-
|
426
|
-
|
471
|
+
model = load_and_create_model(model_args, cehrgpt_args, cehrgpt_tokenizer)
|
472
|
+
|
473
|
+
# Try to update motor tte vocab size if the new configuration is different from the existing one
|
474
|
+
if cehrgpt_args.include_motor_time_to_event:
|
475
|
+
model.update_motor_tte_vocab_size(cehrgpt_tokenizer.motor_tte_vocab_size)
|
427
476
|
|
428
477
|
# Expand tokenizer to adapt to the new pretraining dataset
|
429
478
|
if model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
@@ -500,6 +549,9 @@ def main():
|
|
500
549
|
include_ttv_prediction=model_args.include_ttv_prediction,
|
501
550
|
use_sub_time_tokenization=model_args.use_sub_time_tokenization,
|
502
551
|
include_values=model_args.include_values,
|
552
|
+
include_motor_time_to_event=cehrgpt_args.include_motor_time_to_event,
|
553
|
+
motor_tte_vocab_size=model.config.motor_tte_vocab_size,
|
554
|
+
motor_num_time_pieces=cehrgpt_args.motor_num_time_pieces,
|
503
555
|
),
|
504
556
|
train_dataset=processed_dataset["train"],
|
505
557
|
eval_dataset=(
|
@@ -6,6 +6,12 @@ from typing import List, Optional
|
|
6
6
|
class CehrGPTArguments:
|
7
7
|
"""Arguments pertaining to what data we are going to input our model for training and eval."""
|
8
8
|
|
9
|
+
tokenized_full_dataset_path: Optional[str] = dataclasses.field(
|
10
|
+
default=None,
|
11
|
+
metadata={
|
12
|
+
"help": "The path to the tokenized dataset created for the full population"
|
13
|
+
},
|
14
|
+
)
|
9
15
|
include_inpatient_hour_token: Optional[bool] = dataclasses.field(
|
10
16
|
default=True,
|
11
17
|
metadata={"help": "Include inpatient hour token"},
|
@@ -177,7 +183,61 @@ class CehrGPTArguments:
|
|
177
183
|
"help": "A flag to indicate whether we want to add end token in sample packing"
|
178
184
|
},
|
179
185
|
)
|
186
|
+
include_motor_time_to_event: Optional[bool] = dataclasses.field(
|
187
|
+
default=False,
|
188
|
+
metadata={
|
189
|
+
"help": "A flag to indicate whether we want to include the motor time to events"
|
190
|
+
},
|
191
|
+
)
|
192
|
+
num_motor_tasks: Optional[int] = dataclasses.field(
|
193
|
+
default=10000,
|
194
|
+
metadata={"help": "The number of max MOTOR tasks"},
|
195
|
+
)
|
196
|
+
motor_time_to_event_weight: Optional[float] = dataclasses.field(
|
197
|
+
default=1.0,
|
198
|
+
metadata={"help": "The MOTOR time to event loss weight"},
|
199
|
+
)
|
200
|
+
motor_num_time_pieces: Optional[int] = dataclasses.field(
|
201
|
+
default=8,
|
202
|
+
metadata={
|
203
|
+
"help": "The number of times each motor_num_time_pieces piece has to be"
|
204
|
+
},
|
205
|
+
)
|
206
|
+
concept_dir: Optional[str] = dataclasses.field(
|
207
|
+
default=None,
|
208
|
+
metadata={"help": "The directory where the concept data is stored."},
|
209
|
+
)
|
180
210
|
average_over_sequence: bool = dataclasses.field(
|
181
211
|
default=False,
|
182
212
|
metadata={"help": "Whether or not to average tokens per sequence"},
|
183
213
|
)
|
214
|
+
apply_entropy_filter: Optional[bool] = dataclasses.field(
|
215
|
+
default=False,
|
216
|
+
metadata={"help": "A flag to indicate whether we want to use entropy filter."},
|
217
|
+
)
|
218
|
+
min_prevalence: Optional[float] = dataclasses.field(
|
219
|
+
default=1 / 1000,
|
220
|
+
metadata={"help": "The min_prevalence to keep the concepts in the tokenizer"},
|
221
|
+
)
|
222
|
+
class_weights: Optional[List[int]] = dataclasses.field(
|
223
|
+
default=None,
|
224
|
+
metadata={"help": "The class weights for training"},
|
225
|
+
)
|
226
|
+
negative_sampling_probability: Optional[float] = dataclasses.field(
|
227
|
+
default=None,
|
228
|
+
metadata={
|
229
|
+
"help": "The probability of negative samples will be included in the training data"
|
230
|
+
},
|
231
|
+
)
|
232
|
+
num_of_trajectories_per_sample: Optional[int] = dataclasses.field(
|
233
|
+
default=1,
|
234
|
+
metadata={"help": "The number of trajectories per sample"},
|
235
|
+
)
|
236
|
+
generation_input_length: Optional[int] = dataclasses.field(
|
237
|
+
default=1024,
|
238
|
+
metadata={"help": "The length of the input sequence"},
|
239
|
+
)
|
240
|
+
generation_max_new_tokens: Optional[int] = dataclasses.field(
|
241
|
+
default=1024,
|
242
|
+
metadata={"help": "The maximum number of tokens in the generation sequence"},
|
243
|
+
)
|
@@ -4,12 +4,7 @@ from typing import Callable, Tuple
|
|
4
4
|
import optuna
|
5
5
|
from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
|
6
6
|
from datasets import Dataset, DatasetDict
|
7
|
-
from transformers import
|
8
|
-
EarlyStoppingCallback,
|
9
|
-
Trainer,
|
10
|
-
TrainerCallback,
|
11
|
-
TrainingArguments,
|
12
|
-
)
|
7
|
+
from transformers import EarlyStoppingCallback, TrainerCallback, TrainingArguments
|
13
8
|
from transformers.utils import logging
|
14
9
|
|
15
10
|
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
@@ -85,7 +80,9 @@ def hp_space(
|
|
85
80
|
"per_device_train_batch_size", batch_sizes
|
86
81
|
),
|
87
82
|
"weight_decay": trial.suggest_float("weight_decay", *weight_decays, log=True),
|
88
|
-
"num_train_epochs": trial.
|
83
|
+
"num_train_epochs": trial.suggest_categorical(
|
84
|
+
"num_train_epochs", num_train_epochs
|
85
|
+
),
|
89
86
|
}
|
90
87
|
|
91
88
|
|
@@ -217,6 +214,8 @@ def perform_hyperparameter_search(
|
|
217
214
|
backend="optuna",
|
218
215
|
n_trials=cehrgpt_args.n_trials,
|
219
216
|
compute_objective=lambda m: m["optuna_best_metric"],
|
217
|
+
# Ensure reproducibility
|
218
|
+
sampler=optuna.samplers.TPESampler(seed=training_args.seed),
|
220
219
|
)
|
221
220
|
LOG.info("Best hyperparameters: %s", best_trial.hyperparameters)
|
222
221
|
# Update training arguments with best hyperparameters and set epochs based on adjusted effective epochs
|