ImgAlign 4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,781 @@
1
+ from mpl_interactions import zoom_factory, panhandler
2
+ import matplotlib as mpl
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib.animation import FuncAnimation
5
+ import numpy as np
6
+ import os
7
+ import sys
8
+ import cv2
9
+ import glob
10
+ import textwrap
11
+ import torch
12
+ import argparse
13
+ from argparse import Namespace
14
+ import math
15
+ from itertools import groupby
16
+ from scipy.ndimage import map_coordinates
17
+ from .raft.raft import RAFT
18
+ from concurrent.futures import ThreadPoolExecutor, as_completed
19
+ from sklearn.linear_model import RANSACRegressor
20
+ from python_color_transfer.color_transfer import ColorTransfer
21
+ PT = ColorTransfer()
22
+
23
+
24
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,epilog = textwrap.dedent('''\
25
+
26
+ Manual Keys:
27
+ Double click left: Select point.
28
+ Click and Drag left: Pan image.
29
+ Scroll Wheel: Zoom in and out.
30
+ Click Scroll Wheel: Delete matching pairs of points.
31
+ Double Click right: Reset image view.
32
+ Spacebar: Toggle edge detection view
33
+ u: Undo last point selection.
34
+ w: Close all windows to progress.
35
+ p: Preview alignment. Overlays images using current alignment points.'''))
36
+ parser.add_argument("-s", "--scale", help="Positive integer value. How many times bigger you want the HR resolution to be from the LR resolution.", required=True)
37
+ parser.add_argument("-m", "--mode", default=0, help="Options: 0 or 1. Mode 0 manipulates the HR images while remaining true to the LR images aside from cropping. Mode 1 manipulates the LR images and remains true to the HR images aside from cropping. In almost every case, you will want to use mode 0 so as not to alter the degradations on the LR images.")
38
+ parser.add_argument("-l", "--lr", default='', help="LR File or folder directory. Use this to specify your low resolution image file or folder of images. By default, ImgAlign will use images in the LR folder in the current working directory.")
39
+ parser.add_argument("-g", "--hr", default='', help="HR File or folder directory. Use this to specify your high resolution image file or folder of images. By default, ImgAlign will use images in the HR folder in the current working directory.")
40
+ parser.add_argument("-c", "--autocrop", action='store_true', default=False, help="Disabled by default. If enabled, this auto crops black boarders around HR and LR images. Manually cropping images before running through ImgAlign will usually yield more consistent results so that dark frames aren't overcropped")
41
+ parser.add_argument("-t", "--threshold", default=50, help="Integer 0-255, default 50. Luminance threshold for autocropping. Higher values cause more agressive cropping.")
42
+ parser.add_argument("-j", "--affine", action='store_true', default=False, help="Basic affine alignment. Used as default if no other option is specified")
43
+ parser.add_argument("-r", "--rotate", action='store_true', default=False, help="Disabled by default. If enabled, this allows rotations when aligning images.")
44
+ parser.add_argument("-f", "--full", action='store_true', default=False, help="Disabled by default. If enabled, this allows full homography mapping of the image, correcting rotations, translations, and perspecive warping.")
45
+ parser.add_argument("-w", "--warp", action='store_true', default=False, help="Disabled by default. Match images using Thin Plate Splines, allowing full image warping. Because of the nature of TPS warping, this option requires that manual or semiautomatic points are used.")
46
+ parser.add_argument("-ai", "--ai", action='store_true', default=False, help="Disabled by default. This option allows use of RAFT optical flow to align images. This can be used in conjunction with any of the aligning methods, affine, rotation, homography, or warping to improve alignment, or by itself. This method can occasionally cause artifacts in the output depending on the type of low resolution images being used, this can usually be fixed by lowering the quality parameter to 2 or 1.")
47
+ parser.add_argument("-q", "--quality", default=3, help="Integer 1-3, Default 3. Quality of the AI alignment. Higher numbers are more aggressive and ususally improves alignment, but can cause AI artifacts on some sources. Lower numbers might impact alignment, but causes fewer AI artifacts, uses less VRAM, runs a little faster, and is more suitable for multithreading.")
48
+ parser.add_argument("-u", "--manual", action='store_true', default=False, help="Disabled by default. Manual mode. If enabled, this opens windows for working pairs of images to be aligned. Double click pairs of matching points on each image in sequence, and close the windows when finished.")
49
+ parser.add_argument("-a", "--semiauto", action='store_true', default=False, help="Disabled by default. Semiautomatic mode. Automatically finds matching points, but loads them into a viewer window to manually delete or add more.")
50
+ parser.add_argument("-o", "--overlay", action='store_false', default=True, help="Enabled by default. After saving aligned images, this option will create a separate 50:50 merge of the aligned images in the Overlay folder. Useful for quickly checking through image sets for poorly aligned outputs.")
51
+ parser.add_argument("-i", "--color", default=0, help="Default disabled. After alignment, option -1 changes the colors of the HR image to match those of the LR image. Option 1 changes the color of the LR images to match the HR images. This can occasionally cause miscolored regions in the altered images, so examine the results carefully.")
52
+ parser.add_argument("-n", "--threads", default=1, help="Default 1. Number of threads to use for automatic matching. Large images require a lot of RAM, so start small to test first.")
53
+ parser.add_argument("-e", "--score", action='store_true', default=False, help="Disabled by default. Calculate an alignment score for each processed pair of images. These scores should be taken with a grain of salt, they are mainly to give a general idea of how well aligned things are.")
54
+
55
+ args = vars(parser.parse_args())
56
+
57
+ scale = float(args["scale"])
58
+ mode = int(args["mode"])
59
+ autocrop = args["autocrop"]
60
+ lumthresh = int(args["threshold"])
61
+ threads = int(args["threads"])
62
+ rotate = args["rotate"]
63
+ affi = args["affine"]
64
+ HRfolder = args["hr"]
65
+ LRfolder = args["lr"]
66
+ Overlay = args["overlay"]
67
+ Homography = args["full"]
68
+ Manual = args["manual"]
69
+ score = args["score"]
70
+ semiauto = args["semiauto"]
71
+ warp = args["warp"]
72
+ color_correction = int(args["color"])
73
+ optical_flow = args["ai"]
74
+ quality = args["quality"]
75
+
76
+ # Changing conflicting or priority setting
77
+ if optical_flow:
78
+ if quality == 3:
79
+ Qh, Qw, gauss = 1080, 1440, (5,5)
80
+ elif quality == 2:
81
+ Qh, Qw, gauss = 720, 960, (3,3)
82
+ else:
83
+ Qh, Qw, gauss = 544, 720, (1,1)
84
+ if torch.cuda.is_available():
85
+ DEVICE = 'cuda'
86
+ else:
87
+ DEVICE = 'cpu'
88
+ def load_image(imfile):
89
+ imfile = cv2.GaussianBlur(imfile,gauss, 0)
90
+ img = np.array(imfile).astype(np.uint8)
91
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
92
+ return img[None].to(DEVICE)
93
+ RAFT_MODULE_DIR = os.path.dirname(__file__)
94
+ RAFT_THINGS_FILE = os.path.join(RAFT_MODULE_DIR, 'raft', 'raft-things.pth')
95
+ args = Namespace(small=False, alternate_corr=False, mixed_precision=False, model=RAFT_THINGS_FILE)
96
+ model = torch.nn.DataParallel(RAFT(args))
97
+ model.load_state_dict(torch.load(args.model,map_location=torch.device(DEVICE)))
98
+ model = model.module
99
+ model.to(DEVICE)
100
+ model.eval()
101
+
102
+ if warp:
103
+ Homography = False
104
+ rotate = False
105
+ if not Manual:
106
+ semiauto = True
107
+
108
+ if Manual or semiauto:
109
+ affi = True
110
+
111
+ if Manual or semiauto:
112
+ threads = 1
113
+
114
+ MAX_FEATURES = 500
115
+
116
+ # Cropping dark boarders around images
117
+ def AutoCrop(image):
118
+
119
+ threshold = lumthresh
120
+ if len(image.shape) == 3:
121
+ flatImage = np.max(image, 2)
122
+ else:
123
+ flatImage = image
124
+ assert len(flatImage.shape) == 2
125
+
126
+ rows = np.where(np.max(flatImage, 0) > threshold)[0]
127
+ if rows.size:
128
+ cols = np.where(np.max(flatImage, 1) > threshold)[0]
129
+ image = image[cols[0]: cols[-1] + 1, rows[0]: rows[-1] + 1]
130
+ else:
131
+ image = image[:1, :1]
132
+
133
+ return image
134
+
135
+ # Remove outliers from points
136
+ def ransac(pnt1, pnt2):
137
+ pnt1x, pnt1y = pnt1.reshape(-1,2)[:,0].reshape(-1,1), pnt1.reshape(-1,2)[:,1].reshape(-1,1)
138
+ pnt2x, pnt2y = pnt2.reshape(-1,2)[:,0].reshape(-1,1), pnt2.reshape(-1,2)[:,1].reshape(-1,1)
139
+ ransacx = RANSACRegressor().fit(pnt1x, pnt2x)
140
+ ransacy = RANSACRegressor().fit(pnt1y, pnt2y)
141
+ inlier_maskx = ransacx.inlier_mask_
142
+ inlier_masky = ransacy.inlier_mask_
143
+ inliers = inlier_maskx*inlier_masky
144
+ pnt1, pnt2 = pnt1[inliers], pnt2[inliers]
145
+ return pnt1, pnt2
146
+
147
+ # Create and apply Thin Plate Spline transform to an image
148
+ def WarpImage_TPS(source, target, img, interp):
149
+ tps = cv2.createThinPlateSplineShapeTransformer()
150
+
151
+ source=source.reshape(-1,max(source.shape[0],source.shape[1]),2)
152
+ target=target.reshape(-1,max(target.shape[0],target.shape[1]),2)
153
+
154
+ matches=list()
155
+ for i in range(0,len(source[0])):
156
+
157
+ matches.append(cv2.DMatch(i,i,0))
158
+
159
+ tps.estimateTransformation(target, source, matches)
160
+ if interp == 0:
161
+ new_img = tps.warpImage(img, flags = cv2.INTER_NEAREST)
162
+ else:
163
+ new_img = tps.warpImage(img, flags = cv2.INTER_CUBIC)
164
+
165
+ return new_img
166
+
167
+ # Make and manipulate plots for manual point selection
168
+ def manual_points(img1, img2, pointsA = None, pointsB = None):
169
+ global pnts1, pnts2, markers1, markers2, active, dis1, dis2, axis1, axis2
170
+
171
+ dis1 = img1
172
+ dis2 = img2
173
+ img1edge = cv2.Canny(cv2.GaussianBlur(cv2.cvtColor(img1,cv2.COLOR_BGR2GRAY), (5,5), 0), 50,150)
174
+ img2edge = cv2.Canny(cv2.GaussianBlur(cv2.cvtColor(img2,cv2.COLOR_BGR2GRAY), (5,5), 0), 50,150)
175
+ pnts1 = np.array([])
176
+ pnts2 = np.array([])
177
+ pnts1.shape = (0,2)
178
+ pnts2.shape = (0,2)
179
+ markers1 = []
180
+ markers2 = []
181
+ active = []
182
+ characters=['o', 'v','^','<','>','1','2','3','4','s','p','P','*','+','x','X','D','d']
183
+ if pointsA is not None:
184
+ pointsA, pointsB = pointsA.reshape(-1,2), pointsB.reshape(-1,2)
185
+ for row in pointsA:
186
+ pnts1 = np.concatenate((pnts1,row.reshape(1,2)))
187
+ for row in pointsB:
188
+ pnts2 = np.concatenate((pnts2,row.reshape(1,2)))
189
+
190
+ # Matplotlib UI function
191
+ # Count backwards, helps redraw after scroll wheel click paired point deletion
192
+ def tnuoc(mnum1, acnum):
193
+ global active
194
+ numA = 0
195
+ count = 0
196
+ for idx in range(len(active)-1,-1,-1):
197
+ if active[idx] == acnum:
198
+ count += 1
199
+ if count == mnum1:
200
+ ele = idx
201
+ break
202
+ return ele
203
+
204
+ # Click event functions
205
+ def onclick(event, graph):
206
+ # Select a point
207
+ if event.dblclick and event.button == 1:
208
+ global pnts1, pnts2, markers1, markers2, active
209
+
210
+ ix, iy = event.xdata, event.ydata
211
+ if ix != None and iy != None:
212
+ if graph == 1:
213
+ active.append(1)
214
+ pnts1 = np.concatenate((pnts1,np.array([[ix,iy]])))
215
+ marker = plt.plot(event.xdata, event.ydata, characters[len(markers1)%18], color=mpl.colormaps.get_cmap('hsv')((len(markers1)*25)%256), picker = 5)
216
+ markers1.append(marker)
217
+ plt.draw()
218
+ print(f'x1 = {ix}, y1 = {iy}')
219
+ if graph == 2:
220
+ active.append(2)
221
+ pnts2 = np.concatenate((pnts2,np.array([[ix,iy]])))
222
+ marker = plt.plot(event.xdata, event.ydata, characters[len(markers2)%18], color=mpl.colormaps.get_cmap('hsv')((len(markers2)*25)%256), picker = 5)
223
+ markers2.append(marker)
224
+ plt.draw()
225
+ print(f'x2 = {ix}, y2 = {iy}')
226
+ # Reset view
227
+ if event.dblclick and event.button == 3:
228
+ plt.autoscale(enable=True, axis='both', tight=None)
229
+ plt.draw()
230
+
231
+
232
+ def onpick(event, graph):
233
+ global pnts1, pnts2, markers1, markers2, active
234
+ # Scroll wheel click for paired point deletion
235
+ if event.mouseevent.button == 2:
236
+ pairpoints = event.artist
237
+ rowel = (pairpoints.get_xdata()[0], pairpoints.get_ydata()[0])
238
+ if graph == 1:
239
+ rownum = np.where(np.all(pnts1 == rowel, axis = 1))[0][0]
240
+ else:
241
+ rownum = np.where(np.all(pnts2 == rowel, axis = 1))[0][0]
242
+ if len(pnts1) >= rownum + 1 and len(pnts2) >= rownum + 1:
243
+ pnts1 = np.delete(pnts1, rownum, axis = 0)
244
+ pnts2 = np.delete(pnts2, rownum, axis = 0)
245
+ if len(markers1) - rownum <= active.count(1):
246
+ m1row = len(markers1) - rownum
247
+ m2row = len(markers2) - rownum
248
+ A1 = tnuoc(m1row, 1)
249
+ active.pop(A1)
250
+ A2 = tnuoc(m2row, 2)
251
+ active.pop(A2)
252
+ markers1.pop(rownum)[0].remove()
253
+ markers2.pop(rownum)[0].remove()
254
+ fig1.canvas.draw()
255
+ fig2.canvas.draw()
256
+
257
+ def on_key_press(event):
258
+ global pnts1, pnts2, markers1, markers2, active, dis1, dis2, axis1, axis2
259
+
260
+ # Undo, preview, edge detection, and close all key strokes
261
+ if event.key == 'u':
262
+ print('Undo')
263
+ if active:
264
+ if active[-1] == 1:
265
+ lastact = active.pop()
266
+ last_marker = markers1.pop()
267
+ last_marker[0].remove()
268
+ pnts1 = pnts1[:-1]
269
+ else:
270
+ lastact = active.pop()
271
+ last_marker = markers2.pop()
272
+ last_marker[0].remove()
273
+ pnts2 = pnts2[:-1]
274
+ fig1.canvas.draw()
275
+ fig2.canvas.draw()
276
+
277
+ if event.key == 'w':
278
+ plt.close('all')
279
+
280
+ if event.key == ' ':
281
+ if dis1 is img1:
282
+ dis1 = img1edge
283
+ dis2 = img2edge
284
+ else:
285
+ dis1 = img1
286
+ dis2 = img2
287
+ axis1.set_data(cv2.cvtColor(dis1,cv2.COLOR_BGR2RGB))
288
+ axis2.set_data(cv2.cvtColor(dis2,cv2.COLOR_BGR2RGB))
289
+ fig1.canvas.draw()
290
+ fig2.canvas.draw()
291
+
292
+ if event.key == 'p':
293
+ if len(pnts1) >= 4 and (len(pnts1) == len(pnts2)):
294
+ def on_press(event):
295
+ if event.button == 1:
296
+ fig3.canvas.toolbar.press_pan(event)
297
+ def on_release(event):
298
+ if event.button == 1:
299
+ fig3.canvas.toolbar.release_pan(event)
300
+
301
+ for fig in plt.get_fignums():
302
+ if plt.figure(fig).canvas.manager.get_window_title() == 'Figure 3':
303
+ plt.close(fig)
304
+ break
305
+ temp2 = dis2
306
+ if Homography:
307
+ hom, _ = cv2.findHomography(pnts1, pnts2, cv2.RANSAC)
308
+ temp1 = cv2.warpPerspective(dis1, hom, (dis2.shape[1],dis2.shape[0]), flags = cv2.INTER_CUBIC)
309
+ elif warp:
310
+ if len(dis1.shape) == 2:
311
+ temp1 = np.pad(dis1,[(0,max(0,dis2.shape[0]-dis1.shape[0])),(0,max(0,dis2.shape[1]-dis1.shape[1]))])
312
+ temp1 = WarpImage_TPS(pnts1, pnts2, temp1, 1)
313
+ temp1 = temp1[0:dis2.shape[0],0:dis2.shape[1]]
314
+ else:
315
+ temp1 = np.pad(dis1,[(0,max(0,dis2.shape[0]-dis1.shape[0])),(0,max(0,dis2.shape[1]-dis1.shape[1])),(0,0)])
316
+ temp1 = WarpImage_TPS(pnts1, pnts2, temp1, 1)
317
+ temp1 = temp1[0:dis2.shape[0],0:dis2.shape[1],:]
318
+ else:
319
+ hom, _ = cv2.estimateAffine2D(pnts1, pnts2, cv2.RANSAC)
320
+ if not rotate:
321
+ sx = math.sqrt(hom[0,0]**2+hom[1,0]**2)
322
+ sy = math.sqrt(hom[0,1]**2+hom[1,1]**2)
323
+ hom[:,:2] = np.array([[sx,0],[0,sy]])
324
+ temp1 = cv2.warpAffine(dis1, hom, (dis2.shape[1],dis2.shape[0]), flags = cv2.INTER_CUBIC)
325
+ with plt.ioff():
326
+ fig3, ax3 = plt.subplots()
327
+ ani = ax3.imshow(cv2.cvtColor(temp1,cv2.COLOR_BGR2RGB))
328
+ def update(frame):
329
+ if frame % 2 == 0:
330
+ ani.set_array(cv2.cvtColor(temp1,cv2.COLOR_BGR2RGB))
331
+ else:
332
+ ani.set_array(cv2.cvtColor(temp2,cv2.COLOR_BGR2RGB))
333
+ animation = FuncAnimation(fig3, update, frames=np.arange(0,10), interval=500)
334
+ plt.get_current_fig_manager().toolbar.pack_forget()
335
+ disconnect_zoom3 = zoom_factory(ax3)
336
+ cid = fig3.canvas.mpl_connect('button_press_event', lambda event: onclick(event, 3) if event.dblclick else None)
337
+ cid_key = fig3.canvas.mpl_connect('key_press_event', on_key_press)
338
+ cidpress = fig3.canvas.mpl_connect('button_press_event', on_press)
339
+ cidrelease = fig3.canvas.mpl_connect('button_release_event', on_release)
340
+ plt.axis('off')
341
+ plt.tight_layout()
342
+ plt.show()
343
+ else:
344
+ print('At least 4 points must be selected and the same number of points must be on each image.')
345
+
346
+ # Generate the plots and link functions and controls
347
+ preview = dis2[:]
348
+ repeat = 0
349
+ while len(pnts1) < 4 or (len(pnts1) != len(pnts2)) or repeat == 0:
350
+ with plt.ioff():
351
+ fig1, ax1 = plt.subplots()
352
+ fig1.subplots_adjust(left=0,bottom=0,right=1,top=1)
353
+ axis1 = ax1.imshow(cv2.cvtColor(dis1,cv2.COLOR_BGR2RGB))
354
+ plt.get_current_fig_manager().toolbar.pack_forget()
355
+ disconnect_zoom1 = zoom_factory(ax1)
356
+ pan_handler1 = panhandler(fig1,button=1)
357
+
358
+ if len(pnts1) > 0 or len(pnts2) > 0:
359
+ markers1 = []
360
+ i = 0
361
+ for redo in pnts1:
362
+ markers1.append(plt.plot(redo[0], redo[1], characters[i%18], color=mpl.colormaps.get_cmap('hsv')((i*25)%256), picker = 5))
363
+ i += 1
364
+ fig1.canvas.mpl_disconnect(fig1.canvas.manager.key_press_handler_id)
365
+ cid = fig1.canvas.mpl_connect('button_press_event', lambda event: onclick(event, 1) if event.dblclick else None)
366
+ cid_pick = fig1.canvas.mpl_connect('pick_event', lambda event: onpick(event,1))
367
+ cid_key = fig1.canvas.mpl_connect('key_press_event', on_key_press)
368
+
369
+ plt.axis('off')
370
+ plt.tight_layout()
371
+
372
+ with plt.ioff():
373
+ fig2, ax2 = plt.subplots()
374
+ fig2.subplots_adjust(left=0,bottom=0,right=1,top=1)
375
+ axis2 = ax2.imshow(cv2.cvtColor(dis2,cv2.COLOR_BGR2RGB))
376
+ plt.get_current_fig_manager().toolbar.pack_forget()
377
+ disconnect_zoom2 = zoom_factory(ax2)
378
+ pan_handler2 = panhandler(fig2,button=1)
379
+
380
+ if len(pnts1) > 0 or len(pnts2) > 0:
381
+ markers2 = []
382
+ i = 0
383
+ for redo in pnts2:
384
+ markers2.append(plt.plot(redo[0], redo[1], characters[i%18], color=mpl.colormaps.get_cmap('hsv')((i*25)%256), picker = 5))
385
+ i += 1
386
+ fig2.canvas.mpl_disconnect(fig2.canvas.manager.key_press_handler_id)
387
+ cid = fig2.canvas.mpl_connect('button_press_event', lambda event: onclick(event, 2) if event.dblclick else None)
388
+ cid_pick = fig2.canvas.mpl_connect('pick_event', lambda event: onpick(event,2))
389
+ cid_key = fig2.canvas.mpl_connect('key_press_event', on_key_press)
390
+ plt.axis('off')
391
+ plt.tight_layout()
392
+ plt.show()
393
+
394
+ if len(pnts1) < 4 or (len(pnts1) != len(pnts2)):
395
+ print('At least 4 points must be selected and the same number of points must be on each image.')
396
+ repeat = 1
397
+ return pnts1, pnts2
398
+
399
+ # Automatic point finding with SIFT
400
+ def auto_points(im1, im2):
401
+
402
+ im1y, im1x, _ = im1.shape
403
+ im2y, im2x, _ = im2.shape
404
+
405
+ im1 = cv2.resize(im1,(max(im1x,im2x),max(im1y,im2y)),interpolation=cv2.INTER_CUBIC)
406
+ im2 = cv2.resize(im2,(max(im1x,im2x),max(im1y,im2y)),interpolation=cv2.INTER_CUBIC)
407
+
408
+ im1Gray = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
409
+ im2Gray = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)
410
+
411
+ sift = cv2.SIFT_create(MAX_FEATURES)
412
+ keypoints1, descriptors1 = sift.detectAndCompute(im1Gray, None)
413
+ keypoints2, descriptors2 = sift.detectAndCompute(im2Gray, None)
414
+
415
+ bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)
416
+ matches = bf.knnMatch(descriptors1,descriptors2,k=2)
417
+
418
+ good = []
419
+ for m,n in matches:
420
+ if m.distance < 0.7*n.distance:
421
+ good.append(m)
422
+
423
+ if len(good) > 5:
424
+ points1 = np.float32([ keypoints1[m.queryIdx].pt for m in good ]).reshape(-1,1,2)
425
+ points2 = np.float32([ keypoints2[m.trainIdx].pt for m in good ]).reshape(-1,1,2)
426
+
427
+ points1[:,0,0], points1[:,0,1] = points1[:,0,0]*im1x/max(im1x,im2x), points1[:,0,1]*im1y/max(im1y,im2y)
428
+ points2[:,0,0], points2[:,0,1] = points2[:,0,0]*im2x/max(im1x,im2x), points2[:,0,1]*im2y/max(im1y,im2y)
429
+
430
+ points1, points2 = ransac(points1, points2)
431
+ _, ind1 = np.unique(points1, axis=0, return_index=True)
432
+ _, ind2 = np.unique(points2, axis=0, return_index=True)
433
+ remrows = np.intersect1d(ind1, ind2)
434
+ points1, points2 = points1[remrows], points2[remrows]
435
+
436
+ return points1, points2
437
+
438
+ # Creates a histogram of the longest stretch of consecutive ones for every row in an array
439
+ def longest_ones(matrix):
440
+ result = []
441
+ for row in matrix:
442
+ max_stretch = 0
443
+ for _, group in groupby(row):
444
+ if _ == 1:
445
+ max_stretch = max(max_stretch, len(list(group)))
446
+ result.append(max_stretch)
447
+ return result
448
+
449
+ # Finds the bounds of the largest rectangle in a histogram
450
+ def get_largest_rectangle_indices(heights):
451
+ stack = [-1]
452
+ max_area = 0
453
+ max_indices = (0, 0)
454
+ for i in range(len(heights)):
455
+ while stack[-1] != -1 and heights[stack[-1]] >= heights[i]:
456
+ current_height = heights[stack.pop()]
457
+ current_width = i - stack[-1] - 1
458
+ current_area = current_height * current_width
459
+ if current_area > max_area:
460
+ max_area = current_area
461
+ max_indices = (stack[-1] + 1, i - 1)
462
+ stack.append(i)
463
+ while stack[-1] != -1:
464
+ current_height = heights[stack.pop()]
465
+ current_width = len(heights) - stack[-1] - 1
466
+ current_area = current_height * current_width
467
+ if current_area > max_area:
468
+ max_area = current_area
469
+ max_indices = (stack[-1] + 1, len(heights) - 1)
470
+ return max_indices
471
+
472
+ # Find a large usable rectangle from a transformed dummy array
473
+ def find_rectangle(arr):
474
+
475
+ rowhist = longest_ones(arr)
476
+ colhist = longest_ones(arr.T)
477
+ rows = get_largest_rectangle_indices(rowhist)
478
+ cols = get_largest_rectangle_indices(colhist)
479
+
480
+ if 0 in arr[rows[0]:rows[1]+1,cols[0]:cols[1]+1]:
481
+ while 0 in arr[rows[0]:rows[1]+1,cols[0]:cols[1]+1]:
482
+ rows += np.array([1,-1])
483
+ cols += np.array([1,-1])
484
+ while cols[0] > 0 and 0 not in arr[rows[0]:rows[1]+1,cols[0]-1]:
485
+ cols[0] -= 1
486
+ while cols[1] < arr.shape[1]-1 and 0 not in arr[rows[0]:rows[1]+1,cols[1]+1]:
487
+ cols[1] += 1
488
+ while rows[0] > 0 and 0 not in arr[rows[0]-1,cols[0]:cols[1]+1]:
489
+ rows[0] -= 1
490
+ while rows[1] < arr.shape[0]-1 and 0 not in arr[rows[1]+1,cols[0]:cols[1]+1]:
491
+ rows[1] += 1
492
+
493
+ return np.array([rows[0], cols[0]]), np.array([rows[1], cols[1]])
494
+
495
+ # Improves alignment using RAFT
496
+ def AI_Align_Process(aim1, aim2, pre = 0):
497
+ # Precrops the images to overlapping regions if they aren't prealigned
498
+ if pre == 0:
499
+ aim1y, aim1x, _ = aim1.shape
500
+ aim2y, aim2x, _ = aim2.shape
501
+ prewhite = np.ones_like(aim1[:,:,0])
502
+ taim1 = cv2.resize(aim1,(min(aim1x,aim2x),min(aim1y,aim2y)),interpolation=cv2.INTER_CUBIC)
503
+ taim2 = cv2.resize(aim2,(min(aim1x,aim2x),min(aim1y,aim2y)),interpolation=cv2.INTER_CUBIC)
504
+ points1, points2 = auto_points(taim1, taim2)
505
+ points1[:,0,0], points1[:,0,1] = points1[:,0,0]*aim1x/min(aim1x,aim2x), points1[:,0,1]*aim1y/min(aim1y,aim2y)
506
+ points2[:,0,0], points2[:,0,1] = points2[:,0,0]*aim2x/min(aim1x,aim2x), points2[:,0,1]*aim2y/min(aim1y,aim2y)
507
+ h, _ = cv2.findHomography(points1, points2, cv2.RANSAC)
508
+ prewarp = cv2.warpPerspective(prewhite,h,(aim2x,aim2y),flags=0)
509
+ pntsA, pntsD = find_rectangle(prewarp)
510
+ ih = np.linalg.inv(h)
511
+ A = np.matmul(ih,np.array([pntsA[1], pntsA[0], 1]).T)
512
+ B = np.matmul(ih,np.array([pntsA[1], pntsD[0], 1]).T)
513
+ C = np.matmul(ih,np.array([pntsD[1], pntsA[0], 1]).T)
514
+ D = np.matmul(ih,np.array([pntsD[1], pntsD[0], 1]).T)
515
+ top = int(np.clip(min(A[1], B[1], C[1], D[1]),0,aim1y))
516
+ bottom = int(np.clip(max(A[1], B[1], C[1], D[1]),0,aim1y))
517
+ left = int(np.clip(min(A[0], B[0], C[0], D[0]),0,aim1x))
518
+ right = int(np.clip(max(A[0], B[0], C[0], D[0]),0,aim1x))
519
+ aim1 = aim1[top:bottom+1,left:right+1]
520
+ aim2 = aim2[pntsA[0]:pntsD[0]+1,pntsA[1]:pntsD[1]+1]
521
+ # RAFT mapping
522
+ aim1y, aim1x, _ = aim1.shape
523
+ aim2y, aim2x, _ = aim2.shape
524
+ if mode == 1:
525
+ aim2 = aim2[:aim2y-int((aim2y%ogscale)),:aim2x-int((aim2x%ogscale)),:]
526
+ aim1r = cv2.resize(aim1,(Qw,Qh),cv2.INTER_CUBIC)
527
+ aim2r = cv2.resize(aim2,(Qw,Qh),cv2.INTER_CUBIC)
528
+ aim1g=cv2.cvtColor(cv2.cvtColor(aim1r,cv2.COLOR_BGR2GRAY),cv2.COLOR_GRAY2BGR)
529
+ aim2g=cv2.cvtColor(cv2.cvtColor(aim2r,cv2.COLOR_BGR2GRAY),cv2.COLOR_GRAY2BGR)
530
+ with torch.no_grad():
531
+ image1 = load_image(aim1g)
532
+ image2 = load_image(aim2g)
533
+ flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
534
+ displacement = flow_up[0].permute(1,2,0).detach().cpu().numpy()
535
+ grid_array = np.indices((Qh,Qw),dtype='float').transpose(1,2,0)
536
+ grid_array[:,:,[0,1]] = grid_array[:,:,[1,0]]
537
+ dis = grid_array-displacement
538
+ map = cv2.resize(dis, (int(scale*aim2x),int(scale*aim2y)),cv2.INTER_CUBIC)
539
+ map[:,:,0] = map[:,:,0]*aim1x/Qw
540
+ map[:,:,1] = map[:,:,1]*aim1y/Qh
541
+ warpr = map_coordinates(aim1[:,:,0],(map[:,:,1],map[:,:,0]), order=3, mode='nearest')
542
+ warpb = map_coordinates(aim1[:,:,1],(map[:,:,1],map[:,:,0]), order=3, mode='nearest')
543
+ warpg = map_coordinates(aim1[:,:,2],(map[:,:,1],map[:,:,0]), order=3, mode='nearest')
544
+ warp = cv2.merge((warpr,warpb,warpg))
545
+ white = np.ones_like(aim1[:,:,0])
546
+ mapw = cv2.resize(dis, (aim2x,aim2y),cv2.INTER_CUBIC)
547
+ mapw[:,:,0] = mapw[:,:,0]*aim1x/Qw
548
+ mapw[:,:,1] = mapw[:,:,1]*aim1y/Qh
549
+ warpw = map_coordinates(white,(mapw[:,:,1],mapw[:,:,0]), order=3, mode='constant')
550
+ top_left, bottom_right = find_rectangle(warpw)
551
+ if mode == 1:
552
+ top_left[0] = top_left[0] + top_left[0] % ogscale
553
+ top_left[1] = top_left[1] + top_left[1] % ogscale
554
+ bottom_right[0] = bottom_right[0] - (bottom_right[0]+1) % ogscale
555
+ bottom_right[1] = bottom_right[1] - (bottom_right[1]+1) % ogscale
556
+ warp = warp[int(scale*top_left[0]):int(scale*(bottom_right[0]+1)),int(scale*top_left[1]):int(scale*(bottom_right[1]+1))]
557
+ aim2 = aim2[top_left[0]:(bottom_right[0]+1),top_left[1]:(bottom_right[1]+1)]
558
+ return warp, aim2
559
+
560
+ def Align_Process(im1, im2):
561
+
562
+ # Make dummy array the dimensions of image 1
563
+ im1y, im1x, _ = im1.shape
564
+ im2y, im2x, _ = im2.shape
565
+ white1 = np.ones_like(im1[:,:,0])
566
+
567
+ if Manual:
568
+ points1, points2 = manual_points(im1, im2)
569
+ else:
570
+ points1, points2 = auto_points(im1, im2)
571
+ if semiauto:
572
+ points1, points2 = manual_points(im1, im2, points1, points2)
573
+
574
+ # Find transform based on points
575
+ if Homography:
576
+ smat = np.array([[scale,0,0],[0,scale,0],[0,0,1]])
577
+ h, _ = cv2.findHomography(points1, points2, cv2.RANSAC)
578
+ warp1 = cv2.warpPerspective(white1,h,(im2x,im2y),flags=0)
579
+
580
+ elif warp:
581
+ white1 = np.pad(white1,[(0,max(0,im2y-im1y)),(0,max(0,im2x-im1x))])
582
+ warp1 = WarpImage_TPS(points1, points2, white1, 0)
583
+ warp1 = warp1[0:im2y,0:im2x]
584
+
585
+ else:
586
+ smat = np.array([[scale,0],[0,scale]])
587
+ h, _ = cv2.estimateAffine2D(points1, points2, cv2.RANSAC)
588
+ if not rotate:
589
+ sx = math.sqrt(h[0,0]**2+h[1,0]**2)
590
+ sy = math.sqrt(h[0,1]**2+h[1,1]**2)
591
+ h[:,:2] = np.array([[sx,0],[0,sy]])
592
+
593
+ warp1 = cv2.warpAffine(white1,h,(im2x,im2y),flags=0)
594
+
595
+ # Get usable overlapping region
596
+ top_left, bottom_right = find_rectangle(warp1)
597
+
598
+ if not warp:
599
+ newh = smat @ h
600
+
601
+ # Ensure integer multiple scale down for mode 1
602
+ if mode == 1:
603
+ bottom_right[0] = bottom_right[0] - (bottom_right[0] - top_left[0] + 1) % (1/scale)
604
+ bottom_right[1] = bottom_right[1] - (bottom_right[1] - top_left[1] + 1) % (1/scale)
605
+
606
+ # Transform image 1
607
+ if Homography:
608
+ im1 = cv2.warpPerspective(im1,newh,(int(scale*(bottom_right[1]+1)),int(scale*(bottom_right[0]+1))),flags=cv2.INTER_CUBIC)
609
+ elif warp:
610
+ im1 = np.pad(im1,[(0,int(np.around(max(0,scale*im2y-im1y)))),(0,int(np.around(max(0,scale*im2x-im1x)))),(0,0)])
611
+ im1 = WarpImage_TPS(points1, scale*points2, im1, 1)
612
+ im1 = im1[:int(scale*(bottom_right[0]+1)),:int(scale*(bottom_right[1]+1))]
613
+ else:
614
+ im1 = cv2.warpAffine(im1,newh,(int(scale*(bottom_right[1]+1)),int(scale*(bottom_right[0]+1))),flags=cv2.INTER_CUBIC)
615
+
616
+ # Crop images
617
+ im1 = im1[int(scale*top_left[0]):,int(scale*top_left[1]):]
618
+ im2 = im2[top_left[0]:(bottom_right[0]+1),top_left[1]:(bottom_right[1]+1)]
619
+
620
+ if optical_flow:
621
+ im1, im2 = AI_Align_Process(im1, im2, pre = 1)
622
+
623
+ return im1, im2
624
+
625
+ def align_score(img1, img2):
626
+
627
+ img1 = cv2.resize(img1,(256,256),interpolation=cv2.INTER_CUBIC)
628
+ img2 = cv2.resize(img2,(256,256),interpolation=cv2.INTER_CUBIC)
629
+ points1, points2 = auto_points(img1,img2)
630
+ points1, points2 = ransac(points1,points2)
631
+ points = points2-points1
632
+ score = max(1-3*(np.sum(abs(points))/len(points))/100,0)
633
+
634
+ return score
635
+
636
+ def sort(file):
637
+ with open(file, "r") as f:
638
+ lines = f.readlines()
639
+ sorted_lines = sorted(lines)
640
+
641
+ with open(file, "w") as f:
642
+ f.writelines(sorted_lines)
643
+
644
+
645
+
646
+ def Do_Work(hrimg, lrimg, base = None):
647
+
648
+ highres = cv2.imread(hrimg, cv2.IMREAD_COLOR)
649
+ lowres = cv2.imread(lrimg, cv2.IMREAD_COLOR)
650
+
651
+ if autocrop:
652
+ highres = AutoCrop(highres)
653
+ lowres = AutoCrop(lowres)
654
+
655
+ if optical_flow and not (affi or rotate or Homography or warp):
656
+ if mode == 0:
657
+ highres, lowres = AI_Align_Process(highres, lowres)
658
+
659
+ if mode == 1:
660
+ lowres, highres = AI_Align_Process(lowres, highres)
661
+
662
+ else:
663
+ if mode == 0:
664
+ highres, lowres = Align_Process(highres, lowres)
665
+
666
+ if mode == 1:
667
+ lowres, highres = Align_Process(lowres, highres)
668
+
669
+ lowres = lowres[:lowres.shape[0]-lowres.shape[0]%2,:lowres.shape[1]-lowres.shape[1]%2]
670
+ highres = highres[:int(ogscale*lowres.shape[0]),:int(ogscale*lowres.shape[1])]
671
+
672
+ if color_correction == -1:
673
+ highres = PT.pdf_transfer(img_arr_in = highres, img_arr_ref = lowres, regrain = True)
674
+ elif color_correction == 1:
675
+ lowres = PT.pdf_transfer(img_arr_in = lowres, img_arr_ref = highres, regrain = True)
676
+
677
+ cv2.imwrite('Output/HR/{:s}.png'.format(base), highres)
678
+ cv2.imwrite('Output/LR/{:s}.png'.format(base), lowres)
679
+
680
+
681
+ if Overlay:
682
+
683
+ hhr, whr, _ = highres.shape
684
+ dim_overlay = (whr, hhr)
685
+ scalelr = cv2.resize(lowres,dim_overlay, interpolation=cv2.INTER_CUBIC)
686
+ overlay = cv2.addWeighted(highres,0.5,scalelr,0.5,0)
687
+ cv2.imwrite('Output/Overlay/{:s}.png'.format(base), overlay)
688
+
689
+ if score:
690
+ try:
691
+ ascore = align_score(lowres, highres)
692
+ except:
693
+ ascore = 0
694
+ print('{:s}'.format(base)+' score: '+str(ascore))
695
+ with open('Output/AlignmentScore.txt', 'a+') as f:
696
+ f.write('{:s}'.format(base)+' '+ str(ascore) +'\n')
697
+ f.close()
698
+
699
+
700
+ if not os.path.exists('Output'):
701
+ os.mkdir('Output')
702
+ if not os.path.exists('Output/LR'):
703
+ os.mkdir('Output/LR')
704
+ if not os.path.exists('Output/HR'):
705
+ os.mkdir('Output/HR')
706
+ if Overlay:
707
+ if not os.path.exists('Output/Overlay'):
708
+ os.mkdir('Output/Overlay')
709
+
710
+ ogscale = scale
711
+ if mode == 1:
712
+ scale = 1/scale
713
+
714
+ # Single image pair execution
715
+ if os.path.isfile(HRfolder):
716
+ base = os.path.splitext(os.path.basename(HRfolder))[0]
717
+ hrim = HRfolder
718
+ lrim = LRfolder
719
+ Do_Work(hrim, lrim, base)
720
+
721
+ elif threads > 1:
722
+
723
+ if len(HRfolder) == 0:
724
+ HRfolder = 'HR'
725
+ LRfolder = 'LR/'
726
+
727
+ # Create multithreading function
728
+ def multi(path):
729
+ base = os.path.splitext(os.path.basename(path))[0]
730
+ extention = os.path.splitext(os.path.basename(path))[1]
731
+ hrim = path
732
+ lrim = LRfolder+'/'+base+extention
733
+ print('{:s}'.format(base)+extention)
734
+ try:
735
+ Do_Work(hrim,lrim,base)
736
+ except:
737
+ with open('Output/Failed.txt', 'a+') as f:
738
+ f.write('{:s}'.format(base)+extention+'\n')
739
+ f.close()
740
+ print('Match failed for ','{:s}'.format(base)+extention)
741
+ with ThreadPoolExecutor(max_workers=threads) as executor:
742
+ futures = [executor.submit(multi,path)for path in glob.glob(HRfolder+'/*')]
743
+ try:
744
+ for future in as_completed(futures):
745
+ future.result()
746
+ except KeyboardInterrupt:
747
+ for future in futures:
748
+ future.cancel()
749
+ if os.path.exists('Output/Failed.txt'):
750
+ sort('Output/Failed.txt')
751
+ if score:
752
+ sort('Output/AlignmentScore.txt')
753
+
754
+ # Single threaded execution
755
+ else:
756
+ if len(HRfolder) == 0:
757
+ HRfolder = 'HR'
758
+ LRfolder = 'LR/'
759
+ for path in glob.glob(HRfolder+'/*'):
760
+ base = os.path.splitext(os.path.basename(path))[0]
761
+ extention = os.path.splitext(os.path.basename(path))[1]
762
+ hrim = path
763
+ lrim = LRfolder+'/'+base+extention
764
+ print('{:s}'.format(base)+extention)
765
+ try:
766
+ Do_Work(hrim, lrim, base)
767
+ except KeyboardInterrupt:
768
+ break
769
+ except:
770
+ with open('Output/Failed.txt', 'a+') as f:
771
+ f.write('{:s}'.format(base)+extention+'\n')
772
+ f.close()
773
+ print('Match failed for ','{:s}'.format(base)+extention)
774
+
775
+ if os.path.exists('Output/Failed.txt'):
776
+ sort('Output/Failed.txt')
777
+ if score:
778
+ sort('Output/AlignmentScore.txt')
779
+
780
+ def __main__():
781
+ pass