cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,181 @@
|
|
1
|
+
from typing import Iterator, List, Optional
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torch.distributed as dist
|
6
|
+
from torch.utils.data import Sampler
|
7
|
+
from transformers import logging
|
8
|
+
|
9
|
+
LOG = logging.get_logger("transformers")
|
10
|
+
|
11
|
+
|
12
|
+
class SamplePlacerHolder:
|
13
|
+
def __init__(self):
|
14
|
+
self.epoch = 0
|
15
|
+
|
16
|
+
def set_epoch(self, epoch):
|
17
|
+
self.epoch = epoch
|
18
|
+
|
19
|
+
|
20
|
+
class SamplePackingBatchSampler(Sampler[List[int]]):
|
21
|
+
"""
|
22
|
+
A batch sampler that creates batches by packing samples together.
|
23
|
+
|
24
|
+
to maximize GPU utilization, ensuring the total tokens per batch
|
25
|
+
doesn't exceed max_tokens.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
lengths: List[int],
|
31
|
+
max_tokens_per_batch: int,
|
32
|
+
max_position_embeddings: int,
|
33
|
+
num_replicas: Optional[int] = None,
|
34
|
+
rank: Optional[int] = None,
|
35
|
+
seed: int = 0,
|
36
|
+
drop_last: bool = False,
|
37
|
+
negative_sampling_probability: Optional[float] = None,
|
38
|
+
labels: Optional[List[int]] = None,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
Args:
|
42
|
+
|
43
|
+
lengths: List of sequence lengths for each sample
|
44
|
+
max_tokens: Maximum number of tokens in a batch
|
45
|
+
drop_last: Whether to drop the last incomplete batch
|
46
|
+
"""
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
if num_replicas is None:
|
50
|
+
if dist.is_available() and dist.is_initialized():
|
51
|
+
num_replicas = dist.get_world_size()
|
52
|
+
LOG.info(
|
53
|
+
"torch.distributed is initialized and there are %s of replicas",
|
54
|
+
num_replicas,
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
num_replicas = 1
|
58
|
+
LOG.info(
|
59
|
+
"torch.dist is not initialized and therefore default to 1 for num_replicas"
|
60
|
+
)
|
61
|
+
|
62
|
+
if rank is None:
|
63
|
+
if dist.is_available() and dist.is_initialized():
|
64
|
+
rank = dist.get_rank()
|
65
|
+
LOG.info(
|
66
|
+
"torch.distributed is initialized and the current rank is %s", rank
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
rank = 0
|
70
|
+
LOG.info(
|
71
|
+
"torch.distributed is not initialized and therefore default to 0 for rank"
|
72
|
+
)
|
73
|
+
|
74
|
+
if not (0 <= rank < num_replicas):
|
75
|
+
raise ValueError(
|
76
|
+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
|
77
|
+
)
|
78
|
+
|
79
|
+
if negative_sampling_probability is not None and labels is None:
|
80
|
+
raise ValueError(
|
81
|
+
f"When the negative sampling probability is provide, the labels must be provided as well"
|
82
|
+
)
|
83
|
+
|
84
|
+
self.lengths = lengths
|
85
|
+
self.max_tokens_per_batch = max_tokens_per_batch
|
86
|
+
self.max_position_embeddings = max_position_embeddings
|
87
|
+
self.num_replicas = num_replicas
|
88
|
+
self.rank = rank
|
89
|
+
self.seed = seed
|
90
|
+
self.drop_last = drop_last
|
91
|
+
self.negative_sampling_probability = negative_sampling_probability
|
92
|
+
self.labels = labels
|
93
|
+
# Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
|
94
|
+
# http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
|
95
|
+
# the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
|
96
|
+
# which will call batch_sampler.sample.set_epoch
|
97
|
+
self.sampler = SamplePlacerHolder()
|
98
|
+
|
99
|
+
def __iter__(self) -> Iterator[List[int]]:
|
100
|
+
|
101
|
+
# deterministically shuffle based on epoch and seed
|
102
|
+
g = torch.Generator()
|
103
|
+
g.manual_seed(self.seed + self.sampler.epoch)
|
104
|
+
indices = torch.randperm(len(self.lengths), generator=g).tolist()
|
105
|
+
|
106
|
+
# Partition indices for this rank
|
107
|
+
indices = indices[self.rank :: self.num_replicas]
|
108
|
+
|
109
|
+
batch = []
|
110
|
+
current_batch_tokens = 0
|
111
|
+
|
112
|
+
for idx in indices:
|
113
|
+
# There is a chance to skip the negative samples to account for the class imbalance
|
114
|
+
# in the fine-tuning dataset
|
115
|
+
if self.negative_sampling_probability:
|
116
|
+
if (
|
117
|
+
np.random.random() > self.negative_sampling_probability
|
118
|
+
and self.labels[idx] == 0
|
119
|
+
):
|
120
|
+
continue
|
121
|
+
# We take the minimum of the two because each sequence will be truncated to fit
|
122
|
+
# the context window of the model
|
123
|
+
sample_length = min(self.lengths[idx], self.max_position_embeddings)
|
124
|
+
# If adding this sample would exceed max_tokens_per_batch, yield the current batch
|
125
|
+
if (
|
126
|
+
current_batch_tokens + sample_length + 2 > self.max_tokens_per_batch
|
127
|
+
and batch
|
128
|
+
):
|
129
|
+
yield batch
|
130
|
+
batch = []
|
131
|
+
current_batch_tokens = 0
|
132
|
+
|
133
|
+
# Add the sample to the current batch
|
134
|
+
batch.append(idx)
|
135
|
+
# plus extract one for the [END] and [PAD] tokens to separate samples
|
136
|
+
current_batch_tokens += sample_length + 2
|
137
|
+
|
138
|
+
# Yield the last batch if it's not empty and we're not dropping it
|
139
|
+
if batch and not self.drop_last:
|
140
|
+
yield batch
|
141
|
+
|
142
|
+
def __len__(self) -> int:
|
143
|
+
"""
|
144
|
+
Estimates the number of batches that will be generated.
|
145
|
+
|
146
|
+
This is an approximation since the exact number depends on the specific
|
147
|
+
sequence lengths and their order.
|
148
|
+
"""
|
149
|
+
if len(self.lengths) == 0:
|
150
|
+
return 0
|
151
|
+
|
152
|
+
# There is a chance to skip the negative samples to account for the class imbalance
|
153
|
+
# in the fine-tuning dataset
|
154
|
+
if self.negative_sampling_probability:
|
155
|
+
truncated_lengths = []
|
156
|
+
for length, label in zip(self.lengths, self.labels):
|
157
|
+
if (
|
158
|
+
np.random.random() > self.negative_sampling_probability
|
159
|
+
and label == 0
|
160
|
+
):
|
161
|
+
continue
|
162
|
+
truncated_lengths.append(length)
|
163
|
+
else:
|
164
|
+
# We need to truncate the lengths due to the context window limit imposed by the model
|
165
|
+
truncated_lengths = [
|
166
|
+
min(self.max_position_embeddings, length + 2) for length in self.lengths
|
167
|
+
]
|
168
|
+
|
169
|
+
# Calculate average sequence length
|
170
|
+
avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
|
171
|
+
|
172
|
+
# Estimate average number of sequences per batch
|
173
|
+
seqs_per_batch = self.max_tokens_per_batch // avg_seq_length
|
174
|
+
|
175
|
+
# Estimate total number of batches
|
176
|
+
if self.drop_last:
|
177
|
+
# If dropping last incomplete batch
|
178
|
+
return len(truncated_lengths) // seqs_per_batch
|
179
|
+
else:
|
180
|
+
# If keeping last incomplete batch, ensure at least 1 batch
|
181
|
+
return max(1, len(truncated_lengths) // seqs_per_batch)
|
@@ -93,9 +93,9 @@ def generate_single_batch(
|
|
93
93
|
temperature=temperature,
|
94
94
|
top_p=top_p,
|
95
95
|
top_k=top_k,
|
96
|
-
bos_token_id=
|
97
|
-
eos_token_id=
|
98
|
-
pad_token_id=
|
96
|
+
bos_token_id=model.generation_config.bos_token_id,
|
97
|
+
eos_token_id=model.generation_config.eos_token_id,
|
98
|
+
pad_token_id=model.generation_config.pad_token_id,
|
99
99
|
do_sample=True,
|
100
100
|
use_cache=True,
|
101
101
|
return_dict_in_generate=True,
|
@@ -150,15 +150,11 @@ def main(args):
|
|
150
150
|
attn_implementation=(
|
151
151
|
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
152
152
|
),
|
153
|
-
torch_dtype=(
|
154
|
-
torch.bfloat16
|
155
|
-
if is_flash_attn_2_available() and args.use_bfloat16
|
156
|
-
else torch.float32
|
157
|
-
),
|
158
153
|
)
|
159
154
|
.eval()
|
160
155
|
.to(device)
|
161
156
|
)
|
157
|
+
|
162
158
|
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
163
159
|
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
164
160
|
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
@@ -192,6 +188,7 @@ def main(args):
|
|
192
188
|
LOG.info(f"Top P {args.top_p}")
|
193
189
|
LOG.info(f"Top K {args.top_k}")
|
194
190
|
LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
|
191
|
+
LOG.info(f"MEDS format: {args.meds_format}")
|
195
192
|
|
196
193
|
dataset = load_parquet_as_dataset(args.demographic_data_path)
|
197
194
|
total_rows = len(dataset)
|
@@ -199,6 +196,7 @@ def main(args):
|
|
199
196
|
num_of_batches = args.num_of_patients // args.batch_size + 1
|
200
197
|
sequence_to_flush = []
|
201
198
|
current_person_id = 1
|
199
|
+
prompt_size = 2 if args.meds_format else START_TOKEN_SIZE
|
202
200
|
for i in range(num_of_batches):
|
203
201
|
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
204
202
|
|
@@ -215,7 +213,7 @@ def main(args):
|
|
215
213
|
<= max_seq_allowed
|
216
214
|
):
|
217
215
|
random_prompts.append(
|
218
|
-
cehrgpt_tokenizer.encode(row["concept_ids"][:
|
216
|
+
cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
|
219
217
|
)
|
220
218
|
iter += 1
|
221
219
|
if not random_prompts and iter > 10:
|
@@ -326,6 +324,11 @@ def create_arg_parser():
|
|
326
324
|
dest="drop_long_sequences",
|
327
325
|
action="store_true",
|
328
326
|
)
|
327
|
+
base_arg_parser.add_argument(
|
328
|
+
"--meds_format",
|
329
|
+
dest="meds_format",
|
330
|
+
action="store_true",
|
331
|
+
)
|
329
332
|
return base_arg_parser
|
330
333
|
|
331
334
|
|
@@ -60,6 +60,24 @@ OOV_CONCEPT_MAP = {
|
|
60
60
|
}
|
61
61
|
|
62
62
|
|
63
|
+
def extract_gender_concept_id(gender_token: str) -> int:
|
64
|
+
if gender_token.startswith("Gender/"):
|
65
|
+
return int(gender_token[len("Gender/") :])
|
66
|
+
elif gender_token.isnumeric():
|
67
|
+
return int(gender_token)
|
68
|
+
else:
|
69
|
+
return 0
|
70
|
+
|
71
|
+
|
72
|
+
def extract_race_concept_id(race_token: str) -> int:
|
73
|
+
if race_token.startswith("Race/"):
|
74
|
+
return int(race_token[len("Race/") :])
|
75
|
+
elif race_token.isnumeric():
|
76
|
+
return int(race_token)
|
77
|
+
else:
|
78
|
+
return 0
|
79
|
+
|
80
|
+
|
63
81
|
def create_folder_if_not_exists(output_folder, table_name):
|
64
82
|
if not os.path.isdir(Path(output_folder) / table_name):
|
65
83
|
os.mkdir(Path(output_folder) / table_name)
|
@@ -288,7 +306,13 @@ def gpt_to_omop_converter_batch(
|
|
288
306
|
if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
|
289
307
|
continue
|
290
308
|
|
291
|
-
p = Person(
|
309
|
+
p = Person(
|
310
|
+
person_id=person_id,
|
311
|
+
gender_concept_id=extract_gender_concept_id(start_gender),
|
312
|
+
year_of_birth=birth_year,
|
313
|
+
race_concept_id=extract_race_concept_id(start_race),
|
314
|
+
)
|
315
|
+
|
292
316
|
append_to_dict(omop_export_dict, p, person_id)
|
293
317
|
id_mappings_dict["person"][person_id] = person_id
|
294
318
|
pt_seq_dict[person_id] = " ".join(concept_ids)
|
@@ -316,7 +340,12 @@ def gpt_to_omop_converter_batch(
|
|
316
340
|
id_mappings_dict["death"][person_id] = person_id
|
317
341
|
else:
|
318
342
|
try:
|
319
|
-
|
343
|
+
if clinical_events[event_idx + 1].startswith("Visit/"):
|
344
|
+
visit_concept_id = int(
|
345
|
+
clinical_events[event_idx + 1][len("Visit/") :]
|
346
|
+
)
|
347
|
+
else:
|
348
|
+
visit_concept_id = int(clinical_events[event_idx + 1])
|
320
349
|
inpatient_visit_indicator = visit_concept_id in [
|
321
350
|
9201,
|
322
351
|
262,
|
@@ -349,6 +378,7 @@ def gpt_to_omop_converter_batch(
|
|
349
378
|
visit_occurrence_id
|
350
379
|
] = person_id
|
351
380
|
visit_occurrence_id += 1
|
381
|
+
|
352
382
|
elif event in ATT_TIME_TOKENS:
|
353
383
|
if event[0] == "D":
|
354
384
|
att_date_delta = int(event[1:])
|
cehrgpt/gpt_utils.py
CHANGED
@@ -11,6 +11,7 @@ from cehrgpt.models.special_tokens import (
|
|
11
11
|
)
|
12
12
|
|
13
13
|
# Regular expression pattern to match inpatient attendance tokens
|
14
|
+
MEDS_CODE_PATTERN = re.compile(r".*/.*")
|
14
15
|
INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
|
15
16
|
DEMOGRAPHIC_PROMPT_SIZE = 4
|
16
17
|
|
@@ -194,8 +195,12 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
|
|
194
195
|
return folder_name
|
195
196
|
|
196
197
|
|
197
|
-
def is_clinical_event(token: str) -> bool:
|
198
|
-
|
198
|
+
def is_clinical_event(token: str, meds: bool = False) -> bool:
|
199
|
+
if token.isnumeric():
|
200
|
+
return True
|
201
|
+
if meds:
|
202
|
+
return bool(MEDS_CODE_PATTERN.match(token))
|
203
|
+
return False
|
199
204
|
|
200
205
|
|
201
206
|
def is_visit_start(token: str):
|
@@ -212,6 +217,18 @@ def is_visit_end(token: str) -> bool:
|
|
212
217
|
return token in ["VE", "[VE]"]
|
213
218
|
|
214
219
|
|
220
|
+
def is_inpatient_hour_token(token: str) -> bool:
|
221
|
+
return token.startswith("i-H")
|
222
|
+
|
223
|
+
|
224
|
+
def extract_time_interval_in_hours(token: str) -> int:
|
225
|
+
try:
|
226
|
+
hour = int(token[3:])
|
227
|
+
return hour
|
228
|
+
except ValueError:
|
229
|
+
return 0
|
230
|
+
|
231
|
+
|
215
232
|
def is_att_token(token: str):
|
216
233
|
"""
|
217
234
|
Check if the token is an attention token.
|
@@ -251,6 +268,7 @@ def is_artificial_token(token: str) -> bool:
|
|
251
268
|
return True
|
252
269
|
if token == END_TOKEN:
|
253
270
|
return True
|
271
|
+
|
254
272
|
return False
|
255
273
|
|
256
274
|
|
cehrgpt/models/config.py
CHANGED
@@ -121,6 +121,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
121
121
|
bos_token_id=50256,
|
122
122
|
eos_token_id=50256,
|
123
123
|
lab_token_ids=None,
|
124
|
+
ve_token_id=None,
|
124
125
|
scale_attn_by_inverse_layer_idx=False,
|
125
126
|
reorder_and_upcast_attn=False,
|
126
127
|
exclude_position_ids=False,
|
@@ -128,19 +129,27 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
128
129
|
value_vocab_size=None,
|
129
130
|
include_ttv_prediction=False,
|
130
131
|
use_sub_time_tokenization=True,
|
132
|
+
include_motor_time_to_event=True,
|
133
|
+
motor_tte_vocab_size=None,
|
134
|
+
motor_time_to_event_weight=1.0,
|
135
|
+
motor_num_time_pieces=16,
|
131
136
|
token_to_time_token_mapping: Dict[int, List] = None,
|
132
137
|
use_pretrained_embeddings=False,
|
133
138
|
n_pretrained_embeddings_layers=2,
|
134
139
|
pretrained_embedding_dim=768,
|
135
140
|
pretrained_token_ids: List[int] = None,
|
141
|
+
next_token_prediction_loss_weight=1.0,
|
136
142
|
time_token_loss_weight=1.0,
|
137
143
|
time_to_visit_loss_weight=1.0,
|
138
144
|
causal_sfm=False,
|
139
145
|
demographics_size=4,
|
140
146
|
lab_token_penalty=False,
|
141
147
|
lab_token_loss_weight=0.9,
|
148
|
+
value_prediction_loss_weight=1.0,
|
142
149
|
entropy_penalty=False,
|
143
150
|
entropy_penalty_alpha=0.01,
|
151
|
+
sample_packing_max_positions=None,
|
152
|
+
class_weights=None,
|
144
153
|
**kwargs,
|
145
154
|
):
|
146
155
|
if token_to_time_token_mapping is None:
|
@@ -150,6 +159,11 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
150
159
|
self.vocab_size = vocab_size
|
151
160
|
self.time_token_vocab_size = time_token_vocab_size
|
152
161
|
self.n_positions = n_positions
|
162
|
+
self.sample_packing_max_positions = (
|
163
|
+
sample_packing_max_positions
|
164
|
+
if sample_packing_max_positions
|
165
|
+
else n_positions
|
166
|
+
)
|
153
167
|
self.n_embd = n_embd
|
154
168
|
self.n_layer = n_layer
|
155
169
|
self.n_head = n_head
|
@@ -178,11 +192,28 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
178
192
|
self.include_values = include_values
|
179
193
|
self.value_vocab_size = value_vocab_size
|
180
194
|
|
195
|
+
self.next_token_prediction_loss_weight = next_token_prediction_loss_weight
|
181
196
|
self.include_ttv_prediction = include_ttv_prediction
|
182
197
|
self.use_sub_time_tokenization = use_sub_time_tokenization
|
183
198
|
self._token_to_time_token_mapping = token_to_time_token_mapping
|
184
199
|
self.time_token_loss_weight = time_token_loss_weight
|
185
200
|
self.time_to_visit_loss_weight = time_to_visit_loss_weight
|
201
|
+
|
202
|
+
# MOTOR TTE configuration
|
203
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
204
|
+
self.include_motor_time_to_event = (
|
205
|
+
include_motor_time_to_event
|
206
|
+
and self.motor_tte_vocab_size
|
207
|
+
and self.motor_tte_vocab_size > 0
|
208
|
+
)
|
209
|
+
if self.include_motor_time_to_event and not ve_token_id:
|
210
|
+
raise RuntimeError(
|
211
|
+
f"ve_token_id must be provided when include_motor_time_to_event is True"
|
212
|
+
)
|
213
|
+
self.ve_token_id = ve_token_id
|
214
|
+
self.motor_time_to_event_weight = motor_time_to_event_weight
|
215
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
216
|
+
|
186
217
|
self.causal_sfm = causal_sfm
|
187
218
|
self.demographics_size = demographics_size
|
188
219
|
self.use_pretrained_embeddings = use_pretrained_embeddings
|
@@ -195,6 +226,10 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
195
226
|
self.lab_token_loss_weight = lab_token_loss_weight
|
196
227
|
self.entropy_penalty = entropy_penalty
|
197
228
|
self.entropy_penalty_alpha = entropy_penalty_alpha
|
229
|
+
self.value_prediction_loss_weight = value_prediction_loss_weight
|
230
|
+
|
231
|
+
# Class weights for fine-tuning
|
232
|
+
self.class_weights = class_weights
|
198
233
|
|
199
234
|
kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
|
200
235
|
|