cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
import dataclasses
|
2
|
-
from typing import List, Optional
|
2
|
+
from typing import List, Literal, Optional
|
3
|
+
|
4
|
+
from cehrgpt.models.gpt2 import ACT2FN
|
3
5
|
|
4
6
|
|
5
7
|
@dataclasses.dataclass
|
@@ -12,6 +14,14 @@ class CehrGPTArguments:
|
|
12
14
|
"help": "The path to the tokenized dataset created for the full population"
|
13
15
|
},
|
14
16
|
)
|
17
|
+
activation_function: Literal[tuple(ACT2FN.keys())] = dataclasses.field(
|
18
|
+
default="gelu_new",
|
19
|
+
metadata={"help": "The activation function to use"},
|
20
|
+
)
|
21
|
+
decoder_mlp: Literal["GPT2MLP", "LlamaMLP"] = dataclasses.field(
|
22
|
+
default="GPT2MLP",
|
23
|
+
metadata={"help": "The decoder MLP architecture"},
|
24
|
+
)
|
15
25
|
include_inpatient_hour_token: Optional[bool] = dataclasses.field(
|
16
26
|
default=True,
|
17
27
|
metadata={"help": "Include inpatient hour token"},
|
@@ -54,6 +64,14 @@ class CehrGPTArguments:
|
|
54
64
|
default=128,
|
55
65
|
metadata={"help": "The number of examples from the training set."},
|
56
66
|
)
|
67
|
+
hyperparameter_tuning: Optional[bool] = dataclasses.field(
|
68
|
+
default=False,
|
69
|
+
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
|
70
|
+
)
|
71
|
+
hyperparameter_tuning_is_grid: Optional[bool] = dataclasses.field(
|
72
|
+
default=True,
|
73
|
+
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
|
74
|
+
)
|
57
75
|
hyperparameter_tuning_percentage: Optional[float] = dataclasses.field(
|
58
76
|
default=0.1,
|
59
77
|
metadata={
|
@@ -66,10 +84,6 @@ class CehrGPTArguments:
|
|
66
84
|
"help": "The number of trails will be use for hyperparameter tuning."
|
67
85
|
},
|
68
86
|
)
|
69
|
-
hyperparameter_tuning: Optional[bool] = dataclasses.field(
|
70
|
-
default=False,
|
71
|
-
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
|
72
|
-
)
|
73
87
|
hyperparameter_batch_sizes: Optional[List[int]] = dataclasses.field(
|
74
88
|
default_factory=lambda: [4, 8, 16],
|
75
89
|
metadata={"help": "Hyperparameter search batch sizes"},
|
@@ -78,29 +92,13 @@ class CehrGPTArguments:
|
|
78
92
|
default_factory=lambda: [10],
|
79
93
|
metadata={"help": "Hyperparameter search num_train_epochs"},
|
80
94
|
)
|
81
|
-
|
82
|
-
|
83
|
-
metadata={
|
84
|
-
"help": "The lower bound of the learning rate range for hyperparameter tuning."
|
85
|
-
},
|
95
|
+
hyperparameter_learning_rates: Optional[List[int]] = dataclasses.field(
|
96
|
+
default_factory=lambda: [1e-5],
|
97
|
+
metadata={"help": "Hyperparameter search learning rates"},
|
86
98
|
)
|
87
|
-
|
88
|
-
|
89
|
-
metadata={
|
90
|
-
"help": "The upper bound of the learning rate range for hyperparameter tuning."
|
91
|
-
},
|
92
|
-
)
|
93
|
-
weight_decays_low: Optional[float] = dataclasses.field(
|
94
|
-
default=1e-3,
|
95
|
-
metadata={
|
96
|
-
"help": "The lower bound of the weight decays range for hyperparameter tuning."
|
97
|
-
},
|
98
|
-
)
|
99
|
-
weight_decays_high: Optional[float] = dataclasses.field(
|
100
|
-
default=1e-2,
|
101
|
-
metadata={
|
102
|
-
"help": "The upper bound of the weight decays range for hyperparameter tuning."
|
103
|
-
},
|
99
|
+
hyperparameter_weight_decays: Optional[List[int]] = dataclasses.field(
|
100
|
+
default_factory=lambda: [1e-2],
|
101
|
+
metadata={"help": "Hyperparameter search learning rates"},
|
104
102
|
)
|
105
103
|
causal_sfm: Optional[bool] = dataclasses.field(
|
106
104
|
default=False,
|
@@ -168,6 +166,16 @@ class CehrGPTArguments:
|
|
168
166
|
"help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
|
169
167
|
},
|
170
168
|
)
|
169
|
+
inner_dim: Optional[int] = dataclasses.field(
|
170
|
+
default=None,
|
171
|
+
metadata={"help": "The dimensionality of the hidden layer"},
|
172
|
+
)
|
173
|
+
apply_rotary: Optional[bool] = dataclasses.field(
|
174
|
+
default=False,
|
175
|
+
metadata={
|
176
|
+
"help": "A flag to indicate whether we want to use rotary encoder layers"
|
177
|
+
},
|
178
|
+
)
|
171
179
|
sample_packing: Optional[bool] = dataclasses.field(
|
172
180
|
default=False,
|
173
181
|
metadata={
|
@@ -177,12 +185,6 @@ class CehrGPTArguments:
|
|
177
185
|
max_tokens_per_batch: int = dataclasses.field(
|
178
186
|
default=16384, metadata={"help": "Maximum number of tokens in each batch"}
|
179
187
|
)
|
180
|
-
add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
|
181
|
-
default=False,
|
182
|
-
metadata={
|
183
|
-
"help": "A flag to indicate whether we want to add end token in sample packing"
|
184
|
-
},
|
185
|
-
)
|
186
188
|
include_motor_time_to_event: Optional[bool] = dataclasses.field(
|
187
189
|
default=False,
|
188
190
|
metadata={
|
@@ -203,7 +205,17 @@ class CehrGPTArguments:
|
|
203
205
|
"help": "The number of times each motor_num_time_pieces piece has to be"
|
204
206
|
},
|
205
207
|
)
|
206
|
-
|
208
|
+
motor_use_ontology: Optional[bool] = dataclasses.field(
|
209
|
+
default=False,
|
210
|
+
metadata={
|
211
|
+
"help": "A flag to indicate whether we want to use motor_use_ontology"
|
212
|
+
},
|
213
|
+
)
|
214
|
+
motor_sampling_probability: Optional[float] = dataclasses.field(
|
215
|
+
default=0.0,
|
216
|
+
metadata={"help": "A flag to indicate whether we want to use sample packing"},
|
217
|
+
)
|
218
|
+
vocab_dir: Optional[str] = dataclasses.field(
|
207
219
|
default=None,
|
208
220
|
metadata={"help": "The directory where the concept data is stored."},
|
209
221
|
)
|
@@ -229,3 +241,15 @@ class CehrGPTArguments:
|
|
229
241
|
"help": "The probability of negative samples will be included in the training data"
|
230
242
|
},
|
231
243
|
)
|
244
|
+
num_of_trajectories_per_sample: Optional[int] = dataclasses.field(
|
245
|
+
default=1,
|
246
|
+
metadata={"help": "The number of trajectories per sample"},
|
247
|
+
)
|
248
|
+
generation_input_length: Optional[int] = dataclasses.field(
|
249
|
+
default=1024,
|
250
|
+
metadata={"help": "The length of the input sequence"},
|
251
|
+
)
|
252
|
+
generation_max_new_tokens: Optional[int] = dataclasses.field(
|
253
|
+
default=1024,
|
254
|
+
metadata={"help": "The maximum number of tokens in the generation sequence"},
|
255
|
+
)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from functools import partial
|
2
|
-
from typing import Callable, Tuple
|
2
|
+
from typing import Callable, List, Optional, Tuple, Union
|
3
3
|
|
4
4
|
import optuna
|
5
5
|
from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
|
@@ -64,28 +64,99 @@ class OptunaMetricCallback(TrainerCallback):
|
|
64
64
|
metrics.update({"optuna_best_metric": metrics["eval_loss"]})
|
65
65
|
|
66
66
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
67
|
+
def get_suggestion(
|
68
|
+
trial,
|
69
|
+
hyperparameter_name: str,
|
70
|
+
hyperparameters: List[Union[float, int]],
|
71
|
+
is_grid: bool = False,
|
72
|
+
) -> Union[float, int]:
|
73
|
+
"""
|
74
|
+
Get hyperparameter suggestion based on search mode.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
trial: Optuna trial object
|
78
|
+
hyperparameter_name: Name of the hyperparameter
|
79
|
+
hyperparameters: List of hyperparameter values
|
80
|
+
is_grid: Whether to use grid search mode
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Suggested hyperparameter value
|
84
|
+
|
85
|
+
Raises:
|
86
|
+
RuntimeError: If Bayesian mode is used with incorrect number of bounds
|
87
|
+
"""
|
88
|
+
if is_grid:
|
89
|
+
return trial.suggest_categorical(hyperparameter_name, hyperparameters)
|
90
|
+
|
91
|
+
# For Bayesian optimization, we need exactly 2 values (lower and upper bounds)
|
92
|
+
if len(hyperparameters) != 2:
|
93
|
+
raise RuntimeError(
|
94
|
+
f"{hyperparameter_name} must contain exactly two values (lower and upper bound) "
|
95
|
+
f"for Bayesian Optimization, but {len(hyperparameters)} values were provided: {hyperparameters}"
|
96
|
+
)
|
97
|
+
|
98
|
+
# Ensure bounds are sorted
|
99
|
+
lower, upper = sorted(hyperparameters)
|
100
|
+
return trial.suggest_float(hyperparameter_name, lower, upper, log=True)
|
101
|
+
|
102
|
+
|
103
|
+
def hp_space(trial: optuna.Trial, cehrgpt_args: CehrGPTArguments):
|
104
|
+
"""
|
105
|
+
Define the hyperparameter search space.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
trial: Optuna trial object
|
109
|
+
cehrgpt_args: CehrGPTArguments
|
110
|
+
Returns:
|
111
|
+
Dictionary of hyperparameter suggestions
|
112
|
+
"""
|
113
|
+
|
114
|
+
is_grid = cehrgpt_args.hyperparameter_tuning_is_grid
|
115
|
+
learning_rates = cehrgpt_args.hyperparameter_learning_rates
|
116
|
+
weight_decays = cehrgpt_args.hyperparameter_weight_decays
|
117
|
+
batch_sizes = cehrgpt_args.hyperparameter_batch_sizes
|
118
|
+
num_train_epochs = cehrgpt_args.hyperparameter_num_train_epochs
|
119
|
+
|
77
120
|
return {
|
78
|
-
"learning_rate":
|
121
|
+
"learning_rate": get_suggestion(
|
122
|
+
trial, "learning_rate", learning_rates, is_grid
|
123
|
+
),
|
79
124
|
"per_device_train_batch_size": trial.suggest_categorical(
|
80
125
|
"per_device_train_batch_size", batch_sizes
|
81
126
|
),
|
82
|
-
"weight_decay": trial
|
127
|
+
"weight_decay": get_suggestion(trial, "weight_decay", weight_decays, is_grid),
|
83
128
|
"num_train_epochs": trial.suggest_categorical(
|
84
129
|
"num_train_epochs", num_train_epochs
|
85
130
|
),
|
86
131
|
}
|
87
132
|
|
88
133
|
|
134
|
+
def create_grid_search_space(cehrgpt_args: CehrGPTArguments):
|
135
|
+
"""
|
136
|
+
Create the search space dictionary for GridSampler.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
cehrgpt_args: CehrGPTArguments
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Dictionary defining the grid search space
|
143
|
+
"""
|
144
|
+
return {
|
145
|
+
"learning_rate": cehrgpt_args.hyperparameter_learning_rates,
|
146
|
+
"weight_decay": cehrgpt_args.hyperparameter_weight_decays,
|
147
|
+
"per_device_train_batch_size": cehrgpt_args.hyperparameter_batch_sizes,
|
148
|
+
"num_train_epochs": cehrgpt_args.hyperparameter_num_train_epochs,
|
149
|
+
}
|
150
|
+
|
151
|
+
|
152
|
+
def calculate_total_combinations(search_space: dict) -> int:
|
153
|
+
"""Calculate total number of combinations in grid search."""
|
154
|
+
total = 1
|
155
|
+
for values in search_space.values():
|
156
|
+
total *= len(values)
|
157
|
+
return total
|
158
|
+
|
159
|
+
|
89
160
|
def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
|
90
161
|
"""
|
91
162
|
Samples a subset of the given dataset based on a specified percentage.
|
@@ -113,7 +184,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
|
|
113
184
|
returns the "test" portion, which is the specified percentage of the dataset.
|
114
185
|
- Ensure that `percentage` is between 0 and 1 to avoid errors.
|
115
186
|
"""
|
116
|
-
if percentage
|
187
|
+
if percentage >= 1.0:
|
117
188
|
return data
|
118
189
|
|
119
190
|
return data.train_test_split(
|
@@ -130,14 +201,13 @@ def perform_hyperparameter_search(
|
|
130
201
|
training_args: TrainingArguments,
|
131
202
|
model_args: ModelArguments,
|
132
203
|
cehrgpt_args: CehrGPTArguments,
|
133
|
-
) -> TrainingArguments:
|
204
|
+
) -> Tuple[TrainingArguments, Optional[str]]:
|
134
205
|
"""
|
135
206
|
Perform hyperparameter tuning for the CehrGPT model using Optuna with the Hugging Face Trainer.
|
136
207
|
|
137
|
-
This function
|
138
|
-
|
139
|
-
|
140
|
-
After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
|
208
|
+
This function supports two modes:
|
209
|
+
1. Bayesian Optimization (TPE): Intelligently explores hyperparameter space using bounds
|
210
|
+
2. Grid Search: Exhaustively tests all combinations of discrete values
|
141
211
|
|
142
212
|
Args:
|
143
213
|
trainer_class: A Trainer or its subclass
|
@@ -147,15 +217,15 @@ def perform_hyperparameter_search(
|
|
147
217
|
training_args (TrainingArguments): Configuration for training parameters (e.g., epochs, evaluation strategy).
|
148
218
|
model_args (ModelArguments): Model configuration arguments, including early stopping parameters.
|
149
219
|
cehrgpt_args (CehrGPTArguments): Additional arguments specific to CehrGPT, including hyperparameter
|
150
|
-
tuning options
|
220
|
+
tuning options and search mode configuration.
|
151
221
|
|
152
222
|
Returns:
|
153
|
-
TrainingArguments: Updated
|
154
|
-
|
223
|
+
Tuple[TrainingArguments, Optional[str]]: Updated TrainingArguments with best hyperparameters
|
224
|
+
and optional run_id of the best trial.
|
155
225
|
|
156
226
|
Example:
|
157
227
|
```
|
158
|
-
best_training_args = perform_hyperparameter_search(
|
228
|
+
best_training_args, run_id = perform_hyperparameter_search(
|
159
229
|
trainer_class=Trainer,
|
160
230
|
model_init=my_model_init,
|
161
231
|
dataset=my_dataset_dict,
|
@@ -176,50 +246,91 @@ def perform_hyperparameter_search(
|
|
176
246
|
Logging:
|
177
247
|
Logs the best hyperparameters found at the end of the search.
|
178
248
|
"""
|
179
|
-
if cehrgpt_args.hyperparameter_tuning:
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
#
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
249
|
+
if not cehrgpt_args.hyperparameter_tuning:
|
250
|
+
return training_args, None
|
251
|
+
|
252
|
+
# Prepare hyperparameters based on mode
|
253
|
+
if (
|
254
|
+
cehrgpt_args.hyperparameter_tuning_is_grid
|
255
|
+
and cehrgpt_args.hyperparameter_tuning_is_grid
|
256
|
+
):
|
257
|
+
search_space = create_grid_search_space(cehrgpt_args)
|
258
|
+
total_combinations = calculate_total_combinations(search_space)
|
259
|
+
|
260
|
+
LOG.info(f"Grid search mode: Testing {total_combinations} combinations")
|
261
|
+
LOG.info(f"Search space: {search_space}")
|
262
|
+
|
263
|
+
# Adjust n_trials for grid search if not set appropriately
|
264
|
+
if cehrgpt_args.n_trials < total_combinations:
|
265
|
+
LOG.warning(
|
266
|
+
f"n_trials ({cehrgpt_args.n_trials}) is less than total combinations ({total_combinations}). "
|
267
|
+
f"Setting n_trials to {total_combinations} to test all combinations."
|
268
|
+
)
|
269
|
+
cehrgpt_args.n_trials = total_combinations
|
270
|
+
|
271
|
+
# Configure sampler based on search mode
|
272
|
+
sampler = optuna.samplers.GridSampler(search_space, seed=training_args.seed)
|
273
|
+
else:
|
274
|
+
LOG.info("Bayesian optimization mode (TPE)")
|
275
|
+
LOG.info(f"Learning rate bounds: {cehrgpt_args.hyperparameter_learning_rates}")
|
276
|
+
LOG.info(f"Weight decay bounds: {cehrgpt_args.hyperparameter_weight_decays}")
|
277
|
+
LOG.info(f"Batch sizes: {cehrgpt_args.hyperparameter_batch_sizes}")
|
278
|
+
LOG.info(f"Epochs: {cehrgpt_args.hyperparameter_num_train_epochs}")
|
279
|
+
# Configure the TPE sampler
|
280
|
+
sampler = optuna.samplers.TPESampler(seed=training_args.seed)
|
281
|
+
|
282
|
+
# Prepare datasets
|
283
|
+
save_total_limit_original = training_args.save_total_limit
|
284
|
+
training_args.save_total_limit = 1
|
285
|
+
|
286
|
+
sampled_train = sample_dataset(
|
287
|
+
dataset["train"],
|
288
|
+
cehrgpt_args.hyperparameter_tuning_percentage,
|
289
|
+
training_args.seed,
|
290
|
+
)
|
291
|
+
sampled_val = sample_dataset(
|
292
|
+
dataset["validation"],
|
293
|
+
cehrgpt_args.hyperparameter_tuning_percentage,
|
294
|
+
training_args.seed,
|
295
|
+
)
|
296
|
+
# Create trainer
|
297
|
+
hyperparam_trainer = trainer_class(
|
298
|
+
model_init=model_init,
|
299
|
+
data_collator=data_collator,
|
300
|
+
train_dataset=sampled_train,
|
301
|
+
eval_dataset=sampled_val,
|
302
|
+
callbacks=[
|
303
|
+
EarlyStoppingCallback(model_args.early_stopping_patience),
|
304
|
+
OptunaMetricCallback(),
|
305
|
+
],
|
306
|
+
args=training_args,
|
307
|
+
)
|
308
|
+
|
309
|
+
best_trial = hyperparam_trainer.hyperparameter_search(
|
310
|
+
direction="minimize",
|
311
|
+
hp_space=partial(
|
312
|
+
hp_space,
|
313
|
+
cehrgpt_args=cehrgpt_args,
|
314
|
+
),
|
315
|
+
backend="optuna",
|
316
|
+
n_trials=cehrgpt_args.n_trials,
|
317
|
+
compute_objective=lambda m: m["optuna_best_metric"],
|
318
|
+
sampler=sampler,
|
319
|
+
)
|
320
|
+
|
321
|
+
# Log results
|
322
|
+
LOG.info("=" * 50)
|
323
|
+
LOG.info("HYPERPARAMETER SEARCH COMPLETED")
|
324
|
+
LOG.info("=" * 50)
|
325
|
+
LOG.info(f"Best hyperparameters: {best_trial.hyperparameters}")
|
326
|
+
LOG.info(f"Best metric (eval_loss): {best_trial.objective}")
|
327
|
+
LOG.info(f"Best run_id: {best_trial.run_id}")
|
328
|
+
LOG.info("=" * 50)
|
329
|
+
|
330
|
+
# Restore original settings and update with best hyperparameters
|
331
|
+
training_args.save_total_limit = save_total_limit_original
|
332
|
+
for k, v in best_trial.hyperparameters.items():
|
333
|
+
setattr(training_args, k, v)
|
334
|
+
LOG.info(f"Updated training_args.{k} = {v}")
|
224
335
|
|
225
|
-
return training_args
|
336
|
+
return training_args, best_trial.run_id
|
@@ -1,9 +1,10 @@
|
|
1
1
|
from typing import Optional, Union
|
2
2
|
|
3
|
+
import torch
|
3
4
|
from datasets import Dataset
|
4
5
|
from torch.utils.data import DataLoader
|
5
6
|
from transformers import Trainer
|
6
|
-
from transformers.trainer_utils import has_length
|
7
|
+
from transformers.trainer_utils import has_length, seed_worker
|
7
8
|
from transformers.utils import import_utils, logging
|
8
9
|
|
9
10
|
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
@@ -62,7 +63,10 @@ class SamplePackingTrainer(Trainer):
|
|
62
63
|
if "num_of_concepts" in train_dataset.column_names:
|
63
64
|
lengths = train_dataset["num_of_concepts"]
|
64
65
|
else:
|
65
|
-
lengths = [
|
66
|
+
lengths = [
|
67
|
+
len(sample["input_ids"])
|
68
|
+
for sample in train_dataset.select_columns("input_ids")
|
69
|
+
]
|
66
70
|
|
67
71
|
LOG.info("Finished computing lengths for the train dataset")
|
68
72
|
else:
|
@@ -102,6 +106,11 @@ class SamplePackingTrainer(Trainer):
|
|
102
106
|
"persistent_workers": self.args.dataloader_persistent_workers,
|
103
107
|
"batch_sampler": batch_sampler,
|
104
108
|
}
|
109
|
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
110
|
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
111
|
+
dataloader_params["worker_init_fn"] = seed_worker
|
112
|
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
113
|
+
|
105
114
|
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
106
115
|
|
107
116
|
def get_eval_dataloader(
|
@@ -1,8 +1,8 @@
|
|
1
|
+
import datetime
|
1
2
|
import glob
|
2
3
|
import os
|
3
4
|
import shutil
|
4
5
|
import uuid
|
5
|
-
from datetime import datetime
|
6
6
|
from functools import partial
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Optional, Union
|
@@ -15,6 +15,9 @@ import torch.distributed as dist
|
|
15
15
|
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
16
16
|
from cehrbert.runners.runner_util import generate_prepared_ds_path
|
17
17
|
from datasets import concatenate_datasets, load_from_disk
|
18
|
+
from torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook import (
|
19
|
+
batched_powerSGD_hook,
|
20
|
+
)
|
18
21
|
from torch.utils.data import DataLoader
|
19
22
|
from tqdm import tqdm
|
20
23
|
from transformers.trainer_utils import is_main_process
|
@@ -25,7 +28,6 @@ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
|
|
25
28
|
CehrGptDataCollator,
|
26
29
|
SamplePackingCehrGptDataCollator,
|
27
30
|
)
|
28
|
-
from cehrgpt.data.hf_cehrgpt_dataset_mapping import ExtractTokenizedSequenceDataMapping
|
29
31
|
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
30
32
|
from cehrgpt.models.hf_cehrgpt import (
|
31
33
|
CEHRGPT2Model,
|
@@ -159,24 +161,7 @@ def main():
|
|
159
161
|
final_splits = prepare_finetune_dataset(
|
160
162
|
data_args, training_args, cehrgpt_args, cache_file_collector
|
161
163
|
)
|
162
|
-
|
163
|
-
new_tokenizer_path = os.path.expanduser(training_args.output_dir)
|
164
|
-
if tokenizer_exists(new_tokenizer_path):
|
165
|
-
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
166
|
-
new_tokenizer_path
|
167
|
-
)
|
168
|
-
else:
|
169
|
-
cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
|
170
|
-
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
171
|
-
dataset=final_splits["train"],
|
172
|
-
data_args=data_args,
|
173
|
-
concept_name_mapping={},
|
174
|
-
)
|
175
|
-
cehrgpt_tokenizer.save_pretrained(
|
176
|
-
os.path.expanduser(training_args.output_dir)
|
177
|
-
)
|
178
|
-
|
179
|
-
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
164
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
180
165
|
if not data_args.streaming:
|
181
166
|
all_columns = final_splits["train"].column_names
|
182
167
|
if "visit_concept_ids" in all_columns:
|
@@ -238,10 +223,6 @@ def main():
|
|
238
223
|
len(processed_dataset["test"]),
|
239
224
|
)
|
240
225
|
|
241
|
-
LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
|
242
|
-
LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
|
243
|
-
if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
|
244
|
-
cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
|
245
226
|
if (
|
246
227
|
cehrgpt_model.config.max_position_embeddings
|
247
228
|
< model_args.max_position_embeddings
|
@@ -264,7 +245,6 @@ def main():
|
|
264
245
|
SamplePackingCehrGptDataCollator,
|
265
246
|
cehrgpt_args.max_tokens_per_batch,
|
266
247
|
cehrgpt_model.config.max_position_embeddings,
|
267
|
-
add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
|
268
248
|
)
|
269
249
|
train_batch_sampler = SamplePackingBatchSampler(
|
270
250
|
lengths=train_set["num_of_concepts"],
|
@@ -339,10 +319,12 @@ def main():
|
|
339
319
|
for data_dir in [data_args.data_folder, data_args.test_data_folder]
|
340
320
|
]
|
341
321
|
)
|
342
|
-
|
343
|
-
demographics_df["index_date"] =
|
344
|
-
demographics_df["index_date"]
|
345
|
-
|
322
|
+
|
323
|
+
demographics_df["index_date"] = (
|
324
|
+
demographics_df["index_date"].dt.tz_localize("UTC")
|
325
|
+
- datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
|
326
|
+
).dt.total_seconds()
|
327
|
+
|
346
328
|
demographics_dict = {
|
347
329
|
(row["person_id"], row["index_date"]): {
|
348
330
|
"gender_concept_id": row["gender_concept_id"],
|
@@ -353,7 +335,7 @@ def main():
|
|
353
335
|
|
354
336
|
data_loaders = [("train", train_loader), ("test", test_dataloader)]
|
355
337
|
|
356
|
-
ve_token_id = cehrgpt_tokenizer.
|
338
|
+
ve_token_id = cehrgpt_tokenizer.ve_token_id
|
357
339
|
for split, data_loader in data_loaders:
|
358
340
|
# Ensure prediction folder exists
|
359
341
|
feature_output_folder = (
|
@@ -379,9 +361,16 @@ def main():
|
|
379
361
|
prediction_time_posix = batch.pop("index_date").numpy().squeeze()
|
380
362
|
if prediction_time_posix.ndim == 0:
|
381
363
|
prediction_time_posix = np.asarray([prediction_time_posix])
|
364
|
+
|
382
365
|
prediction_time = list(
|
383
|
-
map(
|
366
|
+
map(
|
367
|
+
lambda posix_time: datetime.datetime.utcfromtimestamp(
|
368
|
+
posix_time
|
369
|
+
).replace(tzinfo=None),
|
370
|
+
prediction_time_posix,
|
371
|
+
)
|
384
372
|
)
|
373
|
+
|
385
374
|
labels = (
|
386
375
|
batch.pop("classifier_label")
|
387
376
|
.float()
|
@@ -393,6 +382,13 @@ def main():
|
|
393
382
|
if labels.ndim == 0:
|
394
383
|
labels = np.asarray([labels])
|
395
384
|
|
385
|
+
# Right now the model does not support this column, we need to pop it
|
386
|
+
if "epoch_times" in batch:
|
387
|
+
batch.pop("epoch_times")
|
388
|
+
|
389
|
+
if "ages" in batch:
|
390
|
+
batch.pop("ages")
|
391
|
+
|
396
392
|
batch = {k: v.to(device) for k, v in batch.items()}
|
397
393
|
# Forward pass
|
398
394
|
cehrgpt_output = cehrgpt_model(
|