yaib 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- icu_benchmarks/__init__.py +6 -0
- icu_benchmarks/contants.py +7 -0
- icu_benchmarks/cross_validation.py +120 -0
- icu_benchmarks/run.py +176 -0
- icu_benchmarks/run_utils.py +254 -0
- icu_benchmarks/wandb_utils.py +61 -0
- yaib-0.3.1.dist-info/LICENSE +25 -0
- yaib-0.3.1.dist-info/METADATA +310 -0
- yaib-0.3.1.dist-info/RECORD +12 -0
- yaib-0.3.1.dist-info/WHEEL +6 -0
- yaib-0.3.1.dist-info/entry_points.txt +2 -0
- yaib-0.3.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
import logging
|
|
4
|
+
import gin
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from pytorch_lightning import seed_everything
|
|
7
|
+
|
|
8
|
+
from icu_benchmarks.wandb_utils import wandb_log
|
|
9
|
+
from icu_benchmarks.run_utils import aggregate_results
|
|
10
|
+
from icu_benchmarks.data.split_process_data import preprocess_data
|
|
11
|
+
from icu_benchmarks.models.train import train_common
|
|
12
|
+
from icu_benchmarks.models.utils import JsonResultLoggingEncoder
|
|
13
|
+
from icu_benchmarks.run_utils import log_full_line
|
|
14
|
+
from icu_benchmarks.contants import RunMode
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@gin.configurable
|
|
18
|
+
def execute_repeated_cv(
|
|
19
|
+
data_dir: Path,
|
|
20
|
+
log_dir: Path,
|
|
21
|
+
seed: int,
|
|
22
|
+
load_weights: bool = False,
|
|
23
|
+
source_dir: Path = None,
|
|
24
|
+
cv_repetitions: int = 5,
|
|
25
|
+
cv_repetitions_to_train: int = None,
|
|
26
|
+
cv_folds: int = 5,
|
|
27
|
+
cv_folds_to_train: int = None,
|
|
28
|
+
reproducible: bool = True,
|
|
29
|
+
debug: bool = False,
|
|
30
|
+
generate_cache: bool = False,
|
|
31
|
+
load_cache: bool = False,
|
|
32
|
+
test_on: str = "test",
|
|
33
|
+
mode: str = RunMode.classification,
|
|
34
|
+
pretrained_imputation_model: object = None,
|
|
35
|
+
cpu: bool = False,
|
|
36
|
+
verbose: bool = False,
|
|
37
|
+
wandb: bool = False,
|
|
38
|
+
) -> float:
|
|
39
|
+
"""Preprocesses data and trains a model for each fold.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
|
|
43
|
+
data_dir: Path to the data directory.
|
|
44
|
+
log_dir: Path to the log directory.
|
|
45
|
+
seed: Random seed.
|
|
46
|
+
load_weights: Whether to load weights from source_dir.
|
|
47
|
+
source_dir: Path to the source directory.
|
|
48
|
+
cv_folds: Number of folds for cross validation.
|
|
49
|
+
cv_folds_to_train: Number of folds to use during training. If None, all folds are trained on.
|
|
50
|
+
cv_repetitions: Amount of cross validation repetitions.
|
|
51
|
+
cv_repetitions_to_train: Amount of training repetitions. If None, all repetitions are trained on.
|
|
52
|
+
reproducible: Whether to make torch reproducible.
|
|
53
|
+
debug: Whether to load less data and enable more logging.
|
|
54
|
+
generate_cache: Whether to generate and save cache.
|
|
55
|
+
load_cache: Whether to load previously cached data.
|
|
56
|
+
test_on: Dataset to test on. Can be "test" or "val" (e.g. for hyperparameter tuning).
|
|
57
|
+
mode: Run mode. Can be one of the values of RunMode
|
|
58
|
+
pretrained_imputation_model: Use a pretrained imputation model.
|
|
59
|
+
cpu: Whether to run on CPU.
|
|
60
|
+
verbose: Enable detailed logging.
|
|
61
|
+
Returns:
|
|
62
|
+
The average loss of all folds.
|
|
63
|
+
"""
|
|
64
|
+
if not cv_repetitions_to_train:
|
|
65
|
+
cv_repetitions_to_train = cv_repetitions
|
|
66
|
+
if not cv_folds_to_train:
|
|
67
|
+
cv_folds_to_train = cv_folds
|
|
68
|
+
agg_loss = 0
|
|
69
|
+
|
|
70
|
+
seed_everything(seed, reproducible)
|
|
71
|
+
for repetition in range(cv_repetitions_to_train):
|
|
72
|
+
for fold_index in range(cv_folds_to_train):
|
|
73
|
+
start_time = datetime.now()
|
|
74
|
+
data = preprocess_data(
|
|
75
|
+
data_dir,
|
|
76
|
+
seed=seed,
|
|
77
|
+
debug=debug,
|
|
78
|
+
load_cache=load_cache,
|
|
79
|
+
generate_cache=generate_cache,
|
|
80
|
+
cv_repetitions=cv_repetitions,
|
|
81
|
+
repetition_index=repetition,
|
|
82
|
+
cv_folds=cv_folds,
|
|
83
|
+
fold_index=fold_index,
|
|
84
|
+
pretrained_imputation_model=pretrained_imputation_model,
|
|
85
|
+
runmode=mode,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
repetition_fold_dir = log_dir / f"repetition_{repetition}" / f"fold_{fold_index}"
|
|
89
|
+
repetition_fold_dir.mkdir(parents=True, exist_ok=True)
|
|
90
|
+
preprocess_time = datetime.now() - start_time
|
|
91
|
+
start_time = datetime.now()
|
|
92
|
+
agg_loss += train_common(
|
|
93
|
+
data,
|
|
94
|
+
log_dir=repetition_fold_dir,
|
|
95
|
+
load_weights=load_weights,
|
|
96
|
+
source_dir=source_dir,
|
|
97
|
+
reproducible=reproducible,
|
|
98
|
+
test_on=test_on,
|
|
99
|
+
mode=mode,
|
|
100
|
+
cpu=cpu,
|
|
101
|
+
verbose=verbose,
|
|
102
|
+
use_wandb=wandb,
|
|
103
|
+
)
|
|
104
|
+
train_time = datetime.now() - start_time
|
|
105
|
+
|
|
106
|
+
log_full_line(
|
|
107
|
+
f"FINISHED FOLD {fold_index}| PREPROCESSING DURATION {preprocess_time}| TRAINING DURATION {train_time}",
|
|
108
|
+
level=logging.INFO,
|
|
109
|
+
)
|
|
110
|
+
durations = {"preprocessing_duration": preprocess_time, "train_duration": train_time}
|
|
111
|
+
|
|
112
|
+
with open(repetition_fold_dir / "durations.json", "w") as f:
|
|
113
|
+
json.dump(durations, f, cls=JsonResultLoggingEncoder)
|
|
114
|
+
if wandb:
|
|
115
|
+
wandb_log({"Iteration": repetition * cv_folds_to_train + fold_index})
|
|
116
|
+
if repetition * cv_folds_to_train + fold_index > 1:
|
|
117
|
+
aggregate_results(log_dir)
|
|
118
|
+
log_full_line(f"FINISHED CV REPETITION {repetition}", level=logging.INFO, char="=", num_newlines=3)
|
|
119
|
+
|
|
120
|
+
return agg_loss / (cv_repetitions_to_train * cv_folds_to_train)
|
icu_benchmarks/run.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
|
|
4
|
+
import gin
|
|
5
|
+
import logging
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import importlib.util
|
|
9
|
+
|
|
10
|
+
import torch.cuda
|
|
11
|
+
|
|
12
|
+
from icu_benchmarks.wandb_utils import update_wandb_config, apply_wandb_sweep, set_wandb_run_name
|
|
13
|
+
from icu_benchmarks.tuning.hyperparameters import choose_and_bind_hyperparameters
|
|
14
|
+
from scripts.plotting.utils import plot_aggregated_results
|
|
15
|
+
from icu_benchmarks.cross_validation import execute_repeated_cv
|
|
16
|
+
from icu_benchmarks.run_utils import (
|
|
17
|
+
build_parser,
|
|
18
|
+
create_run_dir,
|
|
19
|
+
aggregate_results,
|
|
20
|
+
log_full_line,
|
|
21
|
+
load_pretrained_imputation_model,
|
|
22
|
+
setup_logging,
|
|
23
|
+
)
|
|
24
|
+
from icu_benchmarks.contants import RunMode
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@gin.configurable("Run")
|
|
28
|
+
def get_mode(mode: gin.REQUIRED):
|
|
29
|
+
# Check if enum is mode.
|
|
30
|
+
assert RunMode(mode)
|
|
31
|
+
return RunMode(mode)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def main(my_args=tuple(sys.argv[1:])):
|
|
35
|
+
args, _ = build_parser().parse_known_args(my_args)
|
|
36
|
+
|
|
37
|
+
# Set arguments for wandb sweep
|
|
38
|
+
if args.wandb_sweep:
|
|
39
|
+
args = apply_wandb_sweep(args)
|
|
40
|
+
|
|
41
|
+
# Initialize loggers
|
|
42
|
+
log_format = "%(asctime)s - %(levelname)s - %(name)s : %(message)s"
|
|
43
|
+
date_format = "%Y-%m-%d %H:%M:%S"
|
|
44
|
+
verbose = args.verbose
|
|
45
|
+
setup_logging(date_format, log_format, verbose)
|
|
46
|
+
|
|
47
|
+
# Load weights if in evaluation mode
|
|
48
|
+
load_weights = args.command == "evaluate"
|
|
49
|
+
data_dir = Path(args.data_dir)
|
|
50
|
+
|
|
51
|
+
# Get arguments
|
|
52
|
+
name = args.name
|
|
53
|
+
task = args.task
|
|
54
|
+
model = args.model
|
|
55
|
+
reproducible = args.reproducible
|
|
56
|
+
|
|
57
|
+
# Set experiment name
|
|
58
|
+
if name is None:
|
|
59
|
+
name = data_dir.name
|
|
60
|
+
logging.info(f"Running experiment {name}.")
|
|
61
|
+
|
|
62
|
+
# Load task config
|
|
63
|
+
gin.parse_config_file(f"configs/tasks/{task}.gin")
|
|
64
|
+
|
|
65
|
+
mode = get_mode()
|
|
66
|
+
|
|
67
|
+
if args.wandb_sweep:
|
|
68
|
+
run_name = f"{mode}_{model}_{name}"
|
|
69
|
+
set_wandb_run_name(run_name)
|
|
70
|
+
|
|
71
|
+
logging.info(f"Task mode: {mode}.")
|
|
72
|
+
experiment = args.experiment
|
|
73
|
+
|
|
74
|
+
pretrained_imputation_model = load_pretrained_imputation_model(args.pretrained_imputation)
|
|
75
|
+
|
|
76
|
+
# Log imputation model to wandb
|
|
77
|
+
update_wandb_config(
|
|
78
|
+
{
|
|
79
|
+
"pretrained_imputation_model": pretrained_imputation_model.__class__.__name__
|
|
80
|
+
if pretrained_imputation_model is not None
|
|
81
|
+
else "None"
|
|
82
|
+
}
|
|
83
|
+
)
|
|
84
|
+
source_dir = None
|
|
85
|
+
log_dir_name = args.log_dir / name
|
|
86
|
+
log_dir = (
|
|
87
|
+
(log_dir_name / experiment)
|
|
88
|
+
if experiment
|
|
89
|
+
else (log_dir_name / (args.task_name if args.task_name is not None else args.task) / model)
|
|
90
|
+
)
|
|
91
|
+
if torch.cuda.is_available():
|
|
92
|
+
for name in range(0, torch.cuda.device_count()):
|
|
93
|
+
log_full_line(f"Available GPU {name}: {torch.cuda.get_device_name(name)}", level=logging.INFO)
|
|
94
|
+
else:
|
|
95
|
+
log_full_line(
|
|
96
|
+
"No GPUs available: please check your device and Torch,Cuda installation if unintended.", level=logging.WARNING
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
log_full_line(f"Logging to {log_dir}.", logging.INFO)
|
|
100
|
+
|
|
101
|
+
if args.preprocessor:
|
|
102
|
+
# Import custom supplied preprocessor
|
|
103
|
+
log_full_line(f"Importing custom preprocessor from {args.preprocessor}.", logging.INFO)
|
|
104
|
+
try:
|
|
105
|
+
spec = importlib.util.spec_from_file_location("CustomPreprocessor", args.preprocessor)
|
|
106
|
+
module = importlib.util.module_from_spec(spec)
|
|
107
|
+
sys.modules["preprocessor"] = module
|
|
108
|
+
spec.loader.exec_module(module)
|
|
109
|
+
gin.bind_parameter("preprocess.preprocessor", module.CustomPreprocessor)
|
|
110
|
+
except Exception as e:
|
|
111
|
+
logging.error(f"Could not import custom preprocessor from {args.preprocessor}: {e}")
|
|
112
|
+
|
|
113
|
+
if load_weights:
|
|
114
|
+
# Evaluate
|
|
115
|
+
log_dir /= f"from_{args.source_name}"
|
|
116
|
+
run_dir = create_run_dir(log_dir)
|
|
117
|
+
source_dir = args.source_dir
|
|
118
|
+
gin.parse_config_file(source_dir / "train_config.gin")
|
|
119
|
+
else:
|
|
120
|
+
# Train
|
|
121
|
+
checkpoint = log_dir / args.checkpoint if args.checkpoint else None
|
|
122
|
+
model_path = (
|
|
123
|
+
Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin"
|
|
124
|
+
)
|
|
125
|
+
gin_config_files = (
|
|
126
|
+
[Path(f"configs/experiments/{args.experiment}.gin")]
|
|
127
|
+
if args.experiment
|
|
128
|
+
else [model_path, Path(f"configs/tasks/{task}.gin")]
|
|
129
|
+
)
|
|
130
|
+
gin.parse_config_files_and_bindings(gin_config_files, args.hyperparams, finalize_config=False)
|
|
131
|
+
log_full_line(f"Data directory: {data_dir.resolve()}", level=logging.INFO)
|
|
132
|
+
run_dir = create_run_dir(log_dir)
|
|
133
|
+
choose_and_bind_hyperparameters(
|
|
134
|
+
args.tune,
|
|
135
|
+
data_dir,
|
|
136
|
+
run_dir,
|
|
137
|
+
args.seed,
|
|
138
|
+
run_mode=mode,
|
|
139
|
+
checkpoint=checkpoint,
|
|
140
|
+
debug=args.debug,
|
|
141
|
+
generate_cache=args.generate_cache,
|
|
142
|
+
load_cache=args.load_cache,
|
|
143
|
+
verbose=args.verbose,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
log_full_line(f"Logging to {run_dir.resolve()}", level=logging.INFO)
|
|
147
|
+
log_full_line("STARTING TRAINING", level=logging.INFO, char="=", num_newlines=3)
|
|
148
|
+
start_time = datetime.now()
|
|
149
|
+
execute_repeated_cv(
|
|
150
|
+
data_dir,
|
|
151
|
+
run_dir,
|
|
152
|
+
args.seed,
|
|
153
|
+
load_weights=load_weights,
|
|
154
|
+
source_dir=source_dir,
|
|
155
|
+
reproducible=reproducible,
|
|
156
|
+
debug=args.debug,
|
|
157
|
+
verbose=args.verbose,
|
|
158
|
+
load_cache=args.load_cache,
|
|
159
|
+
generate_cache=args.generate_cache,
|
|
160
|
+
mode=mode,
|
|
161
|
+
pretrained_imputation_model=pretrained_imputation_model,
|
|
162
|
+
cpu=args.cpu,
|
|
163
|
+
wandb=args.wandb_sweep,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
log_full_line("FINISHED TRAINING", level=logging.INFO, char="=", num_newlines=3)
|
|
167
|
+
execution_time = datetime.now() - start_time
|
|
168
|
+
log_full_line(f"DURATION: {execution_time}", level=logging.INFO, char="")
|
|
169
|
+
aggregate_results(run_dir, execution_time)
|
|
170
|
+
if args.plot:
|
|
171
|
+
plot_aggregated_results(run_dir, "aggregated_test_metrics.json")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
"""Main module."""
|
|
175
|
+
if __name__ == "__main__":
|
|
176
|
+
main()
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from math import sqrt
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import json
|
|
6
|
+
from argparse import ArgumentParser, BooleanOptionalAction
|
|
7
|
+
from datetime import datetime, timedelta
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import scipy.stats as stats
|
|
11
|
+
import shutil
|
|
12
|
+
from statistics import mean, pstdev
|
|
13
|
+
from icu_benchmarks.models.utils import JsonResultLoggingEncoder
|
|
14
|
+
from icu_benchmarks.wandb_utils import wandb_log
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def build_parser() -> ArgumentParser:
|
|
18
|
+
"""Builds an ArgumentParser for the command line.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The configured ArgumentParser.
|
|
22
|
+
"""
|
|
23
|
+
parser = ArgumentParser(description="Benchmark lib for processing and evaluation of deep learning models on ICU data")
|
|
24
|
+
|
|
25
|
+
parent_parser = ArgumentParser(add_help=False)
|
|
26
|
+
subparsers = parser.add_subparsers(title="Commands", dest="command", required=True)
|
|
27
|
+
|
|
28
|
+
# ARGUMENTS FOR ALL COMMANDS
|
|
29
|
+
general_args = parent_parser.add_argument_group("General arguments")
|
|
30
|
+
general_args.add_argument("-d", "--data-dir", required=True, type=Path, help="Path to the parquet data directory.")
|
|
31
|
+
general_args.add_argument("-t", "--task", default="BinaryClassification", required=True, help="Name of the task gin.")
|
|
32
|
+
general_args.add_argument("-n", "--name", required=False, help="Name of the (target) dataset.")
|
|
33
|
+
general_args.add_argument("-tn", "--task-name", required=False, help="Name of the task, used for naming experiments.")
|
|
34
|
+
general_args.add_argument("-m", "--model", default="LGBMClassifier", required=False, help="Name of the model gin.")
|
|
35
|
+
general_args.add_argument("-e", "--experiment", required=False, help="Name of the experiment gin.")
|
|
36
|
+
general_args.add_argument(
|
|
37
|
+
"-l", "--log-dir", required=False, default=Path("../yaib_logs/"), type=Path, help="Log directory with model weights."
|
|
38
|
+
)
|
|
39
|
+
general_args.add_argument(
|
|
40
|
+
"-s", "--seed", required=False, default=1234, type=int, help="Random seed for processing, tuning and training."
|
|
41
|
+
)
|
|
42
|
+
general_args.add_argument(
|
|
43
|
+
"-v",
|
|
44
|
+
"--verbose",
|
|
45
|
+
default=False,
|
|
46
|
+
required=False,
|
|
47
|
+
action=BooleanOptionalAction,
|
|
48
|
+
help="Whether to use verbose logging. Disable for clean logs.",
|
|
49
|
+
)
|
|
50
|
+
general_args.add_argument("--cpu", default=False, required=False, action=BooleanOptionalAction, help="Set to use CPU.")
|
|
51
|
+
general_args.add_argument(
|
|
52
|
+
"-db", "--debug", required=False, default=False, action=BooleanOptionalAction, help="Set to load less data."
|
|
53
|
+
)
|
|
54
|
+
general_args.add_argument(
|
|
55
|
+
"-lc",
|
|
56
|
+
"--load_cache",
|
|
57
|
+
required=False,
|
|
58
|
+
default=False,
|
|
59
|
+
action=BooleanOptionalAction,
|
|
60
|
+
help="Set to load generated data cache.",
|
|
61
|
+
)
|
|
62
|
+
general_args.add_argument(
|
|
63
|
+
"-gc",
|
|
64
|
+
"--generate_cache",
|
|
65
|
+
required=False,
|
|
66
|
+
default=False,
|
|
67
|
+
action=BooleanOptionalAction,
|
|
68
|
+
help="Set to generate data cache.",
|
|
69
|
+
)
|
|
70
|
+
general_args.add_argument("-p", "--preprocessor", required=False, type=Path, help="Load custom preprocessor from file.")
|
|
71
|
+
general_args.add_argument("-pl", "--plot", required=False, action=BooleanOptionalAction, help="Generate common plots.")
|
|
72
|
+
general_args.add_argument(
|
|
73
|
+
"-wd", "--wandb-sweep", required=False, action="store_true", help="Activates wandb hyper parameter sweep."
|
|
74
|
+
)
|
|
75
|
+
general_args.add_argument(
|
|
76
|
+
"-imp", "--pretrained-imputation", required=False, type=str, help="Path to pretrained imputation model."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# MODEL TRAINING ARGUMENTS
|
|
80
|
+
prep_and_train = subparsers.add_parser("train", help="Preprocess features and train model.", parents=[parent_parser])
|
|
81
|
+
prep_and_train.add_argument(
|
|
82
|
+
"--reproducible", required=False, default=True, action=BooleanOptionalAction, help="Make torch reproducible."
|
|
83
|
+
)
|
|
84
|
+
prep_and_train.add_argument("-hp", "--hyperparams", required=False, nargs="+", help="Hyperparameters for model.")
|
|
85
|
+
prep_and_train.add_argument("--tune", default=False, action=BooleanOptionalAction, help="Find best hyperparameters.")
|
|
86
|
+
prep_and_train.add_argument("--checkpoint", required=False, type=Path, help="Use previous checkpoint.")
|
|
87
|
+
|
|
88
|
+
# EVALUATION PARSER
|
|
89
|
+
evaluate = subparsers.add_parser("evaluate", help="Evaluate trained model on data.", parents=[parent_parser])
|
|
90
|
+
evaluate.add_argument("-sn", "--source-name", required=True, type=Path, help="Name of the source dataset.")
|
|
91
|
+
evaluate.add_argument("--source-dir", required=True, type=Path, help="Directory containing gin and model weights.")
|
|
92
|
+
|
|
93
|
+
return parser
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path:
|
|
97
|
+
"""Creates a log directory with the current time as name.
|
|
98
|
+
|
|
99
|
+
Also creates a file in the log directory, if any parameters were randomly searched.
|
|
100
|
+
The filename contains the fixed hyperparameters to check against in future runs.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
log_dir: Parent directory to create run directory in.
|
|
104
|
+
randomly_searched_params: String representing the randomly searched params.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Path to the created run log directory.
|
|
108
|
+
"""
|
|
109
|
+
log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
|
110
|
+
log_dir_run.mkdir(parents=True)
|
|
111
|
+
if randomly_searched_params:
|
|
112
|
+
(log_dir_run / randomly_searched_params).touch()
|
|
113
|
+
return log_dir_run
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def aggregate_results(log_dir: Path, execution_time: timedelta = None):
|
|
117
|
+
"""Aggregates results from all folds and writes to JSON file.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
log_dir: Path to the log directory.
|
|
121
|
+
execution_time: Overall execution time.
|
|
122
|
+
"""
|
|
123
|
+
aggregated = {}
|
|
124
|
+
for repetition in log_dir.iterdir():
|
|
125
|
+
if repetition.is_dir():
|
|
126
|
+
aggregated[repetition.name] = {}
|
|
127
|
+
for fold_iter in repetition.iterdir():
|
|
128
|
+
aggregated[repetition.name][fold_iter.name] = {}
|
|
129
|
+
if (fold_iter / "test_metrics.json").is_file():
|
|
130
|
+
with open(fold_iter / "test_metrics.json", "r") as f:
|
|
131
|
+
result = json.load(f)
|
|
132
|
+
aggregated[repetition.name][fold_iter.name].update(result)
|
|
133
|
+
elif (fold_iter / "val_metrics.csv").is_file():
|
|
134
|
+
with open(fold_iter / "val_metrics.csv", "r") as f:
|
|
135
|
+
result = json.load(f)
|
|
136
|
+
aggregated[repetition.name][fold_iter.name].update(result)
|
|
137
|
+
# Add durations to metrics
|
|
138
|
+
if (fold_iter / "durations.json").is_file():
|
|
139
|
+
with open(fold_iter / "durations.json", "r") as f:
|
|
140
|
+
result = json.load(f)
|
|
141
|
+
aggregated[repetition.name][fold_iter.name].update(result)
|
|
142
|
+
|
|
143
|
+
# Aggregate results per metric
|
|
144
|
+
list_scores = {}
|
|
145
|
+
for repetition, folds in aggregated.items():
|
|
146
|
+
for fold, result in folds.items():
|
|
147
|
+
for metric, score in result.items():
|
|
148
|
+
if isinstance(score, (float, int)):
|
|
149
|
+
list_scores[metric] = list_scores.setdefault(metric, [])
|
|
150
|
+
list_scores[metric].append(score)
|
|
151
|
+
|
|
152
|
+
# Compute statistical metric over aggregated results
|
|
153
|
+
averaged_scores = {metric: (mean(list)) for metric, list in list_scores.items()}
|
|
154
|
+
|
|
155
|
+
# Calculate the population standard deviation over aggregated results over folds/iterations
|
|
156
|
+
# Divide by sqrt(n) to get standard deviation.
|
|
157
|
+
std_scores = {metric: (pstdev(list) / sqrt(len(list))) for metric, list in list_scores.items()}
|
|
158
|
+
|
|
159
|
+
confidence_interval = {
|
|
160
|
+
metric: (stats.t.interval(0.95, len(list) - 1, loc=mean(list), scale=stats.sem(list)))
|
|
161
|
+
for metric, list in list_scores.items()
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
accumulated_metrics = {
|
|
165
|
+
"avg": averaged_scores,
|
|
166
|
+
"std": std_scores,
|
|
167
|
+
"CI_0.95": confidence_interval,
|
|
168
|
+
"execution_time": execution_time.total_seconds() if execution_time is not None else 0.0,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
with open(log_dir / "aggregated_test_metrics.json", "w") as f:
|
|
172
|
+
json.dump(aggregated, f, cls=JsonResultLoggingEncoder)
|
|
173
|
+
|
|
174
|
+
with open(log_dir / "accumulated_test_metrics.json", "w") as f:
|
|
175
|
+
json.dump(accumulated_metrics, f, cls=JsonResultLoggingEncoder)
|
|
176
|
+
|
|
177
|
+
logging.info(f"Accumulated results: {accumulated_metrics}")
|
|
178
|
+
|
|
179
|
+
wandb_log(json.loads(json.dumps(accumulated_metrics, cls=JsonResultLoggingEncoder)))
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newlines: int = 0):
|
|
183
|
+
"""Logs a full line of a given character with a message centered.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
msg: Message to log.
|
|
187
|
+
level: Logging level.
|
|
188
|
+
char: Character to use for the line.
|
|
189
|
+
num_newlines: Number of newlines to append.
|
|
190
|
+
"""
|
|
191
|
+
terminal_size = shutil.get_terminal_size((80, 20))
|
|
192
|
+
reserved_chars = len(logging.getLevelName(level)) + 28
|
|
193
|
+
logging.log(
|
|
194
|
+
level,
|
|
195
|
+
"{0:{char}^{width}}{1}".format(msg, "\n" * num_newlines, char=char, width=terminal_size.columns - reserved_chars),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def load_pretrained_imputation_model(use_pretrained_imputation):
|
|
200
|
+
"""Loads a pretrained imputation model.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
use_pretrained_imputation: Path to the pretrained imputation model.
|
|
204
|
+
"""
|
|
205
|
+
if use_pretrained_imputation is not None and not Path(use_pretrained_imputation).exists():
|
|
206
|
+
logging.warning("The specified pretrained imputation model does not exist.")
|
|
207
|
+
use_pretrained_imputation = None
|
|
208
|
+
|
|
209
|
+
if use_pretrained_imputation is not None:
|
|
210
|
+
logging.info("Using pretrained imputation from" + str(use_pretrained_imputation))
|
|
211
|
+
pretrained_imputation_model_checkpoint = torch.load(use_pretrained_imputation, map_location=torch.device("cpu"))
|
|
212
|
+
if isinstance(pretrained_imputation_model_checkpoint, dict):
|
|
213
|
+
imputation_model_class = pretrained_imputation_model_checkpoint["class"]
|
|
214
|
+
pretrained_imputation_model = imputation_model_class(**pretrained_imputation_model_checkpoint["hyper_parameters"])
|
|
215
|
+
pretrained_imputation_model.set_trained_columns(pretrained_imputation_model_checkpoint["trained_columns"])
|
|
216
|
+
pretrained_imputation_model.load_state_dict(pretrained_imputation_model_checkpoint["state_dict"])
|
|
217
|
+
else:
|
|
218
|
+
pretrained_imputation_model = pretrained_imputation_model_checkpoint
|
|
219
|
+
pretrained_imputation_model = pretrained_imputation_model.to("cuda" if torch.cuda.is_available() else "cpu")
|
|
220
|
+
try:
|
|
221
|
+
logging.info(f"imputation model device: {next(pretrained_imputation_model.parameters()).device}")
|
|
222
|
+
pretrained_imputation_model.device = next(pretrained_imputation_model.parameters()).device
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logging.debug(f"Could not set device of imputation model: {e}")
|
|
225
|
+
else:
|
|
226
|
+
pretrained_imputation_model = None
|
|
227
|
+
|
|
228
|
+
return pretrained_imputation_model
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def setup_logging(date_format, log_format, verbose):
|
|
232
|
+
"""
|
|
233
|
+
Set up all loggers to use the same format and date format.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
date_format: Format for the date.
|
|
237
|
+
log_format: Format for the log.
|
|
238
|
+
verbose: Whether to log debug messages.
|
|
239
|
+
"""
|
|
240
|
+
logging.basicConfig(format=log_format, datefmt=date_format)
|
|
241
|
+
loggers = ["pytorch_lightning", "lightning_fabric"]
|
|
242
|
+
for logger in loggers:
|
|
243
|
+
logging.getLogger(logger).handlers[0].setFormatter(logging.Formatter(log_format, datefmt=date_format))
|
|
244
|
+
|
|
245
|
+
if not verbose:
|
|
246
|
+
logging.getLogger().setLevel(logging.INFO)
|
|
247
|
+
for logger in loggers:
|
|
248
|
+
logging.getLogger(logger).setLevel(logging.INFO)
|
|
249
|
+
warnings.filterwarnings("ignore")
|
|
250
|
+
else:
|
|
251
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
252
|
+
for logger in loggers:
|
|
253
|
+
logging.getLogger(logger).setLevel(logging.DEBUG)
|
|
254
|
+
warnings.filterwarnings("default")
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from argparse import Namespace
|
|
2
|
+
import logging
|
|
3
|
+
import wandb
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def wandb_running() -> bool:
|
|
7
|
+
"""Check if wandb is running."""
|
|
8
|
+
return wandb.run is not None
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def update_wandb_config(config: dict) -> None:
|
|
12
|
+
"""updates wandb config if wandb is running
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
config (dict): config to set
|
|
16
|
+
"""
|
|
17
|
+
logging.debug(f"Updating Wandb config: {config}")
|
|
18
|
+
if wandb_running():
|
|
19
|
+
wandb.config.update(config)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def apply_wandb_sweep(args: Namespace) -> Namespace:
|
|
23
|
+
"""applies the wandb sweep configuration to the namespace object
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
args (Namespace): parsed arguments
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Namespace: arguments with sweep configuration applied (some are applied via hyperparams)
|
|
30
|
+
"""
|
|
31
|
+
wandb.init()
|
|
32
|
+
sweep_config = wandb.config
|
|
33
|
+
args.__dict__.update(sweep_config)
|
|
34
|
+
if args.hyperparams is None:
|
|
35
|
+
args.hyperparams = []
|
|
36
|
+
for key, value in sweep_config.items():
|
|
37
|
+
args.hyperparams.append(f"{key}=" + (("'" + value + "'") if isinstance(value, str) else str(value)))
|
|
38
|
+
logging.info(f"hyperparams after loading sweep config: {args.hyperparams}")
|
|
39
|
+
return args
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def wandb_log(log_dict):
|
|
43
|
+
"""logs metrics to wandb
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
log_dict (dict): metric dict to log
|
|
47
|
+
"""
|
|
48
|
+
if wandb_running():
|
|
49
|
+
wandb.log(log_dict)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def set_wandb_run_name(run_name):
|
|
53
|
+
"""stores the run name in wandb config
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
run_name (str): name of the current run
|
|
57
|
+
"""
|
|
58
|
+
if wandb_running():
|
|
59
|
+
wandb.config.update({"run-name": run_name})
|
|
60
|
+
wandb.run.name = run_name
|
|
61
|
+
wandb.run.save()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
MIT License
|
|
4
|
+
|
|
5
|
+
Copyright (c) 2023, Robin van de Water, Hendrik Schmidt, Patrick Rockenschaub
|
|
6
|
+
Copyright (c) 2021, ETH Zurich, Biomedical Informatics Group; ratschlab
|
|
7
|
+
|
|
8
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
9
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
10
|
+
in the Software without restriction, including without limitation the rights
|
|
11
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
12
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
13
|
+
furnished to do so, subject to the following conditions:
|
|
14
|
+
|
|
15
|
+
The above copyright notice and this permission notice shall be included in all
|
|
16
|
+
copies or substantial portions of the Software.
|
|
17
|
+
|
|
18
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
19
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
20
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
21
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
22
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
23
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
24
|
+
SOFTWARE.
|
|
25
|
+
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: yaib
|
|
3
|
+
Version: 0.3.1
|
|
4
|
+
Summary: Yet Another ICU Benchmark is a holistic framework for the automation of the development of clinical prediction models on ICU data. Users can create custom datasets, cohorts, prediction tasks, endpoints, and models.
|
|
5
|
+
Home-page: https://github.com/rvandewater/YAIB
|
|
6
|
+
Author: Robin van de Water
|
|
7
|
+
Author-email: robin.vandewater@hpi.de
|
|
8
|
+
License: MIT license
|
|
9
|
+
Keywords: benchmark mimic-iii eicu hirid clinical-ml machine-learning benchmark time-series mimic-iv patient-monitoring amsterdamumcdb clinical-data ehr icu ricu pyicu
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Natural Language :: English
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: black (==23.3.0)
|
|
18
|
+
Requires-Dist: coverage (==7.2.3)
|
|
19
|
+
Requires-Dist: flake8 (==5.0.4)
|
|
20
|
+
Requires-Dist: matplotlib (==3.7.1)
|
|
21
|
+
Requires-Dist: gin-config (==0.5.0)
|
|
22
|
+
Requires-Dist: pytorch-ignite (==0.4.11)
|
|
23
|
+
Requires-Dist: torch (==2.0.1)
|
|
24
|
+
Requires-Dist: pytorch-cuda (==11.8)
|
|
25
|
+
Requires-Dist: lightgbm (==3.3.5)
|
|
26
|
+
Requires-Dist: numpy (==1.24.3)
|
|
27
|
+
Requires-Dist: pandas (==2.0.0)
|
|
28
|
+
Requires-Dist: pyarrow (==11.0.0)
|
|
29
|
+
Requires-Dist: pytest (==7.3.1)
|
|
30
|
+
Requires-Dist: scikit-learn (==1.2.2)
|
|
31
|
+
Requires-Dist: tensorboard (==2.12.2)
|
|
32
|
+
Requires-Dist: tqdm (==4.64.1)
|
|
33
|
+
Requires-Dist: pytorch-lightning (==2.0.3)
|
|
34
|
+
Requires-Dist: wandb (==0.15.4)
|
|
35
|
+
Requires-Dist: pip (==23.1)
|
|
36
|
+
Requires-Dist: einops (==0.6.1)
|
|
37
|
+
Requires-Dist: hydra-core (==1.3)
|
|
38
|
+
Requires-Dist: recipies (==0.1.1)
|
|
39
|
+
Provides-Extra: mps
|
|
40
|
+
Requires-Dist: mkl (<2022) ; extra == 'mps'
|
|
41
|
+
|
|
42
|
+

|
|
43
|
+
|
|
44
|
+
# 🧪 Yet Another ICU Benchmark
|
|
45
|
+
|
|
46
|
+
[](https://github.com/rvandewater/YAIB/actions/workflows/ci.yml)
|
|
47
|
+
[](https://github.com/psf/black)
|
|
48
|
+

|
|
49
|
+
[](http://arxiv.org/abs/2306.05109)
|
|
50
|
+
[](https://pypi.python.org/pypi/yaib/)
|
|
51
|
+
[](LICENSE)
|
|
52
|
+
|
|
53
|
+
[//]: # (TODO: add coverage once we have some tests )
|
|
54
|
+
|
|
55
|
+
Yet another ICU benchmark (YAIB) provides a framework for doing clinical machine learning experiments on Intensive Care Unit (
|
|
56
|
+
ICU) EHR data.
|
|
57
|
+
|
|
58
|
+
We support the following datasets out of the box:
|
|
59
|
+
|
|
60
|
+
| **Dataset** | [MIMIC-III](https://physionet.org/content/mimiciii/) / [IV](https://physionet.org/content/mimiciv/) | [eICU-CRD](https://physionet.org/content/eicu-crd/) | [HiRID](https://physionet.org/content/hirid/1.1.1/) | [AUMCdb](https://doi.org/10.17026/dans-22u-f8vd) |
|
|
61
|
+
|-----------------------------|-----------------------------------------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------|--------------------------------------------------|
|
|
62
|
+
| **Admissions** | 40k / 73k | 200k | 33k | 23k |
|
|
63
|
+
| **Version** | v1.4 / v2.2 | v2.0 | v1.1.1 | v1.0.2 |
|
|
64
|
+
| **Frequency** (time-series) | 1 hour | 5 minutes | 2 / 5 minutes | up to 1 minute |
|
|
65
|
+
| **Originally published** | 2015 / 2020 | 2017 | 2020 | 2019 |
|
|
66
|
+
| **Origin** | USA | USA | Switzerland | Netherlands |
|
|
67
|
+
|
|
68
|
+
New datasets can also be added. We are currently working on a package to make this process as smooth as possible.
|
|
69
|
+
The benchmark is designed for operating on preprocessed parquet files.
|
|
70
|
+
<!-- We refer to PyICU (in development)
|
|
71
|
+
or [ricu package](https://github.com/eth-mds/ricu) for generating these parquet files for particular cohorts and endpoints. -->
|
|
72
|
+
|
|
73
|
+
We provide five common tasks for clinical prediction by default:
|
|
74
|
+
|
|
75
|
+
| No | Task | Frequency | Type |
|
|
76
|
+
|-----|---------------------------|---------------------------|-----------------------|
|
|
77
|
+
| 1 | ICU Mortality | Once per Stay (after 24H) | Binary Classification |
|
|
78
|
+
| 2 | Acute Kidney Injury (AKI) | Hourly (within 6H) | Binary Classification |
|
|
79
|
+
| 3 | Sepsis | Hourly (within 6H) | Binary Classification |
|
|
80
|
+
| 4 | Kidney Function(KF) | Once per stay | Regression |
|
|
81
|
+
| 5 | Length of Stay (LoS) | Hourly (within 7D) | Regression |
|
|
82
|
+
|
|
83
|
+
New tasks can be easily added.
|
|
84
|
+
For the purposes of getting started right away, we include the eICU and MIMIC-III demo datasets in our repository.
|
|
85
|
+
|
|
86
|
+
The following repositories may be relevant as well:
|
|
87
|
+
|
|
88
|
+
- [YAIB-cohorts](https://github.com/rvandewater/YAIB-cohorts): Cohort generation for YAIB.
|
|
89
|
+
- [YAIB-models](https://github.com/rvandewater/YAIB-models): Pretrained models for YAIB.
|
|
90
|
+
- [ReciPys](https://github.com/rvandewater/ReciPys): Preprocessing package for YAIB pipelines.
|
|
91
|
+
|
|
92
|
+
For all YAIB related repositories, please see: https://github.com/stars/rvandewater/lists/yaib.
|
|
93
|
+
# 📄Paper
|
|
94
|
+
|
|
95
|
+
To reproduce the benchmarks in our paper, we refer to: the [ML reproducibility document](PAPER.md).
|
|
96
|
+
If you use this code in your research, please cite the following publication:
|
|
97
|
+
|
|
98
|
+
```
|
|
99
|
+
@article{vandewaterYetAnotherICUBenchmark2023,
|
|
100
|
+
title = {Yet Another ICU Benchmark: A Flexible Multi-Center Framework for Clinical ML},
|
|
101
|
+
shorttitle = {Yet Another ICU Benchmark},
|
|
102
|
+
url = {http://arxiv.org/abs/2306.05109},
|
|
103
|
+
language = {en},
|
|
104
|
+
urldate = {2023-06-09},
|
|
105
|
+
publisher = {arXiv},
|
|
106
|
+
author = {Robin van de Water and Hendrik Schmidt and Paul Elbers and Patrick Thoral and Bert Arnrich and Patrick Rockenschaub},
|
|
107
|
+
month = jun,
|
|
108
|
+
year = {2023},
|
|
109
|
+
note = {arXiv:2306.05109 [cs]},
|
|
110
|
+
keywords = {Computer Science - Machine Learning},
|
|
111
|
+
}
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
This paper can also be found on arxiv [2306.05109](https://arxiv.org/abs/2306.05109)
|
|
115
|
+
|
|
116
|
+
# 💿Installation
|
|
117
|
+
YAIB is currently ideally installed from source, however we also offer it an early PyPi release.
|
|
118
|
+
|
|
119
|
+
## Installation from source
|
|
120
|
+
First, we clone this repository using git:
|
|
121
|
+
````
|
|
122
|
+
git clone https://github.com/rvandewater/YAIB.git
|
|
123
|
+
````
|
|
124
|
+
Please note the branch. The newest features and fixes are available at the development branch:
|
|
125
|
+
````
|
|
126
|
+
git checkout development
|
|
127
|
+
````
|
|
128
|
+
YAIB can be installed using a conda environment (preferred) or pip. Below are the three CLI commands to install YAIB
|
|
129
|
+
using **conda**.
|
|
130
|
+
|
|
131
|
+
The first command will install an environment based on Python 3.10.
|
|
132
|
+
|
|
133
|
+
```
|
|
134
|
+
conda env update -f <environment.yml|environment_mps.yml>
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
> Use `environment.yml` on x86 hardware and `environment_mps.yml` on Macs with Metal Performance Shaders.
|
|
138
|
+
|
|
139
|
+
We then activate the environment and install a package called `icu-benchmarks`, after which YAIB should be operational.
|
|
140
|
+
|
|
141
|
+
```
|
|
142
|
+
conda activate yaib
|
|
143
|
+
pip install -e .
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
If you want to install the icu-benchmarks package with **pip**, execute the command below:
|
|
147
|
+
|
|
148
|
+
```
|
|
149
|
+
pip install torch numpy && pip install -e .
|
|
150
|
+
```
|
|
151
|
+
After installation, please check if your Pytorch version works with CUDA (in case available) to ensure the best performance.
|
|
152
|
+
YAIB will automatically list available processors at initialization in its log files.
|
|
153
|
+
|
|
154
|
+
# 👩💻Usage
|
|
155
|
+
|
|
156
|
+
Please refer to [our wiki](https://github.com/rvandewater/YAIB/wiki) for detailed information on how to use YAIB.
|
|
157
|
+
|
|
158
|
+
## Quickstart 🚀 (demo data)
|
|
159
|
+
|
|
160
|
+
In the folder `demo_data` we provide processed publicly available demo datasets from eICU and MIMIC with the necessary labels
|
|
161
|
+
for `Mortality at 24h`,`Sepsis`, `Akute Kidney Injury`, `Kidney Function`, and `Length of Stay`.
|
|
162
|
+
|
|
163
|
+
If you do not yet have access to the ICU datasets, you can run the following command to train models for the included demo
|
|
164
|
+
cohorts:
|
|
165
|
+
|
|
166
|
+
```
|
|
167
|
+
wandb sweep --verbose experiments/demo_benchmark_classification.yml
|
|
168
|
+
wandb sweep --verbose experiments/demo_benchmark_regression.yml
|
|
169
|
+
```
|
|
170
|
+
|
|
171
|
+
```train
|
|
172
|
+
wandb agent <sweep_id>
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
> Tip: You can choose to run each of the configurations on a SLURM cluster instance by `wandb agent --count 1 <sweep_id>`
|
|
176
|
+
|
|
177
|
+
> Note: You will need to have a wandb account and be logged in to run the above commands.
|
|
178
|
+
|
|
179
|
+
## Getting the datasets
|
|
180
|
+
|
|
181
|
+
HiRID, eICU, and MIMIC IV can be accessed through [PhysioNet](https://physionet.org/). A guide to this process can be
|
|
182
|
+
found [here](https://eicu-crd.mit.edu/gettingstarted/access/).
|
|
183
|
+
AUMCdb can be accessed through a separate access [procedure](https://github.com/AmsterdamUMC/AmsterdamUMCdb). We do not have
|
|
184
|
+
involvement in the access procedure and can not answer to any requests for data access.
|
|
185
|
+
|
|
186
|
+
## Cohort creation
|
|
187
|
+
|
|
188
|
+
Since the datasets were created independently of each other, they do not share the same data structure or data identifiers. In
|
|
189
|
+
order to make them interoperable, use the preprocessing utilities
|
|
190
|
+
provided by the [ricu package](https://github.com/eth-mds/ricu).
|
|
191
|
+
Ricu pre-defines a large number of clinical concepts and how to load them from a given dataset, providing a common interface to
|
|
192
|
+
the data, that is used in this
|
|
193
|
+
benchmark. Please refer to our [cohort definition](https://github.com/rvandewater/YAIB-cohorts) code for generating the cohorts
|
|
194
|
+
using our python interface for ricu.
|
|
195
|
+
After this, you can run the benchmark once you have gained access to the datasets.
|
|
196
|
+
|
|
197
|
+
# 👟 Running YAIB
|
|
198
|
+
|
|
199
|
+
## Preprocessing and Training
|
|
200
|
+
|
|
201
|
+
The following command will run training and evaluation on the MIMIC demo dataset for (Binary) mortality prediction at 24h with
|
|
202
|
+
the
|
|
203
|
+
LGBMClassifier. Child samples are reduced due to the small amount of training data. We load available cache and, if available,
|
|
204
|
+
load
|
|
205
|
+
existing cache files.
|
|
206
|
+
|
|
207
|
+
```
|
|
208
|
+
icu-benchmarks train \
|
|
209
|
+
-d demo_data/mortality24/mimic_demo \
|
|
210
|
+
-n mimic_demo \
|
|
211
|
+
-t BinaryClassification \
|
|
212
|
+
-tn Mortality24 \
|
|
213
|
+
-m LGBMClassifier \
|
|
214
|
+
-hp LGBMClassifier.min_child_samples=10 \
|
|
215
|
+
--generate_cache
|
|
216
|
+
--load_cache \
|
|
217
|
+
--seed 2222 \
|
|
218
|
+
-s 2222 \
|
|
219
|
+
-l ../yaib_logs/ \
|
|
220
|
+
--tune
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
> For a list of available flags, run `icu-benchmarks train -h`.
|
|
224
|
+
|
|
225
|
+
> Run with `PYTORCH_ENABLE_MPS_FALLBACK=1` on Macs with Metal Performance Shaders.
|
|
226
|
+
|
|
227
|
+
[//]: # (> Please note that, for Windows based systems, paths need to be formatted differently, e.g: ` r"\..\data\mortality_seq\hirid"`.)
|
|
228
|
+
> For Windows based systems, the next line character (\\) needs to be replaced by (^) (Command Prompt) or (`) (Powershell)
|
|
229
|
+
> respectively.
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
Alternatively, the easiest method to train all the models in the paper is to run these commands from the directory root:
|
|
233
|
+
|
|
234
|
+
```train
|
|
235
|
+
wandb sweep --verbose experiments/benchmark_classification.yml
|
|
236
|
+
wandb sweep --verbose experiments/benchmark_regression.yml
|
|
237
|
+
```
|
|
238
|
+
|
|
239
|
+
This will create two hyperparameter sweeps for WandB for the classification and regression tasks.
|
|
240
|
+
This configuration will train all the models in the paper. You can then run the following command to train the models:
|
|
241
|
+
|
|
242
|
+
```train
|
|
243
|
+
wandb agent <sweep_id>
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
> Tip: You can choose to run each of the configurations on a SLURM cluster instance by `wandb agent --count 1 <sweep_id>`
|
|
247
|
+
|
|
248
|
+
> Note: You will need to have a wandb account and be logged in to run the above commands.
|
|
249
|
+
|
|
250
|
+
## Evaluate
|
|
251
|
+
|
|
252
|
+
It is possible to evaluate a model trained on another dataset. In this case, the source dataset is the demo data from MIMIC and
|
|
253
|
+
the target is the eICU demo:
|
|
254
|
+
|
|
255
|
+
```
|
|
256
|
+
icu-benchmarks evaluate \
|
|
257
|
+
-d demo_data/mortality24/eicu_demo \
|
|
258
|
+
-n eicu_demo \
|
|
259
|
+
-t BinaryClassification \
|
|
260
|
+
-tn Mortality24 \
|
|
261
|
+
-m LGBMClassifier \
|
|
262
|
+
--generate_cache \
|
|
263
|
+
--load_cache \
|
|
264
|
+
-s 2222 \
|
|
265
|
+
-l ../yaib_logs \
|
|
266
|
+
-sn mimic \
|
|
267
|
+
--source-dir ../yaib_logs/mimic_demo/Mortality24/LGBMClassifier/2022-12-12T15-24-46/fold_0
|
|
268
|
+
```
|
|
269
|
+
|
|
270
|
+
## Models
|
|
271
|
+
|
|
272
|
+
We provide several existing machine learning models that are commonly used for multivariate time-series data.
|
|
273
|
+
`pytorch` is used for the deep learning models, `lightgbm` for the boosted tree approaches, and `sklearn` for other classical
|
|
274
|
+
machine learning models.
|
|
275
|
+
The benchmark provides (among others) the following built-in models:
|
|
276
|
+
|
|
277
|
+
- [Logistic Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html?highlight=logistic+regression):
|
|
278
|
+
Standard regression approach.
|
|
279
|
+
- [Elastic Net](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html): Linear regression with
|
|
280
|
+
combined L1 and L2 priors as regularizer.
|
|
281
|
+
- [LightGBM](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf): Efficient gradient
|
|
282
|
+
boosting trees.
|
|
283
|
+
- [Long Short-term Memory (LSTM)](https://ieeexplore.ieee.org/document/818041): The most commonly used type of Recurrent Neural
|
|
284
|
+
Networks for long sequences.
|
|
285
|
+
- [Gated Recurrent Unit (GRU)](https://arxiv.org/abs/1406.1078) : A extension to LSTM which showed
|
|
286
|
+
improvements ([paper](https://arxiv.org/abs/1412.3555)).
|
|
287
|
+
- [Temporal Convolutional Networks (TCN)](https://arxiv.org/pdf/1803.01271 ): 1D convolution approach to sequence data. By
|
|
288
|
+
using dilated convolution to extend the receptive field of the network it has shown great performance on long-term
|
|
289
|
+
dependencies.
|
|
290
|
+
- [Transformers](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf): The most common Attention
|
|
291
|
+
based approach.
|
|
292
|
+
|
|
293
|
+
# 🛠️ Development
|
|
294
|
+
|
|
295
|
+
To adapt YAIB to your own use case, you can use
|
|
296
|
+
the [development information](https://github.com/rvandewater/YAIB/wiki/Contribution-and-development) page as a reference.
|
|
297
|
+
We appreciate contributions to the project. Please read the [contribution guidelines](CONTRIBUTING.MD) before submitting a pull
|
|
298
|
+
request.
|
|
299
|
+
|
|
300
|
+
# Acknowledgements
|
|
301
|
+
|
|
302
|
+
We do not own any of the datasets used in this benchmark. This project uses heavily adapted components of
|
|
303
|
+
the [HiRID benchmark](https://github.com/ratschlab/HIRID-ICU-Benchmark/). We thank the authors for providing this codebase and
|
|
304
|
+
encourage further development to benefit the scientific community. The demo datasets have been released under
|
|
305
|
+
an [Open Data Commons Open Database License (ODbL)](https://opendatacommons.org/licenses/odbl/1-0/).
|
|
306
|
+
|
|
307
|
+
# License
|
|
308
|
+
|
|
309
|
+
This source code is released under the MIT license, included [here](LICENSE). We do not own any of the datasets used or
|
|
310
|
+
included in this repository.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
icu_benchmarks/__init__.py,sha256=pBMRirXTJQo-EO2fGRq6eOp7aR7f5yqa6iYkH_QVObI,159
|
|
2
|
+
icu_benchmarks/contants.py,sha256=_YB08CQmENCUnRWTIldaXzNKYeRlRo98qZi7qo6ZYck,155
|
|
3
|
+
icu_benchmarks/cross_validation.py,sha256=_DXeu-V4qTEyb-V-4ci3hrpNH0BZfSnvMpbyJh8y5RY,4875
|
|
4
|
+
icu_benchmarks/run.py,sha256=HrJC3Pj40Y9TDEaMCp13xgg4w5iou5SFcdzhkhmQpKI,6061
|
|
5
|
+
icu_benchmarks/run_utils.py,sha256=d0LsHBepoFqBHFQNgVMbgHyGXwIjCq2_lp6chFRaEHU,11821
|
|
6
|
+
icu_benchmarks/wandb_utils.py,sha256=zQOkXlI64TXn3ylbrxm9kbMc1TbNx2kVyNRNGmerBsU,1637
|
|
7
|
+
yaib-0.3.1.dist-info/LICENSE,sha256=aQO49L5qimHt0UxPTlRLicU5cSVhdUnm7k1Ec98-r8E,1215
|
|
8
|
+
yaib-0.3.1.dist-info/METADATA,sha256=9NpyKdl0p8qdgjxwcRU7DNaiHBRFTWquKaqk3Pe0KDw,15674
|
|
9
|
+
yaib-0.3.1.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
|
10
|
+
yaib-0.3.1.dist-info/entry_points.txt,sha256=2FvHz3-zDtaqFJXungdMpL9sWlL8roqgPttNzmKr8kQ,59
|
|
11
|
+
yaib-0.3.1.dist-info/top_level.txt,sha256=tMNfKvvlPACIbabj4GoWKS3Qd3sukUaiq7AcbqTXaiM,15
|
|
12
|
+
yaib-0.3.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
icu_benchmarks
|