ColorCorrectionPipeline 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Configs/configs.py ADDED
@@ -0,0 +1,40 @@
1
+
2
+ import numpy as np
3
+ from typing import Optional, Any
4
+
5
+
6
+ __all__ = ['Config']
7
+
8
+ class Config:
9
+ """
10
+ Simple configuration container for the pipeline steps.
11
+ Any attribute not set will default to None (or be read via get_attr).
12
+ """
13
+ def __init__(
14
+ self,
15
+ do_ffc: bool = True,
16
+ do_gc: bool = True,
17
+ do_wb: bool = True,
18
+ do_cc: bool = True,
19
+ save: bool = False,
20
+ check_saturation : bool = True,
21
+ save_path: Optional[str] = None,
22
+ REF_ILLUMINANT: Optional[np.ndarray] = None,
23
+ FFC_kwargs: Optional[Any] = None,
24
+ GC_kwargs: Optional[Any] = None,
25
+ WB_kwargs: Optional[Any] = None,
26
+ CC_kwargs: Optional[Any] = None,
27
+ ) -> None:
28
+ self.do_ffc = do_ffc
29
+ self.do_gc = do_gc
30
+ self.do_wb = do_wb
31
+ self.do_cc = do_cc
32
+ self.save = save
33
+ self.save_path = save_path
34
+ self.REF_ILLUMINANT = REF_ILLUMINANT
35
+ self.FFC_kwargs = FFC_kwargs
36
+ self.GC_kwargs = GC_kwargs
37
+ self.WB_kwargs = WB_kwargs
38
+ self.CC_kwargs = CC_kwargs
39
+ self.check_saturation = check_saturation
40
+
FFC/FF_correction.py ADDED
@@ -0,0 +1,508 @@
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import plotly.graph_objects as go
5
+ import sys
6
+
7
+ from sklearn.linear_model import LinearRegression
8
+ from sklearn.svm import SVR
9
+ from sklearn.neural_network import MLPRegressor
10
+ from sklearn.cross_decomposition import PLSRegression
11
+ from sklearn.preprocessing import PolynomialFeatures
12
+
13
+ from utils.logger_ import log_, ThrowDlg, match_keywords
14
+ from ultralytics import YOLO
15
+ import torch
16
+
17
+ import gc
18
+
19
+ gc.enable()
20
+
21
+ FLOAT = np.float32
22
+ UINT8 = np.uint8
23
+ cmaps = ['viridis', 'plasma', 'jet', 'Greys', 'cividis']
24
+
25
+
26
+ class FlatFieldCorrection():
27
+ """
28
+ Class for computing flat-field correction from white image.
29
+
30
+ Example usage:
31
+
32
+ ```
33
+ import cv2
34
+ from FFC.FF_correction import FlatFieldCorrection
35
+
36
+ if __name__ == "__main__":
37
+ # Example: Load a white-background image and perform FFC
38
+ path_ = r"Data/Backgrounds/White/Blank.JPG"
39
+ w_img = cv2.imread(path_, cv2.IMREAD_COLOR)
40
+
41
+ ffc_params = {
42
+ "model_path": "best_models/PD_trained_512_dauntless-sweep-1/weights/best.pt",
43
+ "manual_crop": False,
44
+ "smooth_window": 11,
45
+ "bins": 50,
46
+ "show": True,
47
+ }
48
+
49
+ fit_params = {
50
+ "degree": 5,
51
+ "interactions": True,
52
+ "fit_method": "nn", # linear, nn, pls, svm
53
+ "max_iter": 1000,
54
+ "tol": 1e-8,
55
+ "verbose": False,
56
+ "rand_seed": 0,
57
+ }
58
+
59
+ ffc = FlatFieldCorrection(img=w_img, **ffc_params)
60
+ multiplier = ffc.compute_multiplier(**fit_params)
61
+ # np.save("mult.npy", multiplier)
62
+
63
+ corrected_img = ffc.apply_ffc(img=w_img, multiplier=multiplier, show=True)
64
+ ```
65
+
66
+ """
67
+ def __init__(self, img=None, **kwargs):
68
+ self.img = img
69
+ self.model_path = kwargs.get('model_path', '')
70
+ self.manual_crop = kwargs.get('manual_crop', False)
71
+ if self.model_path == '':
72
+ self.manual_crop = True
73
+ self.show = kwargs.get('show', False)
74
+ self.bins = kwargs.get('bins', 50)
75
+ self.smooth_window = kwargs.get('smooth_window', 5)
76
+ self.crop_rect = kwargs.get('crop_rect', None)
77
+ self.model = None
78
+ self.img_cropped = None
79
+ self.cropped_multiplier = None
80
+ self.final_multiplier = None
81
+ self.is_color = self.check_color(self.img) if self.img is not None else None
82
+
83
+ if not self.manual_crop:
84
+ self.model = YOLO(self.model_path)
85
+
86
+
87
+ def check_color(self, img):
88
+ self.is_color = img.ndim == 3 and img.shape[2] == 3
89
+ return self.is_color
90
+
91
+ def resize_image(self, img, factor=None, size=None):
92
+ img_ = img
93
+ if factor is not None:
94
+ img_ = cv2.resize(img, None, fx=factor, fy=factor, interpolation=cv2.INTER_CUBIC)
95
+ if size is not None:
96
+ img_ = cv2.resize(img, (size[1], size[0]), interpolation=cv2.INTER_CUBIC)
97
+
98
+ return img_
99
+
100
+ def transform_extremity(self, x: np.ndarray, cut_off: float=1.5, max: float=2.0):
101
+ x_ = x.flatten()
102
+ mask = x_ > cut_off
103
+ x_[mask] = cut_off + (max - cut_off) * np.tanh(max * (x_[mask] - cut_off))
104
+ return x_.reshape(x.shape)
105
+
106
+ def show_results(self, img_correct, img_original):
107
+ fig, ax = plt.subplots(1, 2, figsize=(15, 7))
108
+ if len(img_correct.shape) == 3:
109
+ ax[0].imshow(cv2.cvtColor(img_correct, cv2.COLOR_BGR2RGB))
110
+ ax[1].imshow(cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB))
111
+ else:
112
+ img_correct = cv2.normalize(img_correct, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
113
+ img_original = cv2.normalize(img_original, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
114
+ ax[0].imshow(img_correct, cmap='gray')
115
+ ax[1].imshow(img_original, cmap='gray')
116
+ ax[0].set_title("FF Corrected Image")
117
+ ax[1].set_title("Original Image")
118
+ plt.show()
119
+
120
+ return fig
121
+
122
+ def plot_intensity_distribution(self, Z, Z_flat, half=False):
123
+
124
+ if half:
125
+ shape = Z.shape
126
+ wh_ = shape[0]//2
127
+ hh_ = shape[1]//2
128
+ Z = Z[0:wh_, :]
129
+ Z_flat = Z_flat[0:wh_, :]
130
+
131
+ try:
132
+ w,h,_ = Z.shape
133
+ except:
134
+ w,h = Z.shape
135
+
136
+ bins = self.bins
137
+ if w > self.bins or h > bins:
138
+ x = np.linspace(0, w-1, bins)
139
+ y = np.linspace(0, h-1, bins)
140
+ X, Y = np.meshgrid(x, y)
141
+ h_win = int((self.smooth_window-1)/2)
142
+
143
+ Z_ = np.zeros_like(X)
144
+ Z_flat_ = np.zeros_like(X)
145
+
146
+ for i, x_ in enumerate(x):
147
+ for j, y_ in enumerate(y):
148
+ x_,y_ = int(x_), int(y_)
149
+ x_bounds = max(0, x_-h_win), min(w, x_+h_win)
150
+ y_bounds = max(0, y_-h_win), min(h, y_+h_win)
151
+
152
+ Z_[i, j] = np.mean(Z[x_bounds[0]:x_bounds[1], y_bounds[0]:y_bounds[1]])
153
+ Z_flat_[i, j] = np.mean(Z_flat[x_bounds[0]:x_bounds[1], y_bounds[0]:y_bounds[1]])
154
+
155
+ Z_ = np.array(Z_)
156
+ Z_flat_ = np.array(Z_flat_)
157
+ else:
158
+ Z_ = Z
159
+ Z_flat_ = Z_flat
160
+
161
+ fig = go.Figure(data=[
162
+ go.Surface(z=Z_-10, opacity=1, colorscale='Viridis'),
163
+ go.Surface(z=Z_flat_, opacity=0.3, colorscale='Jet'),
164
+ ])
165
+ fig.update_layout(title='Intensity distribution', autosize=True,
166
+ margin=dict(l=65, r=50, b=65, t=90),
167
+ scene=dict(
168
+ xaxis_title='X',
169
+ yaxis_title='Y',
170
+ zaxis_title='Intensity*'))
171
+ fig.update_scenes(zaxis_range=[0, 256])
172
+ fig.update_traces(contours_z=dict(show=True, usecolormap=True,
173
+ highlightcolor="limegreen", project_z=True))
174
+ fig.show()
175
+
176
+ def plot_multiplier(self, multiplier):
177
+ fig, ax = plt.subplots(figsize=(15, 10))
178
+ ax.set_title("Multiplier")
179
+ im = ax.imshow(multiplier, cmap='jet')
180
+ plt.colorbar(im, ax=ax, orientation='vertical')
181
+ plt.show()
182
+ return fig
183
+
184
+ def show_3d(self, img_list, names=None):
185
+
186
+ fig = plt.figure()
187
+ ax = fig.add_subplot(111, projection='3d')
188
+
189
+ for i, img in enumerate(img_list):
190
+ x, y = np.meshgrid(range(img.shape[1]), range(img.shape[0]))
191
+ randi = np.random.randint(0, len(cmaps))
192
+ p = ax.plot_surface(x, y, img, cmap=cmaps[randi], alpha=0.5)
193
+ ax.set_xlabel('X')
194
+ ax.set_ylabel('Y')
195
+ ax.set_zlabel('Z')
196
+ if names is not None:
197
+ ax.legend(names)
198
+
199
+ fig.colorbar(p, ax=ax)
200
+ plt.show()
201
+ return fig
202
+
203
+ def detect_and_crop(self):
204
+
205
+ if not self.manual_crop:
206
+ sr = 0.95
207
+ results = self.model.predict(
208
+ source=self.img,
209
+ half=False,
210
+ show=False,
211
+ save=False,
212
+ save_txt=False,
213
+ conf=0.7,
214
+ iou=0.6
215
+ )
216
+ boxes = []
217
+ probs = []
218
+ for result in results:
219
+ box_cpu = np.round(result.boxes.xyxy.cpu().numpy()).astype(int)
220
+ prob_cpu = result.boxes.conf.cpu().numpy()
221
+ boxes.append(box_cpu[0,:])
222
+ probs.append(prob_cpu[0])
223
+ boxes = np.array(boxes)
224
+ probs = np.array(probs)
225
+ if len(boxes) > 1:
226
+ log_(f'{len(boxes)} objects detected', 'light_yellow', 'italic', 'warning')
227
+ areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
228
+ max_index = np.argmax(areas)
229
+ boxes = boxes[max_index]
230
+ probs = probs[max_index]
231
+ log_(
232
+ f'biggest object ID: {max_index} BB: {boxes} probability: {probs} selected',
233
+ 'light_yellow',
234
+ 'italic',
235
+ 'warning'
236
+ )
237
+
238
+ x1,y1,x2,y2 = boxes[0]
239
+ x1 = int(x1 + (1-sr) * (x2 - x1))
240
+ y1 = int(y1 + (1-sr) * (y2 - y1))
241
+ x2 = int(x2 - (1-sr) * (x2 - x1))
242
+ y2 = int(y2 - (1-sr) * (y2 - y1))
243
+
244
+ else: # select ROI manually
245
+ log_('Select ROI manually\nPress "ENTER" when done selecting ROI', 'light_blue', 'italic', 'info')
246
+ cv2.namedWindow("Press 'ENTER' when done", cv2.WINDOW_NORMAL)
247
+ cv2.resizeWindow("Press 'ENTER' when done", 1200, 800)
248
+ rect = cv2.selectROI("Press 'ENTER' when done", self.img, True)
249
+ cv2.destroyAllWindows()
250
+
251
+ try:
252
+ x1 = int(rect[0])
253
+ y1 = int(rect[1])
254
+ x2 = int(rect[0] + rect[2])
255
+ y2 = int(rect[1] + rect[3])
256
+ except:
257
+ log_('ROI not selected', 'light_yellow', 'italic', 'warning')
258
+ x1 = 0
259
+ y1 = 0
260
+ x2 = self.img.shape[1]
261
+ y2 = self.img.shape[0]
262
+
263
+ self.crop_rect = [x1, y1, x2, y2]
264
+ self.img_cropped = self.img[y1:y2, x1:x2]
265
+
266
+ if self.show:
267
+ img = self.img.copy()
268
+ cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 10)
269
+ cv2.namedWindow("Image ROI", cv2.WINDOW_NORMAL)
270
+ cv2.resizeWindow("Image ROI", 1200, 800)
271
+ cv2.imshow("Image ROI", img)
272
+ cv2.waitKey(500)
273
+ cv2.destroyAllWindows()
274
+
275
+ if torch.cuda.is_available():
276
+ torch.cuda.empty_cache()
277
+ del results, boxes, probs
278
+ gc.collect()
279
+
280
+ def get_L(self, img, smooth=False):
281
+ is_color = self.check_color(img)
282
+ img_LAB = None
283
+ if is_color:
284
+ img_LAB = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
285
+ L = img_LAB[:, :, 0]
286
+ else:
287
+ L = img
288
+ if smooth:
289
+ L = cv2.GaussianBlur(L, (self.smooth_window, self.smooth_window), 0)
290
+
291
+ return L, img_LAB
292
+
293
+ def polynomial_features(self, X, **kwargs):
294
+ degree = kwargs.get('degree', 5)
295
+ interactions = not(kwargs.get('interactions', False))
296
+ poly = PolynomialFeatures(degree=degree, interaction_only=interactions)
297
+ X_poly = poly.fit_transform(X)
298
+ names = poly.get_feature_names_out(['x', 'y'])
299
+ return X_poly, names, poly
300
+
301
+ def fit_model(self, X, y, **kwargs):
302
+
303
+ method = kwargs.get('fit_method', 'nn')
304
+ max_iter = kwargs.get('max_iter', 1000)
305
+ tol = kwargs.get('tol', 1e-8)
306
+ verbose = kwargs.get('verbose', False)
307
+ rand_seed = kwargs.get('rand_seed', 0)
308
+
309
+ options = ['linear', 'nn', 'pls', 'svm']
310
+
311
+ fit_method = match_keywords(method, options)
312
+
313
+ model_dict = {
314
+ 'linear': LinearRegression(
315
+ fit_intercept=True,
316
+ n_jobs=8,
317
+ ),
318
+
319
+ 'nn': MLPRegressor(
320
+ activation='relu',
321
+ solver='adam',
322
+ learning_rate='adaptive',
323
+ learning_rate_init=0.001,
324
+ hidden_layer_sizes=(100,),
325
+ max_iter=max_iter,
326
+ shuffle=True,
327
+ random_state=rand_seed,
328
+ tol=tol,
329
+ verbose=verbose,
330
+ nesterovs_momentum=True,
331
+ early_stopping=True,
332
+ n_iter_no_change=int(max_iter * 0.1),
333
+ validation_fraction=0.15,
334
+ ),
335
+
336
+ 'pls': PLSRegression(
337
+ n_components=np.shape(X)[1]-1,
338
+ max_iter=max_iter,
339
+ tol=tol,
340
+ ),
341
+
342
+ 'svm': SVR(
343
+ kernel='rbf',
344
+ degree=3,
345
+ verbose=verbose,
346
+ epsilon=0.1,
347
+ tol=tol,
348
+ max_iter=max_iter,
349
+ )
350
+ }
351
+ if fit_method not in options:
352
+ response = ThrowDlg.yesno(f"""Fit method '{fit_method}' is not recognized. Try one of: {options}.
353
+ Do you want to continue with the default method (Linear regression)?""")
354
+ if response.lower() == "yes":
355
+ fit_method = 'linear'
356
+ log_(f"Using the default method '{fit_method}' for fitting", color='orange', font_style = 'bold', level='WARNING')
357
+ return model_dict[fit_method]
358
+ else:
359
+ log_(f"Cancelled fitting using method '{fit_method}'", color='orange', font_style = 'bold', level='WARNING')
360
+ sys.exit(0)
361
+ else:
362
+ log_(f"ffc fitting using method '{fit_method}'...", color='cyan', font_style='italic', level='INFO')
363
+ model = model_dict[fit_method]
364
+
365
+ model.fit(X, y)
366
+ return model
367
+
368
+ def compute_multiplier(self, **kwargs):
369
+
370
+ # 1. Crop the image
371
+ self.detect_and_crop()
372
+
373
+ img_full = self.img.copy()
374
+ img_cropped = self.img_cropped.copy()
375
+
376
+ L_full, _ = self.get_L(img_full, smooth=True)
377
+ L_cropped, _ = self.get_L(img_cropped, smooth=True)
378
+
379
+ L_float = L_cropped.astype(FLOAT)/255
380
+ self.cropped_multiplier = np.max(L_float)/L_float
381
+
382
+ flat_cropped = (L_float * self.cropped_multiplier)
383
+ flat_cropped = (255*flat_cropped).astype(UINT8)
384
+
385
+ # 2. Compute metrics
386
+ # (omitted for brevity, same as before)
387
+
388
+ # 3. Extrapolate multiplier
389
+ if self.crop_rect is None:
390
+ y1, x1, y2, x2 = 0, 0, self.img.shape[1], self.img.shape[0]
391
+ else:
392
+ y1, x1, y2, x2 = self.crop_rect
393
+
394
+ x = np.linspace(x1, x2-1, self.bins)
395
+ y = np.linspace(y1, y2-1, self.bins)
396
+ X, Y = np.meshgrid(x, y)
397
+
398
+ x_c = np.linspace(0, L_cropped.shape[0]-1, self.bins)
399
+ y_c = np.linspace(0, L_cropped.shape[1]-1, self.bins)
400
+ X_c, Y_c = np.meshgrid(x_c, y_c)
401
+
402
+ h_win = int((self.smooth_window-1) / 2)
403
+
404
+ Z_m = np.ones_like(X_c)
405
+ for i, x_ in enumerate(x_c):
406
+ for j, y_ in enumerate(y_c):
407
+ x_,y_ = int(x_), int(y_)
408
+ x_l, x_h = [max(0, x_-h_win), min(L_cropped.shape[0]-1, x_+h_win)]
409
+ y_l, y_h = [max(0, y_-h_win), min(L_cropped.shape[1]-1, y_+h_win)]
410
+ Z_m[i, j] = np.mean(self.cropped_multiplier[x_l:x_h, y_l:y_h])
411
+
412
+ Z_m = np.array(Z_m)
413
+
414
+ x_flat, y_flat = X.flatten(), Y.flatten()
415
+ z_flat = Z_m.flatten()
416
+
417
+ min_x, max_x = 0, L_full.shape[0]
418
+ min_y, max_y = 0, L_full.shape[1]
419
+ min_z, max_z = np.min(z_flat), np.max(z_flat)
420
+
421
+ eps = 1e-15
422
+ x_flat = (x_flat - min_x) / (max_x - min_x + eps)
423
+ y_flat = (y_flat - min_y) / (max_y - min_y + eps)
424
+ z_flat = (z_flat - min_z) / (max_z - min_z + eps)
425
+
426
+ xy_flat = np.stack([x_flat, y_flat], axis=1)
427
+ xy_flat, names, poly = self.polynomial_features(xy_flat, **kwargs)
428
+
429
+ model = self.fit_model(xy_flat, z_flat, **kwargs)
430
+
431
+ x_full = np.linspace(0, L_full.shape[0]-1, self.bins)
432
+ y_full = np.linspace(0, L_full.shape[1]-1, self.bins)
433
+ X_full, Y_full = np.meshgrid(x_full, y_full)
434
+
435
+ X_full_flat = X_full.flatten()
436
+ Y_full_flat = Y_full.flatten()
437
+
438
+ x_full_flat = (X_full_flat - min_x) / (max_x - min_x + eps)
439
+ y_full_flat = (Y_full_flat - min_y) / (max_y - min_y + eps)
440
+
441
+ xy_full_flat = np.stack([x_full_flat, y_full_flat], axis=1)
442
+ xy_full_flat = poly.transform(xy_full_flat)
443
+
444
+ # predict
445
+ f_multiplier = model.predict(xy_full_flat)
446
+ f_multiplier = (f_multiplier*(max_z-min_z) + min_z).reshape(self.bins, self.bins)
447
+
448
+ f_multiplier = self.transform_extremity(
449
+ f_multiplier,
450
+ max=1.8,
451
+ cut_off=1.3,
452
+ )
453
+
454
+ f_multiplier = cv2.resize(f_multiplier, (L_full.shape[1], L_full.shape[0]), interpolation=cv2.INTER_CUBIC)
455
+ self.final_multiplier = f_multiplier
456
+
457
+ # 4. Apply FFC to the image (via per-pixel multiplication)
458
+ img_corrected = self.apply_ffc(img_full)
459
+
460
+ if self.show:
461
+ self.show_3d([flat_cropped,L_cropped], names=['Flat','Original'])
462
+ self.show_3d([self.final_multiplier], names=['Final Multiplier'])
463
+ self.show_results(img_corrected, img_full)
464
+
465
+ gc.collect()
466
+ return self.final_multiplier
467
+
468
+
469
+ def apply_ffc(self, img, multiplier=None, show=False):
470
+ """
471
+ Applies flat‐field correction to `img` by multiplying its L channel by self.final_multiplier.
472
+ """
473
+ img_orig = img if not show else img.copy()
474
+ assert img_orig.dtype == UINT8, 'Image must be of type UINT8'
475
+
476
+ if multiplier is not None:
477
+ self.final_multiplier = multiplier
478
+
479
+ w, h = img.shape[:2]
480
+ w_o, h_o = self.final_multiplier.shape[:2]
481
+
482
+ if (w, h) != (w_o, h_o):
483
+ log_(
484
+ f'Image size: {w}x{h} | Final multiplier size: {w_o}x{h_o}',
485
+ 'light_yellow', 'italic'
486
+ )
487
+ rtn = ThrowDlg.yesno(
488
+ msg='Image size does not match final multiplier size. Do you want to resize the image?',
489
+ )
490
+ if rtn == 'yes':
491
+ img = self.resize_image(img, size=(w_o, h_o))
492
+
493
+ L, img_LAB = self.get_L(img, smooth=False)
494
+
495
+ # multiply L channel by final_multiplier
496
+ L_ = (L.astype(FLOAT) * self.final_multiplier).astype(UINT8)
497
+
498
+ if self.check_color(img):
499
+ img_LAB[:, :, 0] = L_
500
+ img_corrected = cv2.cvtColor(img_LAB, cv2.COLOR_LAB2BGR)
501
+ else:
502
+ img_corrected = L_
503
+
504
+ if show:
505
+ self.show_results(img_corrected, img_orig)
506
+
507
+ return np.clip(img_corrected, 0, 255)
508
+
FFC/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .FF_correction import FlatFieldCorrection
2
+
3
+ __all__ = ['FlatFieldCorrection']
__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ __version__ = "1.01.1"
2
+
3
+ from core import ColorCorrection
4
+ from Configs.configs import Config
5
+ from models import MyModels
6
+ from FFC.FF_correction import FlatFieldCorrection
7
+ from key_functions import *
8
+ from utils.logger_ import *
9
+ from utils.metrics_ import *