py2ls 0.1.10.12__py3-none-any.whl → 0.2.7.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of py2ls might be problematic. Click here for more details.

Files changed (72) hide show
  1. py2ls/.DS_Store +0 -0
  2. py2ls/.git/.DS_Store +0 -0
  3. py2ls/.git/index +0 -0
  4. py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
  5. py2ls/.git/objects/.DS_Store +0 -0
  6. py2ls/.git/refs/.DS_Store +0 -0
  7. py2ls/ImageLoader.py +621 -0
  8. py2ls/__init__.py +7 -5
  9. py2ls/apptainer2ls.py +3940 -0
  10. py2ls/batman.py +164 -42
  11. py2ls/bio.py +2595 -0
  12. py2ls/cell_image_clf.py +1632 -0
  13. py2ls/container2ls.py +4635 -0
  14. py2ls/corr.py +475 -0
  15. py2ls/data/.DS_Store +0 -0
  16. py2ls/data/email/email_html_template.html +88 -0
  17. py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
  18. py2ls/data/hyper_param_tabrepo_2024.py +1753 -0
  19. py2ls/data/mygenes_fields_241022.txt +355 -0
  20. py2ls/data/re_common_pattern.json +173 -0
  21. py2ls/data/sns_info.json +74 -0
  22. py2ls/data/styles/.DS_Store +0 -0
  23. py2ls/data/styles/example/.DS_Store +0 -0
  24. py2ls/data/styles/stylelib/.DS_Store +0 -0
  25. py2ls/data/styles/stylelib/grid.mplstyle +15 -0
  26. py2ls/data/styles/stylelib/high-contrast.mplstyle +6 -0
  27. py2ls/data/styles/stylelib/high-vis.mplstyle +4 -0
  28. py2ls/data/styles/stylelib/ieee.mplstyle +15 -0
  29. py2ls/data/styles/stylelib/light.mplstyl +6 -0
  30. py2ls/data/styles/stylelib/muted.mplstyle +6 -0
  31. py2ls/data/styles/stylelib/nature-reviews-latex.mplstyle +616 -0
  32. py2ls/data/styles/stylelib/nature-reviews.mplstyle +616 -0
  33. py2ls/data/styles/stylelib/nature.mplstyle +31 -0
  34. py2ls/data/styles/stylelib/no-latex.mplstyle +10 -0
  35. py2ls/data/styles/stylelib/notebook.mplstyle +36 -0
  36. py2ls/data/styles/stylelib/paper.mplstyle +290 -0
  37. py2ls/data/styles/stylelib/paper2.mplstyle +305 -0
  38. py2ls/data/styles/stylelib/retro.mplstyle +4 -0
  39. py2ls/data/styles/stylelib/sans.mplstyle +10 -0
  40. py2ls/data/styles/stylelib/scatter.mplstyle +7 -0
  41. py2ls/data/styles/stylelib/science.mplstyle +48 -0
  42. py2ls/data/styles/stylelib/std-colors.mplstyle +4 -0
  43. py2ls/data/styles/stylelib/vibrant.mplstyle +6 -0
  44. py2ls/data/tiles.csv +146 -0
  45. py2ls/data/usages_pd.json +1417 -0
  46. py2ls/data/usages_sns.json +31 -0
  47. py2ls/docker2ls.py +5446 -0
  48. py2ls/ec2ls.py +61 -0
  49. py2ls/fetch_update.py +145 -0
  50. py2ls/ich2ls.py +1955 -296
  51. py2ls/im2.py +8242 -0
  52. py2ls/image_ml2ls.py +2100 -0
  53. py2ls/ips.py +33909 -3418
  54. py2ls/ml2ls.py +7700 -0
  55. py2ls/mol.py +289 -0
  56. py2ls/mount2ls.py +1307 -0
  57. py2ls/netfinder.py +873 -351
  58. py2ls/nl2ls.py +283 -0
  59. py2ls/ocr.py +1581 -458
  60. py2ls/plot.py +10394 -314
  61. py2ls/rna2ls.py +311 -0
  62. py2ls/ssh2ls.md +456 -0
  63. py2ls/ssh2ls.py +5933 -0
  64. py2ls/ssh2ls_v01.py +2204 -0
  65. py2ls/stats.py +66 -172
  66. py2ls/temp20251124.py +509 -0
  67. py2ls/translator.py +2 -0
  68. py2ls/utils/decorators.py +3564 -0
  69. py2ls/utils_bio.py +3453 -0
  70. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/METADATA +113 -224
  71. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/RECORD +72 -16
  72. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/WHEEL +0 -0
