py2ls 0.1.10.12__py3-none-any.whl → 0.2.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of py2ls might be problematic. Click here for more details.
- py2ls/.DS_Store +0 -0
- py2ls/.git/.DS_Store +0 -0
- py2ls/.git/index +0 -0
- py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
- py2ls/.git/objects/.DS_Store +0 -0
- py2ls/.git/refs/.DS_Store +0 -0
- py2ls/ImageLoader.py +621 -0
- py2ls/__init__.py +7 -5
- py2ls/apptainer2ls.py +3940 -0
- py2ls/batman.py +164 -42
- py2ls/bio.py +2595 -0
- py2ls/cell_image_clf.py +1632 -0
- py2ls/container2ls.py +4635 -0
- py2ls/corr.py +475 -0
- py2ls/data/.DS_Store +0 -0
- py2ls/data/email/email_html_template.html +88 -0
- py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
- py2ls/data/hyper_param_tabrepo_2024.py +1753 -0
- py2ls/data/mygenes_fields_241022.txt +355 -0
- py2ls/data/re_common_pattern.json +173 -0
- py2ls/data/sns_info.json +74 -0
- py2ls/data/styles/.DS_Store +0 -0
- py2ls/data/styles/example/.DS_Store +0 -0
- py2ls/data/styles/stylelib/.DS_Store +0 -0
- py2ls/data/styles/stylelib/grid.mplstyle +15 -0
- py2ls/data/styles/stylelib/high-contrast.mplstyle +6 -0
- py2ls/data/styles/stylelib/high-vis.mplstyle +4 -0
- py2ls/data/styles/stylelib/ieee.mplstyle +15 -0
- py2ls/data/styles/stylelib/light.mplstyl +6 -0
- py2ls/data/styles/stylelib/muted.mplstyle +6 -0
- py2ls/data/styles/stylelib/nature-reviews-latex.mplstyle +616 -0
- py2ls/data/styles/stylelib/nature-reviews.mplstyle +616 -0
- py2ls/data/styles/stylelib/nature.mplstyle +31 -0
- py2ls/data/styles/stylelib/no-latex.mplstyle +10 -0
- py2ls/data/styles/stylelib/notebook.mplstyle +36 -0
- py2ls/data/styles/stylelib/paper.mplstyle +290 -0
- py2ls/data/styles/stylelib/paper2.mplstyle +305 -0
- py2ls/data/styles/stylelib/retro.mplstyle +4 -0
- py2ls/data/styles/stylelib/sans.mplstyle +10 -0
- py2ls/data/styles/stylelib/scatter.mplstyle +7 -0
- py2ls/data/styles/stylelib/science.mplstyle +48 -0
- py2ls/data/styles/stylelib/std-colors.mplstyle +4 -0
- py2ls/data/styles/stylelib/vibrant.mplstyle +6 -0
- py2ls/data/tiles.csv +146 -0
- py2ls/data/usages_pd.json +1417 -0
- py2ls/data/usages_sns.json +31 -0
- py2ls/docker2ls.py +5446 -0
- py2ls/ec2ls.py +61 -0
- py2ls/fetch_update.py +145 -0
- py2ls/ich2ls.py +1955 -296
- py2ls/im2.py +8242 -0
- py2ls/image_ml2ls.py +2100 -0
- py2ls/ips.py +33909 -3418
- py2ls/ml2ls.py +7700 -0
- py2ls/mol.py +289 -0
- py2ls/mount2ls.py +1307 -0
- py2ls/netfinder.py +873 -351
- py2ls/nl2ls.py +283 -0
- py2ls/ocr.py +1581 -458
- py2ls/plot.py +10394 -314
- py2ls/rna2ls.py +311 -0
- py2ls/ssh2ls.md +456 -0
- py2ls/ssh2ls.py +5933 -0
- py2ls/ssh2ls_v01.py +2204 -0
- py2ls/stats.py +66 -172
- py2ls/temp20251124.py +509 -0
- py2ls/translator.py +2 -0
- py2ls/utils/decorators.py +3564 -0
- py2ls/utils_bio.py +3453 -0
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/METADATA +113 -224
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/RECORD +72 -16
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/WHEEL +0 -0
py2ls/cell_image_clf.py
ADDED
|
@@ -0,0 +1,1632 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.optim as optim
|
|
7
|
+
from torch.utils.data import Dataset, DataLoader, random_split
|
|
8
|
+
from torchvision import models
|
|
9
|
+
import albumentations as A
|
|
10
|
+
from albumentations.pytorch import ToTensorV2
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from sklearn.metrics import (
|
|
13
|
+
classification_report,
|
|
14
|
+
confusion_matrix,
|
|
15
|
+
roc_auc_score,
|
|
16
|
+
average_precision_score,
|
|
17
|
+
)
|
|
18
|
+
from sklearn.preprocessing import label_binarize
|
|
19
|
+
import matplotlib.pyplot as plt
|
|
20
|
+
import seaborn as sns
|
|
21
|
+
import pandas as pd
|
|
22
|
+
import joblib
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
import json
|
|
25
|
+
from datetime import datetime
|
|
26
|
+
import time
|
|
27
|
+
from PIL import Image
|
|
28
|
+
import gc
|
|
29
|
+
from torch.cuda import amp
|
|
30
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
|
|
31
|
+
import warnings
|
|
32
|
+
|
|
33
|
+
# 忽略警告
|
|
34
|
+
warnings.filterwarnings("ignore")
|
|
35
|
+
|
|
36
|
+
from .ips import set_computing_device
|
|
37
|
+
set_computing_device()
|
|
38
|
+
|
|
39
|
+
class CellImageClassifier:
|
|
40
|
+
"""
|
|
41
|
+
终极细胞图像分类系统
|
|
42
|
+
|
|
43
|
+
主要改进:
|
|
44
|
+
1. 支持混合精度训练(AMP) - 加速训练并减少显存占用
|
|
45
|
+
2. 添加多种学习率调度器
|
|
46
|
+
3. 增强多通道图像处理能力
|
|
47
|
+
4. 添加模型解释功能(Grad-CAM)
|
|
48
|
+
5. 改进数据加载效率(使用缓存)
|
|
49
|
+
6. 添加更详细的评估指标(AUC-ROC, PR曲线等)
|
|
50
|
+
7. 增强错误处理和日志记录
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, config=None):
|
|
54
|
+
# 默认配置
|
|
55
|
+
self.default_config = {
|
|
56
|
+
"model_name": "resnet18", # 支持的模型: resnet18, resnet50, efficientnet_b0-3, densenet121, vgg16, convnext_tiny
|
|
57
|
+
"num_classes": 5, # 分类类别数
|
|
58
|
+
"input_channels": 3, # 输入通道数
|
|
59
|
+
"image_size": (256, 256), # 输入图像尺寸
|
|
60
|
+
"batch_size": 16, # 批量大小
|
|
61
|
+
"learning_rate": 0.001, # 初始学习率
|
|
62
|
+
"epochs": 30, # 训练轮数
|
|
63
|
+
"device": self._auto_select_device(), # 自动选择设备
|
|
64
|
+
"augmentation_level": "high", # 增强级别: none, low, medium, high
|
|
65
|
+
"staining_type": "generic", # 染色类型: generic, H&E, fluorescence, IHC
|
|
66
|
+
"model_save_path": "models", # 模型保存路径
|
|
67
|
+
"report_save_path": "reports", # 报告保存路径
|
|
68
|
+
"data_cache_path": "data_cache", # 数据缓存路径
|
|
69
|
+
"use_amp": True, # 启用混合精度训练
|
|
70
|
+
"lr_scheduler": "plateau", # 学习率调度器: plateau, cosine, step
|
|
71
|
+
"optimizer": "adam", # 优化器: adam, sgd
|
|
72
|
+
"early_stopping": 5, # 早停轮数(0表示禁用)
|
|
73
|
+
"class_weights": None, # 类别权重(处理不平衡数据)
|
|
74
|
+
"grad_cam_layers": ["layer4"], # Grad-CAM可视化层
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# 合并用户配置
|
|
78
|
+
self.config = {**self.default_config, **(config or {})}
|
|
79
|
+
|
|
80
|
+
# 初始化模型
|
|
81
|
+
self.model = None
|
|
82
|
+
self.class_names = None
|
|
83
|
+
self.label_encoder = None
|
|
84
|
+
self.stats = {
|
|
85
|
+
"train_loss": [],
|
|
86
|
+
"val_loss": [],
|
|
87
|
+
"val_accuracy": [],
|
|
88
|
+
"val_auc": [],
|
|
89
|
+
"val_f1": [],
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# 创建必要的目录
|
|
93
|
+
os.makedirs(self.config["model_save_path"], exist_ok=True)
|
|
94
|
+
os.makedirs(self.config["report_save_path"], exist_ok=True)
|
|
95
|
+
os.makedirs(self.config["data_cache_path"], exist_ok=True)
|
|
96
|
+
|
|
97
|
+
# 初始化混合精度训练
|
|
98
|
+
self.scaler = amp.GradScaler(enabled=self.config["use_amp"])
|
|
99
|
+
|
|
100
|
+
print(f"细胞图像分类器初始化完成 | 设备: {self.config['device']}")
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _auto_select_device():
|
|
104
|
+
"""自动选择最佳计算设备"""
|
|
105
|
+
if torch.cuda.is_available():
|
|
106
|
+
return "cuda"
|
|
107
|
+
elif torch.backends.mps.is_available():
|
|
108
|
+
return "mps"
|
|
109
|
+
else:
|
|
110
|
+
return "cpu"
|
|
111
|
+
|
|
112
|
+
def _clear_cache(self):
|
|
113
|
+
"""清除缓存以释放内存"""
|
|
114
|
+
torch.cuda.empty_cache()
|
|
115
|
+
gc.collect()
|
|
116
|
+
|
|
117
|
+
def build_model(self):
|
|
118
|
+
"""创建图像分类模型"""
|
|
119
|
+
model_name = self.config["model_name"].lower()
|
|
120
|
+
num_classes = self.config["num_classes"]
|
|
121
|
+
in_channels = self.config["input_channels"]
|
|
122
|
+
|
|
123
|
+
# 支持的模型列表
|
|
124
|
+
supported_models = {
|
|
125
|
+
"resnet18": models.resnet18,
|
|
126
|
+
"resnet34": models.resnet34,
|
|
127
|
+
"resnet50": models.resnet50,
|
|
128
|
+
"resnet101": models.resnet101,
|
|
129
|
+
"efficientnet_b0": models.efficientnet_b0,
|
|
130
|
+
"efficientnet_b3": models.efficientnet_b3,
|
|
131
|
+
"densenet121": models.densenet121,
|
|
132
|
+
"vgg16": models.vgg16,
|
|
133
|
+
"convnext_tiny": models.convnext_tiny,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
if model_name not in supported_models:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"不支持的模型: {model_name}。支持: {list(supported_models.keys())}"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# 创建预训练模型
|
|
142
|
+
try:
|
|
143
|
+
model_func = supported_models[model_name]
|
|
144
|
+
pretrained_model = model_func(pretrained=True)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
print(f"加载预训练模型失败: {e}, 使用随机初始化")
|
|
147
|
+
pretrained_model = model_func(pretrained=False)
|
|
148
|
+
|
|
149
|
+
# 根据模型类型调整输入通道
|
|
150
|
+
if "resnet" in model_name:
|
|
151
|
+
# 修改ResNet的第一层卷积
|
|
152
|
+
original_conv1 = pretrained_model.conv1
|
|
153
|
+
pretrained_model.conv1 = nn.Conv2d(
|
|
154
|
+
in_channels=in_channels,
|
|
155
|
+
out_channels=original_conv1.out_channels,
|
|
156
|
+
kernel_size=original_conv1.kernel_size,
|
|
157
|
+
stride=original_conv1.stride,
|
|
158
|
+
padding=original_conv1.padding,
|
|
159
|
+
bias=original_conv1.bias,
|
|
160
|
+
)
|
|
161
|
+
# 修改最后的全连接层
|
|
162
|
+
num_features = pretrained_model.fc.in_features
|
|
163
|
+
pretrained_model.fc = nn.Sequential(
|
|
164
|
+
nn.Dropout(0.5),
|
|
165
|
+
nn.Linear(num_features, 512),
|
|
166
|
+
nn.ReLU(),
|
|
167
|
+
nn.Dropout(0.3),
|
|
168
|
+
nn.Linear(512, num_classes),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
elif "efficientnet" in model_name:
|
|
172
|
+
# 修改EfficientNet的第一层卷积
|
|
173
|
+
original_conv = pretrained_model.features[0][0]
|
|
174
|
+
pretrained_model.features[0][0] = nn.Conv2d(
|
|
175
|
+
in_channels=in_channels,
|
|
176
|
+
out_channels=original_conv.out_channels,
|
|
177
|
+
kernel_size=original_conv.kernel_size,
|
|
178
|
+
stride=original_conv.stride,
|
|
179
|
+
padding=original_conv.padding,
|
|
180
|
+
bias=False,
|
|
181
|
+
)
|
|
182
|
+
# 修改最后的分类层
|
|
183
|
+
num_features = pretrained_model.classifier[1].in_features
|
|
184
|
+
pretrained_model.classifier = nn.Sequential(
|
|
185
|
+
nn.Dropout(0.4), nn.Linear(num_features, num_classes)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
elif "densenet" in model_name:
|
|
189
|
+
# 修改DenseNet的第一层卷积
|
|
190
|
+
original_conv = pretrained_model.features.conv0
|
|
191
|
+
pretrained_model.features.conv0 = nn.Conv2d(
|
|
192
|
+
in_channels=in_channels,
|
|
193
|
+
out_channels=original_conv.out_channels,
|
|
194
|
+
kernel_size=original_conv.kernel_size,
|
|
195
|
+
stride=original_conv.stride,
|
|
196
|
+
padding=original_conv.padding,
|
|
197
|
+
bias=False,
|
|
198
|
+
)
|
|
199
|
+
# 修改最后的分类层
|
|
200
|
+
num_features = pretrained_model.classifier.in_features
|
|
201
|
+
pretrained_model.classifier = nn.Linear(num_features, num_classes)
|
|
202
|
+
|
|
203
|
+
elif "vgg" in model_name:
|
|
204
|
+
# 修改VGG的第一层卷积
|
|
205
|
+
original_conv = pretrained_model.features[0]
|
|
206
|
+
pretrained_model.features[0] = nn.Conv2d(
|
|
207
|
+
in_channels=in_channels,
|
|
208
|
+
out_channels=original_conv.out_channels,
|
|
209
|
+
kernel_size=original_conv.kernel_size,
|
|
210
|
+
stride=original_conv.stride,
|
|
211
|
+
padding=original_conv.padding,
|
|
212
|
+
)
|
|
213
|
+
# 修改最后的分类层
|
|
214
|
+
num_features = pretrained_model.classifier[6].in_features
|
|
215
|
+
pretrained_model.classifier[6] = nn.Linear(num_features, num_classes)
|
|
216
|
+
|
|
217
|
+
elif "convnext" in model_name:
|
|
218
|
+
# 修改ConvNeXt的第一层卷积
|
|
219
|
+
original_conv = pretrained_model.features[0][0]
|
|
220
|
+
pretrained_model.features[0][0] = nn.Conv2d(
|
|
221
|
+
in_channels=in_channels,
|
|
222
|
+
out_channels=original_conv.out_channels,
|
|
223
|
+
kernel_size=original_conv.kernel_size,
|
|
224
|
+
stride=original_conv.stride,
|
|
225
|
+
padding=original_conv.padding,
|
|
226
|
+
)
|
|
227
|
+
# 修改最后的分类层
|
|
228
|
+
num_features = pretrained_model.classifier[2].in_features
|
|
229
|
+
pretrained_model.classifier[2] = nn.Linear(num_features, num_classes)
|
|
230
|
+
|
|
231
|
+
self.model = pretrained_model.to(self.config["device"])
|
|
232
|
+
print(
|
|
233
|
+
f"创建模型: {model_name} | 输入通道: {in_channels} | 类别数: {num_classes}"
|
|
234
|
+
)
|
|
235
|
+
return self.model
|
|
236
|
+
|
|
237
|
+
def get_augmentations(self):
|
|
238
|
+
"""获取数据增强管道 - 根据染色类型和增强级别定制"""
|
|
239
|
+
staining_type = self.config["staining_type"]
|
|
240
|
+
aug_level = self.config["augmentation_level"]
|
|
241
|
+
image_size = self.config["image_size"]
|
|
242
|
+
|
|
243
|
+
# 基础转换(始终应用)
|
|
244
|
+
base_transforms = [
|
|
245
|
+
A.Resize(image_size[0], image_size[1]),
|
|
246
|
+
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
247
|
+
ToTensorV2(),
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
# 根据增强级别选择增强
|
|
251
|
+
if aug_level == "none":
|
|
252
|
+
return A.Compose(base_transforms)
|
|
253
|
+
|
|
254
|
+
# 通用增强
|
|
255
|
+
common_aug = [
|
|
256
|
+
A.HorizontalFlip(p=0.5),
|
|
257
|
+
A.VerticalFlip(p=0.5),
|
|
258
|
+
A.Rotate(limit=30, p=0.5),
|
|
259
|
+
A.ShiftScaleRotate(
|
|
260
|
+
shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5
|
|
261
|
+
),
|
|
262
|
+
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
|
|
263
|
+
]
|
|
264
|
+
|
|
265
|
+
# 染色类型特定的增强
|
|
266
|
+
staining_aug = []
|
|
267
|
+
if staining_type == "H&E":
|
|
268
|
+
# H&E染色增强 - 增强细胞核与细胞质的对比度
|
|
269
|
+
staining_aug = [
|
|
270
|
+
A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=0.7),
|
|
271
|
+
A.HueSaturationValue(
|
|
272
|
+
hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=10, p=0.5
|
|
273
|
+
),
|
|
274
|
+
A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.3),
|
|
275
|
+
A.RandomGamma(gamma_limit=(70, 130), p=0.5), # 增强对比度
|
|
276
|
+
]
|
|
277
|
+
elif staining_type == "fluorescence":
|
|
278
|
+
# 荧光染色增强 - 保持通道关系
|
|
279
|
+
staining_aug = [
|
|
280
|
+
A.RandomGamma(gamma_limit=(80, 120), p=0.5),
|
|
281
|
+
A.MultiplicativeNoise(multiplier=[0.9, 1.1], elementwise=True, p=0.2),
|
|
282
|
+
A.ChannelShuffle(p=0.1), # 模拟通道错位
|
|
283
|
+
A.ChannelDropout(
|
|
284
|
+
channel_drop_range=(1, 1), fill_value=0, p=0.1
|
|
285
|
+
), # 模拟通道丢失
|
|
286
|
+
]
|
|
287
|
+
elif staining_type == "IHC":
|
|
288
|
+
# 免疫组化染色增强 - 增强棕褐色沉淀
|
|
289
|
+
staining_aug = [
|
|
290
|
+
A.ColorJitter(
|
|
291
|
+
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.7
|
|
292
|
+
),
|
|
293
|
+
A.RGBShift(r_shift_limit=20, g_shift_limit=10, b_shift_limit=5, p=0.5),
|
|
294
|
+
A.Sharpen(alpha=(0.2, 0.5), p=0.3), # 增强细节
|
|
295
|
+
]
|
|
296
|
+
else: # generic
|
|
297
|
+
staining_aug = [
|
|
298
|
+
A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
|
|
299
|
+
A.HueSaturationValue(
|
|
300
|
+
hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=5, p=0.5
|
|
301
|
+
),
|
|
302
|
+
A.RandomGamma(gamma_limit=(80, 120), p=0.5),
|
|
303
|
+
]
|
|
304
|
+
|
|
305
|
+
# 高级增强
|
|
306
|
+
advanced_aug = []
|
|
307
|
+
if aug_level in ["medium", "high"]:
|
|
308
|
+
advanced_aug = [
|
|
309
|
+
A.OneOf(
|
|
310
|
+
[
|
|
311
|
+
A.MotionBlur(blur_limit=3, p=0.3),
|
|
312
|
+
A.GaussianBlur(blur_limit=3, p=0.3),
|
|
313
|
+
A.MedianBlur(blur_limit=3, p=0.3),
|
|
314
|
+
],
|
|
315
|
+
p=0.5,
|
|
316
|
+
),
|
|
317
|
+
A.OneOf(
|
|
318
|
+
[
|
|
319
|
+
A.OpticalDistortion(distort_limit=0.5, shift_limit=0.1, p=0.3),
|
|
320
|
+
A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
|
|
321
|
+
A.ElasticTransform(alpha=1, sigma=20, alpha_affine=10, p=0.3),
|
|
322
|
+
],
|
|
323
|
+
p=0.5,
|
|
324
|
+
),
|
|
325
|
+
A.CoarseDropout(max_holes=3, max_height=20, max_width=20, p=0.3),
|
|
326
|
+
]
|
|
327
|
+
|
|
328
|
+
if aug_level == "high":
|
|
329
|
+
advanced_aug += [
|
|
330
|
+
A.RandomShadow(
|
|
331
|
+
shadow_roi=(0, 0.5, 1, 1),
|
|
332
|
+
num_shadows_lower=1,
|
|
333
|
+
num_shadows_upper=2,
|
|
334
|
+
p=0.2,
|
|
335
|
+
),
|
|
336
|
+
A.RandomSunFlare(src_radius=100, p=0.1),
|
|
337
|
+
A.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.3, p=0.1),
|
|
338
|
+
A.RandomRain(p=0.1), # 模拟显微镜上的水滴
|
|
339
|
+
]
|
|
340
|
+
|
|
341
|
+
# 组合所有增强
|
|
342
|
+
all_aug = common_aug + staining_aug + advanced_aug
|
|
343
|
+
return A.Compose(all_aug + base_transforms)
|
|
344
|
+
|
|
345
|
+
def create_dataset(self, image_paths, labels, is_train=True):
|
|
346
|
+
"""
|
|
347
|
+
创建细胞图像数据集
|
|
348
|
+
|
|
349
|
+
参数:
|
|
350
|
+
image_paths (list): 图像路径列表
|
|
351
|
+
labels (list): 对应标签列表
|
|
352
|
+
is_train (bool): 是否为训练集
|
|
353
|
+
|
|
354
|
+
返回:
|
|
355
|
+
CellDataset: 自定义数据集对象
|
|
356
|
+
"""
|
|
357
|
+
# 如果没有标签,创建虚拟标签(用于预测)
|
|
358
|
+
if labels is None:
|
|
359
|
+
labels = [-1] * len(image_paths)
|
|
360
|
+
|
|
361
|
+
# 创建标签编码器(如果是第一次)
|
|
362
|
+
if self.label_encoder is None:
|
|
363
|
+
unique_labels = sorted(set(labels))
|
|
364
|
+
self.label_encoder = {label: idx for idx, label in enumerate(unique_labels)}
|
|
365
|
+
self.class_names = list(self.label_encoder.keys())
|
|
366
|
+
self.config["num_classes"] = len(self.class_names)
|
|
367
|
+
print(f"创建标签编码器 | 类别: {self.class_names}")
|
|
368
|
+
|
|
369
|
+
# 将标签编码为数字
|
|
370
|
+
encoded_labels = [self.label_encoder.get(l, -1) for l in labels]
|
|
371
|
+
|
|
372
|
+
# 获取增强管道
|
|
373
|
+
transform = self.get_augmentations() if is_train else self.get_augmentations()
|
|
374
|
+
|
|
375
|
+
return CellDataset(
|
|
376
|
+
image_paths=image_paths,
|
|
377
|
+
labels=encoded_labels,
|
|
378
|
+
transform=transform,
|
|
379
|
+
is_train=is_train,
|
|
380
|
+
input_channels=self.config["input_channels"],
|
|
381
|
+
cache_dir=self.config["data_cache_path"] if is_train else None,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
def train(self, image_paths, labels, val_split=0.2, save_best=True):
|
|
385
|
+
"""
|
|
386
|
+
训练细胞图像分类模型
|
|
387
|
+
|
|
388
|
+
参数:
|
|
389
|
+
image_paths (list): 图像路径列表
|
|
390
|
+
labels (list): 对应标签列表
|
|
391
|
+
val_split (float): 验证集比例
|
|
392
|
+
save_best (bool): 是否保存最佳模型
|
|
393
|
+
"""
|
|
394
|
+
# 清除缓存
|
|
395
|
+
self._clear_cache()
|
|
396
|
+
|
|
397
|
+
# 创建数据集
|
|
398
|
+
full_dataset = self.create_dataset(image_paths, labels, is_train=True)
|
|
399
|
+
|
|
400
|
+
# 分割训练集和验证集
|
|
401
|
+
if val_split > 0:
|
|
402
|
+
val_size = int(len(full_dataset) * val_split)
|
|
403
|
+
train_size = len(full_dataset) - val_size
|
|
404
|
+
train_dataset, val_dataset = random_split(
|
|
405
|
+
full_dataset,
|
|
406
|
+
[train_size, val_size],
|
|
407
|
+
generator=torch.Generator().manual_seed(42),
|
|
408
|
+
)
|
|
409
|
+
else:
|
|
410
|
+
train_dataset = full_dataset
|
|
411
|
+
val_dataset = None
|
|
412
|
+
|
|
413
|
+
# 创建数据加载器
|
|
414
|
+
train_loader = DataLoader(
|
|
415
|
+
train_dataset,
|
|
416
|
+
batch_size=self.config["batch_size"],
|
|
417
|
+
shuffle=True,
|
|
418
|
+
num_workers=0,#min(4, os.cpu_count()),
|
|
419
|
+
pin_memory=True,
|
|
420
|
+
persistent_workers=True,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
val_loader = None
|
|
424
|
+
if val_dataset:
|
|
425
|
+
val_loader = DataLoader(
|
|
426
|
+
val_dataset,
|
|
427
|
+
batch_size=self.config["batch_size"],
|
|
428
|
+
shuffle=False,
|
|
429
|
+
num_workers=0,#min(2, os.cpu_count()),
|
|
430
|
+
pin_memory=True,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# 创建模型(如果尚未创建)
|
|
434
|
+
if self.model is None:
|
|
435
|
+
self.build_model()
|
|
436
|
+
|
|
437
|
+
# 设置优化器
|
|
438
|
+
optimizer_name = self.config["optimizer"].lower()
|
|
439
|
+
if optimizer_name == "sgd":
|
|
440
|
+
optimizer = optim.SGD(
|
|
441
|
+
self.model.parameters(),
|
|
442
|
+
lr=self.config["learning_rate"],
|
|
443
|
+
momentum=0.9,
|
|
444
|
+
weight_decay=1e-4,
|
|
445
|
+
)
|
|
446
|
+
else: # 默认使用Adam
|
|
447
|
+
optimizer = optim.Adam(
|
|
448
|
+
self.model.parameters(),
|
|
449
|
+
lr=self.config["learning_rate"],
|
|
450
|
+
weight_decay=1e-4,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# 设置损失函数(考虑类别不平衡)
|
|
454
|
+
if self.config["class_weights"]:
|
|
455
|
+
weights = torch.tensor(
|
|
456
|
+
self.config["class_weights"], device=self.config["device"]
|
|
457
|
+
)
|
|
458
|
+
criterion = nn.CrossEntropyLoss(weight=weights)
|
|
459
|
+
else:
|
|
460
|
+
criterion = nn.CrossEntropyLoss()
|
|
461
|
+
|
|
462
|
+
# 设置学习率调度器
|
|
463
|
+
scheduler_name = self.config["lr_scheduler"].lower()
|
|
464
|
+
if scheduler_name == "cosine":
|
|
465
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=self.config["epochs"])
|
|
466
|
+
elif scheduler_name == "step":
|
|
467
|
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
|
|
468
|
+
else: # 默认使用Plateau
|
|
469
|
+
scheduler = ReduceLROnPlateau(
|
|
470
|
+
optimizer, mode="min", patience=3, factor=0.5, verbose=True
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# 训练循环
|
|
474
|
+
best_val_loss = float("inf")
|
|
475
|
+
best_val_accuracy = 0.0
|
|
476
|
+
epochs_no_improve = 0
|
|
477
|
+
|
|
478
|
+
for epoch in range(self.config["epochs"]):
|
|
479
|
+
start_time = time.time()
|
|
480
|
+
|
|
481
|
+
# 训练阶段
|
|
482
|
+
self.model.train()
|
|
483
|
+
train_loss = 0.0
|
|
484
|
+
progress_bar = tqdm(
|
|
485
|
+
train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']} [Train]"
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
for images, labels in progress_bar:
|
|
489
|
+
images, labels = images.to(self.config["device"]), labels.to(
|
|
490
|
+
self.config["device"]
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# 混合精度训练
|
|
494
|
+
with amp.autocast(enabled=self.config["use_amp"]):
|
|
495
|
+
outputs = self.model(images)
|
|
496
|
+
loss = criterion(outputs, labels)
|
|
497
|
+
|
|
498
|
+
# 反向传播
|
|
499
|
+
optimizer.zero_grad()
|
|
500
|
+
self.scaler.scale(loss).backward()
|
|
501
|
+
self.scaler.step(optimizer)
|
|
502
|
+
self.scaler.update()
|
|
503
|
+
|
|
504
|
+
train_loss += loss.item() * images.size(0)
|
|
505
|
+
progress_bar.set_postfix(loss=loss.item())
|
|
506
|
+
|
|
507
|
+
# 计算平均训练损失
|
|
508
|
+
train_loss = train_loss / len(train_loader.dataset)
|
|
509
|
+
self.stats["train_loss"].append(train_loss)
|
|
510
|
+
|
|
511
|
+
# 验证阶段
|
|
512
|
+
val_loss = 0.0
|
|
513
|
+
val_accuracy = 0.0
|
|
514
|
+
all_labels = []
|
|
515
|
+
all_preds = []
|
|
516
|
+
|
|
517
|
+
if val_loader:
|
|
518
|
+
self.model.eval()
|
|
519
|
+
correct = 0
|
|
520
|
+
total = 0
|
|
521
|
+
progress_bar = tqdm(
|
|
522
|
+
val_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']} [Val]"
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
with torch.no_grad():
|
|
526
|
+
for images, labels in progress_bar:
|
|
527
|
+
images, labels = images.to(self.config["device"]), labels.to(
|
|
528
|
+
self.config["device"]
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
with amp.autocast(enabled=self.config["use_amp"]):
|
|
532
|
+
outputs = self.model(images)
|
|
533
|
+
loss = criterion(outputs, labels)
|
|
534
|
+
|
|
535
|
+
val_loss += loss.item() * images.size(0)
|
|
536
|
+
|
|
537
|
+
_, predicted = torch.max(outputs.data, 1)
|
|
538
|
+
total += labels.size(0)
|
|
539
|
+
correct += (predicted == labels).sum().item()
|
|
540
|
+
|
|
541
|
+
# 收集预测结果用于计算高级指标
|
|
542
|
+
all_labels.extend(labels.cpu().numpy())
|
|
543
|
+
all_preds.extend(predicted.cpu().numpy())
|
|
544
|
+
|
|
545
|
+
accuracy = correct / total
|
|
546
|
+
progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy)
|
|
547
|
+
|
|
548
|
+
# 计算验证指标
|
|
549
|
+
val_loss = val_loss / len(val_loader.dataset)
|
|
550
|
+
val_accuracy = correct / total
|
|
551
|
+
self.stats["val_loss"].append(val_loss)
|
|
552
|
+
self.stats["val_accuracy"].append(val_accuracy)
|
|
553
|
+
|
|
554
|
+
# 计算AUC和F1分数
|
|
555
|
+
if len(self.class_names) > 2:
|
|
556
|
+
# 多分类AUC
|
|
557
|
+
y_true_bin = label_binarize(
|
|
558
|
+
all_labels, classes=range(len(self.class_names))
|
|
559
|
+
)
|
|
560
|
+
y_pred_bin = label_binarize(
|
|
561
|
+
all_preds, classes=range(len(self.class_names))
|
|
562
|
+
)
|
|
563
|
+
auc = roc_auc_score(y_true_bin, y_pred_bin, multi_class="ovr")
|
|
564
|
+
f1 = average_precision_score(y_true_bin, y_pred_bin)
|
|
565
|
+
else:
|
|
566
|
+
# 二分类AUC
|
|
567
|
+
auc = roc_auc_score(all_labels, all_preds)
|
|
568
|
+
f1 = average_precision_score(all_labels, all_preds)
|
|
569
|
+
|
|
570
|
+
self.stats["val_auc"].append(auc)
|
|
571
|
+
self.stats["val_f1"].append(f1)
|
|
572
|
+
|
|
573
|
+
# 更新学习率
|
|
574
|
+
if scheduler_name == "plateau":
|
|
575
|
+
scheduler.step(val_loss)
|
|
576
|
+
else:
|
|
577
|
+
scheduler.step()
|
|
578
|
+
|
|
579
|
+
# 保存最佳模型
|
|
580
|
+
if save_best and val_accuracy > best_val_accuracy:
|
|
581
|
+
best_val_accuracy = val_accuracy
|
|
582
|
+
best_val_loss = val_loss
|
|
583
|
+
epochs_no_improve = 0
|
|
584
|
+
self.save_model("best_model.pth")
|
|
585
|
+
print(
|
|
586
|
+
f"保存最佳模型 | 验证准确率: {val_accuracy:.4f} | AUC: {auc:.4f}"
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
epochs_no_improve += 1
|
|
590
|
+
if (
|
|
591
|
+
self.config["early_stopping"] > 0
|
|
592
|
+
and epochs_no_improve >= self.config["early_stopping"]
|
|
593
|
+
):
|
|
594
|
+
print(f"早停触发: 验证准确率连续{epochs_no_improve}轮未提升")
|
|
595
|
+
break
|
|
596
|
+
else:
|
|
597
|
+
# 如果没有验证集,只更新训练损失
|
|
598
|
+
if scheduler_name != "plateau":
|
|
599
|
+
scheduler.step()
|
|
600
|
+
|
|
601
|
+
# 计算epoch时间
|
|
602
|
+
epoch_time = time.time() - start_time
|
|
603
|
+
|
|
604
|
+
# 打印epoch总结
|
|
605
|
+
if val_loader:
|
|
606
|
+
print(
|
|
607
|
+
f"Epoch {epoch+1}/{self.config['epochs']} | "
|
|
608
|
+
f"Time: {epoch_time:.1f}s | "
|
|
609
|
+
f"Train Loss: {train_loss:.4f} | "
|
|
610
|
+
f"Val Loss: {val_loss:.4f} | "
|
|
611
|
+
f"Val Acc: {val_accuracy:.4f} | "
|
|
612
|
+
f"AUC: {auc:.4f} | "
|
|
613
|
+
f"F1: {f1:.4f}"
|
|
614
|
+
)
|
|
615
|
+
else:
|
|
616
|
+
print(
|
|
617
|
+
f"Epoch {epoch+1}/{self.config['epochs']} | "
|
|
618
|
+
f"Time: {epoch_time:.1f}s | "
|
|
619
|
+
f"Train Loss: {train_loss:.4f}"
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
# 保存最终模型
|
|
623
|
+
self.save_model("final_model.pth")
|
|
624
|
+
print("训练完成!")
|
|
625
|
+
|
|
626
|
+
# 可视化训练过程
|
|
627
|
+
self.plot_training_history()
|
|
628
|
+
|
|
629
|
+
return self.stats
|
|
630
|
+
|
|
631
|
+
def predict(self, image_paths, output_dir=None):
|
|
632
|
+
"""
|
|
633
|
+
使用训练好的模型进行预测
|
|
634
|
+
|
|
635
|
+
参数:
|
|
636
|
+
image_paths (list): 图像路径列表
|
|
637
|
+
output_dir (str): 预测结果保存目录
|
|
638
|
+
|
|
639
|
+
返回:
|
|
640
|
+
tuple: (预测标签列表, 预测概率数组)
|
|
641
|
+
"""
|
|
642
|
+
if self.model is None:
|
|
643
|
+
raise RuntimeError("模型未加载,请先加载或训练模型")
|
|
644
|
+
|
|
645
|
+
# 清除缓存
|
|
646
|
+
self._clear_cache()
|
|
647
|
+
|
|
648
|
+
# 创建预测数据集
|
|
649
|
+
dataset = self.create_dataset(image_paths, labels=None, is_train=False)
|
|
650
|
+
dataloader = DataLoader(
|
|
651
|
+
dataset,
|
|
652
|
+
batch_size=self.config["batch_size"],
|
|
653
|
+
shuffle=False,
|
|
654
|
+
num_workers=0,#min(2, os.cpu_count()),
|
|
655
|
+
pin_memory=True,
|
|
656
|
+
)
|
|
657
|
+
# print("---debug---")
|
|
658
|
+
# print("如果参数里有NaN或Inf,说明权重加载或保存有问题,需要重新加载或重新训练模型。")
|
|
659
|
+
# for name, param in self.model.named_parameters():
|
|
660
|
+
# if param.requires_grad:
|
|
661
|
+
# print(f"{name} mean: {param.data.mean().item()}, std: {param.data.std().item()}")
|
|
662
|
+
# if torch.isnan(param).any():
|
|
663
|
+
# print(f"参数{name}中含NaN")
|
|
664
|
+
# if torch.isinf(param).any():
|
|
665
|
+
# print(f"参数{name}中含Inf")
|
|
666
|
+
# print("---debug---")
|
|
667
|
+
# 预测
|
|
668
|
+
self.model.eval()
|
|
669
|
+
all_predictions = []
|
|
670
|
+
all_probabilities = []
|
|
671
|
+
all_logits = []
|
|
672
|
+
with torch.no_grad():
|
|
673
|
+
for images, _ in tqdm(dataloader, desc="预测中"):
|
|
674
|
+
images = images.to(self.config["device"])
|
|
675
|
+
# ---- debug ----
|
|
676
|
+
# sample_img = images[0].unsqueeze(0).to(self.config["device"])
|
|
677
|
+
# output = self.model(sample_img)
|
|
678
|
+
# print("单张图像模型输出:", output)
|
|
679
|
+
# print("是否含NaN:", torch.isnan(output).any().item())
|
|
680
|
+
# print("是否含Inf:", torch.isinf(output).any().item())
|
|
681
|
+
# ---- debug ----
|
|
682
|
+
|
|
683
|
+
# print("输入数据 min:", images.min().item())
|
|
684
|
+
# print("输入数据 max:", images.max().item())
|
|
685
|
+
# print("输入数据 has NaN:", torch.isnan(images).any().item())
|
|
686
|
+
# print("输入数据 has Inf:", torch.isinf(images).any().item())
|
|
687
|
+
# break # 只打印第一个batch,防止输出太多
|
|
688
|
+
|
|
689
|
+
with amp.autocast(enabled=self.config["use_amp"]):
|
|
690
|
+
outputs = self.model(images)
|
|
691
|
+
# ----debug----
|
|
692
|
+
# print("outputs min:", outputs.min().item())
|
|
693
|
+
# print("outputs max:", outputs.max().item())
|
|
694
|
+
# print("outputs has NaN:", torch.isnan(outputs).any().item())
|
|
695
|
+
# print("outputs has Inf:", torch.isinf(outputs).any().item())
|
|
696
|
+
# ----debug----
|
|
697
|
+
|
|
698
|
+
probabilities = torch.softmax(outputs, dim=1)
|
|
699
|
+
_, predicted = torch.max(outputs, 1)
|
|
700
|
+
|
|
701
|
+
all_predictions.extend(predicted.cpu().numpy())
|
|
702
|
+
all_probabilities.extend(probabilities.cpu().numpy())
|
|
703
|
+
all_logits.extend(outputs.cpu().numpy())
|
|
704
|
+
|
|
705
|
+
# 将数字标签解码为原始标签
|
|
706
|
+
decoded_predictions = [self.class_names[p] for p in all_predictions]
|
|
707
|
+
|
|
708
|
+
# 保存预测结果
|
|
709
|
+
if output_dir:
|
|
710
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
711
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
712
|
+
report_path = os.path.join(output_dir, f"predictions_{timestamp}.csv")
|
|
713
|
+
|
|
714
|
+
# 创建结果DataFrame
|
|
715
|
+
results = []
|
|
716
|
+
for i, path in enumerate(image_paths):
|
|
717
|
+
pred_label = decoded_predictions[i]
|
|
718
|
+
prob = all_probabilities[i][all_predictions[i]]
|
|
719
|
+
results.append(
|
|
720
|
+
{
|
|
721
|
+
"image_path": path,
|
|
722
|
+
"predicted_class": pred_label,
|
|
723
|
+
"probability": prob,
|
|
724
|
+
**{
|
|
725
|
+
f"prob_{cls}": prob_val
|
|
726
|
+
for cls, prob_val in zip(
|
|
727
|
+
self.class_names, all_probabilities[i]
|
|
728
|
+
)
|
|
729
|
+
},
|
|
730
|
+
}
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
df = pd.DataFrame(results)
|
|
734
|
+
df.to_csv(report_path, index=False)
|
|
735
|
+
print(f"预测结果保存至: {report_path}")
|
|
736
|
+
|
|
737
|
+
# 生成可视化报告
|
|
738
|
+
self.generate_prediction_report(
|
|
739
|
+
image_paths, decoded_predictions, all_probabilities, output_dir
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# 生成Grad-CAM可视化
|
|
743
|
+
if self.config["grad_cam_layers"]:
|
|
744
|
+
self.generate_grad_cam(
|
|
745
|
+
image_paths,
|
|
746
|
+
all_logits,
|
|
747
|
+
output_dir,
|
|
748
|
+
layer_names=self.config["grad_cam_layers"],
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
return decoded_predictions, np.array(all_probabilities)
|
|
752
|
+
|
|
753
|
+
def evaluate(self, image_paths, labels):
|
|
754
|
+
"""
|
|
755
|
+
评估模型性能
|
|
756
|
+
|
|
757
|
+
参数:
|
|
758
|
+
image_paths (list): 图像路径列表
|
|
759
|
+
labels (list): 真实标签列表
|
|
760
|
+
|
|
761
|
+
返回:
|
|
762
|
+
dict: 包含详细评估指标的字典
|
|
763
|
+
"""
|
|
764
|
+
# 预测
|
|
765
|
+
predictions, probabilities = self.predict(image_paths)
|
|
766
|
+
# -------debug------
|
|
767
|
+
# prob_array = np.array(probabilities)
|
|
768
|
+
# print("probabilities shape:", prob_array.shape)
|
|
769
|
+
# print("NaN count in probabilities:", np.isnan(prob_array).sum())
|
|
770
|
+
# print("Sample probabilities:", prob_array[:5])
|
|
771
|
+
# -------debug------
|
|
772
|
+
|
|
773
|
+
# 处理NaN,替换成0
|
|
774
|
+
probabilities = np.nan_to_num(probabilities, nan=0.0)
|
|
775
|
+
|
|
776
|
+
# 编码真实标签
|
|
777
|
+
true_labels_encoded = [self.label_encoder[l] for l in labels]
|
|
778
|
+
# 把字符串预测标签转成数字编码
|
|
779
|
+
predictions_encoded = [self.label_encoder[p] for p in predictions]
|
|
780
|
+
|
|
781
|
+
# 计算评估指标
|
|
782
|
+
report = classification_report(
|
|
783
|
+
true_labels_encoded, predictions_encoded, output_dict=True
|
|
784
|
+
)
|
|
785
|
+
cm = confusion_matrix(
|
|
786
|
+
true_labels_encoded, predictions_encoded, labels=range(len(self.class_names))
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
# 计算AUC-ROC
|
|
790
|
+
if len(self.class_names) > 2:
|
|
791
|
+
# 多分类AUC
|
|
792
|
+
y_true_bin = label_binarize(
|
|
793
|
+
true_labels_encoded, classes=range(len(self.class_names))
|
|
794
|
+
)
|
|
795
|
+
auc = roc_auc_score(y_true_bin, probabilities, multi_class="ovr")
|
|
796
|
+
ap = average_precision_score(y_true_bin, probabilities)
|
|
797
|
+
else:
|
|
798
|
+
# 二分类AUC
|
|
799
|
+
auc = roc_auc_score(true_labels_encoded, probabilities[:, 1])
|
|
800
|
+
ap = average_precision_score(true_labels_encoded, probabilities[:, 1])
|
|
801
|
+
|
|
802
|
+
# 可视化混淆矩阵
|
|
803
|
+
plt.figure(figsize=(12, 10))
|
|
804
|
+
sns.heatmap(
|
|
805
|
+
cm,
|
|
806
|
+
annot=True,
|
|
807
|
+
fmt="d",
|
|
808
|
+
cmap="Blues",
|
|
809
|
+
xticklabels=self.class_names,
|
|
810
|
+
yticklabels=self.class_names,
|
|
811
|
+
)
|
|
812
|
+
plt.xlabel("预测标签")
|
|
813
|
+
plt.ylabel("真实标签")
|
|
814
|
+
plt.title("混淆矩阵")
|
|
815
|
+
|
|
816
|
+
# 保存评估报告
|
|
817
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
818
|
+
report_dir = self.config["report_save_path"]
|
|
819
|
+
os.makedirs(report_dir, exist_ok=True)
|
|
820
|
+
|
|
821
|
+
# 保存分类报告
|
|
822
|
+
report_path = os.path.join(
|
|
823
|
+
report_dir, f"classification_report_{timestamp}.json"
|
|
824
|
+
)
|
|
825
|
+
with open(report_path, "w") as f:
|
|
826
|
+
json.dump(report, f, indent=4)
|
|
827
|
+
|
|
828
|
+
# 保存混淆矩阵
|
|
829
|
+
cm_path = os.path.join(report_dir, f"confusion_matrix_{timestamp}.png")
|
|
830
|
+
plt.savefig(cm_path, bbox_inches="tight")
|
|
831
|
+
plt.close()
|
|
832
|
+
|
|
833
|
+
# 保存ROC曲线
|
|
834
|
+
self.plot_roc_curve(
|
|
835
|
+
true_labels_encoded,
|
|
836
|
+
probabilities,
|
|
837
|
+
save_path=os.path.join(report_dir, f"roc_curve_{timestamp}.png"),
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
# 保存PR曲线
|
|
841
|
+
self.plot_pr_curve(
|
|
842
|
+
true_labels_encoded,
|
|
843
|
+
probabilities,
|
|
844
|
+
save_path=os.path.join(report_dir, f"pr_curve_{timestamp}.png"),
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
print(f"评估报告保存至: {report_path}")
|
|
848
|
+
print(f"混淆矩阵保存至: {cm_path}")
|
|
849
|
+
print(f"AUC-ROC: {auc:.4f} | Average Precision: {ap:.4f}")
|
|
850
|
+
|
|
851
|
+
return {
|
|
852
|
+
"classification_report": report,
|
|
853
|
+
"confusion_matrix": cm.tolist(),
|
|
854
|
+
"auc_roc": auc,
|
|
855
|
+
"average_precision": ap,
|
|
856
|
+
}
|
|
857
|
+
|
|
858
|
+
def save_model(self, filename):
|
|
859
|
+
"""
|
|
860
|
+
保存模型和配置
|
|
861
|
+
|
|
862
|
+
参数:
|
|
863
|
+
filename (str): 模型文件名
|
|
864
|
+
"""
|
|
865
|
+
model_path = os.path.join(self.config["model_save_path"], filename)
|
|
866
|
+
|
|
867
|
+
# 保存模型状态
|
|
868
|
+
torch.save(
|
|
869
|
+
{
|
|
870
|
+
"model_state_dict": self.model.state_dict(),
|
|
871
|
+
"config": self.config,
|
|
872
|
+
"label_encoder": self.label_encoder,
|
|
873
|
+
"class_names": self.class_names,
|
|
874
|
+
"stats": self.stats,
|
|
875
|
+
},
|
|
876
|
+
model_path,
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
print(f"模型保存至: {model_path}")
|
|
880
|
+
return model_path
|
|
881
|
+
|
|
882
|
+
def load_model(self, model_path):
|
|
883
|
+
"""
|
|
884
|
+
加载预训练模型
|
|
885
|
+
|
|
886
|
+
参数:
|
|
887
|
+
model_path (str): 模型文件路径
|
|
888
|
+
"""
|
|
889
|
+
# 加载检查点
|
|
890
|
+
checkpoint = torch.load(
|
|
891
|
+
model_path, map_location=torch.device(self.config["device"])
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
# 加载配置
|
|
895
|
+
if "config" in checkpoint:
|
|
896
|
+
loaded_config = checkpoint["config"]
|
|
897
|
+
# 更新配置但不覆盖设备设置
|
|
898
|
+
loaded_config["device"] = self.config["device"]
|
|
899
|
+
self.config = {**self.config, **loaded_config}
|
|
900
|
+
|
|
901
|
+
# 加载标签编码器和类别名称
|
|
902
|
+
if "label_encoder" in checkpoint:
|
|
903
|
+
self.label_encoder = checkpoint["label_encoder"]
|
|
904
|
+
self.class_names = checkpoint.get(
|
|
905
|
+
"class_names", list(self.label_encoder.keys())
|
|
906
|
+
)
|
|
907
|
+
self.config["num_classes"] = len(self.class_names)
|
|
908
|
+
print(f"加载标签编码器 | 类别: {self.class_names}")
|
|
909
|
+
elif "class_names" in checkpoint:
|
|
910
|
+
self.class_names = checkpoint["class_names"]
|
|
911
|
+
self.label_encoder = {cls: idx for idx, cls in enumerate(self.class_names)}
|
|
912
|
+
self.config["num_classes"] = len(self.class_names)
|
|
913
|
+
print(f"从类别名称重建标签编码器 | 类别: {self.class_names}")
|
|
914
|
+
|
|
915
|
+
# 加载训练统计信息
|
|
916
|
+
if "stats" in checkpoint:
|
|
917
|
+
self.stats = checkpoint["stats"]
|
|
918
|
+
|
|
919
|
+
# 创建模型架构
|
|
920
|
+
self.build_model()
|
|
921
|
+
|
|
922
|
+
# 加载模型权重
|
|
923
|
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
924
|
+
self.model.eval()
|
|
925
|
+
print(f"模型从 {model_path} 加载成功")
|
|
926
|
+
|
|
927
|
+
return self.model
|
|
928
|
+
|
|
929
|
+
def plot_training_history(self):
|
|
930
|
+
"""可视化训练历史"""
|
|
931
|
+
if not self.stats["train_loss"]:
|
|
932
|
+
print("没有训练历史数据")
|
|
933
|
+
return
|
|
934
|
+
|
|
935
|
+
plt.figure(figsize=(16, 12))
|
|
936
|
+
|
|
937
|
+
# 绘制损失曲线
|
|
938
|
+
plt.subplot(2, 2, 1)
|
|
939
|
+
plt.plot(self.stats["train_loss"], label="训练损失")
|
|
940
|
+
if self.stats["val_loss"]:
|
|
941
|
+
plt.plot(self.stats["val_loss"], label="验证损失")
|
|
942
|
+
plt.title("训练和验证损失")
|
|
943
|
+
plt.xlabel("轮次")
|
|
944
|
+
plt.ylabel("损失")
|
|
945
|
+
plt.legend()
|
|
946
|
+
plt.grid(True)
|
|
947
|
+
|
|
948
|
+
# 绘制准确率曲线
|
|
949
|
+
if self.stats["val_accuracy"]:
|
|
950
|
+
plt.subplot(2, 2, 2)
|
|
951
|
+
plt.plot(self.stats["val_accuracy"], label="验证准确率", color="green")
|
|
952
|
+
plt.title("验证准确率")
|
|
953
|
+
plt.xlabel("轮次")
|
|
954
|
+
plt.ylabel("准确率")
|
|
955
|
+
plt.legend()
|
|
956
|
+
plt.grid(True)
|
|
957
|
+
|
|
958
|
+
# 绘制AUC曲线
|
|
959
|
+
if self.stats["val_auc"]:
|
|
960
|
+
plt.subplot(2, 2, 3)
|
|
961
|
+
plt.plot(self.stats["val_auc"], label="验证AUC", color="purple")
|
|
962
|
+
plt.title("验证AUC")
|
|
963
|
+
plt.xlabel("轮次")
|
|
964
|
+
plt.ylabel("AUC")
|
|
965
|
+
plt.legend()
|
|
966
|
+
plt.grid(True)
|
|
967
|
+
|
|
968
|
+
# 绘制F1曲线
|
|
969
|
+
if self.stats["val_f1"]:
|
|
970
|
+
plt.subplot(2, 2, 4)
|
|
971
|
+
plt.plot(self.stats["val_f1"], label="验证F1分数", color="orange")
|
|
972
|
+
plt.title("验证F1分数")
|
|
973
|
+
plt.xlabel("轮次")
|
|
974
|
+
plt.ylabel("F1分数")
|
|
975
|
+
plt.legend()
|
|
976
|
+
plt.grid(True)
|
|
977
|
+
|
|
978
|
+
plt.tight_layout()
|
|
979
|
+
|
|
980
|
+
# 保存图像
|
|
981
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
982
|
+
plot_path = os.path.join(
|
|
983
|
+
self.config["report_save_path"], f"training_history_{timestamp}.png"
|
|
984
|
+
)
|
|
985
|
+
plt.savefig(plot_path, bbox_inches="tight", dpi=300)
|
|
986
|
+
plt.close()
|
|
987
|
+
print(f"训练历史图保存至: {plot_path}")
|
|
988
|
+
|
|
989
|
+
def plot_roc_curve(self, true_labels, probabilities, save_path=None):
|
|
990
|
+
"""绘制ROC曲线"""
|
|
991
|
+
from sklearn.metrics import roc_curve, auc
|
|
992
|
+
from itertools import cycle
|
|
993
|
+
|
|
994
|
+
if len(self.class_names) == 2:
|
|
995
|
+
# 二分类
|
|
996
|
+
fpr, tpr, _ = roc_curve(true_labels, probabilities[:, 1])
|
|
997
|
+
roc_auc = auc(fpr, tpr)
|
|
998
|
+
|
|
999
|
+
plt.figure()
|
|
1000
|
+
plt.plot(
|
|
1001
|
+
fpr,
|
|
1002
|
+
tpr,
|
|
1003
|
+
color="darkorange",
|
|
1004
|
+
lw=2,
|
|
1005
|
+
label=f"ROC曲线 (AUC = {roc_auc:.2f})",
|
|
1006
|
+
)
|
|
1007
|
+
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
|
1008
|
+
plt.xlim([0.0, 1.0])
|
|
1009
|
+
plt.ylim([0.0, 1.05])
|
|
1010
|
+
plt.xlabel("假正例率")
|
|
1011
|
+
plt.ylabel("真正例率")
|
|
1012
|
+
plt.title("ROC曲线")
|
|
1013
|
+
plt.legend(loc="lower right")
|
|
1014
|
+
else:
|
|
1015
|
+
# 多分类
|
|
1016
|
+
fpr = dict()
|
|
1017
|
+
tpr = dict()
|
|
1018
|
+
roc_auc = dict()
|
|
1019
|
+
|
|
1020
|
+
# 二值化真实标签
|
|
1021
|
+
y_true_bin = label_binarize(
|
|
1022
|
+
true_labels, classes=range(len(self.class_names))
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
# 计算每个类别的ROC曲线和AUC
|
|
1026
|
+
for i in range(len(self.class_names)):
|
|
1027
|
+
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], probabilities[:, i])
|
|
1028
|
+
roc_auc[i] = auc(fpr[i], tpr[i])
|
|
1029
|
+
|
|
1030
|
+
# 计算微平均ROC曲线和AUC
|
|
1031
|
+
fpr["micro"], tpr["micro"], _ = roc_curve(
|
|
1032
|
+
y_true_bin.ravel(), probabilities.ravel()
|
|
1033
|
+
)
|
|
1034
|
+
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
|
1035
|
+
|
|
1036
|
+
# 绘制ROC曲线
|
|
1037
|
+
plt.figure(figsize=(10, 8))
|
|
1038
|
+
colors = cycle(
|
|
1039
|
+
["aqua", "darkorange", "cornflowerblue", "green", "red", "purple"]
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
# 绘制每个类别的ROC曲线
|
|
1043
|
+
for i, color in zip(range(len(self.class_names)), colors):
|
|
1044
|
+
plt.plot(
|
|
1045
|
+
fpr[i],
|
|
1046
|
+
tpr[i],
|
|
1047
|
+
color=color,
|
|
1048
|
+
lw=2,
|
|
1049
|
+
label=f"ROC {self.class_names[i]} (AUC = {roc_auc[i]:.2f})",
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
# 绘制微平均ROC曲线
|
|
1053
|
+
plt.plot(
|
|
1054
|
+
fpr["micro"],
|
|
1055
|
+
tpr["micro"],
|
|
1056
|
+
label=f'Micro-average ROC (AUC = {roc_auc["micro"]:.2f})',
|
|
1057
|
+
color="deeppink",
|
|
1058
|
+
linestyle=":",
|
|
1059
|
+
linewidth=4,
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
plt.plot([0, 1], [0, 1], "k--", lw=2)
|
|
1063
|
+
plt.xlim([0.0, 1.0])
|
|
1064
|
+
plt.ylim([0.0, 1.05])
|
|
1065
|
+
plt.xlabel("假正例率")
|
|
1066
|
+
plt.ylabel("真正例率")
|
|
1067
|
+
plt.title("多类别ROC曲线")
|
|
1068
|
+
plt.legend(loc="lower right")
|
|
1069
|
+
|
|
1070
|
+
if save_path:
|
|
1071
|
+
plt.savefig(save_path, bbox_inches="tight", dpi=300)
|
|
1072
|
+
plt.close()
|
|
1073
|
+
return save_path
|
|
1074
|
+
else:
|
|
1075
|
+
plt.show()
|
|
1076
|
+
|
|
1077
|
+
def plot_pr_curve(self, true_labels, probabilities, save_path=None):
|
|
1078
|
+
"""绘制精确率-召回率曲线"""
|
|
1079
|
+
from sklearn.metrics import precision_recall_curve, average_precision_score
|
|
1080
|
+
from itertools import cycle
|
|
1081
|
+
|
|
1082
|
+
if len(self.class_names) == 2:
|
|
1083
|
+
# 二分类
|
|
1084
|
+
precision, recall, _ = precision_recall_curve(
|
|
1085
|
+
true_labels, probabilities[:, 1]
|
|
1086
|
+
)
|
|
1087
|
+
ap = average_precision_score(true_labels, probabilities[:, 1])
|
|
1088
|
+
|
|
1089
|
+
plt.figure()
|
|
1090
|
+
plt.plot(
|
|
1091
|
+
recall,
|
|
1092
|
+
precision,
|
|
1093
|
+
color="darkorange",
|
|
1094
|
+
lw=2,
|
|
1095
|
+
label=f"PR曲线 (AP = {ap:.2f})",
|
|
1096
|
+
)
|
|
1097
|
+
plt.xlabel("召回率")
|
|
1098
|
+
plt.ylabel("精确率")
|
|
1099
|
+
plt.title("精确率-召回率曲线")
|
|
1100
|
+
plt.legend(loc="upper right")
|
|
1101
|
+
else:
|
|
1102
|
+
# 多分类
|
|
1103
|
+
precision = dict()
|
|
1104
|
+
recall = dict()
|
|
1105
|
+
average_precision = dict()
|
|
1106
|
+
|
|
1107
|
+
# 二值化真实标签
|
|
1108
|
+
y_true_bin = label_binarize(
|
|
1109
|
+
true_labels, classes=range(len(self.class_names))
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
# 计算每个类别的PR曲线和AP
|
|
1113
|
+
for i in range(len(self.class_names)):
|
|
1114
|
+
precision[i], recall[i], _ = precision_recall_curve(
|
|
1115
|
+
y_true_bin[:, i], probabilities[:, i]
|
|
1116
|
+
)
|
|
1117
|
+
average_precision[i] = average_precision_score(
|
|
1118
|
+
y_true_bin[:, i], probabilities[:, i]
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
# 计算微平均PR曲线
|
|
1122
|
+
precision["micro"], recall["micro"], _ = precision_recall_curve(
|
|
1123
|
+
y_true_bin.ravel(), probabilities.ravel()
|
|
1124
|
+
)
|
|
1125
|
+
average_precision["micro"] = average_precision_score(
|
|
1126
|
+
y_true_bin, probabilities, average="micro"
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
# 绘制PR曲线
|
|
1130
|
+
plt.figure(figsize=(10, 8))
|
|
1131
|
+
colors = cycle(
|
|
1132
|
+
["aqua", "darkorange", "cornflowerblue", "green", "red", "purple"]
|
|
1133
|
+
)
|
|
1134
|
+
|
|
1135
|
+
# 绘制每个类别的PR曲线
|
|
1136
|
+
for i, color in zip(range(len(self.class_names)), colors):
|
|
1137
|
+
plt.plot(
|
|
1138
|
+
recall[i],
|
|
1139
|
+
precision[i],
|
|
1140
|
+
color=color,
|
|
1141
|
+
lw=2,
|
|
1142
|
+
label=f"PR {self.class_names[i]} (AP = {average_precision[i]:.2f})",
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
# 绘制微平均PR曲线
|
|
1146
|
+
plt.plot(
|
|
1147
|
+
recall["micro"],
|
|
1148
|
+
precision["micro"],
|
|
1149
|
+
label=f'Micro-average PR (AP = {average_precision["micro"]:.2f})',
|
|
1150
|
+
color="deeppink",
|
|
1151
|
+
linestyle=":",
|
|
1152
|
+
linewidth=4,
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1155
|
+
plt.xlabel("召回率")
|
|
1156
|
+
plt.ylabel("精确率")
|
|
1157
|
+
plt.title("多类别精确率-召回率曲线")
|
|
1158
|
+
plt.legend(loc="upper right")
|
|
1159
|
+
|
|
1160
|
+
if save_path:
|
|
1161
|
+
plt.savefig(save_path, bbox_inches="tight", dpi=300)
|
|
1162
|
+
plt.close()
|
|
1163
|
+
return save_path
|
|
1164
|
+
else:
|
|
1165
|
+
plt.show()
|
|
1166
|
+
|
|
1167
|
+
def generate_grad_cam(
|
|
1168
|
+
self, image_paths, logits, output_dir, layer_names, num_samples=10
|
|
1169
|
+
):
|
|
1170
|
+
"""
|
|
1171
|
+
生成Grad-CAM可视化 - 解释模型决策
|
|
1172
|
+
|
|
1173
|
+
参数:
|
|
1174
|
+
image_paths (list): 图像路径列表
|
|
1175
|
+
logits (list): 模型原始输出列表
|
|
1176
|
+
output_dir (str): 输出目录
|
|
1177
|
+
layer_names (list): 要可视化的层名列表
|
|
1178
|
+
num_samples (int): 要可视化的样本数量
|
|
1179
|
+
"""
|
|
1180
|
+
if not layer_names:
|
|
1181
|
+
return
|
|
1182
|
+
|
|
1183
|
+
# 确保模型在评估模式
|
|
1184
|
+
self.model.eval()
|
|
1185
|
+
|
|
1186
|
+
# 创建Grad-CAM目录
|
|
1187
|
+
cam_dir = os.path.join(output_dir, "grad_cam")
|
|
1188
|
+
os.makedirs(cam_dir, exist_ok=True)
|
|
1189
|
+
|
|
1190
|
+
# 随机选择样本
|
|
1191
|
+
indices = np.random.choice(
|
|
1192
|
+
len(image_paths), min(num_samples, len(image_paths)), replace=False
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
# 对每个样本生成Grad-CAM
|
|
1196
|
+
for idx in indices:
|
|
1197
|
+
img_path = image_paths[idx]
|
|
1198
|
+
logit = logits[idx]
|
|
1199
|
+
predicted_class = np.argmax(logit)
|
|
1200
|
+
|
|
1201
|
+
# 加载原始图像
|
|
1202
|
+
img = cv2.imread(img_path)
|
|
1203
|
+
if img is None:
|
|
1204
|
+
continue
|
|
1205
|
+
|
|
1206
|
+
# 对每个指定的层生成Grad-CAM
|
|
1207
|
+
for layer_name in layer_names:
|
|
1208
|
+
# 生成Grad-CAM
|
|
1209
|
+
cam = self._compute_grad_cam(img_path, layer_name, predicted_class)
|
|
1210
|
+
if cam is None:
|
|
1211
|
+
continue
|
|
1212
|
+
|
|
1213
|
+
# 将CAM覆盖到原始图像上
|
|
1214
|
+
heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
|
|
1215
|
+
superimposed_img = heatmap * 0.4 + img * 0.6
|
|
1216
|
+
|
|
1217
|
+
# 保存结果
|
|
1218
|
+
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
|
1219
|
+
save_path = os.path.join(
|
|
1220
|
+
cam_dir, f"{img_name}_{layer_name}_gradcam.jpg"
|
|
1221
|
+
)
|
|
1222
|
+
cv2.imwrite(save_path, superimposed_img)
|
|
1223
|
+
|
|
1224
|
+
print(f"Grad-CAM可视化保存至: {cam_dir}")
|
|
1225
|
+
|
|
1226
|
+
def _compute_grad_cam(self, image_path, layer_name, target_class):
|
|
1227
|
+
"""计算指定层和类别的Grad-CAM"""
|
|
1228
|
+
# 获取目标层
|
|
1229
|
+
layer = None
|
|
1230
|
+
for name, module in self.model.named_modules():
|
|
1231
|
+
if name == layer_name:
|
|
1232
|
+
layer = module
|
|
1233
|
+
break
|
|
1234
|
+
|
|
1235
|
+
if layer is None:
|
|
1236
|
+
print(f"未找到层: {layer_name}")
|
|
1237
|
+
return None
|
|
1238
|
+
|
|
1239
|
+
# 注册hook
|
|
1240
|
+
activations = []
|
|
1241
|
+
gradients = []
|
|
1242
|
+
|
|
1243
|
+
def forward_hook(module, input, output):
|
|
1244
|
+
activations.append(output.detach())
|
|
1245
|
+
|
|
1246
|
+
def backward_hook(module, grad_input, grad_output):
|
|
1247
|
+
gradients.append(grad_output[0].detach())
|
|
1248
|
+
|
|
1249
|
+
forward_handle = layer.register_forward_hook(forward_hook)
|
|
1250
|
+
backward_handle = layer.register_backward_hook(backward_hook)
|
|
1251
|
+
|
|
1252
|
+
try:
|
|
1253
|
+
# 预处理图像
|
|
1254
|
+
transform = self.get_augmentations()
|
|
1255
|
+
img = cv2.imread(image_path)
|
|
1256
|
+
if img is None:
|
|
1257
|
+
return None
|
|
1258
|
+
|
|
1259
|
+
augmented = transform(image=img)
|
|
1260
|
+
input_tensor = augmented["image"].unsqueeze(0).to(self.config["device"])
|
|
1261
|
+
|
|
1262
|
+
# 前向传播
|
|
1263
|
+
output = self.model(input_tensor)
|
|
1264
|
+
|
|
1265
|
+
# 后向传播
|
|
1266
|
+
self.model.zero_grad()
|
|
1267
|
+
one_hot = torch.zeros_like(output)
|
|
1268
|
+
one_hot[0, target_class] = 1
|
|
1269
|
+
output.backward(gradient=one_hot)
|
|
1270
|
+
|
|
1271
|
+
# 获取激活和梯度
|
|
1272
|
+
if not activations or not gradients:
|
|
1273
|
+
return None
|
|
1274
|
+
|
|
1275
|
+
activations = activations[0].cpu().numpy()[0]
|
|
1276
|
+
gradients = gradients[0].cpu().numpy()[0]
|
|
1277
|
+
|
|
1278
|
+
# 计算权重
|
|
1279
|
+
weights = np.mean(gradients, axis=(1, 2))
|
|
1280
|
+
cam = np.zeros(activations.shape[1:], dtype=np.float32)
|
|
1281
|
+
|
|
1282
|
+
# 计算CAM
|
|
1283
|
+
for i, w in enumerate(weights):
|
|
1284
|
+
cam += w * activations[i]
|
|
1285
|
+
|
|
1286
|
+
# 后处理CAM
|
|
1287
|
+
cam = np.maximum(cam, 0)
|
|
1288
|
+
cam = cam - np.min(cam)
|
|
1289
|
+
cam = cam / np.max(cam) if np.max(cam) > 0 else cam
|
|
1290
|
+
cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
|
|
1291
|
+
|
|
1292
|
+
return (cam * 255).astype(np.uint8)
|
|
1293
|
+
|
|
1294
|
+
finally:
|
|
1295
|
+
# 移除hook
|
|
1296
|
+
forward_handle.remove()
|
|
1297
|
+
backward_handle.remove()
|
|
1298
|
+
|
|
1299
|
+
def generate_prediction_report(
|
|
1300
|
+
self, image_paths, predictions, probabilities, output_dir, num_samples=10
|
|
1301
|
+
):
|
|
1302
|
+
"""
|
|
1303
|
+
生成预测报告
|
|
1304
|
+
|
|
1305
|
+
参数:
|
|
1306
|
+
image_paths (list): 图像路径列表
|
|
1307
|
+
predictions (list): 预测标签列表
|
|
1308
|
+
probabilities (list): 预测概率列表
|
|
1309
|
+
output_dir (str): 输出目录
|
|
1310
|
+
num_samples (int): 报告中包含的样本数量
|
|
1311
|
+
"""
|
|
1312
|
+
# 创建HTML报告
|
|
1313
|
+
html_content = """
|
|
1314
|
+
<!DOCTYPE html>
|
|
1315
|
+
<html>
|
|
1316
|
+
<head>
|
|
1317
|
+
<title>细胞图像分类报告</title>
|
|
1318
|
+
<style>
|
|
1319
|
+
body {{ font-family: Arial, sans-serif; margin: 20px; background-color: #f5f5f5; }}
|
|
1320
|
+
.container {{ max-width: 1400px; margin: auto; background: white; padding: 20px; box-shadow: 0 0 10px rgba(0,0,0,0.1); border-radius: 8px; }}
|
|
1321
|
+
h1, h2 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
|
|
1322
|
+
.grid {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px; margin: 20px 0; }}
|
|
1323
|
+
.cell-card {{ border: 1px solid #ddd; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 5px rgba(0,0,0,0.1); transition: transform 0.3s; }}
|
|
1324
|
+
.cell-card:hover {{ transform: translateY(-5px); box-shadow: 0 5px 15px rgba(0,0,0,0.1); }}
|
|
1325
|
+
.cell-card img {{ width: 100%; height: 180px; object-fit: cover; border-bottom: 1px solid #ddd; }}
|
|
1326
|
+
.card-content {{ padding: 15px; }}
|
|
1327
|
+
.prediction {{ font-weight: bold; font-size: 16px; margin: 10px 0; }}
|
|
1328
|
+
.prob-bar {{ height: 20px; background: #eee; border-radius: 10px; margin: 10px 0; overflow: hidden; }}
|
|
1329
|
+
.prob-fill {{ height: 100%; background: #3498db; border-radius: 10px; }}
|
|
1330
|
+
.prob-text {{ font-size: 14px; color: #555; }}
|
|
1331
|
+
.class-distribution {{ display: flex; flex-wrap: wrap; gap: 15px; margin: 20px 0; }}
|
|
1332
|
+
.class-item {{ background: #f8f9fa; padding: 10px 15px; border-radius: 8px; border-left: 4px solid #3498db; }}
|
|
1333
|
+
table {{ width: 100%; border-collapse: collapse; margin: 20px 0; }}
|
|
1334
|
+
th, td {{ border: 1px solid #ddd; padding: 12px; text-align: left; }}
|
|
1335
|
+
th {{ background-color: #3498db; color: white; }}
|
|
1336
|
+
tr:nth-child(even) {{ background-color: #f2f2f2; }}
|
|
1337
|
+
.metrics {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); gap: 20px; margin: 20px 0; }}
|
|
1338
|
+
.metric-card {{ background: #f8f9fa; padding: 20px; border-radius: 8px; text-align: center; }}
|
|
1339
|
+
.metric-value {{ font-size: 24px; font-weight: bold; color: #3498db; }}
|
|
1340
|
+
.metric-label {{ font-size: 16px; color: #555; }}
|
|
1341
|
+
</style>
|
|
1342
|
+
</head>
|
|
1343
|
+
<body>
|
|
1344
|
+
<div class="container">
|
|
1345
|
+
<h1>细胞图像分类报告</h1>
|
|
1346
|
+
<p><strong>生成时间:</strong> {timestamp}</p>
|
|
1347
|
+
<p><strong>模型:</strong> {model_name}</p>
|
|
1348
|
+
<p><strong>样本总数:</strong> {total_samples}</p>
|
|
1349
|
+
|
|
1350
|
+
<h2>预测分布</h2>
|
|
1351
|
+
<div class="class-distribution">
|
|
1352
|
+
{class_distribution}
|
|
1353
|
+
</div>
|
|
1354
|
+
|
|
1355
|
+
<h2>模型评估指标</h2>
|
|
1356
|
+
<div class="metrics">
|
|
1357
|
+
{metrics_cards}
|
|
1358
|
+
</div>
|
|
1359
|
+
|
|
1360
|
+
<h2>随机样本预测结果</h2>
|
|
1361
|
+
<div class="grid">
|
|
1362
|
+
{sample_images}
|
|
1363
|
+
</div>
|
|
1364
|
+
|
|
1365
|
+
<h2>所有预测结果</h2>
|
|
1366
|
+
{predictions_table}
|
|
1367
|
+
</div>
|
|
1368
|
+
</body>
|
|
1369
|
+
</html>
|
|
1370
|
+
"""
|
|
1371
|
+
|
|
1372
|
+
# 生成类别分布
|
|
1373
|
+
class_counts = {cls: 0 for cls in self.class_names}
|
|
1374
|
+
for pred in predictions:
|
|
1375
|
+
class_counts[pred] += 1
|
|
1376
|
+
|
|
1377
|
+
dist_chart = ""
|
|
1378
|
+
for cls, count in class_counts.items():
|
|
1379
|
+
percentage = count / len(predictions)
|
|
1380
|
+
dist_chart += f"""
|
|
1381
|
+
<div class="class-item">
|
|
1382
|
+
<div><strong>{cls}:</strong> {count} ({percentage:.1%})</div>
|
|
1383
|
+
<div class="prob-bar">
|
|
1384
|
+
<div class="prob-fill" style="width: {percentage*100}%"></div>
|
|
1385
|
+
</div>
|
|
1386
|
+
</div>
|
|
1387
|
+
"""
|
|
1388
|
+
|
|
1389
|
+
# 计算评估指标
|
|
1390
|
+
metrics_cards = ""
|
|
1391
|
+
if hasattr(self, "stats") and self.stats.get("val_auc"):
|
|
1392
|
+
metrics = {
|
|
1393
|
+
"准确率": self.stats["val_accuracy"][-1],
|
|
1394
|
+
"AUC": self.stats["val_auc"][-1],
|
|
1395
|
+
"F1分数": self.stats["val_f1"][-1],
|
|
1396
|
+
"损失": self.stats["val_loss"][-1],
|
|
1397
|
+
}
|
|
1398
|
+
|
|
1399
|
+
for name, value in metrics.items():
|
|
1400
|
+
metrics_cards += f"""
|
|
1401
|
+
<div class="metric-card">
|
|
1402
|
+
<div class="metric-value">{value:.4f}</div>
|
|
1403
|
+
<div class="metric-label">{name}</div>
|
|
1404
|
+
</div>
|
|
1405
|
+
"""
|
|
1406
|
+
|
|
1407
|
+
# 随机选择样本
|
|
1408
|
+
indices = np.random.choice(
|
|
1409
|
+
len(image_paths), min(num_samples, len(image_paths)), replace=False
|
|
1410
|
+
)
|
|
1411
|
+
sample_html = ""
|
|
1412
|
+
|
|
1413
|
+
for idx in indices:
|
|
1414
|
+
img_path = image_paths[idx]
|
|
1415
|
+
pred = predictions[idx]
|
|
1416
|
+
prob = probabilities[idx][self.label_encoder[pred]]
|
|
1417
|
+
|
|
1418
|
+
# 创建图像标签
|
|
1419
|
+
img_tag = f'<img src="{img_path}" alt="Cell Image">'
|
|
1420
|
+
|
|
1421
|
+
# 创建概率条
|
|
1422
|
+
prob_bar = f"""
|
|
1423
|
+
<div class="prob-bar">
|
|
1424
|
+
<div class="prob-fill" style="width: {prob*100:.1f}%"></div>
|
|
1425
|
+
</div>
|
|
1426
|
+
<div class="prob-text">置信度: {prob:.3f}</div>
|
|
1427
|
+
"""
|
|
1428
|
+
|
|
1429
|
+
sample_html += f"""
|
|
1430
|
+
<div class="cell-card">
|
|
1431
|
+
{img_tag}
|
|
1432
|
+
<div class="card-content">
|
|
1433
|
+
<div class="prediction">预测: {pred}</div>
|
|
1434
|
+
{prob_bar}
|
|
1435
|
+
</div>
|
|
1436
|
+
</div>
|
|
1437
|
+
"""
|
|
1438
|
+
|
|
1439
|
+
# 创建预测结果表
|
|
1440
|
+
table_rows = ""
|
|
1441
|
+
for i, (img_path, pred) in enumerate(zip(image_paths, predictions)):
|
|
1442
|
+
prob = probabilities[i][self.label_encoder[pred]]
|
|
1443
|
+
table_rows += f"""
|
|
1444
|
+
<tr>
|
|
1445
|
+
<td>{i+1}</td>
|
|
1446
|
+
<td>{os.path.basename(img_path)}</td>
|
|
1447
|
+
<td>{pred}</td>
|
|
1448
|
+
<td>{prob:.4f}</td>
|
|
1449
|
+
</tr>
|
|
1450
|
+
"""
|
|
1451
|
+
|
|
1452
|
+
predictions_table = f"""
|
|
1453
|
+
<table>
|
|
1454
|
+
<thead>
|
|
1455
|
+
<tr>
|
|
1456
|
+
<th>序号</th>
|
|
1457
|
+
<th>图像</th>
|
|
1458
|
+
<th>预测类别</th>
|
|
1459
|
+
<th>置信度</th>
|
|
1460
|
+
</tr>
|
|
1461
|
+
</thead>
|
|
1462
|
+
<tbody>
|
|
1463
|
+
{table_rows}
|
|
1464
|
+
</tbody>
|
|
1465
|
+
</table>
|
|
1466
|
+
"""
|
|
1467
|
+
|
|
1468
|
+
# 填充HTML模板
|
|
1469
|
+
html_content = html_content.format(
|
|
1470
|
+
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
1471
|
+
model_name=self.config["model_name"],
|
|
1472
|
+
total_samples=len(image_paths),
|
|
1473
|
+
class_distribution=dist_chart,
|
|
1474
|
+
metrics_cards=metrics_cards,
|
|
1475
|
+
sample_images=sample_html,
|
|
1476
|
+
predictions_table=predictions_table,
|
|
1477
|
+
)
|
|
1478
|
+
|
|
1479
|
+
# 保存HTML报告
|
|
1480
|
+
report_path = os.path.join(output_dir, "prediction_report.html")
|
|
1481
|
+
with open(report_path, "w", encoding="utf-8") as f:
|
|
1482
|
+
f.write(html_content)
|
|
1483
|
+
|
|
1484
|
+
print(f"预测报告保存至: {report_path}")
|
|
1485
|
+
return report_path
|
|
1486
|
+
|
|
1487
|
+
|
|
1488
|
+
class CellDataset(Dataset):
|
|
1489
|
+
"""
|
|
1490
|
+
细胞图像数据集类 - 优化版
|
|
1491
|
+
|
|
1492
|
+
改进:
|
|
1493
|
+
1. 添加图像缓存机制 - 加速后续训练
|
|
1494
|
+
2. 更健壮的错误处理
|
|
1495
|
+
3. 支持多种图像格式
|
|
1496
|
+
4. 优化的多通道处理
|
|
1497
|
+
"""
|
|
1498
|
+
|
|
1499
|
+
def __init__(
|
|
1500
|
+
self,
|
|
1501
|
+
image_paths,
|
|
1502
|
+
labels,
|
|
1503
|
+
transform=None,
|
|
1504
|
+
is_train=True,
|
|
1505
|
+
input_channels=3,
|
|
1506
|
+
cache_dir=None,
|
|
1507
|
+
):
|
|
1508
|
+
self.image_paths = image_paths
|
|
1509
|
+
self.labels = labels
|
|
1510
|
+
self.transform = transform
|
|
1511
|
+
self.is_train = is_train
|
|
1512
|
+
self.input_channels = input_channels
|
|
1513
|
+
self.cache_dir = cache_dir
|
|
1514
|
+
|
|
1515
|
+
# 细胞图像特有的增强组合
|
|
1516
|
+
self.cell_specific_aug = A.Compose(
|
|
1517
|
+
[
|
|
1518
|
+
A.OneOf(
|
|
1519
|
+
[
|
|
1520
|
+
A.MotionBlur(blur_limit=3, p=0.3),
|
|
1521
|
+
A.GaussianBlur(blur_limit=3, p=0.3),
|
|
1522
|
+
A.MedianBlur(blur_limit=3, p=0.3),
|
|
1523
|
+
],
|
|
1524
|
+
p=0.5,
|
|
1525
|
+
),
|
|
1526
|
+
A.OneOf(
|
|
1527
|
+
[
|
|
1528
|
+
A.OpticalDistortion(distort_limit=0.5, shift_limit=0.1, p=0.3),
|
|
1529
|
+
A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
|
|
1530
|
+
A.ElasticTransform(alpha=1, sigma=20, alpha_affine=10, p=0.3),
|
|
1531
|
+
],
|
|
1532
|
+
p=0.5,
|
|
1533
|
+
),
|
|
1534
|
+
A.RandomGamma(gamma_limit=(80, 120), p=0.5),
|
|
1535
|
+
]
|
|
1536
|
+
)
|
|
1537
|
+
|
|
1538
|
+
# 检查图像路径和标签是否匹配
|
|
1539
|
+
if labels is not None and len(image_paths) != len(labels):
|
|
1540
|
+
raise ValueError("图像路径数量和标签数量不匹配")
|
|
1541
|
+
|
|
1542
|
+
# 创建缓存目录
|
|
1543
|
+
if cache_dir:
|
|
1544
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
1545
|
+
|
|
1546
|
+
def __len__(self):
|
|
1547
|
+
return len(self.image_paths)
|
|
1548
|
+
# return image, label
|
|
1549
|
+
def __getitem__(self, idx):
|
|
1550
|
+
img_path = self.image_paths[idx]
|
|
1551
|
+
|
|
1552
|
+
# 读取图像 - 确保返回NumPy数组
|
|
1553
|
+
try:
|
|
1554
|
+
# 使用OpenCV读取图像
|
|
1555
|
+
image = cv2.imread(img_path)
|
|
1556
|
+
if image is None:
|
|
1557
|
+
raise FileNotFoundError(f"无法加载图像: {img_path}")
|
|
1558
|
+
|
|
1559
|
+
# 确保图像是NumPy数组
|
|
1560
|
+
if not isinstance(image, np.ndarray):
|
|
1561
|
+
image = np.array(image)
|
|
1562
|
+
|
|
1563
|
+
# 确保图像有正确的通道数
|
|
1564
|
+
if len(image.shape) == 2: # 灰度图像
|
|
1565
|
+
if self.input_channels == 3:
|
|
1566
|
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
1567
|
+
elif self.input_channels == 1:
|
|
1568
|
+
image = np.expand_dims(image, axis=-1)
|
|
1569
|
+
else: # 彩色或多通道图像
|
|
1570
|
+
if image.shape[2] == 3 and self.input_channels == 1:
|
|
1571
|
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
|
1572
|
+
image = np.expand_dims(image, axis=-1)
|
|
1573
|
+
elif image.shape[2] == 4 and self.input_channels == 3:
|
|
1574
|
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
1575
|
+
elif image.shape[2] > self.input_channels:
|
|
1576
|
+
image = image[:, :, :self.input_channels]
|
|
1577
|
+
|
|
1578
|
+
except Exception as e:
|
|
1579
|
+
print(f"加载图像{img_path}出错: {e}")
|
|
1580
|
+
# 返回空白图像作为后备
|
|
1581
|
+
image = np.zeros((self.transform.height, self.transform.width, self.input_channels), dtype=np.uint8)
|
|
1582
|
+
|
|
1583
|
+
if image.ndim == 2:
|
|
1584
|
+
image = np.expand_dims(image, axis=-1)
|
|
1585
|
+
|
|
1586
|
+
if image.shape[-1] == 1 and self.input_channels == 3:
|
|
1587
|
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
1588
|
+
elif image.shape[-1] == 3 and self.input_channels == 1:
|
|
1589
|
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
|
1590
|
+
image = np.expand_dims(image, axis=-1)
|
|
1591
|
+
|
|
1592
|
+
if image.dtype != np.uint8:
|
|
1593
|
+
image = image.astype(np.uint8)
|
|
1594
|
+
# 应用数据增强
|
|
1595
|
+
if self.transform:
|
|
1596
|
+
try:
|
|
1597
|
+
# albumentations 要求 image 是 NumPy array
|
|
1598
|
+
if isinstance(image, torch.Tensor):
|
|
1599
|
+
image = image.numpy()
|
|
1600
|
+
if isinstance(image, Image.Image): # PIL image
|
|
1601
|
+
image = np.array(image)
|
|
1602
|
+
|
|
1603
|
+
if image.ndim == 3 and image.shape[2] == 1:
|
|
1604
|
+
image = np.squeeze(image, axis=2)
|
|
1605
|
+
|
|
1606
|
+
augmented = self.transform(image=image)
|
|
1607
|
+
image = augmented['image']
|
|
1608
|
+
|
|
1609
|
+
# 如果是训练集,应用细胞特有增强
|
|
1610
|
+
if self.is_train and self.cell_specific_aug:
|
|
1611
|
+
if isinstance(image, torch.Tensor):
|
|
1612
|
+
image = image.numpy()
|
|
1613
|
+
if isinstance(image, Image.Image):
|
|
1614
|
+
image = np.array(image)
|
|
1615
|
+
|
|
1616
|
+
if image.ndim == 3 and image.shape[2] == 1:
|
|
1617
|
+
image = np.squeeze(image, axis=2)
|
|
1618
|
+
|
|
1619
|
+
augmented = self.cell_specific_aug(image=image)
|
|
1620
|
+
image = augmented['image']
|
|
1621
|
+
except Exception as e:
|
|
1622
|
+
print(f"数据增强出错: {e}")
|
|
1623
|
+
|
|
1624
|
+
|
|
1625
|
+
# 获取标签
|
|
1626
|
+
label = self.labels[idx] if self.labels is not None else -1
|
|
1627
|
+
|
|
1628
|
+
# 确保最终输出是张量
|
|
1629
|
+
if not isinstance(image, torch.Tensor):
|
|
1630
|
+
image = torch.from_numpy(image)
|
|
1631
|
+
|
|
1632
|
+
return image, label
|