yaib 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Top-level package for YAIB."""
3
+
4
+ __author__ = "Robin van de Water"
5
+ __email__ = "robin.vandewater@hpi.de"
6
+ __version__ = "0.1.0"
@@ -0,0 +1,7 @@
1
+ from enum import Enum
2
+
3
+
4
+ class RunMode(str, Enum):
5
+ classification = "Classification"
6
+ imputation = "Imputation"
7
+ regression = "Regression"
@@ -0,0 +1,120 @@
1
+ import json
2
+ from datetime import datetime
3
+ import logging
4
+ import gin
5
+ from pathlib import Path
6
+ from pytorch_lightning import seed_everything
7
+
8
+ from icu_benchmarks.wandb_utils import wandb_log
9
+ from icu_benchmarks.run_utils import aggregate_results
10
+ from icu_benchmarks.data.split_process_data import preprocess_data
11
+ from icu_benchmarks.models.train import train_common
12
+ from icu_benchmarks.models.utils import JsonResultLoggingEncoder
13
+ from icu_benchmarks.run_utils import log_full_line
14
+ from icu_benchmarks.contants import RunMode
15
+
16
+
17
+ @gin.configurable
18
+ def execute_repeated_cv(
19
+ data_dir: Path,
20
+ log_dir: Path,
21
+ seed: int,
22
+ load_weights: bool = False,
23
+ source_dir: Path = None,
24
+ cv_repetitions: int = 5,
25
+ cv_repetitions_to_train: int = None,
26
+ cv_folds: int = 5,
27
+ cv_folds_to_train: int = None,
28
+ reproducible: bool = True,
29
+ debug: bool = False,
30
+ generate_cache: bool = False,
31
+ load_cache: bool = False,
32
+ test_on: str = "test",
33
+ mode: str = RunMode.classification,
34
+ pretrained_imputation_model: object = None,
35
+ cpu: bool = False,
36
+ verbose: bool = False,
37
+ wandb: bool = False,
38
+ ) -> float:
39
+ """Preprocesses data and trains a model for each fold.
40
+
41
+ Args:
42
+
43
+ data_dir: Path to the data directory.
44
+ log_dir: Path to the log directory.
45
+ seed: Random seed.
46
+ load_weights: Whether to load weights from source_dir.
47
+ source_dir: Path to the source directory.
48
+ cv_folds: Number of folds for cross validation.
49
+ cv_folds_to_train: Number of folds to use during training. If None, all folds are trained on.
50
+ cv_repetitions: Amount of cross validation repetitions.
51
+ cv_repetitions_to_train: Amount of training repetitions. If None, all repetitions are trained on.
52
+ reproducible: Whether to make torch reproducible.
53
+ debug: Whether to load less data and enable more logging.
54
+ generate_cache: Whether to generate and save cache.
55
+ load_cache: Whether to load previously cached data.
56
+ test_on: Dataset to test on. Can be "test" or "val" (e.g. for hyperparameter tuning).
57
+ mode: Run mode. Can be one of the values of RunMode
58
+ pretrained_imputation_model: Use a pretrained imputation model.
59
+ cpu: Whether to run on CPU.
60
+ verbose: Enable detailed logging.
61
+ Returns:
62
+ The average loss of all folds.
63
+ """
64
+ if not cv_repetitions_to_train:
65
+ cv_repetitions_to_train = cv_repetitions
66
+ if not cv_folds_to_train:
67
+ cv_folds_to_train = cv_folds
68
+ agg_loss = 0
69
+
70
+ seed_everything(seed, reproducible)
71
+ for repetition in range(cv_repetitions_to_train):
72
+ for fold_index in range(cv_folds_to_train):
73
+ start_time = datetime.now()
74
+ data = preprocess_data(
75
+ data_dir,
76
+ seed=seed,
77
+ debug=debug,
78
+ load_cache=load_cache,
79
+ generate_cache=generate_cache,
80
+ cv_repetitions=cv_repetitions,
81
+ repetition_index=repetition,
82
+ cv_folds=cv_folds,
83
+ fold_index=fold_index,
84
+ pretrained_imputation_model=pretrained_imputation_model,
85
+ runmode=mode,
86
+ )
87
+
88
+ repetition_fold_dir = log_dir / f"repetition_{repetition}" / f"fold_{fold_index}"
89
+ repetition_fold_dir.mkdir(parents=True, exist_ok=True)
90
+ preprocess_time = datetime.now() - start_time
91
+ start_time = datetime.now()
92
+ agg_loss += train_common(
93
+ data,
94
+ log_dir=repetition_fold_dir,
95
+ load_weights=load_weights,
96
+ source_dir=source_dir,
97
+ reproducible=reproducible,
98
+ test_on=test_on,
99
+ mode=mode,
100
+ cpu=cpu,
101
+ verbose=verbose,
102
+ use_wandb=wandb,
103
+ )
104
+ train_time = datetime.now() - start_time
105
+
106
+ log_full_line(
107
+ f"FINISHED FOLD {fold_index}| PREPROCESSING DURATION {preprocess_time}| TRAINING DURATION {train_time}",
108
+ level=logging.INFO,
109
+ )
110
+ durations = {"preprocessing_duration": preprocess_time, "train_duration": train_time}
111
+
112
+ with open(repetition_fold_dir / "durations.json", "w") as f:
113
+ json.dump(durations, f, cls=JsonResultLoggingEncoder)
114
+ if wandb:
115
+ wandb_log({"Iteration": repetition * cv_folds_to_train + fold_index})
116
+ if repetition * cv_folds_to_train + fold_index > 1:
117
+ aggregate_results(log_dir)
118
+ log_full_line(f"FINISHED CV REPETITION {repetition}", level=logging.INFO, char="=", num_newlines=3)
119
+
120
+ return agg_loss / (cv_repetitions_to_train * cv_folds_to_train)
icu_benchmarks/run.py ADDED
@@ -0,0 +1,176 @@
1
+ # -*- coding: utf-8 -*-
2
+ from datetime import datetime
3
+
4
+ import gin
5
+ import logging
6
+ import sys
7
+ from pathlib import Path
8
+ import importlib.util
9
+
10
+ import torch.cuda
11
+
12
+ from icu_benchmarks.wandb_utils import update_wandb_config, apply_wandb_sweep, set_wandb_run_name
13
+ from icu_benchmarks.tuning.hyperparameters import choose_and_bind_hyperparameters
14
+ from scripts.plotting.utils import plot_aggregated_results
15
+ from icu_benchmarks.cross_validation import execute_repeated_cv
16
+ from icu_benchmarks.run_utils import (
17
+ build_parser,
18
+ create_run_dir,
19
+ aggregate_results,
20
+ log_full_line,
21
+ load_pretrained_imputation_model,
22
+ setup_logging,
23
+ )
24
+ from icu_benchmarks.contants import RunMode
25
+
26
+
27
+ @gin.configurable("Run")
28
+ def get_mode(mode: gin.REQUIRED):
29
+ # Check if enum is mode.
30
+ assert RunMode(mode)
31
+ return RunMode(mode)
32
+
33
+
34
+ def main(my_args=tuple(sys.argv[1:])):
35
+ args, _ = build_parser().parse_known_args(my_args)
36
+
37
+ # Set arguments for wandb sweep
38
+ if args.wandb_sweep:
39
+ args = apply_wandb_sweep(args)
40
+
41
+ # Initialize loggers
42
+ log_format = "%(asctime)s - %(levelname)s - %(name)s : %(message)s"
43
+ date_format = "%Y-%m-%d %H:%M:%S"
44
+ verbose = args.verbose
45
+ setup_logging(date_format, log_format, verbose)
46
+
47
+ # Load weights if in evaluation mode
48
+ load_weights = args.command == "evaluate"
49
+ data_dir = Path(args.data_dir)
50
+
51
+ # Get arguments
52
+ name = args.name
53
+ task = args.task
54
+ model = args.model
55
+ reproducible = args.reproducible
56
+
57
+ # Set experiment name
58
+ if name is None:
59
+ name = data_dir.name
60
+ logging.info(f"Running experiment {name}.")
61
+
62
+ # Load task config
63
+ gin.parse_config_file(f"configs/tasks/{task}.gin")
64
+
65
+ mode = get_mode()
66
+
67
+ if args.wandb_sweep:
68
+ run_name = f"{mode}_{model}_{name}"
69
+ set_wandb_run_name(run_name)
70
+
71
+ logging.info(f"Task mode: {mode}.")
72
+ experiment = args.experiment
73
+
74
+ pretrained_imputation_model = load_pretrained_imputation_model(args.pretrained_imputation)
75
+
76
+ # Log imputation model to wandb
77
+ update_wandb_config(
78
+ {
79
+ "pretrained_imputation_model": pretrained_imputation_model.__class__.__name__
80
+ if pretrained_imputation_model is not None
81
+ else "None"
82
+ }
83
+ )
84
+ source_dir = None
85
+ log_dir_name = args.log_dir / name
86
+ log_dir = (
87
+ (log_dir_name / experiment)
88
+ if experiment
89
+ else (log_dir_name / (args.task_name if args.task_name is not None else args.task) / model)
90
+ )
91
+ if torch.cuda.is_available():
92
+ for name in range(0, torch.cuda.device_count()):
93
+ log_full_line(f"Available GPU {name}: {torch.cuda.get_device_name(name)}", level=logging.INFO)
94
+ else:
95
+ log_full_line(
96
+ "No GPUs available: please check your device and Torch,Cuda installation if unintended.", level=logging.WARNING
97
+ )
98
+
99
+ log_full_line(f"Logging to {log_dir}.", logging.INFO)
100
+
101
+ if args.preprocessor:
102
+ # Import custom supplied preprocessor
103
+ log_full_line(f"Importing custom preprocessor from {args.preprocessor}.", logging.INFO)
104
+ try:
105
+ spec = importlib.util.spec_from_file_location("CustomPreprocessor", args.preprocessor)
106
+ module = importlib.util.module_from_spec(spec)
107
+ sys.modules["preprocessor"] = module
108
+ spec.loader.exec_module(module)
109
+ gin.bind_parameter("preprocess.preprocessor", module.CustomPreprocessor)
110
+ except Exception as e:
111
+ logging.error(f"Could not import custom preprocessor from {args.preprocessor}: {e}")
112
+
113
+ if load_weights:
114
+ # Evaluate
115
+ log_dir /= f"from_{args.source_name}"
116
+ run_dir = create_run_dir(log_dir)
117
+ source_dir = args.source_dir
118
+ gin.parse_config_file(source_dir / "train_config.gin")
119
+ else:
120
+ # Train
121
+ checkpoint = log_dir / args.checkpoint if args.checkpoint else None
122
+ model_path = (
123
+ Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin"
124
+ )
125
+ gin_config_files = (
126
+ [Path(f"configs/experiments/{args.experiment}.gin")]
127
+ if args.experiment
128
+ else [model_path, Path(f"configs/tasks/{task}.gin")]
129
+ )
130
+ gin.parse_config_files_and_bindings(gin_config_files, args.hyperparams, finalize_config=False)
131
+ log_full_line(f"Data directory: {data_dir.resolve()}", level=logging.INFO)
132
+ run_dir = create_run_dir(log_dir)
133
+ choose_and_bind_hyperparameters(
134
+ args.tune,
135
+ data_dir,
136
+ run_dir,
137
+ args.seed,
138
+ run_mode=mode,
139
+ checkpoint=checkpoint,
140
+ debug=args.debug,
141
+ generate_cache=args.generate_cache,
142
+ load_cache=args.load_cache,
143
+ verbose=args.verbose,
144
+ )
145
+
146
+ log_full_line(f"Logging to {run_dir.resolve()}", level=logging.INFO)
147
+ log_full_line("STARTING TRAINING", level=logging.INFO, char="=", num_newlines=3)
148
+ start_time = datetime.now()
149
+ execute_repeated_cv(
150
+ data_dir,
151
+ run_dir,
152
+ args.seed,
153
+ load_weights=load_weights,
154
+ source_dir=source_dir,
155
+ reproducible=reproducible,
156
+ debug=args.debug,
157
+ verbose=args.verbose,
158
+ load_cache=args.load_cache,
159
+ generate_cache=args.generate_cache,
160
+ mode=mode,
161
+ pretrained_imputation_model=pretrained_imputation_model,
162
+ cpu=args.cpu,
163
+ wandb=args.wandb_sweep,
164
+ )
165
+
166
+ log_full_line("FINISHED TRAINING", level=logging.INFO, char="=", num_newlines=3)
167
+ execution_time = datetime.now() - start_time
168
+ log_full_line(f"DURATION: {execution_time}", level=logging.INFO, char="")
169
+ aggregate_results(run_dir, execution_time)
170
+ if args.plot:
171
+ plot_aggregated_results(run_dir, "aggregated_test_metrics.json")
172
+
173
+
174
+ """Main module."""
175
+ if __name__ == "__main__":
176
+ main()
@@ -0,0 +1,254 @@
1
+ import warnings
2
+ from math import sqrt
3
+
4
+ import torch
5
+ import json
6
+ from argparse import ArgumentParser, BooleanOptionalAction
7
+ from datetime import datetime, timedelta
8
+ import logging
9
+ from pathlib import Path
10
+ import scipy.stats as stats
11
+ import shutil
12
+ from statistics import mean, pstdev
13
+ from icu_benchmarks.models.utils import JsonResultLoggingEncoder
14
+ from icu_benchmarks.wandb_utils import wandb_log
15
+
16
+
17
+ def build_parser() -> ArgumentParser:
18
+ """Builds an ArgumentParser for the command line.
19
+
20
+ Returns:
21
+ The configured ArgumentParser.
22
+ """
23
+ parser = ArgumentParser(description="Benchmark lib for processing and evaluation of deep learning models on ICU data")
24
+
25
+ parent_parser = ArgumentParser(add_help=False)
26
+ subparsers = parser.add_subparsers(title="Commands", dest="command", required=True)
27
+
28
+ # ARGUMENTS FOR ALL COMMANDS
29
+ general_args = parent_parser.add_argument_group("General arguments")
30
+ general_args.add_argument("-d", "--data-dir", required=True, type=Path, help="Path to the parquet data directory.")
31
+ general_args.add_argument("-t", "--task", default="BinaryClassification", required=True, help="Name of the task gin.")
32
+ general_args.add_argument("-n", "--name", required=False, help="Name of the (target) dataset.")
33
+ general_args.add_argument("-tn", "--task-name", required=False, help="Name of the task, used for naming experiments.")
34
+ general_args.add_argument("-m", "--model", default="LGBMClassifier", required=False, help="Name of the model gin.")
35
+ general_args.add_argument("-e", "--experiment", required=False, help="Name of the experiment gin.")
36
+ general_args.add_argument(
37
+ "-l", "--log-dir", required=False, default=Path("../yaib_logs/"), type=Path, help="Log directory with model weights."
38
+ )
39
+ general_args.add_argument(
40
+ "-s", "--seed", required=False, default=1234, type=int, help="Random seed for processing, tuning and training."
41
+ )
42
+ general_args.add_argument(
43
+ "-v",
44
+ "--verbose",
45
+ default=False,
46
+ required=False,
47
+ action=BooleanOptionalAction,
48
+ help="Whether to use verbose logging. Disable for clean logs.",
49
+ )
50
+ general_args.add_argument("--cpu", default=False, required=False, action=BooleanOptionalAction, help="Set to use CPU.")
51
+ general_args.add_argument(
52
+ "-db", "--debug", required=False, default=False, action=BooleanOptionalAction, help="Set to load less data."
53
+ )
54
+ general_args.add_argument(
55
+ "-lc",
56
+ "--load_cache",
57
+ required=False,
58
+ default=False,
59
+ action=BooleanOptionalAction,
60
+ help="Set to load generated data cache.",
61
+ )
62
+ general_args.add_argument(
63
+ "-gc",
64
+ "--generate_cache",
65
+ required=False,
66
+ default=False,
67
+ action=BooleanOptionalAction,
68
+ help="Set to generate data cache.",
69
+ )
70
+ general_args.add_argument("-p", "--preprocessor", required=False, type=Path, help="Load custom preprocessor from file.")
71
+ general_args.add_argument("-pl", "--plot", required=False, action=BooleanOptionalAction, help="Generate common plots.")
72
+ general_args.add_argument(
73
+ "-wd", "--wandb-sweep", required=False, action="store_true", help="Activates wandb hyper parameter sweep."
74
+ )
75
+ general_args.add_argument(
76
+ "-imp", "--pretrained-imputation", required=False, type=str, help="Path to pretrained imputation model."
77
+ )
78
+
79
+ # MODEL TRAINING ARGUMENTS
80
+ prep_and_train = subparsers.add_parser("train", help="Preprocess features and train model.", parents=[parent_parser])
81
+ prep_and_train.add_argument(
82
+ "--reproducible", required=False, default=True, action=BooleanOptionalAction, help="Make torch reproducible."
83
+ )
84
+ prep_and_train.add_argument("-hp", "--hyperparams", required=False, nargs="+", help="Hyperparameters for model.")
85
+ prep_and_train.add_argument("--tune", default=False, action=BooleanOptionalAction, help="Find best hyperparameters.")
86
+ prep_and_train.add_argument("--checkpoint", required=False, type=Path, help="Use previous checkpoint.")
87
+
88
+ # EVALUATION PARSER
89
+ evaluate = subparsers.add_parser("evaluate", help="Evaluate trained model on data.", parents=[parent_parser])
90
+ evaluate.add_argument("-sn", "--source-name", required=True, type=Path, help="Name of the source dataset.")
91
+ evaluate.add_argument("--source-dir", required=True, type=Path, help="Directory containing gin and model weights.")
92
+
93
+ return parser
94
+
95
+
96
+ def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path:
97
+ """Creates a log directory with the current time as name.
98
+
99
+ Also creates a file in the log directory, if any parameters were randomly searched.
100
+ The filename contains the fixed hyperparameters to check against in future runs.
101
+
102
+ Args:
103
+ log_dir: Parent directory to create run directory in.
104
+ randomly_searched_params: String representing the randomly searched params.
105
+
106
+ Returns:
107
+ Path to the created run log directory.
108
+ """
109
+ log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
110
+ log_dir_run.mkdir(parents=True)
111
+ if randomly_searched_params:
112
+ (log_dir_run / randomly_searched_params).touch()
113
+ return log_dir_run
114
+
115
+
116
+ def aggregate_results(log_dir: Path, execution_time: timedelta = None):
117
+ """Aggregates results from all folds and writes to JSON file.
118
+
119
+ Args:
120
+ log_dir: Path to the log directory.
121
+ execution_time: Overall execution time.
122
+ """
123
+ aggregated = {}
124
+ for repetition in log_dir.iterdir():
125
+ if repetition.is_dir():
126
+ aggregated[repetition.name] = {}
127
+ for fold_iter in repetition.iterdir():
128
+ aggregated[repetition.name][fold_iter.name] = {}
129
+ if (fold_iter / "test_metrics.json").is_file():
130
+ with open(fold_iter / "test_metrics.json", "r") as f:
131
+ result = json.load(f)
132
+ aggregated[repetition.name][fold_iter.name].update(result)
133
+ elif (fold_iter / "val_metrics.csv").is_file():
134
+ with open(fold_iter / "val_metrics.csv", "r") as f:
135
+ result = json.load(f)
136
+ aggregated[repetition.name][fold_iter.name].update(result)
137
+ # Add durations to metrics
138
+ if (fold_iter / "durations.json").is_file():
139
+ with open(fold_iter / "durations.json", "r") as f:
140
+ result = json.load(f)
141
+ aggregated[repetition.name][fold_iter.name].update(result)
142
+
143
+ # Aggregate results per metric
144
+ list_scores = {}
145
+ for repetition, folds in aggregated.items():
146
+ for fold, result in folds.items():
147
+ for metric, score in result.items():
148
+ if isinstance(score, (float, int)):
149
+ list_scores[metric] = list_scores.setdefault(metric, [])
150
+ list_scores[metric].append(score)
151
+
152
+ # Compute statistical metric over aggregated results
153
+ averaged_scores = {metric: (mean(list)) for metric, list in list_scores.items()}
154
+
155
+ # Calculate the population standard deviation over aggregated results over folds/iterations
156
+ # Divide by sqrt(n) to get standard deviation.
157
+ std_scores = {metric: (pstdev(list) / sqrt(len(list))) for metric, list in list_scores.items()}
158
+
159
+ confidence_interval = {
160
+ metric: (stats.t.interval(0.95, len(list) - 1, loc=mean(list), scale=stats.sem(list)))
161
+ for metric, list in list_scores.items()
162
+ }
163
+
164
+ accumulated_metrics = {
165
+ "avg": averaged_scores,
166
+ "std": std_scores,
167
+ "CI_0.95": confidence_interval,
168
+ "execution_time": execution_time.total_seconds() if execution_time is not None else 0.0,
169
+ }
170
+
171
+ with open(log_dir / "aggregated_test_metrics.json", "w") as f:
172
+ json.dump(aggregated, f, cls=JsonResultLoggingEncoder)
173
+
174
+ with open(log_dir / "accumulated_test_metrics.json", "w") as f:
175
+ json.dump(accumulated_metrics, f, cls=JsonResultLoggingEncoder)
176
+
177
+ logging.info(f"Accumulated results: {accumulated_metrics}")
178
+
179
+ wandb_log(json.loads(json.dumps(accumulated_metrics, cls=JsonResultLoggingEncoder)))
180
+
181
+
182
+ def log_full_line(msg: str, level: int = logging.INFO, char: str = "-", num_newlines: int = 0):
183
+ """Logs a full line of a given character with a message centered.
184
+
185
+ Args:
186
+ msg: Message to log.
187
+ level: Logging level.
188
+ char: Character to use for the line.
189
+ num_newlines: Number of newlines to append.
190
+ """
191
+ terminal_size = shutil.get_terminal_size((80, 20))
192
+ reserved_chars = len(logging.getLevelName(level)) + 28
193
+ logging.log(
194
+ level,
195
+ "{0:{char}^{width}}{1}".format(msg, "\n" * num_newlines, char=char, width=terminal_size.columns - reserved_chars),
196
+ )
197
+
198
+
199
+ def load_pretrained_imputation_model(use_pretrained_imputation):
200
+ """Loads a pretrained imputation model.
201
+
202
+ Args:
203
+ use_pretrained_imputation: Path to the pretrained imputation model.
204
+ """
205
+ if use_pretrained_imputation is not None and not Path(use_pretrained_imputation).exists():
206
+ logging.warning("The specified pretrained imputation model does not exist.")
207
+ use_pretrained_imputation = None
208
+
209
+ if use_pretrained_imputation is not None:
210
+ logging.info("Using pretrained imputation from" + str(use_pretrained_imputation))
211
+ pretrained_imputation_model_checkpoint = torch.load(use_pretrained_imputation, map_location=torch.device("cpu"))
212
+ if isinstance(pretrained_imputation_model_checkpoint, dict):
213
+ imputation_model_class = pretrained_imputation_model_checkpoint["class"]
214
+ pretrained_imputation_model = imputation_model_class(**pretrained_imputation_model_checkpoint["hyper_parameters"])
215
+ pretrained_imputation_model.set_trained_columns(pretrained_imputation_model_checkpoint["trained_columns"])
216
+ pretrained_imputation_model.load_state_dict(pretrained_imputation_model_checkpoint["state_dict"])
217
+ else:
218
+ pretrained_imputation_model = pretrained_imputation_model_checkpoint
219
+ pretrained_imputation_model = pretrained_imputation_model.to("cuda" if torch.cuda.is_available() else "cpu")
220
+ try:
221
+ logging.info(f"imputation model device: {next(pretrained_imputation_model.parameters()).device}")
222
+ pretrained_imputation_model.device = next(pretrained_imputation_model.parameters()).device
223
+ except Exception as e:
224
+ logging.debug(f"Could not set device of imputation model: {e}")
225
+ else:
226
+ pretrained_imputation_model = None
227
+
228
+ return pretrained_imputation_model
229
+
230
+
231
+ def setup_logging(date_format, log_format, verbose):
232
+ """
233
+ Set up all loggers to use the same format and date format.
234
+
235
+ Args:
236
+ date_format: Format for the date.
237
+ log_format: Format for the log.
238
+ verbose: Whether to log debug messages.
239
+ """
240
+ logging.basicConfig(format=log_format, datefmt=date_format)
241
+ loggers = ["pytorch_lightning", "lightning_fabric"]
242
+ for logger in loggers:
243
+ logging.getLogger(logger).handlers[0].setFormatter(logging.Formatter(log_format, datefmt=date_format))
244
+
245
+ if not verbose:
246
+ logging.getLogger().setLevel(logging.INFO)
247
+ for logger in loggers:
248
+ logging.getLogger(logger).setLevel(logging.INFO)
249
+ warnings.filterwarnings("ignore")
250
+ else:
251
+ logging.getLogger().setLevel(logging.DEBUG)
252
+ for logger in loggers:
253
+ logging.getLogger(logger).setLevel(logging.DEBUG)
254
+ warnings.filterwarnings("default")
@@ -0,0 +1,61 @@
1
+ from argparse import Namespace
2
+ import logging
3
+ import wandb
4
+
5
+
6
+ def wandb_running() -> bool:
7
+ """Check if wandb is running."""
8
+ return wandb.run is not None
9
+
10
+
11
+ def update_wandb_config(config: dict) -> None:
12
+ """updates wandb config if wandb is running
13
+
14
+ Args:
15
+ config (dict): config to set
16
+ """
17
+ logging.debug(f"Updating Wandb config: {config}")
18
+ if wandb_running():
19
+ wandb.config.update(config)
20
+
21
+
22
+ def apply_wandb_sweep(args: Namespace) -> Namespace:
23
+ """applies the wandb sweep configuration to the namespace object
24
+
25
+ Args:
26
+ args (Namespace): parsed arguments
27
+
28
+ Returns:
29
+ Namespace: arguments with sweep configuration applied (some are applied via hyperparams)
30
+ """
31
+ wandb.init()
32
+ sweep_config = wandb.config
33
+ args.__dict__.update(sweep_config)
34
+ if args.hyperparams is None:
35
+ args.hyperparams = []
36
+ for key, value in sweep_config.items():
37
+ args.hyperparams.append(f"{key}=" + (("'" + value + "'") if isinstance(value, str) else str(value)))
38
+ logging.info(f"hyperparams after loading sweep config: {args.hyperparams}")
39
+ return args
40
+
41
+
42
+ def wandb_log(log_dict):
43
+ """logs metrics to wandb
44
+
45
+ Args:
46
+ log_dict (dict): metric dict to log
47
+ """
48
+ if wandb_running():
49
+ wandb.log(log_dict)
50
+
51
+
52
+ def set_wandb_run_name(run_name):
53
+ """stores the run name in wandb config
54
+
55
+ Args:
56
+ run_name (str): name of the current run
57
+ """
58
+ if wandb_running():
59
+ wandb.config.update({"run-name": run_name})
60
+ wandb.run.name = run_name
61
+ wandb.run.save()
@@ -0,0 +1,25 @@
1
+
2
+
3
+ MIT License
4
+
5
+ Copyright (c) 2023, Robin van de Water, Hendrik Schmidt, Patrick Rockenschaub
6
+ Copyright (c) 2021, ETH Zurich, Biomedical Informatics Group; ratschlab
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+
@@ -0,0 +1,310 @@
1
+ Metadata-Version: 2.1
2
+ Name: yaib
3
+ Version: 0.3.1
4
+ Summary: Yet Another ICU Benchmark is a holistic framework for the automation of the development of clinical prediction models on ICU data. Users can create custom datasets, cohorts, prediction tasks, endpoints, and models.
5
+ Home-page: https://github.com/rvandewater/YAIB
6
+ Author: Robin van de Water
7
+ Author-email: robin.vandewater@hpi.de
8
+ License: MIT license
9
+ Keywords: benchmark mimic-iii eicu hirid clinical-ml machine-learning benchmark time-series mimic-iv patient-monitoring amsterdamumcdb clinical-data ehr icu ricu pyicu
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Natural Language :: English
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: black (==23.3.0)
18
+ Requires-Dist: coverage (==7.2.3)
19
+ Requires-Dist: flake8 (==5.0.4)
20
+ Requires-Dist: matplotlib (==3.7.1)
21
+ Requires-Dist: gin-config (==0.5.0)
22
+ Requires-Dist: pytorch-ignite (==0.4.11)
23
+ Requires-Dist: torch (==2.0.1)
24
+ Requires-Dist: pytorch-cuda (==11.8)
25
+ Requires-Dist: lightgbm (==3.3.5)
26
+ Requires-Dist: numpy (==1.24.3)
27
+ Requires-Dist: pandas (==2.0.0)
28
+ Requires-Dist: pyarrow (==11.0.0)
29
+ Requires-Dist: pytest (==7.3.1)
30
+ Requires-Dist: scikit-learn (==1.2.2)
31
+ Requires-Dist: tensorboard (==2.12.2)
32
+ Requires-Dist: tqdm (==4.64.1)
33
+ Requires-Dist: pytorch-lightning (==2.0.3)
34
+ Requires-Dist: wandb (==0.15.4)
35
+ Requires-Dist: pip (==23.1)
36
+ Requires-Dist: einops (==0.6.1)
37
+ Requires-Dist: hydra-core (==1.3)
38
+ Requires-Dist: recipies (==0.1.1)
39
+ Provides-Extra: mps
40
+ Requires-Dist: mkl (<2022) ; extra == 'mps'
41
+
42
+ ![YAIB logo](https://github.com/rvandewater/YAIB/blob/development/docs/figures/yaib_logo.png?raw=true)
43
+
44
+ # 🧪 Yet Another ICU Benchmark
45
+
46
+ [![CI](https://github.com/rvandewater/YAIB/actions/workflows/ci.yml/badge.svg?branch=development)](https://github.com/rvandewater/YAIB/actions/workflows/ci.yml)
47
+ [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
48
+ ![Platform](https://img.shields.io/badge/platform-linux--64%20|%20win--64%20|%20osx--64-lightgrey)
49
+ [![arXiv](https://img.shields.io/badge/arXiv-2306.05109-b31b1b.svg)](http://arxiv.org/abs/2306.05109)
50
+ [![PyPI version shields.io](https://img.shields.io/pypi/v/yaib.svg)](https://pypi.python.org/pypi/yaib/)
51
+ [![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
52
+
53
+ [//]: # (TODO: add coverage once we have some tests )
54
+
55
+ Yet another ICU benchmark (YAIB) provides a framework for doing clinical machine learning experiments on Intensive Care Unit (
56
+ ICU) EHR data.
57
+
58
+ We support the following datasets out of the box:
59
+
60
+ | **Dataset** | [MIMIC-III](https://physionet.org/content/mimiciii/) / [IV](https://physionet.org/content/mimiciv/) | [eICU-CRD](https://physionet.org/content/eicu-crd/) | [HiRID](https://physionet.org/content/hirid/1.1.1/) | [AUMCdb](https://doi.org/10.17026/dans-22u-f8vd) |
61
+ |-----------------------------|-----------------------------------------------------------------------------------------------------|-----------------------------------------------------|-----------------------------------------------------|--------------------------------------------------|
62
+ | **Admissions** | 40k / 73k | 200k | 33k | 23k |
63
+ | **Version** | v1.4 / v2.2 | v2.0 | v1.1.1 | v1.0.2 |
64
+ | **Frequency** (time-series) | 1 hour | 5 minutes | 2 / 5 minutes | up to 1 minute |
65
+ | **Originally published** | 2015 / 2020 | 2017 | 2020 | 2019 |
66
+ | **Origin** | USA | USA | Switzerland | Netherlands |
67
+
68
+ New datasets can also be added. We are currently working on a package to make this process as smooth as possible.
69
+ The benchmark is designed for operating on preprocessed parquet files.
70
+ <!-- We refer to PyICU (in development)
71
+ or [ricu package](https://github.com/eth-mds/ricu) for generating these parquet files for particular cohorts and endpoints. -->
72
+
73
+ We provide five common tasks for clinical prediction by default:
74
+
75
+ | No | Task | Frequency | Type |
76
+ |-----|---------------------------|---------------------------|-----------------------|
77
+ | 1 | ICU Mortality | Once per Stay (after 24H) | Binary Classification |
78
+ | 2 | Acute Kidney Injury (AKI) | Hourly (within 6H) | Binary Classification |
79
+ | 3 | Sepsis | Hourly (within 6H) | Binary Classification |
80
+ | 4 | Kidney Function(KF) | Once per stay | Regression |
81
+ | 5 | Length of Stay (LoS) | Hourly (within 7D) | Regression |
82
+
83
+ New tasks can be easily added.
84
+ For the purposes of getting started right away, we include the eICU and MIMIC-III demo datasets in our repository.
85
+
86
+ The following repositories may be relevant as well:
87
+
88
+ - [YAIB-cohorts](https://github.com/rvandewater/YAIB-cohorts): Cohort generation for YAIB.
89
+ - [YAIB-models](https://github.com/rvandewater/YAIB-models): Pretrained models for YAIB.
90
+ - [ReciPys](https://github.com/rvandewater/ReciPys): Preprocessing package for YAIB pipelines.
91
+
92
+ For all YAIB related repositories, please see: https://github.com/stars/rvandewater/lists/yaib.
93
+ # 📄Paper
94
+
95
+ To reproduce the benchmarks in our paper, we refer to: the [ML reproducibility document](PAPER.md).
96
+ If you use this code in your research, please cite the following publication:
97
+
98
+ ```
99
+ @article{vandewaterYetAnotherICUBenchmark2023,
100
+ title = {Yet Another ICU Benchmark: A Flexible Multi-Center Framework for Clinical ML},
101
+ shorttitle = {Yet Another ICU Benchmark},
102
+ url = {http://arxiv.org/abs/2306.05109},
103
+ language = {en},
104
+ urldate = {2023-06-09},
105
+ publisher = {arXiv},
106
+ author = {Robin van de Water and Hendrik Schmidt and Paul Elbers and Patrick Thoral and Bert Arnrich and Patrick Rockenschaub},
107
+ month = jun,
108
+ year = {2023},
109
+ note = {arXiv:2306.05109 [cs]},
110
+ keywords = {Computer Science - Machine Learning},
111
+ }
112
+ ```
113
+
114
+ This paper can also be found on arxiv [2306.05109](https://arxiv.org/abs/2306.05109)
115
+
116
+ # 💿Installation
117
+ YAIB is currently ideally installed from source, however we also offer it an early PyPi release.
118
+
119
+ ## Installation from source
120
+ First, we clone this repository using git:
121
+ ````
122
+ git clone https://github.com/rvandewater/YAIB.git
123
+ ````
124
+ Please note the branch. The newest features and fixes are available at the development branch:
125
+ ````
126
+ git checkout development
127
+ ````
128
+ YAIB can be installed using a conda environment (preferred) or pip. Below are the three CLI commands to install YAIB
129
+ using **conda**.
130
+
131
+ The first command will install an environment based on Python 3.10.
132
+
133
+ ```
134
+ conda env update -f <environment.yml|environment_mps.yml>
135
+ ```
136
+
137
+ > Use `environment.yml` on x86 hardware and `environment_mps.yml` on Macs with Metal Performance Shaders.
138
+
139
+ We then activate the environment and install a package called `icu-benchmarks`, after which YAIB should be operational.
140
+
141
+ ```
142
+ conda activate yaib
143
+ pip install -e .
144
+ ```
145
+
146
+ If you want to install the icu-benchmarks package with **pip**, execute the command below:
147
+
148
+ ```
149
+ pip install torch numpy && pip install -e .
150
+ ```
151
+ After installation, please check if your Pytorch version works with CUDA (in case available) to ensure the best performance.
152
+ YAIB will automatically list available processors at initialization in its log files.
153
+
154
+ # 👩‍💻Usage
155
+
156
+ Please refer to [our wiki](https://github.com/rvandewater/YAIB/wiki) for detailed information on how to use YAIB.
157
+
158
+ ## Quickstart 🚀 (demo data)
159
+
160
+ In the folder `demo_data` we provide processed publicly available demo datasets from eICU and MIMIC with the necessary labels
161
+ for `Mortality at 24h`,`Sepsis`, `Akute Kidney Injury`, `Kidney Function`, and `Length of Stay`.
162
+
163
+ If you do not yet have access to the ICU datasets, you can run the following command to train models for the included demo
164
+ cohorts:
165
+
166
+ ```
167
+ wandb sweep --verbose experiments/demo_benchmark_classification.yml
168
+ wandb sweep --verbose experiments/demo_benchmark_regression.yml
169
+ ```
170
+
171
+ ```train
172
+ wandb agent <sweep_id>
173
+ ```
174
+
175
+ > Tip: You can choose to run each of the configurations on a SLURM cluster instance by `wandb agent --count 1 <sweep_id>`
176
+
177
+ > Note: You will need to have a wandb account and be logged in to run the above commands.
178
+
179
+ ## Getting the datasets
180
+
181
+ HiRID, eICU, and MIMIC IV can be accessed through [PhysioNet](https://physionet.org/). A guide to this process can be
182
+ found [here](https://eicu-crd.mit.edu/gettingstarted/access/).
183
+ AUMCdb can be accessed through a separate access [procedure](https://github.com/AmsterdamUMC/AmsterdamUMCdb). We do not have
184
+ involvement in the access procedure and can not answer to any requests for data access.
185
+
186
+ ## Cohort creation
187
+
188
+ Since the datasets were created independently of each other, they do not share the same data structure or data identifiers. In
189
+ order to make them interoperable, use the preprocessing utilities
190
+ provided by the [ricu package](https://github.com/eth-mds/ricu).
191
+ Ricu pre-defines a large number of clinical concepts and how to load them from a given dataset, providing a common interface to
192
+ the data, that is used in this
193
+ benchmark. Please refer to our [cohort definition](https://github.com/rvandewater/YAIB-cohorts) code for generating the cohorts
194
+ using our python interface for ricu.
195
+ After this, you can run the benchmark once you have gained access to the datasets.
196
+
197
+ # 👟 Running YAIB
198
+
199
+ ## Preprocessing and Training
200
+
201
+ The following command will run training and evaluation on the MIMIC demo dataset for (Binary) mortality prediction at 24h with
202
+ the
203
+ LGBMClassifier. Child samples are reduced due to the small amount of training data. We load available cache and, if available,
204
+ load
205
+ existing cache files.
206
+
207
+ ```
208
+ icu-benchmarks train \
209
+ -d demo_data/mortality24/mimic_demo \
210
+ -n mimic_demo \
211
+ -t BinaryClassification \
212
+ -tn Mortality24 \
213
+ -m LGBMClassifier \
214
+ -hp LGBMClassifier.min_child_samples=10 \
215
+ --generate_cache
216
+ --load_cache \
217
+ --seed 2222 \
218
+ -s 2222 \
219
+ -l ../yaib_logs/ \
220
+ --tune
221
+ ```
222
+
223
+ > For a list of available flags, run `icu-benchmarks train -h`.
224
+
225
+ > Run with `PYTORCH_ENABLE_MPS_FALLBACK=1` on Macs with Metal Performance Shaders.
226
+
227
+ [//]: # (> Please note that, for Windows based systems, paths need to be formatted differently, e.g: ` r"\..\data\mortality_seq\hirid"`.)
228
+ > For Windows based systems, the next line character (\\) needs to be replaced by (^) (Command Prompt) or (`) (Powershell)
229
+ > respectively.
230
+
231
+
232
+ Alternatively, the easiest method to train all the models in the paper is to run these commands from the directory root:
233
+
234
+ ```train
235
+ wandb sweep --verbose experiments/benchmark_classification.yml
236
+ wandb sweep --verbose experiments/benchmark_regression.yml
237
+ ```
238
+
239
+ This will create two hyperparameter sweeps for WandB for the classification and regression tasks.
240
+ This configuration will train all the models in the paper. You can then run the following command to train the models:
241
+
242
+ ```train
243
+ wandb agent <sweep_id>
244
+ ```
245
+
246
+ > Tip: You can choose to run each of the configurations on a SLURM cluster instance by `wandb agent --count 1 <sweep_id>`
247
+
248
+ > Note: You will need to have a wandb account and be logged in to run the above commands.
249
+
250
+ ## Evaluate
251
+
252
+ It is possible to evaluate a model trained on another dataset. In this case, the source dataset is the demo data from MIMIC and
253
+ the target is the eICU demo:
254
+
255
+ ```
256
+ icu-benchmarks evaluate \
257
+ -d demo_data/mortality24/eicu_demo \
258
+ -n eicu_demo \
259
+ -t BinaryClassification \
260
+ -tn Mortality24 \
261
+ -m LGBMClassifier \
262
+ --generate_cache \
263
+ --load_cache \
264
+ -s 2222 \
265
+ -l ../yaib_logs \
266
+ -sn mimic \
267
+ --source-dir ../yaib_logs/mimic_demo/Mortality24/LGBMClassifier/2022-12-12T15-24-46/fold_0
268
+ ```
269
+
270
+ ## Models
271
+
272
+ We provide several existing machine learning models that are commonly used for multivariate time-series data.
273
+ `pytorch` is used for the deep learning models, `lightgbm` for the boosted tree approaches, and `sklearn` for other classical
274
+ machine learning models.
275
+ The benchmark provides (among others) the following built-in models:
276
+
277
+ - [Logistic Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html?highlight=logistic+regression):
278
+ Standard regression approach.
279
+ - [Elastic Net](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html): Linear regression with
280
+ combined L1 and L2 priors as regularizer.
281
+ - [LightGBM](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf): Efficient gradient
282
+ boosting trees.
283
+ - [Long Short-term Memory (LSTM)](https://ieeexplore.ieee.org/document/818041): The most commonly used type of Recurrent Neural
284
+ Networks for long sequences.
285
+ - [Gated Recurrent Unit (GRU)](https://arxiv.org/abs/1406.1078) : A extension to LSTM which showed
286
+ improvements ([paper](https://arxiv.org/abs/1412.3555)).
287
+ - [Temporal Convolutional Networks (TCN)](https://arxiv.org/pdf/1803.01271 ): 1D convolution approach to sequence data. By
288
+ using dilated convolution to extend the receptive field of the network it has shown great performance on long-term
289
+ dependencies.
290
+ - [Transformers](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf): The most common Attention
291
+ based approach.
292
+
293
+ # 🛠️ Development
294
+
295
+ To adapt YAIB to your own use case, you can use
296
+ the [development information](https://github.com/rvandewater/YAIB/wiki/Contribution-and-development) page as a reference.
297
+ We appreciate contributions to the project. Please read the [contribution guidelines](CONTRIBUTING.MD) before submitting a pull
298
+ request.
299
+
300
+ # Acknowledgements
301
+
302
+ We do not own any of the datasets used in this benchmark. This project uses heavily adapted components of
303
+ the [HiRID benchmark](https://github.com/ratschlab/HIRID-ICU-Benchmark/). We thank the authors for providing this codebase and
304
+ encourage further development to benefit the scientific community. The demo datasets have been released under
305
+ an [Open Data Commons Open Database License (ODbL)](https://opendatacommons.org/licenses/odbl/1-0/).
306
+
307
+ # License
308
+
309
+ This source code is released under the MIT license, included [here](LICENSE). We do not own any of the datasets used or
310
+ included in this repository.
@@ -0,0 +1,12 @@
1
+ icu_benchmarks/__init__.py,sha256=pBMRirXTJQo-EO2fGRq6eOp7aR7f5yqa6iYkH_QVObI,159
2
+ icu_benchmarks/contants.py,sha256=_YB08CQmENCUnRWTIldaXzNKYeRlRo98qZi7qo6ZYck,155
3
+ icu_benchmarks/cross_validation.py,sha256=_DXeu-V4qTEyb-V-4ci3hrpNH0BZfSnvMpbyJh8y5RY,4875
4
+ icu_benchmarks/run.py,sha256=HrJC3Pj40Y9TDEaMCp13xgg4w5iou5SFcdzhkhmQpKI,6061
5
+ icu_benchmarks/run_utils.py,sha256=d0LsHBepoFqBHFQNgVMbgHyGXwIjCq2_lp6chFRaEHU,11821
6
+ icu_benchmarks/wandb_utils.py,sha256=zQOkXlI64TXn3ylbrxm9kbMc1TbNx2kVyNRNGmerBsU,1637
7
+ yaib-0.3.1.dist-info/LICENSE,sha256=aQO49L5qimHt0UxPTlRLicU5cSVhdUnm7k1Ec98-r8E,1215
8
+ yaib-0.3.1.dist-info/METADATA,sha256=9NpyKdl0p8qdgjxwcRU7DNaiHBRFTWquKaqk3Pe0KDw,15674
9
+ yaib-0.3.1.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
10
+ yaib-0.3.1.dist-info/entry_points.txt,sha256=2FvHz3-zDtaqFJXungdMpL9sWlL8roqgPttNzmKr8kQ,59
11
+ yaib-0.3.1.dist-info/top_level.txt,sha256=tMNfKvvlPACIbabj4GoWKS3Qd3sukUaiq7AcbqTXaiM,15
12
+ yaib-0.3.1.dist-info/RECORD,,
@@ -0,0 +1,6 @@
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.38.4)
3
+ Root-Is-Purelib: true
4
+ Tag: py2-none-any
5
+ Tag: py3-none-any
6
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ icu-benchmarks = icu_benchmarks.run:main
@@ -0,0 +1 @@
1
+ icu_benchmarks