visionkit-pro 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
visionkit/__init__.py ADDED
@@ -0,0 +1,23 @@
1
+ """Reusable image-classification training, evaluation, and Grad-CAM tools."""
2
+
3
+ from visionkit.config import ExperimentConfig
4
+ from visionkit.data import CSVImageDataset, build_dataloaders, build_transforms
5
+ from visionkit.gradcam import GradCAM, generate_gradcam_report
6
+ from visionkit.models import build_model, load_checkpoint
7
+ from visionkit.pipeline import run_evaluation, run_gradcam, run_training
8
+ from visionkit.train import train_model
9
+
10
+ __all__ = [
11
+ "CSVImageDataset",
12
+ "ExperimentConfig",
13
+ "GradCAM",
14
+ "build_dataloaders",
15
+ "build_model",
16
+ "build_transforms",
17
+ "generate_gradcam_report",
18
+ "load_checkpoint",
19
+ "run_evaluation",
20
+ "run_gradcam",
21
+ "run_training",
22
+ "train_model",
23
+ ]
visionkit/cli.py ADDED
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from visionkit.config import ExperimentConfig
8
+ from visionkit.pipeline import run_evaluation, run_gradcam, run_training
9
+
10
+
11
+ def _image_size(value: str) -> tuple[int, int]:
12
+ parts = value.lower().replace("x", ",").split(",")
13
+ if len(parts) != 2:
14
+ raise argparse.ArgumentTypeError("image size must look like 300x300")
15
+ return int(parts[0]), int(parts[1])
16
+
17
+
18
+ def _base_parser(parser: argparse.ArgumentParser) -> None:
19
+ parser.add_argument("--csv", dest="csv_path", type=Path, required=True)
20
+ parser.add_argument("--image-dir", type=Path, required=True)
21
+ parser.add_argument("--output-dir", type=Path, required=True)
22
+ parser.add_argument("--filename-column", default="filename")
23
+ parser.add_argument("--label-column", default="label")
24
+ parser.add_argument("--split-column", default="split")
25
+ parser.add_argument("--classes", nargs="*")
26
+ parser.add_argument("--architecture", default="efficientnet_b3")
27
+ parser.add_argument("--weights", default="DEFAULT")
28
+ parser.add_argument("--image-size", type=_image_size, default=(300, 300))
29
+ parser.add_argument("--batch-size", type=int, default=8)
30
+ parser.add_argument("--seed", type=int, default=123)
31
+ parser.add_argument("--positive-class")
32
+ parser.add_argument("--target-layer")
33
+ parser.add_argument("--device")
34
+
35
+
36
+ def _config(args) -> ExperimentConfig:
37
+ return ExperimentConfig(
38
+ csv_path=args.csv_path,
39
+ image_dir=args.image_dir,
40
+ output_dir=args.output_dir,
41
+ filename_column=args.filename_column,
42
+ label_column=args.label_column,
43
+ split_column=args.split_column,
44
+ train_split=getattr(args, "train_split", "train"),
45
+ val_split=getattr(args, "val_split", "validation"),
46
+ class_names=args.classes,
47
+ architecture=args.architecture,
48
+ weights=args.weights,
49
+ image_size=args.image_size,
50
+ batch_size=args.batch_size,
51
+ epochs=getattr(args, "epochs", 30),
52
+ learning_rate=getattr(args, "learning_rate", 1e-3),
53
+ seed=args.seed,
54
+ positive_class=args.positive_class,
55
+ target_layer=args.target_layer,
56
+ device=args.device,
57
+ )
58
+
59
+
60
+ def build_parser() -> argparse.ArgumentParser:
61
+ parser = argparse.ArgumentParser(prog="visionkit")
62
+ subparsers = parser.add_subparsers(dest="command", required=True)
63
+
64
+ train = subparsers.add_parser("train")
65
+ _base_parser(train)
66
+ train.add_argument("--train-split", default="train")
67
+ train.add_argument("--val-split", default="validation")
68
+ train.add_argument("--epochs", type=int, default=30)
69
+ train.add_argument("--learning-rate", type=float, default=1e-3)
70
+
71
+ evaluate = subparsers.add_parser("evaluate")
72
+ _base_parser(evaluate)
73
+ evaluate.add_argument("--checkpoint", type=Path, required=True)
74
+ evaluate.add_argument("--split", default="validation")
75
+ evaluate.add_argument("--output-csv", type=Path)
76
+
77
+ gradcam = subparsers.add_parser("gradcam")
78
+ _base_parser(gradcam)
79
+ gradcam.add_argument("--checkpoint", type=Path, required=True)
80
+ gradcam.add_argument("--split", default="validation")
81
+ gradcam.add_argument("--gradcam-output-dir", type=Path)
82
+ gradcam.add_argument("--output-csv", type=Path)
83
+ gradcam.add_argument("--target-class")
84
+ return parser
85
+
86
+
87
+ def main(argv: Optional[list[str]] = None) -> None:
88
+ parser = build_parser()
89
+ args = parser.parse_args(argv)
90
+ config = _config(args)
91
+
92
+ if args.command == "train":
93
+ _, history = run_training(config)
94
+ print(history.tail(1).to_string(index=False))
95
+ elif args.command == "evaluate":
96
+ _, metrics = run_evaluation(config, args.checkpoint, split_name=args.split, output_csv=args.output_csv)
97
+ print(metrics)
98
+ elif args.command == "gradcam":
99
+ _, metrics = run_gradcam(
100
+ config,
101
+ args.checkpoint,
102
+ split_name=args.split,
103
+ output_dir=args.gradcam_output_dir,
104
+ output_csv=args.output_csv,
105
+ target_class=args.target_class,
106
+ )
107
+ print(metrics)
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
visionkit/config.py ADDED
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Optional, Sequence
6
+
7
+
8
+ @dataclass
9
+ class ExperimentConfig:
10
+ csv_path: Path
11
+ image_dir: Path
12
+ output_dir: Path
13
+ filename_column: str = "filename"
14
+ label_column: str = "label"
15
+ split_column: Optional[str] = "split"
16
+ train_split: str = "train"
17
+ val_split: str = "validation"
18
+ class_names: Optional[Sequence[str]] = None
19
+ architecture: str = "efficientnet_b3"
20
+ weights: Optional[str] = "DEFAULT"
21
+ image_size: tuple[int, int] = (300, 300)
22
+ batch_size: int = 8
23
+ epochs: int = 30
24
+ learning_rate: float = 1e-3
25
+ num_workers: Optional[int] = None
26
+ seed: int = 123
27
+ positive_class: Optional[str] = None
28
+ target_layer: Optional[str] = None
29
+ device: Optional[str] = None
30
+ checkpoint_metric: str = "val_auc"
31
+ save_every_epoch: bool = True
32
+ normalize_mean: tuple[float, float, float] = (0.485, 0.456, 0.406)
33
+ normalize_std: tuple[float, float, float] = (0.229, 0.224, 0.225)
34
+ extra_metadata_columns: Sequence[str] = field(default_factory=tuple)
35
+
36
+ def model_dir(self) -> Path:
37
+ return self.output_dir / "models"
38
+
39
+ def log_dir(self) -> Path:
40
+ return self.output_dir / "logs"
41
+
42
+ def prediction_dir(self) -> Path:
43
+ return self.output_dir / "predictions"
44
+
45
+ def gradcam_dir(self) -> Path:
46
+ return self.output_dir / "gradcam"
47
+
48
+ def ensure_output_dirs(self) -> None:
49
+ for path in [
50
+ self.output_dir,
51
+ self.model_dir(),
52
+ self.log_dir(),
53
+ self.prediction_dir(),
54
+ self.gradcam_dir(),
55
+ ]:
56
+ path.mkdir(parents=True, exist_ok=True)
visionkit/data.py ADDED
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+ from pathlib import Path
5
+ from typing import Optional, Sequence
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torchvision import transforms
12
+
13
+ from visionkit.config import ExperimentConfig
14
+ from visionkit.reproducibility import seed_worker
15
+
16
+
17
+ class CSVImageDataset(Dataset):
18
+ def __init__(
19
+ self,
20
+ rows: pd.DataFrame,
21
+ image_dir: Path,
22
+ filename_column: str = "filename",
23
+ label_column: Optional[str] = "label",
24
+ class_to_idx: Optional[dict[str, int]] = None,
25
+ transform=None,
26
+ return_metadata: bool = False,
27
+ ) -> None:
28
+ self.rows = rows.reset_index(drop=True).copy()
29
+ self.image_dir = Path(image_dir)
30
+ self.filename_column = filename_column
31
+ self.label_column = label_column
32
+ self.class_to_idx = class_to_idx or {}
33
+ self.transform = transform
34
+ self.return_metadata = return_metadata
35
+ self.samples: list[tuple[Path, Optional[int], dict]] = []
36
+
37
+ missing_files = []
38
+ unknown_labels = []
39
+ for record in self.rows.to_dict(orient="records"):
40
+ filename = str(record[filename_column])
41
+ image_path = self.image_dir / filename
42
+ if not image_path.exists():
43
+ missing_files.append(filename)
44
+ continue
45
+
46
+ target = None
47
+ if label_column is not None and label_column in record:
48
+ label = str(record[label_column])
49
+ if label not in self.class_to_idx:
50
+ unknown_labels.append(label)
51
+ continue
52
+ target = self.class_to_idx[label]
53
+ self.samples.append((image_path, target, record))
54
+
55
+ if missing_files:
56
+ preview = ", ".join(missing_files[:5])
57
+ raise FileNotFoundError(f"Images listed in CSV were not found: {preview}")
58
+ if unknown_labels:
59
+ labels = ", ".join(sorted(set(unknown_labels)))
60
+ raise ValueError(f"Labels are missing from class mapping: {labels}")
61
+
62
+ def __len__(self) -> int:
63
+ return len(self.samples)
64
+
65
+ def __getitem__(self, index: int):
66
+ image_path, target, record = self.samples[index]
67
+ image = Image.open(image_path).convert("RGB")
68
+ if self.transform is not None:
69
+ image = self.transform(image)
70
+ if self.return_metadata:
71
+ return image, target, record
72
+ return image, target
73
+
74
+
75
+ def read_metadata(csv_path: Path) -> pd.DataFrame:
76
+ return pd.read_csv(csv_path)
77
+
78
+
79
+ def infer_class_names(df: pd.DataFrame, label_column: str, class_names: Optional[Sequence[str]]) -> list[str]:
80
+ if class_names:
81
+ return [str(name) for name in class_names]
82
+ return sorted(str(label) for label in df[label_column].dropna().unique())
83
+
84
+
85
+ def class_mapping(class_names: Sequence[str]) -> tuple[dict[str, int], list[str]]:
86
+ names = [str(name) for name in class_names]
87
+ return {name: idx for idx, name in enumerate(names)}, names
88
+
89
+
90
+ def filter_split(df: pd.DataFrame, split_column: Optional[str], split_name: Optional[str]) -> pd.DataFrame:
91
+ if not split_column or split_name is None:
92
+ return df.copy()
93
+ if split_column not in df.columns:
94
+ raise KeyError(f"Split column '{split_column}' was not found in CSV.")
95
+ return df[df[split_column].astype(str) == str(split_name)].copy()
96
+
97
+
98
+ def build_transforms(
99
+ image_size: tuple[int, int] = (300, 300),
100
+ train: bool = False,
101
+ mean: tuple[float, float, float] = (0.485, 0.456, 0.406),
102
+ std: tuple[float, float, float] = (0.229, 0.224, 0.225),
103
+ ):
104
+ ops = [transforms.Resize(image_size)]
105
+ if train:
106
+ ops.extend(
107
+ [
108
+ transforms.RandomHorizontalFlip(),
109
+ transforms.RandomVerticalFlip(),
110
+ transforms.RandomRotation(10),
111
+ ]
112
+ )
113
+ ops.extend([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
114
+ return transforms.Compose(ops)
115
+
116
+
117
+ def build_dataset(
118
+ config: ExperimentConfig,
119
+ split_name: Optional[str],
120
+ transform,
121
+ return_metadata: bool = False,
122
+ ) -> tuple[CSVImageDataset, list[str]]:
123
+ df = read_metadata(config.csv_path)
124
+ class_names = infer_class_names(df, config.label_column, config.class_names)
125
+ class_to_idx, class_names = class_mapping(class_names)
126
+ rows = filter_split(df, config.split_column, split_name)
127
+ dataset = CSVImageDataset(
128
+ rows=rows,
129
+ image_dir=config.image_dir,
130
+ filename_column=config.filename_column,
131
+ label_column=config.label_column,
132
+ class_to_idx=class_to_idx,
133
+ transform=transform,
134
+ return_metadata=return_metadata,
135
+ )
136
+ return dataset, class_names
137
+
138
+
139
+ def build_dataloaders(config: ExperimentConfig) -> tuple[DataLoader, DataLoader, list[str]]:
140
+ train_transform = build_transforms(
141
+ image_size=config.image_size,
142
+ train=True,
143
+ mean=config.normalize_mean,
144
+ std=config.normalize_std,
145
+ )
146
+ eval_transform = build_transforms(
147
+ image_size=config.image_size,
148
+ train=False,
149
+ mean=config.normalize_mean,
150
+ std=config.normalize_std,
151
+ )
152
+ train_dataset, class_names = build_dataset(config, config.train_split, train_transform)
153
+ val_dataset, _ = build_dataset(config, config.val_split, eval_transform)
154
+
155
+ generator = torch.Generator().manual_seed(config.seed)
156
+ num_workers = config.num_workers if config.num_workers is not None else 0
157
+ worker_init = partial(seed_worker, base_seed=config.seed)
158
+ train_loader = DataLoader(
159
+ train_dataset,
160
+ batch_size=config.batch_size,
161
+ shuffle=True,
162
+ num_workers=num_workers,
163
+ pin_memory=torch.cuda.is_available(),
164
+ worker_init_fn=worker_init,
165
+ generator=generator,
166
+ )
167
+ val_loader = DataLoader(
168
+ val_dataset,
169
+ batch_size=config.batch_size,
170
+ shuffle=False,
171
+ num_workers=num_workers,
172
+ pin_memory=torch.cuda.is_available(),
173
+ worker_init_fn=worker_init,
174
+ )
175
+ return train_loader, val_loader, class_names
visionkit/evaluate.py ADDED
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import DataLoader
10
+ from tqdm.auto import tqdm
11
+
12
+ from visionkit.data import build_dataset, build_transforms
13
+ from visionkit.metrics import classification_metrics
14
+
15
+
16
+ def _record_at(records, index: int) -> dict:
17
+ if not isinstance(records, dict):
18
+ return dict(records[index])
19
+ row = {}
20
+ for key, value in records.items():
21
+ item = value[index]
22
+ if hasattr(item, "item"):
23
+ try:
24
+ item = item.item()
25
+ except ValueError:
26
+ pass
27
+ row[key] = item
28
+ return row
29
+
30
+
31
+ def predict(
32
+ model: torch.nn.Module,
33
+ dataloader: DataLoader,
34
+ class_names: list[str],
35
+ device: Optional[str] = None,
36
+ positive_index: Optional[int] = None,
37
+ ) -> pd.DataFrame:
38
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
39
+ positive_index = positive_index if positive_index is not None else min(1, len(class_names) - 1)
40
+ model.to(device)
41
+ model.eval()
42
+ rows = []
43
+
44
+ with torch.no_grad():
45
+ for images, labels, records in tqdm(dataloader, desc="predict"):
46
+ images = images.to(device)
47
+ outputs = model(images)
48
+ probabilities = F.softmax(outputs, dim=1)
49
+ predicted = probabilities.argmax(dim=1).cpu().tolist()
50
+ scores = probabilities[:, positive_index].cpu().tolist()
51
+
52
+ for idx in range(images.size(0)):
53
+ record = _record_at(records, idx)
54
+ true_label_idx = labels[idx].item() if labels[idx] is not None else None
55
+ row = dict(record)
56
+ row.update(
57
+ {
58
+ "true_label_num": true_label_idx,
59
+ "true_label": class_names[true_label_idx] if true_label_idx is not None else None,
60
+ "predicted_class_num": predicted[idx],
61
+ "predicted_class": class_names[predicted[idx]],
62
+ "positive_class_probability": scores[idx],
63
+ }
64
+ )
65
+ for class_idx, class_name in enumerate(class_names):
66
+ row[f"probability_{class_name}"] = probabilities[idx, class_idx].item()
67
+ rows.append(row)
68
+ return pd.DataFrame(rows)
69
+
70
+
71
+ def evaluate_dataframe(
72
+ predictions: pd.DataFrame,
73
+ positive_index: int = 1,
74
+ n_bootstraps: int = 1000,
75
+ seed: int = 123,
76
+ ) -> dict[str, float]:
77
+ if predictions.empty or predictions["true_label_num"].nunique() < 2:
78
+ return {}
79
+ return classification_metrics(
80
+ predictions["true_label_num"].astype(int),
81
+ predictions["predicted_class_num"].astype(int),
82
+ predictions["positive_class_probability"],
83
+ positive_index=positive_index,
84
+ n_bootstraps=n_bootstraps,
85
+ seed=seed,
86
+ )
87
+
88
+
89
+ def evaluate_model(
90
+ model: torch.nn.Module,
91
+ config,
92
+ split_name: Optional[str],
93
+ class_names: list[str],
94
+ output_csv: Path,
95
+ positive_index: int = 1,
96
+ ) -> tuple[pd.DataFrame, dict[str, float]]:
97
+ transform = build_transforms(
98
+ image_size=config.image_size,
99
+ train=False,
100
+ mean=config.normalize_mean,
101
+ std=config.normalize_std,
102
+ )
103
+ dataset, _ = build_dataset(config, split_name, transform, return_metadata=True)
104
+ loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, num_workers=0)
105
+ predictions = predict(model, loader, class_names, device=config.device, positive_index=positive_index)
106
+ output_csv.parent.mkdir(parents=True, exist_ok=True)
107
+ predictions.to_csv(output_csv, index=False)
108
+ metrics = evaluate_dataframe(predictions, positive_index=positive_index, seed=config.seed)
109
+ if metrics:
110
+ pd.DataFrame([metrics]).to_csv(output_csv.with_name(output_csv.stem + "_metrics.csv"), index=False)
111
+ return predictions, metrics
visionkit/gradcam.py ADDED
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Sequence, Union
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from tqdm.auto import tqdm
13
+
14
+ from visionkit.data import build_transforms, class_mapping, filter_split, infer_class_names, read_metadata
15
+ from visionkit.evaluate import evaluate_dataframe
16
+ from visionkit.models import get_module_by_path
17
+
18
+
19
+ class GradCAM:
20
+ def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module) -> None:
21
+ self.model = model
22
+ self.target_layer = target_layer
23
+ self.gradients = None
24
+ self.activations = None
25
+ self.handles = []
26
+ self._register_hooks()
27
+
28
+ def _register_hooks(self) -> None:
29
+ def forward_hook(module, inputs, output):
30
+ self.activations = output.detach()
31
+
32
+ def backward_hook(module, grad_input, grad_output):
33
+ self.gradients = grad_output[0].detach()
34
+
35
+ self.handles.append(self.target_layer.register_forward_hook(forward_hook))
36
+ self.handles.append(self.target_layer.register_full_backward_hook(backward_hook))
37
+
38
+ def generate(self, input_tensor: torch.Tensor, target_class: Optional[int] = None) -> np.ndarray:
39
+ output = self.model(input_tensor)
40
+ if target_class is None:
41
+ target_class = int(output.argmax(dim=1).item())
42
+ self.model.zero_grad()
43
+ output[0, target_class].backward()
44
+
45
+ if self.gradients is None or self.activations is None:
46
+ raise RuntimeError("Grad-CAM hooks did not capture gradients or activations.")
47
+
48
+ gradients = self.gradients[0]
49
+ activations = self.activations[0]
50
+ weights = gradients.mean(dim=(1, 2), keepdim=True)
51
+ cam = torch.sum(weights * activations, dim=0)
52
+ cam = F.relu(cam)
53
+ cam = F.interpolate(
54
+ cam.unsqueeze(0).unsqueeze(0),
55
+ size=input_tensor.shape[2:],
56
+ mode="bilinear",
57
+ align_corners=False,
58
+ )
59
+ cam_np = cam.squeeze().detach().cpu().numpy()
60
+ denominator = cam_np.max() - cam_np.min()
61
+ if denominator <= 1e-12:
62
+ return np.zeros_like(cam_np)
63
+ return (cam_np - cam_np.min()) / denominator
64
+
65
+ def close(self) -> None:
66
+ for handle in self.handles:
67
+ handle.remove()
68
+ self.handles.clear()
69
+
70
+
71
+ def _denormalize(image_tensor: torch.Tensor, mean: Sequence[float], std: Sequence[float]) -> np.ndarray:
72
+ image = image_tensor.detach().cpu().clone()
73
+ for channel, (mean_value, std_value) in enumerate(zip(mean, std)):
74
+ image[channel] = image[channel] * std_value + mean_value
75
+ image = image.clamp(0, 1).permute(1, 2, 0).numpy()
76
+ return image
77
+
78
+
79
+ def save_gradcam_figure(
80
+ original_image: np.ndarray,
81
+ cam: np.ndarray,
82
+ save_path: Path,
83
+ title: str,
84
+ alpha: float = 0.4,
85
+ ) -> None:
86
+ save_path.parent.mkdir(parents=True, exist_ok=True)
87
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4))
88
+ axes[0].imshow(original_image)
89
+ axes[0].set_title(title)
90
+ axes[1].imshow(cam, cmap="jet")
91
+ axes[1].set_title("Grad-CAM")
92
+ axes[2].imshow(original_image)
93
+ axes[2].imshow(cam, cmap="jet", alpha=alpha)
94
+ axes[2].set_title("Overlay")
95
+ for axis in axes:
96
+ axis.axis("off")
97
+ fig.tight_layout()
98
+ fig.savefig(save_path, dpi=150)
99
+ plt.close(fig)
100
+
101
+
102
+ def generate_gradcam_report(
103
+ model: torch.nn.Module,
104
+ csv_path: Path,
105
+ image_dir: Path,
106
+ output_dir: Path,
107
+ output_csv: Path,
108
+ filename_column: str = "filename",
109
+ label_column: str = "label",
110
+ split_column: Optional[str] = "split",
111
+ split_name: Optional[str] = None,
112
+ class_names: Optional[Sequence[str]] = None,
113
+ target_class: Optional[Union[str, int]] = None,
114
+ target_layer: Optional[str] = None,
115
+ image_size: tuple[int, int] = (300, 300),
116
+ mean: tuple[float, float, float] = (0.485, 0.456, 0.406),
117
+ std: tuple[float, float, float] = (0.229, 0.224, 0.225),
118
+ device: Optional[str] = None,
119
+ ) -> tuple[pd.DataFrame, dict[str, float]]:
120
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
121
+ model.to(device)
122
+ model.eval()
123
+ df = filter_split(read_metadata(csv_path), split_column, split_name)
124
+ class_names = infer_class_names(df, label_column, class_names)
125
+ class_to_idx, class_names = class_mapping(class_names)
126
+ positive_index = min(1, len(class_names) - 1)
127
+ if target_class is None:
128
+ target_index = positive_index
129
+ elif isinstance(target_class, int):
130
+ target_index = target_class
131
+ else:
132
+ target_index = class_to_idx[str(target_class)]
133
+
134
+ transform = build_transforms(image_size=image_size, train=False, mean=mean, std=std)
135
+ target_module = get_module_by_path(model, target_layer)
136
+ gradcam = GradCAM(model, target_module)
137
+ output_dir = Path(output_dir)
138
+ output_csv = Path(output_csv)
139
+ results = []
140
+
141
+ try:
142
+ for record in tqdm(df.to_dict(orient="records"), desc="gradcam"):
143
+ filename = str(record[filename_column])
144
+ image_path = Path(image_dir) / filename
145
+ if not image_path.exists():
146
+ continue
147
+ image = Image.open(image_path).convert("RGB")
148
+ image_tensor = transform(image).unsqueeze(0).to(device)
149
+ cam = gradcam.generate(image_tensor, target_class=target_index)
150
+
151
+ with torch.no_grad():
152
+ outputs = model(image_tensor)
153
+ probabilities = F.softmax(outputs, dim=1)[0].detach().cpu()
154
+ predicted_index = int(probabilities.argmax().item())
155
+
156
+ true_label = str(record[label_column]) if label_column in record else None
157
+ true_index = class_to_idx[true_label] if true_label in class_to_idx else None
158
+ original = _denormalize(image_tensor[0], mean, std)
159
+ pred_label = class_names[predicted_index]
160
+ save_name = f"{Path(filename).stem}__pred-{pred_label}.png"
161
+ save_gradcam_figure(
162
+ original,
163
+ cam,
164
+ output_dir / save_name,
165
+ title=f"{filename}\ntrue={true_label} pred={pred_label}",
166
+ )
167
+
168
+ result = dict(record)
169
+ result.update(
170
+ {
171
+ "gradcam_file": save_name,
172
+ "target_class_num": target_index,
173
+ "target_class": class_names[target_index],
174
+ "true_label_num": true_index,
175
+ "true_label": true_label,
176
+ "predicted_class_num": predicted_index,
177
+ "predicted_class": pred_label,
178
+ "positive_class_probability": float(probabilities[positive_index].item()),
179
+ }
180
+ )
181
+ for class_idx, class_name in enumerate(class_names):
182
+ result[f"probability_{class_name}"] = float(probabilities[class_idx].item())
183
+ results.append(result)
184
+ finally:
185
+ gradcam.close()
186
+
187
+ results_df = pd.DataFrame(results)
188
+ output_csv.parent.mkdir(parents=True, exist_ok=True)
189
+ results_df.to_csv(output_csv, index=False)
190
+ metrics = evaluate_dataframe(results_df, positive_index=positive_index)
191
+ if metrics:
192
+ pd.DataFrame([metrics]).to_csv(output_csv.with_name(output_csv.stem + "_metrics.csv"), index=False)
193
+ return results_df, metrics