py2ls/image_ml2ls.py ADDED
@@ -0,0 +1,2100 @@
1
+ r"""
2
+ Ultimate Image Machine Learning Master Function (Optimized)
3
+ Author: DeepSeek AI
4
+ Date: 2025-05-30
5
+ Version: 2.0
6
+
7
+ 新增功能:
8
+ 1. 模型导出与导入 (pickle/joblib)
9
+ 2. 完整的预测报告生成 (HTML格式)
10
+ 3. 预测结果可视化增强
11
+ 4. 支持从保存的模型直接预测
12
+ 5. 更完善的配置管理
13
+
14
+
15
+ # usage: 250601
16
+ from py2ls.ips import listdir
17
+ import re
18
+
19
+ f = listdir("/Users/macjianfeng/Desktop/img_datasets/", "folder", verbose=True)
20
+
21
+
22
+ train_images = [
23
+ *listdir(f["path"][0], "tif")["path"].tolist(),
24
+ *listdir(f["path"][1], "tif")["path"].tolist(),
25
+ ]
26
+
27
+
28
+ labels = [re.findall(r"\d+x", i)[0] for i in train_images]
29
+
30
+ # 初始化主函数
31
+ mlearner = ImageMLearner()
32
+ # 加载图像
33
+ images, filtered_labels = mlearner.load_images(train_images, labels)
34
+
35
+ # 自动模型比较
36
+ comparison_results = mlearner.auto_compare_models()
37
+ # 保存最佳模型
38
+ mlearner.save_model(mlearner.best_model, "models/best_model.joblib")
39
+ dir_test = "/Users/macjianfeng/Desktop/20250515_CDX2_LiCl_7d/HNT-34/50µM/"
40
+ f = listdir(dir_test, "tif", verbose=True)
41
+
42
+ # 使用最佳模型进行预测
43
+ test_images = f.path.tolist()
44
+
45
+ predictions = mlearner.predict(test_images, output_dir="reports")
46
+ print("Predictions:", predictions)
47
+
48
+ """
49
+
50
+ """
51
+ -------------------------------------------------------------------------------
52
+ This file provides a complete machine learning pipeline for image analysis with:
53
+
54
+ Core Functionality:
55
+ Image loading and preprocessing
56
+ Feature extraction (HOG, LBP, CNN, etc.)
57
+ Model training and evaluation (many classifiers/regressors)
58
+ Model comparison and selection
59
+ Prediction and reporting
60
+
61
+ Key Components:
62
+ Multiple feature extraction methods
63
+ Wide range of ML models (traditional and deep learning)
64
+ Model evaluation and visualization
65
+ Report generation (HTML/Markdown)
66
+
67
+ Use Cases:
68
+ When you need end-to-end image classification/regression
69
+ When you want to compare multiple models automatically
70
+ When you need feature extraction from images
71
+ When you want detailed reports and visualizations
72
+ When you need GPU acceleration support
73
+
74
+ You need a complete ML pipeline from images to predictions
75
+ You want to compare multiple ML models automatically
76
+ You need different feature extraction methods
77
+ You want detailed evaluation metrics and visualizations
78
+ You need HTML/Markdown reports of your analysis
79
+ You want to leverage both traditional ML and deep learning
80
+ -------------------------------------------------------------------------------
81
+ """
82
+
83
+
84
+ usage_str="""
85
+ -------------如果要使用tensorflow的话, 一定要最先调用--------
86
+ import tensorflow as tf
87
+ print("Eager execution enabled:", tf.executing_eagerly())
88
+ tf.config.set_visible_devices([], "GPU")
89
+ -------------如果要使用tensorflow的话, 一定要最先调用--------
90
+ """
91
+ print(usage_str)
92
+ from .ips import has_gpu #set_computing_device
93
+ from .ImageLoader import ImageLoader
94
+ import os
95
+ import json
96
+ import yaml
97
+ import pickle
98
+ import joblib
99
+ import numpy as np
100
+ import pandas as pd
101
+ import matplotlib.pyplot as plt
102
+ import seaborn as sns
103
+ import time
104
+ from functools import partial
105
+ from collections import defaultdict
106
+ from typing import Union, Dict, List, Tuple, Optional, Callable, Any
107
+ from IPython.display import HTML, display
108
+ from tqdm import tqdm
109
+ import gc
110
+ import logging
111
+
112
+ # 图像处理库
113
+ import cv2
114
+ from skimage import io, color, exposure
115
+ from skimage import transform, feature, filters, morphology, segmentation
116
+ from PIL import Image
117
+ # 机器学习库
118
+ from sklearn.base import BaseEstimator
119
+ from sklearn.pipeline import Pipeline
120
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
121
+ from sklearn.decomposition import PCA
122
+ from sklearn.manifold import TSNE
123
+ from sklearn.model_selection import train_test_split
124
+ from sklearn.metrics import (
125
+ accuracy_score,
126
+ f1_score,
127
+ confusion_matrix,
128
+ mean_squared_error,
129
+ r2_score,
130
+ classification_report,
131
+ )
132
+
133
+ # 分类模型
134
+ from sklearn.ensemble import (
135
+ RandomForestClassifier,
136
+ RandomForestRegressor,
137
+ ExtraTreesClassifier,
138
+ ExtraTreesRegressor,
139
+ HistGradientBoostingRegressor,
140
+ BaggingClassifier,
141
+ BaggingRegressor,
142
+ AdaBoostClassifier,
143
+ AdaBoostRegressor,
144
+ )
145
+ from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor,StackingClassifier,StackingRegressor
146
+ from sklearn.svm import SVC
147
+
148
+ from sklearn.linear_model import (
149
+ LogisticRegression,ElasticNet,ElasticNetCV,
150
+ LinearRegression,Lasso,RidgeClassifierCV,Perceptron,SGDClassifier,
151
+ RidgeCV,Ridge,TheilSenRegressor,HuberRegressor,PoissonRegressor,Lars, LassoLars, BayesianRidge,
152
+ GammaRegressor, TweedieRegressor, LassoCV, LassoLarsCV, LarsCV,
153
+ OrthogonalMatchingPursuit, OrthogonalMatchingPursuitCV, PassiveAggressiveRegressor
154
+ )
155
+ from sklearn.linear_model import LogisticRegression
156
+ from sklearn.svm import SVC
157
+ from sklearn.neighbors import KNeighborsClassifier
158
+ from sklearn.tree import DecisionTreeClassifier
159
+ import xgboost as xgb
160
+ from sklearn.naive_bayes import GaussianNB,BernoulliNB
161
+ from sklearn.discriminant_analysis import (
162
+ LinearDiscriminantAnalysis,
163
+ QuadraticDiscriminantAnalysis,
164
+ )
165
+ from sklearn.naive_bayes import GaussianNB, BernoulliNB
166
+ import lightgbm as lgb
167
+ import catboost as cb
168
+ from sklearn.neural_network import MLPClassifier, MLPRegressor
169
+ # 回归模型
170
+ from sklearn.svm import SVR
171
+ from sklearn.neighbors import KNeighborsRegressor
172
+ from sklearn.tree import DecisionTreeRegressor
173
+ from sklearn.ensemble import (
174
+ RandomForestRegressor,
175
+ GradientBoostingRegressor,
176
+ AdaBoostRegressor,
177
+ ExtraTreesRegressor,
178
+ )
179
+
180
+ from sklearn.linear_model import (
181
+ LassoCV,
182
+ LogisticRegression,
183
+ LinearRegression,
184
+ Lasso,
185
+ Ridge,
186
+ ElasticNet,
187
+ )
188
+ # 深度学习特征提取
189
+ # try:
190
+ # from tensorflow.keras.utils import to_categorical
191
+
192
+ # except Exception as e:
193
+ # print("Error importing tensorflow.keras.utils or sklearn.utils.class_weight:", e)
194
+
195
+ # try:
196
+
197
+ # from tensorflow.keras.applications import (
198
+ # VGG16,
199
+ # VGG19,
200
+ # ResNet50,
201
+ # InceptionV3,
202
+ # MobileNet,
203
+ # DenseNet121,
204
+ # EfficientNetB0,
205
+ # )
206
+ # from tensorflow.keras.models import Model
207
+ # from tensorflow.keras.preprocessing import image
208
+ # from tensorflow.keras.applications.imagenet_utils import preprocess_input
209
+
210
+ # DL_ENABLED = True
211
+ # except ImportError:
212
+ # DL_ENABLED = False
213
+ # Conditional imports for deep learning
214
+ DL_ENABLED = False
215
+ try:
216
+ import tensorflow as tf
217
+ from sklearn.utils.class_weight import compute_class_weight
218
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
219
+ from tensorflow.keras.utils import to_categorical
220
+ from tensorflow.keras.applications import (
221
+ VGG16, VGG19, ResNet50, InceptionV3,
222
+ MobileNet, DenseNet121, EfficientNetB0
223
+ )
224
+ from tensorflow.keras.models import Model
225
+ from tensorflow.keras.preprocessing import image
226
+ from tensorflow.keras.applications.imagenet_utils import preprocess_input
227
+ from tensorflow.keras.callbacks import (
228
+ EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
229
+ )
230
+ from concurrent.futures import ThreadPoolExecutor
231
+ from tensorflow.keras.mixed_precision import set_global_policy
232
+ from tensorflow.keras.optimizers import Adam
233
+ DL_ENABLED = True
234
+ except ImportError:
235
+ pass
236
+ #====check gpu available====
237
+ # set_computing_device()
238
+ # Global configuration
239
+
240
+ DEFAULT_CONFIG = {
241
+ "preprocessing": {
242
+ "resize": (128, 128),
243
+ "remove_background": True,
244
+ "denoise": "median",
245
+ "normalize": True,
246
+ "hist_equalize": False,
247
+ "chunk_size": 100,
248
+ "augmentation": {
249
+ "rotation_range": 10,
250
+ "width_shift_range": 0.1,
251
+ "height_shift_range": 0.1,
252
+ "shear_range": 0.05,
253
+ "zoom_range": 0.1,
254
+ "horizontal_flip": True
255
+ }
256
+ },
257
+ "feature_extraction": {
258
+ "method": "hog",
259
+ "cnn_model": "resnet50" if DL_ENABLED else None,
260
+ "hog_params": {"orientations": 8, "pixels_per_cell": (16, 16), "cells_per_block": (1, 1)},
261
+ "lbp_params": {"P": 8, "R": 1, "method": "uniform"}
262
+ },
263
+ "cnn": {
264
+ "conv_layers": [(16, 3, 1), (32, 3, 1)],
265
+ "dense_layers": [32],
266
+ "use_batchnorm": True,
267
+ "use_dropout": False,
268
+ "dropout_rate": 0.2,
269
+ "pooling": "max",
270
+ "learning_rate": 0.001,
271
+ "optimizer": "adam"
272
+ },
273
+ "model": {"name": "cnn", "params": {}},
274
+ "auto_comparison": True,
275
+ "test_size": 0.2,
276
+ "random_state": 1,
277
+ "n_cpu": max(os.cpu_count() - 2, 4),
278
+ "visualization": {
279
+ "enable": True,
280
+ "confusion_matrix": True,
281
+ "feature_space": True,
282
+ "sample_images": 2,
283
+ "cmap": "viridis"
284
+ },
285
+ "report": {"format": "html", "include_visuals": True}
286
+ }
287
+ def build_cnn_model(
288
+ input_shape: tuple,
289
+ num_classes: int,
290
+ *,
291
+ # Architecture configuration
292
+ conv_layers: list = None,
293
+ dense_layers: list = None,
294
+ use_batchnorm: bool = True,
295
+ use_dropout: bool = False,
296
+ dropout_rate: float = 0.2,
297
+ pooling: str = "max",
298
+ # Training configuration
299
+ learning_rate: float = 0.001,
300
+ optimizer: str = "adam",
301
+ # Logging/utility
302
+ verbose: bool = True,
303
+ logger: logging.Logger = None
304
+ ) -> tf.keras.Model:
305
+ """
306
+ Flexible CNN builder with customizable architecture.
307
+
308
+ Args:
309
+ input_shape: Tuple (height, width, channels) or (height, width)
310
+ num_classes: Number of output classes
311
+ conv_layers: List of tuples (filters, kernel_size, stride) for conv layers
312
+ Default: [(16, 3, 1), (32, 3, 1)]
313
+ dense_layers: List of units for dense layers. Default: [32]
314
+ use_batchnorm: Whether to use batch normalization
315
+ use_dropout: Whether to use dropout
316
+ dropout_rate: Dropout rate if use_dropout=True
317
+ pooling: "max" or "avg" pooling
318
+ learning_rate: Learning rate for optimizer
319
+ optimizer: "adam", "sgd", or "rmsprop"
320
+ verbose: Whether to print model info
321
+ logger: Custom logger instance
322
+
323
+ Returns:
324
+ Compiled Keras model
325
+ """
326
+ # Initialize logging
327
+ logger = logger or logging.getLogger("cnn_builder")
328
+ if verbose and not logger.handlers:
329
+ ch = logging.StreamHandler()
330
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
331
+ ch.setFormatter(formatter)
332
+ logger.addHandler(ch)
333
+
334
+ start_time = time.time()
335
+
336
+ # Validate and normalize input shape
337
+ if len(input_shape) == 2:
338
+ input_shape = (*input_shape, 1)
339
+ elif len(input_shape) != 3:
340
+ raise ValueError("input_shape must be (h,w) or (h,w,c)")
341
+
342
+ # Set default architecture if not provided
343
+ conv_layers = conv_layers or [(16, 3, 1), (32, 3, 1)]
344
+ dense_layers = dense_layers or [32]
345
+
346
+ # Configure output layer
347
+ if num_classes == 1: # Regression
348
+ final_config = {"units": 1, "activation": "linear", "loss": "mse", "metrics": ["mae"]}
349
+ elif num_classes == 2: # Binary classification
350
+ final_config = {"units": 1, "activation": "sigmoid", "loss": "binary_crossentropy", "metrics": ["accuracy"]}
351
+ else: # Multi-class
352
+ final_config = {"units": num_classes, "activation": "softmax", "loss": "sparse_categorical_crossentropy", "metrics": ["accuracy"]}
353
+
354
+ # Build model
355
+ model = tf.keras.Sequential()
356
+ model.add(tf.keras.layers.InputLayer(input_shape=input_shape))
357
+
358
+ # Add convolutional blocks
359
+ for i, (filters, kernel_size, stride) in enumerate(conv_layers):
360
+ model.add(tf.keras.layers.Conv2D(
361
+ filters=filters,
362
+ kernel_size=kernel_size,
363
+ strides=stride,
364
+ padding="same",
365
+ name=f"conv_{i}"
366
+ ))
367
+
368
+ if use_batchnorm:
369
+ model.add(tf.keras.layers.BatchNormalization(name=f"bn_{i}"))
370
+
371
+ model.add(tf.keras.layers.ReLU(name=f"relu_{i}"))
372
+
373
+ # Add pooling every other conv layer
374
+ if i % 2 == 1:
375
+ if pooling == "max":
376
+ model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name=f"pool_{i}"))
377
+ else:
378
+ model.add(tf.keras.layers.AveragePooling2D(pool_size=(2, 2), name=f"pool_{i}"))
379
+
380
+ if use_dropout:
381
+ model.add(tf.keras.layers.Dropout(dropout_rate, name=f"dropout_{i}"))
382
+
383
+ # Transition to dense layers
384
+ model.add(tf.keras.layers.Flatten())
385
+
386
+ # Add dense layers
387
+ for i, units in enumerate(dense_layers):
388
+ model.add(tf.keras.layers.Dense(units, activation="relu", name=f"dense_{i}"))
389
+ if use_batchnorm:
390
+ model.add(tf.keras.layers.BatchNormalization(name=f"bn_dense_{i}"))
391
+ if use_dropout:
392
+ model.add(tf.keras.layers.Dropout(dropout_rate, name=f"dropout_dense_{i}"))
393
+
394
+ # Output layer
395
+ model.add(tf.keras.layers.Dense(
396
+ units=final_config["units"],
397
+ activation=final_config["activation"],
398
+ name="output"
399
+ ))
400
+
401
+ # Configure optimizer
402
+ optimizers = {
403
+ "adam": tf.keras.optimizers.Adam(learning_rate=learning_rate),
404
+ "sgd": tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9),
405
+ "rmsprop": tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
406
+ }
407
+
408
+ # Compile model
409
+ model.compile(
410
+ optimizer=optimizers.get(optimizer, optimizers["adam"]),
411
+ loss=final_config["loss"],
412
+ metrics=final_config["metrics"]
413
+ )
414
+
415
+ # Logging
416
+ if verbose:
417
+ logger.info(f"Built CNN with:")
418
+ logger.info(f"- {len(conv_layers)} conv layers")
419
+ logger.info(f"- {len(dense_layers)} dense layers")
420
+ logger.info(f"- BatchNorm: {use_batchnorm}")
421
+ logger.info(f"- Dropout: {use_dropout} ({dropout_rate if use_dropout else 'N/A'})")
422
+ logger.info(f"- Optimizer: {optimizer} (lr={learning_rate})")
423
+ logger.info(f"Model built in {time.time()-start_time:.2f}s\n")
424
+ model.summary(print_fn=logger.info)
425
+
426
+ return model
427
+
428
+
429
+
430
+ class ImageMLearner:
431
+ """
432
+ 通用图像机器学习主函数 (优化版)
433
+
434
+ 参数:
435
+ cfg (Union[dict, str]): 配置字典或配置文件路径 (JSON/YAML)
436
+ verbose (bool): 是否显示详细输出
437
+ """
438
+
439
+ def __init__(self, cfg: Union[dict, str] = None, verbose: bool = True):
440
+ self.verbose = verbose
441
+ self.dl_enabled = DL_ENABLED
442
+ self.cfg = self._load_cfg(cfg)
443
+ self.models = {}
444
+ self.feature_extractors = {}
445
+ self.preprocessors = {}
446
+ self.chunk_size = self.cfg["preprocessing"].get("chunk_size", 128)
447
+ self.n_cpu=self.cfg.get("n_cpu", max(os.cpu_count(), 8))
448
+ self._register_default_components()
449
+
450
+ # 状态跟踪
451
+ self.images, self.image_input, self.labels = [], [], []
452
+ self.X_train, self.y_train = None, None
453
+ self.X_test, self.y_test = None, None
454
+ self.features = None
455
+ self.model_performance = {}
456
+ self.best_model = None
457
+ self.label_encoder = None
458
+ self.task_type = None
459
+ self.class_names = None
460
+
461
+ # Hardware optimization
462
+ self.has_gpu = has_gpu(verbose=self.verbose)
463
+ self.cfg_hash = hash(json.dumps(self.cfg["preprocessing"], sort_keys=True))
464
+
465
+ if self.has_gpu and DL_ENABLED:
466
+ self._configure_gpu()
467
+
468
+ # Initialize ImageLoader
469
+ self.image_loader = ImageLoader(
470
+ target_size=self.cfg["preprocessing"].get("resize", (128, 128)),
471
+ grayscale=self.cfg["preprocessing"].get("to_grayscale", True),
472
+ n_jobs=self.cfg.get("n_cpu", max(os.cpu_count() - 2, 4)),
473
+ verbose=verbose,
474
+ scaler="normalize" if self.cfg["preprocessing"].get("normalize", True) else "raw",
475
+ chunk_size=self.cfg["preprocessing"].get("chunk_size", 100)
476
+ )
477
+ # 初始化CNN模型缓存
478
+ self._cnn_models = {}
479
+ if cfg is not None:
480
+ self.clear_cache()
481
+ # If DL is disabled but config requests CNN, fall back to HOG
482
+ if not self.dl_enabled and self.cfg["feature_extraction"].get("method") == "cnn":
483
+ self.cfg["feature_extraction"]["method"] = "hog"
484
+ if self.verbose:
485
+ print("CNN requested but DL not available. Falling back to HOG.")
486
+ if self.verbose:
487
+ print(f"init: ImageMLearner initialized with cfg: {self.cfg}")
488
+ def _configure_gpu(self):
489
+ """Configure GPU settings for optimal performance"""
490
+ try:
491
+ gpus = tf.config.list_physical_devices('GPU')
492
+ if gpus:
493
+ for gpu in gpus:
494
+ tf.config.experimental.set_memory_growth(gpu, True)
495
+ set_global_policy('mixed_float16')
496
+ if self.verbose:
497
+ print("GPU configured with mixed precision")
498
+ except Exception as e:
499
+ print(f"Error configuring GPU: {str(e)}")
500
+ def _load_cfg(self, cfg: Union[dict, str]) -> dict:
501
+ """加载配置文件"""
502
+ default_cfg = {
503
+ "preprocessing": {
504
+ "resize": (128, 128),
505
+ "remove_background": True,
506
+ "denoise": "median",
507
+ "normalize": True,
508
+ "hist_equalize": False,
509
+ "chunk_size": 100,
510
+ },
511
+ "feature_extraction": {
512
+ "method": "hog",
513
+ "cnn_model": "resnet50" if self.dl_enabled else None,
514
+ "hog_params": {
515
+ "orientations": 8,
516
+ "pixels_per_cell": (16, 16),
517
+ "cells_per_block": (1, 1),
518
+ },
519
+ "lbp_params": {"P": 8, "R": 1, "method": "uniform"},
520
+ },
521
+ "model": {"name": "random_forest", "params": {}},
522
+ "auto_comparison": True,
523
+ "test_size": 0.2,
524
+ "random_state": 1,
525
+ "n_cpu":max(os.cpu_count(), 8),
526
+ "visualization": {
527
+ "enable": True,
528
+ "confusion_matrix": True,
529
+ "feature_space": True,
530
+ "sample_images": 2,
531
+ "cmap": "viridis",
532
+ },
533
+ "report": {"format": "html", "include_visuals": True}, # html/md
534
+ }
535
+
536
+ if cfg is None:
537
+ return default_cfg
538
+
539
+ if isinstance(cfg, str):
540
+ if cfg.endswith(".json"):
541
+ with open(cfg, "r") as f:
542
+ cfg = json.load(f)
543
+ elif cfg.endswith((".yaml", ".yml")):
544
+ with open(cfg, "r") as f:
545
+ cfg = yaml.safe_load(f)
546
+ else:
547
+ raise ValueError("Unsupported cfg file format. Use JSON or YAML.")
548
+
549
+ # 深度合并配置
550
+ def deep_merge(target, source):
551
+ for key, value in source.items():
552
+ if (key in target) and isinstance(target[key], dict) and isinstance(value, dict):
553
+ deep_merge(target[key], value)
554
+ else:
555
+ target[key] = value
556
+ return target
557
+
558
+ config = DEFAULT_CONFIG.copy()
559
+ return deep_merge(config, cfg)
560
+ def _register_default_components(self):
561
+ """注册默认预处理、特征提取和模型"""
562
+ # 注册预处理方法
563
+ self.register_preprocessor("grabcut", self._grabcut_background_removal)
564
+ self.register_preprocessor("kmeans", self._kmeans_background_removal)
565
+ self.register_preprocessor("morphology", self._morphology_background_removal)
566
+
567
+ # 注册特征提取方法
568
+ self.register_feature_extractor("pixels", self._extract_raw_pixels)
569
+ self.register_feature_extractor("hog", self._extract_hog)
570
+ self.register_feature_extractor("lbp", self._extract_lbp)
571
+ self.register_feature_extractor("color_hist", self._extract_color_histogram)
572
+
573
+ # if self.dl_enabled:
574
+ # self.register_feature_extractor("cnn", self._extract_cnn_features)
575
+
576
+ # 注册模型
577
+ models = {
578
+ # 分类模型
579
+ "logistic_regression": LogisticRegression,
580
+ "svm": SVC,
581
+ "knn": KNeighborsClassifier,
582
+ "decision_tree": DecisionTreeClassifier,
583
+ "random_forest": RandomForestClassifier,
584
+ "gradient_boosting": GradientBoostingClassifier,
585
+ "ada_boost": AdaBoostClassifier,
586
+ "naive_bayes": GaussianNB,
587
+ "lda": LinearDiscriminantAnalysis,
588
+ "qda": QuadraticDiscriminantAnalysis,
589
+ "extra_trees": ExtraTreesClassifier,
590
+ # added
591
+ "lasso_logistic": LogisticRegression(penalty='l1', solver='saga'),
592
+ "ridge_classifier": RidgeClassifierCV(),
593
+ "elastic_net_cls": ElasticNet(),
594
+ "xgb": xgb.XGBClassifier(), #if XGB_ENABLED else None,
595
+ "lightgbm": lgb.LGBMClassifier(), #if LGBM_ENABLED else None,
596
+ "catboost": cb.CatBoostClassifier(verbose=0), #if CATBOOST_ENABLED else None,
597
+ "bagging": BaggingClassifier(),
598
+ "mlp": MLPClassifier(max_iter=500),
599
+ "quadratic_discriminant": QuadraticDiscriminantAnalysis(),
600
+ "bernoulli_nb": BernoulliNB(),
601
+ "sgd": SGDClassifier(),
602
+ # 回归模型
603
+ "linear_regression": LinearRegression,
604
+ "ridge": Ridge,
605
+ "lasso": Lasso,
606
+ "elastic_net": ElasticNet,
607
+ "svr": SVR,
608
+ "knn_regressor": KNeighborsRegressor,
609
+ "decision_tree_reg": DecisionTreeRegressor,
610
+ "random_forest_reg": RandomForestRegressor,
611
+ "gradient_boosting_reg": GradientBoostingRegressor,
612
+ "ada_boost_reg": AdaBoostRegressor,
613
+ "extra_trees_reg": ExtraTreesRegressor,
614
+ "cnn":self._build_cnn_model,
615
+ # added
616
+ "lasso_cv": LassoCV(),
617
+ "elastic_net_cv": ElasticNetCV(),
618
+ "xgb_reg": xgb.XGBRegressor(), #if XGB_ENABLED else None,
619
+ "lightgbm_reg": lgb.LGBMRegressor(), #if LGBM_ENABLED else None,
620
+ "catboost_reg": cb.CatBoostRegressor(verbose=0), #if CATBOOST_ENABLED else None,
621
+ "bagging_reg": BaggingRegressor(),
622
+ "mlp_reg": MLPRegressor(max_iter=500),
623
+ "theil_sen": TheilSenRegressor(),
624
+ "huber": HuberRegressor(),
625
+ "poisson": PoissonRegressor(),
626
+ }
627
+
628
+ for name, model in models.items():
629
+ self.register_model(name, model)
630
+
631
+ def register_preprocessor(self, name: str, func: Callable):
632
+ """注册新的预处理方法"""
633
+ self.preprocessors[name] = func
634
+ # if self.verbose:
635
+ # print(f"Registered preprocessor: {name}")
636
+
637
+ def register_feature_extractor(self, name: str, func: Callable):
638
+ """注册新的特征提取方法"""
639
+ self.feature_extractors[name] = func
640
+ # if self.verbose:
641
+ # print(f"Registered feature extractor: {name}")
642
+
643
+ def register_model(self, name: str, model_class: BaseEstimator):
644
+ """注册新的机器学习模型"""
645
+ self.models[name] = model_class
646
+ # if self.verbose:
647
+ # print(f"Registered model: {name}")
648
+
649
+ def load_images(self, image_paths: List[str], labels: List = None):
650
+ """
651
+ 加载图像数据集
652
+
653
+ 参数:
654
+ image_paths: Can be:
655
+ - List of file paths (e.g., ["img1.jpg", "img2.png"])
656
+ - NumPy array of shape (n_samples, height, width[, channels])
657
+ - List of NumPy arrays (each shape: (height, width[, channels]))
658
+ labels: 对应的标签列表 (可选)
659
+ """
660
+ self._reset_state()
661
+ if self.verbose:
662
+ print(f"Loading {len(image_paths)} images...")
663
+ # return images, filtered_labels
664
+ # Case 1: Input is already arrays
665
+ if not isinstance(image_paths[0], str):
666
+ self.images = list(image_paths)
667
+ self.image_input = None
668
+ if labels is not None:
669
+ self._process_labels(labels)
670
+ return self.images, self.labels
671
+
672
+ # Case 2: Input is file paths - use ImageLoader
673
+ df = pd.DataFrame({"path": image_paths})
674
+ if labels is not None:
675
+ df["label"] = labels
676
+
677
+ # Process images using ImageLoader
678
+ result = self.image_loader.process(
679
+ data=df,
680
+ x_col="path",
681
+ y_col="label" if labels is not None else None,
682
+ encoder="label" if labels is not None else None,
683
+ output="array"
684
+ )
685
+
686
+ if labels is not None:
687
+ images, processed_labels = result
688
+ self.labels = processed_labels
689
+ else:
690
+ images = result
691
+
692
+ self.images = list(images)
693
+ self.image_input = image_paths
694
+
695
+ # Detect task type and process labels
696
+ if labels is not None:
697
+ self._process_labels(labels)
698
+
699
+ if self.verbose:
700
+ print(f"Loaded {len(self.images)} images, {len(self.labels) if labels is not None else 0} labels")
701
+
702
+ return self.images, self.labels if labels is not None else None
703
+
704
+
705
+ def _process_labels(self, labels: List):
706
+ """Internal method to process labels and detect task type"""
707
+ labels = np.array(labels).flatten()
708
+
709
+ # Detect task type
710
+ unique_labels = set(labels)
711
+ if all(isinstance(label, (int, float)) for label in labels):
712
+ self.task_type = "regression"
713
+ if self.verbose:
714
+ print("Detected regression task")
715
+ else:
716
+ self.task_type = "classification"
717
+ self.class_names = sorted(unique_labels)
718
+ self.label_encoder = LabelEncoder()
719
+ self.labels = self.label_encoder.fit_transform(labels)
720
+ if self.verbose:
721
+ print(f"Detected classification task with {len(unique_labels)} classes")
722
+ def preprocess_image(self, img: np.ndarray) -> np.ndarray:
723
+ """应用配置的预处理步骤到单个图像"""
724
+ cfg = self.cfg["preprocessing"]
725
+
726
+ # Convert to 3-channel BGR if needed
727
+ # if img.ndim == 2: # Grayscale
728
+ # img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
729
+ # elif img.ndim == 3 and img.shape[2] == 4: # RGBA
730
+ # img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
731
+ # elif img.ndim == 3 and img.shape[2] == 3: # RGB
732
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
733
+ if img.ndim == 2:
734
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
735
+ elif img.shape[2] == 4:
736
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
737
+
738
+ # Ensure uint8 (0-255) for OpenCV operations
739
+ if img.dtype != np.uint8:
740
+ img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
741
+ # or: img = img.astype(np.uint8) # If already [0,255]
742
+
743
+ # 调整大小 # Resize with OpenCV for performance
744
+ if cfg.get("resize") and img.shape[:2] != cfg["resize"]:
745
+ img = cv2.resize(img, tuple(cfg["resize"][::-1]), interpolation=cv2.INTER_AREA)
746
+
747
+ # 去背景
748
+ if cfg.get("remove_background"):
749
+ method = cfg.get("bg_method", "grabcut")
750
+ if method in self.preprocessors:
751
+ img = self.preprocessors[method](img)
752
+ # 转换为灰度(如果需要)
753
+ if cfg.get("to_grayscale", True) and img.ndim == 3:
754
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
755
+ # 降噪
756
+ denoise_method = cfg.get("denoise")
757
+ if denoise_method:
758
+ if denoise_method == "median":
759
+ img = cv2.medianBlur(img, 3)
760
+ elif denoise_method == "gaussian":
761
+ img = cv2.GaussianBlur(img, (5, 5), 0)
762
+ elif denoise_method == "bilateral":
763
+ img = cv2.bilateralFilter(img, 9, 75, 75)
764
+
765
+ # 直方图均衡化
766
+ if cfg.get("hist_equalize", False):
767
+ if img.ndim == 3:
768
+ img_yuv = cv2.cvtColor(img, cv2.COLOR_RGB2YUV)
769
+ img_yuv[:, :, 0] = cv2.equalizeHist(img_yuv[:, :, 0])
770
+ img = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
771
+ else:
772
+ img = exposure.equalize_hist(img)
773
+
774
+ # 标准化
775
+ if cfg.get("normalize", True):
776
+ img = img.astype(np.float32) / 255.0
777
+
778
+ return img
779
+
780
+ def _grabcut_background_removal(self, img: np.ndarray) -> np.ndarray:
781
+ """使用GrabCut算法去除背景"""
782
+ mask = np.zeros(img.shape[:2], np.uint8)
783
+ bgd_model = np.zeros((1, 65), np.float64)
784
+ fgd_model = np.zeros((1, 65), np.float64)
785
+
786
+ # 定义ROI(整个图像)
787
+ h, w = img.shape[:2]
788
+ margin=10
789
+ rect = (
790
+ max(0, margin),
791
+ max(0, margin),
792
+ max(1, w - 2 * margin),
793
+ max(1, h - 2 * margin)
794
+ )
795
+
796
+ cv2.grabCut(img, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
797
+
798
+ mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype("uint8")
799
+ result = img * mask2[:, :, np.newaxis]
800
+
801
+ # 清除边界
802
+ result = segmentation.clear_border(result)
803
+
804
+ return result
805
+
806
+ def _kmeans_background_removal(self, img: np.ndarray) -> np.ndarray:
807
+ """使用K-means聚类去除背景"""
808
+ if img.ndim == 3:
809
+ Z = img.reshape((-1, 3))
810
+ else:
811
+ Z = img.reshape((-1, 1))
812
+
813
+ Z = np.float32(Z)
814
+
815
+ # K-means参数
816
+ K = 2
817
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
818
+ ret, label, center = cv2.kmeans(
819
+ Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS
820
+ )
821
+
822
+ # 找到背景(假设背景是最大的簇)
823
+ counts = np.bincount(label.flatten())
824
+ bg_label = np.argmax(counts)
825
+
826
+ # 创建掩码
827
+ mask = np.uint8(label == bg_label)
828
+ mask = mask.reshape(img.shape[:2])
829
+
830
+ # 形态学操作清理掩码
831
+ mask = morphology.binary_erosion(mask, morphology.disk(5))
832
+ mask = morphology.binary_dilation(mask, morphology.disk(5))
833
+
834
+ # 应用掩码
835
+ if img.ndim == 3:
836
+ result = img.copy()
837
+ result[mask == 1] = 0
838
+ else:
839
+ result = img * (1 - mask)
840
+
841
+ return result
842
+
843
+ def _morphology_background_removal(self, img: np.ndarray) -> np.ndarray:
844
+ """使用形态学操作去除背景"""
845
+ if img.ndim == 3:
846
+ gray = color.rgb2gray(img)
847
+ else:
848
+ gray = img
849
+
850
+ # 阈值处理
851
+ thresh = gray > 0.5
852
+
853
+ # 形态学操作
854
+ cleaned = morphology.binary_erosion(thresh, morphology.disk(5))
855
+ cleaned = morphology.binary_dilation(cleaned, morphology.disk(5))
856
+
857
+ # 应用掩码
858
+ if img.ndim == 3:
859
+ result = img.copy()
860
+ result[~cleaned] = 0
861
+ else:
862
+ result = img * cleaned
863
+
864
+ return result
865
+
866
+ def extract_features(self, images: List[np.ndarray] = None) -> np.ndarray:
867
+ """从图像列表中提取特征"""
868
+ import concurrent.futures
869
+ if images is None:
870
+ images = self.images
871
+
872
+ if not images:
873
+ raise ValueError("No images available for feature extraction")
874
+
875
+ method = self.cfg["feature_extraction"]["method"]
876
+ extractor = self.feature_extractors.get(method)
877
+
878
+ if not extractor:
879
+ raise ValueError(f"Unsupported feature extraction method: {method}")
880
+
881
+ if self.verbose:
882
+ print(f"Extracting features using {method} method...")
883
+ # Use cached features if available
884
+ cache_file = f"features_cache_{self.cfg_hash}.npy"
885
+ if os.path.exists(cache_file):
886
+ if self.verbose:
887
+ print("Loading features from cache")
888
+ self.features = np.load(cache_file, allow_pickle=True)
889
+ return self.features
890
+ start_time = time.time()
891
+ # Use cached preprocessing if available
892
+ preprocessed_images = self._cache_preprocessed(images)
893
+
894
+ # Process features
895
+ with ThreadPoolExecutor(max_workers=self.n_cpu) as executor:
896
+ features = list(tqdm(
897
+ executor.map(extractor, preprocessed_images),
898
+ total=len(images),
899
+ desc="Extracting features",
900
+ disable=not self.verbose
901
+ ))
902
+ self.features = np.array(features)
903
+ # Save to cache
904
+ np.save(cache_file, self.features)
905
+
906
+ if self.verbose:
907
+ print(f"Features extracted in {time.time()-start_time:.2f}s")
908
+ print(f"Feature matrix shape: {self.features.shape}")
909
+
910
+ return self.features
911
+
912
+ def _cache_preprocessed(self, images):
913
+ """Cache preprocessed images to disk for faster subsequent runs"""
914
+ import concurrent.futures
915
+ cache_file = f"preprocessed_cache_{self.cfg_hash}.npy"
916
+ if os.path.exists(cache_file):
917
+ if self.verbose:
918
+ print("Loading preprocessed images from cache")
919
+ return np.load(cache_file, allow_pickle=True)
920
+
921
+ if self.verbose:
922
+ print(f"Preprocessing {len(images)} images...")
923
+ start_time = time.time()
924
+ processed = []
925
+
926
+ # Process in chunks
927
+ for i in tqdm(range(0, len(images), self.chunk_size),
928
+ desc="Preprocessing",
929
+ disable=not self.verbose):
930
+ chunk = images[i:i+self.chunk_size]
931
+ with ThreadPoolExecutor(max_workers=self.n_cpu) as executor:
932
+ processed_chunk = list(executor.map(self.preprocess_image, chunk))
933
+ processed.extend(processed_chunk)
934
+
935
+ # Save to cache
936
+ np.save(cache_file, processed)
937
+
938
+ if self.verbose:
939
+ print(f"Preprocessed in {time.time()-start_time:.2f}s")
940
+
941
+ return processed
942
+
943
+ def clear_cache(self, pattern: str = "*.npy") -> int:
944
+ """Clear cached files with pattern matching"""
945
+ import glob
946
+ deleted = 0
947
+ for f in glob.glob(pattern):
948
+ try:
949
+ os.remove(f)
950
+ deleted += 1
951
+ if self.verbose:
952
+ print(f"Removed {f}")
953
+ except Exception as e:
954
+ if self.verbose:
955
+ print(f"Failed to remove {f}: {str(e)}")
956
+ return deleted
957
+
958
+ def _extract_raw_pixels(self, img: np.ndarray) -> np.ndarray:
959
+ """提取原始像素特征"""
960
+ return img.flatten()
961
+ def _reset_state(self):
962
+ """Reset internal state to free memory"""
963
+ self.features = None
964
+ self.X_train, self.y_train = None, None
965
+ self.X_test, self.y_test = None, None
966
+ self.clear_cache("preprocessed_cache_*.npy")
967
+ self.clear_cache("features_cache_*.npy")
968
+ self.clear_cache("cnn_preprocessed_*.npy")
969
+ gc.collect()
970
+
971
+ if DL_ENABLED:
972
+ tf.keras.backend.clear_session()
973
+ def _extract_hog(self, img: np.ndarray) -> np.ndarray:
974
+ """提取HOG特征"""
975
+ params = self.cfg["feature_extraction"]["hog_params"]
976
+
977
+ if img.ndim == 3:
978
+ img = color.rgb2gray(img)
979
+
980
+ return feature.hog(
981
+ img,
982
+ orientations=params["orientations"],
983
+ pixels_per_cell=params["pixels_per_cell"],
984
+ cells_per_block=params["cells_per_block"],
985
+ feature_vector=True,
986
+ )
987
+
988
+ def _extract_lbp(self, img: np.ndarray) -> np.ndarray:
989
+ """提取LBP特征"""
990
+ params = self.cfg["feature_extraction"]["lbp_params"]
991
+
992
+ if img.ndim == 3:
993
+ img = color.rgb2gray(img)
994
+
995
+ lbp = feature.local_binary_pattern(
996
+ img, P=params["P"], R=params["R"], method=params["method"]
997
+ )
998
+
999
+ # 计算直方图
1000
+ n_bins = int(lbp.max() + 1)
1001
+ hist, _ = np.histogram(lbp, bins=n_bins, density=True)
1002
+
1003
+ return hist
1004
+
1005
+ def _extract_color_histogram(self, img: np.ndarray) -> np.ndarray:
1006
+ """提取颜色直方图特征"""
1007
+ if img.ndim == 2: # 灰度图像
1008
+ hist = cv2.calcHist([img], [0], None, [256], [0, 256])
1009
+ return hist.flatten()
1010
+
1011
+ # 彩色图像
1012
+ channels = cv2.split(img)
1013
+ hist_features = []
1014
+
1015
+ for i, channel in enumerate(channels):
1016
+ hist = cv2.calcHist([channel], [0], None, [256], [0, 256])
1017
+ hist_features.extend(hist.flatten())
1018
+
1019
+ return np.array(hist_features)
1020
+
1021
+ def _extract_cnn_features(self, img: np.ndarray) -> np.ndarray:
1022
+ """使用预训练CNN提取特征"""
1023
+ if not self.dl_enabled:
1024
+ raise ImportError("Deep learning libraries not available")
1025
+ model_name = self.cfg["feature_extraction"]["cnn_model"].lower()
1026
+ models = {
1027
+ "vgg16": VGG16,
1028
+ "vgg19": VGG19,
1029
+ "resnet50": ResNet50,
1030
+ "inceptionv3": InceptionV3,
1031
+ "mobilenet": MobileNet,
1032
+ "densenet121": DenseNet121,
1033
+ "efficientnetb0": EfficientNetB0,
1034
+ }
1035
+
1036
+ if model_name not in models:
1037
+ raise ValueError(f"Unsupported CNN model: {model_name}")
1038
+
1039
+ # 加载模型
1040
+ if model_name not in self._cnn_models:
1041
+ base_model = models[model_name](
1042
+ weights="imagenet", include_top=False, pooling="avg"
1043
+ )
1044
+ self._cnn_models[model_name] = Model(
1045
+ inputs=base_model.input, outputs=base_model.output
1046
+ )
1047
+
1048
+ model = self._cnn_models[model_name]
1049
+
1050
+ # 预处理图像
1051
+ if img.ndim == 2:
1052
+ img = np.stack((img,) * 3, axis=-1)
1053
+
1054
+ img = transform.resize(img, (224, 224))
1055
+ img = image.img_to_array(img)
1056
+ img = np.expand_dims(img, axis=0)
1057
+ img = preprocess_input(img)
1058
+
1059
+ # 提取特征
1060
+ features = model.predict(img)
1061
+ return features.flatten()
1062
+
1063
+
1064
+ def preprocess_image_for_cnn(self, img: np.ndarray) -> np.ndarray:
1065
+ """
1066
+ Ultimate optimized CNN preprocessing:
1067
+ - Minimal operations
1068
+ - Vectorized normalization
1069
+ - Efficient resizing
1070
+ - Memory-friendly processing
1071
+ """
1072
+ target_size = self.cfg["preprocessing"].get("resize", (32, 32))
1073
+
1074
+ # Convert PIL Image if needed
1075
+ if isinstance(img, Image.Image):
1076
+ img = np.array(img)
1077
+
1078
+ if img.ndim == 2: # Grayscale
1079
+ img = np.stack((img,)*3, axis=-1) # Convert to 3-channel
1080
+ elif img.shape[2] == 4: # RGBA
1081
+ img = img[..., :3] # Remove alpha
1082
+ # Ensure we have 3 channels
1083
+ if img.shape[2] == 1:
1084
+ img = np.concatenate([img]*3, axis=-1)
1085
+ # Handle different channel cases
1086
+ if img.ndim == 2: # Grayscale to RGB
1087
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
1088
+ elif img.ndim == 3 and img.shape[2] == 4: # RGBA
1089
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
1090
+ elif img.ndim == 3 and img.shape[2] == 3: # BGR to RGB
1091
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
1092
+ # Efficient resizing - use OpenCV for fastest performance
1093
+ if img.shape[:2] != target_size:
1094
+ img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
1095
+
1096
+ # Resize
1097
+ if img.shape[:2] != target_size:
1098
+ img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
1099
+
1100
+ # Normalization
1101
+ img = img.astype(np.float32) / 255.0
1102
+ return img
1103
+ def _build_cnn_model(self, input_shape=None, num_classes=None):
1104
+ """Build CNN model with proper instance access"""
1105
+ if input_shape is None:
1106
+ input_shape = self.cfg["preprocessing"]["resize"]
1107
+
1108
+ if num_classes is None:
1109
+ if hasattr(self, 'labels') and self.labels is not None:
1110
+ num_classes = len(np.unique(self.labels)) if self.task_type == "classification" else 1
1111
+ else:
1112
+ num_classes = 2 # Default fallback
1113
+
1114
+ return build_cnn_model(
1115
+ input_shape=input_shape,
1116
+ num_classes=num_classes,
1117
+ **self.cfg["cnn"]
1118
+ )
1119
+
1120
+ def _train_cnn_model(self, model_builder: Callable, params: dict):
1121
+ """Optimized CNN training pipeline with fixes for hanging issue"""
1122
+ if not self.dl_enabled:
1123
+ raise ImportError("Deep learning libraries not available")
1124
+
1125
+ tf.config.optimizer.set_jit(True) # Enable XLA compilation
1126
+ tf.keras.backend.clear_session() # Clean any previous models
1127
+ # Configure GPU settings
1128
+ if self.has_gpu:
1129
+ set_global_policy('mixed_float16')
1130
+
1131
+ # Preprocessing
1132
+ target_size = self.cfg["preprocessing"].get("resize", (224, 224))
1133
+ X = self._preprocess_for_cnn(self.images,target_size)
1134
+ y = np.array(self.labels).flatten()
1135
+
1136
+ # Debug: Check data shapes and types
1137
+ if self.verbose:
1138
+ print(f"X shape: {X.shape}, dtype: {X.dtype}")
1139
+ print(f"y shape: {y.shape}, dtype: {y.dtype}")
1140
+ print(f"Unique labels: {np.unique(y)}")
1141
+
1142
+ # Split dataset
1143
+ stratify = y if self.task_type == "classification" else None
1144
+ X_train, X_val, y_train, y_val = train_test_split(
1145
+ X, y,
1146
+ test_size=self.cfg["test_size"],
1147
+ random_state=self.cfg["random_state"],
1148
+ stratify=stratify
1149
+ )
1150
+
1151
+ # Fix: Ensure data types match
1152
+ X_train = X_train.astype(np.float32)
1153
+ X_val = X_val.astype(np.float32)
1154
+
1155
+ # Data augmentation
1156
+ datagen = self._create_datagen(params)
1157
+
1158
+ # Build model
1159
+ input_shape = X.shape[1:]
1160
+ num_classes = len(np.unique(y)) if self.task_type == "classification" else 1
1161
+ model = model_builder(input_shape, num_classes)
1162
+
1163
+ # # Debug: Print model summary
1164
+ # if self.verbose:
1165
+ # model.summary()
1166
+
1167
+ # Training configuration
1168
+ batch_size = params.get('batch_size', 32)
1169
+ # FIX: Use explicit DataGenerator instead of direct arrays
1170
+ train_datagen = ImageDataGenerator()
1171
+ val_datagen = ImageDataGenerator()
1172
+
1173
+ train_generator = train_datagen.flow(
1174
+ X_train, y_train,
1175
+ batch_size=batch_size,
1176
+ shuffle=True
1177
+ )
1178
+
1179
+ val_generator = val_datagen.flow(
1180
+ X_val, y_val,
1181
+ batch_size=batch_size,
1182
+ shuffle=False
1183
+ )
1184
+ # Calculate steps
1185
+ steps_per_epoch = max(1, len(X_train) // batch_size)
1186
+ validation_steps = max(1, len(X_val) // batch_size)
1187
+
1188
+ # Callbacks
1189
+ callbacks = self._create_callbacks(params)
1190
+
1191
+ # FIX: Add explicit print before training
1192
+ if self.verbose:
1193
+ print(f"\nStarting training with:")
1194
+ print(f" Train samples: {len(X_train)}")
1195
+ print(f" Validation samples: {len(X_val)}")
1196
+ print(f" Batch size: {batch_size}")
1197
+ print(f" Steps per epoch: {steps_per_epoch}")
1198
+ print(f" Validation steps: {validation_steps}")
1199
+ print("Compiling model...")
1200
+
1201
+ # FIX: Test model with a single batch
1202
+ try:
1203
+ if self.verbose:
1204
+ print("Testing with one batch...")
1205
+ test_batch = next(train_generator)
1206
+ model.predict(test_batch[0], verbose=0)
1207
+ if self.verbose:
1208
+ print("1st batch test successful")
1209
+ except Exception as e:
1210
+ print(f"❌ Error in single batch test: {str(e)}")
1211
+ raise
1212
+ # Train model
1213
+ start_time = time.time()
1214
+
1215
+ # FIX: Use explicit generators
1216
+ history = model.fit(
1217
+ train_generator,
1218
+ steps_per_epoch=steps_per_epoch,
1219
+ epochs=params.get('epochs', 50),
1220
+ validation_data=val_generator,
1221
+ validation_steps=validation_steps,
1222
+ callbacks=callbacks,
1223
+ verbose=1 if self.verbose else 0 # Use 1 for per-batch updates
1224
+ )
1225
+
1226
+ # Collect metrics
1227
+ train_time = time.time() - start_time
1228
+ best_epoch = np.argmin(history.history['val_loss']) + 1
1229
+ metrics = {
1230
+ 'train_time': train_time,
1231
+ 'best_epoch': best_epoch,
1232
+ 'val_loss': history.history['val_loss'][best_epoch-1],
1233
+ 'best_val_loss': min(history.history['val_loss'])
1234
+ }
1235
+
1236
+ if self.task_type == 'classification':
1237
+ metrics['val_accuracy'] = history.history['val_accuracy'][best_epoch-1]
1238
+ metrics['best_val_accuracy'] = max(history.history['val_accuracy'])
1239
+
1240
+ # Save results
1241
+ self.model_performance["cnn"] = {
1242
+ "model": model,
1243
+ "metrics": metrics,
1244
+ "history": history.history,
1245
+ "params": params
1246
+ }
1247
+ self._update_best_model(model, metrics)
1248
+ # Cleanup
1249
+ tf.keras.backend.clear_session()
1250
+ gc.collect()
1251
+
1252
+ if self.verbose:
1253
+ print(f"Training completed in {train_time:.2f}s")
1254
+ print(f"Best validation accuracy: {metrics.get('best_val_accuracy', metrics['best_val_loss']):.4f}")
1255
+
1256
+ return model, metrics
1257
+
1258
+
1259
+ def _create_datagen(self, params: dict):
1260
+ """Create data generator with augmentation"""
1261
+ use_augmentation = params.get('use_augmentation', False)
1262
+
1263
+ # Fix: Only apply augmentation if explicitly enabled
1264
+ if use_augmentation:
1265
+ aug_cfg = self.cfg["preprocessing"].get("augmentation", {})
1266
+ datagen = ImageDataGenerator(
1267
+ rotation_range=aug_cfg.get("rotation_range", 10),
1268
+ width_shift_range=aug_cfg.get("width_shift_range", 0.1),
1269
+ height_shift_range=aug_cfg.get("height_shift_range", 0.1),
1270
+ shear_range=aug_cfg.get("shear_range", 0.05),
1271
+ zoom_range=aug_cfg.get("zoom_range", 0.1),
1272
+ horizontal_flip=aug_cfg.get("horizontal_flip", True),
1273
+ fill_mode='reflect'
1274
+ )
1275
+ else:
1276
+ # Fix: Use simple generator without augmentation
1277
+ datagen = ImageDataGenerator()
1278
+
1279
+ return datagen
1280
+ def _preprocess_for_cnn(self, images, target_size: Tuple[int, int]) -> np.ndarray:
1281
+ """Efficient CNN preprocessing with parallel processing"""
1282
+ cache_file = f"cnn_preprocessed_{self.cfg_hash}.npy"
1283
+
1284
+ if os.path.exists(cache_file):
1285
+ if self.verbose:
1286
+ print("Loading CNN preprocessed images from cache")
1287
+ return np.load(cache_file)
1288
+
1289
+ if self.verbose:
1290
+ print("Preprocessing images for CNN...")
1291
+
1292
+ start_time = time.time()
1293
+ processed = []
1294
+
1295
+ # Process in chunks
1296
+ for i in range(0, len(images), self.chunk_size):
1297
+ chunk = images[i:i+self.chunk_size]
1298
+ with ThreadPoolExecutor(max_workers=self.n_cpu) as executor:
1299
+ processed_chunk = list(executor.map(
1300
+ partial(self._preprocess_single_cnn, target_size=target_size),
1301
+ chunk
1302
+ ))
1303
+ processed.extend(processed_chunk)
1304
+
1305
+ processed = np.array(processed)
1306
+
1307
+ # FIX: Ensure proper shape (samples, height, width, channels)
1308
+ if processed.ndim == 3:
1309
+ # Add channel dimension for grayscale
1310
+ processed = np.expand_dims(processed, axis=-1)
1311
+
1312
+ np.save(cache_file, processed)
1313
+
1314
+ if self.verbose:
1315
+ print(f"Preprocessed in {time.time()-start_time:.2f}s")
1316
+ print(f"Processed shape: {processed.shape}")
1317
+
1318
+ return processed
1319
+
1320
+
1321
+ def _preprocess_single_cnn(self, img: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
1322
+ """Optimized single image preprocessing for CNN"""
1323
+ # Convert to 3-channel if needed
1324
+ if img.ndim == 2:
1325
+ img = np.stack((img,)*3, axis=-1)
1326
+ elif img.shape[2] == 4:
1327
+ img = img[..., :3]
1328
+ # Handle channel conversions
1329
+ if img.ndim == 3:
1330
+ if img.shape[2] == 3: # RGB to BGR if needed
1331
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
1332
+ elif img.shape[2] == 1: # Grayscale to "RGB"
1333
+ img = np.concatenate([img]*3, axis=-1)
1334
+ # Resize with OpenCV
1335
+ if img.shape[:2] != target_size:
1336
+ img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
1337
+
1338
+ # Convert to float and normalize
1339
+ img = img.astype(np.float32) / 255.0
1340
+
1341
+ # Fix: Return grayscale if needed
1342
+ if self.cfg["preprocessing"].get("to_grayscale", True):
1343
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
1344
+ img = np.expand_dims(img, axis=-1)
1345
+
1346
+ return img
1347
+
1348
+ def _create_callbacks(self, params: dict) -> List[tf.keras.callbacks.Callback]:
1349
+ """Create training callbacks"""
1350
+ return [
1351
+ EarlyStopping(
1352
+ patience=params.get('patience', 8),
1353
+ monitor='val_accuracy' if self.task_type == "classification" else 'val_loss',
1354
+ restore_best_weights=True,
1355
+ verbose=1 if self.verbose else 0
1356
+ ),
1357
+ ModelCheckpoint(
1358
+ 'best_model.keras',
1359
+ save_best_only=True,
1360
+ monitor='val_accuracy' if self.task_type == "classification" else 'val_loss',
1361
+ mode='max' if self.task_type == "classification" else 'min'
1362
+ ),
1363
+ ReduceLROnPlateau(
1364
+ monitor='val_loss',
1365
+ factor=0.5,
1366
+ patience=3,
1367
+ min_lr=1e-6,
1368
+ verbose=1 if self.verbose else 0
1369
+ )
1370
+ ]
1371
+
1372
+ def train_model(self, model_name: str = None, params: dict = None):
1373
+ """训练指定的模型"""
1374
+ if model_name is None:
1375
+ model_name=self.cfg["model"]["name"]
1376
+ if model_name != "cnn" and self.features is None:
1377
+ self.extract_features()
1378
+ print(f"Training model: {model_name}")
1379
+ if self.labels is None:
1380
+ raise ValueError("Labels are required for training")
1381
+ # 特殊处理需要自定义训练流程的模型
1382
+ if model_name in ["xgb", "lightgbm", "catboost"]:
1383
+ # 设置GPU加速(如果可用)
1384
+ if self.has_gpu:
1385
+ params.update({"tree_method": "gpu_hist", "gpu_id": 0})
1386
+
1387
+ # 设置类别权重(针对不平衡数据)
1388
+ if self.task_type == "classification":
1389
+ class_weights = compute_class_weight('balanced', classes=np.unique(self.y_train), y=self.y_train)
1390
+ params["class_weight"] = dict(zip(np.unique(self.y_train), class_weights))
1391
+
1392
+ # 获取模型配置
1393
+ model_cfg = self.cfg["model"]
1394
+ model_name = model_name or model_cfg["name"]
1395
+ params = params or model_cfg.get("params", {})
1396
+
1397
+ if model_name not in self.models:
1398
+ raise ValueError(f"Unsupported model: {model_name}")
1399
+
1400
+ # 特殊处理CNN模型
1401
+ if model_name == "cnn":
1402
+ if not self.dl_enabled:
1403
+ raise ImportError("Deep learning libraries not available")
1404
+ return self._train_cnn_model(self.models[model_name], params or {})
1405
+
1406
+ # 划分训练测试集
1407
+ self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
1408
+ self.features,
1409
+ self.labels,
1410
+ test_size=self.cfg["test_size"],
1411
+ random_state=self.cfg["random_state"],
1412
+ stratify=self.labels,
1413
+ )
1414
+ print("划分训练测试集")
1415
+
1416
+ # 过滤不支持的参数
1417
+ model_class = self.models[model_name]
1418
+
1419
+ supported_params = model_class().get_params().keys()
1420
+ filtered_params = {k: v for k, v in params.items() if k in supported_params}
1421
+
1422
+ if len(filtered_params) != len(params) and self.verbose:
1423
+ removed_params = set(params.keys()) - set(filtered_params.keys())
1424
+ print(f"Removed unsupported params for {model_name}: {removed_params}")
1425
+ # 创建模型
1426
+ model = model_class(**filtered_params)
1427
+
1428
+ print("创建模型")
1429
+ if self.verbose:
1430
+ print(f"🤖 Training {model_name} model...")
1431
+ print(f"📐 Parameters: {params}")
1432
+
1433
+ # 训练模型
1434
+ start_time = time.time()
1435
+ model.fit(self.X_train, self.y_train)
1436
+ train_time = time.time() - start_time
1437
+
1438
+ print(" 训练模型")
1439
+ # 评估模型
1440
+ metrics = self.evaluate_model(model)
1441
+ print(" 评估模型")
1442
+ # 保存结果
1443
+ self.model_performance[model_name] = {
1444
+ "model": model,
1445
+ "metrics": metrics,
1446
+ "train_time": train_time,
1447
+ "params": params,
1448
+ }
1449
+ print(" 保存结果")
1450
+ # 更新最佳模型
1451
+ self._update_best_model(model, metrics)
1452
+
1453
+ if self.verbose:
1454
+ print(f"Model trained in {train_time:.2f}s")
1455
+ print("Performance metrics:")
1456
+ for k, v in metrics.items():
1457
+ print(f" {k}: {v:.4f}")
1458
+
1459
+ return model, metrics
1460
+
1461
+ def auto_compare_models(
1462
+ self, model_names: List[str] = None, params_list: List[dict] = None
1463
+ ):
1464
+ """自动比较多个模型并排名"""
1465
+ if model_names is None:
1466
+ # 根据任务类型选择默认模型
1467
+ if self.task_type == "classification":
1468
+ model_names = [
1469
+ "random_forest", "svm", "knn", "gradient_boosting", "logistic_regression",
1470
+ "xgb", "lightgbm", "catboost", "extra_trees", "mlp", "bagging"
1471
+ ]
1472
+ params_list = [
1473
+ {"n_estimators": 100}, # RF
1474
+ {"C": 1.0, "kernel": "rbf"}, # SVM
1475
+ {"n_neighbors": 5}, # KNN
1476
+ {"n_estimators": 100}, # GBM
1477
+ {"C": 1.0, "penalty": "l2"}, # Logistic
1478
+ {"n_estimators": 100}, # XGBoost
1479
+ {"n_estimators": 100}, # LightGBM
1480
+ {"iterations": 100}, # CatBoost
1481
+ {"n_estimators": 100}, # Extra Trees
1482
+ {"hidden_layer_sizes": (64,)}, # MLP
1483
+ {"n_estimators": 50} # Bagging
1484
+ ]
1485
+ else: # regression
1486
+ model_names = [
1487
+ "random_forest_reg", "svr", "knn_regressor", "gradient_boosting_reg", "linear_regression",
1488
+ "xgb_reg", "lightgbm_reg", "catboost_reg", "extra_trees_reg", "mlp_reg", "bagging_reg"
1489
+ ]
1490
+ params_list = [
1491
+ {"n_estimators": 100}, # RF Reg
1492
+ {"C": 1.0, "kernel": "rbf"}, # SVR
1493
+ {"n_neighbors": 5}, # KNN Reg
1494
+ {"n_estimators": 100}, # GBM Reg
1495
+ {}, # Linear Regression
1496
+ {"n_estimators": 100}, # XGBoost Reg
1497
+ {"n_estimators": 100}, # LightGBM Reg
1498
+ {"iterations": 100}, # CatBoost Reg
1499
+ {"n_estimators": 100}, # Extra Trees Reg
1500
+ {"hidden_layer_sizes": (64,)}, # MLP Reg
1501
+ {"n_estimators": 50} # Bagging Reg
1502
+ ]
1503
+ # 分类的话, 必须要>=2个分类
1504
+ if self.task_type == "classification":
1505
+ unique_classes = np.unique(self.labels)
1506
+ if len(unique_classes) < 2:
1507
+ raise ValueError(
1508
+ "Classification requires ≥2 classes. Your data has only one class."
1509
+ )
1510
+ if params_list is None:
1511
+ params_list = [{}] * len(model_names)
1512
+
1513
+ if len(model_names) != len(params_list):
1514
+ raise ValueError("model_names and params_list must have the same length")
1515
+
1516
+ if self.verbose:
1517
+ print(f"Comparing {len(model_names)} models...")
1518
+
1519
+ # 训练并评估所有模型
1520
+ results = []
1521
+ for name, params in zip(model_names, params_list):
1522
+ try:
1523
+ _, metrics = self.train_model(name, params)
1524
+ results.append({"model": name, "params": params, **metrics})
1525
+ except Exception as e:
1526
+ print(f"Error training {name}: {str(e)}")
1527
+
1528
+ # 创建结果DataFrame
1529
+ results_df = pd.DataFrame(results)
1530
+
1531
+ # 根据任务类型选择排序指标
1532
+ if self.task_type == "classification":
1533
+ sort_by = "accuracy"
1534
+ else:
1535
+ sort_by = "r2"
1536
+
1537
+ # 排序结果
1538
+ results_df = results_df.sort_values(by=sort_by, ascending=False)
1539
+
1540
+ # 保存最佳模型
1541
+ best_model_name = results_df.iloc[0]["model"]
1542
+ self.best_model = self.model_performance[best_model_name]["model"]
1543
+
1544
+ if self.verbose:
1545
+ print("\n🏆 Model Comparison Results:")
1546
+ print(results_df[["model", sort_by]])
1547
+ print(f"\n🥇 Best model: {best_model_name}")
1548
+
1549
+ return results_df
1550
+
1551
+ def _update_best_model(self, model, metrics):
1552
+ """Helper to update the best model reference"""
1553
+ if self.best_model is None:
1554
+ self.best_model = model
1555
+ else:
1556
+ current_metric = self._get_primary_metric(self.best_model)
1557
+ new_metric = self._get_primary_metric(model, metrics)
1558
+
1559
+ if new_metric > current_metric:
1560
+ self.best_model = model
1561
+
1562
+ def _get_primary_metric(self, model, metrics=None):
1563
+ """Get the appropriate primary metric for comparison"""
1564
+ if metrics is None:
1565
+ metrics = self.evaluate_model(model=model)
1566
+
1567
+ return metrics.get(
1568
+ "accuracy" if self.task_type == "classification" else "r2", 0
1569
+ )
1570
+
1571
+ def evaluate_model(self, model: BaseEstimator = None) -> dict:
1572
+ """评估模型性能"""
1573
+ if model is None:
1574
+ if self.best_model is None:
1575
+ raise ValueError("No model available for evaluation")
1576
+ model = self.best_model
1577
+
1578
+ # 计算指标
1579
+ metrics = {}
1580
+ try:
1581
+ # 预测
1582
+ y_pred = model.predict(self.X_test)
1583
+ print("try to predict it again")
1584
+ if self.task_type == "classification":
1585
+ metrics["accuracy"] = accuracy_score(self.y_test, y_pred)
1586
+ metrics["f1_weighted"] = f1_score(self.y_test, y_pred, average="weighted")
1587
+ if self.verbose:
1588
+ print("\nClassification Report:")
1589
+ print(
1590
+ classification_report(
1591
+ self.y_test, y_pred, target_names=self.class_names
1592
+ )
1593
+ )
1594
+ # 可视化混淆矩阵
1595
+ if (
1596
+ self.cfg["visualization"]["enable"]
1597
+ and self.cfg["visualization"]["confusion_matrix"]
1598
+ ):
1599
+ self.plot_confusion_matrix(self.y_test, y_pred)
1600
+ else:
1601
+ metrics["mse"] = mean_squared_error(self.y_test, y_pred)
1602
+ metrics["mae"] = mean_absolute_error(self.y_test, y_pred)
1603
+ metrics["r2"] = r2_score(self.y_test, y_pred)
1604
+ except Exception as e:
1605
+ print(f"Error evaluating model: {str(e)}")
1606
+ return metrics
1607
+
1608
+ def _reset_prediction_state(self):
1609
+ """Reset internal state before new prediction"""
1610
+ # Clear feature cache
1611
+ self.features = None
1612
+ self.clear_cache()
1613
+ # Clear CNN-specific cache
1614
+ if hasattr(self, 'preprocessed_images'):
1615
+ del self.preprocessed_images
1616
+ # Clear image cache
1617
+ self.images = []
1618
+ self.image_input = []
1619
+ # Clear train/test splits
1620
+ if hasattr(self, 'X_train'): self.X_train = None
1621
+ if hasattr(self, 'X_test'): self.X_test = None
1622
+ if hasattr(self, 'y_train'): self.y_train = None
1623
+ if hasattr(self, 'y_test'): self.y_test = None
1624
+
1625
+ def predict(
1626
+ self,
1627
+ image_paths: List[str],
1628
+ model: BaseEstimator = None,
1629
+ output_dir: str = None,
1630
+ ) -> np.ndarray:
1631
+ """预测新图像的标签并生成报告"""
1632
+ self._reset_prediction_state()
1633
+ if model is None:
1634
+ if self.best_model is None:
1635
+ raise ValueError("No model available for prediction")
1636
+ model = self.best_model
1637
+ image_paths = list(image_paths)
1638
+ # 加载并预处理图像
1639
+ images, _ = self.load_images(image_paths)
1640
+
1641
+ # Check if model is a CNN (Sequential/Functional API with Conv layers)
1642
+ is_cnn = any(isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Convolution2D))
1643
+ for layer in model.layers)
1644
+ if is_cnn:
1645
+ print("Using direct CNN prediction") # Debug
1646
+ target_size = self.cfg["preprocessing"].get("resize", (32, 32))
1647
+ images = self._preprocess_for_cnn(images,target_size=target_size) # Must return (N,H,W,C)
1648
+ predictions = model.predict(images)
1649
+ else:
1650
+ features = self.extract_features(images)
1651
+ predictions = model.predict(features)
1652
+ self.clear_cache()
1653
+ self.predictions=predictions
1654
+ # 解码分类标签
1655
+ # if self.task_type == "classification" and self.label_encoder:
1656
+ # predictions = self.label_encoder.inverse_transform(predictions)
1657
+ # # Handle predictions based on model type
1658
+ if self.task_type == "classification":
1659
+ if predictions.ndim > 1 and predictions.shape[1] > 1:
1660
+ # Multi-class: convert probabilities to class indices
1661
+ predictions = np.argmax(predictions, axis=1)
1662
+ else:
1663
+ # Binary classification: threshold at 0.5
1664
+ predictions = (predictions > 0.5).astype(int)
1665
+
1666
+ # Only inverse transform if label_encoder exists
1667
+ if hasattr(self, 'label_encoder') and self.label_encoder:
1668
+ try:
1669
+ predictions = self.label_encoder.inverse_transform(predictions)
1670
+ except ValueError as e:
1671
+ print(f"Label transform warning: {e}")
1672
+ print("Returning numeric predictions instead")
1673
+ # 可视化预测结果
1674
+ if self.cfg["visualization"]["enable"]:
1675
+ self.visualize_predictions(images, predictions, image_paths)
1676
+
1677
+ # 生成预测报告
1678
+ if output_dir:
1679
+ self.generate_prediction_report(image_paths, predictions, output_dir)
1680
+
1681
+ return predictions
1682
+
1683
+ def save_model(self, model: BaseEstimator, file_path: str, format: str = "joblib"):
1684
+ """
1685
+ 保存训练好的模型
1686
+
1687
+ 参数:
1688
+ model: 要保存的模型对象
1689
+ file_path: 保存路径
1690
+ format: 保存格式 ('joblib' 或 'pickle')
1691
+ """
1692
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
1693
+
1694
+ if format == "joblib":
1695
+ joblib.dump(model, file_path)
1696
+ elif format == "pickle":
1697
+ with open(file_path, "wb") as f:
1698
+ pickle.dump(model, f)
1699
+ else:
1700
+ raise ValueError("Unsupported format. Use 'joblib' or 'pickle'")
1701
+
1702
+ if self.verbose:
1703
+ print(f"Model saved to {file_path}")
1704
+
1705
+ def load_model(self, file_path: str, format: str = "joblib") -> BaseEstimator:
1706
+ """
1707
+ 加载保存的模型
1708
+
1709
+ 参数:
1710
+ file_path: 模型文件路径
1711
+ format: 文件格式 ('joblib' 或 'pickle')
1712
+
1713
+ 返回:
1714
+ 加载的模型对象
1715
+ """
1716
+ if format == "joblib":
1717
+ model = joblib.load(file_path)
1718
+ elif format == "pickle":
1719
+ with open(file_path, "rb") as f:
1720
+ model = pickle.load(f)
1721
+ else:
1722
+ raise ValueError("Unsupported format. Use 'joblib' or 'pickle'")
1723
+
1724
+ if self.best_model is None:
1725
+ self.best_model = model
1726
+ print(f"the best model is loaded: {model}")
1727
+ if self.verbose:
1728
+ print(f"Model loaded from {file_path}")
1729
+
1730
+ return model
1731
+
1732
+ def visualize_predictions(
1733
+ self, images: List[np.ndarray], predictions: List, paths: List[str] = None
1734
+ ):
1735
+ """可视化预测结果"""
1736
+ n_samples = min(self.cfg["visualization"]["sample_images"], len(images))
1737
+
1738
+ plt.figure(figsize=(15, 8))
1739
+ for i in range(n_samples):
1740
+ plt.subplot(2, (n_samples + 1) // 2, i + 1)
1741
+
1742
+ if images[i].ndim == 2:
1743
+ plt.imshow(images[i], cmap="gray")
1744
+ else:
1745
+ plt.imshow(images[i])
1746
+
1747
+ title = f"Pred: {predictions[i]}"
1748
+ if paths and isinstance(paths[i], str) and paths[i]:
1749
+ title += f"\n{os.path.basename(paths[i])}"
1750
+
1751
+ plt.title(title, fontsize=10)
1752
+ plt.axis("off")
1753
+
1754
+ plt.tight_layout()
1755
+
1756
+ # 保存可视化结果
1757
+ if self.cfg["report"].get("include_visuals", True):
1758
+ vis_path = os.path.join("reports", "prediction_visuals.png")
1759
+ os.makedirs(os.path.dirname(vis_path), exist_ok=True)
1760
+ plt.savefig(vis_path, dpi=150, bbox_inches="tight")
1761
+
1762
+ plt.show()
1763
+
1764
+ def plot_confusion_matrix(self, y_true, y_pred):
1765
+ """绘制混淆矩阵"""
1766
+ cm = confusion_matrix(y_true, y_pred)
1767
+ plt.figure(figsize=(10, 8))
1768
+ sns.heatmap(
1769
+ cm,
1770
+ annot=True,
1771
+ fmt="d",
1772
+ cmap="Blues",
1773
+ xticklabels=self.class_names,
1774
+ yticklabels=self.class_names,
1775
+ )
1776
+ plt.xlabel("Predicted")
1777
+ plt.ylabel("True")
1778
+ plt.title("Confusion Matrix")
1779
+
1780
+ # 保存可视化结果
1781
+ if self.cfg["report"].get("include_visuals", True):
1782
+ cm_path = os.path.join("reports", "confusion_matrix.png")
1783
+ os.makedirs(os.path.dirname(cm_path), exist_ok=True)
1784
+ plt.savefig(cm_path, dpi=150, bbox_inches="tight")
1785
+
1786
+ plt.show()
1787
+
1788
+ def visualize_feature_space(self):
1789
+ """可视化特征空间"""
1790
+ if self.features is None:
1791
+ raise ValueError("No features available for visualization")
1792
+ if self.verbose:
1793
+ print(f"features dimensions: {self.features.shape[1]}")
1794
+ # 降维
1795
+ if self.features.shape[1] > 2:
1796
+ if self.verbose:
1797
+ print(
1798
+ "Reducing feature dimensions for visualization...\n\t第一步:PCA;\n\t第二步:t-SNE to 2D)"
1799
+ )
1800
+ # 使用PCA进行初步降维, n_components=50
1801
+ pca = PCA(n_components=50)
1802
+ features_reduced = pca.fit_transform(self.features)
1803
+
1804
+ # 使用t-SNE进一步降维到2D
1805
+ tsne = TSNE(n_components=2, random_state=self.cfg["random_state"])
1806
+ features_2d = tsne.fit_transform(features_reduced)
1807
+ else:
1808
+ features_2d = self.features
1809
+
1810
+ # 绘制特征空间
1811
+ plt.figure(figsize=(12, 8))
1812
+
1813
+ if self.labels is not None:
1814
+ scatter = plt.scatter(
1815
+ features_2d[:, 0],
1816
+ features_2d[:, 1],
1817
+ c=self.labels,
1818
+ cmap=self.cfg["visualization"]["cmap"],
1819
+ alpha=0.7,
1820
+ )
1821
+
1822
+ if self.task_type == "classification":
1823
+ plt.legend(*scatter.legend_elements(), title="Classes")
1824
+ else:
1825
+ plt.scatter(features_2d[:, 0], features_2d[:, 1], alpha=0.7)
1826
+
1827
+ plt.title("Feature Space Visualization")
1828
+ plt.xlabel("Dimension 1")
1829
+ plt.ylabel("Dimension 2")
1830
+
1831
+ # 保存可视化结果
1832
+ if self.cfg["report"].get("include_visuals", True):
1833
+ fs_path = os.path.join("reports", "feature_space.pdf")
1834
+ os.makedirs(os.path.dirname(fs_path), exist_ok=True)
1835
+ plt.savefig(fs_path, dpi=150, bbox_inches="tight")
1836
+
1837
+ plt.show()
1838
+
1839
+ def generate_report(self, output_path: str = "reports/training_report.html"):
1840
+ """生成训练报告"""
1841
+ report_type = "html" if output_path.endswith(".html") else "md"
1842
+
1843
+ if report_type == "html":
1844
+ report = self._generate_html_report()
1845
+ with open(output_path, "w") as f:
1846
+ f.write(report)
1847
+ else:
1848
+ report = self._generate_markdown_report()
1849
+ with open(output_path, "w") as f:
1850
+ f.write(report)
1851
+
1852
+ if self.verbose:
1853
+ print(f"Report generated at {output_path}")
1854
+
1855
+ return report
1856
+
1857
+ def generate_prediction_report(
1858
+ self, image_paths: List[str], predictions: List, output_dir: str = "reports"
1859
+ ):
1860
+ """生成预测报告"""
1861
+ os.makedirs(output_dir, exist_ok=True)
1862
+ report_path = os.path.join(output_dir, "prediction_report.html")
1863
+
1864
+ # 创建报告内容
1865
+ report = """
1866
+ <!DOCTYPE html>
1867
+ <html>
1868
+ <head>
1869
+ <title>Model Training Report</title>
1870
+ <style>
1871
+ body {{ font-family: Arial, sans-serif; margin: 20px; }}
1872
+ h1, h2 {{ color: #2c3e50; }}
1873
+ table {{ border-collapse: collapse; width: 100%; margin-bottom: 20px; }}
1874
+ th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
1875
+ th {{ background-color: #f2f2f2; }}
1876
+ .summary {{ background-color: #f9f9f9; padding: 15px; border-radius: 5px; }}
1877
+ .section {{ margin-bottom: 30px; }}
1878
+ img {{ max-width: 100%; height: auto; }}
1879
+ </style>
1880
+ </head>
1881
+ <body>
1882
+ <h1>Image Prediction Report</h1>
1883
+ <p>Generated on: {date}</p>
1884
+
1885
+ <div class="summary">
1886
+ <h2>Summary</h2>
1887
+ <p><strong>Total Predictions:</strong> {count}</p>
1888
+ <p><strong>Model Used:</strong> {model_name}</p>
1889
+ <p><strong>Task Type:</strong> {task_type}</p>
1890
+ </div>
1891
+
1892
+ <h2>Prediction Details</h2>
1893
+ <table>
1894
+ <tr>
1895
+ <th>Image</th>
1896
+ <th>Filename</th>
1897
+ <th>Prediction</th>
1898
+ </tr>
1899
+ {rows}
1900
+ </table>
1901
+ </body>
1902
+ </html>
1903
+ """
1904
+
1905
+ # 生成表格行
1906
+ rows = ""
1907
+ for path, pred in zip(image_paths, predictions):
1908
+ try:
1909
+ img_tag = f'<img src="{path}" alt="{os.path.basename(path)}">'
1910
+ rows += f"<tr><td>{img_tag}</td><td>{os.path.basename(path)}</td><td>{pred}</td></tr>"
1911
+ except Exception as e:
1912
+ print(e)
1913
+ break
1914
+
1915
+ # 填充报告
1916
+ report = report.format(
1917
+ date=time.strftime("%Y-%m-%d %H:%M:%S"),
1918
+ count=len(image_paths),
1919
+ model_name=(
1920
+ self.best_model.__class__.__name__ if self.best_model else "Unknown"
1921
+ ),
1922
+ task_type=self.task_type.capitalize(),
1923
+ rows=rows,
1924
+ )
1925
+
1926
+ # 保存报告
1927
+ with open(report_path, "w") as f:
1928
+ f.write(report)
1929
+
1930
+ if self.verbose:
1931
+ print(f"Prediction report generated at {report_path}")
1932
+
1933
+ return report
1934
+
1935
+ def _generate_html_report(self) -> str:
1936
+ """生成HTML格式的训练报告"""
1937
+ report = """
1938
+ <!DOCTYPE html>
1939
+ <html>
1940
+ <head>
1941
+ <title>Model Training Report</title>
1942
+ <style>
1943
+ body {{font-family: Arial; margin: 20px; }}
1944
+ h1, h2 {{ color: #2c3e50; }}
1945
+ table {{ border-collapse: collapse; width: 100%; margin-bottom: 20px; }}
1946
+ th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
1947
+ th {{ background-color: #f2f2f2; }}
1948
+ .summary {{ background-color: #f9f9f9; padding: 15px; border-radius: 5px; }}
1949
+ .section {{ margin-bottom: 30px; }}
1950
+ img {{ max-width: 100%; height: auto; }}
1951
+ </style>
1952
+ </head>
1953
+ <body>
1954
+ <h1>Image Machine Learning Training Report</h1>
1955
+ <p>Generated on: {{date}}</p>
1956
+
1957
+ <div class="summary">
1958
+ <h2>Summary</h2>
1959
+ <p><strong>Task Type:</strong> {task_type}</p>
1960
+ <p><strong>Number of Images:</strong> {num_images}</p>
1961
+ <p><strong>Feature Extraction Method:</strong> {feat_method}</p>
1962
+ <p><strong>Best Model:</strong> {best_model} (Accuracy/R2: {best_metric:.4f})</p>
1963
+ </div>
1964
+
1965
+ <div class="section">
1966
+ <h2>Model Performance Comparison</h2>
1967
+ {model_table}
1968
+ </div>
1969
+
1970
+ <div class="section">
1971
+ <h2>cfguration</h2>
1972
+ <pre>{cfg}</pre>
1973
+ </div>
1974
+
1975
+ <div class="section">
1976
+ <h2>Visualizations</h2>
1977
+ {visualizations}
1978
+ </div>
1979
+ </body>
1980
+ </html>
1981
+ """
1982
+
1983
+ # 生成模型性能表
1984
+ if self.model_performance:
1985
+ model_table = """
1986
+ <table>
1987
+ <tr>
1988
+ <th>Model</th>
1989
+ <th>Accuracy/R2</th>
1990
+ <th>F1/MSE</th>
1991
+ <th>Training Time (s)</th>
1992
+ </tr>
1993
+ {rows}
1994
+ </table>
1995
+ """
1996
+
1997
+ rows = ""
1998
+ for name, perf in self.model_performance.items():
1999
+ metrics = perf["metrics"]
2000
+ if self.task_type == "classification":
2001
+ row = f"<tr><td>{name}</td><td>{metrics['accuracy']:.4f}</td><td>{metrics['f1_weighted']:.4f}</td><td>{perf['train_time']:.2f}</td></tr>"
2002
+ else:
2003
+ row = f"<tr><td>{name}</td><td>{metrics['r2']:.4f}</td><td>{metrics['mse']:.4f}</td><td>{perf['train_time']:.2f}</td></tr>"
2004
+ rows += row
2005
+
2006
+ model_table = model_table.format(rows=rows)
2007
+ else:
2008
+ model_table = "<p>No models trained yet.</p>"
2009
+
2010
+ # 生成可视化部分
2011
+ visualizations = ""
2012
+ if self.cfg["report"].get("include_visuals", True):
2013
+ # 特征空间可视化
2014
+ fs_path = os.path.join("reports", "feature_space.png")
2015
+ if os.path.exists(fs_path):
2016
+ visualizations += (
2017
+ f'<h3>Feature Space</h3><img src="{fs_path}" alt="Feature Space">'
2018
+ )
2019
+
2020
+ # 混淆矩阵
2021
+ if self.task_type == "classification":
2022
+ cm_path = os.path.join("reports", "confusion_matrix.png")
2023
+ if os.path.exists(cm_path):
2024
+ visualizations += f'<h3>Confusion Matrix</h3><img src="{cm_path}" alt="Confusion Matrix">'
2025
+
2026
+ # 填充报告
2027
+ best_metric = self.evaluate_model(self.best_model).get(
2028
+ "accuracy" if self.task_type == "classification" else "r2", 0
2029
+ )
2030
+
2031
+ report = report.format(
2032
+ date=time.strftime("%Y-%m-%d %H:%M:%S"),
2033
+ task_type=self.task_type.capitalize(),
2034
+ num_images=len(self.images),
2035
+ feat_method=self.cfg["feature_extraction"]["method"],
2036
+ best_model=self.best_model.__class__.__name__ if self.best_model else "N/A",
2037
+ best_metric=best_metric,
2038
+ model_table=model_table,
2039
+ cfg=json.dumps(self.cfg, indent=2),
2040
+ visualizations=visualizations,
2041
+ )
2042
+
2043
+ return report
2044
+
2045
+ def _generate_markdown_report(self) -> str:
2046
+ """生成Markdown格式的训练报告"""
2047
+ report = "# Image Machine Learning Training Report\n\n"
2048
+ report += f"**Generated on**: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n"
2049
+
2050
+ # 摘要部分
2051
+ best_metric = self.evaluate_model(self.best_model).get(
2052
+ "accuracy" if self.task_type == "classification" else "r2", 0
2053
+ )
2054
+
2055
+ report += "## Summary\n"
2056
+ report += f"- **Task Type**: {self.task_type.capitalize()}\n"
2057
+ report += f"- **Number of Images**: {len(self.images)}\n"
2058
+
2059
+ if self.task_type == "classification":
2060
+ report += f"- **Number of Classes**: {len(self.class_names)}\n"
2061
+
2062
+ report += f"- **Feature Extraction Method**: {self.cfg['feature_extraction']['method']}\n"
2063
+ report += f"- **Best Model**: {self.best_model.__class__.__name__ if self.best_model else 'N/A'} (Accuracy/R2: {best_metric:.4f})\n\n"
2064
+
2065
+ # 模型性能比较
2066
+ if self.model_performance:
2067
+ report += "## Model Performance Comparison\n\n"
2068
+ report += "| Model | Accuracy/R2 | F1/MSE | Training Time (s) |\n"
2069
+ report += "|-------|-------------|--------|-------------------|\n"
2070
+
2071
+ for name, perf in self.model_performance.items():
2072
+ metrics = perf["metrics"]
2073
+ if self.task_type == "classification":
2074
+ report += f"| {name} | {metrics['accuracy']:.4f} | {metrics['f1_weighted']:.4f} | {perf['train_time']:.2f} |\n"
2075
+ else:
2076
+ report += f"| {name} | {metrics['r2']:.4f} | {metrics['mse']:.4f} | {perf['train_time']:.2f} |\n"
2077
+ else:
2078
+ report += "No models trained yet.\n\n"
2079
+
2080
+ # 配置信息
2081
+ report += "\n## cfguration\n```json\n"
2082
+ report += json.dumps(self.cfg, indent=2)
2083
+ report += "\n```\n\n"
2084
+
2085
+ # 可视化
2086
+ if self.cfg["report"].get("include_visuals", True):
2087
+ report += "## Visualizations\n"
2088
+
2089
+ # 特征空间可视化
2090
+ fs_path = os.path.join("reports", "feature_space.png")
2091
+ if os.path.exists(fs_path):
2092
+ report += f"### Feature Space\n![]({fs_path})\n\n"
2093
+
2094
+ # 混淆矩阵
2095
+ if self.task_type == "classification":
2096
+ cm_path = os.path.join("reports", "confusion_matrix.png")
2097
+ if os.path.exists(cm_path):
2098
+ report += f"### Confusion Matrix\n![]({cm_path})\n\n"
2099
+
2100
+ return report