visionkit-pro 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- visionkit/__init__.py +23 -0
- visionkit/cli.py +111 -0
- visionkit/config.py +56 -0
- visionkit/data.py +175 -0
- visionkit/evaluate.py +111 -0
- visionkit/gradcam.py +193 -0
- visionkit/metrics.py +83 -0
- visionkit/models.py +86 -0
- visionkit/pipeline.py +99 -0
- visionkit/reproducibility.py +32 -0
- visionkit/train.py +129 -0
- visionkit_pro-0.1.2.dist-info/METADATA +291 -0
- visionkit_pro-0.1.2.dist-info/RECORD +17 -0
- visionkit_pro-0.1.2.dist-info/WHEEL +5 -0
- visionkit_pro-0.1.2.dist-info/entry_points.txt +2 -0
- visionkit_pro-0.1.2.dist-info/licenses/LICENSE +21 -0
- visionkit_pro-0.1.2.dist-info/top_level.txt +1 -0
visionkit/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Reusable image-classification training, evaluation, and Grad-CAM tools."""
|
|
2
|
+
|
|
3
|
+
from visionkit.config import ExperimentConfig
|
|
4
|
+
from visionkit.data import CSVImageDataset, build_dataloaders, build_transforms
|
|
5
|
+
from visionkit.gradcam import GradCAM, generate_gradcam_report
|
|
6
|
+
from visionkit.models import build_model, load_checkpoint
|
|
7
|
+
from visionkit.pipeline import run_evaluation, run_gradcam, run_training
|
|
8
|
+
from visionkit.train import train_model
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CSVImageDataset",
|
|
12
|
+
"ExperimentConfig",
|
|
13
|
+
"GradCAM",
|
|
14
|
+
"build_dataloaders",
|
|
15
|
+
"build_model",
|
|
16
|
+
"build_transforms",
|
|
17
|
+
"generate_gradcam_report",
|
|
18
|
+
"load_checkpoint",
|
|
19
|
+
"run_evaluation",
|
|
20
|
+
"run_gradcam",
|
|
21
|
+
"run_training",
|
|
22
|
+
"train_model",
|
|
23
|
+
]
|
visionkit/cli.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from visionkit.config import ExperimentConfig
|
|
8
|
+
from visionkit.pipeline import run_evaluation, run_gradcam, run_training
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _image_size(value: str) -> tuple[int, int]:
|
|
12
|
+
parts = value.lower().replace("x", ",").split(",")
|
|
13
|
+
if len(parts) != 2:
|
|
14
|
+
raise argparse.ArgumentTypeError("image size must look like 300x300")
|
|
15
|
+
return int(parts[0]), int(parts[1])
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _base_parser(parser: argparse.ArgumentParser) -> None:
|
|
19
|
+
parser.add_argument("--csv", dest="csv_path", type=Path, required=True)
|
|
20
|
+
parser.add_argument("--image-dir", type=Path, required=True)
|
|
21
|
+
parser.add_argument("--output-dir", type=Path, required=True)
|
|
22
|
+
parser.add_argument("--filename-column", default="filename")
|
|
23
|
+
parser.add_argument("--label-column", default="label")
|
|
24
|
+
parser.add_argument("--split-column", default="split")
|
|
25
|
+
parser.add_argument("--classes", nargs="*")
|
|
26
|
+
parser.add_argument("--architecture", default="efficientnet_b3")
|
|
27
|
+
parser.add_argument("--weights", default="DEFAULT")
|
|
28
|
+
parser.add_argument("--image-size", type=_image_size, default=(300, 300))
|
|
29
|
+
parser.add_argument("--batch-size", type=int, default=8)
|
|
30
|
+
parser.add_argument("--seed", type=int, default=123)
|
|
31
|
+
parser.add_argument("--positive-class")
|
|
32
|
+
parser.add_argument("--target-layer")
|
|
33
|
+
parser.add_argument("--device")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _config(args) -> ExperimentConfig:
|
|
37
|
+
return ExperimentConfig(
|
|
38
|
+
csv_path=args.csv_path,
|
|
39
|
+
image_dir=args.image_dir,
|
|
40
|
+
output_dir=args.output_dir,
|
|
41
|
+
filename_column=args.filename_column,
|
|
42
|
+
label_column=args.label_column,
|
|
43
|
+
split_column=args.split_column,
|
|
44
|
+
train_split=getattr(args, "train_split", "train"),
|
|
45
|
+
val_split=getattr(args, "val_split", "validation"),
|
|
46
|
+
class_names=args.classes,
|
|
47
|
+
architecture=args.architecture,
|
|
48
|
+
weights=args.weights,
|
|
49
|
+
image_size=args.image_size,
|
|
50
|
+
batch_size=args.batch_size,
|
|
51
|
+
epochs=getattr(args, "epochs", 30),
|
|
52
|
+
learning_rate=getattr(args, "learning_rate", 1e-3),
|
|
53
|
+
seed=args.seed,
|
|
54
|
+
positive_class=args.positive_class,
|
|
55
|
+
target_layer=args.target_layer,
|
|
56
|
+
device=args.device,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
61
|
+
parser = argparse.ArgumentParser(prog="visionkit")
|
|
62
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
63
|
+
|
|
64
|
+
train = subparsers.add_parser("train")
|
|
65
|
+
_base_parser(train)
|
|
66
|
+
train.add_argument("--train-split", default="train")
|
|
67
|
+
train.add_argument("--val-split", default="validation")
|
|
68
|
+
train.add_argument("--epochs", type=int, default=30)
|
|
69
|
+
train.add_argument("--learning-rate", type=float, default=1e-3)
|
|
70
|
+
|
|
71
|
+
evaluate = subparsers.add_parser("evaluate")
|
|
72
|
+
_base_parser(evaluate)
|
|
73
|
+
evaluate.add_argument("--checkpoint", type=Path, required=True)
|
|
74
|
+
evaluate.add_argument("--split", default="validation")
|
|
75
|
+
evaluate.add_argument("--output-csv", type=Path)
|
|
76
|
+
|
|
77
|
+
gradcam = subparsers.add_parser("gradcam")
|
|
78
|
+
_base_parser(gradcam)
|
|
79
|
+
gradcam.add_argument("--checkpoint", type=Path, required=True)
|
|
80
|
+
gradcam.add_argument("--split", default="validation")
|
|
81
|
+
gradcam.add_argument("--gradcam-output-dir", type=Path)
|
|
82
|
+
gradcam.add_argument("--output-csv", type=Path)
|
|
83
|
+
gradcam.add_argument("--target-class")
|
|
84
|
+
return parser
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def main(argv: Optional[list[str]] = None) -> None:
|
|
88
|
+
parser = build_parser()
|
|
89
|
+
args = parser.parse_args(argv)
|
|
90
|
+
config = _config(args)
|
|
91
|
+
|
|
92
|
+
if args.command == "train":
|
|
93
|
+
_, history = run_training(config)
|
|
94
|
+
print(history.tail(1).to_string(index=False))
|
|
95
|
+
elif args.command == "evaluate":
|
|
96
|
+
_, metrics = run_evaluation(config, args.checkpoint, split_name=args.split, output_csv=args.output_csv)
|
|
97
|
+
print(metrics)
|
|
98
|
+
elif args.command == "gradcam":
|
|
99
|
+
_, metrics = run_gradcam(
|
|
100
|
+
config,
|
|
101
|
+
args.checkpoint,
|
|
102
|
+
split_name=args.split,
|
|
103
|
+
output_dir=args.gradcam_output_dir,
|
|
104
|
+
output_csv=args.output_csv,
|
|
105
|
+
target_class=args.target_class,
|
|
106
|
+
)
|
|
107
|
+
print(metrics)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
if __name__ == "__main__":
|
|
111
|
+
main()
|
visionkit/config.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, Sequence
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ExperimentConfig:
|
|
10
|
+
csv_path: Path
|
|
11
|
+
image_dir: Path
|
|
12
|
+
output_dir: Path
|
|
13
|
+
filename_column: str = "filename"
|
|
14
|
+
label_column: str = "label"
|
|
15
|
+
split_column: Optional[str] = "split"
|
|
16
|
+
train_split: str = "train"
|
|
17
|
+
val_split: str = "validation"
|
|
18
|
+
class_names: Optional[Sequence[str]] = None
|
|
19
|
+
architecture: str = "efficientnet_b3"
|
|
20
|
+
weights: Optional[str] = "DEFAULT"
|
|
21
|
+
image_size: tuple[int, int] = (300, 300)
|
|
22
|
+
batch_size: int = 8
|
|
23
|
+
epochs: int = 30
|
|
24
|
+
learning_rate: float = 1e-3
|
|
25
|
+
num_workers: Optional[int] = None
|
|
26
|
+
seed: int = 123
|
|
27
|
+
positive_class: Optional[str] = None
|
|
28
|
+
target_layer: Optional[str] = None
|
|
29
|
+
device: Optional[str] = None
|
|
30
|
+
checkpoint_metric: str = "val_auc"
|
|
31
|
+
save_every_epoch: bool = True
|
|
32
|
+
normalize_mean: tuple[float, float, float] = (0.485, 0.456, 0.406)
|
|
33
|
+
normalize_std: tuple[float, float, float] = (0.229, 0.224, 0.225)
|
|
34
|
+
extra_metadata_columns: Sequence[str] = field(default_factory=tuple)
|
|
35
|
+
|
|
36
|
+
def model_dir(self) -> Path:
|
|
37
|
+
return self.output_dir / "models"
|
|
38
|
+
|
|
39
|
+
def log_dir(self) -> Path:
|
|
40
|
+
return self.output_dir / "logs"
|
|
41
|
+
|
|
42
|
+
def prediction_dir(self) -> Path:
|
|
43
|
+
return self.output_dir / "predictions"
|
|
44
|
+
|
|
45
|
+
def gradcam_dir(self) -> Path:
|
|
46
|
+
return self.output_dir / "gradcam"
|
|
47
|
+
|
|
48
|
+
def ensure_output_dirs(self) -> None:
|
|
49
|
+
for path in [
|
|
50
|
+
self.output_dir,
|
|
51
|
+
self.model_dir(),
|
|
52
|
+
self.log_dir(),
|
|
53
|
+
self.prediction_dir(),
|
|
54
|
+
self.gradcam_dir(),
|
|
55
|
+
]:
|
|
56
|
+
path.mkdir(parents=True, exist_ok=True)
|
visionkit/data.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, Sequence
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
|
11
|
+
from torchvision import transforms
|
|
12
|
+
|
|
13
|
+
from visionkit.config import ExperimentConfig
|
|
14
|
+
from visionkit.reproducibility import seed_worker
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CSVImageDataset(Dataset):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
rows: pd.DataFrame,
|
|
21
|
+
image_dir: Path,
|
|
22
|
+
filename_column: str = "filename",
|
|
23
|
+
label_column: Optional[str] = "label",
|
|
24
|
+
class_to_idx: Optional[dict[str, int]] = None,
|
|
25
|
+
transform=None,
|
|
26
|
+
return_metadata: bool = False,
|
|
27
|
+
) -> None:
|
|
28
|
+
self.rows = rows.reset_index(drop=True).copy()
|
|
29
|
+
self.image_dir = Path(image_dir)
|
|
30
|
+
self.filename_column = filename_column
|
|
31
|
+
self.label_column = label_column
|
|
32
|
+
self.class_to_idx = class_to_idx or {}
|
|
33
|
+
self.transform = transform
|
|
34
|
+
self.return_metadata = return_metadata
|
|
35
|
+
self.samples: list[tuple[Path, Optional[int], dict]] = []
|
|
36
|
+
|
|
37
|
+
missing_files = []
|
|
38
|
+
unknown_labels = []
|
|
39
|
+
for record in self.rows.to_dict(orient="records"):
|
|
40
|
+
filename = str(record[filename_column])
|
|
41
|
+
image_path = self.image_dir / filename
|
|
42
|
+
if not image_path.exists():
|
|
43
|
+
missing_files.append(filename)
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
target = None
|
|
47
|
+
if label_column is not None and label_column in record:
|
|
48
|
+
label = str(record[label_column])
|
|
49
|
+
if label not in self.class_to_idx:
|
|
50
|
+
unknown_labels.append(label)
|
|
51
|
+
continue
|
|
52
|
+
target = self.class_to_idx[label]
|
|
53
|
+
self.samples.append((image_path, target, record))
|
|
54
|
+
|
|
55
|
+
if missing_files:
|
|
56
|
+
preview = ", ".join(missing_files[:5])
|
|
57
|
+
raise FileNotFoundError(f"Images listed in CSV were not found: {preview}")
|
|
58
|
+
if unknown_labels:
|
|
59
|
+
labels = ", ".join(sorted(set(unknown_labels)))
|
|
60
|
+
raise ValueError(f"Labels are missing from class mapping: {labels}")
|
|
61
|
+
|
|
62
|
+
def __len__(self) -> int:
|
|
63
|
+
return len(self.samples)
|
|
64
|
+
|
|
65
|
+
def __getitem__(self, index: int):
|
|
66
|
+
image_path, target, record = self.samples[index]
|
|
67
|
+
image = Image.open(image_path).convert("RGB")
|
|
68
|
+
if self.transform is not None:
|
|
69
|
+
image = self.transform(image)
|
|
70
|
+
if self.return_metadata:
|
|
71
|
+
return image, target, record
|
|
72
|
+
return image, target
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def read_metadata(csv_path: Path) -> pd.DataFrame:
|
|
76
|
+
return pd.read_csv(csv_path)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def infer_class_names(df: pd.DataFrame, label_column: str, class_names: Optional[Sequence[str]]) -> list[str]:
|
|
80
|
+
if class_names:
|
|
81
|
+
return [str(name) for name in class_names]
|
|
82
|
+
return sorted(str(label) for label in df[label_column].dropna().unique())
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def class_mapping(class_names: Sequence[str]) -> tuple[dict[str, int], list[str]]:
|
|
86
|
+
names = [str(name) for name in class_names]
|
|
87
|
+
return {name: idx for idx, name in enumerate(names)}, names
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def filter_split(df: pd.DataFrame, split_column: Optional[str], split_name: Optional[str]) -> pd.DataFrame:
|
|
91
|
+
if not split_column or split_name is None:
|
|
92
|
+
return df.copy()
|
|
93
|
+
if split_column not in df.columns:
|
|
94
|
+
raise KeyError(f"Split column '{split_column}' was not found in CSV.")
|
|
95
|
+
return df[df[split_column].astype(str) == str(split_name)].copy()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def build_transforms(
|
|
99
|
+
image_size: tuple[int, int] = (300, 300),
|
|
100
|
+
train: bool = False,
|
|
101
|
+
mean: tuple[float, float, float] = (0.485, 0.456, 0.406),
|
|
102
|
+
std: tuple[float, float, float] = (0.229, 0.224, 0.225),
|
|
103
|
+
):
|
|
104
|
+
ops = [transforms.Resize(image_size)]
|
|
105
|
+
if train:
|
|
106
|
+
ops.extend(
|
|
107
|
+
[
|
|
108
|
+
transforms.RandomHorizontalFlip(),
|
|
109
|
+
transforms.RandomVerticalFlip(),
|
|
110
|
+
transforms.RandomRotation(10),
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
ops.extend([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
|
114
|
+
return transforms.Compose(ops)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def build_dataset(
|
|
118
|
+
config: ExperimentConfig,
|
|
119
|
+
split_name: Optional[str],
|
|
120
|
+
transform,
|
|
121
|
+
return_metadata: bool = False,
|
|
122
|
+
) -> tuple[CSVImageDataset, list[str]]:
|
|
123
|
+
df = read_metadata(config.csv_path)
|
|
124
|
+
class_names = infer_class_names(df, config.label_column, config.class_names)
|
|
125
|
+
class_to_idx, class_names = class_mapping(class_names)
|
|
126
|
+
rows = filter_split(df, config.split_column, split_name)
|
|
127
|
+
dataset = CSVImageDataset(
|
|
128
|
+
rows=rows,
|
|
129
|
+
image_dir=config.image_dir,
|
|
130
|
+
filename_column=config.filename_column,
|
|
131
|
+
label_column=config.label_column,
|
|
132
|
+
class_to_idx=class_to_idx,
|
|
133
|
+
transform=transform,
|
|
134
|
+
return_metadata=return_metadata,
|
|
135
|
+
)
|
|
136
|
+
return dataset, class_names
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def build_dataloaders(config: ExperimentConfig) -> tuple[DataLoader, DataLoader, list[str]]:
|
|
140
|
+
train_transform = build_transforms(
|
|
141
|
+
image_size=config.image_size,
|
|
142
|
+
train=True,
|
|
143
|
+
mean=config.normalize_mean,
|
|
144
|
+
std=config.normalize_std,
|
|
145
|
+
)
|
|
146
|
+
eval_transform = build_transforms(
|
|
147
|
+
image_size=config.image_size,
|
|
148
|
+
train=False,
|
|
149
|
+
mean=config.normalize_mean,
|
|
150
|
+
std=config.normalize_std,
|
|
151
|
+
)
|
|
152
|
+
train_dataset, class_names = build_dataset(config, config.train_split, train_transform)
|
|
153
|
+
val_dataset, _ = build_dataset(config, config.val_split, eval_transform)
|
|
154
|
+
|
|
155
|
+
generator = torch.Generator().manual_seed(config.seed)
|
|
156
|
+
num_workers = config.num_workers if config.num_workers is not None else 0
|
|
157
|
+
worker_init = partial(seed_worker, base_seed=config.seed)
|
|
158
|
+
train_loader = DataLoader(
|
|
159
|
+
train_dataset,
|
|
160
|
+
batch_size=config.batch_size,
|
|
161
|
+
shuffle=True,
|
|
162
|
+
num_workers=num_workers,
|
|
163
|
+
pin_memory=torch.cuda.is_available(),
|
|
164
|
+
worker_init_fn=worker_init,
|
|
165
|
+
generator=generator,
|
|
166
|
+
)
|
|
167
|
+
val_loader = DataLoader(
|
|
168
|
+
val_dataset,
|
|
169
|
+
batch_size=config.batch_size,
|
|
170
|
+
shuffle=False,
|
|
171
|
+
num_workers=num_workers,
|
|
172
|
+
pin_memory=torch.cuda.is_available(),
|
|
173
|
+
worker_init_fn=worker_init,
|
|
174
|
+
)
|
|
175
|
+
return train_loader, val_loader, class_names
|
visionkit/evaluate.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
from visionkit.data import build_dataset, build_transforms
|
|
13
|
+
from visionkit.metrics import classification_metrics
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _record_at(records, index: int) -> dict:
|
|
17
|
+
if not isinstance(records, dict):
|
|
18
|
+
return dict(records[index])
|
|
19
|
+
row = {}
|
|
20
|
+
for key, value in records.items():
|
|
21
|
+
item = value[index]
|
|
22
|
+
if hasattr(item, "item"):
|
|
23
|
+
try:
|
|
24
|
+
item = item.item()
|
|
25
|
+
except ValueError:
|
|
26
|
+
pass
|
|
27
|
+
row[key] = item
|
|
28
|
+
return row
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def predict(
|
|
32
|
+
model: torch.nn.Module,
|
|
33
|
+
dataloader: DataLoader,
|
|
34
|
+
class_names: list[str],
|
|
35
|
+
device: Optional[str] = None,
|
|
36
|
+
positive_index: Optional[int] = None,
|
|
37
|
+
) -> pd.DataFrame:
|
|
38
|
+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
39
|
+
positive_index = positive_index if positive_index is not None else min(1, len(class_names) - 1)
|
|
40
|
+
model.to(device)
|
|
41
|
+
model.eval()
|
|
42
|
+
rows = []
|
|
43
|
+
|
|
44
|
+
with torch.no_grad():
|
|
45
|
+
for images, labels, records in tqdm(dataloader, desc="predict"):
|
|
46
|
+
images = images.to(device)
|
|
47
|
+
outputs = model(images)
|
|
48
|
+
probabilities = F.softmax(outputs, dim=1)
|
|
49
|
+
predicted = probabilities.argmax(dim=1).cpu().tolist()
|
|
50
|
+
scores = probabilities[:, positive_index].cpu().tolist()
|
|
51
|
+
|
|
52
|
+
for idx in range(images.size(0)):
|
|
53
|
+
record = _record_at(records, idx)
|
|
54
|
+
true_label_idx = labels[idx].item() if labels[idx] is not None else None
|
|
55
|
+
row = dict(record)
|
|
56
|
+
row.update(
|
|
57
|
+
{
|
|
58
|
+
"true_label_num": true_label_idx,
|
|
59
|
+
"true_label": class_names[true_label_idx] if true_label_idx is not None else None,
|
|
60
|
+
"predicted_class_num": predicted[idx],
|
|
61
|
+
"predicted_class": class_names[predicted[idx]],
|
|
62
|
+
"positive_class_probability": scores[idx],
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
for class_idx, class_name in enumerate(class_names):
|
|
66
|
+
row[f"probability_{class_name}"] = probabilities[idx, class_idx].item()
|
|
67
|
+
rows.append(row)
|
|
68
|
+
return pd.DataFrame(rows)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def evaluate_dataframe(
|
|
72
|
+
predictions: pd.DataFrame,
|
|
73
|
+
positive_index: int = 1,
|
|
74
|
+
n_bootstraps: int = 1000,
|
|
75
|
+
seed: int = 123,
|
|
76
|
+
) -> dict[str, float]:
|
|
77
|
+
if predictions.empty or predictions["true_label_num"].nunique() < 2:
|
|
78
|
+
return {}
|
|
79
|
+
return classification_metrics(
|
|
80
|
+
predictions["true_label_num"].astype(int),
|
|
81
|
+
predictions["predicted_class_num"].astype(int),
|
|
82
|
+
predictions["positive_class_probability"],
|
|
83
|
+
positive_index=positive_index,
|
|
84
|
+
n_bootstraps=n_bootstraps,
|
|
85
|
+
seed=seed,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def evaluate_model(
|
|
90
|
+
model: torch.nn.Module,
|
|
91
|
+
config,
|
|
92
|
+
split_name: Optional[str],
|
|
93
|
+
class_names: list[str],
|
|
94
|
+
output_csv: Path,
|
|
95
|
+
positive_index: int = 1,
|
|
96
|
+
) -> tuple[pd.DataFrame, dict[str, float]]:
|
|
97
|
+
transform = build_transforms(
|
|
98
|
+
image_size=config.image_size,
|
|
99
|
+
train=False,
|
|
100
|
+
mean=config.normalize_mean,
|
|
101
|
+
std=config.normalize_std,
|
|
102
|
+
)
|
|
103
|
+
dataset, _ = build_dataset(config, split_name, transform, return_metadata=True)
|
|
104
|
+
loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, num_workers=0)
|
|
105
|
+
predictions = predict(model, loader, class_names, device=config.device, positive_index=positive_index)
|
|
106
|
+
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
|
107
|
+
predictions.to_csv(output_csv, index=False)
|
|
108
|
+
metrics = evaluate_dataframe(predictions, positive_index=positive_index, seed=config.seed)
|
|
109
|
+
if metrics:
|
|
110
|
+
pd.DataFrame([metrics]).to_csv(output_csv.with_name(output_csv.stem + "_metrics.csv"), index=False)
|
|
111
|
+
return predictions, metrics
|
visionkit/gradcam.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
|
+
|
|
14
|
+
from visionkit.data import build_transforms, class_mapping, filter_split, infer_class_names, read_metadata
|
|
15
|
+
from visionkit.evaluate import evaluate_dataframe
|
|
16
|
+
from visionkit.models import get_module_by_path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GradCAM:
|
|
20
|
+
def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module) -> None:
|
|
21
|
+
self.model = model
|
|
22
|
+
self.target_layer = target_layer
|
|
23
|
+
self.gradients = None
|
|
24
|
+
self.activations = None
|
|
25
|
+
self.handles = []
|
|
26
|
+
self._register_hooks()
|
|
27
|
+
|
|
28
|
+
def _register_hooks(self) -> None:
|
|
29
|
+
def forward_hook(module, inputs, output):
|
|
30
|
+
self.activations = output.detach()
|
|
31
|
+
|
|
32
|
+
def backward_hook(module, grad_input, grad_output):
|
|
33
|
+
self.gradients = grad_output[0].detach()
|
|
34
|
+
|
|
35
|
+
self.handles.append(self.target_layer.register_forward_hook(forward_hook))
|
|
36
|
+
self.handles.append(self.target_layer.register_full_backward_hook(backward_hook))
|
|
37
|
+
|
|
38
|
+
def generate(self, input_tensor: torch.Tensor, target_class: Optional[int] = None) -> np.ndarray:
|
|
39
|
+
output = self.model(input_tensor)
|
|
40
|
+
if target_class is None:
|
|
41
|
+
target_class = int(output.argmax(dim=1).item())
|
|
42
|
+
self.model.zero_grad()
|
|
43
|
+
output[0, target_class].backward()
|
|
44
|
+
|
|
45
|
+
if self.gradients is None or self.activations is None:
|
|
46
|
+
raise RuntimeError("Grad-CAM hooks did not capture gradients or activations.")
|
|
47
|
+
|
|
48
|
+
gradients = self.gradients[0]
|
|
49
|
+
activations = self.activations[0]
|
|
50
|
+
weights = gradients.mean(dim=(1, 2), keepdim=True)
|
|
51
|
+
cam = torch.sum(weights * activations, dim=0)
|
|
52
|
+
cam = F.relu(cam)
|
|
53
|
+
cam = F.interpolate(
|
|
54
|
+
cam.unsqueeze(0).unsqueeze(0),
|
|
55
|
+
size=input_tensor.shape[2:],
|
|
56
|
+
mode="bilinear",
|
|
57
|
+
align_corners=False,
|
|
58
|
+
)
|
|
59
|
+
cam_np = cam.squeeze().detach().cpu().numpy()
|
|
60
|
+
denominator = cam_np.max() - cam_np.min()
|
|
61
|
+
if denominator <= 1e-12:
|
|
62
|
+
return np.zeros_like(cam_np)
|
|
63
|
+
return (cam_np - cam_np.min()) / denominator
|
|
64
|
+
|
|
65
|
+
def close(self) -> None:
|
|
66
|
+
for handle in self.handles:
|
|
67
|
+
handle.remove()
|
|
68
|
+
self.handles.clear()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _denormalize(image_tensor: torch.Tensor, mean: Sequence[float], std: Sequence[float]) -> np.ndarray:
|
|
72
|
+
image = image_tensor.detach().cpu().clone()
|
|
73
|
+
for channel, (mean_value, std_value) in enumerate(zip(mean, std)):
|
|
74
|
+
image[channel] = image[channel] * std_value + mean_value
|
|
75
|
+
image = image.clamp(0, 1).permute(1, 2, 0).numpy()
|
|
76
|
+
return image
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def save_gradcam_figure(
|
|
80
|
+
original_image: np.ndarray,
|
|
81
|
+
cam: np.ndarray,
|
|
82
|
+
save_path: Path,
|
|
83
|
+
title: str,
|
|
84
|
+
alpha: float = 0.4,
|
|
85
|
+
) -> None:
|
|
86
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
|
|
88
|
+
axes[0].imshow(original_image)
|
|
89
|
+
axes[0].set_title(title)
|
|
90
|
+
axes[1].imshow(cam, cmap="jet")
|
|
91
|
+
axes[1].set_title("Grad-CAM")
|
|
92
|
+
axes[2].imshow(original_image)
|
|
93
|
+
axes[2].imshow(cam, cmap="jet", alpha=alpha)
|
|
94
|
+
axes[2].set_title("Overlay")
|
|
95
|
+
for axis in axes:
|
|
96
|
+
axis.axis("off")
|
|
97
|
+
fig.tight_layout()
|
|
98
|
+
fig.savefig(save_path, dpi=150)
|
|
99
|
+
plt.close(fig)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def generate_gradcam_report(
|
|
103
|
+
model: torch.nn.Module,
|
|
104
|
+
csv_path: Path,
|
|
105
|
+
image_dir: Path,
|
|
106
|
+
output_dir: Path,
|
|
107
|
+
output_csv: Path,
|
|
108
|
+
filename_column: str = "filename",
|
|
109
|
+
label_column: str = "label",
|
|
110
|
+
split_column: Optional[str] = "split",
|
|
111
|
+
split_name: Optional[str] = None,
|
|
112
|
+
class_names: Optional[Sequence[str]] = None,
|
|
113
|
+
target_class: Optional[Union[str, int]] = None,
|
|
114
|
+
target_layer: Optional[str] = None,
|
|
115
|
+
image_size: tuple[int, int] = (300, 300),
|
|
116
|
+
mean: tuple[float, float, float] = (0.485, 0.456, 0.406),
|
|
117
|
+
std: tuple[float, float, float] = (0.229, 0.224, 0.225),
|
|
118
|
+
device: Optional[str] = None,
|
|
119
|
+
) -> tuple[pd.DataFrame, dict[str, float]]:
|
|
120
|
+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
121
|
+
model.to(device)
|
|
122
|
+
model.eval()
|
|
123
|
+
df = filter_split(read_metadata(csv_path), split_column, split_name)
|
|
124
|
+
class_names = infer_class_names(df, label_column, class_names)
|
|
125
|
+
class_to_idx, class_names = class_mapping(class_names)
|
|
126
|
+
positive_index = min(1, len(class_names) - 1)
|
|
127
|
+
if target_class is None:
|
|
128
|
+
target_index = positive_index
|
|
129
|
+
elif isinstance(target_class, int):
|
|
130
|
+
target_index = target_class
|
|
131
|
+
else:
|
|
132
|
+
target_index = class_to_idx[str(target_class)]
|
|
133
|
+
|
|
134
|
+
transform = build_transforms(image_size=image_size, train=False, mean=mean, std=std)
|
|
135
|
+
target_module = get_module_by_path(model, target_layer)
|
|
136
|
+
gradcam = GradCAM(model, target_module)
|
|
137
|
+
output_dir = Path(output_dir)
|
|
138
|
+
output_csv = Path(output_csv)
|
|
139
|
+
results = []
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
for record in tqdm(df.to_dict(orient="records"), desc="gradcam"):
|
|
143
|
+
filename = str(record[filename_column])
|
|
144
|
+
image_path = Path(image_dir) / filename
|
|
145
|
+
if not image_path.exists():
|
|
146
|
+
continue
|
|
147
|
+
image = Image.open(image_path).convert("RGB")
|
|
148
|
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
|
149
|
+
cam = gradcam.generate(image_tensor, target_class=target_index)
|
|
150
|
+
|
|
151
|
+
with torch.no_grad():
|
|
152
|
+
outputs = model(image_tensor)
|
|
153
|
+
probabilities = F.softmax(outputs, dim=1)[0].detach().cpu()
|
|
154
|
+
predicted_index = int(probabilities.argmax().item())
|
|
155
|
+
|
|
156
|
+
true_label = str(record[label_column]) if label_column in record else None
|
|
157
|
+
true_index = class_to_idx[true_label] if true_label in class_to_idx else None
|
|
158
|
+
original = _denormalize(image_tensor[0], mean, std)
|
|
159
|
+
pred_label = class_names[predicted_index]
|
|
160
|
+
save_name = f"{Path(filename).stem}__pred-{pred_label}.png"
|
|
161
|
+
save_gradcam_figure(
|
|
162
|
+
original,
|
|
163
|
+
cam,
|
|
164
|
+
output_dir / save_name,
|
|
165
|
+
title=f"{filename}\ntrue={true_label} pred={pred_label}",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
result = dict(record)
|
|
169
|
+
result.update(
|
|
170
|
+
{
|
|
171
|
+
"gradcam_file": save_name,
|
|
172
|
+
"target_class_num": target_index,
|
|
173
|
+
"target_class": class_names[target_index],
|
|
174
|
+
"true_label_num": true_index,
|
|
175
|
+
"true_label": true_label,
|
|
176
|
+
"predicted_class_num": predicted_index,
|
|
177
|
+
"predicted_class": pred_label,
|
|
178
|
+
"positive_class_probability": float(probabilities[positive_index].item()),
|
|
179
|
+
}
|
|
180
|
+
)
|
|
181
|
+
for class_idx, class_name in enumerate(class_names):
|
|
182
|
+
result[f"probability_{class_name}"] = float(probabilities[class_idx].item())
|
|
183
|
+
results.append(result)
|
|
184
|
+
finally:
|
|
185
|
+
gradcam.close()
|
|
186
|
+
|
|
187
|
+
results_df = pd.DataFrame(results)
|
|
188
|
+
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
|
189
|
+
results_df.to_csv(output_csv, index=False)
|
|
190
|
+
metrics = evaluate_dataframe(results_df, positive_index=positive_index)
|
|
191
|
+
if metrics:
|
|
192
|
+
pd.DataFrame([metrics]).to_csv(output_csv.with_name(output_csv.stem + "_metrics.csv"), index=False)
|
|
193
|
+
return results_df, metrics
|