cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,223 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from typing import Callable, Tuple
|
3
|
+
|
4
|
+
import optuna
|
5
|
+
from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
|
6
|
+
from datasets import Dataset, DatasetDict
|
7
|
+
from transformers import (
|
8
|
+
EarlyStoppingCallback,
|
9
|
+
Trainer,
|
10
|
+
TrainerCallback,
|
11
|
+
TrainingArguments,
|
12
|
+
)
|
13
|
+
from transformers.utils import logging
|
14
|
+
|
15
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
16
|
+
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
17
|
+
|
18
|
+
LOG = logging.get_logger("transformers")
|
19
|
+
|
20
|
+
|
21
|
+
class OptunaMetricCallback(TrainerCallback):
|
22
|
+
"""
|
23
|
+
A custom callback to store the best metric in the evaluation metrics dictionary during training.
|
24
|
+
|
25
|
+
This callback monitors the training state and updates the metrics dictionary with the `best_metric`
|
26
|
+
(e.g., the lowest `eval_loss` or highest accuracy) observed during training. It ensures that the
|
27
|
+
best metric value is preserved in the final evaluation results, even if early stopping occurs.
|
28
|
+
|
29
|
+
Attributes:
|
30
|
+
None
|
31
|
+
|
32
|
+
Methods:
|
33
|
+
on_evaluate(args, state, control, **kwargs):
|
34
|
+
Called during evaluation. Adds `state.best_metric` to `metrics` if it exists.
|
35
|
+
|
36
|
+
Example Usage:
|
37
|
+
```
|
38
|
+
store_best_metric_callback = StoreBestMetricCallback()
|
39
|
+
trainer = Trainer(
|
40
|
+
model=model,
|
41
|
+
args=training_args,
|
42
|
+
train_dataset=train_dataset,
|
43
|
+
eval_dataset=val_dataset,
|
44
|
+
callbacks=[store_best_metric_callback]
|
45
|
+
)
|
46
|
+
```
|
47
|
+
"""
|
48
|
+
|
49
|
+
def on_evaluate(self, args, state, control, **kwargs):
|
50
|
+
"""
|
51
|
+
During evaluation, adds the best metric value to the metrics dictionary if it exists.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
args: Training arguments.
|
55
|
+
state: Trainer state object that holds information about training progress.
|
56
|
+
control: Trainer control object to modify training behavior.
|
57
|
+
**kwargs: Additional keyword arguments, including `metrics`, which holds evaluation metrics.
|
58
|
+
|
59
|
+
Updates:
|
60
|
+
`metrics["best_metric"]`: Sets this to `state.best_metric` if available.
|
61
|
+
"""
|
62
|
+
# Check if best metric is available and add it to metrics if it exists
|
63
|
+
metrics = kwargs.get("metrics", {})
|
64
|
+
if state.best_metric is not None:
|
65
|
+
metrics.update(
|
66
|
+
{"optuna_best_metric": min(state.best_metric, metrics["eval_loss"])}
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
metrics.update({"optuna_best_metric": metrics["eval_loss"]})
|
70
|
+
|
71
|
+
|
72
|
+
# Define the hyperparameter search space with parameters
|
73
|
+
def hp_space(
|
74
|
+
trial: optuna.Trial,
|
75
|
+
lr_range: Tuple[float, float] = (1e-5, 5e-5),
|
76
|
+
batch_sizes=None,
|
77
|
+
weight_decays: Tuple[float, float] = (1e-4, 1e-2),
|
78
|
+
num_train_epochs: Tuple[float, ...] = 10,
|
79
|
+
):
|
80
|
+
if batch_sizes is None:
|
81
|
+
batch_sizes = [4, 8]
|
82
|
+
return {
|
83
|
+
"learning_rate": trial.suggest_float("learning_rate", *lr_range, log=True),
|
84
|
+
"per_device_train_batch_size": trial.suggest_categorical(
|
85
|
+
"per_device_train_batch_size", batch_sizes
|
86
|
+
),
|
87
|
+
"weight_decay": trial.suggest_float("weight_decay", *weight_decays, log=True),
|
88
|
+
"num_train_epochs": trial.suggest_int("num_train_epochs", *num_train_epochs),
|
89
|
+
}
|
90
|
+
|
91
|
+
|
92
|
+
def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
|
93
|
+
"""
|
94
|
+
Samples a subset of the given dataset based on a specified percentage.
|
95
|
+
|
96
|
+
This function uses a random train-test split to select a subset of the dataset, returning a sample
|
97
|
+
that is approximately `percentage` of the total dataset size. It is useful for creating smaller
|
98
|
+
datasets for tasks such as hyperparameter tuning or quick testing.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
data (Dataset): The input dataset to sample from.
|
102
|
+
percentage (float): The fraction of the dataset to sample, represented as a decimal
|
103
|
+
(e.g., 0.1 for 10%).
|
104
|
+
seed (int): A random seed for reproducibility in the sampling process.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Dataset: A sampled subset of the input dataset containing `percentage` of the original data.
|
108
|
+
|
109
|
+
Example:
|
110
|
+
```
|
111
|
+
sampled_data = sample_dataset(my_dataset, percentage=0.1, seed=42)
|
112
|
+
```
|
113
|
+
|
114
|
+
Notes:
|
115
|
+
- The `train_test_split` method splits the dataset into "train" and "test" portions. This function
|
116
|
+
returns the "test" portion, which is the specified percentage of the dataset.
|
117
|
+
- Ensure that `percentage` is between 0 and 1 to avoid errors.
|
118
|
+
"""
|
119
|
+
if percentage == 1.0:
|
120
|
+
return data
|
121
|
+
|
122
|
+
return data.train_test_split(
|
123
|
+
test_size=percentage,
|
124
|
+
seed=seed,
|
125
|
+
)["test"]
|
126
|
+
|
127
|
+
|
128
|
+
def perform_hyperparameter_search(
|
129
|
+
model_init: Callable,
|
130
|
+
dataset: DatasetDict,
|
131
|
+
data_collator: CehrGptDataCollator,
|
132
|
+
training_args: TrainingArguments,
|
133
|
+
model_args: ModelArguments,
|
134
|
+
cehrgpt_args: CehrGPTArguments,
|
135
|
+
) -> TrainingArguments:
|
136
|
+
"""
|
137
|
+
Perform hyperparameter tuning for the CehrGPT model using Optuna with the Hugging Face Trainer.
|
138
|
+
|
139
|
+
This function initializes a Trainer with sampled training and validation sets, and performs
|
140
|
+
a hyperparameter search using Optuna. The search tunes learning rate, batch size, and weight decay
|
141
|
+
to optimize model performance based on a specified objective metric (e.g., validation loss).
|
142
|
+
After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
model_init (Callable): A function to initialize the model, used for each hyperparameter trial.
|
146
|
+
dataset (DatasetDict): A Hugging Face DatasetDict containing "train" and "validation" datasets.
|
147
|
+
data_collator (CehrGptDataCollator): A data collator for processing batches.
|
148
|
+
training_args (TrainingArguments): Configuration for training parameters (e.g., epochs, evaluation strategy).
|
149
|
+
model_args (ModelArguments): Model configuration arguments, including early stopping parameters.
|
150
|
+
cehrgpt_args (CehrGPTArguments): Additional arguments specific to CehrGPT, including hyperparameter
|
151
|
+
tuning options such as learning rate range, batch sizes, and tuning percentage.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
TrainingArguments: Updated `TrainingArguments` instance containing the best hyperparameters found
|
155
|
+
from the search.
|
156
|
+
|
157
|
+
Example:
|
158
|
+
```
|
159
|
+
best_training_args = perform_hyperparameter_search(
|
160
|
+
model_init=my_model_init,
|
161
|
+
dataset=my_dataset_dict,
|
162
|
+
data_collator=my_data_collator,
|
163
|
+
training_args=initial_training_args,
|
164
|
+
model_args=model_args,
|
165
|
+
cehrgpt_args=cehrgpt_args
|
166
|
+
)
|
167
|
+
```
|
168
|
+
|
169
|
+
Notes:
|
170
|
+
- If `cehrgpt_args.hyperparameter_tuning` is set to `True`, this function samples a portion of the
|
171
|
+
training and validation datasets for efficient tuning.
|
172
|
+
- `EarlyStoppingCallback` is added to the Trainer if early stopping is enabled in `model_args`.
|
173
|
+
- Optuna's `hyperparameter_search` is configured with the specified number of trials (`n_trials`)
|
174
|
+
and learning rate and batch size ranges provided in `cehrgpt_args`.
|
175
|
+
|
176
|
+
Logging:
|
177
|
+
Logs the best hyperparameters found at the end of the search.
|
178
|
+
"""
|
179
|
+
if cehrgpt_args.hyperparameter_tuning:
|
180
|
+
sampled_train = sample_dataset(
|
181
|
+
dataset["train"],
|
182
|
+
cehrgpt_args.hyperparameter_tuning_percentage,
|
183
|
+
training_args.seed,
|
184
|
+
)
|
185
|
+
sampled_val = sample_dataset(
|
186
|
+
dataset["validation"],
|
187
|
+
cehrgpt_args.hyperparameter_tuning_percentage,
|
188
|
+
training_args.seed,
|
189
|
+
)
|
190
|
+
hyperparam_trainer = Trainer(
|
191
|
+
model_init=model_init,
|
192
|
+
data_collator=data_collator,
|
193
|
+
train_dataset=sampled_train,
|
194
|
+
eval_dataset=sampled_val,
|
195
|
+
callbacks=[
|
196
|
+
EarlyStoppingCallback(model_args.early_stopping_patience),
|
197
|
+
OptunaMetricCallback(),
|
198
|
+
],
|
199
|
+
args=training_args,
|
200
|
+
)
|
201
|
+
# Perform hyperparameter search
|
202
|
+
best_trial = hyperparam_trainer.hyperparameter_search(
|
203
|
+
direction="minimize",
|
204
|
+
hp_space=partial(
|
205
|
+
hp_space,
|
206
|
+
lr_range=(cehrgpt_args.lr_low, cehrgpt_args.lr_high),
|
207
|
+
weight_decays=(
|
208
|
+
cehrgpt_args.weight_decays_low,
|
209
|
+
cehrgpt_args.weight_decays_high,
|
210
|
+
),
|
211
|
+
batch_sizes=cehrgpt_args.hyperparameter_batch_sizes,
|
212
|
+
num_train_epochs=cehrgpt_args.hyperparameter_num_train_epochs,
|
213
|
+
),
|
214
|
+
backend="optuna",
|
215
|
+
n_trials=cehrgpt_args.n_trials,
|
216
|
+
compute_objective=lambda m: m["optuna_best_metric"],
|
217
|
+
)
|
218
|
+
LOG.info("Best hyperparameters: %s", best_trial.hyperparameters)
|
219
|
+
# Update training arguments with best hyperparameters and set epochs based on adjusted effective epochs
|
220
|
+
for k, v in best_trial.hyperparameters.items():
|
221
|
+
setattr(training_args, k, v)
|
222
|
+
|
223
|
+
return training_args
|
File without changes
|
@@ -0,0 +1,8 @@
|
|
1
|
+
task_name: "next_visit_type_prediction"
|
2
|
+
outcome_events: [
|
3
|
+
'9202', '9203', '581477', '9201', '5083', '262', '38004250', '8883', '38004238', '38004251',
|
4
|
+
'38004222', '38004268', '38004228', '32693', '8971', '38004269', '38004193', '32036', '8782'
|
5
|
+
]
|
6
|
+
include_descendants: false
|
7
|
+
future_visit_start: 0
|
8
|
+
future_visit_end: 1
|
@@ -0,0 +1,226 @@
|
|
1
|
+
import math
|
2
|
+
from collections import Counter
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from cehrbert_data.decorators.patient_event_decorator_base import time_month_token
|
9
|
+
from transformers import GenerationConfig
|
10
|
+
|
11
|
+
from cehrgpt.gpt_utils import (
|
12
|
+
extract_time_interval_in_days,
|
13
|
+
is_att_token,
|
14
|
+
is_visit_end,
|
15
|
+
is_visit_start,
|
16
|
+
)
|
17
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
18
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class TimeToEvent:
|
23
|
+
average_time: float
|
24
|
+
median_time: float
|
25
|
+
standard_deviation: float
|
26
|
+
most_likely_time: str
|
27
|
+
num_of_simulations: int
|
28
|
+
time_intervals: List[int]
|
29
|
+
outcome_events: List[str]
|
30
|
+
time_interval_probability_table: List[Dict[str, Any]]
|
31
|
+
|
32
|
+
|
33
|
+
def create_time_to_event(
|
34
|
+
time_event_tuples: List[Tuple[str, int]], num_of_simulations: int
|
35
|
+
) -> TimeToEvent:
|
36
|
+
outcome_events, time_intervals = zip(*time_event_tuples)
|
37
|
+
time_buckets = [time_month_token(_) for _ in time_intervals]
|
38
|
+
time_bucket_counter = Counter(time_buckets)
|
39
|
+
most_common_item = time_bucket_counter.most_common(1)[0][0]
|
40
|
+
total_count = sum(time_bucket_counter.values())
|
41
|
+
# Generate the probability table
|
42
|
+
probability_table = {
|
43
|
+
item: count / total_count for item, count in time_bucket_counter.items()
|
44
|
+
}
|
45
|
+
sorted_probability_table = [
|
46
|
+
{"time_interval": k, "probability": v}
|
47
|
+
for k, v in sorted(probability_table.items(), key=lambda x: x[1], reverse=True)
|
48
|
+
]
|
49
|
+
return TimeToEvent(
|
50
|
+
time_intervals=time_intervals,
|
51
|
+
outcome_events=outcome_events,
|
52
|
+
average_time=np.mean(time_intervals),
|
53
|
+
median_time=np.median(time_intervals),
|
54
|
+
standard_deviation=np.std(time_intervals),
|
55
|
+
most_likely_time=most_common_item,
|
56
|
+
num_of_simulations=num_of_simulations,
|
57
|
+
time_interval_probability_table=sorted_probability_table,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
class TimeToEventModel:
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
tokenizer: CehrGptTokenizer,
|
65
|
+
model: CEHRGPT2LMHeadModel,
|
66
|
+
outcome_events: List[str],
|
67
|
+
generation_config: GenerationConfig,
|
68
|
+
device: torch.device = torch.device("cpu"),
|
69
|
+
batch_size: int = 32,
|
70
|
+
):
|
71
|
+
self.tokenizer = tokenizer
|
72
|
+
self.model = model.eval()
|
73
|
+
self.generation_config = generation_config
|
74
|
+
self.outcome_events = outcome_events
|
75
|
+
self.device = device
|
76
|
+
self.batch_size = batch_size
|
77
|
+
self.max_sequence = model.config.n_positions
|
78
|
+
|
79
|
+
def is_outcome_event(self, token: str):
|
80
|
+
return token in self.outcome_events
|
81
|
+
|
82
|
+
def simulate(
|
83
|
+
self, partial_history: Union[np.ndarray, List[str]]
|
84
|
+
) -> List[List[str]]:
|
85
|
+
|
86
|
+
sequence_is_demographics = len(partial_history) == 4 and partial_history[
|
87
|
+
0
|
88
|
+
].startswith("year")
|
89
|
+
sequence_ends_ve = is_visit_end(partial_history[-1])
|
90
|
+
|
91
|
+
if not (sequence_is_demographics | sequence_ends_ve):
|
92
|
+
raise ValueError(
|
93
|
+
"There are only two types of sequences allowed. 1) the sequence only contains "
|
94
|
+
"demographics; 2) the sequence ends on VE;"
|
95
|
+
)
|
96
|
+
|
97
|
+
token_ids = self.tokenizer.encode(partial_history)
|
98
|
+
prompt = torch.tensor(token_ids).unsqueeze(0).to(self.device)
|
99
|
+
|
100
|
+
simulated_sequences = []
|
101
|
+
num_iters = max(
|
102
|
+
math.ceil(self.generation_config.num_return_sequences / self.batch_size), 1
|
103
|
+
)
|
104
|
+
old_num_return_sequences = self.generation_config.num_return_sequences
|
105
|
+
self.generation_config.num_return_sequences = min(
|
106
|
+
self.batch_size, old_num_return_sequences
|
107
|
+
)
|
108
|
+
with torch.no_grad():
|
109
|
+
for _ in range(num_iters):
|
110
|
+
results = self.model.generate(
|
111
|
+
inputs=prompt,
|
112
|
+
generation_config=self.generation_config,
|
113
|
+
)
|
114
|
+
# Clear the cache
|
115
|
+
torch.cuda.empty_cache()
|
116
|
+
# Add the sequences to the result array
|
117
|
+
simulated_sequences.extend(
|
118
|
+
[
|
119
|
+
self.tokenizer.decode(seq.cpu().numpy())
|
120
|
+
for seq in results.sequences
|
121
|
+
]
|
122
|
+
)
|
123
|
+
|
124
|
+
self.generation_config.num_return_sequences = old_num_return_sequences
|
125
|
+
return simulated_sequences
|
126
|
+
|
127
|
+
def predict_time_to_events(
|
128
|
+
self,
|
129
|
+
partial_history: Union[np.ndarray, list],
|
130
|
+
future_visit_start: int = 0,
|
131
|
+
future_visit_end: int = -1,
|
132
|
+
prediction_window_start: int = 0,
|
133
|
+
prediction_window_end: int = 365,
|
134
|
+
debug: bool = False,
|
135
|
+
max_n_trial: int = 2,
|
136
|
+
) -> Optional[TimeToEvent]:
|
137
|
+
patient_history_length = len(partial_history)
|
138
|
+
time_event_tuples = []
|
139
|
+
seqs_failed_to_convert = []
|
140
|
+
n_trial = 0
|
141
|
+
num_return_sequences = self.generation_config.num_return_sequences
|
142
|
+
max_new_tokens = self.generation_config.max_new_tokens
|
143
|
+
while (
|
144
|
+
len(time_event_tuples) < self.generation_config.num_return_sequences
|
145
|
+
and n_trial < max_n_trial
|
146
|
+
):
|
147
|
+
self.generation_config.num_return_sequences = num_return_sequences - len(
|
148
|
+
time_event_tuples
|
149
|
+
)
|
150
|
+
# self.generation_config.max_new_tokens = max_new_tokens * (n_trial + 1)
|
151
|
+
simulated_seqs = self.simulate(partial_history)
|
152
|
+
n_trial += 1
|
153
|
+
for seq in simulated_seqs:
|
154
|
+
visit_counter = 0
|
155
|
+
time_delta = 0
|
156
|
+
success = False
|
157
|
+
for next_token in seq[patient_history_length:]:
|
158
|
+
visit_counter += int(is_visit_start(next_token))
|
159
|
+
if (
|
160
|
+
visit_counter > future_visit_end != -1
|
161
|
+
or time_delta > prediction_window_end != -1
|
162
|
+
):
|
163
|
+
time_event_tuples.append(("0", time_delta))
|
164
|
+
success = True
|
165
|
+
break
|
166
|
+
if is_att_token(next_token):
|
167
|
+
time_delta += extract_time_interval_in_days(next_token)
|
168
|
+
elif (
|
169
|
+
visit_counter >= future_visit_start
|
170
|
+
and time_delta >= prediction_window_start
|
171
|
+
) and self.is_outcome_event(next_token):
|
172
|
+
time_event_tuples.append((next_token, time_delta))
|
173
|
+
success = True
|
174
|
+
break
|
175
|
+
if not success:
|
176
|
+
# This indicates the generated sequence did not satisfy the criteria
|
177
|
+
if future_visit_end != -1 or prediction_window_end != -1:
|
178
|
+
seqs_failed_to_convert.append(seq[patient_history_length:])
|
179
|
+
else:
|
180
|
+
time_event_tuples.append(("0", time_delta))
|
181
|
+
|
182
|
+
self.generation_config.num_return_sequences = num_return_sequences
|
183
|
+
self.generation_config.max_new_tokens = max_new_tokens
|
184
|
+
|
185
|
+
if debug:
|
186
|
+
print(f"seqs_failed_to_convert: {seqs_failed_to_convert}")
|
187
|
+
|
188
|
+
# Count the occurrences of each time tokens for each concept
|
189
|
+
return (
|
190
|
+
create_time_to_event(time_event_tuples, len(time_event_tuples))
|
191
|
+
if len(time_event_tuples) > 0
|
192
|
+
else None
|
193
|
+
)
|
194
|
+
|
195
|
+
@staticmethod
|
196
|
+
def get_generation_config(
|
197
|
+
tokenizer: CehrGptTokenizer,
|
198
|
+
max_length: int,
|
199
|
+
num_return_sequences: int,
|
200
|
+
top_p: float = 1.0,
|
201
|
+
top_k: int = 300,
|
202
|
+
temperature: float = 1.0,
|
203
|
+
repetition_penalty: float = 1.0,
|
204
|
+
epsilon_cutoff: float = 0.0,
|
205
|
+
max_new_tokens: int = 128,
|
206
|
+
) -> GenerationConfig:
|
207
|
+
return GenerationConfig(
|
208
|
+
max_length=max_length,
|
209
|
+
max_new_tokens=max_new_tokens,
|
210
|
+
num_return_sequences=num_return_sequences,
|
211
|
+
temperature=temperature,
|
212
|
+
repetition_penalty=repetition_penalty,
|
213
|
+
epsilon_cutoff=epsilon_cutoff,
|
214
|
+
top_p=top_p,
|
215
|
+
top_k=top_k,
|
216
|
+
bos_token_id=tokenizer.end_token_id,
|
217
|
+
eos_token_id=tokenizer.end_token_id,
|
218
|
+
pad_token_id=tokenizer.pad_token_id,
|
219
|
+
do_sample=True,
|
220
|
+
use_cache=True,
|
221
|
+
return_dict_in_generate=True,
|
222
|
+
output_attentions=False,
|
223
|
+
output_hidden_states=False,
|
224
|
+
output_scores=False,
|
225
|
+
renormalize_logits=True,
|
226
|
+
)
|