cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +244 -39
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +131 -5
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,314 @@
|
|
1
|
+
import datetime
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import shutil
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Dict
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import polars as pl
|
10
|
+
import torch
|
11
|
+
import torch.distributed as dist
|
12
|
+
from cehrbert.runners.runner_util import generate_prepared_ds_path
|
13
|
+
from datasets import load_from_disk
|
14
|
+
from meds import held_out_split, train_split, tuning_split
|
15
|
+
from torch.utils.data import DataLoader
|
16
|
+
from tqdm import tqdm
|
17
|
+
from transformers.trainer_utils import is_main_process
|
18
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
19
|
+
|
20
|
+
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
21
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
22
|
+
from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
|
23
|
+
generate_single_batch,
|
24
|
+
normalize_value,
|
25
|
+
)
|
26
|
+
from cehrgpt.gpt_utils import (
|
27
|
+
extract_time_interval_in_days,
|
28
|
+
extract_time_interval_in_hours,
|
29
|
+
is_att_token,
|
30
|
+
is_inpatient_hour_token,
|
31
|
+
is_visit_end,
|
32
|
+
is_visit_start,
|
33
|
+
)
|
34
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
35
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
36
|
+
from cehrgpt.runners.data_utils import (
|
37
|
+
extract_cohort_sequences,
|
38
|
+
prepare_finetune_dataset,
|
39
|
+
)
|
40
|
+
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
41
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
42
|
+
|
43
|
+
LOG = logging.get_logger("transformers")
|
44
|
+
|
45
|
+
|
46
|
+
def map_data_split_name(split: str) -> str:
|
47
|
+
if split == "train":
|
48
|
+
return train_split
|
49
|
+
elif split == "validation":
|
50
|
+
return tuning_split
|
51
|
+
elif split == "test":
|
52
|
+
return held_out_split
|
53
|
+
raise ValueError(f"Unknown split: {split}")
|
54
|
+
|
55
|
+
|
56
|
+
def seed_all(seed: int = 42):
|
57
|
+
"""Set seed for Python, NumPy, and PyTorch (CPU & CUDA)."""
|
58
|
+
random.seed(seed) # Python random
|
59
|
+
np.random.seed(seed) # NumPy
|
60
|
+
torch.manual_seed(seed) # PyTorch CPU
|
61
|
+
torch.cuda.manual_seed(seed) # Current GPU
|
62
|
+
torch.cuda.manual_seed_all(seed) # All GPUs
|
63
|
+
|
64
|
+
# For reproducibility in dataloader workers
|
65
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
66
|
+
|
67
|
+
|
68
|
+
def generate_trajectories_per_batch(
|
69
|
+
batch: Dict[str, Any],
|
70
|
+
cehrgpt_tokenizer: CehrGptTokenizer,
|
71
|
+
cehrgpt_model: CEHRGPT2LMHeadModel,
|
72
|
+
device,
|
73
|
+
data_output_path: Path,
|
74
|
+
max_length: int,
|
75
|
+
):
|
76
|
+
subject_ids = batch["person_id"].squeeze().detach().cpu().tolist()
|
77
|
+
prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
|
78
|
+
batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
|
79
|
+
batched_input_ids = batch["input_ids"]
|
80
|
+
batched_value_indicators = batch["value_indicators"]
|
81
|
+
batched_values = batch["values"]
|
82
|
+
# Make sure the batch does not exceed batch_size
|
83
|
+
batch_sequences = generate_single_batch(
|
84
|
+
cehrgpt_model,
|
85
|
+
cehrgpt_tokenizer,
|
86
|
+
batched_input_ids,
|
87
|
+
values=batched_values,
|
88
|
+
value_indicators=batched_value_indicators,
|
89
|
+
max_length=max_length,
|
90
|
+
top_p=1.0,
|
91
|
+
top_k=cehrgpt_tokenizer.vocab_size,
|
92
|
+
device=device,
|
93
|
+
)
|
94
|
+
# Clear the cache
|
95
|
+
torch.cuda.empty_cache()
|
96
|
+
|
97
|
+
trajectories = []
|
98
|
+
for sample_i, (concept_ids, value_indicators, values) in enumerate(
|
99
|
+
zip(
|
100
|
+
batch_sequences["sequences"],
|
101
|
+
batch_sequences["value_indicators"],
|
102
|
+
batch_sequences["values"],
|
103
|
+
)
|
104
|
+
):
|
105
|
+
(
|
106
|
+
concept_ids,
|
107
|
+
is_numeric_types,
|
108
|
+
number_as_values,
|
109
|
+
concept_as_values,
|
110
|
+
units,
|
111
|
+
) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
|
112
|
+
|
113
|
+
epoch_times = batched_epoch_times[sample_i]
|
114
|
+
input_length = len(epoch_times)
|
115
|
+
# Getting the last observed event time from the token before the prediction time
|
116
|
+
window_last_observed = epoch_times[input_length - 1]
|
117
|
+
current_cursor = epoch_times[-1]
|
118
|
+
generated_epoch_times = []
|
119
|
+
valid_indices = []
|
120
|
+
|
121
|
+
for i in range(input_length, len(concept_ids)):
|
122
|
+
concept_id = concept_ids[i]
|
123
|
+
# We use the left padding strategy in the data collator
|
124
|
+
if concept_id in [cehrgpt_tokenizer.pad_token, cehrgpt_tokenizer.end_token]:
|
125
|
+
continue
|
126
|
+
# We need to construct the time stamp
|
127
|
+
if is_att_token(concept_id):
|
128
|
+
current_cursor += extract_time_interval_in_days(concept_id) * 24 * 3600
|
129
|
+
elif is_inpatient_hour_token(concept_id):
|
130
|
+
current_cursor += extract_time_interval_in_hours(concept_id) * 3600
|
131
|
+
elif is_visit_start(concept_id) or is_visit_end(concept_id):
|
132
|
+
continue
|
133
|
+
else:
|
134
|
+
valid_indices.append(i)
|
135
|
+
generated_epoch_times.append(
|
136
|
+
datetime.datetime.utcfromtimestamp(current_cursor).replace(
|
137
|
+
tzinfo=None
|
138
|
+
)
|
139
|
+
)
|
140
|
+
|
141
|
+
trajectories.append(
|
142
|
+
{
|
143
|
+
"subject_id": subject_ids[sample_i],
|
144
|
+
"prediction_time": datetime.datetime.utcfromtimestamp(
|
145
|
+
prediction_times[sample_i]
|
146
|
+
).replace(tzinfo=None),
|
147
|
+
"window_last_observed_time": datetime.datetime.utcfromtimestamp(
|
148
|
+
window_last_observed
|
149
|
+
).replace(tzinfo=None),
|
150
|
+
"times": generated_epoch_times,
|
151
|
+
"concept_ids": np.asarray(concept_ids)[valid_indices].tolist(),
|
152
|
+
"numeric_values": np.asarray(number_as_values)[valid_indices].tolist(),
|
153
|
+
"text_value": np.asarray(concept_as_values)[valid_indices].tolist(),
|
154
|
+
"units": np.asarray(units)[valid_indices].tolist(),
|
155
|
+
}
|
156
|
+
)
|
157
|
+
|
158
|
+
trajectories = (
|
159
|
+
pl.DataFrame(trajectories)
|
160
|
+
.explode(["times", "concept_ids", "numeric_values", "text_value", "units"])
|
161
|
+
.rename(
|
162
|
+
{
|
163
|
+
"times": "time",
|
164
|
+
"concept_ids": "code",
|
165
|
+
"numeric_values": "numeric_value",
|
166
|
+
"units": "unit",
|
167
|
+
}
|
168
|
+
)
|
169
|
+
.select(
|
170
|
+
"subject_id",
|
171
|
+
"prediction_time",
|
172
|
+
"window_last_observed_time",
|
173
|
+
"time",
|
174
|
+
"code",
|
175
|
+
"numeric_value",
|
176
|
+
"text_value",
|
177
|
+
"unit",
|
178
|
+
)
|
179
|
+
)
|
180
|
+
trajectories.write_parquet(data_output_path)
|
181
|
+
|
182
|
+
|
183
|
+
def main():
|
184
|
+
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
185
|
+
if torch.cuda.is_available():
|
186
|
+
device = torch.device("cuda")
|
187
|
+
else:
|
188
|
+
device = torch.device("cpu")
|
189
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
190
|
+
model_args.tokenizer_name_or_path
|
191
|
+
)
|
192
|
+
cehrgpt_model = (
|
193
|
+
CEHRGPT2LMHeadModel.from_pretrained(
|
194
|
+
model_args.model_name_or_path,
|
195
|
+
attn_implementation=(
|
196
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
197
|
+
),
|
198
|
+
)
|
199
|
+
.eval()
|
200
|
+
.to(device)
|
201
|
+
)
|
202
|
+
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
203
|
+
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
204
|
+
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
205
|
+
|
206
|
+
if not os.path.exists(training_args.output_dir):
|
207
|
+
os.makedirs(training_args.output_dir)
|
208
|
+
|
209
|
+
prepared_ds_path = generate_prepared_ds_path(
|
210
|
+
data_args, model_args, data_folder=data_args.cohort_folder
|
211
|
+
)
|
212
|
+
|
213
|
+
processed_dataset = None
|
214
|
+
if any(prepared_ds_path.glob("*")):
|
215
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
216
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
217
|
+
LOG.info("Prepared dataset loaded from disk...")
|
218
|
+
if cehrgpt_args.expand_tokenizer:
|
219
|
+
if tokenizer_exists(training_args.output_dir):
|
220
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
221
|
+
training_args.output_dir
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
LOG.warning(
|
225
|
+
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
226
|
+
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
227
|
+
f"Please delete the processed dataset at {prepared_ds_path}."
|
228
|
+
)
|
229
|
+
processed_dataset = None
|
230
|
+
shutil.rmtree(prepared_ds_path)
|
231
|
+
|
232
|
+
if processed_dataset is None and is_main_process(training_args.local_rank):
|
233
|
+
# If the full dataset has been tokenized, we don't want to tokenize the cohort containing
|
234
|
+
# the subset of the data. We should slice out the portion of the tokenized sequences for each sample
|
235
|
+
if cehrgpt_args.tokenized_full_dataset_path is not None:
|
236
|
+
processed_dataset = extract_cohort_sequences(data_args, cehrgpt_args)
|
237
|
+
else:
|
238
|
+
# Organize them into a single DatasetDict
|
239
|
+
final_splits = prepare_finetune_dataset(
|
240
|
+
data_args, training_args, cehrgpt_args
|
241
|
+
)
|
242
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
243
|
+
if not data_args.streaming:
|
244
|
+
all_columns = final_splits["train"].column_names
|
245
|
+
if "visit_concept_ids" in all_columns:
|
246
|
+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
|
247
|
+
|
248
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
249
|
+
dataset=final_splits,
|
250
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
251
|
+
data_args=data_args,
|
252
|
+
)
|
253
|
+
if not data_args.streaming:
|
254
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
255
|
+
processed_dataset.cleanup_cache_files()
|
256
|
+
|
257
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
258
|
+
if dist.is_available() and dist.is_initialized():
|
259
|
+
dist.barrier()
|
260
|
+
|
261
|
+
# We suppress the additional learning objectives in fine-tuning
|
262
|
+
data_collator = CehrGptDataCollator(
|
263
|
+
tokenizer=cehrgpt_tokenizer,
|
264
|
+
max_length=cehrgpt_args.generation_input_length,
|
265
|
+
include_values=cehrgpt_model.config.include_values,
|
266
|
+
pretraining=False,
|
267
|
+
include_ttv_prediction=False,
|
268
|
+
use_sub_time_tokenization=False,
|
269
|
+
include_demographics=False,
|
270
|
+
add_linear_prob_token=False,
|
271
|
+
)
|
272
|
+
|
273
|
+
LOG.info(
|
274
|
+
"Generating %s trajectories per sample",
|
275
|
+
cehrgpt_args.num_of_trajectories_per_sample,
|
276
|
+
)
|
277
|
+
for sample_i in range(cehrgpt_args.num_of_trajectories_per_sample):
|
278
|
+
for split, dataset in processed_dataset.items():
|
279
|
+
meds_split = map_data_split_name(split)
|
280
|
+
dataloader = DataLoader(
|
281
|
+
dataset=dataset,
|
282
|
+
batch_size=training_args.per_device_eval_batch_size,
|
283
|
+
num_workers=training_args.dataloader_num_workers,
|
284
|
+
collate_fn=data_collator,
|
285
|
+
pin_memory=training_args.dataloader_pin_memory,
|
286
|
+
)
|
287
|
+
sample_output_dir = (
|
288
|
+
Path(training_args.output_dir) / meds_split / f"{sample_i}"
|
289
|
+
)
|
290
|
+
sample_output_dir.mkdir(exist_ok=True, parents=True)
|
291
|
+
for batch_i, batch in tqdm(
|
292
|
+
enumerate(dataloader),
|
293
|
+
desc=f"Generating Trajectories for split {meds_split} with trajectory {sample_i + 1}",
|
294
|
+
):
|
295
|
+
output_parquet_file = sample_output_dir / f"{batch_i}.parquet"
|
296
|
+
if output_parquet_file.exists():
|
297
|
+
LOG.info("%s already exists, skip...", output_parquet_file)
|
298
|
+
continue
|
299
|
+
|
300
|
+
generate_trajectories_per_batch(
|
301
|
+
batch,
|
302
|
+
cehrgpt_tokenizer,
|
303
|
+
cehrgpt_model,
|
304
|
+
device,
|
305
|
+
sample_output_dir / f"{batch_i}.parquet",
|
306
|
+
cehrgpt_args.generation_max_new_tokens
|
307
|
+
+ cehrgpt_args.generation_input_length,
|
308
|
+
)
|
309
|
+
|
310
|
+
|
311
|
+
if __name__ == "__main__":
|
312
|
+
# ✅ Call first thing inside main()
|
313
|
+
seed_all(42)
|
314
|
+
main()
|
@@ -74,7 +74,10 @@ def generate_single_batch(
|
|
74
74
|
model: CEHRGPT2LMHeadModel,
|
75
75
|
tokenizer: CehrGptTokenizer,
|
76
76
|
prompts: List[List[int]],
|
77
|
-
|
77
|
+
max_length: int,
|
78
|
+
values: Optional[torch.Tensor] = None,
|
79
|
+
value_indicators: Optional[torch.Tensor] = None,
|
80
|
+
max_new_tokens: Optional[int] = None,
|
78
81
|
mini_num_of_concepts=1,
|
79
82
|
top_p=0.95,
|
80
83
|
top_k=50,
|
@@ -88,7 +91,8 @@ def generate_single_batch(
|
|
88
91
|
with torch.no_grad():
|
89
92
|
generation_config = GenerationConfig(
|
90
93
|
repetition_penalty=repetition_penalty,
|
91
|
-
|
94
|
+
max_new_tokens=max_new_tokens,
|
95
|
+
max_length=max_length,
|
92
96
|
min_length=mini_num_of_concepts,
|
93
97
|
temperature=temperature,
|
94
98
|
top_p=top_p,
|
@@ -107,9 +111,17 @@ def generate_single_batch(
|
|
107
111
|
num_beam_groups=num_beam_groups,
|
108
112
|
epsilon_cutoff=epsilon_cutoff,
|
109
113
|
)
|
114
|
+
|
110
115
|
batched_prompts = torch.tensor(prompts).to(device)
|
116
|
+
if values is not None:
|
117
|
+
values = values.to(device)
|
118
|
+
if value_indicators is not None:
|
119
|
+
value_indicators = value_indicators.to(device)
|
120
|
+
|
111
121
|
results = model.generate(
|
112
122
|
inputs=batched_prompts,
|
123
|
+
values=values,
|
124
|
+
value_indicators=value_indicators,
|
113
125
|
generation_config=generation_config,
|
114
126
|
lab_token_ids=tokenizer.lab_token_ids,
|
115
127
|
)
|
@@ -226,7 +238,7 @@ def main(args):
|
|
226
238
|
cehrgpt_model,
|
227
239
|
cehrgpt_tokenizer,
|
228
240
|
random_prompts[: args.batch_size],
|
229
|
-
|
241
|
+
max_length=args.context_window,
|
230
242
|
mini_num_of_concepts=args.min_num_of_concepts,
|
231
243
|
top_p=args.top_p,
|
232
244
|
top_k=args.top_k,
|
@@ -60,6 +60,24 @@ OOV_CONCEPT_MAP = {
|
|
60
60
|
}
|
61
61
|
|
62
62
|
|
63
|
+
def extract_gender_concept_id(gender_token: str) -> int:
|
64
|
+
if gender_token.startswith("Gender/"):
|
65
|
+
return int(gender_token[len("Gender/") :])
|
66
|
+
elif gender_token.isnumeric():
|
67
|
+
return int(gender_token)
|
68
|
+
else:
|
69
|
+
return 0
|
70
|
+
|
71
|
+
|
72
|
+
def extract_race_concept_id(race_token: str) -> int:
|
73
|
+
if race_token.startswith("Race/"):
|
74
|
+
return int(race_token[len("Race/") :])
|
75
|
+
elif race_token.isnumeric():
|
76
|
+
return int(race_token)
|
77
|
+
else:
|
78
|
+
return 0
|
79
|
+
|
80
|
+
|
63
81
|
def create_folder_if_not_exists(output_folder, table_name):
|
64
82
|
if not os.path.isdir(Path(output_folder) / table_name):
|
65
83
|
os.mkdir(Path(output_folder) / table_name)
|
@@ -288,7 +306,13 @@ def gpt_to_omop_converter_batch(
|
|
288
306
|
if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
|
289
307
|
continue
|
290
308
|
|
291
|
-
p = Person(
|
309
|
+
p = Person(
|
310
|
+
person_id=person_id,
|
311
|
+
gender_concept_id=extract_gender_concept_id(start_gender),
|
312
|
+
year_of_birth=birth_year,
|
313
|
+
race_concept_id=extract_race_concept_id(start_race),
|
314
|
+
)
|
315
|
+
|
292
316
|
append_to_dict(omop_export_dict, p, person_id)
|
293
317
|
id_mappings_dict["person"][person_id] = person_id
|
294
318
|
pt_seq_dict[person_id] = " ".join(concept_ids)
|
@@ -316,7 +340,12 @@ def gpt_to_omop_converter_batch(
|
|
316
340
|
id_mappings_dict["death"][person_id] = person_id
|
317
341
|
else:
|
318
342
|
try:
|
319
|
-
|
343
|
+
if clinical_events[event_idx + 1].startswith("Visit/"):
|
344
|
+
visit_concept_id = int(
|
345
|
+
clinical_events[event_idx + 1][len("Visit/") :]
|
346
|
+
)
|
347
|
+
else:
|
348
|
+
visit_concept_id = int(clinical_events[event_idx + 1])
|
320
349
|
inpatient_visit_indicator = visit_concept_id in [
|
321
350
|
9201,
|
322
351
|
262,
|
@@ -349,6 +378,7 @@ def gpt_to_omop_converter_batch(
|
|
349
378
|
visit_occurrence_id
|
350
379
|
] = person_id
|
351
380
|
visit_occurrence_id += 1
|
381
|
+
|
352
382
|
elif event in ATT_TIME_TOKENS:
|
353
383
|
if event[0] == "D":
|
354
384
|
att_date_delta = int(event[1:])
|
cehrgpt/gpt_utils.py
CHANGED
@@ -11,6 +11,7 @@ from cehrgpt.models.special_tokens import (
|
|
11
11
|
)
|
12
12
|
|
13
13
|
# Regular expression pattern to match inpatient attendance tokens
|
14
|
+
MEDS_CODE_PATTERN = re.compile(r".*/.*")
|
14
15
|
INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
|
15
16
|
DEMOGRAPHIC_PROMPT_SIZE = 4
|
16
17
|
|
@@ -194,8 +195,12 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
|
|
194
195
|
return folder_name
|
195
196
|
|
196
197
|
|
197
|
-
def is_clinical_event(token: str) -> bool:
|
198
|
-
|
198
|
+
def is_clinical_event(token: str, meds: bool = False) -> bool:
|
199
|
+
if token.isnumeric():
|
200
|
+
return True
|
201
|
+
if meds:
|
202
|
+
return bool(MEDS_CODE_PATTERN.match(token))
|
203
|
+
return False
|
199
204
|
|
200
205
|
|
201
206
|
def is_visit_start(token: str):
|
@@ -212,6 +217,18 @@ def is_visit_end(token: str) -> bool:
|
|
212
217
|
return token in ["VE", "[VE]"]
|
213
218
|
|
214
219
|
|
220
|
+
def is_inpatient_hour_token(token: str) -> bool:
|
221
|
+
return token.startswith("i-H")
|
222
|
+
|
223
|
+
|
224
|
+
def extract_time_interval_in_hours(token: str) -> int:
|
225
|
+
try:
|
226
|
+
hour = int(token[3:])
|
227
|
+
return hour
|
228
|
+
except ValueError:
|
229
|
+
return 0
|
230
|
+
|
231
|
+
|
215
232
|
def is_att_token(token: str):
|
216
233
|
"""
|
217
234
|
Check if the token is an attention token.
|
@@ -251,6 +268,7 @@ def is_artificial_token(token: str) -> bool:
|
|
251
268
|
return True
|
252
269
|
if token == END_TOKEN:
|
253
270
|
return True
|
271
|
+
|
254
272
|
return False
|
255
273
|
|
256
274
|
|
cehrgpt/models/config.py
CHANGED
@@ -121,6 +121,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
121
121
|
bos_token_id=50256,
|
122
122
|
eos_token_id=50256,
|
123
123
|
lab_token_ids=None,
|
124
|
+
ve_token_id=None,
|
124
125
|
scale_attn_by_inverse_layer_idx=False,
|
125
126
|
reorder_and_upcast_attn=False,
|
126
127
|
exclude_position_ids=False,
|
@@ -128,6 +129,10 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
128
129
|
value_vocab_size=None,
|
129
130
|
include_ttv_prediction=False,
|
130
131
|
use_sub_time_tokenization=True,
|
132
|
+
include_motor_time_to_event=True,
|
133
|
+
motor_tte_vocab_size=None,
|
134
|
+
motor_time_to_event_weight=1.0,
|
135
|
+
motor_num_time_pieces=16,
|
131
136
|
token_to_time_token_mapping: Dict[int, List] = None,
|
132
137
|
use_pretrained_embeddings=False,
|
133
138
|
n_pretrained_embeddings_layers=2,
|
@@ -144,6 +149,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
144
149
|
entropy_penalty=False,
|
145
150
|
entropy_penalty_alpha=0.01,
|
146
151
|
sample_packing_max_positions=None,
|
152
|
+
class_weights=None,
|
147
153
|
**kwargs,
|
148
154
|
):
|
149
155
|
if token_to_time_token_mapping is None:
|
@@ -192,6 +198,22 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
192
198
|
self._token_to_time_token_mapping = token_to_time_token_mapping
|
193
199
|
self.time_token_loss_weight = time_token_loss_weight
|
194
200
|
self.time_to_visit_loss_weight = time_to_visit_loss_weight
|
201
|
+
|
202
|
+
# MOTOR TTE configuration
|
203
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
204
|
+
self.include_motor_time_to_event = (
|
205
|
+
include_motor_time_to_event
|
206
|
+
and self.motor_tte_vocab_size
|
207
|
+
and self.motor_tte_vocab_size > 0
|
208
|
+
)
|
209
|
+
if self.include_motor_time_to_event and not ve_token_id:
|
210
|
+
raise RuntimeError(
|
211
|
+
f"ve_token_id must be provided when include_motor_time_to_event is True"
|
212
|
+
)
|
213
|
+
self.ve_token_id = ve_token_id
|
214
|
+
self.motor_time_to_event_weight = motor_time_to_event_weight
|
215
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
216
|
+
|
195
217
|
self.causal_sfm = causal_sfm
|
196
218
|
self.demographics_size = demographics_size
|
197
219
|
self.use_pretrained_embeddings = use_pretrained_embeddings
|
@@ -206,6 +228,9 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
206
228
|
self.entropy_penalty_alpha = entropy_penalty_alpha
|
207
229
|
self.value_prediction_loss_weight = value_prediction_loss_weight
|
208
230
|
|
231
|
+
# Class weights for fine-tuning
|
232
|
+
self.class_weights = class_weights
|
233
|
+
|
209
234
|
kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
|
210
235
|
|
211
236
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|