py2ls 0.1.10.12__py3-none-any.whl → 0.2.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of py2ls might be problematic. Click here for more details.
- py2ls/.DS_Store +0 -0
- py2ls/.git/.DS_Store +0 -0
- py2ls/.git/index +0 -0
- py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
- py2ls/.git/objects/.DS_Store +0 -0
- py2ls/.git/refs/.DS_Store +0 -0
- py2ls/ImageLoader.py +621 -0
- py2ls/__init__.py +7 -5
- py2ls/apptainer2ls.py +3940 -0
- py2ls/batman.py +164 -42
- py2ls/bio.py +2595 -0
- py2ls/cell_image_clf.py +1632 -0
- py2ls/container2ls.py +4635 -0
- py2ls/corr.py +475 -0
- py2ls/data/.DS_Store +0 -0
- py2ls/data/email/email_html_template.html +88 -0
- py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
- py2ls/data/hyper_param_tabrepo_2024.py +1753 -0
- py2ls/data/mygenes_fields_241022.txt +355 -0
- py2ls/data/re_common_pattern.json +173 -0
- py2ls/data/sns_info.json +74 -0
- py2ls/data/styles/.DS_Store +0 -0
- py2ls/data/styles/example/.DS_Store +0 -0
- py2ls/data/styles/stylelib/.DS_Store +0 -0
- py2ls/data/styles/stylelib/grid.mplstyle +15 -0
- py2ls/data/styles/stylelib/high-contrast.mplstyle +6 -0
- py2ls/data/styles/stylelib/high-vis.mplstyle +4 -0
- py2ls/data/styles/stylelib/ieee.mplstyle +15 -0
- py2ls/data/styles/stylelib/light.mplstyl +6 -0
- py2ls/data/styles/stylelib/muted.mplstyle +6 -0
- py2ls/data/styles/stylelib/nature-reviews-latex.mplstyle +616 -0
- py2ls/data/styles/stylelib/nature-reviews.mplstyle +616 -0
- py2ls/data/styles/stylelib/nature.mplstyle +31 -0
- py2ls/data/styles/stylelib/no-latex.mplstyle +10 -0
- py2ls/data/styles/stylelib/notebook.mplstyle +36 -0
- py2ls/data/styles/stylelib/paper.mplstyle +290 -0
- py2ls/data/styles/stylelib/paper2.mplstyle +305 -0
- py2ls/data/styles/stylelib/retro.mplstyle +4 -0
- py2ls/data/styles/stylelib/sans.mplstyle +10 -0
- py2ls/data/styles/stylelib/scatter.mplstyle +7 -0
- py2ls/data/styles/stylelib/science.mplstyle +48 -0
- py2ls/data/styles/stylelib/std-colors.mplstyle +4 -0
- py2ls/data/styles/stylelib/vibrant.mplstyle +6 -0
- py2ls/data/tiles.csv +146 -0
- py2ls/data/usages_pd.json +1417 -0
- py2ls/data/usages_sns.json +31 -0
- py2ls/docker2ls.py +5446 -0
- py2ls/ec2ls.py +61 -0
- py2ls/fetch_update.py +145 -0
- py2ls/ich2ls.py +1955 -296
- py2ls/im2.py +8242 -0
- py2ls/image_ml2ls.py +2100 -0
- py2ls/ips.py +33909 -3418
- py2ls/ml2ls.py +7700 -0
- py2ls/mol.py +289 -0
- py2ls/mount2ls.py +1307 -0
- py2ls/netfinder.py +873 -351
- py2ls/nl2ls.py +283 -0
- py2ls/ocr.py +1581 -458
- py2ls/plot.py +10394 -314
- py2ls/rna2ls.py +311 -0
- py2ls/ssh2ls.md +456 -0
- py2ls/ssh2ls.py +5933 -0
- py2ls/ssh2ls_v01.py +2204 -0
- py2ls/stats.py +66 -172
- py2ls/temp20251124.py +509 -0
- py2ls/translator.py +2 -0
- py2ls/utils/decorators.py +3564 -0
- py2ls/utils_bio.py +3453 -0
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/METADATA +113 -224
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/RECORD +72 -16
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/WHEEL +0 -0
py2ls/image_ml2ls.py
ADDED
|
@@ -0,0 +1,2100 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
Ultimate Image Machine Learning Master Function (Optimized)
|
|
3
|
+
Author: DeepSeek AI
|
|
4
|
+
Date: 2025-05-30
|
|
5
|
+
Version: 2.0
|
|
6
|
+
|
|
7
|
+
新增功能:
|
|
8
|
+
1. 模型导出与导入 (pickle/joblib)
|
|
9
|
+
2. 完整的预测报告生成 (HTML格式)
|
|
10
|
+
3. 预测结果可视化增强
|
|
11
|
+
4. 支持从保存的模型直接预测
|
|
12
|
+
5. 更完善的配置管理
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# usage: 250601
|
|
16
|
+
from py2ls.ips import listdir
|
|
17
|
+
import re
|
|
18
|
+
|
|
19
|
+
f = listdir("/Users/macjianfeng/Desktop/img_datasets/", "folder", verbose=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
train_images = [
|
|
23
|
+
*listdir(f["path"][0], "tif")["path"].tolist(),
|
|
24
|
+
*listdir(f["path"][1], "tif")["path"].tolist(),
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
labels = [re.findall(r"\d+x", i)[0] for i in train_images]
|
|
29
|
+
|
|
30
|
+
# 初始化主函数
|
|
31
|
+
mlearner = ImageMLearner()
|
|
32
|
+
# 加载图像
|
|
33
|
+
images, filtered_labels = mlearner.load_images(train_images, labels)
|
|
34
|
+
|
|
35
|
+
# 自动模型比较
|
|
36
|
+
comparison_results = mlearner.auto_compare_models()
|
|
37
|
+
# 保存最佳模型
|
|
38
|
+
mlearner.save_model(mlearner.best_model, "models/best_model.joblib")
|
|
39
|
+
dir_test = "/Users/macjianfeng/Desktop/20250515_CDX2_LiCl_7d/HNT-34/50µM/"
|
|
40
|
+
f = listdir(dir_test, "tif", verbose=True)
|
|
41
|
+
|
|
42
|
+
# 使用最佳模型进行预测
|
|
43
|
+
test_images = f.path.tolist()
|
|
44
|
+
|
|
45
|
+
predictions = mlearner.predict(test_images, output_dir="reports")
|
|
46
|
+
print("Predictions:", predictions)
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
-------------------------------------------------------------------------------
|
|
52
|
+
This file provides a complete machine learning pipeline for image analysis with:
|
|
53
|
+
|
|
54
|
+
Core Functionality:
|
|
55
|
+
Image loading and preprocessing
|
|
56
|
+
Feature extraction (HOG, LBP, CNN, etc.)
|
|
57
|
+
Model training and evaluation (many classifiers/regressors)
|
|
58
|
+
Model comparison and selection
|
|
59
|
+
Prediction and reporting
|
|
60
|
+
|
|
61
|
+
Key Components:
|
|
62
|
+
Multiple feature extraction methods
|
|
63
|
+
Wide range of ML models (traditional and deep learning)
|
|
64
|
+
Model evaluation and visualization
|
|
65
|
+
Report generation (HTML/Markdown)
|
|
66
|
+
|
|
67
|
+
Use Cases:
|
|
68
|
+
When you need end-to-end image classification/regression
|
|
69
|
+
When you want to compare multiple models automatically
|
|
70
|
+
When you need feature extraction from images
|
|
71
|
+
When you want detailed reports and visualizations
|
|
72
|
+
When you need GPU acceleration support
|
|
73
|
+
|
|
74
|
+
You need a complete ML pipeline from images to predictions
|
|
75
|
+
You want to compare multiple ML models automatically
|
|
76
|
+
You need different feature extraction methods
|
|
77
|
+
You want detailed evaluation metrics and visualizations
|
|
78
|
+
You need HTML/Markdown reports of your analysis
|
|
79
|
+
You want to leverage both traditional ML and deep learning
|
|
80
|
+
-------------------------------------------------------------------------------
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
usage_str="""
|
|
85
|
+
-------------如果要使用tensorflow的话, 一定要最先调用--------
|
|
86
|
+
import tensorflow as tf
|
|
87
|
+
print("Eager execution enabled:", tf.executing_eagerly())
|
|
88
|
+
tf.config.set_visible_devices([], "GPU")
|
|
89
|
+
-------------如果要使用tensorflow的话, 一定要最先调用--------
|
|
90
|
+
"""
|
|
91
|
+
print(usage_str)
|
|
92
|
+
from .ips import has_gpu #set_computing_device
|
|
93
|
+
from .ImageLoader import ImageLoader
|
|
94
|
+
import os
|
|
95
|
+
import json
|
|
96
|
+
import yaml
|
|
97
|
+
import pickle
|
|
98
|
+
import joblib
|
|
99
|
+
import numpy as np
|
|
100
|
+
import pandas as pd
|
|
101
|
+
import matplotlib.pyplot as plt
|
|
102
|
+
import seaborn as sns
|
|
103
|
+
import time
|
|
104
|
+
from functools import partial
|
|
105
|
+
from collections import defaultdict
|
|
106
|
+
from typing import Union, Dict, List, Tuple, Optional, Callable, Any
|
|
107
|
+
from IPython.display import HTML, display
|
|
108
|
+
from tqdm import tqdm
|
|
109
|
+
import gc
|
|
110
|
+
import logging
|
|
111
|
+
|
|
112
|
+
# 图像处理库
|
|
113
|
+
import cv2
|
|
114
|
+
from skimage import io, color, exposure
|
|
115
|
+
from skimage import transform, feature, filters, morphology, segmentation
|
|
116
|
+
from PIL import Image
|
|
117
|
+
# 机器学习库
|
|
118
|
+
from sklearn.base import BaseEstimator
|
|
119
|
+
from sklearn.pipeline import Pipeline
|
|
120
|
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
|
121
|
+
from sklearn.decomposition import PCA
|
|
122
|
+
from sklearn.manifold import TSNE
|
|
123
|
+
from sklearn.model_selection import train_test_split
|
|
124
|
+
from sklearn.metrics import (
|
|
125
|
+
accuracy_score,
|
|
126
|
+
f1_score,
|
|
127
|
+
confusion_matrix,
|
|
128
|
+
mean_squared_error,
|
|
129
|
+
r2_score,
|
|
130
|
+
classification_report,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# 分类模型
|
|
134
|
+
from sklearn.ensemble import (
|
|
135
|
+
RandomForestClassifier,
|
|
136
|
+
RandomForestRegressor,
|
|
137
|
+
ExtraTreesClassifier,
|
|
138
|
+
ExtraTreesRegressor,
|
|
139
|
+
HistGradientBoostingRegressor,
|
|
140
|
+
BaggingClassifier,
|
|
141
|
+
BaggingRegressor,
|
|
142
|
+
AdaBoostClassifier,
|
|
143
|
+
AdaBoostRegressor,
|
|
144
|
+
)
|
|
145
|
+
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor,StackingClassifier,StackingRegressor
|
|
146
|
+
from sklearn.svm import SVC
|
|
147
|
+
|
|
148
|
+
from sklearn.linear_model import (
|
|
149
|
+
LogisticRegression,ElasticNet,ElasticNetCV,
|
|
150
|
+
LinearRegression,Lasso,RidgeClassifierCV,Perceptron,SGDClassifier,
|
|
151
|
+
RidgeCV,Ridge,TheilSenRegressor,HuberRegressor,PoissonRegressor,Lars, LassoLars, BayesianRidge,
|
|
152
|
+
GammaRegressor, TweedieRegressor, LassoCV, LassoLarsCV, LarsCV,
|
|
153
|
+
OrthogonalMatchingPursuit, OrthogonalMatchingPursuitCV, PassiveAggressiveRegressor
|
|
154
|
+
)
|
|
155
|
+
from sklearn.linear_model import LogisticRegression
|
|
156
|
+
from sklearn.svm import SVC
|
|
157
|
+
from sklearn.neighbors import KNeighborsClassifier
|
|
158
|
+
from sklearn.tree import DecisionTreeClassifier
|
|
159
|
+
import xgboost as xgb
|
|
160
|
+
from sklearn.naive_bayes import GaussianNB,BernoulliNB
|
|
161
|
+
from sklearn.discriminant_analysis import (
|
|
162
|
+
LinearDiscriminantAnalysis,
|
|
163
|
+
QuadraticDiscriminantAnalysis,
|
|
164
|
+
)
|
|
165
|
+
from sklearn.naive_bayes import GaussianNB, BernoulliNB
|
|
166
|
+
import lightgbm as lgb
|
|
167
|
+
import catboost as cb
|
|
168
|
+
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
|
169
|
+
# 回归模型
|
|
170
|
+
from sklearn.svm import SVR
|
|
171
|
+
from sklearn.neighbors import KNeighborsRegressor
|
|
172
|
+
from sklearn.tree import DecisionTreeRegressor
|
|
173
|
+
from sklearn.ensemble import (
|
|
174
|
+
RandomForestRegressor,
|
|
175
|
+
GradientBoostingRegressor,
|
|
176
|
+
AdaBoostRegressor,
|
|
177
|
+
ExtraTreesRegressor,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
from sklearn.linear_model import (
|
|
181
|
+
LassoCV,
|
|
182
|
+
LogisticRegression,
|
|
183
|
+
LinearRegression,
|
|
184
|
+
Lasso,
|
|
185
|
+
Ridge,
|
|
186
|
+
ElasticNet,
|
|
187
|
+
)
|
|
188
|
+
# 深度学习特征提取
|
|
189
|
+
# try:
|
|
190
|
+
# from tensorflow.keras.utils import to_categorical
|
|
191
|
+
|
|
192
|
+
# except Exception as e:
|
|
193
|
+
# print("Error importing tensorflow.keras.utils or sklearn.utils.class_weight:", e)
|
|
194
|
+
|
|
195
|
+
# try:
|
|
196
|
+
|
|
197
|
+
# from tensorflow.keras.applications import (
|
|
198
|
+
# VGG16,
|
|
199
|
+
# VGG19,
|
|
200
|
+
# ResNet50,
|
|
201
|
+
# InceptionV3,
|
|
202
|
+
# MobileNet,
|
|
203
|
+
# DenseNet121,
|
|
204
|
+
# EfficientNetB0,
|
|
205
|
+
# )
|
|
206
|
+
# from tensorflow.keras.models import Model
|
|
207
|
+
# from tensorflow.keras.preprocessing import image
|
|
208
|
+
# from tensorflow.keras.applications.imagenet_utils import preprocess_input
|
|
209
|
+
|
|
210
|
+
# DL_ENABLED = True
|
|
211
|
+
# except ImportError:
|
|
212
|
+
# DL_ENABLED = False
|
|
213
|
+
# Conditional imports for deep learning
|
|
214
|
+
DL_ENABLED = False
|
|
215
|
+
try:
|
|
216
|
+
import tensorflow as tf
|
|
217
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
218
|
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
219
|
+
from tensorflow.keras.utils import to_categorical
|
|
220
|
+
from tensorflow.keras.applications import (
|
|
221
|
+
VGG16, VGG19, ResNet50, InceptionV3,
|
|
222
|
+
MobileNet, DenseNet121, EfficientNetB0
|
|
223
|
+
)
|
|
224
|
+
from tensorflow.keras.models import Model
|
|
225
|
+
from tensorflow.keras.preprocessing import image
|
|
226
|
+
from tensorflow.keras.applications.imagenet_utils import preprocess_input
|
|
227
|
+
from tensorflow.keras.callbacks import (
|
|
228
|
+
EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
|
229
|
+
)
|
|
230
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
231
|
+
from tensorflow.keras.mixed_precision import set_global_policy
|
|
232
|
+
from tensorflow.keras.optimizers import Adam
|
|
233
|
+
DL_ENABLED = True
|
|
234
|
+
except ImportError:
|
|
235
|
+
pass
|
|
236
|
+
#====check gpu available====
|
|
237
|
+
# set_computing_device()
|
|
238
|
+
# Global configuration
|
|
239
|
+
|
|
240
|
+
DEFAULT_CONFIG = {
|
|
241
|
+
"preprocessing": {
|
|
242
|
+
"resize": (128, 128),
|
|
243
|
+
"remove_background": True,
|
|
244
|
+
"denoise": "median",
|
|
245
|
+
"normalize": True,
|
|
246
|
+
"hist_equalize": False,
|
|
247
|
+
"chunk_size": 100,
|
|
248
|
+
"augmentation": {
|
|
249
|
+
"rotation_range": 10,
|
|
250
|
+
"width_shift_range": 0.1,
|
|
251
|
+
"height_shift_range": 0.1,
|
|
252
|
+
"shear_range": 0.05,
|
|
253
|
+
"zoom_range": 0.1,
|
|
254
|
+
"horizontal_flip": True
|
|
255
|
+
}
|
|
256
|
+
},
|
|
257
|
+
"feature_extraction": {
|
|
258
|
+
"method": "hog",
|
|
259
|
+
"cnn_model": "resnet50" if DL_ENABLED else None,
|
|
260
|
+
"hog_params": {"orientations": 8, "pixels_per_cell": (16, 16), "cells_per_block": (1, 1)},
|
|
261
|
+
"lbp_params": {"P": 8, "R": 1, "method": "uniform"}
|
|
262
|
+
},
|
|
263
|
+
"cnn": {
|
|
264
|
+
"conv_layers": [(16, 3, 1), (32, 3, 1)],
|
|
265
|
+
"dense_layers": [32],
|
|
266
|
+
"use_batchnorm": True,
|
|
267
|
+
"use_dropout": False,
|
|
268
|
+
"dropout_rate": 0.2,
|
|
269
|
+
"pooling": "max",
|
|
270
|
+
"learning_rate": 0.001,
|
|
271
|
+
"optimizer": "adam"
|
|
272
|
+
},
|
|
273
|
+
"model": {"name": "cnn", "params": {}},
|
|
274
|
+
"auto_comparison": True,
|
|
275
|
+
"test_size": 0.2,
|
|
276
|
+
"random_state": 1,
|
|
277
|
+
"n_cpu": max(os.cpu_count() - 2, 4),
|
|
278
|
+
"visualization": {
|
|
279
|
+
"enable": True,
|
|
280
|
+
"confusion_matrix": True,
|
|
281
|
+
"feature_space": True,
|
|
282
|
+
"sample_images": 2,
|
|
283
|
+
"cmap": "viridis"
|
|
284
|
+
},
|
|
285
|
+
"report": {"format": "html", "include_visuals": True}
|
|
286
|
+
}
|
|
287
|
+
def build_cnn_model(
|
|
288
|
+
input_shape: tuple,
|
|
289
|
+
num_classes: int,
|
|
290
|
+
*,
|
|
291
|
+
# Architecture configuration
|
|
292
|
+
conv_layers: list = None,
|
|
293
|
+
dense_layers: list = None,
|
|
294
|
+
use_batchnorm: bool = True,
|
|
295
|
+
use_dropout: bool = False,
|
|
296
|
+
dropout_rate: float = 0.2,
|
|
297
|
+
pooling: str = "max",
|
|
298
|
+
# Training configuration
|
|
299
|
+
learning_rate: float = 0.001,
|
|
300
|
+
optimizer: str = "adam",
|
|
301
|
+
# Logging/utility
|
|
302
|
+
verbose: bool = True,
|
|
303
|
+
logger: logging.Logger = None
|
|
304
|
+
) -> tf.keras.Model:
|
|
305
|
+
"""
|
|
306
|
+
Flexible CNN builder with customizable architecture.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
input_shape: Tuple (height, width, channels) or (height, width)
|
|
310
|
+
num_classes: Number of output classes
|
|
311
|
+
conv_layers: List of tuples (filters, kernel_size, stride) for conv layers
|
|
312
|
+
Default: [(16, 3, 1), (32, 3, 1)]
|
|
313
|
+
dense_layers: List of units for dense layers. Default: [32]
|
|
314
|
+
use_batchnorm: Whether to use batch normalization
|
|
315
|
+
use_dropout: Whether to use dropout
|
|
316
|
+
dropout_rate: Dropout rate if use_dropout=True
|
|
317
|
+
pooling: "max" or "avg" pooling
|
|
318
|
+
learning_rate: Learning rate for optimizer
|
|
319
|
+
optimizer: "adam", "sgd", or "rmsprop"
|
|
320
|
+
verbose: Whether to print model info
|
|
321
|
+
logger: Custom logger instance
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Compiled Keras model
|
|
325
|
+
"""
|
|
326
|
+
# Initialize logging
|
|
327
|
+
logger = logger or logging.getLogger("cnn_builder")
|
|
328
|
+
if verbose and not logger.handlers:
|
|
329
|
+
ch = logging.StreamHandler()
|
|
330
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
331
|
+
ch.setFormatter(formatter)
|
|
332
|
+
logger.addHandler(ch)
|
|
333
|
+
|
|
334
|
+
start_time = time.time()
|
|
335
|
+
|
|
336
|
+
# Validate and normalize input shape
|
|
337
|
+
if len(input_shape) == 2:
|
|
338
|
+
input_shape = (*input_shape, 1)
|
|
339
|
+
elif len(input_shape) != 3:
|
|
340
|
+
raise ValueError("input_shape must be (h,w) or (h,w,c)")
|
|
341
|
+
|
|
342
|
+
# Set default architecture if not provided
|
|
343
|
+
conv_layers = conv_layers or [(16, 3, 1), (32, 3, 1)]
|
|
344
|
+
dense_layers = dense_layers or [32]
|
|
345
|
+
|
|
346
|
+
# Configure output layer
|
|
347
|
+
if num_classes == 1: # Regression
|
|
348
|
+
final_config = {"units": 1, "activation": "linear", "loss": "mse", "metrics": ["mae"]}
|
|
349
|
+
elif num_classes == 2: # Binary classification
|
|
350
|
+
final_config = {"units": 1, "activation": "sigmoid", "loss": "binary_crossentropy", "metrics": ["accuracy"]}
|
|
351
|
+
else: # Multi-class
|
|
352
|
+
final_config = {"units": num_classes, "activation": "softmax", "loss": "sparse_categorical_crossentropy", "metrics": ["accuracy"]}
|
|
353
|
+
|
|
354
|
+
# Build model
|
|
355
|
+
model = tf.keras.Sequential()
|
|
356
|
+
model.add(tf.keras.layers.InputLayer(input_shape=input_shape))
|
|
357
|
+
|
|
358
|
+
# Add convolutional blocks
|
|
359
|
+
for i, (filters, kernel_size, stride) in enumerate(conv_layers):
|
|
360
|
+
model.add(tf.keras.layers.Conv2D(
|
|
361
|
+
filters=filters,
|
|
362
|
+
kernel_size=kernel_size,
|
|
363
|
+
strides=stride,
|
|
364
|
+
padding="same",
|
|
365
|
+
name=f"conv_{i}"
|
|
366
|
+
))
|
|
367
|
+
|
|
368
|
+
if use_batchnorm:
|
|
369
|
+
model.add(tf.keras.layers.BatchNormalization(name=f"bn_{i}"))
|
|
370
|
+
|
|
371
|
+
model.add(tf.keras.layers.ReLU(name=f"relu_{i}"))
|
|
372
|
+
|
|
373
|
+
# Add pooling every other conv layer
|
|
374
|
+
if i % 2 == 1:
|
|
375
|
+
if pooling == "max":
|
|
376
|
+
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name=f"pool_{i}"))
|
|
377
|
+
else:
|
|
378
|
+
model.add(tf.keras.layers.AveragePooling2D(pool_size=(2, 2), name=f"pool_{i}"))
|
|
379
|
+
|
|
380
|
+
if use_dropout:
|
|
381
|
+
model.add(tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}"))
|
|
382
|
+
|
|
383
|
+
# Transition to dense layers
|
|
384
|
+
model.add(tf.keras.layers.Flatten())
|
|
385
|
+
|
|
386
|
+
# Add dense layers
|
|
387
|
+
for i, units in enumerate(dense_layers):
|
|
388
|
+
model.add(tf.keras.layers.Dense(units, activation="relu", name=f"dense_{i}"))
|
|
389
|
+
if use_batchnorm:
|
|
390
|
+
model.add(tf.keras.layers.BatchNormalization(name=f"bn_dense_{i}"))
|
|
391
|
+
if use_dropout:
|
|
392
|
+
model.add(tf.keras.layers.Dropout(dropout_rate, name=f"dropout_dense_{i}"))
|
|
393
|
+
|
|
394
|
+
# Output layer
|
|
395
|
+
model.add(tf.keras.layers.Dense(
|
|
396
|
+
units=final_config["units"],
|
|
397
|
+
activation=final_config["activation"],
|
|
398
|
+
name="output"
|
|
399
|
+
))
|
|
400
|
+
|
|
401
|
+
# Configure optimizer
|
|
402
|
+
optimizers = {
|
|
403
|
+
"adam": tf.keras.optimizers.Adam(learning_rate=learning_rate),
|
|
404
|
+
"sgd": tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9),
|
|
405
|
+
"rmsprop": tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
# Compile model
|
|
409
|
+
model.compile(
|
|
410
|
+
optimizer=optimizers.get(optimizer, optimizers["adam"]),
|
|
411
|
+
loss=final_config["loss"],
|
|
412
|
+
metrics=final_config["metrics"]
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Logging
|
|
416
|
+
if verbose:
|
|
417
|
+
logger.info(f"Built CNN with:")
|
|
418
|
+
logger.info(f"- {len(conv_layers)} conv layers")
|
|
419
|
+
logger.info(f"- {len(dense_layers)} dense layers")
|
|
420
|
+
logger.info(f"- BatchNorm: {use_batchnorm}")
|
|
421
|
+
logger.info(f"- Dropout: {use_dropout} ({dropout_rate if use_dropout else 'N/A'})")
|
|
422
|
+
logger.info(f"- Optimizer: {optimizer} (lr={learning_rate})")
|
|
423
|
+
logger.info(f"Model built in {time.time()-start_time:.2f}s\n")
|
|
424
|
+
model.summary(print_fn=logger.info)
|
|
425
|
+
|
|
426
|
+
return model
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
class ImageMLearner:
|
|
431
|
+
"""
|
|
432
|
+
通用图像机器学习主函数 (优化版)
|
|
433
|
+
|
|
434
|
+
参数:
|
|
435
|
+
cfg (Union[dict, str]): 配置字典或配置文件路径 (JSON/YAML)
|
|
436
|
+
verbose (bool): 是否显示详细输出
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
def __init__(self, cfg: Union[dict, str] = None, verbose: bool = True):
|
|
440
|
+
self.verbose = verbose
|
|
441
|
+
self.dl_enabled = DL_ENABLED
|
|
442
|
+
self.cfg = self._load_cfg(cfg)
|
|
443
|
+
self.models = {}
|
|
444
|
+
self.feature_extractors = {}
|
|
445
|
+
self.preprocessors = {}
|
|
446
|
+
self.chunk_size = self.cfg["preprocessing"].get("chunk_size", 128)
|
|
447
|
+
self.n_cpu=self.cfg.get("n_cpu", max(os.cpu_count(), 8))
|
|
448
|
+
self._register_default_components()
|
|
449
|
+
|
|
450
|
+
# 状态跟踪
|
|
451
|
+
self.images, self.image_input, self.labels = [], [], []
|
|
452
|
+
self.X_train, self.y_train = None, None
|
|
453
|
+
self.X_test, self.y_test = None, None
|
|
454
|
+
self.features = None
|
|
455
|
+
self.model_performance = {}
|
|
456
|
+
self.best_model = None
|
|
457
|
+
self.label_encoder = None
|
|
458
|
+
self.task_type = None
|
|
459
|
+
self.class_names = None
|
|
460
|
+
|
|
461
|
+
# Hardware optimization
|
|
462
|
+
self.has_gpu = has_gpu(verbose=self.verbose)
|
|
463
|
+
self.cfg_hash = hash(json.dumps(self.cfg["preprocessing"], sort_keys=True))
|
|
464
|
+
|
|
465
|
+
if self.has_gpu and DL_ENABLED:
|
|
466
|
+
self._configure_gpu()
|
|
467
|
+
|
|
468
|
+
# Initialize ImageLoader
|
|
469
|
+
self.image_loader = ImageLoader(
|
|
470
|
+
target_size=self.cfg["preprocessing"].get("resize", (128, 128)),
|
|
471
|
+
grayscale=self.cfg["preprocessing"].get("to_grayscale", True),
|
|
472
|
+
n_jobs=self.cfg.get("n_cpu", max(os.cpu_count() - 2, 4)),
|
|
473
|
+
verbose=verbose,
|
|
474
|
+
scaler="normalize" if self.cfg["preprocessing"].get("normalize", True) else "raw",
|
|
475
|
+
chunk_size=self.cfg["preprocessing"].get("chunk_size", 100)
|
|
476
|
+
)
|
|
477
|
+
# 初始化CNN模型缓存
|
|
478
|
+
self._cnn_models = {}
|
|
479
|
+
if cfg is not None:
|
|
480
|
+
self.clear_cache()
|
|
481
|
+
# If DL is disabled but config requests CNN, fall back to HOG
|
|
482
|
+
if not self.dl_enabled and self.cfg["feature_extraction"].get("method") == "cnn":
|
|
483
|
+
self.cfg["feature_extraction"]["method"] = "hog"
|
|
484
|
+
if self.verbose:
|
|
485
|
+
print("CNN requested but DL not available. Falling back to HOG.")
|
|
486
|
+
if self.verbose:
|
|
487
|
+
print(f"init: ImageMLearner initialized with cfg: {self.cfg}")
|
|
488
|
+
def _configure_gpu(self):
|
|
489
|
+
"""Configure GPU settings for optimal performance"""
|
|
490
|
+
try:
|
|
491
|
+
gpus = tf.config.list_physical_devices('GPU')
|
|
492
|
+
if gpus:
|
|
493
|
+
for gpu in gpus:
|
|
494
|
+
tf.config.experimental.set_memory_growth(gpu, True)
|
|
495
|
+
set_global_policy('mixed_float16')
|
|
496
|
+
if self.verbose:
|
|
497
|
+
print("GPU configured with mixed precision")
|
|
498
|
+
except Exception as e:
|
|
499
|
+
print(f"Error configuring GPU: {str(e)}")
|
|
500
|
+
def _load_cfg(self, cfg: Union[dict, str]) -> dict:
|
|
501
|
+
"""加载配置文件"""
|
|
502
|
+
default_cfg = {
|
|
503
|
+
"preprocessing": {
|
|
504
|
+
"resize": (128, 128),
|
|
505
|
+
"remove_background": True,
|
|
506
|
+
"denoise": "median",
|
|
507
|
+
"normalize": True,
|
|
508
|
+
"hist_equalize": False,
|
|
509
|
+
"chunk_size": 100,
|
|
510
|
+
},
|
|
511
|
+
"feature_extraction": {
|
|
512
|
+
"method": "hog",
|
|
513
|
+
"cnn_model": "resnet50" if self.dl_enabled else None,
|
|
514
|
+
"hog_params": {
|
|
515
|
+
"orientations": 8,
|
|
516
|
+
"pixels_per_cell": (16, 16),
|
|
517
|
+
"cells_per_block": (1, 1),
|
|
518
|
+
},
|
|
519
|
+
"lbp_params": {"P": 8, "R": 1, "method": "uniform"},
|
|
520
|
+
},
|
|
521
|
+
"model": {"name": "random_forest", "params": {}},
|
|
522
|
+
"auto_comparison": True,
|
|
523
|
+
"test_size": 0.2,
|
|
524
|
+
"random_state": 1,
|
|
525
|
+
"n_cpu":max(os.cpu_count(), 8),
|
|
526
|
+
"visualization": {
|
|
527
|
+
"enable": True,
|
|
528
|
+
"confusion_matrix": True,
|
|
529
|
+
"feature_space": True,
|
|
530
|
+
"sample_images": 2,
|
|
531
|
+
"cmap": "viridis",
|
|
532
|
+
},
|
|
533
|
+
"report": {"format": "html", "include_visuals": True}, # html/md
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
if cfg is None:
|
|
537
|
+
return default_cfg
|
|
538
|
+
|
|
539
|
+
if isinstance(cfg, str):
|
|
540
|
+
if cfg.endswith(".json"):
|
|
541
|
+
with open(cfg, "r") as f:
|
|
542
|
+
cfg = json.load(f)
|
|
543
|
+
elif cfg.endswith((".yaml", ".yml")):
|
|
544
|
+
with open(cfg, "r") as f:
|
|
545
|
+
cfg = yaml.safe_load(f)
|
|
546
|
+
else:
|
|
547
|
+
raise ValueError("Unsupported cfg file format. Use JSON or YAML.")
|
|
548
|
+
|
|
549
|
+
# 深度合并配置
|
|
550
|
+
def deep_merge(target, source):
|
|
551
|
+
for key, value in source.items():
|
|
552
|
+
if (key in target) and isinstance(target[key], dict) and isinstance(value, dict):
|
|
553
|
+
deep_merge(target[key], value)
|
|
554
|
+
else:
|
|
555
|
+
target[key] = value
|
|
556
|
+
return target
|
|
557
|
+
|
|
558
|
+
config = DEFAULT_CONFIG.copy()
|
|
559
|
+
return deep_merge(config, cfg)
|
|
560
|
+
def _register_default_components(self):
|
|
561
|
+
"""注册默认预处理、特征提取和模型"""
|
|
562
|
+
# 注册预处理方法
|
|
563
|
+
self.register_preprocessor("grabcut", self._grabcut_background_removal)
|
|
564
|
+
self.register_preprocessor("kmeans", self._kmeans_background_removal)
|
|
565
|
+
self.register_preprocessor("morphology", self._morphology_background_removal)
|
|
566
|
+
|
|
567
|
+
# 注册特征提取方法
|
|
568
|
+
self.register_feature_extractor("pixels", self._extract_raw_pixels)
|
|
569
|
+
self.register_feature_extractor("hog", self._extract_hog)
|
|
570
|
+
self.register_feature_extractor("lbp", self._extract_lbp)
|
|
571
|
+
self.register_feature_extractor("color_hist", self._extract_color_histogram)
|
|
572
|
+
|
|
573
|
+
# if self.dl_enabled:
|
|
574
|
+
# self.register_feature_extractor("cnn", self._extract_cnn_features)
|
|
575
|
+
|
|
576
|
+
# 注册模型
|
|
577
|
+
models = {
|
|
578
|
+
# 分类模型
|
|
579
|
+
"logistic_regression": LogisticRegression,
|
|
580
|
+
"svm": SVC,
|
|
581
|
+
"knn": KNeighborsClassifier,
|
|
582
|
+
"decision_tree": DecisionTreeClassifier,
|
|
583
|
+
"random_forest": RandomForestClassifier,
|
|
584
|
+
"gradient_boosting": GradientBoostingClassifier,
|
|
585
|
+
"ada_boost": AdaBoostClassifier,
|
|
586
|
+
"naive_bayes": GaussianNB,
|
|
587
|
+
"lda": LinearDiscriminantAnalysis,
|
|
588
|
+
"qda": QuadraticDiscriminantAnalysis,
|
|
589
|
+
"extra_trees": ExtraTreesClassifier,
|
|
590
|
+
# added
|
|
591
|
+
"lasso_logistic": LogisticRegression(penalty='l1', solver='saga'),
|
|
592
|
+
"ridge_classifier": RidgeClassifierCV(),
|
|
593
|
+
"elastic_net_cls": ElasticNet(),
|
|
594
|
+
"xgb": xgb.XGBClassifier(), #if XGB_ENABLED else None,
|
|
595
|
+
"lightgbm": lgb.LGBMClassifier(), #if LGBM_ENABLED else None,
|
|
596
|
+
"catboost": cb.CatBoostClassifier(verbose=0), #if CATBOOST_ENABLED else None,
|
|
597
|
+
"bagging": BaggingClassifier(),
|
|
598
|
+
"mlp": MLPClassifier(max_iter=500),
|
|
599
|
+
"quadratic_discriminant": QuadraticDiscriminantAnalysis(),
|
|
600
|
+
"bernoulli_nb": BernoulliNB(),
|
|
601
|
+
"sgd": SGDClassifier(),
|
|
602
|
+
# 回归模型
|
|
603
|
+
"linear_regression": LinearRegression,
|
|
604
|
+
"ridge": Ridge,
|
|
605
|
+
"lasso": Lasso,
|
|
606
|
+
"elastic_net": ElasticNet,
|
|
607
|
+
"svr": SVR,
|
|
608
|
+
"knn_regressor": KNeighborsRegressor,
|
|
609
|
+
"decision_tree_reg": DecisionTreeRegressor,
|
|
610
|
+
"random_forest_reg": RandomForestRegressor,
|
|
611
|
+
"gradient_boosting_reg": GradientBoostingRegressor,
|
|
612
|
+
"ada_boost_reg": AdaBoostRegressor,
|
|
613
|
+
"extra_trees_reg": ExtraTreesRegressor,
|
|
614
|
+
"cnn":self._build_cnn_model,
|
|
615
|
+
# added
|
|
616
|
+
"lasso_cv": LassoCV(),
|
|
617
|
+
"elastic_net_cv": ElasticNetCV(),
|
|
618
|
+
"xgb_reg": xgb.XGBRegressor(), #if XGB_ENABLED else None,
|
|
619
|
+
"lightgbm_reg": lgb.LGBMRegressor(), #if LGBM_ENABLED else None,
|
|
620
|
+
"catboost_reg": cb.CatBoostRegressor(verbose=0), #if CATBOOST_ENABLED else None,
|
|
621
|
+
"bagging_reg": BaggingRegressor(),
|
|
622
|
+
"mlp_reg": MLPRegressor(max_iter=500),
|
|
623
|
+
"theil_sen": TheilSenRegressor(),
|
|
624
|
+
"huber": HuberRegressor(),
|
|
625
|
+
"poisson": PoissonRegressor(),
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
for name, model in models.items():
|
|
629
|
+
self.register_model(name, model)
|
|
630
|
+
|
|
631
|
+
def register_preprocessor(self, name: str, func: Callable):
|
|
632
|
+
"""注册新的预处理方法"""
|
|
633
|
+
self.preprocessors[name] = func
|
|
634
|
+
# if self.verbose:
|
|
635
|
+
# print(f"Registered preprocessor: {name}")
|
|
636
|
+
|
|
637
|
+
def register_feature_extractor(self, name: str, func: Callable):
|
|
638
|
+
"""注册新的特征提取方法"""
|
|
639
|
+
self.feature_extractors[name] = func
|
|
640
|
+
# if self.verbose:
|
|
641
|
+
# print(f"Registered feature extractor: {name}")
|
|
642
|
+
|
|
643
|
+
def register_model(self, name: str, model_class: BaseEstimator):
|
|
644
|
+
"""注册新的机器学习模型"""
|
|
645
|
+
self.models[name] = model_class
|
|
646
|
+
# if self.verbose:
|
|
647
|
+
# print(f"Registered model: {name}")
|
|
648
|
+
|
|
649
|
+
def load_images(self, image_paths: List[str], labels: List = None):
|
|
650
|
+
"""
|
|
651
|
+
加载图像数据集
|
|
652
|
+
|
|
653
|
+
参数:
|
|
654
|
+
image_paths: Can be:
|
|
655
|
+
- List of file paths (e.g., ["img1.jpg", "img2.png"])
|
|
656
|
+
- NumPy array of shape (n_samples, height, width[, channels])
|
|
657
|
+
- List of NumPy arrays (each shape: (height, width[, channels]))
|
|
658
|
+
labels: 对应的标签列表 (可选)
|
|
659
|
+
"""
|
|
660
|
+
self._reset_state()
|
|
661
|
+
if self.verbose:
|
|
662
|
+
print(f"Loading {len(image_paths)} images...")
|
|
663
|
+
# return images, filtered_labels
|
|
664
|
+
# Case 1: Input is already arrays
|
|
665
|
+
if not isinstance(image_paths[0], str):
|
|
666
|
+
self.images = list(image_paths)
|
|
667
|
+
self.image_input = None
|
|
668
|
+
if labels is not None:
|
|
669
|
+
self._process_labels(labels)
|
|
670
|
+
return self.images, self.labels
|
|
671
|
+
|
|
672
|
+
# Case 2: Input is file paths - use ImageLoader
|
|
673
|
+
df = pd.DataFrame({"path": image_paths})
|
|
674
|
+
if labels is not None:
|
|
675
|
+
df["label"] = labels
|
|
676
|
+
|
|
677
|
+
# Process images using ImageLoader
|
|
678
|
+
result = self.image_loader.process(
|
|
679
|
+
data=df,
|
|
680
|
+
x_col="path",
|
|
681
|
+
y_col="label" if labels is not None else None,
|
|
682
|
+
encoder="label" if labels is not None else None,
|
|
683
|
+
output="array"
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
if labels is not None:
|
|
687
|
+
images, processed_labels = result
|
|
688
|
+
self.labels = processed_labels
|
|
689
|
+
else:
|
|
690
|
+
images = result
|
|
691
|
+
|
|
692
|
+
self.images = list(images)
|
|
693
|
+
self.image_input = image_paths
|
|
694
|
+
|
|
695
|
+
# Detect task type and process labels
|
|
696
|
+
if labels is not None:
|
|
697
|
+
self._process_labels(labels)
|
|
698
|
+
|
|
699
|
+
if self.verbose:
|
|
700
|
+
print(f"Loaded {len(self.images)} images, {len(self.labels) if labels is not None else 0} labels")
|
|
701
|
+
|
|
702
|
+
return self.images, self.labels if labels is not None else None
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def _process_labels(self, labels: List):
|
|
706
|
+
"""Internal method to process labels and detect task type"""
|
|
707
|
+
labels = np.array(labels).flatten()
|
|
708
|
+
|
|
709
|
+
# Detect task type
|
|
710
|
+
unique_labels = set(labels)
|
|
711
|
+
if all(isinstance(label, (int, float)) for label in labels):
|
|
712
|
+
self.task_type = "regression"
|
|
713
|
+
if self.verbose:
|
|
714
|
+
print("Detected regression task")
|
|
715
|
+
else:
|
|
716
|
+
self.task_type = "classification"
|
|
717
|
+
self.class_names = sorted(unique_labels)
|
|
718
|
+
self.label_encoder = LabelEncoder()
|
|
719
|
+
self.labels = self.label_encoder.fit_transform(labels)
|
|
720
|
+
if self.verbose:
|
|
721
|
+
print(f"Detected classification task with {len(unique_labels)} classes")
|
|
722
|
+
def preprocess_image(self, img: np.ndarray) -> np.ndarray:
|
|
723
|
+
"""应用配置的预处理步骤到单个图像"""
|
|
724
|
+
cfg = self.cfg["preprocessing"]
|
|
725
|
+
|
|
726
|
+
# Convert to 3-channel BGR if needed
|
|
727
|
+
# if img.ndim == 2: # Grayscale
|
|
728
|
+
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
729
|
+
# elif img.ndim == 3 and img.shape[2] == 4: # RGBA
|
|
730
|
+
# img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
|
|
731
|
+
# elif img.ndim == 3 and img.shape[2] == 3: # RGB
|
|
732
|
+
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
733
|
+
if img.ndim == 2:
|
|
734
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
735
|
+
elif img.shape[2] == 4:
|
|
736
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
|
737
|
+
|
|
738
|
+
# Ensure uint8 (0-255) for OpenCV operations
|
|
739
|
+
if img.dtype != np.uint8:
|
|
740
|
+
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
|
741
|
+
# or: img = img.astype(np.uint8) # If already [0,255]
|
|
742
|
+
|
|
743
|
+
# 调整大小 # Resize with OpenCV for performance
|
|
744
|
+
if cfg.get("resize") and img.shape[:2] != cfg["resize"]:
|
|
745
|
+
img = cv2.resize(img, tuple(cfg["resize"][::-1]), interpolation=cv2.INTER_AREA)
|
|
746
|
+
|
|
747
|
+
# 去背景
|
|
748
|
+
if cfg.get("remove_background"):
|
|
749
|
+
method = cfg.get("bg_method", "grabcut")
|
|
750
|
+
if method in self.preprocessors:
|
|
751
|
+
img = self.preprocessors[method](img)
|
|
752
|
+
# 转换为灰度(如果需要)
|
|
753
|
+
if cfg.get("to_grayscale", True) and img.ndim == 3:
|
|
754
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
|
755
|
+
# 降噪
|
|
756
|
+
denoise_method = cfg.get("denoise")
|
|
757
|
+
if denoise_method:
|
|
758
|
+
if denoise_method == "median":
|
|
759
|
+
img = cv2.medianBlur(img, 3)
|
|
760
|
+
elif denoise_method == "gaussian":
|
|
761
|
+
img = cv2.GaussianBlur(img, (5, 5), 0)
|
|
762
|
+
elif denoise_method == "bilateral":
|
|
763
|
+
img = cv2.bilateralFilter(img, 9, 75, 75)
|
|
764
|
+
|
|
765
|
+
# 直方图均衡化
|
|
766
|
+
if cfg.get("hist_equalize", False):
|
|
767
|
+
if img.ndim == 3:
|
|
768
|
+
img_yuv = cv2.cvtColor(img, cv2.COLOR_RGB2YUV)
|
|
769
|
+
img_yuv[:, :, 0] = cv2.equalizeHist(img_yuv[:, :, 0])
|
|
770
|
+
img = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
|
|
771
|
+
else:
|
|
772
|
+
img = exposure.equalize_hist(img)
|
|
773
|
+
|
|
774
|
+
# 标准化
|
|
775
|
+
if cfg.get("normalize", True):
|
|
776
|
+
img = img.astype(np.float32) / 255.0
|
|
777
|
+
|
|
778
|
+
return img
|
|
779
|
+
|
|
780
|
+
def _grabcut_background_removal(self, img: np.ndarray) -> np.ndarray:
|
|
781
|
+
"""使用GrabCut算法去除背景"""
|
|
782
|
+
mask = np.zeros(img.shape[:2], np.uint8)
|
|
783
|
+
bgd_model = np.zeros((1, 65), np.float64)
|
|
784
|
+
fgd_model = np.zeros((1, 65), np.float64)
|
|
785
|
+
|
|
786
|
+
# 定义ROI(整个图像)
|
|
787
|
+
h, w = img.shape[:2]
|
|
788
|
+
margin=10
|
|
789
|
+
rect = (
|
|
790
|
+
max(0, margin),
|
|
791
|
+
max(0, margin),
|
|
792
|
+
max(1, w - 2 * margin),
|
|
793
|
+
max(1, h - 2 * margin)
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
cv2.grabCut(img, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
|
|
797
|
+
|
|
798
|
+
mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype("uint8")
|
|
799
|
+
result = img * mask2[:, :, np.newaxis]
|
|
800
|
+
|
|
801
|
+
# 清除边界
|
|
802
|
+
result = segmentation.clear_border(result)
|
|
803
|
+
|
|
804
|
+
return result
|
|
805
|
+
|
|
806
|
+
def _kmeans_background_removal(self, img: np.ndarray) -> np.ndarray:
|
|
807
|
+
"""使用K-means聚类去除背景"""
|
|
808
|
+
if img.ndim == 3:
|
|
809
|
+
Z = img.reshape((-1, 3))
|
|
810
|
+
else:
|
|
811
|
+
Z = img.reshape((-1, 1))
|
|
812
|
+
|
|
813
|
+
Z = np.float32(Z)
|
|
814
|
+
|
|
815
|
+
# K-means参数
|
|
816
|
+
K = 2
|
|
817
|
+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
|
818
|
+
ret, label, center = cv2.kmeans(
|
|
819
|
+
Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# 找到背景(假设背景是最大的簇)
|
|
823
|
+
counts = np.bincount(label.flatten())
|
|
824
|
+
bg_label = np.argmax(counts)
|
|
825
|
+
|
|
826
|
+
# 创建掩码
|
|
827
|
+
mask = np.uint8(label == bg_label)
|
|
828
|
+
mask = mask.reshape(img.shape[:2])
|
|
829
|
+
|
|
830
|
+
# 形态学操作清理掩码
|
|
831
|
+
mask = morphology.binary_erosion(mask, morphology.disk(5))
|
|
832
|
+
mask = morphology.binary_dilation(mask, morphology.disk(5))
|
|
833
|
+
|
|
834
|
+
# 应用掩码
|
|
835
|
+
if img.ndim == 3:
|
|
836
|
+
result = img.copy()
|
|
837
|
+
result[mask == 1] = 0
|
|
838
|
+
else:
|
|
839
|
+
result = img * (1 - mask)
|
|
840
|
+
|
|
841
|
+
return result
|
|
842
|
+
|
|
843
|
+
def _morphology_background_removal(self, img: np.ndarray) -> np.ndarray:
|
|
844
|
+
"""使用形态学操作去除背景"""
|
|
845
|
+
if img.ndim == 3:
|
|
846
|
+
gray = color.rgb2gray(img)
|
|
847
|
+
else:
|
|
848
|
+
gray = img
|
|
849
|
+
|
|
850
|
+
# 阈值处理
|
|
851
|
+
thresh = gray > 0.5
|
|
852
|
+
|
|
853
|
+
# 形态学操作
|
|
854
|
+
cleaned = morphology.binary_erosion(thresh, morphology.disk(5))
|
|
855
|
+
cleaned = morphology.binary_dilation(cleaned, morphology.disk(5))
|
|
856
|
+
|
|
857
|
+
# 应用掩码
|
|
858
|
+
if img.ndim == 3:
|
|
859
|
+
result = img.copy()
|
|
860
|
+
result[~cleaned] = 0
|
|
861
|
+
else:
|
|
862
|
+
result = img * cleaned
|
|
863
|
+
|
|
864
|
+
return result
|
|
865
|
+
|
|
866
|
+
def extract_features(self, images: List[np.ndarray] = None) -> np.ndarray:
|
|
867
|
+
"""从图像列表中提取特征"""
|
|
868
|
+
import concurrent.futures
|
|
869
|
+
if images is None:
|
|
870
|
+
images = self.images
|
|
871
|
+
|
|
872
|
+
if not images:
|
|
873
|
+
raise ValueError("No images available for feature extraction")
|
|
874
|
+
|
|
875
|
+
method = self.cfg["feature_extraction"]["method"]
|
|
876
|
+
extractor = self.feature_extractors.get(method)
|
|
877
|
+
|
|
878
|
+
if not extractor:
|
|
879
|
+
raise ValueError(f"Unsupported feature extraction method: {method}")
|
|
880
|
+
|
|
881
|
+
if self.verbose:
|
|
882
|
+
print(f"Extracting features using {method} method...")
|
|
883
|
+
# Use cached features if available
|
|
884
|
+
cache_file = f"features_cache_{self.cfg_hash}.npy"
|
|
885
|
+
if os.path.exists(cache_file):
|
|
886
|
+
if self.verbose:
|
|
887
|
+
print("Loading features from cache")
|
|
888
|
+
self.features = np.load(cache_file, allow_pickle=True)
|
|
889
|
+
return self.features
|
|
890
|
+
start_time = time.time()
|
|
891
|
+
# Use cached preprocessing if available
|
|
892
|
+
preprocessed_images = self._cache_preprocessed(images)
|
|
893
|
+
|
|
894
|
+
# Process features
|
|
895
|
+
with ThreadPoolExecutor(max_workers=self.n_cpu) as executor:
|
|
896
|
+
features = list(tqdm(
|
|
897
|
+
executor.map(extractor, preprocessed_images),
|
|
898
|
+
total=len(images),
|
|
899
|
+
desc="Extracting features",
|
|
900
|
+
disable=not self.verbose
|
|
901
|
+
))
|
|
902
|
+
self.features = np.array(features)
|
|
903
|
+
# Save to cache
|
|
904
|
+
np.save(cache_file, self.features)
|
|
905
|
+
|
|
906
|
+
if self.verbose:
|
|
907
|
+
print(f"Features extracted in {time.time()-start_time:.2f}s")
|
|
908
|
+
print(f"Feature matrix shape: {self.features.shape}")
|
|
909
|
+
|
|
910
|
+
return self.features
|
|
911
|
+
|
|
912
|
+
def _cache_preprocessed(self, images):
|
|
913
|
+
"""Cache preprocessed images to disk for faster subsequent runs"""
|
|
914
|
+
import concurrent.futures
|
|
915
|
+
cache_file = f"preprocessed_cache_{self.cfg_hash}.npy"
|
|
916
|
+
if os.path.exists(cache_file):
|
|
917
|
+
if self.verbose:
|
|
918
|
+
print("Loading preprocessed images from cache")
|
|
919
|
+
return np.load(cache_file, allow_pickle=True)
|
|
920
|
+
|
|
921
|
+
if self.verbose:
|
|
922
|
+
print(f"Preprocessing {len(images)} images...")
|
|
923
|
+
start_time = time.time()
|
|
924
|
+
processed = []
|
|
925
|
+
|
|
926
|
+
# Process in chunks
|
|
927
|
+
for i in tqdm(range(0, len(images), self.chunk_size),
|
|
928
|
+
desc="Preprocessing",
|
|
929
|
+
disable=not self.verbose):
|
|
930
|
+
chunk = images[i:i+self.chunk_size]
|
|
931
|
+
with ThreadPoolExecutor(max_workers=self.n_cpu) as executor:
|
|
932
|
+
processed_chunk = list(executor.map(self.preprocess_image, chunk))
|
|
933
|
+
processed.extend(processed_chunk)
|
|
934
|
+
|
|
935
|
+
# Save to cache
|
|
936
|
+
np.save(cache_file, processed)
|
|
937
|
+
|
|
938
|
+
if self.verbose:
|
|
939
|
+
print(f"Preprocessed in {time.time()-start_time:.2f}s")
|
|
940
|
+
|
|
941
|
+
return processed
|
|
942
|
+
|
|
943
|
+
def clear_cache(self, pattern: str = "*.npy") -> int:
|
|
944
|
+
"""Clear cached files with pattern matching"""
|
|
945
|
+
import glob
|
|
946
|
+
deleted = 0
|
|
947
|
+
for f in glob.glob(pattern):
|
|
948
|
+
try:
|
|
949
|
+
os.remove(f)
|
|
950
|
+
deleted += 1
|
|
951
|
+
if self.verbose:
|
|
952
|
+
print(f"Removed {f}")
|
|
953
|
+
except Exception as e:
|
|
954
|
+
if self.verbose:
|
|
955
|
+
print(f"Failed to remove {f}: {str(e)}")
|
|
956
|
+
return deleted
|
|
957
|
+
|
|
958
|
+
def _extract_raw_pixels(self, img: np.ndarray) -> np.ndarray:
|
|
959
|
+
"""提取原始像素特征"""
|
|
960
|
+
return img.flatten()
|
|
961
|
+
def _reset_state(self):
|
|
962
|
+
"""Reset internal state to free memory"""
|
|
963
|
+
self.features = None
|
|
964
|
+
self.X_train, self.y_train = None, None
|
|
965
|
+
self.X_test, self.y_test = None, None
|
|
966
|
+
self.clear_cache("preprocessed_cache_*.npy")
|
|
967
|
+
self.clear_cache("features_cache_*.npy")
|
|
968
|
+
self.clear_cache("cnn_preprocessed_*.npy")
|
|
969
|
+
gc.collect()
|
|
970
|
+
|
|
971
|
+
if DL_ENABLED:
|
|
972
|
+
tf.keras.backend.clear_session()
|
|
973
|
+
def _extract_hog(self, img: np.ndarray) -> np.ndarray:
|
|
974
|
+
"""提取HOG特征"""
|
|
975
|
+
params = self.cfg["feature_extraction"]["hog_params"]
|
|
976
|
+
|
|
977
|
+
if img.ndim == 3:
|
|
978
|
+
img = color.rgb2gray(img)
|
|
979
|
+
|
|
980
|
+
return feature.hog(
|
|
981
|
+
img,
|
|
982
|
+
orientations=params["orientations"],
|
|
983
|
+
pixels_per_cell=params["pixels_per_cell"],
|
|
984
|
+
cells_per_block=params["cells_per_block"],
|
|
985
|
+
feature_vector=True,
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
def _extract_lbp(self, img: np.ndarray) -> np.ndarray:
|
|
989
|
+
"""提取LBP特征"""
|
|
990
|
+
params = self.cfg["feature_extraction"]["lbp_params"]
|
|
991
|
+
|
|
992
|
+
if img.ndim == 3:
|
|
993
|
+
img = color.rgb2gray(img)
|
|
994
|
+
|
|
995
|
+
lbp = feature.local_binary_pattern(
|
|
996
|
+
img, P=params["P"], R=params["R"], method=params["method"]
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
# 计算直方图
|
|
1000
|
+
n_bins = int(lbp.max() + 1)
|
|
1001
|
+
hist, _ = np.histogram(lbp, bins=n_bins, density=True)
|
|
1002
|
+
|
|
1003
|
+
return hist
|
|
1004
|
+
|
|
1005
|
+
def _extract_color_histogram(self, img: np.ndarray) -> np.ndarray:
|
|
1006
|
+
"""提取颜色直方图特征"""
|
|
1007
|
+
if img.ndim == 2: # 灰度图像
|
|
1008
|
+
hist = cv2.calcHist([img], [0], None, [256], [0, 256])
|
|
1009
|
+
return hist.flatten()
|
|
1010
|
+
|
|
1011
|
+
# 彩色图像
|
|
1012
|
+
channels = cv2.split(img)
|
|
1013
|
+
hist_features = []
|
|
1014
|
+
|
|
1015
|
+
for i, channel in enumerate(channels):
|
|
1016
|
+
hist = cv2.calcHist([channel], [0], None, [256], [0, 256])
|
|
1017
|
+
hist_features.extend(hist.flatten())
|
|
1018
|
+
|
|
1019
|
+
return np.array(hist_features)
|
|
1020
|
+
|
|
1021
|
+
def _extract_cnn_features(self, img: np.ndarray) -> np.ndarray:
|
|
1022
|
+
"""使用预训练CNN提取特征"""
|
|
1023
|
+
if not self.dl_enabled:
|
|
1024
|
+
raise ImportError("Deep learning libraries not available")
|
|
1025
|
+
model_name = self.cfg["feature_extraction"]["cnn_model"].lower()
|
|
1026
|
+
models = {
|
|
1027
|
+
"vgg16": VGG16,
|
|
1028
|
+
"vgg19": VGG19,
|
|
1029
|
+
"resnet50": ResNet50,
|
|
1030
|
+
"inceptionv3": InceptionV3,
|
|
1031
|
+
"mobilenet": MobileNet,
|
|
1032
|
+
"densenet121": DenseNet121,
|
|
1033
|
+
"efficientnetb0": EfficientNetB0,
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
if model_name not in models:
|
|
1037
|
+
raise ValueError(f"Unsupported CNN model: {model_name}")
|
|
1038
|
+
|
|
1039
|
+
# 加载模型
|
|
1040
|
+
if model_name not in self._cnn_models:
|
|
1041
|
+
base_model = models[model_name](
|
|
1042
|
+
weights="imagenet", include_top=False, pooling="avg"
|
|
1043
|
+
)
|
|
1044
|
+
self._cnn_models[model_name] = Model(
|
|
1045
|
+
inputs=base_model.input, outputs=base_model.output
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
model = self._cnn_models[model_name]
|
|
1049
|
+
|
|
1050
|
+
# 预处理图像
|
|
1051
|
+
if img.ndim == 2:
|
|
1052
|
+
img = np.stack((img,) * 3, axis=-1)
|
|
1053
|
+
|
|
1054
|
+
img = transform.resize(img, (224, 224))
|
|
1055
|
+
img = image.img_to_array(img)
|
|
1056
|
+
img = np.expand_dims(img, axis=0)
|
|
1057
|
+
img = preprocess_input(img)
|
|
1058
|
+
|
|
1059
|
+
# 提取特征
|
|
1060
|
+
features = model.predict(img)
|
|
1061
|
+
return features.flatten()
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
def preprocess_image_for_cnn(self, img: np.ndarray) -> np.ndarray:
|
|
1065
|
+
"""
|
|
1066
|
+
Ultimate optimized CNN preprocessing:
|
|
1067
|
+
- Minimal operations
|
|
1068
|
+
- Vectorized normalization
|
|
1069
|
+
- Efficient resizing
|
|
1070
|
+
- Memory-friendly processing
|
|
1071
|
+
"""
|
|
1072
|
+
target_size = self.cfg["preprocessing"].get("resize", (32, 32))
|
|
1073
|
+
|
|
1074
|
+
# Convert PIL Image if needed
|
|
1075
|
+
if isinstance(img, Image.Image):
|
|
1076
|
+
img = np.array(img)
|
|
1077
|
+
|
|
1078
|
+
if img.ndim == 2: # Grayscale
|
|
1079
|
+
img = np.stack((img,)*3, axis=-1) # Convert to 3-channel
|
|
1080
|
+
elif img.shape[2] == 4: # RGBA
|
|
1081
|
+
img = img[..., :3] # Remove alpha
|
|
1082
|
+
# Ensure we have 3 channels
|
|
1083
|
+
if img.shape[2] == 1:
|
|
1084
|
+
img = np.concatenate([img]*3, axis=-1)
|
|
1085
|
+
# Handle different channel cases
|
|
1086
|
+
if img.ndim == 2: # Grayscale to RGB
|
|
1087
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
1088
|
+
elif img.ndim == 3 and img.shape[2] == 4: # RGBA
|
|
1089
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
|
1090
|
+
elif img.ndim == 3 and img.shape[2] == 3: # BGR to RGB
|
|
1091
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
1092
|
+
# Efficient resizing - use OpenCV for fastest performance
|
|
1093
|
+
if img.shape[:2] != target_size:
|
|
1094
|
+
img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
|
|
1095
|
+
|
|
1096
|
+
# Resize
|
|
1097
|
+
if img.shape[:2] != target_size:
|
|
1098
|
+
img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
|
|
1099
|
+
|
|
1100
|
+
# Normalization
|
|
1101
|
+
img = img.astype(np.float32) / 255.0
|
|
1102
|
+
return img
|
|
1103
|
+
def _build_cnn_model(self, input_shape=None, num_classes=None):
|
|
1104
|
+
"""Build CNN model with proper instance access"""
|
|
1105
|
+
if input_shape is None:
|
|
1106
|
+
input_shape = self.cfg["preprocessing"]["resize"]
|
|
1107
|
+
|
|
1108
|
+
if num_classes is None:
|
|
1109
|
+
if hasattr(self, 'labels') and self.labels is not None:
|
|
1110
|
+
num_classes = len(np.unique(self.labels)) if self.task_type == "classification" else 1
|
|
1111
|
+
else:
|
|
1112
|
+
num_classes = 2 # Default fallback
|
|
1113
|
+
|
|
1114
|
+
return build_cnn_model(
|
|
1115
|
+
input_shape=input_shape,
|
|
1116
|
+
num_classes=num_classes,
|
|
1117
|
+
**self.cfg["cnn"]
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
def _train_cnn_model(self, model_builder: Callable, params: dict):
|
|
1121
|
+
"""Optimized CNN training pipeline with fixes for hanging issue"""
|
|
1122
|
+
if not self.dl_enabled:
|
|
1123
|
+
raise ImportError("Deep learning libraries not available")
|
|
1124
|
+
|
|
1125
|
+
tf.config.optimizer.set_jit(True) # Enable XLA compilation
|
|
1126
|
+
tf.keras.backend.clear_session() # Clean any previous models
|
|
1127
|
+
# Configure GPU settings
|
|
1128
|
+
if self.has_gpu:
|
|
1129
|
+
set_global_policy('mixed_float16')
|
|
1130
|
+
|
|
1131
|
+
# Preprocessing
|
|
1132
|
+
target_size = self.cfg["preprocessing"].get("resize", (224, 224))
|
|
1133
|
+
X = self._preprocess_for_cnn(self.images,target_size)
|
|
1134
|
+
y = np.array(self.labels).flatten()
|
|
1135
|
+
|
|
1136
|
+
# Debug: Check data shapes and types
|
|
1137
|
+
if self.verbose:
|
|
1138
|
+
print(f"X shape: {X.shape}, dtype: {X.dtype}")
|
|
1139
|
+
print(f"y shape: {y.shape}, dtype: {y.dtype}")
|
|
1140
|
+
print(f"Unique labels: {np.unique(y)}")
|
|
1141
|
+
|
|
1142
|
+
# Split dataset
|
|
1143
|
+
stratify = y if self.task_type == "classification" else None
|
|
1144
|
+
X_train, X_val, y_train, y_val = train_test_split(
|
|
1145
|
+
X, y,
|
|
1146
|
+
test_size=self.cfg["test_size"],
|
|
1147
|
+
random_state=self.cfg["random_state"],
|
|
1148
|
+
stratify=stratify
|
|
1149
|
+
)
|
|
1150
|
+
|
|
1151
|
+
# Fix: Ensure data types match
|
|
1152
|
+
X_train = X_train.astype(np.float32)
|
|
1153
|
+
X_val = X_val.astype(np.float32)
|
|
1154
|
+
|
|
1155
|
+
# Data augmentation
|
|
1156
|
+
datagen = self._create_datagen(params)
|
|
1157
|
+
|
|
1158
|
+
# Build model
|
|
1159
|
+
input_shape = X.shape[1:]
|
|
1160
|
+
num_classes = len(np.unique(y)) if self.task_type == "classification" else 1
|
|
1161
|
+
model = model_builder(input_shape, num_classes)
|
|
1162
|
+
|
|
1163
|
+
# # Debug: Print model summary
|
|
1164
|
+
# if self.verbose:
|
|
1165
|
+
# model.summary()
|
|
1166
|
+
|
|
1167
|
+
# Training configuration
|
|
1168
|
+
batch_size = params.get('batch_size', 32)
|
|
1169
|
+
# FIX: Use explicit DataGenerator instead of direct arrays
|
|
1170
|
+
train_datagen = ImageDataGenerator()
|
|
1171
|
+
val_datagen = ImageDataGenerator()
|
|
1172
|
+
|
|
1173
|
+
train_generator = train_datagen.flow(
|
|
1174
|
+
X_train, y_train,
|
|
1175
|
+
batch_size=batch_size,
|
|
1176
|
+
shuffle=True
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
val_generator = val_datagen.flow(
|
|
1180
|
+
X_val, y_val,
|
|
1181
|
+
batch_size=batch_size,
|
|
1182
|
+
shuffle=False
|
|
1183
|
+
)
|
|
1184
|
+
# Calculate steps
|
|
1185
|
+
steps_per_epoch = max(1, len(X_train) // batch_size)
|
|
1186
|
+
validation_steps = max(1, len(X_val) // batch_size)
|
|
1187
|
+
|
|
1188
|
+
# Callbacks
|
|
1189
|
+
callbacks = self._create_callbacks(params)
|
|
1190
|
+
|
|
1191
|
+
# FIX: Add explicit print before training
|
|
1192
|
+
if self.verbose:
|
|
1193
|
+
print(f"\nStarting training with:")
|
|
1194
|
+
print(f" Train samples: {len(X_train)}")
|
|
1195
|
+
print(f" Validation samples: {len(X_val)}")
|
|
1196
|
+
print(f" Batch size: {batch_size}")
|
|
1197
|
+
print(f" Steps per epoch: {steps_per_epoch}")
|
|
1198
|
+
print(f" Validation steps: {validation_steps}")
|
|
1199
|
+
print("Compiling model...")
|
|
1200
|
+
|
|
1201
|
+
# FIX: Test model with a single batch
|
|
1202
|
+
try:
|
|
1203
|
+
if self.verbose:
|
|
1204
|
+
print("Testing with one batch...")
|
|
1205
|
+
test_batch = next(train_generator)
|
|
1206
|
+
model.predict(test_batch[0], verbose=0)
|
|
1207
|
+
if self.verbose:
|
|
1208
|
+
print("1st batch test successful")
|
|
1209
|
+
except Exception as e:
|
|
1210
|
+
print(f"❌ Error in single batch test: {str(e)}")
|
|
1211
|
+
raise
|
|
1212
|
+
# Train model
|
|
1213
|
+
start_time = time.time()
|
|
1214
|
+
|
|
1215
|
+
# FIX: Use explicit generators
|
|
1216
|
+
history = model.fit(
|
|
1217
|
+
train_generator,
|
|
1218
|
+
steps_per_epoch=steps_per_epoch,
|
|
1219
|
+
epochs=params.get('epochs', 50),
|
|
1220
|
+
validation_data=val_generator,
|
|
1221
|
+
validation_steps=validation_steps,
|
|
1222
|
+
callbacks=callbacks,
|
|
1223
|
+
verbose=1 if self.verbose else 0 # Use 1 for per-batch updates
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
# Collect metrics
|
|
1227
|
+
train_time = time.time() - start_time
|
|
1228
|
+
best_epoch = np.argmin(history.history['val_loss']) + 1
|
|
1229
|
+
metrics = {
|
|
1230
|
+
'train_time': train_time,
|
|
1231
|
+
'best_epoch': best_epoch,
|
|
1232
|
+
'val_loss': history.history['val_loss'][best_epoch-1],
|
|
1233
|
+
'best_val_loss': min(history.history['val_loss'])
|
|
1234
|
+
}
|
|
1235
|
+
|
|
1236
|
+
if self.task_type == 'classification':
|
|
1237
|
+
metrics['val_accuracy'] = history.history['val_accuracy'][best_epoch-1]
|
|
1238
|
+
metrics['best_val_accuracy'] = max(history.history['val_accuracy'])
|
|
1239
|
+
|
|
1240
|
+
# Save results
|
|
1241
|
+
self.model_performance["cnn"] = {
|
|
1242
|
+
"model": model,
|
|
1243
|
+
"metrics": metrics,
|
|
1244
|
+
"history": history.history,
|
|
1245
|
+
"params": params
|
|
1246
|
+
}
|
|
1247
|
+
self._update_best_model(model, metrics)
|
|
1248
|
+
# Cleanup
|
|
1249
|
+
tf.keras.backend.clear_session()
|
|
1250
|
+
gc.collect()
|
|
1251
|
+
|
|
1252
|
+
if self.verbose:
|
|
1253
|
+
print(f"Training completed in {train_time:.2f}s")
|
|
1254
|
+
print(f"Best validation accuracy: {metrics.get('best_val_accuracy', metrics['best_val_loss']):.4f}")
|
|
1255
|
+
|
|
1256
|
+
return model, metrics
|
|
1257
|
+
|
|
1258
|
+
|
|
1259
|
+
def _create_datagen(self, params: dict):
|
|
1260
|
+
"""Create data generator with augmentation"""
|
|
1261
|
+
use_augmentation = params.get('use_augmentation', False)
|
|
1262
|
+
|
|
1263
|
+
# Fix: Only apply augmentation if explicitly enabled
|
|
1264
|
+
if use_augmentation:
|
|
1265
|
+
aug_cfg = self.cfg["preprocessing"].get("augmentation", {})
|
|
1266
|
+
datagen = ImageDataGenerator(
|
|
1267
|
+
rotation_range=aug_cfg.get("rotation_range", 10),
|
|
1268
|
+
width_shift_range=aug_cfg.get("width_shift_range", 0.1),
|
|
1269
|
+
height_shift_range=aug_cfg.get("height_shift_range", 0.1),
|
|
1270
|
+
shear_range=aug_cfg.get("shear_range", 0.05),
|
|
1271
|
+
zoom_range=aug_cfg.get("zoom_range", 0.1),
|
|
1272
|
+
horizontal_flip=aug_cfg.get("horizontal_flip", True),
|
|
1273
|
+
fill_mode='reflect'
|
|
1274
|
+
)
|
|
1275
|
+
else:
|
|
1276
|
+
# Fix: Use simple generator without augmentation
|
|
1277
|
+
datagen = ImageDataGenerator()
|
|
1278
|
+
|
|
1279
|
+
return datagen
|
|
1280
|
+
def _preprocess_for_cnn(self, images, target_size: Tuple[int, int]) -> np.ndarray:
|
|
1281
|
+
"""Efficient CNN preprocessing with parallel processing"""
|
|
1282
|
+
cache_file = f"cnn_preprocessed_{self.cfg_hash}.npy"
|
|
1283
|
+
|
|
1284
|
+
if os.path.exists(cache_file):
|
|
1285
|
+
if self.verbose:
|
|
1286
|
+
print("Loading CNN preprocessed images from cache")
|
|
1287
|
+
return np.load(cache_file)
|
|
1288
|
+
|
|
1289
|
+
if self.verbose:
|
|
1290
|
+
print("Preprocessing images for CNN...")
|
|
1291
|
+
|
|
1292
|
+
start_time = time.time()
|
|
1293
|
+
processed = []
|
|
1294
|
+
|
|
1295
|
+
# Process in chunks
|
|
1296
|
+
for i in range(0, len(images), self.chunk_size):
|
|
1297
|
+
chunk = images[i:i+self.chunk_size]
|
|
1298
|
+
with ThreadPoolExecutor(max_workers=self.n_cpu) as executor:
|
|
1299
|
+
processed_chunk = list(executor.map(
|
|
1300
|
+
partial(self._preprocess_single_cnn, target_size=target_size),
|
|
1301
|
+
chunk
|
|
1302
|
+
))
|
|
1303
|
+
processed.extend(processed_chunk)
|
|
1304
|
+
|
|
1305
|
+
processed = np.array(processed)
|
|
1306
|
+
|
|
1307
|
+
# FIX: Ensure proper shape (samples, height, width, channels)
|
|
1308
|
+
if processed.ndim == 3:
|
|
1309
|
+
# Add channel dimension for grayscale
|
|
1310
|
+
processed = np.expand_dims(processed, axis=-1)
|
|
1311
|
+
|
|
1312
|
+
np.save(cache_file, processed)
|
|
1313
|
+
|
|
1314
|
+
if self.verbose:
|
|
1315
|
+
print(f"Preprocessed in {time.time()-start_time:.2f}s")
|
|
1316
|
+
print(f"Processed shape: {processed.shape}")
|
|
1317
|
+
|
|
1318
|
+
return processed
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
def _preprocess_single_cnn(self, img: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
|
1322
|
+
"""Optimized single image preprocessing for CNN"""
|
|
1323
|
+
# Convert to 3-channel if needed
|
|
1324
|
+
if img.ndim == 2:
|
|
1325
|
+
img = np.stack((img,)*3, axis=-1)
|
|
1326
|
+
elif img.shape[2] == 4:
|
|
1327
|
+
img = img[..., :3]
|
|
1328
|
+
# Handle channel conversions
|
|
1329
|
+
if img.ndim == 3:
|
|
1330
|
+
if img.shape[2] == 3: # RGB to BGR if needed
|
|
1331
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
1332
|
+
elif img.shape[2] == 1: # Grayscale to "RGB"
|
|
1333
|
+
img = np.concatenate([img]*3, axis=-1)
|
|
1334
|
+
# Resize with OpenCV
|
|
1335
|
+
if img.shape[:2] != target_size:
|
|
1336
|
+
img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
|
|
1337
|
+
|
|
1338
|
+
# Convert to float and normalize
|
|
1339
|
+
img = img.astype(np.float32) / 255.0
|
|
1340
|
+
|
|
1341
|
+
# Fix: Return grayscale if needed
|
|
1342
|
+
if self.cfg["preprocessing"].get("to_grayscale", True):
|
|
1343
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
|
1344
|
+
img = np.expand_dims(img, axis=-1)
|
|
1345
|
+
|
|
1346
|
+
return img
|
|
1347
|
+
|
|
1348
|
+
def _create_callbacks(self, params: dict) -> List[tf.keras.callbacks.Callback]:
|
|
1349
|
+
"""Create training callbacks"""
|
|
1350
|
+
return [
|
|
1351
|
+
EarlyStopping(
|
|
1352
|
+
patience=params.get('patience', 8),
|
|
1353
|
+
monitor='val_accuracy' if self.task_type == "classification" else 'val_loss',
|
|
1354
|
+
restore_best_weights=True,
|
|
1355
|
+
verbose=1 if self.verbose else 0
|
|
1356
|
+
),
|
|
1357
|
+
ModelCheckpoint(
|
|
1358
|
+
'best_model.keras',
|
|
1359
|
+
save_best_only=True,
|
|
1360
|
+
monitor='val_accuracy' if self.task_type == "classification" else 'val_loss',
|
|
1361
|
+
mode='max' if self.task_type == "classification" else 'min'
|
|
1362
|
+
),
|
|
1363
|
+
ReduceLROnPlateau(
|
|
1364
|
+
monitor='val_loss',
|
|
1365
|
+
factor=0.5,
|
|
1366
|
+
patience=3,
|
|
1367
|
+
min_lr=1e-6,
|
|
1368
|
+
verbose=1 if self.verbose else 0
|
|
1369
|
+
)
|
|
1370
|
+
]
|
|
1371
|
+
|
|
1372
|
+
def train_model(self, model_name: str = None, params: dict = None):
|
|
1373
|
+
"""训练指定的模型"""
|
|
1374
|
+
if model_name is None:
|
|
1375
|
+
model_name=self.cfg["model"]["name"]
|
|
1376
|
+
if model_name != "cnn" and self.features is None:
|
|
1377
|
+
self.extract_features()
|
|
1378
|
+
print(f"Training model: {model_name}")
|
|
1379
|
+
if self.labels is None:
|
|
1380
|
+
raise ValueError("Labels are required for training")
|
|
1381
|
+
# 特殊处理需要自定义训练流程的模型
|
|
1382
|
+
if model_name in ["xgb", "lightgbm", "catboost"]:
|
|
1383
|
+
# 设置GPU加速(如果可用)
|
|
1384
|
+
if self.has_gpu:
|
|
1385
|
+
params.update({"tree_method": "gpu_hist", "gpu_id": 0})
|
|
1386
|
+
|
|
1387
|
+
# 设置类别权重(针对不平衡数据)
|
|
1388
|
+
if self.task_type == "classification":
|
|
1389
|
+
class_weights = compute_class_weight('balanced', classes=np.unique(self.y_train), y=self.y_train)
|
|
1390
|
+
params["class_weight"] = dict(zip(np.unique(self.y_train), class_weights))
|
|
1391
|
+
|
|
1392
|
+
# 获取模型配置
|
|
1393
|
+
model_cfg = self.cfg["model"]
|
|
1394
|
+
model_name = model_name or model_cfg["name"]
|
|
1395
|
+
params = params or model_cfg.get("params", {})
|
|
1396
|
+
|
|
1397
|
+
if model_name not in self.models:
|
|
1398
|
+
raise ValueError(f"Unsupported model: {model_name}")
|
|
1399
|
+
|
|
1400
|
+
# 特殊处理CNN模型
|
|
1401
|
+
if model_name == "cnn":
|
|
1402
|
+
if not self.dl_enabled:
|
|
1403
|
+
raise ImportError("Deep learning libraries not available")
|
|
1404
|
+
return self._train_cnn_model(self.models[model_name], params or {})
|
|
1405
|
+
|
|
1406
|
+
# 划分训练测试集
|
|
1407
|
+
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
|
|
1408
|
+
self.features,
|
|
1409
|
+
self.labels,
|
|
1410
|
+
test_size=self.cfg["test_size"],
|
|
1411
|
+
random_state=self.cfg["random_state"],
|
|
1412
|
+
stratify=self.labels,
|
|
1413
|
+
)
|
|
1414
|
+
print("划分训练测试集")
|
|
1415
|
+
|
|
1416
|
+
# 过滤不支持的参数
|
|
1417
|
+
model_class = self.models[model_name]
|
|
1418
|
+
|
|
1419
|
+
supported_params = model_class().get_params().keys()
|
|
1420
|
+
filtered_params = {k: v for k, v in params.items() if k in supported_params}
|
|
1421
|
+
|
|
1422
|
+
if len(filtered_params) != len(params) and self.verbose:
|
|
1423
|
+
removed_params = set(params.keys()) - set(filtered_params.keys())
|
|
1424
|
+
print(f"Removed unsupported params for {model_name}: {removed_params}")
|
|
1425
|
+
# 创建模型
|
|
1426
|
+
model = model_class(**filtered_params)
|
|
1427
|
+
|
|
1428
|
+
print("创建模型")
|
|
1429
|
+
if self.verbose:
|
|
1430
|
+
print(f"🤖 Training {model_name} model...")
|
|
1431
|
+
print(f"📐 Parameters: {params}")
|
|
1432
|
+
|
|
1433
|
+
# 训练模型
|
|
1434
|
+
start_time = time.time()
|
|
1435
|
+
model.fit(self.X_train, self.y_train)
|
|
1436
|
+
train_time = time.time() - start_time
|
|
1437
|
+
|
|
1438
|
+
print(" 训练模型")
|
|
1439
|
+
# 评估模型
|
|
1440
|
+
metrics = self.evaluate_model(model)
|
|
1441
|
+
print(" 评估模型")
|
|
1442
|
+
# 保存结果
|
|
1443
|
+
self.model_performance[model_name] = {
|
|
1444
|
+
"model": model,
|
|
1445
|
+
"metrics": metrics,
|
|
1446
|
+
"train_time": train_time,
|
|
1447
|
+
"params": params,
|
|
1448
|
+
}
|
|
1449
|
+
print(" 保存结果")
|
|
1450
|
+
# 更新最佳模型
|
|
1451
|
+
self._update_best_model(model, metrics)
|
|
1452
|
+
|
|
1453
|
+
if self.verbose:
|
|
1454
|
+
print(f"Model trained in {train_time:.2f}s")
|
|
1455
|
+
print("Performance metrics:")
|
|
1456
|
+
for k, v in metrics.items():
|
|
1457
|
+
print(f" {k}: {v:.4f}")
|
|
1458
|
+
|
|
1459
|
+
return model, metrics
|
|
1460
|
+
|
|
1461
|
+
def auto_compare_models(
|
|
1462
|
+
self, model_names: List[str] = None, params_list: List[dict] = None
|
|
1463
|
+
):
|
|
1464
|
+
"""自动比较多个模型并排名"""
|
|
1465
|
+
if model_names is None:
|
|
1466
|
+
# 根据任务类型选择默认模型
|
|
1467
|
+
if self.task_type == "classification":
|
|
1468
|
+
model_names = [
|
|
1469
|
+
"random_forest", "svm", "knn", "gradient_boosting", "logistic_regression",
|
|
1470
|
+
"xgb", "lightgbm", "catboost", "extra_trees", "mlp", "bagging"
|
|
1471
|
+
]
|
|
1472
|
+
params_list = [
|
|
1473
|
+
{"n_estimators": 100}, # RF
|
|
1474
|
+
{"C": 1.0, "kernel": "rbf"}, # SVM
|
|
1475
|
+
{"n_neighbors": 5}, # KNN
|
|
1476
|
+
{"n_estimators": 100}, # GBM
|
|
1477
|
+
{"C": 1.0, "penalty": "l2"}, # Logistic
|
|
1478
|
+
{"n_estimators": 100}, # XGBoost
|
|
1479
|
+
{"n_estimators": 100}, # LightGBM
|
|
1480
|
+
{"iterations": 100}, # CatBoost
|
|
1481
|
+
{"n_estimators": 100}, # Extra Trees
|
|
1482
|
+
{"hidden_layer_sizes": (64,)}, # MLP
|
|
1483
|
+
{"n_estimators": 50} # Bagging
|
|
1484
|
+
]
|
|
1485
|
+
else: # regression
|
|
1486
|
+
model_names = [
|
|
1487
|
+
"random_forest_reg", "svr", "knn_regressor", "gradient_boosting_reg", "linear_regression",
|
|
1488
|
+
"xgb_reg", "lightgbm_reg", "catboost_reg", "extra_trees_reg", "mlp_reg", "bagging_reg"
|
|
1489
|
+
]
|
|
1490
|
+
params_list = [
|
|
1491
|
+
{"n_estimators": 100}, # RF Reg
|
|
1492
|
+
{"C": 1.0, "kernel": "rbf"}, # SVR
|
|
1493
|
+
{"n_neighbors": 5}, # KNN Reg
|
|
1494
|
+
{"n_estimators": 100}, # GBM Reg
|
|
1495
|
+
{}, # Linear Regression
|
|
1496
|
+
{"n_estimators": 100}, # XGBoost Reg
|
|
1497
|
+
{"n_estimators": 100}, # LightGBM Reg
|
|
1498
|
+
{"iterations": 100}, # CatBoost Reg
|
|
1499
|
+
{"n_estimators": 100}, # Extra Trees Reg
|
|
1500
|
+
{"hidden_layer_sizes": (64,)}, # MLP Reg
|
|
1501
|
+
{"n_estimators": 50} # Bagging Reg
|
|
1502
|
+
]
|
|
1503
|
+
# 分类的话, 必须要>=2个分类
|
|
1504
|
+
if self.task_type == "classification":
|
|
1505
|
+
unique_classes = np.unique(self.labels)
|
|
1506
|
+
if len(unique_classes) < 2:
|
|
1507
|
+
raise ValueError(
|
|
1508
|
+
"Classification requires ≥2 classes. Your data has only one class."
|
|
1509
|
+
)
|
|
1510
|
+
if params_list is None:
|
|
1511
|
+
params_list = [{}] * len(model_names)
|
|
1512
|
+
|
|
1513
|
+
if len(model_names) != len(params_list):
|
|
1514
|
+
raise ValueError("model_names and params_list must have the same length")
|
|
1515
|
+
|
|
1516
|
+
if self.verbose:
|
|
1517
|
+
print(f"Comparing {len(model_names)} models...")
|
|
1518
|
+
|
|
1519
|
+
# 训练并评估所有模型
|
|
1520
|
+
results = []
|
|
1521
|
+
for name, params in zip(model_names, params_list):
|
|
1522
|
+
try:
|
|
1523
|
+
_, metrics = self.train_model(name, params)
|
|
1524
|
+
results.append({"model": name, "params": params, **metrics})
|
|
1525
|
+
except Exception as e:
|
|
1526
|
+
print(f"Error training {name}: {str(e)}")
|
|
1527
|
+
|
|
1528
|
+
# 创建结果DataFrame
|
|
1529
|
+
results_df = pd.DataFrame(results)
|
|
1530
|
+
|
|
1531
|
+
# 根据任务类型选择排序指标
|
|
1532
|
+
if self.task_type == "classification":
|
|
1533
|
+
sort_by = "accuracy"
|
|
1534
|
+
else:
|
|
1535
|
+
sort_by = "r2"
|
|
1536
|
+
|
|
1537
|
+
# 排序结果
|
|
1538
|
+
results_df = results_df.sort_values(by=sort_by, ascending=False)
|
|
1539
|
+
|
|
1540
|
+
# 保存最佳模型
|
|
1541
|
+
best_model_name = results_df.iloc[0]["model"]
|
|
1542
|
+
self.best_model = self.model_performance[best_model_name]["model"]
|
|
1543
|
+
|
|
1544
|
+
if self.verbose:
|
|
1545
|
+
print("\n🏆 Model Comparison Results:")
|
|
1546
|
+
print(results_df[["model", sort_by]])
|
|
1547
|
+
print(f"\n🥇 Best model: {best_model_name}")
|
|
1548
|
+
|
|
1549
|
+
return results_df
|
|
1550
|
+
|
|
1551
|
+
def _update_best_model(self, model, metrics):
|
|
1552
|
+
"""Helper to update the best model reference"""
|
|
1553
|
+
if self.best_model is None:
|
|
1554
|
+
self.best_model = model
|
|
1555
|
+
else:
|
|
1556
|
+
current_metric = self._get_primary_metric(self.best_model)
|
|
1557
|
+
new_metric = self._get_primary_metric(model, metrics)
|
|
1558
|
+
|
|
1559
|
+
if new_metric > current_metric:
|
|
1560
|
+
self.best_model = model
|
|
1561
|
+
|
|
1562
|
+
def _get_primary_metric(self, model, metrics=None):
|
|
1563
|
+
"""Get the appropriate primary metric for comparison"""
|
|
1564
|
+
if metrics is None:
|
|
1565
|
+
metrics = self.evaluate_model(model=model)
|
|
1566
|
+
|
|
1567
|
+
return metrics.get(
|
|
1568
|
+
"accuracy" if self.task_type == "classification" else "r2", 0
|
|
1569
|
+
)
|
|
1570
|
+
|
|
1571
|
+
def evaluate_model(self, model: BaseEstimator = None) -> dict:
|
|
1572
|
+
"""评估模型性能"""
|
|
1573
|
+
if model is None:
|
|
1574
|
+
if self.best_model is None:
|
|
1575
|
+
raise ValueError("No model available for evaluation")
|
|
1576
|
+
model = self.best_model
|
|
1577
|
+
|
|
1578
|
+
# 计算指标
|
|
1579
|
+
metrics = {}
|
|
1580
|
+
try:
|
|
1581
|
+
# 预测
|
|
1582
|
+
y_pred = model.predict(self.X_test)
|
|
1583
|
+
print("try to predict it again")
|
|
1584
|
+
if self.task_type == "classification":
|
|
1585
|
+
metrics["accuracy"] = accuracy_score(self.y_test, y_pred)
|
|
1586
|
+
metrics["f1_weighted"] = f1_score(self.y_test, y_pred, average="weighted")
|
|
1587
|
+
if self.verbose:
|
|
1588
|
+
print("\nClassification Report:")
|
|
1589
|
+
print(
|
|
1590
|
+
classification_report(
|
|
1591
|
+
self.y_test, y_pred, target_names=self.class_names
|
|
1592
|
+
)
|
|
1593
|
+
)
|
|
1594
|
+
# 可视化混淆矩阵
|
|
1595
|
+
if (
|
|
1596
|
+
self.cfg["visualization"]["enable"]
|
|
1597
|
+
and self.cfg["visualization"]["confusion_matrix"]
|
|
1598
|
+
):
|
|
1599
|
+
self.plot_confusion_matrix(self.y_test, y_pred)
|
|
1600
|
+
else:
|
|
1601
|
+
metrics["mse"] = mean_squared_error(self.y_test, y_pred)
|
|
1602
|
+
metrics["mae"] = mean_absolute_error(self.y_test, y_pred)
|
|
1603
|
+
metrics["r2"] = r2_score(self.y_test, y_pred)
|
|
1604
|
+
except Exception as e:
|
|
1605
|
+
print(f"Error evaluating model: {str(e)}")
|
|
1606
|
+
return metrics
|
|
1607
|
+
|
|
1608
|
+
def _reset_prediction_state(self):
|
|
1609
|
+
"""Reset internal state before new prediction"""
|
|
1610
|
+
# Clear feature cache
|
|
1611
|
+
self.features = None
|
|
1612
|
+
self.clear_cache()
|
|
1613
|
+
# Clear CNN-specific cache
|
|
1614
|
+
if hasattr(self, 'preprocessed_images'):
|
|
1615
|
+
del self.preprocessed_images
|
|
1616
|
+
# Clear image cache
|
|
1617
|
+
self.images = []
|
|
1618
|
+
self.image_input = []
|
|
1619
|
+
# Clear train/test splits
|
|
1620
|
+
if hasattr(self, 'X_train'): self.X_train = None
|
|
1621
|
+
if hasattr(self, 'X_test'): self.X_test = None
|
|
1622
|
+
if hasattr(self, 'y_train'): self.y_train = None
|
|
1623
|
+
if hasattr(self, 'y_test'): self.y_test = None
|
|
1624
|
+
|
|
1625
|
+
def predict(
|
|
1626
|
+
self,
|
|
1627
|
+
image_paths: List[str],
|
|
1628
|
+
model: BaseEstimator = None,
|
|
1629
|
+
output_dir: str = None,
|
|
1630
|
+
) -> np.ndarray:
|
|
1631
|
+
"""预测新图像的标签并生成报告"""
|
|
1632
|
+
self._reset_prediction_state()
|
|
1633
|
+
if model is None:
|
|
1634
|
+
if self.best_model is None:
|
|
1635
|
+
raise ValueError("No model available for prediction")
|
|
1636
|
+
model = self.best_model
|
|
1637
|
+
image_paths = list(image_paths)
|
|
1638
|
+
# 加载并预处理图像
|
|
1639
|
+
images, _ = self.load_images(image_paths)
|
|
1640
|
+
|
|
1641
|
+
# Check if model is a CNN (Sequential/Functional API with Conv layers)
|
|
1642
|
+
is_cnn = any(isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Convolution2D))
|
|
1643
|
+
for layer in model.layers)
|
|
1644
|
+
if is_cnn:
|
|
1645
|
+
print("Using direct CNN prediction") # Debug
|
|
1646
|
+
target_size = self.cfg["preprocessing"].get("resize", (32, 32))
|
|
1647
|
+
images = self._preprocess_for_cnn(images,target_size=target_size) # Must return (N,H,W,C)
|
|
1648
|
+
predictions = model.predict(images)
|
|
1649
|
+
else:
|
|
1650
|
+
features = self.extract_features(images)
|
|
1651
|
+
predictions = model.predict(features)
|
|
1652
|
+
self.clear_cache()
|
|
1653
|
+
self.predictions=predictions
|
|
1654
|
+
# 解码分类标签
|
|
1655
|
+
# if self.task_type == "classification" and self.label_encoder:
|
|
1656
|
+
# predictions = self.label_encoder.inverse_transform(predictions)
|
|
1657
|
+
# # Handle predictions based on model type
|
|
1658
|
+
if self.task_type == "classification":
|
|
1659
|
+
if predictions.ndim > 1 and predictions.shape[1] > 1:
|
|
1660
|
+
# Multi-class: convert probabilities to class indices
|
|
1661
|
+
predictions = np.argmax(predictions, axis=1)
|
|
1662
|
+
else:
|
|
1663
|
+
# Binary classification: threshold at 0.5
|
|
1664
|
+
predictions = (predictions > 0.5).astype(int)
|
|
1665
|
+
|
|
1666
|
+
# Only inverse transform if label_encoder exists
|
|
1667
|
+
if hasattr(self, 'label_encoder') and self.label_encoder:
|
|
1668
|
+
try:
|
|
1669
|
+
predictions = self.label_encoder.inverse_transform(predictions)
|
|
1670
|
+
except ValueError as e:
|
|
1671
|
+
print(f"Label transform warning: {e}")
|
|
1672
|
+
print("Returning numeric predictions instead")
|
|
1673
|
+
# 可视化预测结果
|
|
1674
|
+
if self.cfg["visualization"]["enable"]:
|
|
1675
|
+
self.visualize_predictions(images, predictions, image_paths)
|
|
1676
|
+
|
|
1677
|
+
# 生成预测报告
|
|
1678
|
+
if output_dir:
|
|
1679
|
+
self.generate_prediction_report(image_paths, predictions, output_dir)
|
|
1680
|
+
|
|
1681
|
+
return predictions
|
|
1682
|
+
|
|
1683
|
+
def save_model(self, model: BaseEstimator, file_path: str, format: str = "joblib"):
|
|
1684
|
+
"""
|
|
1685
|
+
保存训练好的模型
|
|
1686
|
+
|
|
1687
|
+
参数:
|
|
1688
|
+
model: 要保存的模型对象
|
|
1689
|
+
file_path: 保存路径
|
|
1690
|
+
format: 保存格式 ('joblib' 或 'pickle')
|
|
1691
|
+
"""
|
|
1692
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
1693
|
+
|
|
1694
|
+
if format == "joblib":
|
|
1695
|
+
joblib.dump(model, file_path)
|
|
1696
|
+
elif format == "pickle":
|
|
1697
|
+
with open(file_path, "wb") as f:
|
|
1698
|
+
pickle.dump(model, f)
|
|
1699
|
+
else:
|
|
1700
|
+
raise ValueError("Unsupported format. Use 'joblib' or 'pickle'")
|
|
1701
|
+
|
|
1702
|
+
if self.verbose:
|
|
1703
|
+
print(f"Model saved to {file_path}")
|
|
1704
|
+
|
|
1705
|
+
def load_model(self, file_path: str, format: str = "joblib") -> BaseEstimator:
|
|
1706
|
+
"""
|
|
1707
|
+
加载保存的模型
|
|
1708
|
+
|
|
1709
|
+
参数:
|
|
1710
|
+
file_path: 模型文件路径
|
|
1711
|
+
format: 文件格式 ('joblib' 或 'pickle')
|
|
1712
|
+
|
|
1713
|
+
返回:
|
|
1714
|
+
加载的模型对象
|
|
1715
|
+
"""
|
|
1716
|
+
if format == "joblib":
|
|
1717
|
+
model = joblib.load(file_path)
|
|
1718
|
+
elif format == "pickle":
|
|
1719
|
+
with open(file_path, "rb") as f:
|
|
1720
|
+
model = pickle.load(f)
|
|
1721
|
+
else:
|
|
1722
|
+
raise ValueError("Unsupported format. Use 'joblib' or 'pickle'")
|
|
1723
|
+
|
|
1724
|
+
if self.best_model is None:
|
|
1725
|
+
self.best_model = model
|
|
1726
|
+
print(f"the best model is loaded: {model}")
|
|
1727
|
+
if self.verbose:
|
|
1728
|
+
print(f"Model loaded from {file_path}")
|
|
1729
|
+
|
|
1730
|
+
return model
|
|
1731
|
+
|
|
1732
|
+
def visualize_predictions(
|
|
1733
|
+
self, images: List[np.ndarray], predictions: List, paths: List[str] = None
|
|
1734
|
+
):
|
|
1735
|
+
"""可视化预测结果"""
|
|
1736
|
+
n_samples = min(self.cfg["visualization"]["sample_images"], len(images))
|
|
1737
|
+
|
|
1738
|
+
plt.figure(figsize=(15, 8))
|
|
1739
|
+
for i in range(n_samples):
|
|
1740
|
+
plt.subplot(2, (n_samples + 1) // 2, i + 1)
|
|
1741
|
+
|
|
1742
|
+
if images[i].ndim == 2:
|
|
1743
|
+
plt.imshow(images[i], cmap="gray")
|
|
1744
|
+
else:
|
|
1745
|
+
plt.imshow(images[i])
|
|
1746
|
+
|
|
1747
|
+
title = f"Pred: {predictions[i]}"
|
|
1748
|
+
if paths and isinstance(paths[i], str) and paths[i]:
|
|
1749
|
+
title += f"\n{os.path.basename(paths[i])}"
|
|
1750
|
+
|
|
1751
|
+
plt.title(title, fontsize=10)
|
|
1752
|
+
plt.axis("off")
|
|
1753
|
+
|
|
1754
|
+
plt.tight_layout()
|
|
1755
|
+
|
|
1756
|
+
# 保存可视化结果
|
|
1757
|
+
if self.cfg["report"].get("include_visuals", True):
|
|
1758
|
+
vis_path = os.path.join("reports", "prediction_visuals.png")
|
|
1759
|
+
os.makedirs(os.path.dirname(vis_path), exist_ok=True)
|
|
1760
|
+
plt.savefig(vis_path, dpi=150, bbox_inches="tight")
|
|
1761
|
+
|
|
1762
|
+
plt.show()
|
|
1763
|
+
|
|
1764
|
+
def plot_confusion_matrix(self, y_true, y_pred):
|
|
1765
|
+
"""绘制混淆矩阵"""
|
|
1766
|
+
cm = confusion_matrix(y_true, y_pred)
|
|
1767
|
+
plt.figure(figsize=(10, 8))
|
|
1768
|
+
sns.heatmap(
|
|
1769
|
+
cm,
|
|
1770
|
+
annot=True,
|
|
1771
|
+
fmt="d",
|
|
1772
|
+
cmap="Blues",
|
|
1773
|
+
xticklabels=self.class_names,
|
|
1774
|
+
yticklabels=self.class_names,
|
|
1775
|
+
)
|
|
1776
|
+
plt.xlabel("Predicted")
|
|
1777
|
+
plt.ylabel("True")
|
|
1778
|
+
plt.title("Confusion Matrix")
|
|
1779
|
+
|
|
1780
|
+
# 保存可视化结果
|
|
1781
|
+
if self.cfg["report"].get("include_visuals", True):
|
|
1782
|
+
cm_path = os.path.join("reports", "confusion_matrix.png")
|
|
1783
|
+
os.makedirs(os.path.dirname(cm_path), exist_ok=True)
|
|
1784
|
+
plt.savefig(cm_path, dpi=150, bbox_inches="tight")
|
|
1785
|
+
|
|
1786
|
+
plt.show()
|
|
1787
|
+
|
|
1788
|
+
def visualize_feature_space(self):
|
|
1789
|
+
"""可视化特征空间"""
|
|
1790
|
+
if self.features is None:
|
|
1791
|
+
raise ValueError("No features available for visualization")
|
|
1792
|
+
if self.verbose:
|
|
1793
|
+
print(f"features dimensions: {self.features.shape[1]}")
|
|
1794
|
+
# 降维
|
|
1795
|
+
if self.features.shape[1] > 2:
|
|
1796
|
+
if self.verbose:
|
|
1797
|
+
print(
|
|
1798
|
+
"Reducing feature dimensions for visualization...\n\t第一步:PCA;\n\t第二步:t-SNE to 2D)"
|
|
1799
|
+
)
|
|
1800
|
+
# 使用PCA进行初步降维, n_components=50
|
|
1801
|
+
pca = PCA(n_components=50)
|
|
1802
|
+
features_reduced = pca.fit_transform(self.features)
|
|
1803
|
+
|
|
1804
|
+
# 使用t-SNE进一步降维到2D
|
|
1805
|
+
tsne = TSNE(n_components=2, random_state=self.cfg["random_state"])
|
|
1806
|
+
features_2d = tsne.fit_transform(features_reduced)
|
|
1807
|
+
else:
|
|
1808
|
+
features_2d = self.features
|
|
1809
|
+
|
|
1810
|
+
# 绘制特征空间
|
|
1811
|
+
plt.figure(figsize=(12, 8))
|
|
1812
|
+
|
|
1813
|
+
if self.labels is not None:
|
|
1814
|
+
scatter = plt.scatter(
|
|
1815
|
+
features_2d[:, 0],
|
|
1816
|
+
features_2d[:, 1],
|
|
1817
|
+
c=self.labels,
|
|
1818
|
+
cmap=self.cfg["visualization"]["cmap"],
|
|
1819
|
+
alpha=0.7,
|
|
1820
|
+
)
|
|
1821
|
+
|
|
1822
|
+
if self.task_type == "classification":
|
|
1823
|
+
plt.legend(*scatter.legend_elements(), title="Classes")
|
|
1824
|
+
else:
|
|
1825
|
+
plt.scatter(features_2d[:, 0], features_2d[:, 1], alpha=0.7)
|
|
1826
|
+
|
|
1827
|
+
plt.title("Feature Space Visualization")
|
|
1828
|
+
plt.xlabel("Dimension 1")
|
|
1829
|
+
plt.ylabel("Dimension 2")
|
|
1830
|
+
|
|
1831
|
+
# 保存可视化结果
|
|
1832
|
+
if self.cfg["report"].get("include_visuals", True):
|
|
1833
|
+
fs_path = os.path.join("reports", "feature_space.pdf")
|
|
1834
|
+
os.makedirs(os.path.dirname(fs_path), exist_ok=True)
|
|
1835
|
+
plt.savefig(fs_path, dpi=150, bbox_inches="tight")
|
|
1836
|
+
|
|
1837
|
+
plt.show()
|
|
1838
|
+
|
|
1839
|
+
def generate_report(self, output_path: str = "reports/training_report.html"):
|
|
1840
|
+
"""生成训练报告"""
|
|
1841
|
+
report_type = "html" if output_path.endswith(".html") else "md"
|
|
1842
|
+
|
|
1843
|
+
if report_type == "html":
|
|
1844
|
+
report = self._generate_html_report()
|
|
1845
|
+
with open(output_path, "w") as f:
|
|
1846
|
+
f.write(report)
|
|
1847
|
+
else:
|
|
1848
|
+
report = self._generate_markdown_report()
|
|
1849
|
+
with open(output_path, "w") as f:
|
|
1850
|
+
f.write(report)
|
|
1851
|
+
|
|
1852
|
+
if self.verbose:
|
|
1853
|
+
print(f"Report generated at {output_path}")
|
|
1854
|
+
|
|
1855
|
+
return report
|
|
1856
|
+
|
|
1857
|
+
def generate_prediction_report(
|
|
1858
|
+
self, image_paths: List[str], predictions: List, output_dir: str = "reports"
|
|
1859
|
+
):
|
|
1860
|
+
"""生成预测报告"""
|
|
1861
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1862
|
+
report_path = os.path.join(output_dir, "prediction_report.html")
|
|
1863
|
+
|
|
1864
|
+
# 创建报告内容
|
|
1865
|
+
report = """
|
|
1866
|
+
<!DOCTYPE html>
|
|
1867
|
+
<html>
|
|
1868
|
+
<head>
|
|
1869
|
+
<title>Model Training Report</title>
|
|
1870
|
+
<style>
|
|
1871
|
+
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
|
1872
|
+
h1, h2 {{ color: #2c3e50; }}
|
|
1873
|
+
table {{ border-collapse: collapse; width: 100%; margin-bottom: 20px; }}
|
|
1874
|
+
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
|
1875
|
+
th {{ background-color: #f2f2f2; }}
|
|
1876
|
+
.summary {{ background-color: #f9f9f9; padding: 15px; border-radius: 5px; }}
|
|
1877
|
+
.section {{ margin-bottom: 30px; }}
|
|
1878
|
+
img {{ max-width: 100%; height: auto; }}
|
|
1879
|
+
</style>
|
|
1880
|
+
</head>
|
|
1881
|
+
<body>
|
|
1882
|
+
<h1>Image Prediction Report</h1>
|
|
1883
|
+
<p>Generated on: {date}</p>
|
|
1884
|
+
|
|
1885
|
+
<div class="summary">
|
|
1886
|
+
<h2>Summary</h2>
|
|
1887
|
+
<p><strong>Total Predictions:</strong> {count}</p>
|
|
1888
|
+
<p><strong>Model Used:</strong> {model_name}</p>
|
|
1889
|
+
<p><strong>Task Type:</strong> {task_type}</p>
|
|
1890
|
+
</div>
|
|
1891
|
+
|
|
1892
|
+
<h2>Prediction Details</h2>
|
|
1893
|
+
<table>
|
|
1894
|
+
<tr>
|
|
1895
|
+
<th>Image</th>
|
|
1896
|
+
<th>Filename</th>
|
|
1897
|
+
<th>Prediction</th>
|
|
1898
|
+
</tr>
|
|
1899
|
+
{rows}
|
|
1900
|
+
</table>
|
|
1901
|
+
</body>
|
|
1902
|
+
</html>
|
|
1903
|
+
"""
|
|
1904
|
+
|
|
1905
|
+
# 生成表格行
|
|
1906
|
+
rows = ""
|
|
1907
|
+
for path, pred in zip(image_paths, predictions):
|
|
1908
|
+
try:
|
|
1909
|
+
img_tag = f'<img src="{path}" alt="{os.path.basename(path)}">'
|
|
1910
|
+
rows += f"<tr><td>{img_tag}</td><td>{os.path.basename(path)}</td><td>{pred}</td></tr>"
|
|
1911
|
+
except Exception as e:
|
|
1912
|
+
print(e)
|
|
1913
|
+
break
|
|
1914
|
+
|
|
1915
|
+
# 填充报告
|
|
1916
|
+
report = report.format(
|
|
1917
|
+
date=time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
1918
|
+
count=len(image_paths),
|
|
1919
|
+
model_name=(
|
|
1920
|
+
self.best_model.__class__.__name__ if self.best_model else "Unknown"
|
|
1921
|
+
),
|
|
1922
|
+
task_type=self.task_type.capitalize(),
|
|
1923
|
+
rows=rows,
|
|
1924
|
+
)
|
|
1925
|
+
|
|
1926
|
+
# 保存报告
|
|
1927
|
+
with open(report_path, "w") as f:
|
|
1928
|
+
f.write(report)
|
|
1929
|
+
|
|
1930
|
+
if self.verbose:
|
|
1931
|
+
print(f"Prediction report generated at {report_path}")
|
|
1932
|
+
|
|
1933
|
+
return report
|
|
1934
|
+
|
|
1935
|
+
def _generate_html_report(self) -> str:
|
|
1936
|
+
"""生成HTML格式的训练报告"""
|
|
1937
|
+
report = """
|
|
1938
|
+
<!DOCTYPE html>
|
|
1939
|
+
<html>
|
|
1940
|
+
<head>
|
|
1941
|
+
<title>Model Training Report</title>
|
|
1942
|
+
<style>
|
|
1943
|
+
body {{font-family: Arial; margin: 20px; }}
|
|
1944
|
+
h1, h2 {{ color: #2c3e50; }}
|
|
1945
|
+
table {{ border-collapse: collapse; width: 100%; margin-bottom: 20px; }}
|
|
1946
|
+
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
|
1947
|
+
th {{ background-color: #f2f2f2; }}
|
|
1948
|
+
.summary {{ background-color: #f9f9f9; padding: 15px; border-radius: 5px; }}
|
|
1949
|
+
.section {{ margin-bottom: 30px; }}
|
|
1950
|
+
img {{ max-width: 100%; height: auto; }}
|
|
1951
|
+
</style>
|
|
1952
|
+
</head>
|
|
1953
|
+
<body>
|
|
1954
|
+
<h1>Image Machine Learning Training Report</h1>
|
|
1955
|
+
<p>Generated on: {{date}}</p>
|
|
1956
|
+
|
|
1957
|
+
<div class="summary">
|
|
1958
|
+
<h2>Summary</h2>
|
|
1959
|
+
<p><strong>Task Type:</strong> {task_type}</p>
|
|
1960
|
+
<p><strong>Number of Images:</strong> {num_images}</p>
|
|
1961
|
+
<p><strong>Feature Extraction Method:</strong> {feat_method}</p>
|
|
1962
|
+
<p><strong>Best Model:</strong> {best_model} (Accuracy/R2: {best_metric:.4f})</p>
|
|
1963
|
+
</div>
|
|
1964
|
+
|
|
1965
|
+
<div class="section">
|
|
1966
|
+
<h2>Model Performance Comparison</h2>
|
|
1967
|
+
{model_table}
|
|
1968
|
+
</div>
|
|
1969
|
+
|
|
1970
|
+
<div class="section">
|
|
1971
|
+
<h2>cfguration</h2>
|
|
1972
|
+
<pre>{cfg}</pre>
|
|
1973
|
+
</div>
|
|
1974
|
+
|
|
1975
|
+
<div class="section">
|
|
1976
|
+
<h2>Visualizations</h2>
|
|
1977
|
+
{visualizations}
|
|
1978
|
+
</div>
|
|
1979
|
+
</body>
|
|
1980
|
+
</html>
|
|
1981
|
+
"""
|
|
1982
|
+
|
|
1983
|
+
# 生成模型性能表
|
|
1984
|
+
if self.model_performance:
|
|
1985
|
+
model_table = """
|
|
1986
|
+
<table>
|
|
1987
|
+
<tr>
|
|
1988
|
+
<th>Model</th>
|
|
1989
|
+
<th>Accuracy/R2</th>
|
|
1990
|
+
<th>F1/MSE</th>
|
|
1991
|
+
<th>Training Time (s)</th>
|
|
1992
|
+
</tr>
|
|
1993
|
+
{rows}
|
|
1994
|
+
</table>
|
|
1995
|
+
"""
|
|
1996
|
+
|
|
1997
|
+
rows = ""
|
|
1998
|
+
for name, perf in self.model_performance.items():
|
|
1999
|
+
metrics = perf["metrics"]
|
|
2000
|
+
if self.task_type == "classification":
|
|
2001
|
+
row = f"<tr><td>{name}</td><td>{metrics['accuracy']:.4f}</td><td>{metrics['f1_weighted']:.4f}</td><td>{perf['train_time']:.2f}</td></tr>"
|
|
2002
|
+
else:
|
|
2003
|
+
row = f"<tr><td>{name}</td><td>{metrics['r2']:.4f}</td><td>{metrics['mse']:.4f}</td><td>{perf['train_time']:.2f}</td></tr>"
|
|
2004
|
+
rows += row
|
|
2005
|
+
|
|
2006
|
+
model_table = model_table.format(rows=rows)
|
|
2007
|
+
else:
|
|
2008
|
+
model_table = "<p>No models trained yet.</p>"
|
|
2009
|
+
|
|
2010
|
+
# 生成可视化部分
|
|
2011
|
+
visualizations = ""
|
|
2012
|
+
if self.cfg["report"].get("include_visuals", True):
|
|
2013
|
+
# 特征空间可视化
|
|
2014
|
+
fs_path = os.path.join("reports", "feature_space.png")
|
|
2015
|
+
if os.path.exists(fs_path):
|
|
2016
|
+
visualizations += (
|
|
2017
|
+
f'<h3>Feature Space</h3><img src="{fs_path}" alt="Feature Space">'
|
|
2018
|
+
)
|
|
2019
|
+
|
|
2020
|
+
# 混淆矩阵
|
|
2021
|
+
if self.task_type == "classification":
|
|
2022
|
+
cm_path = os.path.join("reports", "confusion_matrix.png")
|
|
2023
|
+
if os.path.exists(cm_path):
|
|
2024
|
+
visualizations += f'<h3>Confusion Matrix</h3><img src="{cm_path}" alt="Confusion Matrix">'
|
|
2025
|
+
|
|
2026
|
+
# 填充报告
|
|
2027
|
+
best_metric = self.evaluate_model(self.best_model).get(
|
|
2028
|
+
"accuracy" if self.task_type == "classification" else "r2", 0
|
|
2029
|
+
)
|
|
2030
|
+
|
|
2031
|
+
report = report.format(
|
|
2032
|
+
date=time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
2033
|
+
task_type=self.task_type.capitalize(),
|
|
2034
|
+
num_images=len(self.images),
|
|
2035
|
+
feat_method=self.cfg["feature_extraction"]["method"],
|
|
2036
|
+
best_model=self.best_model.__class__.__name__ if self.best_model else "N/A",
|
|
2037
|
+
best_metric=best_metric,
|
|
2038
|
+
model_table=model_table,
|
|
2039
|
+
cfg=json.dumps(self.cfg, indent=2),
|
|
2040
|
+
visualizations=visualizations,
|
|
2041
|
+
)
|
|
2042
|
+
|
|
2043
|
+
return report
|
|
2044
|
+
|
|
2045
|
+
def _generate_markdown_report(self) -> str:
|
|
2046
|
+
"""生成Markdown格式的训练报告"""
|
|
2047
|
+
report = "# Image Machine Learning Training Report\n\n"
|
|
2048
|
+
report += f"**Generated on**: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
|
2049
|
+
|
|
2050
|
+
# 摘要部分
|
|
2051
|
+
best_metric = self.evaluate_model(self.best_model).get(
|
|
2052
|
+
"accuracy" if self.task_type == "classification" else "r2", 0
|
|
2053
|
+
)
|
|
2054
|
+
|
|
2055
|
+
report += "## Summary\n"
|
|
2056
|
+
report += f"- **Task Type**: {self.task_type.capitalize()}\n"
|
|
2057
|
+
report += f"- **Number of Images**: {len(self.images)}\n"
|
|
2058
|
+
|
|
2059
|
+
if self.task_type == "classification":
|
|
2060
|
+
report += f"- **Number of Classes**: {len(self.class_names)}\n"
|
|
2061
|
+
|
|
2062
|
+
report += f"- **Feature Extraction Method**: {self.cfg['feature_extraction']['method']}\n"
|
|
2063
|
+
report += f"- **Best Model**: {self.best_model.__class__.__name__ if self.best_model else 'N/A'} (Accuracy/R2: {best_metric:.4f})\n\n"
|
|
2064
|
+
|
|
2065
|
+
# 模型性能比较
|
|
2066
|
+
if self.model_performance:
|
|
2067
|
+
report += "## Model Performance Comparison\n\n"
|
|
2068
|
+
report += "| Model | Accuracy/R2 | F1/MSE | Training Time (s) |\n"
|
|
2069
|
+
report += "|-------|-------------|--------|-------------------|\n"
|
|
2070
|
+
|
|
2071
|
+
for name, perf in self.model_performance.items():
|
|
2072
|
+
metrics = perf["metrics"]
|
|
2073
|
+
if self.task_type == "classification":
|
|
2074
|
+
report += f"| {name} | {metrics['accuracy']:.4f} | {metrics['f1_weighted']:.4f} | {perf['train_time']:.2f} |\n"
|
|
2075
|
+
else:
|
|
2076
|
+
report += f"| {name} | {metrics['r2']:.4f} | {metrics['mse']:.4f} | {perf['train_time']:.2f} |\n"
|
|
2077
|
+
else:
|
|
2078
|
+
report += "No models trained yet.\n\n"
|
|
2079
|
+
|
|
2080
|
+
# 配置信息
|
|
2081
|
+
report += "\n## cfguration\n```json\n"
|
|
2082
|
+
report += json.dumps(self.cfg, indent=2)
|
|
2083
|
+
report += "\n```\n\n"
|
|
2084
|
+
|
|
2085
|
+
# 可视化
|
|
2086
|
+
if self.cfg["report"].get("include_visuals", True):
|
|
2087
|
+
report += "## Visualizations\n"
|
|
2088
|
+
|
|
2089
|
+
# 特征空间可视化
|
|
2090
|
+
fs_path = os.path.join("reports", "feature_space.png")
|
|
2091
|
+
if os.path.exists(fs_path):
|
|
2092
|
+
report += f"### Feature Space\n\n\n"
|
|
2093
|
+
|
|
2094
|
+
# 混淆矩阵
|
|
2095
|
+
if self.task_type == "classification":
|
|
2096
|
+
cm_path = os.path.join("reports", "confusion_matrix.png")
|
|
2097
|
+
if os.path.exists(cm_path):
|
|
2098
|
+
report += f"### Confusion Matrix\n\n\n"
|
|
2099
|
+
|
|
2100
|
+
return report
|