py2ls 0.1.10.12__py3-none-any.whl → 0.2.7.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of py2ls might be problematic. Click here for more details.

Files changed (72) hide show
  1. py2ls/.DS_Store +0 -0
  2. py2ls/.git/.DS_Store +0 -0
  3. py2ls/.git/index +0 -0
  4. py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
  5. py2ls/.git/objects/.DS_Store +0 -0
  6. py2ls/.git/refs/.DS_Store +0 -0
  7. py2ls/ImageLoader.py +621 -0
  8. py2ls/__init__.py +7 -5
  9. py2ls/apptainer2ls.py +3940 -0
  10. py2ls/batman.py +164 -42
  11. py2ls/bio.py +2595 -0
  12. py2ls/cell_image_clf.py +1632 -0
  13. py2ls/container2ls.py +4635 -0
  14. py2ls/corr.py +475 -0
  15. py2ls/data/.DS_Store +0 -0
  16. py2ls/data/email/email_html_template.html +88 -0
  17. py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
  18. py2ls/data/hyper_param_tabrepo_2024.py +1753 -0
  19. py2ls/data/mygenes_fields_241022.txt +355 -0
  20. py2ls/data/re_common_pattern.json +173 -0
  21. py2ls/data/sns_info.json +74 -0
  22. py2ls/data/styles/.DS_Store +0 -0
  23. py2ls/data/styles/example/.DS_Store +0 -0
  24. py2ls/data/styles/stylelib/.DS_Store +0 -0
  25. py2ls/data/styles/stylelib/grid.mplstyle +15 -0
  26. py2ls/data/styles/stylelib/high-contrast.mplstyle +6 -0
  27. py2ls/data/styles/stylelib/high-vis.mplstyle +4 -0
  28. py2ls/data/styles/stylelib/ieee.mplstyle +15 -0
  29. py2ls/data/styles/stylelib/light.mplstyl +6 -0
  30. py2ls/data/styles/stylelib/muted.mplstyle +6 -0
  31. py2ls/data/styles/stylelib/nature-reviews-latex.mplstyle +616 -0
  32. py2ls/data/styles/stylelib/nature-reviews.mplstyle +616 -0
  33. py2ls/data/styles/stylelib/nature.mplstyle +31 -0
  34. py2ls/data/styles/stylelib/no-latex.mplstyle +10 -0
  35. py2ls/data/styles/stylelib/notebook.mplstyle +36 -0
  36. py2ls/data/styles/stylelib/paper.mplstyle +290 -0
  37. py2ls/data/styles/stylelib/paper2.mplstyle +305 -0
  38. py2ls/data/styles/stylelib/retro.mplstyle +4 -0
  39. py2ls/data/styles/stylelib/sans.mplstyle +10 -0
  40. py2ls/data/styles/stylelib/scatter.mplstyle +7 -0
  41. py2ls/data/styles/stylelib/science.mplstyle +48 -0
  42. py2ls/data/styles/stylelib/std-colors.mplstyle +4 -0
  43. py2ls/data/styles/stylelib/vibrant.mplstyle +6 -0
  44. py2ls/data/tiles.csv +146 -0
  45. py2ls/data/usages_pd.json +1417 -0
  46. py2ls/data/usages_sns.json +31 -0
  47. py2ls/docker2ls.py +5446 -0
  48. py2ls/ec2ls.py +61 -0
  49. py2ls/fetch_update.py +145 -0
  50. py2ls/ich2ls.py +1955 -296
  51. py2ls/im2.py +8242 -0
  52. py2ls/image_ml2ls.py +2100 -0
  53. py2ls/ips.py +33909 -3418
  54. py2ls/ml2ls.py +7700 -0
  55. py2ls/mol.py +289 -0
  56. py2ls/mount2ls.py +1307 -0
  57. py2ls/netfinder.py +873 -351
  58. py2ls/nl2ls.py +283 -0
  59. py2ls/ocr.py +1581 -458
  60. py2ls/plot.py +10394 -314
  61. py2ls/rna2ls.py +311 -0
  62. py2ls/ssh2ls.md +456 -0
  63. py2ls/ssh2ls.py +5933 -0
  64. py2ls/ssh2ls_v01.py +2204 -0
  65. py2ls/stats.py +66 -172
  66. py2ls/temp20251124.py +509 -0
  67. py2ls/translator.py +2 -0
  68. py2ls/utils/decorators.py +3564 -0
  69. py2ls/utils_bio.py +3453 -0
  70. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/METADATA +113 -224
  71. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/RECORD +72 -16
  72. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1632 @@
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset, DataLoader, random_split
8
+ from torchvision import models
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+ from sklearn.model_selection import train_test_split
12
+ from sklearn.metrics import (
13
+ classification_report,
14
+ confusion_matrix,
15
+ roc_auc_score,
16
+ average_precision_score,
17
+ )
18
+ from sklearn.preprocessing import label_binarize
19
+ import matplotlib.pyplot as plt
20
+ import seaborn as sns
21
+ import pandas as pd
22
+ import joblib
23
+ from tqdm import tqdm
24
+ import json
25
+ from datetime import datetime
26
+ import time
27
+ from PIL import Image
28
+ import gc
29
+ from torch.cuda import amp
30
+ from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
31
+ import warnings
32
+
33
+ # 忽略警告
34
+ warnings.filterwarnings("ignore")
35
+
36
+ from .ips import set_computing_device
37
+ set_computing_device()
38
+
39
+ class CellImageClassifier:
40
+ """
41
+ 终极细胞图像分类系统
42
+
43
+ 主要改进:
44
+ 1. 支持混合精度训练(AMP) - 加速训练并减少显存占用
45
+ 2. 添加多种学习率调度器
46
+ 3. 增强多通道图像处理能力
47
+ 4. 添加模型解释功能(Grad-CAM)
48
+ 5. 改进数据加载效率(使用缓存)
49
+ 6. 添加更详细的评估指标(AUC-ROC, PR曲线等)
50
+ 7. 增强错误处理和日志记录
51
+ """
52
+
53
+ def __init__(self, config=None):
54
+ # 默认配置
55
+ self.default_config = {
56
+ "model_name": "resnet18", # 支持的模型: resnet18, resnet50, efficientnet_b0-3, densenet121, vgg16, convnext_tiny
57
+ "num_classes": 5, # 分类类别数
58
+ "input_channels": 3, # 输入通道数
59
+ "image_size": (256, 256), # 输入图像尺寸
60
+ "batch_size": 16, # 批量大小
61
+ "learning_rate": 0.001, # 初始学习率
62
+ "epochs": 30, # 训练轮数
63
+ "device": self._auto_select_device(), # 自动选择设备
64
+ "augmentation_level": "high", # 增强级别: none, low, medium, high
65
+ "staining_type": "generic", # 染色类型: generic, H&E, fluorescence, IHC
66
+ "model_save_path": "models", # 模型保存路径
67
+ "report_save_path": "reports", # 报告保存路径
68
+ "data_cache_path": "data_cache", # 数据缓存路径
69
+ "use_amp": True, # 启用混合精度训练
70
+ "lr_scheduler": "plateau", # 学习率调度器: plateau, cosine, step
71
+ "optimizer": "adam", # 优化器: adam, sgd
72
+ "early_stopping": 5, # 早停轮数(0表示禁用)
73
+ "class_weights": None, # 类别权重(处理不平衡数据)
74
+ "grad_cam_layers": ["layer4"], # Grad-CAM可视化层
75
+ }
76
+
77
+ # 合并用户配置
78
+ self.config = {**self.default_config, **(config or {})}
79
+
80
+ # 初始化模型
81
+ self.model = None
82
+ self.class_names = None
83
+ self.label_encoder = None
84
+ self.stats = {
85
+ "train_loss": [],
86
+ "val_loss": [],
87
+ "val_accuracy": [],
88
+ "val_auc": [],
89
+ "val_f1": [],
90
+ }
91
+
92
+ # 创建必要的目录
93
+ os.makedirs(self.config["model_save_path"], exist_ok=True)
94
+ os.makedirs(self.config["report_save_path"], exist_ok=True)
95
+ os.makedirs(self.config["data_cache_path"], exist_ok=True)
96
+
97
+ # 初始化混合精度训练
98
+ self.scaler = amp.GradScaler(enabled=self.config["use_amp"])
99
+
100
+ print(f"细胞图像分类器初始化完成 | 设备: {self.config['device']}")
101
+
102
+ @staticmethod
103
+ def _auto_select_device():
104
+ """自动选择最佳计算设备"""
105
+ if torch.cuda.is_available():
106
+ return "cuda"
107
+ elif torch.backends.mps.is_available():
108
+ return "mps"
109
+ else:
110
+ return "cpu"
111
+
112
+ def _clear_cache(self):
113
+ """清除缓存以释放内存"""
114
+ torch.cuda.empty_cache()
115
+ gc.collect()
116
+
117
+ def build_model(self):
118
+ """创建图像分类模型"""
119
+ model_name = self.config["model_name"].lower()
120
+ num_classes = self.config["num_classes"]
121
+ in_channels = self.config["input_channels"]
122
+
123
+ # 支持的模型列表
124
+ supported_models = {
125
+ "resnet18": models.resnet18,
126
+ "resnet34": models.resnet34,
127
+ "resnet50": models.resnet50,
128
+ "resnet101": models.resnet101,
129
+ "efficientnet_b0": models.efficientnet_b0,
130
+ "efficientnet_b3": models.efficientnet_b3,
131
+ "densenet121": models.densenet121,
132
+ "vgg16": models.vgg16,
133
+ "convnext_tiny": models.convnext_tiny,
134
+ }
135
+
136
+ if model_name not in supported_models:
137
+ raise ValueError(
138
+ f"不支持的模型: {model_name}。支持: {list(supported_models.keys())}"
139
+ )
140
+
141
+ # 创建预训练模型
142
+ try:
143
+ model_func = supported_models[model_name]
144
+ pretrained_model = model_func(pretrained=True)
145
+ except Exception as e:
146
+ print(f"加载预训练模型失败: {e}, 使用随机初始化")
147
+ pretrained_model = model_func(pretrained=False)
148
+
149
+ # 根据模型类型调整输入通道
150
+ if "resnet" in model_name:
151
+ # 修改ResNet的第一层卷积
152
+ original_conv1 = pretrained_model.conv1
153
+ pretrained_model.conv1 = nn.Conv2d(
154
+ in_channels=in_channels,
155
+ out_channels=original_conv1.out_channels,
156
+ kernel_size=original_conv1.kernel_size,
157
+ stride=original_conv1.stride,
158
+ padding=original_conv1.padding,
159
+ bias=original_conv1.bias,
160
+ )
161
+ # 修改最后的全连接层
162
+ num_features = pretrained_model.fc.in_features
163
+ pretrained_model.fc = nn.Sequential(
164
+ nn.Dropout(0.5),
165
+ nn.Linear(num_features, 512),
166
+ nn.ReLU(),
167
+ nn.Dropout(0.3),
168
+ nn.Linear(512, num_classes),
169
+ )
170
+
171
+ elif "efficientnet" in model_name:
172
+ # 修改EfficientNet的第一层卷积
173
+ original_conv = pretrained_model.features[0][0]
174
+ pretrained_model.features[0][0] = nn.Conv2d(
175
+ in_channels=in_channels,
176
+ out_channels=original_conv.out_channels,
177
+ kernel_size=original_conv.kernel_size,
178
+ stride=original_conv.stride,
179
+ padding=original_conv.padding,
180
+ bias=False,
181
+ )
182
+ # 修改最后的分类层
183
+ num_features = pretrained_model.classifier[1].in_features
184
+ pretrained_model.classifier = nn.Sequential(
185
+ nn.Dropout(0.4), nn.Linear(num_features, num_classes)
186
+ )
187
+
188
+ elif "densenet" in model_name:
189
+ # 修改DenseNet的第一层卷积
190
+ original_conv = pretrained_model.features.conv0
191
+ pretrained_model.features.conv0 = nn.Conv2d(
192
+ in_channels=in_channels,
193
+ out_channels=original_conv.out_channels,
194
+ kernel_size=original_conv.kernel_size,
195
+ stride=original_conv.stride,
196
+ padding=original_conv.padding,
197
+ bias=False,
198
+ )
199
+ # 修改最后的分类层
200
+ num_features = pretrained_model.classifier.in_features
201
+ pretrained_model.classifier = nn.Linear(num_features, num_classes)
202
+
203
+ elif "vgg" in model_name:
204
+ # 修改VGG的第一层卷积
205
+ original_conv = pretrained_model.features[0]
206
+ pretrained_model.features[0] = nn.Conv2d(
207
+ in_channels=in_channels,
208
+ out_channels=original_conv.out_channels,
209
+ kernel_size=original_conv.kernel_size,
210
+ stride=original_conv.stride,
211
+ padding=original_conv.padding,
212
+ )
213
+ # 修改最后的分类层
214
+ num_features = pretrained_model.classifier[6].in_features
215
+ pretrained_model.classifier[6] = nn.Linear(num_features, num_classes)
216
+
217
+ elif "convnext" in model_name:
218
+ # 修改ConvNeXt的第一层卷积
219
+ original_conv = pretrained_model.features[0][0]
220
+ pretrained_model.features[0][0] = nn.Conv2d(
221
+ in_channels=in_channels,
222
+ out_channels=original_conv.out_channels,
223
+ kernel_size=original_conv.kernel_size,
224
+ stride=original_conv.stride,
225
+ padding=original_conv.padding,
226
+ )
227
+ # 修改最后的分类层
228
+ num_features = pretrained_model.classifier[2].in_features
229
+ pretrained_model.classifier[2] = nn.Linear(num_features, num_classes)
230
+
231
+ self.model = pretrained_model.to(self.config["device"])
232
+ print(
233
+ f"创建模型: {model_name} | 输入通道: {in_channels} | 类别数: {num_classes}"
234
+ )
235
+ return self.model
236
+
237
+ def get_augmentations(self):
238
+ """获取数据增强管道 - 根据染色类型和增强级别定制"""
239
+ staining_type = self.config["staining_type"]
240
+ aug_level = self.config["augmentation_level"]
241
+ image_size = self.config["image_size"]
242
+
243
+ # 基础转换(始终应用)
244
+ base_transforms = [
245
+ A.Resize(image_size[0], image_size[1]),
246
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
247
+ ToTensorV2(),
248
+ ]
249
+
250
+ # 根据增强级别选择增强
251
+ if aug_level == "none":
252
+ return A.Compose(base_transforms)
253
+
254
+ # 通用增强
255
+ common_aug = [
256
+ A.HorizontalFlip(p=0.5),
257
+ A.VerticalFlip(p=0.5),
258
+ A.Rotate(limit=30, p=0.5),
259
+ A.ShiftScaleRotate(
260
+ shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5
261
+ ),
262
+ A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
263
+ ]
264
+
265
+ # 染色类型特定的增强
266
+ staining_aug = []
267
+ if staining_type == "H&E":
268
+ # H&E染色增强 - 增强细胞核与细胞质的对比度
269
+ staining_aug = [
270
+ A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=0.7),
271
+ A.HueSaturationValue(
272
+ hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=10, p=0.5
273
+ ),
274
+ A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.3),
275
+ A.RandomGamma(gamma_limit=(70, 130), p=0.5), # 增强对比度
276
+ ]
277
+ elif staining_type == "fluorescence":
278
+ # 荧光染色增强 - 保持通道关系
279
+ staining_aug = [
280
+ A.RandomGamma(gamma_limit=(80, 120), p=0.5),
281
+ A.MultiplicativeNoise(multiplier=[0.9, 1.1], elementwise=True, p=0.2),
282
+ A.ChannelShuffle(p=0.1), # 模拟通道错位
283
+ A.ChannelDropout(
284
+ channel_drop_range=(1, 1), fill_value=0, p=0.1
285
+ ), # 模拟通道丢失
286
+ ]
287
+ elif staining_type == "IHC":
288
+ # 免疫组化染色增强 - 增强棕褐色沉淀
289
+ staining_aug = [
290
+ A.ColorJitter(
291
+ brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.7
292
+ ),
293
+ A.RGBShift(r_shift_limit=20, g_shift_limit=10, b_shift_limit=5, p=0.5),
294
+ A.Sharpen(alpha=(0.2, 0.5), p=0.3), # 增强细节
295
+ ]
296
+ else: # generic
297
+ staining_aug = [
298
+ A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
299
+ A.HueSaturationValue(
300
+ hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=5, p=0.5
301
+ ),
302
+ A.RandomGamma(gamma_limit=(80, 120), p=0.5),
303
+ ]
304
+
305
+ # 高级增强
306
+ advanced_aug = []
307
+ if aug_level in ["medium", "high"]:
308
+ advanced_aug = [
309
+ A.OneOf(
310
+ [
311
+ A.MotionBlur(blur_limit=3, p=0.3),
312
+ A.GaussianBlur(blur_limit=3, p=0.3),
313
+ A.MedianBlur(blur_limit=3, p=0.3),
314
+ ],
315
+ p=0.5,
316
+ ),
317
+ A.OneOf(
318
+ [
319
+ A.OpticalDistortion(distort_limit=0.5, shift_limit=0.1, p=0.3),
320
+ A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
321
+ A.ElasticTransform(alpha=1, sigma=20, alpha_affine=10, p=0.3),
322
+ ],
323
+ p=0.5,
324
+ ),
325
+ A.CoarseDropout(max_holes=3, max_height=20, max_width=20, p=0.3),
326
+ ]
327
+
328
+ if aug_level == "high":
329
+ advanced_aug += [
330
+ A.RandomShadow(
331
+ shadow_roi=(0, 0.5, 1, 1),
332
+ num_shadows_lower=1,
333
+ num_shadows_upper=2,
334
+ p=0.2,
335
+ ),
336
+ A.RandomSunFlare(src_radius=100, p=0.1),
337
+ A.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.3, p=0.1),
338
+ A.RandomRain(p=0.1), # 模拟显微镜上的水滴
339
+ ]
340
+
341
+ # 组合所有增强
342
+ all_aug = common_aug + staining_aug + advanced_aug
343
+ return A.Compose(all_aug + base_transforms)
344
+
345
+ def create_dataset(self, image_paths, labels, is_train=True):
346
+ """
347
+ 创建细胞图像数据集
348
+
349
+ 参数:
350
+ image_paths (list): 图像路径列表
351
+ labels (list): 对应标签列表
352
+ is_train (bool): 是否为训练集
353
+
354
+ 返回:
355
+ CellDataset: 自定义数据集对象
356
+ """
357
+ # 如果没有标签,创建虚拟标签(用于预测)
358
+ if labels is None:
359
+ labels = [-1] * len(image_paths)
360
+
361
+ # 创建标签编码器(如果是第一次)
362
+ if self.label_encoder is None:
363
+ unique_labels = sorted(set(labels))
364
+ self.label_encoder = {label: idx for idx, label in enumerate(unique_labels)}
365
+ self.class_names = list(self.label_encoder.keys())
366
+ self.config["num_classes"] = len(self.class_names)
367
+ print(f"创建标签编码器 | 类别: {self.class_names}")
368
+
369
+ # 将标签编码为数字
370
+ encoded_labels = [self.label_encoder.get(l, -1) for l in labels]
371
+
372
+ # 获取增强管道
373
+ transform = self.get_augmentations() if is_train else self.get_augmentations()
374
+
375
+ return CellDataset(
376
+ image_paths=image_paths,
377
+ labels=encoded_labels,
378
+ transform=transform,
379
+ is_train=is_train,
380
+ input_channels=self.config["input_channels"],
381
+ cache_dir=self.config["data_cache_path"] if is_train else None,
382
+ )
383
+
384
+ def train(self, image_paths, labels, val_split=0.2, save_best=True):
385
+ """
386
+ 训练细胞图像分类模型
387
+
388
+ 参数:
389
+ image_paths (list): 图像路径列表
390
+ labels (list): 对应标签列表
391
+ val_split (float): 验证集比例
392
+ save_best (bool): 是否保存最佳模型
393
+ """
394
+ # 清除缓存
395
+ self._clear_cache()
396
+
397
+ # 创建数据集
398
+ full_dataset = self.create_dataset(image_paths, labels, is_train=True)
399
+
400
+ # 分割训练集和验证集
401
+ if val_split > 0:
402
+ val_size = int(len(full_dataset) * val_split)
403
+ train_size = len(full_dataset) - val_size
404
+ train_dataset, val_dataset = random_split(
405
+ full_dataset,
406
+ [train_size, val_size],
407
+ generator=torch.Generator().manual_seed(42),
408
+ )
409
+ else:
410
+ train_dataset = full_dataset
411
+ val_dataset = None
412
+
413
+ # 创建数据加载器
414
+ train_loader = DataLoader(
415
+ train_dataset,
416
+ batch_size=self.config["batch_size"],
417
+ shuffle=True,
418
+ num_workers=0,#min(4, os.cpu_count()),
419
+ pin_memory=True,
420
+ persistent_workers=True,
421
+ )
422
+
423
+ val_loader = None
424
+ if val_dataset:
425
+ val_loader = DataLoader(
426
+ val_dataset,
427
+ batch_size=self.config["batch_size"],
428
+ shuffle=False,
429
+ num_workers=0,#min(2, os.cpu_count()),
430
+ pin_memory=True,
431
+ )
432
+
433
+ # 创建模型(如果尚未创建)
434
+ if self.model is None:
435
+ self.build_model()
436
+
437
+ # 设置优化器
438
+ optimizer_name = self.config["optimizer"].lower()
439
+ if optimizer_name == "sgd":
440
+ optimizer = optim.SGD(
441
+ self.model.parameters(),
442
+ lr=self.config["learning_rate"],
443
+ momentum=0.9,
444
+ weight_decay=1e-4,
445
+ )
446
+ else: # 默认使用Adam
447
+ optimizer = optim.Adam(
448
+ self.model.parameters(),
449
+ lr=self.config["learning_rate"],
450
+ weight_decay=1e-4,
451
+ )
452
+
453
+ # 设置损失函数(考虑类别不平衡)
454
+ if self.config["class_weights"]:
455
+ weights = torch.tensor(
456
+ self.config["class_weights"], device=self.config["device"]
457
+ )
458
+ criterion = nn.CrossEntropyLoss(weight=weights)
459
+ else:
460
+ criterion = nn.CrossEntropyLoss()
461
+
462
+ # 设置学习率调度器
463
+ scheduler_name = self.config["lr_scheduler"].lower()
464
+ if scheduler_name == "cosine":
465
+ scheduler = CosineAnnealingLR(optimizer, T_max=self.config["epochs"])
466
+ elif scheduler_name == "step":
467
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
468
+ else: # 默认使用Plateau
469
+ scheduler = ReduceLROnPlateau(
470
+ optimizer, mode="min", patience=3, factor=0.5, verbose=True
471
+ )
472
+
473
+ # 训练循环
474
+ best_val_loss = float("inf")
475
+ best_val_accuracy = 0.0
476
+ epochs_no_improve = 0
477
+
478
+ for epoch in range(self.config["epochs"]):
479
+ start_time = time.time()
480
+
481
+ # 训练阶段
482
+ self.model.train()
483
+ train_loss = 0.0
484
+ progress_bar = tqdm(
485
+ train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']} [Train]"
486
+ )
487
+
488
+ for images, labels in progress_bar:
489
+ images, labels = images.to(self.config["device"]), labels.to(
490
+ self.config["device"]
491
+ )
492
+
493
+ # 混合精度训练
494
+ with amp.autocast(enabled=self.config["use_amp"]):
495
+ outputs = self.model(images)
496
+ loss = criterion(outputs, labels)
497
+
498
+ # 反向传播
499
+ optimizer.zero_grad()
500
+ self.scaler.scale(loss).backward()
501
+ self.scaler.step(optimizer)
502
+ self.scaler.update()
503
+
504
+ train_loss += loss.item() * images.size(0)
505
+ progress_bar.set_postfix(loss=loss.item())
506
+
507
+ # 计算平均训练损失
508
+ train_loss = train_loss / len(train_loader.dataset)
509
+ self.stats["train_loss"].append(train_loss)
510
+
511
+ # 验证阶段
512
+ val_loss = 0.0
513
+ val_accuracy = 0.0
514
+ all_labels = []
515
+ all_preds = []
516
+
517
+ if val_loader:
518
+ self.model.eval()
519
+ correct = 0
520
+ total = 0
521
+ progress_bar = tqdm(
522
+ val_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']} [Val]"
523
+ )
524
+
525
+ with torch.no_grad():
526
+ for images, labels in progress_bar:
527
+ images, labels = images.to(self.config["device"]), labels.to(
528
+ self.config["device"]
529
+ )
530
+
531
+ with amp.autocast(enabled=self.config["use_amp"]):
532
+ outputs = self.model(images)
533
+ loss = criterion(outputs, labels)
534
+
535
+ val_loss += loss.item() * images.size(0)
536
+
537
+ _, predicted = torch.max(outputs.data, 1)
538
+ total += labels.size(0)
539
+ correct += (predicted == labels).sum().item()
540
+
541
+ # 收集预测结果用于计算高级指标
542
+ all_labels.extend(labels.cpu().numpy())
543
+ all_preds.extend(predicted.cpu().numpy())
544
+
545
+ accuracy = correct / total
546
+ progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy)
547
+
548
+ # 计算验证指标
549
+ val_loss = val_loss / len(val_loader.dataset)
550
+ val_accuracy = correct / total
551
+ self.stats["val_loss"].append(val_loss)
552
+ self.stats["val_accuracy"].append(val_accuracy)
553
+
554
+ # 计算AUC和F1分数
555
+ if len(self.class_names) > 2:
556
+ # 多分类AUC
557
+ y_true_bin = label_binarize(
558
+ all_labels, classes=range(len(self.class_names))
559
+ )
560
+ y_pred_bin = label_binarize(
561
+ all_preds, classes=range(len(self.class_names))
562
+ )
563
+ auc = roc_auc_score(y_true_bin, y_pred_bin, multi_class="ovr")
564
+ f1 = average_precision_score(y_true_bin, y_pred_bin)
565
+ else:
566
+ # 二分类AUC
567
+ auc = roc_auc_score(all_labels, all_preds)
568
+ f1 = average_precision_score(all_labels, all_preds)
569
+
570
+ self.stats["val_auc"].append(auc)
571
+ self.stats["val_f1"].append(f1)
572
+
573
+ # 更新学习率
574
+ if scheduler_name == "plateau":
575
+ scheduler.step(val_loss)
576
+ else:
577
+ scheduler.step()
578
+
579
+ # 保存最佳模型
580
+ if save_best and val_accuracy > best_val_accuracy:
581
+ best_val_accuracy = val_accuracy
582
+ best_val_loss = val_loss
583
+ epochs_no_improve = 0
584
+ self.save_model("best_model.pth")
585
+ print(
586
+ f"保存最佳模型 | 验证准确率: {val_accuracy:.4f} | AUC: {auc:.4f}"
587
+ )
588
+ else:
589
+ epochs_no_improve += 1
590
+ if (
591
+ self.config["early_stopping"] > 0
592
+ and epochs_no_improve >= self.config["early_stopping"]
593
+ ):
594
+ print(f"早停触发: 验证准确率连续{epochs_no_improve}轮未提升")
595
+ break
596
+ else:
597
+ # 如果没有验证集,只更新训练损失
598
+ if scheduler_name != "plateau":
599
+ scheduler.step()
600
+
601
+ # 计算epoch时间
602
+ epoch_time = time.time() - start_time
603
+
604
+ # 打印epoch总结
605
+ if val_loader:
606
+ print(
607
+ f"Epoch {epoch+1}/{self.config['epochs']} | "
608
+ f"Time: {epoch_time:.1f}s | "
609
+ f"Train Loss: {train_loss:.4f} | "
610
+ f"Val Loss: {val_loss:.4f} | "
611
+ f"Val Acc: {val_accuracy:.4f} | "
612
+ f"AUC: {auc:.4f} | "
613
+ f"F1: {f1:.4f}"
614
+ )
615
+ else:
616
+ print(
617
+ f"Epoch {epoch+1}/{self.config['epochs']} | "
618
+ f"Time: {epoch_time:.1f}s | "
619
+ f"Train Loss: {train_loss:.4f}"
620
+ )
621
+
622
+ # 保存最终模型
623
+ self.save_model("final_model.pth")
624
+ print("训练完成!")
625
+
626
+ # 可视化训练过程
627
+ self.plot_training_history()
628
+
629
+ return self.stats
630
+
631
+ def predict(self, image_paths, output_dir=None):
632
+ """
633
+ 使用训练好的模型进行预测
634
+
635
+ 参数:
636
+ image_paths (list): 图像路径列表
637
+ output_dir (str): 预测结果保存目录
638
+
639
+ 返回:
640
+ tuple: (预测标签列表, 预测概率数组)
641
+ """
642
+ if self.model is None:
643
+ raise RuntimeError("模型未加载,请先加载或训练模型")
644
+
645
+ # 清除缓存
646
+ self._clear_cache()
647
+
648
+ # 创建预测数据集
649
+ dataset = self.create_dataset(image_paths, labels=None, is_train=False)
650
+ dataloader = DataLoader(
651
+ dataset,
652
+ batch_size=self.config["batch_size"],
653
+ shuffle=False,
654
+ num_workers=0,#min(2, os.cpu_count()),
655
+ pin_memory=True,
656
+ )
657
+ # print("---debug---")
658
+ # print("如果参数里有NaN或Inf,说明权重加载或保存有问题,需要重新加载或重新训练模型。")
659
+ # for name, param in self.model.named_parameters():
660
+ # if param.requires_grad:
661
+ # print(f"{name} mean: {param.data.mean().item()}, std: {param.data.std().item()}")
662
+ # if torch.isnan(param).any():
663
+ # print(f"参数{name}中含NaN")
664
+ # if torch.isinf(param).any():
665
+ # print(f"参数{name}中含Inf")
666
+ # print("---debug---")
667
+ # 预测
668
+ self.model.eval()
669
+ all_predictions = []
670
+ all_probabilities = []
671
+ all_logits = []
672
+ with torch.no_grad():
673
+ for images, _ in tqdm(dataloader, desc="预测中"):
674
+ images = images.to(self.config["device"])
675
+ # ---- debug ----
676
+ # sample_img = images[0].unsqueeze(0).to(self.config["device"])
677
+ # output = self.model(sample_img)
678
+ # print("单张图像模型输出:", output)
679
+ # print("是否含NaN:", torch.isnan(output).any().item())
680
+ # print("是否含Inf:", torch.isinf(output).any().item())
681
+ # ---- debug ----
682
+
683
+ # print("输入数据 min:", images.min().item())
684
+ # print("输入数据 max:", images.max().item())
685
+ # print("输入数据 has NaN:", torch.isnan(images).any().item())
686
+ # print("输入数据 has Inf:", torch.isinf(images).any().item())
687
+ # break # 只打印第一个batch,防止输出太多
688
+
689
+ with amp.autocast(enabled=self.config["use_amp"]):
690
+ outputs = self.model(images)
691
+ # ----debug----
692
+ # print("outputs min:", outputs.min().item())
693
+ # print("outputs max:", outputs.max().item())
694
+ # print("outputs has NaN:", torch.isnan(outputs).any().item())
695
+ # print("outputs has Inf:", torch.isinf(outputs).any().item())
696
+ # ----debug----
697
+
698
+ probabilities = torch.softmax(outputs, dim=1)
699
+ _, predicted = torch.max(outputs, 1)
700
+
701
+ all_predictions.extend(predicted.cpu().numpy())
702
+ all_probabilities.extend(probabilities.cpu().numpy())
703
+ all_logits.extend(outputs.cpu().numpy())
704
+
705
+ # 将数字标签解码为原始标签
706
+ decoded_predictions = [self.class_names[p] for p in all_predictions]
707
+
708
+ # 保存预测结果
709
+ if output_dir:
710
+ os.makedirs(output_dir, exist_ok=True)
711
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
712
+ report_path = os.path.join(output_dir, f"predictions_{timestamp}.csv")
713
+
714
+ # 创建结果DataFrame
715
+ results = []
716
+ for i, path in enumerate(image_paths):
717
+ pred_label = decoded_predictions[i]
718
+ prob = all_probabilities[i][all_predictions[i]]
719
+ results.append(
720
+ {
721
+ "image_path": path,
722
+ "predicted_class": pred_label,
723
+ "probability": prob,
724
+ **{
725
+ f"prob_{cls}": prob_val
726
+ for cls, prob_val in zip(
727
+ self.class_names, all_probabilities[i]
728
+ )
729
+ },
730
+ }
731
+ )
732
+
733
+ df = pd.DataFrame(results)
734
+ df.to_csv(report_path, index=False)
735
+ print(f"预测结果保存至: {report_path}")
736
+
737
+ # 生成可视化报告
738
+ self.generate_prediction_report(
739
+ image_paths, decoded_predictions, all_probabilities, output_dir
740
+ )
741
+
742
+ # 生成Grad-CAM可视化
743
+ if self.config["grad_cam_layers"]:
744
+ self.generate_grad_cam(
745
+ image_paths,
746
+ all_logits,
747
+ output_dir,
748
+ layer_names=self.config["grad_cam_layers"],
749
+ )
750
+
751
+ return decoded_predictions, np.array(all_probabilities)
752
+
753
+ def evaluate(self, image_paths, labels):
754
+ """
755
+ 评估模型性能
756
+
757
+ 参数:
758
+ image_paths (list): 图像路径列表
759
+ labels (list): 真实标签列表
760
+
761
+ 返回:
762
+ dict: 包含详细评估指标的字典
763
+ """
764
+ # 预测
765
+ predictions, probabilities = self.predict(image_paths)
766
+ # -------debug------
767
+ # prob_array = np.array(probabilities)
768
+ # print("probabilities shape:", prob_array.shape)
769
+ # print("NaN count in probabilities:", np.isnan(prob_array).sum())
770
+ # print("Sample probabilities:", prob_array[:5])
771
+ # -------debug------
772
+
773
+ # 处理NaN,替换成0
774
+ probabilities = np.nan_to_num(probabilities, nan=0.0)
775
+
776
+ # 编码真实标签
777
+ true_labels_encoded = [self.label_encoder[l] for l in labels]
778
+ # 把字符串预测标签转成数字编码
779
+ predictions_encoded = [self.label_encoder[p] for p in predictions]
780
+
781
+ # 计算评估指标
782
+ report = classification_report(
783
+ true_labels_encoded, predictions_encoded, output_dict=True
784
+ )
785
+ cm = confusion_matrix(
786
+ true_labels_encoded, predictions_encoded, labels=range(len(self.class_names))
787
+ )
788
+
789
+ # 计算AUC-ROC
790
+ if len(self.class_names) > 2:
791
+ # 多分类AUC
792
+ y_true_bin = label_binarize(
793
+ true_labels_encoded, classes=range(len(self.class_names))
794
+ )
795
+ auc = roc_auc_score(y_true_bin, probabilities, multi_class="ovr")
796
+ ap = average_precision_score(y_true_bin, probabilities)
797
+ else:
798
+ # 二分类AUC
799
+ auc = roc_auc_score(true_labels_encoded, probabilities[:, 1])
800
+ ap = average_precision_score(true_labels_encoded, probabilities[:, 1])
801
+
802
+ # 可视化混淆矩阵
803
+ plt.figure(figsize=(12, 10))
804
+ sns.heatmap(
805
+ cm,
806
+ annot=True,
807
+ fmt="d",
808
+ cmap="Blues",
809
+ xticklabels=self.class_names,
810
+ yticklabels=self.class_names,
811
+ )
812
+ plt.xlabel("预测标签")
813
+ plt.ylabel("真实标签")
814
+ plt.title("混淆矩阵")
815
+
816
+ # 保存评估报告
817
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
818
+ report_dir = self.config["report_save_path"]
819
+ os.makedirs(report_dir, exist_ok=True)
820
+
821
+ # 保存分类报告
822
+ report_path = os.path.join(
823
+ report_dir, f"classification_report_{timestamp}.json"
824
+ )
825
+ with open(report_path, "w") as f:
826
+ json.dump(report, f, indent=4)
827
+
828
+ # 保存混淆矩阵
829
+ cm_path = os.path.join(report_dir, f"confusion_matrix_{timestamp}.png")
830
+ plt.savefig(cm_path, bbox_inches="tight")
831
+ plt.close()
832
+
833
+ # 保存ROC曲线
834
+ self.plot_roc_curve(
835
+ true_labels_encoded,
836
+ probabilities,
837
+ save_path=os.path.join(report_dir, f"roc_curve_{timestamp}.png"),
838
+ )
839
+
840
+ # 保存PR曲线
841
+ self.plot_pr_curve(
842
+ true_labels_encoded,
843
+ probabilities,
844
+ save_path=os.path.join(report_dir, f"pr_curve_{timestamp}.png"),
845
+ )
846
+
847
+ print(f"评估报告保存至: {report_path}")
848
+ print(f"混淆矩阵保存至: {cm_path}")
849
+ print(f"AUC-ROC: {auc:.4f} | Average Precision: {ap:.4f}")
850
+
851
+ return {
852
+ "classification_report": report,
853
+ "confusion_matrix": cm.tolist(),
854
+ "auc_roc": auc,
855
+ "average_precision": ap,
856
+ }
857
+
858
+ def save_model(self, filename):
859
+ """
860
+ 保存模型和配置
861
+
862
+ 参数:
863
+ filename (str): 模型文件名
864
+ """
865
+ model_path = os.path.join(self.config["model_save_path"], filename)
866
+
867
+ # 保存模型状态
868
+ torch.save(
869
+ {
870
+ "model_state_dict": self.model.state_dict(),
871
+ "config": self.config,
872
+ "label_encoder": self.label_encoder,
873
+ "class_names": self.class_names,
874
+ "stats": self.stats,
875
+ },
876
+ model_path,
877
+ )
878
+
879
+ print(f"模型保存至: {model_path}")
880
+ return model_path
881
+
882
+ def load_model(self, model_path):
883
+ """
884
+ 加载预训练模型
885
+
886
+ 参数:
887
+ model_path (str): 模型文件路径
888
+ """
889
+ # 加载检查点
890
+ checkpoint = torch.load(
891
+ model_path, map_location=torch.device(self.config["device"])
892
+ )
893
+
894
+ # 加载配置
895
+ if "config" in checkpoint:
896
+ loaded_config = checkpoint["config"]
897
+ # 更新配置但不覆盖设备设置
898
+ loaded_config["device"] = self.config["device"]
899
+ self.config = {**self.config, **loaded_config}
900
+
901
+ # 加载标签编码器和类别名称
902
+ if "label_encoder" in checkpoint:
903
+ self.label_encoder = checkpoint["label_encoder"]
904
+ self.class_names = checkpoint.get(
905
+ "class_names", list(self.label_encoder.keys())
906
+ )
907
+ self.config["num_classes"] = len(self.class_names)
908
+ print(f"加载标签编码器 | 类别: {self.class_names}")
909
+ elif "class_names" in checkpoint:
910
+ self.class_names = checkpoint["class_names"]
911
+ self.label_encoder = {cls: idx for idx, cls in enumerate(self.class_names)}
912
+ self.config["num_classes"] = len(self.class_names)
913
+ print(f"从类别名称重建标签编码器 | 类别: {self.class_names}")
914
+
915
+ # 加载训练统计信息
916
+ if "stats" in checkpoint:
917
+ self.stats = checkpoint["stats"]
918
+
919
+ # 创建模型架构
920
+ self.build_model()
921
+
922
+ # 加载模型权重
923
+ self.model.load_state_dict(checkpoint["model_state_dict"])
924
+ self.model.eval()
925
+ print(f"模型从 {model_path} 加载成功")
926
+
927
+ return self.model
928
+
929
+ def plot_training_history(self):
930
+ """可视化训练历史"""
931
+ if not self.stats["train_loss"]:
932
+ print("没有训练历史数据")
933
+ return
934
+
935
+ plt.figure(figsize=(16, 12))
936
+
937
+ # 绘制损失曲线
938
+ plt.subplot(2, 2, 1)
939
+ plt.plot(self.stats["train_loss"], label="训练损失")
940
+ if self.stats["val_loss"]:
941
+ plt.plot(self.stats["val_loss"], label="验证损失")
942
+ plt.title("训练和验证损失")
943
+ plt.xlabel("轮次")
944
+ plt.ylabel("损失")
945
+ plt.legend()
946
+ plt.grid(True)
947
+
948
+ # 绘制准确率曲线
949
+ if self.stats["val_accuracy"]:
950
+ plt.subplot(2, 2, 2)
951
+ plt.plot(self.stats["val_accuracy"], label="验证准确率", color="green")
952
+ plt.title("验证准确率")
953
+ plt.xlabel("轮次")
954
+ plt.ylabel("准确率")
955
+ plt.legend()
956
+ plt.grid(True)
957
+
958
+ # 绘制AUC曲线
959
+ if self.stats["val_auc"]:
960
+ plt.subplot(2, 2, 3)
961
+ plt.plot(self.stats["val_auc"], label="验证AUC", color="purple")
962
+ plt.title("验证AUC")
963
+ plt.xlabel("轮次")
964
+ plt.ylabel("AUC")
965
+ plt.legend()
966
+ plt.grid(True)
967
+
968
+ # 绘制F1曲线
969
+ if self.stats["val_f1"]:
970
+ plt.subplot(2, 2, 4)
971
+ plt.plot(self.stats["val_f1"], label="验证F1分数", color="orange")
972
+ plt.title("验证F1分数")
973
+ plt.xlabel("轮次")
974
+ plt.ylabel("F1分数")
975
+ plt.legend()
976
+ plt.grid(True)
977
+
978
+ plt.tight_layout()
979
+
980
+ # 保存图像
981
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
982
+ plot_path = os.path.join(
983
+ self.config["report_save_path"], f"training_history_{timestamp}.png"
984
+ )
985
+ plt.savefig(plot_path, bbox_inches="tight", dpi=300)
986
+ plt.close()
987
+ print(f"训练历史图保存至: {plot_path}")
988
+
989
+ def plot_roc_curve(self, true_labels, probabilities, save_path=None):
990
+ """绘制ROC曲线"""
991
+ from sklearn.metrics import roc_curve, auc
992
+ from itertools import cycle
993
+
994
+ if len(self.class_names) == 2:
995
+ # 二分类
996
+ fpr, tpr, _ = roc_curve(true_labels, probabilities[:, 1])
997
+ roc_auc = auc(fpr, tpr)
998
+
999
+ plt.figure()
1000
+ plt.plot(
1001
+ fpr,
1002
+ tpr,
1003
+ color="darkorange",
1004
+ lw=2,
1005
+ label=f"ROC曲线 (AUC = {roc_auc:.2f})",
1006
+ )
1007
+ plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
1008
+ plt.xlim([0.0, 1.0])
1009
+ plt.ylim([0.0, 1.05])
1010
+ plt.xlabel("假正例率")
1011
+ plt.ylabel("真正例率")
1012
+ plt.title("ROC曲线")
1013
+ plt.legend(loc="lower right")
1014
+ else:
1015
+ # 多分类
1016
+ fpr = dict()
1017
+ tpr = dict()
1018
+ roc_auc = dict()
1019
+
1020
+ # 二值化真实标签
1021
+ y_true_bin = label_binarize(
1022
+ true_labels, classes=range(len(self.class_names))
1023
+ )
1024
+
1025
+ # 计算每个类别的ROC曲线和AUC
1026
+ for i in range(len(self.class_names)):
1027
+ fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], probabilities[:, i])
1028
+ roc_auc[i] = auc(fpr[i], tpr[i])
1029
+
1030
+ # 计算微平均ROC曲线和AUC
1031
+ fpr["micro"], tpr["micro"], _ = roc_curve(
1032
+ y_true_bin.ravel(), probabilities.ravel()
1033
+ )
1034
+ roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
1035
+
1036
+ # 绘制ROC曲线
1037
+ plt.figure(figsize=(10, 8))
1038
+ colors = cycle(
1039
+ ["aqua", "darkorange", "cornflowerblue", "green", "red", "purple"]
1040
+ )
1041
+
1042
+ # 绘制每个类别的ROC曲线
1043
+ for i, color in zip(range(len(self.class_names)), colors):
1044
+ plt.plot(
1045
+ fpr[i],
1046
+ tpr[i],
1047
+ color=color,
1048
+ lw=2,
1049
+ label=f"ROC {self.class_names[i]} (AUC = {roc_auc[i]:.2f})",
1050
+ )
1051
+
1052
+ # 绘制微平均ROC曲线
1053
+ plt.plot(
1054
+ fpr["micro"],
1055
+ tpr["micro"],
1056
+ label=f'Micro-average ROC (AUC = {roc_auc["micro"]:.2f})',
1057
+ color="deeppink",
1058
+ linestyle=":",
1059
+ linewidth=4,
1060
+ )
1061
+
1062
+ plt.plot([0, 1], [0, 1], "k--", lw=2)
1063
+ plt.xlim([0.0, 1.0])
1064
+ plt.ylim([0.0, 1.05])
1065
+ plt.xlabel("假正例率")
1066
+ plt.ylabel("真正例率")
1067
+ plt.title("多类别ROC曲线")
1068
+ plt.legend(loc="lower right")
1069
+
1070
+ if save_path:
1071
+ plt.savefig(save_path, bbox_inches="tight", dpi=300)
1072
+ plt.close()
1073
+ return save_path
1074
+ else:
1075
+ plt.show()
1076
+
1077
+ def plot_pr_curve(self, true_labels, probabilities, save_path=None):
1078
+ """绘制精确率-召回率曲线"""
1079
+ from sklearn.metrics import precision_recall_curve, average_precision_score
1080
+ from itertools import cycle
1081
+
1082
+ if len(self.class_names) == 2:
1083
+ # 二分类
1084
+ precision, recall, _ = precision_recall_curve(
1085
+ true_labels, probabilities[:, 1]
1086
+ )
1087
+ ap = average_precision_score(true_labels, probabilities[:, 1])
1088
+
1089
+ plt.figure()
1090
+ plt.plot(
1091
+ recall,
1092
+ precision,
1093
+ color="darkorange",
1094
+ lw=2,
1095
+ label=f"PR曲线 (AP = {ap:.2f})",
1096
+ )
1097
+ plt.xlabel("召回率")
1098
+ plt.ylabel("精确率")
1099
+ plt.title("精确率-召回率曲线")
1100
+ plt.legend(loc="upper right")
1101
+ else:
1102
+ # 多分类
1103
+ precision = dict()
1104
+ recall = dict()
1105
+ average_precision = dict()
1106
+
1107
+ # 二值化真实标签
1108
+ y_true_bin = label_binarize(
1109
+ true_labels, classes=range(len(self.class_names))
1110
+ )
1111
+
1112
+ # 计算每个类别的PR曲线和AP
1113
+ for i in range(len(self.class_names)):
1114
+ precision[i], recall[i], _ = precision_recall_curve(
1115
+ y_true_bin[:, i], probabilities[:, i]
1116
+ )
1117
+ average_precision[i] = average_precision_score(
1118
+ y_true_bin[:, i], probabilities[:, i]
1119
+ )
1120
+
1121
+ # 计算微平均PR曲线
1122
+ precision["micro"], recall["micro"], _ = precision_recall_curve(
1123
+ y_true_bin.ravel(), probabilities.ravel()
1124
+ )
1125
+ average_precision["micro"] = average_precision_score(
1126
+ y_true_bin, probabilities, average="micro"
1127
+ )
1128
+
1129
+ # 绘制PR曲线
1130
+ plt.figure(figsize=(10, 8))
1131
+ colors = cycle(
1132
+ ["aqua", "darkorange", "cornflowerblue", "green", "red", "purple"]
1133
+ )
1134
+
1135
+ # 绘制每个类别的PR曲线
1136
+ for i, color in zip(range(len(self.class_names)), colors):
1137
+ plt.plot(
1138
+ recall[i],
1139
+ precision[i],
1140
+ color=color,
1141
+ lw=2,
1142
+ label=f"PR {self.class_names[i]} (AP = {average_precision[i]:.2f})",
1143
+ )
1144
+
1145
+ # 绘制微平均PR曲线
1146
+ plt.plot(
1147
+ recall["micro"],
1148
+ precision["micro"],
1149
+ label=f'Micro-average PR (AP = {average_precision["micro"]:.2f})',
1150
+ color="deeppink",
1151
+ linestyle=":",
1152
+ linewidth=4,
1153
+ )
1154
+
1155
+ plt.xlabel("召回率")
1156
+ plt.ylabel("精确率")
1157
+ plt.title("多类别精确率-召回率曲线")
1158
+ plt.legend(loc="upper right")
1159
+
1160
+ if save_path:
1161
+ plt.savefig(save_path, bbox_inches="tight", dpi=300)
1162
+ plt.close()
1163
+ return save_path
1164
+ else:
1165
+ plt.show()
1166
+
1167
+ def generate_grad_cam(
1168
+ self, image_paths, logits, output_dir, layer_names, num_samples=10
1169
+ ):
1170
+ """
1171
+ 生成Grad-CAM可视化 - 解释模型决策
1172
+
1173
+ 参数:
1174
+ image_paths (list): 图像路径列表
1175
+ logits (list): 模型原始输出列表
1176
+ output_dir (str): 输出目录
1177
+ layer_names (list): 要可视化的层名列表
1178
+ num_samples (int): 要可视化的样本数量
1179
+ """
1180
+ if not layer_names:
1181
+ return
1182
+
1183
+ # 确保模型在评估模式
1184
+ self.model.eval()
1185
+
1186
+ # 创建Grad-CAM目录
1187
+ cam_dir = os.path.join(output_dir, "grad_cam")
1188
+ os.makedirs(cam_dir, exist_ok=True)
1189
+
1190
+ # 随机选择样本
1191
+ indices = np.random.choice(
1192
+ len(image_paths), min(num_samples, len(image_paths)), replace=False
1193
+ )
1194
+
1195
+ # 对每个样本生成Grad-CAM
1196
+ for idx in indices:
1197
+ img_path = image_paths[idx]
1198
+ logit = logits[idx]
1199
+ predicted_class = np.argmax(logit)
1200
+
1201
+ # 加载原始图像
1202
+ img = cv2.imread(img_path)
1203
+ if img is None:
1204
+ continue
1205
+
1206
+ # 对每个指定的层生成Grad-CAM
1207
+ for layer_name in layer_names:
1208
+ # 生成Grad-CAM
1209
+ cam = self._compute_grad_cam(img_path, layer_name, predicted_class)
1210
+ if cam is None:
1211
+ continue
1212
+
1213
+ # 将CAM覆盖到原始图像上
1214
+ heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
1215
+ superimposed_img = heatmap * 0.4 + img * 0.6
1216
+
1217
+ # 保存结果
1218
+ img_name = os.path.splitext(os.path.basename(img_path))[0]
1219
+ save_path = os.path.join(
1220
+ cam_dir, f"{img_name}_{layer_name}_gradcam.jpg"
1221
+ )
1222
+ cv2.imwrite(save_path, superimposed_img)
1223
+
1224
+ print(f"Grad-CAM可视化保存至: {cam_dir}")
1225
+
1226
+ def _compute_grad_cam(self, image_path, layer_name, target_class):
1227
+ """计算指定层和类别的Grad-CAM"""
1228
+ # 获取目标层
1229
+ layer = None
1230
+ for name, module in self.model.named_modules():
1231
+ if name == layer_name:
1232
+ layer = module
1233
+ break
1234
+
1235
+ if layer is None:
1236
+ print(f"未找到层: {layer_name}")
1237
+ return None
1238
+
1239
+ # 注册hook
1240
+ activations = []
1241
+ gradients = []
1242
+
1243
+ def forward_hook(module, input, output):
1244
+ activations.append(output.detach())
1245
+
1246
+ def backward_hook(module, grad_input, grad_output):
1247
+ gradients.append(grad_output[0].detach())
1248
+
1249
+ forward_handle = layer.register_forward_hook(forward_hook)
1250
+ backward_handle = layer.register_backward_hook(backward_hook)
1251
+
1252
+ try:
1253
+ # 预处理图像
1254
+ transform = self.get_augmentations()
1255
+ img = cv2.imread(image_path)
1256
+ if img is None:
1257
+ return None
1258
+
1259
+ augmented = transform(image=img)
1260
+ input_tensor = augmented["image"].unsqueeze(0).to(self.config["device"])
1261
+
1262
+ # 前向传播
1263
+ output = self.model(input_tensor)
1264
+
1265
+ # 后向传播
1266
+ self.model.zero_grad()
1267
+ one_hot = torch.zeros_like(output)
1268
+ one_hot[0, target_class] = 1
1269
+ output.backward(gradient=one_hot)
1270
+
1271
+ # 获取激活和梯度
1272
+ if not activations or not gradients:
1273
+ return None
1274
+
1275
+ activations = activations[0].cpu().numpy()[0]
1276
+ gradients = gradients[0].cpu().numpy()[0]
1277
+
1278
+ # 计算权重
1279
+ weights = np.mean(gradients, axis=(1, 2))
1280
+ cam = np.zeros(activations.shape[1:], dtype=np.float32)
1281
+
1282
+ # 计算CAM
1283
+ for i, w in enumerate(weights):
1284
+ cam += w * activations[i]
1285
+
1286
+ # 后处理CAM
1287
+ cam = np.maximum(cam, 0)
1288
+ cam = cam - np.min(cam)
1289
+ cam = cam / np.max(cam) if np.max(cam) > 0 else cam
1290
+ cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
1291
+
1292
+ return (cam * 255).astype(np.uint8)
1293
+
1294
+ finally:
1295
+ # 移除hook
1296
+ forward_handle.remove()
1297
+ backward_handle.remove()
1298
+
1299
+ def generate_prediction_report(
1300
+ self, image_paths, predictions, probabilities, output_dir, num_samples=10
1301
+ ):
1302
+ """
1303
+ 生成预测报告
1304
+
1305
+ 参数:
1306
+ image_paths (list): 图像路径列表
1307
+ predictions (list): 预测标签列表
1308
+ probabilities (list): 预测概率列表
1309
+ output_dir (str): 输出目录
1310
+ num_samples (int): 报告中包含的样本数量
1311
+ """
1312
+ # 创建HTML报告
1313
+ html_content = """
1314
+ <!DOCTYPE html>
1315
+ <html>
1316
+ <head>
1317
+ <title>细胞图像分类报告</title>
1318
+ <style>
1319
+ body {{ font-family: Arial, sans-serif; margin: 20px; background-color: #f5f5f5; }}
1320
+ .container {{ max-width: 1400px; margin: auto; background: white; padding: 20px; box-shadow: 0 0 10px rgba(0,0,0,0.1); border-radius: 8px; }}
1321
+ h1, h2 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
1322
+ .grid {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); gap: 20px; margin: 20px 0; }}
1323
+ .cell-card {{ border: 1px solid #ddd; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 5px rgba(0,0,0,0.1); transition: transform 0.3s; }}
1324
+ .cell-card:hover {{ transform: translateY(-5px); box-shadow: 0 5px 15px rgba(0,0,0,0.1); }}
1325
+ .cell-card img {{ width: 100%; height: 180px; object-fit: cover; border-bottom: 1px solid #ddd; }}
1326
+ .card-content {{ padding: 15px; }}
1327
+ .prediction {{ font-weight: bold; font-size: 16px; margin: 10px 0; }}
1328
+ .prob-bar {{ height: 20px; background: #eee; border-radius: 10px; margin: 10px 0; overflow: hidden; }}
1329
+ .prob-fill {{ height: 100%; background: #3498db; border-radius: 10px; }}
1330
+ .prob-text {{ font-size: 14px; color: #555; }}
1331
+ .class-distribution {{ display: flex; flex-wrap: wrap; gap: 15px; margin: 20px 0; }}
1332
+ .class-item {{ background: #f8f9fa; padding: 10px 15px; border-radius: 8px; border-left: 4px solid #3498db; }}
1333
+ table {{ width: 100%; border-collapse: collapse; margin: 20px 0; }}
1334
+ th, td {{ border: 1px solid #ddd; padding: 12px; text-align: left; }}
1335
+ th {{ background-color: #3498db; color: white; }}
1336
+ tr:nth-child(even) {{ background-color: #f2f2f2; }}
1337
+ .metrics {{ display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); gap: 20px; margin: 20px 0; }}
1338
+ .metric-card {{ background: #f8f9fa; padding: 20px; border-radius: 8px; text-align: center; }}
1339
+ .metric-value {{ font-size: 24px; font-weight: bold; color: #3498db; }}
1340
+ .metric-label {{ font-size: 16px; color: #555; }}
1341
+ </style>
1342
+ </head>
1343
+ <body>
1344
+ <div class="container">
1345
+ <h1>细胞图像分类报告</h1>
1346
+ <p><strong>生成时间:</strong> {timestamp}</p>
1347
+ <p><strong>模型:</strong> {model_name}</p>
1348
+ <p><strong>样本总数:</strong> {total_samples}</p>
1349
+
1350
+ <h2>预测分布</h2>
1351
+ <div class="class-distribution">
1352
+ {class_distribution}
1353
+ </div>
1354
+
1355
+ <h2>模型评估指标</h2>
1356
+ <div class="metrics">
1357
+ {metrics_cards}
1358
+ </div>
1359
+
1360
+ <h2>随机样本预测结果</h2>
1361
+ <div class="grid">
1362
+ {sample_images}
1363
+ </div>
1364
+
1365
+ <h2>所有预测结果</h2>
1366
+ {predictions_table}
1367
+ </div>
1368
+ </body>
1369
+ </html>
1370
+ """
1371
+
1372
+ # 生成类别分布
1373
+ class_counts = {cls: 0 for cls in self.class_names}
1374
+ for pred in predictions:
1375
+ class_counts[pred] += 1
1376
+
1377
+ dist_chart = ""
1378
+ for cls, count in class_counts.items():
1379
+ percentage = count / len(predictions)
1380
+ dist_chart += f"""
1381
+ <div class="class-item">
1382
+ <div><strong>{cls}:</strong> {count} ({percentage:.1%})</div>
1383
+ <div class="prob-bar">
1384
+ <div class="prob-fill" style="width: {percentage*100}%"></div>
1385
+ </div>
1386
+ </div>
1387
+ """
1388
+
1389
+ # 计算评估指标
1390
+ metrics_cards = ""
1391
+ if hasattr(self, "stats") and self.stats.get("val_auc"):
1392
+ metrics = {
1393
+ "准确率": self.stats["val_accuracy"][-1],
1394
+ "AUC": self.stats["val_auc"][-1],
1395
+ "F1分数": self.stats["val_f1"][-1],
1396
+ "损失": self.stats["val_loss"][-1],
1397
+ }
1398
+
1399
+ for name, value in metrics.items():
1400
+ metrics_cards += f"""
1401
+ <div class="metric-card">
1402
+ <div class="metric-value">{value:.4f}</div>
1403
+ <div class="metric-label">{name}</div>
1404
+ </div>
1405
+ """
1406
+
1407
+ # 随机选择样本
1408
+ indices = np.random.choice(
1409
+ len(image_paths), min(num_samples, len(image_paths)), replace=False
1410
+ )
1411
+ sample_html = ""
1412
+
1413
+ for idx in indices:
1414
+ img_path = image_paths[idx]
1415
+ pred = predictions[idx]
1416
+ prob = probabilities[idx][self.label_encoder[pred]]
1417
+
1418
+ # 创建图像标签
1419
+ img_tag = f'<img src="{img_path}" alt="Cell Image">'
1420
+
1421
+ # 创建概率条
1422
+ prob_bar = f"""
1423
+ <div class="prob-bar">
1424
+ <div class="prob-fill" style="width: {prob*100:.1f}%"></div>
1425
+ </div>
1426
+ <div class="prob-text">置信度: {prob:.3f}</div>
1427
+ """
1428
+
1429
+ sample_html += f"""
1430
+ <div class="cell-card">
1431
+ {img_tag}
1432
+ <div class="card-content">
1433
+ <div class="prediction">预测: {pred}</div>
1434
+ {prob_bar}
1435
+ </div>
1436
+ </div>
1437
+ """
1438
+
1439
+ # 创建预测结果表
1440
+ table_rows = ""
1441
+ for i, (img_path, pred) in enumerate(zip(image_paths, predictions)):
1442
+ prob = probabilities[i][self.label_encoder[pred]]
1443
+ table_rows += f"""
1444
+ <tr>
1445
+ <td>{i+1}</td>
1446
+ <td>{os.path.basename(img_path)}</td>
1447
+ <td>{pred}</td>
1448
+ <td>{prob:.4f}</td>
1449
+ </tr>
1450
+ """
1451
+
1452
+ predictions_table = f"""
1453
+ <table>
1454
+ <thead>
1455
+ <tr>
1456
+ <th>序号</th>
1457
+ <th>图像</th>
1458
+ <th>预测类别</th>
1459
+ <th>置信度</th>
1460
+ </tr>
1461
+ </thead>
1462
+ <tbody>
1463
+ {table_rows}
1464
+ </tbody>
1465
+ </table>
1466
+ """
1467
+
1468
+ # 填充HTML模板
1469
+ html_content = html_content.format(
1470
+ timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
1471
+ model_name=self.config["model_name"],
1472
+ total_samples=len(image_paths),
1473
+ class_distribution=dist_chart,
1474
+ metrics_cards=metrics_cards,
1475
+ sample_images=sample_html,
1476
+ predictions_table=predictions_table,
1477
+ )
1478
+
1479
+ # 保存HTML报告
1480
+ report_path = os.path.join(output_dir, "prediction_report.html")
1481
+ with open(report_path, "w", encoding="utf-8") as f:
1482
+ f.write(html_content)
1483
+
1484
+ print(f"预测报告保存至: {report_path}")
1485
+ return report_path
1486
+
1487
+
1488
+ class CellDataset(Dataset):
1489
+ """
1490
+ 细胞图像数据集类 - 优化版
1491
+
1492
+ 改进:
1493
+ 1. 添加图像缓存机制 - 加速后续训练
1494
+ 2. 更健壮的错误处理
1495
+ 3. 支持多种图像格式
1496
+ 4. 优化的多通道处理
1497
+ """
1498
+
1499
+ def __init__(
1500
+ self,
1501
+ image_paths,
1502
+ labels,
1503
+ transform=None,
1504
+ is_train=True,
1505
+ input_channels=3,
1506
+ cache_dir=None,
1507
+ ):
1508
+ self.image_paths = image_paths
1509
+ self.labels = labels
1510
+ self.transform = transform
1511
+ self.is_train = is_train
1512
+ self.input_channels = input_channels
1513
+ self.cache_dir = cache_dir
1514
+
1515
+ # 细胞图像特有的增强组合
1516
+ self.cell_specific_aug = A.Compose(
1517
+ [
1518
+ A.OneOf(
1519
+ [
1520
+ A.MotionBlur(blur_limit=3, p=0.3),
1521
+ A.GaussianBlur(blur_limit=3, p=0.3),
1522
+ A.MedianBlur(blur_limit=3, p=0.3),
1523
+ ],
1524
+ p=0.5,
1525
+ ),
1526
+ A.OneOf(
1527
+ [
1528
+ A.OpticalDistortion(distort_limit=0.5, shift_limit=0.1, p=0.3),
1529
+ A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
1530
+ A.ElasticTransform(alpha=1, sigma=20, alpha_affine=10, p=0.3),
1531
+ ],
1532
+ p=0.5,
1533
+ ),
1534
+ A.RandomGamma(gamma_limit=(80, 120), p=0.5),
1535
+ ]
1536
+ )
1537
+
1538
+ # 检查图像路径和标签是否匹配
1539
+ if labels is not None and len(image_paths) != len(labels):
1540
+ raise ValueError("图像路径数量和标签数量不匹配")
1541
+
1542
+ # 创建缓存目录
1543
+ if cache_dir:
1544
+ os.makedirs(cache_dir, exist_ok=True)
1545
+
1546
+ def __len__(self):
1547
+ return len(self.image_paths)
1548
+ # return image, label
1549
+ def __getitem__(self, idx):
1550
+ img_path = self.image_paths[idx]
1551
+
1552
+ # 读取图像 - 确保返回NumPy数组
1553
+ try:
1554
+ # 使用OpenCV读取图像
1555
+ image = cv2.imread(img_path)
1556
+ if image is None:
1557
+ raise FileNotFoundError(f"无法加载图像: {img_path}")
1558
+
1559
+ # 确保图像是NumPy数组
1560
+ if not isinstance(image, np.ndarray):
1561
+ image = np.array(image)
1562
+
1563
+ # 确保图像有正确的通道数
1564
+ if len(image.shape) == 2: # 灰度图像
1565
+ if self.input_channels == 3:
1566
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
1567
+ elif self.input_channels == 1:
1568
+ image = np.expand_dims(image, axis=-1)
1569
+ else: # 彩色或多通道图像
1570
+ if image.shape[2] == 3 and self.input_channels == 1:
1571
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
1572
+ image = np.expand_dims(image, axis=-1)
1573
+ elif image.shape[2] == 4 and self.input_channels == 3:
1574
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
1575
+ elif image.shape[2] > self.input_channels:
1576
+ image = image[:, :, :self.input_channels]
1577
+
1578
+ except Exception as e:
1579
+ print(f"加载图像{img_path}出错: {e}")
1580
+ # 返回空白图像作为后备
1581
+ image = np.zeros((self.transform.height, self.transform.width, self.input_channels), dtype=np.uint8)
1582
+
1583
+ if image.ndim == 2:
1584
+ image = np.expand_dims(image, axis=-1)
1585
+
1586
+ if image.shape[-1] == 1 and self.input_channels == 3:
1587
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
1588
+ elif image.shape[-1] == 3 and self.input_channels == 1:
1589
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
1590
+ image = np.expand_dims(image, axis=-1)
1591
+
1592
+ if image.dtype != np.uint8:
1593
+ image = image.astype(np.uint8)
1594
+ # 应用数据增强
1595
+ if self.transform:
1596
+ try:
1597
+ # albumentations 要求 image 是 NumPy array
1598
+ if isinstance(image, torch.Tensor):
1599
+ image = image.numpy()
1600
+ if isinstance(image, Image.Image): # PIL image
1601
+ image = np.array(image)
1602
+
1603
+ if image.ndim == 3 and image.shape[2] == 1:
1604
+ image = np.squeeze(image, axis=2)
1605
+
1606
+ augmented = self.transform(image=image)
1607
+ image = augmented['image']
1608
+
1609
+ # 如果是训练集,应用细胞特有增强
1610
+ if self.is_train and self.cell_specific_aug:
1611
+ if isinstance(image, torch.Tensor):
1612
+ image = image.numpy()
1613
+ if isinstance(image, Image.Image):
1614
+ image = np.array(image)
1615
+
1616
+ if image.ndim == 3 and image.shape[2] == 1:
1617
+ image = np.squeeze(image, axis=2)
1618
+
1619
+ augmented = self.cell_specific_aug(image=image)
1620
+ image = augmented['image']
1621
+ except Exception as e:
1622
+ print(f"数据增强出错: {e}")
1623
+
1624
+
1625
+ # 获取标签
1626
+ label = self.labels[idx] if self.labels is not None else -1
1627
+
1628
+ # 确保最终输出是张量
1629
+ if not isinstance(image, torch.Tensor):
1630
+ image = torch.from_numpy(image)
1631
+
1632
+ return image, label