py2ls 0.2.4.15__py3-none-any.whl → 0.2.4.16__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- py2ls/.git/index +0 -0
- py2ls/ips.py +722 -12
- py2ls/ml2ls copy.py +2906 -0
- py2ls/ml2ls.py +345 -12
- py2ls/plot.py +409 -24
- {py2ls-0.2.4.15.dist-info → py2ls-0.2.4.16.dist-info}/METADATA +1 -1
- {py2ls-0.2.4.15.dist-info → py2ls-0.2.4.16.dist-info}/RECORD +8 -7
- {py2ls-0.2.4.15.dist-info → py2ls-0.2.4.16.dist-info}/WHEEL +0 -0
py2ls/ml2ls.py
CHANGED
@@ -506,7 +506,7 @@ def get_models(
|
|
506
506
|
"Support Vector Machine(svm)",
|
507
507
|
"naive bayes",
|
508
508
|
"Linear Discriminant Analysis (lda)",
|
509
|
-
"
|
509
|
+
"AdaBoost",
|
510
510
|
"DecisionTree",
|
511
511
|
"KNeighbors",
|
512
512
|
"Bagging",
|
@@ -585,7 +585,7 @@ def get_features(
|
|
585
585
|
"Support Vector Machine(svm)",
|
586
586
|
"naive bayes",
|
587
587
|
"Linear Discriminant Analysis (lda)",
|
588
|
-
"
|
588
|
+
"AdaBoost",
|
589
589
|
"DecisionTree",
|
590
590
|
"KNeighbors",
|
591
591
|
"Bagging",
|
@@ -699,9 +699,11 @@ def get_features(
|
|
699
699
|
"Support Vector Machine(svm)",
|
700
700
|
"Naive Bayes",
|
701
701
|
"Linear Discriminant Analysis (lda)",
|
702
|
-
"
|
702
|
+
"AdaBoost",
|
703
703
|
]
|
704
704
|
cls = [ips.strcmp(i, cls_)[0] for i in cls]
|
705
|
+
|
706
|
+
feature_importances = {}
|
705
707
|
|
706
708
|
# Lasso Feature Selection
|
707
709
|
lasso_importances = (
|
@@ -712,6 +714,7 @@ def get_features(
|
|
712
714
|
lasso_selected_features = (
|
713
715
|
lasso_importances.head(n_features)["feature"].values if "lasso" in cls else []
|
714
716
|
)
|
717
|
+
feature_importances['lasso']=lasso_importances.head(n_features)
|
715
718
|
# Ridge
|
716
719
|
ridge_importances = (
|
717
720
|
features_ridge(x_train, y_train, ridge_params)
|
@@ -721,6 +724,7 @@ def get_features(
|
|
721
724
|
selected_ridge_features = (
|
722
725
|
ridge_importances.head(n_features)["feature"].values if "ridge" in cls else []
|
723
726
|
)
|
727
|
+
feature_importances['ridge']=ridge_importances.head(n_features)
|
724
728
|
# Elastic Net
|
725
729
|
enet_importances = (
|
726
730
|
features_enet(x_train, y_train, enet_params)
|
@@ -730,6 +734,7 @@ def get_features(
|
|
730
734
|
selected_enet_features = (
|
731
735
|
enet_importances.head(n_features)["feature"].values if "Enet" in cls else []
|
732
736
|
)
|
737
|
+
feature_importances['Enet']=enet_importances.head(n_features)
|
733
738
|
# Random Forest Feature Importance
|
734
739
|
rf_importances = (
|
735
740
|
features_rf(x_train, y_train, rf_params)
|
@@ -741,6 +746,7 @@ def get_features(
|
|
741
746
|
if "Random Forest" in cls
|
742
747
|
else []
|
743
748
|
)
|
749
|
+
feature_importances['Random Forest']=rf_importances.head(n_features)
|
744
750
|
# Gradient Boosting Feature Importance
|
745
751
|
gb_importances = (
|
746
752
|
features_gradient_boosting(x_train, y_train, gb_params)
|
@@ -752,6 +758,7 @@ def get_features(
|
|
752
758
|
if "Gradient Boosting" in cls
|
753
759
|
else []
|
754
760
|
)
|
761
|
+
feature_importances['Gradient Boosting']=gb_importances.head(n_features)
|
755
762
|
# xgb
|
756
763
|
xgb_importances = (
|
757
764
|
features_xgb(x_train, y_train, xgb_params) if "xgb" in cls else pd.DataFrame()
|
@@ -759,6 +766,7 @@ def get_features(
|
|
759
766
|
top_xgb_features = (
|
760
767
|
xgb_importances.head(n_features)["feature"].values if "xgb" in cls else []
|
761
768
|
)
|
769
|
+
feature_importances['xgb']=xgb_importances.head(n_features)
|
762
770
|
|
763
771
|
# SVM with RFE
|
764
772
|
selected_svm_features = (
|
@@ -773,6 +781,7 @@ def get_features(
|
|
773
781
|
selected_lda_features = (
|
774
782
|
lda_importances.head(n_features)["feature"].values if "lda" in cls else []
|
775
783
|
)
|
784
|
+
feature_importances['lda']=lda_importances.head(n_features)
|
776
785
|
# AdaBoost Feature Importance
|
777
786
|
adaboost_importances = (
|
778
787
|
features_adaboost(x_train, y_train, adaboost_params)
|
@@ -784,6 +793,7 @@ def get_features(
|
|
784
793
|
if "AdaBoost" in cls
|
785
794
|
else []
|
786
795
|
)
|
796
|
+
feature_importances['AdaBoost']=adaboost_importances.head(n_features)
|
787
797
|
# Decision Tree Feature Importance
|
788
798
|
dt_importances = (
|
789
799
|
features_decision_tree(x_train, y_train, dt_params)
|
@@ -794,7 +804,8 @@ def get_features(
|
|
794
804
|
dt_importances.head(n_features)["feature"].values
|
795
805
|
if "Decision Tree" in cls
|
796
806
|
else []
|
797
|
-
)
|
807
|
+
)
|
808
|
+
feature_importances['Decision Tree']=dt_importances.head(n_features)
|
798
809
|
# Bagging Feature Importance
|
799
810
|
bagging_importances = (
|
800
811
|
features_bagging(x_train, y_train, bagging_params)
|
@@ -806,6 +817,7 @@ def get_features(
|
|
806
817
|
if "Bagging" in cls
|
807
818
|
else []
|
808
819
|
)
|
820
|
+
feature_importances['Bagging']=bagging_importances.head(n_features)
|
809
821
|
# KNN Feature Importance via Permutation
|
810
822
|
knn_importances = (
|
811
823
|
features_knn(x_train, y_train, knn_params) if "KNN" in cls else pd.DataFrame()
|
@@ -813,6 +825,7 @@ def get_features(
|
|
813
825
|
top_knn_features = (
|
814
826
|
knn_importances.head(n_features)["feature"].values if "KNN" in cls else []
|
815
827
|
)
|
828
|
+
feature_importances['KNN']=knn_importances.head(n_features)
|
816
829
|
|
817
830
|
#! Find common features
|
818
831
|
common_features = ips.shared(
|
@@ -915,6 +928,7 @@ def get_features(
|
|
915
928
|
"cv_train_scores": cv_train_results_df,
|
916
929
|
"cv_test_scores": rank_models(cv_test_results_df, plot_=plot_),
|
917
930
|
"common_features": list(common_features),
|
931
|
+
"feature_importances":feature_importances
|
918
932
|
}
|
919
933
|
if all([plot_, dir_save]):
|
920
934
|
from datetime import datetime
|
@@ -927,6 +941,7 @@ def get_features(
|
|
927
941
|
"cv_train_scores": pd.DataFrame(),
|
928
942
|
"cv_test_scores": pd.DataFrame(),
|
929
943
|
"common_features": [],
|
944
|
+
"feature_importances":{}
|
930
945
|
}
|
931
946
|
print(f"Warning: 没有找到共同的genes, when n_shared={n_shared}")
|
932
947
|
return results
|
@@ -2227,12 +2242,16 @@ def predict(
|
|
2227
2242
|
# else:
|
2228
2243
|
# y_train = ips.df_encoder(pd.DataFrame(y_train), method="dummy").values.ravel()
|
2229
2244
|
y_train=pd.DataFrame(y_train)
|
2230
|
-
|
2231
|
-
|
2232
|
-
|
2233
|
-
|
2234
|
-
|
2245
|
+
if y_train.select_dtypes(include=np.number).empty:
|
2246
|
+
y_train_=ips.df_encoder(y_train, method="dummy",drop=None)
|
2247
|
+
is_binary = False if y_train_.shape[1] >2 else True
|
2248
|
+
else:
|
2249
|
+
y_train_=ips.flatten(y_train.values)
|
2250
|
+
is_binary = False if len(y_train_)>2 else True
|
2235
2251
|
|
2252
|
+
if is_binary:
|
2253
|
+
y_train = ips.df_encoder(pd.DataFrame(y_train), method="label")
|
2254
|
+
print('is_binary:',is_binary)
|
2236
2255
|
if x_true is None:
|
2237
2256
|
x_train, x_true, y_train, y_true = train_test_split(
|
2238
2257
|
x_train,
|
@@ -2893,7 +2912,11 @@ def predict(
|
|
2893
2912
|
x_true = x_true.reindex(columns=x_train.columns, fill_value=0)
|
2894
2913
|
y_pred = best_clf.predict(x_true)
|
2895
2914
|
if hasattr(best_clf, "predict_proba"):
|
2896
|
-
y_pred_proba = best_clf.predict_proba(x_true)
|
2915
|
+
y_pred_proba = best_clf.predict_proba(x_true)
|
2916
|
+
print("Shape of predicted probabilities:", y_pred_proba.shape)
|
2917
|
+
if y_pred_proba.shape[1] == 1:
|
2918
|
+
y_pred_proba = np.hstack([1 - y_pred_proba, y_pred_proba]) # Add missing class probabilities
|
2919
|
+
y_pred_proba = y_pred_proba[:, 1]
|
2897
2920
|
elif hasattr(best_clf, "decision_function"):
|
2898
2921
|
# If predict_proba is not available, use decision_function (e.g., for SVM)
|
2899
2922
|
y_pred_proba = best_clf.decision_function(x_true)
|
@@ -3078,7 +3101,7 @@ def predict(
|
|
3078
3101
|
ips.figsave(dir_save + f"scores_sorted_heatmap{now_}.pdf")
|
3079
3102
|
|
3080
3103
|
df_scores=df_scores.select_dtypes(include=np.number)
|
3081
|
-
|
3104
|
+
|
3082
3105
|
if df_scores.shape[0] > 1: # draw cluster
|
3083
3106
|
plot.heatmap(df_scores, kind="direct", cluster=True)
|
3084
3107
|
plot.figsets(xangle=30)
|
@@ -3169,7 +3192,14 @@ def cal_metrics(
|
|
3169
3192
|
|
3170
3193
|
# Confusion matrix to calculate specificity
|
3171
3194
|
if is_binary:
|
3172
|
-
|
3195
|
+
cm = confusion_matrix(y_true, y_pred)
|
3196
|
+
if cm.size == 4:
|
3197
|
+
tn, fp, fn, tp = cm.ravel()
|
3198
|
+
else:
|
3199
|
+
# Handle single-class predictions
|
3200
|
+
tn, fp, fn, tp = 0, 0, 0, 0
|
3201
|
+
print("Warning: Only one class found in y_pred or y_true.")
|
3202
|
+
|
3173
3203
|
# Specificity calculation
|
3174
3204
|
validation_scores["specificity"] = (
|
3175
3205
|
tn / (tn + fp) if (tn + fp) > 0 else 0
|
@@ -3217,3 +3247,306 @@ def cal_metrics(
|
|
3217
3247
|
)
|
3218
3248
|
|
3219
3249
|
return validation_scores
|
3250
|
+
|
3251
|
+
def plot_trees(
|
3252
|
+
X, y, cls, max_trees=500, test_size=0.2, random_state=42, early_stopping_rounds=None
|
3253
|
+
):
|
3254
|
+
"""
|
3255
|
+
# # Example usage:
|
3256
|
+
# X = np.random.rand(100, 10) # Example data with 100 samples and 10 features
|
3257
|
+
# y = np.random.randint(0, 2, 100) # Example binary target
|
3258
|
+
# # Using the function with different classifiers
|
3259
|
+
# # Random Forest example
|
3260
|
+
# plot_trees(X, y, RandomForestClassifier(), max_trees=100)
|
3261
|
+
# # Gradient Boosting with early stopping example
|
3262
|
+
# plot_trees(X, y, GradientBoostingClassifier(), max_trees=100, early_stopping_rounds=10)
|
3263
|
+
# # Extra Trees example
|
3264
|
+
# plot_trees(X, y, ExtraTreesClassifier(), max_trees=100)
|
3265
|
+
Master function to plot error rates (OOB, training, and testing) for different tree-based ensemble classifiers.
|
3266
|
+
|
3267
|
+
Parameters:
|
3268
|
+
- X (array-like): Feature matrix.
|
3269
|
+
- y (array-like): Target labels.
|
3270
|
+
- cls (object): Tree-based ensemble classifier instance (e.g., RandomForestClassifier()).
|
3271
|
+
- max_trees (int): Maximum number of trees to evaluate. Default is 500.
|
3272
|
+
- test_size (float): Proportion of data to use as test set for testing error. Default is 0.2.
|
3273
|
+
- random_state (int): Random state for reproducibility. Default is 42.
|
3274
|
+
- early_stopping_rounds (int): For boosting models only, stops training if validation error doesn't improve after specified rounds.
|
3275
|
+
|
3276
|
+
Returns:
|
3277
|
+
- None
|
3278
|
+
"""
|
3279
|
+
from sklearn.model_selection import train_test_split
|
3280
|
+
from sklearn.metrics import accuracy_score
|
3281
|
+
from sklearn.ensemble import (
|
3282
|
+
RandomForestClassifier,
|
3283
|
+
BaggingClassifier,
|
3284
|
+
ExtraTreesClassifier,
|
3285
|
+
)
|
3286
|
+
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
|
3287
|
+
# Split data for training and testing error calculation
|
3288
|
+
x_train, x_test, y_train, y_test = train_test_split(
|
3289
|
+
X, y, test_size=test_size, random_state=random_state
|
3290
|
+
)
|
3291
|
+
|
3292
|
+
# Initialize lists to store error rates
|
3293
|
+
oob_error_rate = []
|
3294
|
+
train_error_rate = []
|
3295
|
+
test_error_rate = []
|
3296
|
+
validation_error = None
|
3297
|
+
|
3298
|
+
# Configure classifier based on type
|
3299
|
+
oob_enabled = False # Default to no OOB error unless explicitly set
|
3300
|
+
|
3301
|
+
if isinstance(cls, (RandomForestClassifier, ExtraTreesClassifier)):
|
3302
|
+
# Enable OOB if cls supports it and is using bootstrapping
|
3303
|
+
cls.set_params(warm_start=True, n_estimators=1)
|
3304
|
+
if hasattr(cls, "oob_score"):
|
3305
|
+
cls.set_params(bootstrap=True, oob_score=True)
|
3306
|
+
oob_enabled = True
|
3307
|
+
elif isinstance(cls, BaggingClassifier):
|
3308
|
+
cls.set_params(warm_start=True, bootstrap=True, oob_score=True, n_estimators=1)
|
3309
|
+
oob_enabled = True
|
3310
|
+
elif isinstance(cls, (AdaBoostClassifier, GradientBoostingClassifier)):
|
3311
|
+
cls.set_params(n_estimators=1)
|
3312
|
+
oob_enabled = False
|
3313
|
+
if early_stopping_rounds:
|
3314
|
+
validation_error = []
|
3315
|
+
|
3316
|
+
# Train and evaluate with an increasing number of trees
|
3317
|
+
for i in range(1, max_trees + 1):
|
3318
|
+
cls.set_params(n_estimators=i)
|
3319
|
+
cls.fit(x_train, y_train)
|
3320
|
+
|
3321
|
+
# Calculate OOB error (for models that support it)
|
3322
|
+
if oob_enabled and hasattr(cls, "oob_score_") and cls.oob_score:
|
3323
|
+
oob_error = 1 - cls.oob_score_
|
3324
|
+
oob_error_rate.append(oob_error)
|
3325
|
+
|
3326
|
+
# Calculate training error
|
3327
|
+
train_error = 1 - accuracy_score(y_train, cls.predict(x_train))
|
3328
|
+
train_error_rate.append(train_error)
|
3329
|
+
|
3330
|
+
# Calculate testing error
|
3331
|
+
test_error = 1 - accuracy_score(y_test, cls.predict(x_test))
|
3332
|
+
test_error_rate.append(test_error)
|
3333
|
+
|
3334
|
+
# For boosting models, use validation error with early stopping
|
3335
|
+
if early_stopping_rounds and isinstance(
|
3336
|
+
cls, (AdaBoostClassifier, GradientBoostingClassifier)
|
3337
|
+
):
|
3338
|
+
val_error = 1 - accuracy_score(y_test, cls.predict(x_test))
|
3339
|
+
validation_error.append(val_error)
|
3340
|
+
if len(validation_error) > early_stopping_rounds:
|
3341
|
+
# Stop if validation error has not improved in early_stopping_rounds
|
3342
|
+
if validation_error[-early_stopping_rounds:] == sorted(
|
3343
|
+
validation_error[-early_stopping_rounds:]
|
3344
|
+
):
|
3345
|
+
print(f"Early stopping at tree {i} due to lack of improvement in validation error.")
|
3346
|
+
break
|
3347
|
+
|
3348
|
+
# Plot results
|
3349
|
+
plt.figure(figsize=(10, 6))
|
3350
|
+
if oob_error_rate:
|
3351
|
+
plt.plot(
|
3352
|
+
range(1, len(oob_error_rate) + 1),
|
3353
|
+
oob_error_rate,
|
3354
|
+
color="black",
|
3355
|
+
label="OOB Error Rate",
|
3356
|
+
linewidth=2,
|
3357
|
+
)
|
3358
|
+
if train_error_rate:
|
3359
|
+
plt.plot(
|
3360
|
+
range(1, len(train_error_rate) + 1),
|
3361
|
+
train_error_rate,
|
3362
|
+
linestyle="dotted",
|
3363
|
+
color="green",
|
3364
|
+
label="Training Error Rate",
|
3365
|
+
)
|
3366
|
+
if test_error_rate:
|
3367
|
+
plt.plot(
|
3368
|
+
range(1, len(test_error_rate) + 1),
|
3369
|
+
test_error_rate,
|
3370
|
+
linestyle="dashed",
|
3371
|
+
color="red",
|
3372
|
+
label="Testing Error Rate",
|
3373
|
+
)
|
3374
|
+
if validation_error:
|
3375
|
+
plt.plot(
|
3376
|
+
range(1, len(validation_error) + 1),
|
3377
|
+
validation_error,
|
3378
|
+
linestyle="solid",
|
3379
|
+
color="blue",
|
3380
|
+
label="Validation Error (Boosting)",
|
3381
|
+
)
|
3382
|
+
|
3383
|
+
# Customize plot
|
3384
|
+
plt.xlabel("Number of Trees")
|
3385
|
+
plt.ylabel("Error Rate")
|
3386
|
+
plt.title(f"Error Rate Analysis for {cls.__class__.__name__}")
|
3387
|
+
plt.legend(loc="upper right")
|
3388
|
+
plt.grid(True)
|
3389
|
+
plt.show()
|
3390
|
+
|
3391
|
+
def img_datasets_preprocessing(
|
3392
|
+
data: pd.DataFrame,
|
3393
|
+
x_col: str,
|
3394
|
+
y_col: str=None,
|
3395
|
+
target_size: tuple = (224, 224),
|
3396
|
+
batch_size: int = 128,
|
3397
|
+
class_mode: str = "raw",
|
3398
|
+
shuffle: bool = False,
|
3399
|
+
augment: bool = False,
|
3400
|
+
scaler: str = 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
|
3401
|
+
grayscale: bool = False,
|
3402
|
+
encoder: str = "label", # Options: 'label', 'onehot', 'binary'
|
3403
|
+
label_encoder=None,
|
3404
|
+
kws_augmentation: dict = None,
|
3405
|
+
verbose: bool = True,
|
3406
|
+
drop_missing: bool = True,
|
3407
|
+
output="df", # "iterator":data_iterator,'df':return DataFrame
|
3408
|
+
):
|
3409
|
+
"""
|
3410
|
+
Enhanced preprocessing function for loading and preparing image data from a DataFrame.
|
3411
|
+
|
3412
|
+
Parameters:
|
3413
|
+
- df (pd.DataFrame): Input DataFrame with image paths and labels.
|
3414
|
+
- x_col (str): Column in `df` containing image file paths.
|
3415
|
+
- y_col (str): Column in `df` containing image labels.
|
3416
|
+
- target_size (tuple): Desired image size in (height, width).
|
3417
|
+
- batch_size (int): Number of images per batch.
|
3418
|
+
- class_mode (str): Mode of label ('raw', 'categorical', 'binary').
|
3419
|
+
- shuffle (bool): Shuffle the images in the DataFrame.
|
3420
|
+
- augment (bool): Apply data augmentation.
|
3421
|
+
- scaler (str): 'normalize', # 'normalize', 'standardize', 'clahe', 'raw'
|
3422
|
+
- grayscale (bool): Convert images to grayscale.
|
3423
|
+
- normalize (bool): Normalize image data to [0, 1] range.
|
3424
|
+
- encoder (str): Label encoder method ('label', 'onehot', 'binary').
|
3425
|
+
- label_encoder: Optional pre-defined label encoder.
|
3426
|
+
- kws_augmentation (dict): Parameters for data augmentation.
|
3427
|
+
- verbose (bool): Print status messages.
|
3428
|
+
- drop_missing (bool): Drop rows with missing or invalid image paths.
|
3429
|
+
|
3430
|
+
Returns:
|
3431
|
+
- pd.DataFrame: DataFrame with flattened image pixels and 'Label' column.
|
3432
|
+
"""
|
3433
|
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
3434
|
+
from tensorflow.keras.utils import to_categorical
|
3435
|
+
from sklearn.preprocessing import LabelEncoder
|
3436
|
+
from PIL import Image
|
3437
|
+
import os
|
3438
|
+
|
3439
|
+
# Validate input DataFrame for required columns
|
3440
|
+
if y_col:
|
3441
|
+
assert (
|
3442
|
+
x_col in data.columns and y_col in data.columns
|
3443
|
+
), "Missing required columns in DataFrame."
|
3444
|
+
if y_col is None:
|
3445
|
+
class_mode=None
|
3446
|
+
# 输出格式
|
3447
|
+
output = ips.strcmp(output,[
|
3448
|
+
"generator","tf","iterator","transform","transformer","dataframe",
|
3449
|
+
"df","pd","pandas"])[0]
|
3450
|
+
|
3451
|
+
# Handle missing file paths
|
3452
|
+
if drop_missing:
|
3453
|
+
data = data[
|
3454
|
+
data[x_col].apply(lambda path: os.path.exists(path) and os.path.isfile(path))
|
3455
|
+
]
|
3456
|
+
|
3457
|
+
# Encoding labels if necessary
|
3458
|
+
if encoder and y_col is not None:
|
3459
|
+
if encoder == "binary":
|
3460
|
+
data[y_col] = (data[y_col] == data[y_col].unique()[0]).astype(int)
|
3461
|
+
elif encoder == "onehot":
|
3462
|
+
if not label_encoder:
|
3463
|
+
label_encoder = LabelEncoder()
|
3464
|
+
data[y_col] = label_encoder.fit_transform(data[y_col])
|
3465
|
+
data[y_col] = to_categorical(data[y_col])
|
3466
|
+
elif encoder == "label":
|
3467
|
+
if not label_encoder:
|
3468
|
+
label_encoder = LabelEncoder()
|
3469
|
+
data[y_col] = label_encoder.fit_transform(data[y_col])
|
3470
|
+
|
3471
|
+
# Set up data augmentation
|
3472
|
+
if augment:
|
3473
|
+
aug_params = {
|
3474
|
+
"rotation_range": 20,
|
3475
|
+
"width_shift_range": 0.2,
|
3476
|
+
"height_shift_range": 0.2,
|
3477
|
+
"shear_range": 0.2,
|
3478
|
+
"zoom_range": 0.2,
|
3479
|
+
"horizontal_flip": True,
|
3480
|
+
"fill_mode": "nearest",
|
3481
|
+
}
|
3482
|
+
if kws_augmentation:
|
3483
|
+
aug_params.update(kws_augmentation)
|
3484
|
+
dat = ImageDataGenerator(rescale=scaler, **aug_params)
|
3485
|
+
dat = ImageDataGenerator(
|
3486
|
+
rescale=1.0 / 255 if scaler == 'normalize' else None, **aug_params)
|
3487
|
+
|
3488
|
+
else:
|
3489
|
+
dat = ImageDataGenerator(
|
3490
|
+
rescale=1.0 / 255 if scaler == 'normalize' else None)
|
3491
|
+
|
3492
|
+
# Create DataFrameIterator
|
3493
|
+
data_iterator = dat.flow_from_dataframe(
|
3494
|
+
dataframe=data,
|
3495
|
+
x_col=x_col,
|
3496
|
+
y_col=y_col,
|
3497
|
+
target_size=target_size,
|
3498
|
+
color_mode="grayscale" if grayscale else "rgb",
|
3499
|
+
batch_size=batch_size,
|
3500
|
+
class_mode=class_mode,
|
3501
|
+
shuffle=shuffle,
|
3502
|
+
)
|
3503
|
+
print(f"target_size:{target_size}")
|
3504
|
+
if output.lower() in ["generator", "tf", "iterator", "transform", "transformer"]:
|
3505
|
+
return data_iterator
|
3506
|
+
elif output.lower() in ["dataframe", "df", "pd", "pandas"]:
|
3507
|
+
# Initialize list to collect processed data
|
3508
|
+
data_list = []
|
3509
|
+
total_batches = data_iterator.n // batch_size
|
3510
|
+
|
3511
|
+
# Load, resize, and process images in batches
|
3512
|
+
for i, (batch_images, batch_labels) in enumerate(data_iterator):
|
3513
|
+
for img, label in zip(batch_images, batch_labels):
|
3514
|
+
if scaler == ['normalize','raw']:
|
3515
|
+
# Already rescaled by 1.0/255 in ImageDataGenerator
|
3516
|
+
pass
|
3517
|
+
elif scaler == 'standardize':
|
3518
|
+
# Standardize by subtracting mean and dividing by std
|
3519
|
+
img = (img - np.mean(img)) / np.std(img)
|
3520
|
+
elif scaler == 'clahe':
|
3521
|
+
# Apply CLAHE to the image
|
3522
|
+
img = apply_clahe(img)
|
3523
|
+
flat_img = img.flatten()
|
3524
|
+
data_list.append(np.append(flat_img, label))
|
3525
|
+
|
3526
|
+
# Stop when all images have been processed
|
3527
|
+
if i >= total_batches:
|
3528
|
+
break
|
3529
|
+
|
3530
|
+
# Define column names for flattened image data
|
3531
|
+
pixel_count = target_size[0] * target_size[1] * (1 if grayscale else 3)
|
3532
|
+
column_names = [f"pixel_{i}" for i in range(pixel_count)] + ["Label"]
|
3533
|
+
|
3534
|
+
# Create DataFrame from flattened data
|
3535
|
+
df_img = pd.DataFrame(data_list, columns=column_names)
|
3536
|
+
|
3537
|
+
if verbose:
|
3538
|
+
print("Processed images:", len(df_img))
|
3539
|
+
print("Final DataFrame shape:", df_img.shape)
|
3540
|
+
display(df_img.head())
|
3541
|
+
|
3542
|
+
return df_img
|
3543
|
+
# Function to apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
|
3544
|
+
def apply_clahe(img):
|
3545
|
+
import cv2
|
3546
|
+
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) # Convert to LAB color space
|
3547
|
+
l, a, b = cv2.split(lab) # Split into channels
|
3548
|
+
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
|
3549
|
+
cl = clahe.apply(l) # Apply CLAHE to the L channel
|
3550
|
+
limg = cv2.merge((cl, a, b)) # Merge back the channels
|
3551
|
+
img_clahe = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB) # Convert back to RGB
|
3552
|
+
return img_clahe
|