rawforge 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rawforge-0.1.0 → rawforge-0.2.0}/PKG-INFO +10 -2
- rawforge-0.2.0/RawForge/application/ImageSaver.py +55 -0
- rawforge-0.2.0/RawForge/application/InferenceWorker.py +79 -0
- rawforge-0.2.0/RawForge/application/MODEL_REGISTRY.py +135 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/application/ModelHandler.py +15 -98
- rawforge-0.2.0/RawForge/application/helpers/censored_fit.py +100 -0
- rawforge-0.2.0/RawForge/application/helpers/get_image.py +41 -0
- rawforge-0.2.0/RawForge/application/helpers/utils.py +17 -0
- rawforge-0.2.0/RawForge/application/postprocessing.py +92 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/main.py +36 -21
- {rawforge-0.1.0 → rawforge-0.2.0}/pyproject.toml +10 -3
- {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/PKG-INFO +10 -2
- {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/SOURCES.txt +4 -2
- rawforge-0.2.0/rawforge.egg-info/entry_points.txt +2 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/requires.txt +13 -1
- rawforge-0.1.0/RawForge/application/InferenceWorker.py +0 -98
- rawforge-0.1.0/RawForge/application/InferenceWorkerRawpy.py +0 -108
- rawforge-0.1.0/RawForge/application/MODEL_REGISTRY.py +0 -59
- rawforge-0.1.0/RawForge/application/postprocessing.py +0 -87
- rawforge-0.1.0/RawForge/application/utils.py +0 -10
- rawforge-0.1.0/rawforge.egg-info/entry_points.txt +0 -2
- {rawforge-0.1.0 → rawforge-0.2.0}/LICENSE +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/README.md +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/__init__.py +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/application/dng_utils.py +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/scripts/generate_keys.py +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/scripts/sign_models.py +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/dependency_links.txt +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/top_level.txt +0 -0
- {rawforge-0.1.0 → rawforge-0.2.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rawforge
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A compute backend/CLI application for using machine learning models on raw images.
|
|
5
5
|
Author: Ryan Mueller
|
|
6
6
|
License: MIT License
|
|
@@ -40,12 +40,20 @@ License-File: LICENSE
|
|
|
40
40
|
Requires-Dist: numpy>=2.2
|
|
41
41
|
Requires-Dist: RawHandler~=0.2.0
|
|
42
42
|
Requires-Dist: colour_demosaicing~=0.2.6
|
|
43
|
-
Requires-Dist:
|
|
43
|
+
Requires-Dist: blended-tiling-numpy
|
|
44
44
|
Requires-Dist: requests~=2.32
|
|
45
45
|
Requires-Dist: platformdirs~=4.5
|
|
46
46
|
Requires-Dist: tqdm~=4.67
|
|
47
47
|
Requires-Dist: cryptography>=46.0
|
|
48
48
|
Requires-Dist: tifffile
|
|
49
|
+
Provides-Extra: cpu
|
|
50
|
+
Requires-Dist: onnxruntime; extra == "cpu"
|
|
51
|
+
Provides-Extra: cuda
|
|
52
|
+
Requires-Dist: onnxruntime-gpu; extra == "cuda"
|
|
53
|
+
Provides-Extra: web
|
|
54
|
+
Requires-Dist: onnxruntime-web; extra == "web"
|
|
55
|
+
Provides-Extra: directml
|
|
56
|
+
Requires-Dist: onnxruntime-directml; extra == "directml"
|
|
49
57
|
Dynamic: license-file
|
|
50
58
|
|
|
51
59
|
# RawRefinery
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from RawForge.application.dng_utils import convert_color_matrix, to_dng
|
|
3
|
+
import tifffile
|
|
4
|
+
|
|
5
|
+
class ImageSaver():
|
|
6
|
+
def __init__(self, model_params, rh, dims=None):
|
|
7
|
+
self.rh = rh
|
|
8
|
+
self.model_params = model_params
|
|
9
|
+
self.dims = dims
|
|
10
|
+
|
|
11
|
+
def to_tiff(self, image, filename, apply_ccm=True):
|
|
12
|
+
image = image
|
|
13
|
+
if apply_ccm:
|
|
14
|
+
transform_matrix = self.rh.rgb_colorspace_transform(colorspace='lin_rec2020')
|
|
15
|
+
image = image[0].transpose(1, 2, 0)
|
|
16
|
+
transformed = image @ transform_matrix.T
|
|
17
|
+
else:
|
|
18
|
+
transformed = image[0].transpose(1, 2, 0)
|
|
19
|
+
|
|
20
|
+
transformed = np.clip(transformed, 0, 1)
|
|
21
|
+
transformed = transformed ** (1/2.2)
|
|
22
|
+
transformed = transformed * (2 ** 8 - 1)
|
|
23
|
+
|
|
24
|
+
uint_img = transformed.astype(np.uint8)
|
|
25
|
+
|
|
26
|
+
tifffile.imwrite(
|
|
27
|
+
filename,
|
|
28
|
+
uint_img,
|
|
29
|
+
photometric='rgb', # Explicitly define the color space
|
|
30
|
+
compression='deflate' # Optional: Lossless compression supported by darktable
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def to_raw(self, denoised, filename, save_cfa):
|
|
34
|
+
# Compute CFA
|
|
35
|
+
if self.model_params['demosaicing'] == 'rawpy':
|
|
36
|
+
_, mask = self.rh.compute_mask_and_sparse(dims=self.dims)
|
|
37
|
+
denoised = denoised[0]
|
|
38
|
+
denoised = denoised.clip(0, 1)
|
|
39
|
+
|
|
40
|
+
denoised = np.where(mask, denoised, 0)
|
|
41
|
+
denoised = denoised.sum(axis=0)
|
|
42
|
+
denoised = denoised * ( self.rh.core_metadata.white_level) + self.rh.core_metadata.black_level_per_channel[0]
|
|
43
|
+
self.rh.to_dng(filename, uint_img=denoised)
|
|
44
|
+
else:
|
|
45
|
+
transform_matrix = np.linalg.inv(
|
|
46
|
+
self.rh.rgb_colorspace_transform(colorspace=self.rh.colorspace)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
CCM = self.rh.rgb_colorspace_transform(colorspace='XYZ')
|
|
50
|
+
CCM = np.linalg.inv(CCM)
|
|
51
|
+
denoised = denoised[0].transpose(1, 2, 0)
|
|
52
|
+
transformed = denoised @ transform_matrix.T
|
|
53
|
+
uint_img = np.clip(transformed * 2**16-1, 0, 2**16-1).astype(np.uint16)
|
|
54
|
+
ccm1 = convert_color_matrix(CCM)
|
|
55
|
+
to_dng(uint_img, self.rh, filename, ccm1, save_cfa=save_cfa, convert_to_cfa=True)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import numpy as np
|
|
3
|
+
from blended_tiling_numpy import TilingModule
|
|
4
|
+
from RawForge.application.postprocessing import match_colors_linear
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
import rawpy
|
|
7
|
+
|
|
8
|
+
class InferenceWorker():
|
|
9
|
+
def __init__(self, model, model_params, conditioning, tile_size=512, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.model = model
|
|
12
|
+
self.model_params = model_params
|
|
13
|
+
self.conditioning = conditioning
|
|
14
|
+
self.tile_size = tile_size
|
|
15
|
+
if 'tile_size' in model_params:
|
|
16
|
+
self.tile_size = model_params['tile_size']
|
|
17
|
+
self.tile_overlap = tile_overlap
|
|
18
|
+
self.batch_size = batch_size
|
|
19
|
+
if 'batch_size' in model_params:
|
|
20
|
+
self.batch_size = model_params['batch_size']
|
|
21
|
+
self._is_cancelled = False
|
|
22
|
+
self.disable_tqdm = disable_tqdm
|
|
23
|
+
|
|
24
|
+
def cancel(self):
|
|
25
|
+
self._is_cancelled = True
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _tile_process(self, image_RGB, model_params):
|
|
29
|
+
# Prepare Data
|
|
30
|
+
full_size = [image_RGB.shape[2], image_RGB.shape[3]]
|
|
31
|
+
tile_size = [self.tile_size, self.tile_size]
|
|
32
|
+
overlap = [self.tile_overlap, self.tile_overlap]
|
|
33
|
+
# Tiling Setup
|
|
34
|
+
tiling_module_rgb = TilingModule(tile_size=[s for s in tile_size], tile_overlap=overlap, base_size=[s for s in full_size])
|
|
35
|
+
|
|
36
|
+
tiles_rgb = tiling_module_rgb.split_into_tiles(image_RGB)
|
|
37
|
+
|
|
38
|
+
batches_rgb = [tiles_rgb[i : i + self.batch_size]
|
|
39
|
+
for i in range(0, len(tiles_rgb), self.batch_size)]
|
|
40
|
+
# Conditioning Setup
|
|
41
|
+
cond_tensor = np.array([self.conditioning]).astype(np.float32)
|
|
42
|
+
cond_tensor[:, 0] /= 6400.
|
|
43
|
+
cond_tensor[:, 1] = 0.
|
|
44
|
+
cond_tensor = cond_tensor[:, 0:1]
|
|
45
|
+
cond_tensor = cond_tensor.astype(np.float16)
|
|
46
|
+
if 'cond_scale' in model_params:
|
|
47
|
+
cond_tensor *= cond_tensor * model_params['cond_scale']
|
|
48
|
+
|
|
49
|
+
processed_batches = []
|
|
50
|
+
|
|
51
|
+
# Inference Loop
|
|
52
|
+
for i, (batch_rgb) in tqdm(enumerate(batches_rgb), disable=self.disable_tqdm):
|
|
53
|
+
if self._is_cancelled: return None, None
|
|
54
|
+
|
|
55
|
+
B = batch_rgb.shape[0]
|
|
56
|
+
# Expand conditioning to match batch size
|
|
57
|
+
curr_cond = np.broadcast_to(cond_tensor, (B, cond_tensor.shape[-1])).astype(np.float16)
|
|
58
|
+
|
|
59
|
+
payload = {
|
|
60
|
+
"input": batch_rgb,
|
|
61
|
+
"cond": curr_cond
|
|
62
|
+
}
|
|
63
|
+
# Filter based on inputs
|
|
64
|
+
model_inputs = {i.name for i in self.model.get_inputs()}
|
|
65
|
+
filtered_inputs = {k: v for k, v in payload.items() if k in model_inputs}
|
|
66
|
+
output = self.model.run(["output"], filtered_inputs)
|
|
67
|
+
|
|
68
|
+
processed_batches.append(output[0])
|
|
69
|
+
|
|
70
|
+
# Rebuild
|
|
71
|
+
tiles_out = np.concat(processed_batches, axis=0)
|
|
72
|
+
stitched = tiling_module_rgb.rebuild_with_masks(tiles_out)
|
|
73
|
+
|
|
74
|
+
return image_RGB, stitched
|
|
75
|
+
|
|
76
|
+
def run(self, model_params, image_RGB):
|
|
77
|
+
img, denoised_img = self._tile_process(image_RGB, model_params)
|
|
78
|
+
return img, denoised_img
|
|
79
|
+
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
MODEL_REGISTRY = {
|
|
2
|
+
"TreeNetDenoise": {
|
|
3
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1.onnx",
|
|
4
|
+
"filename": "ShadowWeightedL1.onnx",
|
|
5
|
+
"max_iso": 65535,
|
|
6
|
+
"demosaicing": "Malvar2004",
|
|
7
|
+
},
|
|
8
|
+
"TreeNetDenoiseLight": {
|
|
9
|
+
"url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_light.onnx",
|
|
10
|
+
"filename": "ShadowWeightedL1_light.onnx",
|
|
11
|
+
"max_iso": 65535,
|
|
12
|
+
"demosaicing": "Malvar2004",
|
|
13
|
+
},
|
|
14
|
+
"TreeNetDenoiseSuperLight": {
|
|
15
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_super_light.onnx",
|
|
16
|
+
"filename": "ShadowWeightedL1_super_light.onnx",
|
|
17
|
+
"max_iso": 65535,
|
|
18
|
+
"demosaicing": "Malvar2004",
|
|
19
|
+
},
|
|
20
|
+
"TreeNetDenoiseHeavy": {
|
|
21
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_24_deep_500.onnx",
|
|
22
|
+
"filename": "ShadowWeightedL1_24_deep_500.onnx",
|
|
23
|
+
"max_iso": 65535,
|
|
24
|
+
"demosaicing": "Malvar2004",
|
|
25
|
+
},
|
|
26
|
+
"Deblur": {
|
|
27
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/realblur_gamma_140.onnx",
|
|
28
|
+
"filename": "realblur_gamma_140.onnx",
|
|
29
|
+
"demosaicing": "Malvar2004",
|
|
30
|
+
},
|
|
31
|
+
"DeepSharpen": {
|
|
32
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/Deblur_deep_24.onnx",
|
|
33
|
+
"filename": "Deblur_deep_24.onnx",
|
|
34
|
+
"demosaicing": "Malvar2004",
|
|
35
|
+
},
|
|
36
|
+
"TreeNetDenoiseXTrans": {
|
|
37
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.onnx",
|
|
38
|
+
"filename": "xtrans_fixed_exposure_no_conditioning_400.onnx",
|
|
39
|
+
"demosaicing": "rawpy",
|
|
40
|
+
"conditioning": "false",
|
|
41
|
+
"batch_size": 1,
|
|
42
|
+
"crop_size": 256,
|
|
43
|
+
"cond_scale": 0,
|
|
44
|
+
},
|
|
45
|
+
"RestormerXTrans": {
|
|
46
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/restormer.onnx",
|
|
47
|
+
"filename": "restormer.onnx",
|
|
48
|
+
"demosaicing": "rawpy",
|
|
49
|
+
"conditioning": "false",
|
|
50
|
+
"batch_size": 1,
|
|
51
|
+
},
|
|
52
|
+
"XFormerXTrans2": {
|
|
53
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xformer_fp16.onnx",
|
|
54
|
+
"filename": "xformer_fp16.onnx",
|
|
55
|
+
"demosaicing": "rawpy",
|
|
56
|
+
"conditioning": "false",
|
|
57
|
+
"batch_size": 1,
|
|
58
|
+
},
|
|
59
|
+
|
|
60
|
+
"XFormerDenoise": {
|
|
61
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xformer_static.onnx",
|
|
62
|
+
"filename": "xformer_static.onnx",
|
|
63
|
+
"demosaicing": "rawpy",
|
|
64
|
+
"conditioning": "false",
|
|
65
|
+
"batch_size": 1,
|
|
66
|
+
"crop_size": 256,
|
|
67
|
+
},
|
|
68
|
+
"NAFDenoise": {
|
|
69
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_static.onnx",
|
|
70
|
+
"filename": "NAF_static.onnx",
|
|
71
|
+
"demosaicing": "rawpy",
|
|
72
|
+
"conditioning": "false",
|
|
73
|
+
"batch_size": 1,
|
|
74
|
+
"crop_size": 256,
|
|
75
|
+
},
|
|
76
|
+
|
|
77
|
+
"NAFDenoise_dynamic": {
|
|
78
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_dynamic.onnx",
|
|
79
|
+
"filename": "NAF_dynamic.onnx",
|
|
80
|
+
"demosaicing": "rawpy",
|
|
81
|
+
"conditioning": "false",
|
|
82
|
+
"batch_size": 1,
|
|
83
|
+
"crop_size": 256,
|
|
84
|
+
},
|
|
85
|
+
"NAFDenoise_dynamic_fp16": {
|
|
86
|
+
"url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_dynamic_fp16.onnx",
|
|
87
|
+
"filename": "NAF_dynamic_fp16.onnx",
|
|
88
|
+
"demosaicing": "rawpy",
|
|
89
|
+
"conditioning": "false",
|
|
90
|
+
"batch_size": 1,
|
|
91
|
+
"crop_size": 256,
|
|
92
|
+
},
|
|
93
|
+
"NAF_static_fp16": {
|
|
94
|
+
"url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_static_fp16.onnx",
|
|
95
|
+
"filename": "NAF_static_fp16.onnx",
|
|
96
|
+
"demosaicing": "rawpy",
|
|
97
|
+
"conditioning": "false",
|
|
98
|
+
"batch_size": 1,
|
|
99
|
+
"crop_size": 256,
|
|
100
|
+
},
|
|
101
|
+
"NAF_trace_test": {
|
|
102
|
+
"url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.onnx",
|
|
103
|
+
"filename": "xtrans_fixed_exposure_no_conditioning_400.onnx",
|
|
104
|
+
"demosaicing": "rawpy",
|
|
105
|
+
"conditioning": "false",
|
|
106
|
+
"batch_size": 1,
|
|
107
|
+
"crop_size": 256,
|
|
108
|
+
},
|
|
109
|
+
"non_trace_test": {
|
|
110
|
+
"url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/non_trace_test.onnx",
|
|
111
|
+
"filename": "non_trace_test.onnx",
|
|
112
|
+
"demosaicing": "rawpy",
|
|
113
|
+
"conditioning": "false",
|
|
114
|
+
"batch_size": 1,
|
|
115
|
+
"crop_size": 256,
|
|
116
|
+
},
|
|
117
|
+
"trace_test": {
|
|
118
|
+
"url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/trace_test.onnx",
|
|
119
|
+
"filename": "trace_test.onnx",
|
|
120
|
+
"demosaicing": "rawpy",
|
|
121
|
+
"conditioning": "false",
|
|
122
|
+
"batch_size": 1,
|
|
123
|
+
"crop_size": 256,
|
|
124
|
+
},
|
|
125
|
+
"non_trace_test_opset_20_static_simp": {
|
|
126
|
+
"url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/non_trace_test_opset_20_static_simp.onnx",
|
|
127
|
+
"filename": "non_trace_test_opset_20_static_simp.onnx",
|
|
128
|
+
"demosaicing": "rawpy",
|
|
129
|
+
"conditioning": "false",
|
|
130
|
+
"batch_size": 1,
|
|
131
|
+
"crop_size": 256,
|
|
132
|
+
},
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
}
|
|
@@ -1,21 +1,15 @@
|
|
|
1
|
-
import
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import os
|
|
2
3
|
import numpy as np
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from platformdirs import user_data_dir
|
|
5
|
-
from time import perf_counter
|
|
6
6
|
import requests
|
|
7
7
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
8
8
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
9
9
|
|
|
10
|
-
from RawHandler.RawHandler import RawHandler
|
|
11
|
-
from RawHandler.RawHandlerRawpy import RawHandlerRawpy
|
|
12
|
-
|
|
13
|
-
from RawForge.application.dng_utils import convert_color_matrix, to_dng
|
|
14
|
-
from RawForge.application.utils import can_use_gpu
|
|
15
|
-
|
|
16
10
|
from RawForge.application.MODEL_REGISTRY import MODEL_REGISTRY
|
|
17
11
|
from RawForge.application.InferenceWorker import InferenceWorker
|
|
18
|
-
from RawForge.application.
|
|
12
|
+
from RawForge.application.helpers.utils import get_best_providers
|
|
19
13
|
|
|
20
14
|
key_string = '''-----BEGIN PUBLIC KEY-----
|
|
21
15
|
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA8iRGMPqFIFVF0TM/AbMI
|
|
@@ -40,46 +34,23 @@ class ModelHandler():
|
|
|
40
34
|
def __init__(self):
|
|
41
35
|
super().__init__()
|
|
42
36
|
|
|
43
|
-
self.model = None
|
|
44
|
-
|
|
45
|
-
self.
|
|
46
|
-
self.
|
|
47
|
-
|
|
37
|
+
self.model = None
|
|
38
|
+
app_name = "RawForge"
|
|
39
|
+
self.data_dir = Path(user_data_dir(app_name))
|
|
40
|
+
self.cache_dir = Path(os.path.expanduser(f"{self.data_dir}/model_cache")).resolve()
|
|
48
41
|
# Manage devices
|
|
49
|
-
|
|
50
|
-
"cuda": can_use_gpu(),
|
|
51
|
-
"mps": torch.backends.mps.is_available(),
|
|
52
|
-
"cpu": lambda : True
|
|
53
|
-
}
|
|
54
|
-
self.devices = [d for d, is_available in devices.items() if is_available]
|
|
55
|
-
self.set_device(self.devices[0])
|
|
56
|
-
|
|
57
|
-
self.filename = None
|
|
58
|
-
self.start_time = None
|
|
59
|
-
self.model_params = {}
|
|
60
|
-
|
|
42
|
+
self.providers = get_best_providers(cache_dir=self.cache_dir)
|
|
61
43
|
self.pub = serialization.load_pem_public_key(key_string.encode('utf-8'))
|
|
62
44
|
|
|
63
|
-
def load_rh(self, path):
|
|
64
|
-
"""Loads the raw file handler"""
|
|
65
|
-
if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
|
|
66
|
-
self.rh = RawHandlerRawpy(path)
|
|
67
|
-
else:
|
|
68
|
-
self.rh = RawHandler(path)
|
|
69
|
-
self.iso = self.rh.full_metadata.get_ISO()
|
|
70
|
-
return self.iso
|
|
71
45
|
|
|
72
46
|
def load_model(self, model_key):
|
|
73
47
|
"""Loads a model by key from the registry"""
|
|
74
48
|
if model_key not in MODEL_REGISTRY:
|
|
75
49
|
print(f"Model {model_key} not found in registry.")
|
|
76
50
|
return
|
|
77
|
-
|
|
78
51
|
conf = MODEL_REGISTRY[model_key]
|
|
79
52
|
self.model_params = conf
|
|
80
|
-
|
|
81
|
-
data_dir = Path(user_data_dir(app_name))
|
|
82
|
-
model_path = data_dir / conf["filename"]
|
|
53
|
+
model_path = self.data_dir / conf["filename"]
|
|
83
54
|
|
|
84
55
|
# Handle Download
|
|
85
56
|
if not model_path.is_file():
|
|
@@ -97,41 +68,15 @@ class ModelHandler():
|
|
|
97
68
|
# Verify model before load
|
|
98
69
|
self._verify_model(model_path, model_path.with_suffix(f'{model_path.suffix}.sig'))
|
|
99
70
|
|
|
100
|
-
|
|
101
|
-
|
|
71
|
+
session = ort.InferenceSession(
|
|
72
|
+
model_path,
|
|
73
|
+
providers=self.providers,
|
|
74
|
+
)
|
|
75
|
+
print("Loaded!")
|
|
76
|
+
self.model = session
|
|
102
77
|
except Exception as e:
|
|
103
78
|
print(f"Failed to load model: {e}")
|
|
104
79
|
|
|
105
|
-
def set_device(self, device):
|
|
106
|
-
self.device = torch.device(device)
|
|
107
|
-
if self.model:
|
|
108
|
-
self.model.to(self.device)
|
|
109
|
-
print(f"Using Device {self.device} from {device}")
|
|
110
|
-
|
|
111
|
-
def run_inference(self, conditioning, dims=None, inference_kwargs={}):
|
|
112
|
-
"""Starts the worker thread"""
|
|
113
|
-
if not self.model or not self.rh:
|
|
114
|
-
print("Model or Image not loaded.")
|
|
115
|
-
return
|
|
116
|
-
|
|
117
|
-
# Some older models were trained with a lower max iso
|
|
118
|
-
if "max_iso" in self.model_params:
|
|
119
|
-
conditioning[0] = min(conditioning[0], self.model_params["max_iso"])
|
|
120
|
-
|
|
121
|
-
if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
|
|
122
|
-
worker = InferenceWorkerRawpy(self.model, self.model_params, self.device, self.rh, conditioning, dims, **inference_kwargs)
|
|
123
|
-
else:
|
|
124
|
-
worker = InferenceWorker(self.model, self.model_params, self.device, self.rh, conditioning, dims, **inference_kwargs)
|
|
125
|
-
img, final_denoised = worker.run()
|
|
126
|
-
|
|
127
|
-
return img, final_denoised
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def generate_thumbnail(self, size=400):
|
|
131
|
-
if not self.rh: return None
|
|
132
|
-
thumb = self.rh.generate_thumbnail(min_preview_size=size, clip=True)
|
|
133
|
-
return thumb
|
|
134
|
-
|
|
135
80
|
def _verify_model(self, dest_path, sig_path):
|
|
136
81
|
try:
|
|
137
82
|
data = Path(dest_path).read_bytes()
|
|
@@ -156,7 +101,6 @@ class ModelHandler():
|
|
|
156
101
|
print(f"Model {dest_path} not verified! Deleting.")
|
|
157
102
|
return False
|
|
158
103
|
|
|
159
|
-
|
|
160
104
|
def _download_file(self, url, dest_path):
|
|
161
105
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
162
106
|
try:
|
|
@@ -179,31 +123,4 @@ class ModelHandler():
|
|
|
179
123
|
except Exception as e:
|
|
180
124
|
print(e)
|
|
181
125
|
return False
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
def handle_full_image(self, denoised, filename, save_cfa):
|
|
185
|
-
# Compute CFA
|
|
186
|
-
if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
|
|
187
|
-
_, mask = self.rh.compute_mask_and_sparse(dims=(0, 99999, 0, 99999))
|
|
188
|
-
denoised = denoised.transpose(2, 0, 1)
|
|
189
|
-
denoised = denoised.clip(0, 1)
|
|
190
|
-
|
|
191
|
-
denoised = np.where(mask, denoised, 0)
|
|
192
|
-
denoised = denoised.sum(axis=0)
|
|
193
|
-
denoised = denoised * ( self.rh.core_metadata.white_level) + self.rh.core_metadata.black_level_per_channel[0]
|
|
194
|
-
self.rh.to_dng(filename, uint_img=denoised)
|
|
195
|
-
else:
|
|
196
|
-
transform_matrix = np.linalg.inv(
|
|
197
|
-
self.rh.rgb_colorspace_transform(colorspace=self.colorspace)
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
CCM = self.rh.rgb_colorspace_transform(colorspace='XYZ')
|
|
201
|
-
CCM = np.linalg.inv(CCM)
|
|
202
|
-
|
|
203
|
-
transformed = denoised @ transform_matrix.T
|
|
204
|
-
uint_img = np.clip(transformed * 2**16-1, 0, 2**16-1).astype(np.uint16)
|
|
205
|
-
ccm1 = convert_color_matrix(CCM)
|
|
206
|
-
to_dng(uint_img, self.rh, filename, ccm1, save_cfa=save_cfa, convert_to_cfa=True)
|
|
207
|
-
|
|
208
|
-
|
|
209
126
|
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
def norm_cdf(z):
|
|
5
|
+
erf_vec = np.vectorize(math.erf)
|
|
6
|
+
return 0.5 * (1 + erf_vec(z / np.sqrt(2)))
|
|
7
|
+
|
|
8
|
+
def norm_pdf(z):
|
|
9
|
+
return (1 / np.sqrt(2 * np.pi)) * np.exp(-0.5 * z**2)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def censored_linear_fit_twosided(x, y, clip_low=0, clip_high=1,
|
|
13
|
+
max_iter=200, tol=1e-6, include_offset=True):
|
|
14
|
+
"""
|
|
15
|
+
Fit y ≈ a + b*x + ε, ε ~ N(0, σ²) under two-sided censoring:
|
|
16
|
+
clip_low ≤ y_true ≤ clip_high
|
|
17
|
+
Observed y are clipped to [clip_low, clip_high].
|
|
18
|
+
Returns (a, b, sigma) estimated via EM.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
x, y : array_like
|
|
23
|
+
Input data.
|
|
24
|
+
clip_low, clip_high : float or None
|
|
25
|
+
Lower/upper clip levels. Can be None for one-sided clipping.
|
|
26
|
+
max_iter : int
|
|
27
|
+
Maximum EM iterations.
|
|
28
|
+
tol : float
|
|
29
|
+
Relative tolerance for convergence.
|
|
30
|
+
include_offset : bool
|
|
31
|
+
If False, forces intercept a=0 (fit y ≈ b*x).
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
a, b, sigma : floats
|
|
36
|
+
Estimated regression parameters.
|
|
37
|
+
"""
|
|
38
|
+
x = np.asarray(x).ravel()
|
|
39
|
+
y = np.asarray(y).ravel()
|
|
40
|
+
mask = np.isfinite(x) & np.isfinite(y)
|
|
41
|
+
x, y = x[mask], y[mask]
|
|
42
|
+
|
|
43
|
+
if len(x) < 3:
|
|
44
|
+
raise ValueError("Not enough data points.")
|
|
45
|
+
|
|
46
|
+
# Initial guess (ordinary least squares)
|
|
47
|
+
if include_offset:
|
|
48
|
+
A = np.vstack([np.ones_like(x), x]).T
|
|
49
|
+
a, b = np.linalg.lstsq(A, y, rcond=None)[0]
|
|
50
|
+
else:
|
|
51
|
+
b = np.dot(x, y) / np.dot(x, x)
|
|
52
|
+
a = 0.0
|
|
53
|
+
sigma = np.std(y - (a + b*x))
|
|
54
|
+
|
|
55
|
+
for _ in range(max_iter):
|
|
56
|
+
mu = a + b*x
|
|
57
|
+
y_exp = y.copy()
|
|
58
|
+
|
|
59
|
+
# Handle right-censoring (high clip)
|
|
60
|
+
if clip_high is not None:
|
|
61
|
+
high_mask = y >= clip_high - 1e-12
|
|
62
|
+
if np.any(high_mask):
|
|
63
|
+
z = (clip_high - mu[high_mask]) / sigma
|
|
64
|
+
Phi = norm_cdf(z)
|
|
65
|
+
phi = norm_pdf(z)
|
|
66
|
+
one_minus_Phi = 1.0 - Phi
|
|
67
|
+
lambda_ = np.zeros_like(z)
|
|
68
|
+
valid = one_minus_Phi > 1e-15
|
|
69
|
+
lambda_[valid] = phi[valid] / one_minus_Phi[valid]
|
|
70
|
+
y_exp[high_mask] = mu[high_mask] + sigma * lambda_
|
|
71
|
+
|
|
72
|
+
# Handle left-censoring (low clip)
|
|
73
|
+
if clip_low is not None:
|
|
74
|
+
low_mask = y <= clip_low + 1e-12
|
|
75
|
+
if np.any(low_mask):
|
|
76
|
+
z = (clip_low - mu[low_mask]) / sigma
|
|
77
|
+
Phi = norm_cdf(z)
|
|
78
|
+
phi = norm_pdf(z)
|
|
79
|
+
lambda_ = np.zeros_like(z)
|
|
80
|
+
valid = Phi > 1e-15
|
|
81
|
+
lambda_[valid] = -phi[valid] / Phi[valid]
|
|
82
|
+
y_exp[low_mask] = mu[low_mask] + sigma * lambda_
|
|
83
|
+
|
|
84
|
+
# M-step: re-fit with imputed expectations
|
|
85
|
+
if include_offset:
|
|
86
|
+
A = np.vstack([np.ones_like(x), x]).T
|
|
87
|
+
a_new, b_new = np.linalg.lstsq(A, y_exp, rcond=None)[0]
|
|
88
|
+
else:
|
|
89
|
+
b_new = np.dot(x, y_exp) / np.dot(x, x)
|
|
90
|
+
a_new = 0.0
|
|
91
|
+
|
|
92
|
+
sigma_new = np.std(y_exp - (a_new + b_new*x))
|
|
93
|
+
|
|
94
|
+
if np.allclose([a, b, sigma], [a_new, b_new, sigma_new],
|
|
95
|
+
rtol=tol, atol=tol):
|
|
96
|
+
break
|
|
97
|
+
|
|
98
|
+
a, b, sigma = a_new, b_new, sigma_new
|
|
99
|
+
|
|
100
|
+
return a, b, sigma
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import rawpy
|
|
3
|
+
from RawHandler.RawHandler import RawHandler
|
|
4
|
+
from RawHandler.RawHandlerRawpy import RawHandlerRawpy
|
|
5
|
+
|
|
6
|
+
def load_rh(path, demosaicing):
|
|
7
|
+
"""Loads the raw file handler"""
|
|
8
|
+
if demosaicing == 'rawpy':
|
|
9
|
+
rh = RawHandlerRawpy(path)
|
|
10
|
+
elif demosaicing == 'Malvar2004':
|
|
11
|
+
rh = RawHandler(path)
|
|
12
|
+
iso = rh.full_metadata.get_ISO()
|
|
13
|
+
return rh, iso
|
|
14
|
+
|
|
15
|
+
def get_image(path, model_params, dims=None):
|
|
16
|
+
demosaicing = model_params['demosaicing']
|
|
17
|
+
rh, iso = load_rh(path, demosaicing)
|
|
18
|
+
|
|
19
|
+
if demosaicing=='Malvar2004':
|
|
20
|
+
from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004
|
|
21
|
+
image_RGB = rh.as_rgb(dims=dims,
|
|
22
|
+
demosaicing_func=demosaicing_CFA_Bayer_Malvar2004,
|
|
23
|
+
colorspace='lin_rec2020', clip=True)
|
|
24
|
+
elif demosaicing=='rawpy':
|
|
25
|
+
image_RGB = rh.rawpy_object.postprocess(
|
|
26
|
+
user_wb=[1, 1, 1, 1],
|
|
27
|
+
output_color=rawpy.ColorSpace.raw,
|
|
28
|
+
demosaic_algorithm= rawpy.DemosaicAlgorithm(3),
|
|
29
|
+
no_auto_bright=True,
|
|
30
|
+
use_camera_wb=False,
|
|
31
|
+
use_auto_wb=False,
|
|
32
|
+
gamma=(1, 1),
|
|
33
|
+
user_flip=0,
|
|
34
|
+
output_bps=16,
|
|
35
|
+
no_auto_scale=True,
|
|
36
|
+
) / rh.rawpy_object.white_level
|
|
37
|
+
if dims is not None:
|
|
38
|
+
image_RGB = image_RGB[dims[0]:dims[1], dims[2]:dims[3]]
|
|
39
|
+
image_RGB = image_RGB.transpose(2, 0, 1)
|
|
40
|
+
image_RGB = np.expand_dims(image_RGB, axis=0).astype(np.float16)
|
|
41
|
+
return rh, image_RGB, iso
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
def get_best_providers(cache_dir):
|
|
5
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
6
|
+
available = ort.get_available_providers()
|
|
7
|
+
return available
|
|
8
|
+
priority = [
|
|
9
|
+
"CUDAExecutionProvider", # NVIDIA
|
|
10
|
+
"ROCMExecutionProvider", # AMD (Direct)
|
|
11
|
+
"MIGraphXExecutionProvider", # AMD (Optimized)
|
|
12
|
+
"CoreMLExecutionProvider", # Apple Silicon
|
|
13
|
+
"DmlExecutionProvider", # Windows (AMD/Intel/Generic GPU)
|
|
14
|
+
"CPUExecutionProvider" # The fallback
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
return [p for p in priority if p in available]
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from RawForge.application.helpers.censored_fit import censored_linear_fit_twosided
|
|
3
|
+
|
|
4
|
+
def match_colors_linear(
|
|
5
|
+
src: np.ndarray,
|
|
6
|
+
tgt: np.ndarray,
|
|
7
|
+
sample_fraction: float = 0.05,
|
|
8
|
+
censored_fit: bool = True
|
|
9
|
+
):
|
|
10
|
+
"""
|
|
11
|
+
Fit per-channel affine color transforms using NumPy.
|
|
12
|
+
src/tgt: [B, C, H, W]
|
|
13
|
+
"""
|
|
14
|
+
B, C, H, W = src.shape
|
|
15
|
+
|
|
16
|
+
# Flatten spatial dims: [B, C, H*W]
|
|
17
|
+
src_flat = src.reshape(B, C, -1)
|
|
18
|
+
tgt_flat = tgt.reshape(B, C, -1)
|
|
19
|
+
|
|
20
|
+
# Sample subset of pixels
|
|
21
|
+
N = src_flat.shape[-1]
|
|
22
|
+
k = max(64, int(N * sample_fraction))
|
|
23
|
+
|
|
24
|
+
# Generate random indices
|
|
25
|
+
idx = np.random.randint(0, N, size=(k,))
|
|
26
|
+
|
|
27
|
+
src_s = src_flat[..., idx] # [B, C, k]
|
|
28
|
+
tgt_s = tgt_flat[..., idx]
|
|
29
|
+
|
|
30
|
+
# Compute scale and bias using least squares
|
|
31
|
+
src_mean = src_s.mean(axis=-1, keepdims=True)
|
|
32
|
+
tgt_mean = tgt_s.mean(axis=-1, keepdims=True)
|
|
33
|
+
|
|
34
|
+
src_centered = src_s - src_mean
|
|
35
|
+
tgt_centered = tgt_s - tgt_mean
|
|
36
|
+
|
|
37
|
+
var_src = (src_centered ** 2).mean(axis=-1)
|
|
38
|
+
cov = (src_centered * tgt_centered).mean(axis=-1)
|
|
39
|
+
|
|
40
|
+
scale = cov / (var_src + 1e-8) # [B, C]
|
|
41
|
+
# Squeeze the mean to [B, C] for bias calculation
|
|
42
|
+
bias = tgt_mean.squeeze(-1) - scale * src_mean.squeeze(-1)
|
|
43
|
+
|
|
44
|
+
# Apply correction: reshape for broadcasting to [B, C, H, W]
|
|
45
|
+
scale_ = scale[:, :, np.newaxis, np.newaxis]
|
|
46
|
+
bias_ = bias[:, :, np.newaxis, np.newaxis]
|
|
47
|
+
transformed = src * scale_ + bias_
|
|
48
|
+
|
|
49
|
+
if censored_fit:
|
|
50
|
+
a, b, sigma = censored_linear_fit_twosided(tgt_s.astype(np.float32()), src_s.astype(np.float32()))
|
|
51
|
+
transformed = a + b * transformed
|
|
52
|
+
|
|
53
|
+
return transformed, scale, bias
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def scaled_dot_product(x1, x2, eps=1e-6):
|
|
57
|
+
# Sum along axis 2 (C if input is B, H, C)
|
|
58
|
+
dot = (x1 * x2).sum(axis=2, keepdims=True)
|
|
59
|
+
x1_mag = np.sqrt((x1 * x1).sum(axis=2, keepdims=True))
|
|
60
|
+
x2_mag = np.sqrt((x2 * x2).sum(axis=2, keepdims=True))
|
|
61
|
+
return dot / (x1_mag + x2_mag + eps)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def postprocess(img, denoised, lumi_blend=0, chroma_blend=0, eps=1e-6,
|
|
65
|
+
clip_highlights=False, affine=False):
|
|
66
|
+
|
|
67
|
+
if affine:
|
|
68
|
+
_denoised = np.expand_dims(denoised.transpose(2, 0, 1), axis=0)
|
|
69
|
+
_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
|
70
|
+
_denoised, _, _ = match_colors_linear(_denoised, _img)
|
|
71
|
+
denoised = _denoised[0].transpose(1, 2, 0)
|
|
72
|
+
|
|
73
|
+
dot = (img * denoised).sum(axis=2, keepdims=True)
|
|
74
|
+
denoised_mag = (denoised * denoised).sum(axis=2, keepdims=True) ** .5
|
|
75
|
+
|
|
76
|
+
# Project denoised along original image vector
|
|
77
|
+
lumi = dot / (denoised_mag ** 2 + eps) * denoised
|
|
78
|
+
chroma = img - lumi
|
|
79
|
+
output = (1 - lumi_blend) * denoised + lumi * (lumi_blend) + chroma_blend * chroma
|
|
80
|
+
|
|
81
|
+
if clip_highlights:
|
|
82
|
+
output = clip_highlights_func(img, output)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
return output
|
|
87
|
+
|
|
88
|
+
def clip_highlights_func(img, denoised):
|
|
89
|
+
out = np.copy(denoised)
|
|
90
|
+
mask = (img == 1)
|
|
91
|
+
out[mask] = 1
|
|
92
|
+
return out
|
|
@@ -1,8 +1,15 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import platform
|
|
3
|
+
import argparse
|
|
4
|
+
from pathlib import Path
|
|
3
5
|
from RawForge.application.ModelHandler import ModelHandler
|
|
4
6
|
from RawForge.application.postprocessing import postprocess
|
|
5
|
-
import
|
|
7
|
+
from RawForge.application.MODEL_REGISTRY import MODEL_REGISTRY
|
|
8
|
+
from RawForge.application.InferenceWorker import InferenceWorker
|
|
9
|
+
from RawForge.application.ImageSaver import ImageSaver
|
|
10
|
+
from RawForge.application.helpers.get_image import get_image
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
6
13
|
# import glob
|
|
7
14
|
|
|
8
15
|
def main():
|
|
@@ -17,13 +24,12 @@ def main():
|
|
|
17
24
|
parser.add_argument('--device', type=str, help='Set device backend (cuda, cpu, mps).')
|
|
18
25
|
parser.add_argument('--disable_tqdm', action='store_true', help='Disable the progress bar.')
|
|
19
26
|
parser.add_argument('--tile_size', type=int, help='Set tile size. (default: 256)', default=256)
|
|
27
|
+
parser.add_argument('--tile_overlap', type=float, help='Set tile overlap. (default: 0.25)', default=0.25)
|
|
20
28
|
|
|
21
29
|
parser.add_argument('--lumi', type=float, help='Lumi noise (0-1).', default=0)
|
|
22
30
|
parser.add_argument('--chroma', type=float, help='Chroma noise (0-1).', default=0)
|
|
23
31
|
parser.add_argument('--clip_highlights', action='store_true', help='Do not run model on clipped highlights.')
|
|
24
|
-
|
|
25
|
-
args = parser.parse_args()
|
|
26
|
-
|
|
32
|
+
parser.add_argument('--affine', action='store_true', help='Affine fit the model to the input.')
|
|
27
33
|
# # Glob handeling
|
|
28
34
|
# in_files = sorted(glob.glob(args.in_file))
|
|
29
35
|
# if not in_files:
|
|
@@ -33,30 +39,39 @@ def main():
|
|
|
33
39
|
# raise ValueError(
|
|
34
40
|
# "When using glob input, out_file must be a directory."
|
|
35
41
|
# )
|
|
36
|
-
|
|
37
|
-
handler = ModelHandler()
|
|
38
|
-
|
|
39
|
-
handler.load_model(args.model)
|
|
40
|
-
|
|
41
|
-
iso = handler.load_rh(args.in_file)
|
|
42
|
+
args = parser.parse_args()
|
|
42
43
|
|
|
44
|
+
# Initialize
|
|
45
|
+
models = args.model.split(',')
|
|
46
|
+
model_params = MODEL_REGISTRY[models[0]]
|
|
47
|
+
rh, image_RGB, iso = get_image(args.in_file, model_params, dims=args.dims)
|
|
43
48
|
if not args.conditioning:
|
|
44
49
|
conditioning = [iso, 0]
|
|
45
50
|
else:
|
|
46
51
|
conditioning = [int(x) for x in args.conditioning.split(',')]
|
|
47
|
-
|
|
48
|
-
if args.device:
|
|
49
|
-
handler.set_device(args.device)
|
|
50
|
-
|
|
51
|
-
|
|
52
52
|
inference_kwargs = {"disable_tqdm": args.disable_tqdm,
|
|
53
|
-
"tile_size": args.tile_size
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
53
|
+
"tile_size": args.tile_size,
|
|
54
|
+
"tile_overlap": args.tile_overlap}
|
|
55
|
+
|
|
56
|
+
# Run processing
|
|
57
|
+
output_img = image_RGB
|
|
58
|
+
for model in models:
|
|
59
|
+
handler = ModelHandler()
|
|
60
|
+
handler.load_model(model)
|
|
61
|
+
worker = InferenceWorker(handler.model, handler.model_params, conditioning, **inference_kwargs)
|
|
62
|
+
_, output_img = worker._tile_process(output_img, handler.model_params)
|
|
59
63
|
|
|
64
|
+
# Postprocess
|
|
65
|
+
output = postprocess(image_RGB, output_img, lumi_blend=args.lumi, chroma_blend=args.chroma, eps=1e-6,
|
|
66
|
+
clip_highlights=args.clip_highlights, affine=args.affine)
|
|
67
|
+
|
|
68
|
+
# Save image
|
|
69
|
+
saver = ImageSaver(model_params, rh, dims=args.dims)
|
|
70
|
+
apply_ccm = model_params['demosaicing'] == 'rawpy'
|
|
71
|
+
if Path(args.out_file).suffix=='.tiff':
|
|
72
|
+
saver.to_tiff(output, args.out_file, apply_ccm=apply_ccm)
|
|
73
|
+
else:
|
|
74
|
+
saver.to_raw(output, args.out_file, args.cfa)
|
|
60
75
|
|
|
61
76
|
if __name__ == '__main__':
|
|
62
77
|
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rawforge"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.2.0"
|
|
4
4
|
description = "A compute backend/CLI application for using machine learning models on raw images."
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Ryan Mueller"},
|
|
@@ -21,7 +21,7 @@ dependencies = [
|
|
|
21
21
|
"numpy >=2.2",
|
|
22
22
|
"RawHandler ~=0.2.0",
|
|
23
23
|
"colour_demosaicing ~=0.2.6",
|
|
24
|
-
"
|
|
24
|
+
"blended-tiling-numpy",
|
|
25
25
|
"requests ~=2.32",
|
|
26
26
|
"platformdirs ~=4.5",
|
|
27
27
|
"tqdm ~=4.67",
|
|
@@ -30,6 +30,13 @@ dependencies = [
|
|
|
30
30
|
]
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
[project.optional-dependencies]
|
|
34
|
+
cpu = ["onnxruntime"]
|
|
35
|
+
cuda = ["onnxruntime-gpu"]
|
|
36
|
+
web = ["onnxruntime-web"]
|
|
37
|
+
directml = ["onnxruntime-directml"]
|
|
38
|
+
|
|
39
|
+
|
|
33
40
|
[project.urls]
|
|
34
41
|
"Homepage" = "https://github.com/rymuelle/RawForge"
|
|
35
42
|
"Bug Tracker" = "https://github.com/rymuelle/RawForge/issues"
|
|
@@ -43,4 +50,4 @@ where = ["."]
|
|
|
43
50
|
include = ["RawForge*"]
|
|
44
51
|
|
|
45
52
|
[project.scripts]
|
|
46
|
-
|
|
53
|
+
rawforgeonnx = "RawForge.main:main"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rawforge
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A compute backend/CLI application for using machine learning models on raw images.
|
|
5
5
|
Author: Ryan Mueller
|
|
6
6
|
License: MIT License
|
|
@@ -40,12 +40,20 @@ License-File: LICENSE
|
|
|
40
40
|
Requires-Dist: numpy>=2.2
|
|
41
41
|
Requires-Dist: RawHandler~=0.2.0
|
|
42
42
|
Requires-Dist: colour_demosaicing~=0.2.6
|
|
43
|
-
Requires-Dist:
|
|
43
|
+
Requires-Dist: blended-tiling-numpy
|
|
44
44
|
Requires-Dist: requests~=2.32
|
|
45
45
|
Requires-Dist: platformdirs~=4.5
|
|
46
46
|
Requires-Dist: tqdm~=4.67
|
|
47
47
|
Requires-Dist: cryptography>=46.0
|
|
48
48
|
Requires-Dist: tifffile
|
|
49
|
+
Provides-Extra: cpu
|
|
50
|
+
Requires-Dist: onnxruntime; extra == "cpu"
|
|
51
|
+
Provides-Extra: cuda
|
|
52
|
+
Requires-Dist: onnxruntime-gpu; extra == "cuda"
|
|
53
|
+
Provides-Extra: web
|
|
54
|
+
Requires-Dist: onnxruntime-web; extra == "web"
|
|
55
|
+
Provides-Extra: directml
|
|
56
|
+
Requires-Dist: onnxruntime-directml; extra == "directml"
|
|
49
57
|
Dynamic: license-file
|
|
50
58
|
|
|
51
59
|
# RawRefinery
|
|
@@ -3,13 +3,15 @@ README.md
|
|
|
3
3
|
pyproject.toml
|
|
4
4
|
RawForge/__init__.py
|
|
5
5
|
RawForge/main.py
|
|
6
|
+
RawForge/application/ImageSaver.py
|
|
6
7
|
RawForge/application/InferenceWorker.py
|
|
7
|
-
RawForge/application/InferenceWorkerRawpy.py
|
|
8
8
|
RawForge/application/MODEL_REGISTRY.py
|
|
9
9
|
RawForge/application/ModelHandler.py
|
|
10
10
|
RawForge/application/dng_utils.py
|
|
11
11
|
RawForge/application/postprocessing.py
|
|
12
|
-
RawForge/application/
|
|
12
|
+
RawForge/application/helpers/censored_fit.py
|
|
13
|
+
RawForge/application/helpers/get_image.py
|
|
14
|
+
RawForge/application/helpers/utils.py
|
|
13
15
|
RawForge/scripts/generate_keys.py
|
|
14
16
|
RawForge/scripts/sign_models.py
|
|
15
17
|
rawforge.egg-info/PKG-INFO
|
|
@@ -1,9 +1,21 @@
|
|
|
1
1
|
numpy>=2.2
|
|
2
2
|
RawHandler~=0.2.0
|
|
3
3
|
colour_demosaicing~=0.2.6
|
|
4
|
-
|
|
4
|
+
blended-tiling-numpy
|
|
5
5
|
requests~=2.32
|
|
6
6
|
platformdirs~=4.5
|
|
7
7
|
tqdm~=4.67
|
|
8
8
|
cryptography>=46.0
|
|
9
9
|
tifffile
|
|
10
|
+
|
|
11
|
+
[cpu]
|
|
12
|
+
onnxruntime
|
|
13
|
+
|
|
14
|
+
[cuda]
|
|
15
|
+
onnxruntime-gpu
|
|
16
|
+
|
|
17
|
+
[directml]
|
|
18
|
+
onnxruntime-directml
|
|
19
|
+
|
|
20
|
+
[web]
|
|
21
|
+
onnxruntime-web
|
|
@@ -1,98 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from blended_tiling import TilingModule
|
|
3
|
-
from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004
|
|
4
|
-
from RawForge.application.postprocessing import match_colors_linear
|
|
5
|
-
from tqdm import tqdm
|
|
6
|
-
|
|
7
|
-
class InferenceWorker():
|
|
8
|
-
def __init__(self, model, model_params, device, rh, conditioning, dims, tile_size=256, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
|
|
9
|
-
super().__init__()
|
|
10
|
-
self.model = model
|
|
11
|
-
self.model_params = model_params
|
|
12
|
-
self.device = device
|
|
13
|
-
self.rh = rh
|
|
14
|
-
self.conditioning = conditioning
|
|
15
|
-
self.dims = dims
|
|
16
|
-
# Quick and dirty hack to force to be even
|
|
17
|
-
self.tile_size = tile_size
|
|
18
|
-
self.tile_overlap = tile_overlap
|
|
19
|
-
self.batch_size = batch_size
|
|
20
|
-
self._is_cancelled = False
|
|
21
|
-
self.disable_tqdm = disable_tqdm
|
|
22
|
-
|
|
23
|
-
def cancel(self):
|
|
24
|
-
self._is_cancelled = True
|
|
25
|
-
|
|
26
|
-
def _tile_process(self):
|
|
27
|
-
# Prepare Data
|
|
28
|
-
image_RGGB = self.rh.as_rggb(dims=self.dims, colorspace='lin_rec2020')
|
|
29
|
-
image_RGB = self.rh.as_rgb(dims=self.dims, demosaicing_func=demosaicing_CFA_Bayer_Malvar2004, colorspace='lin_rec2020', clip=True)
|
|
30
|
-
tensor_image = torch.from_numpy(image_RGGB).unsqueeze(0).contiguous()
|
|
31
|
-
tensor_RGB = torch.from_numpy(image_RGB).unsqueeze(0).contiguous()
|
|
32
|
-
|
|
33
|
-
full_size = [image_RGGB.shape[1], image_RGGB.shape[2]]
|
|
34
|
-
tile_size = [self.tile_size // 2, self.tile_size // 2]
|
|
35
|
-
overlap = [self.tile_overlap, self.tile_overlap]
|
|
36
|
-
|
|
37
|
-
# Tiling Setup
|
|
38
|
-
tiling_module = TilingModule(tile_size=tile_size, tile_overlap=overlap, base_size=full_size)
|
|
39
|
-
tiling_module_rgb = TilingModule(tile_size=[s*2 for s in tile_size], tile_overlap=overlap, base_size=[s*2 for s in full_size])
|
|
40
|
-
tiling_module_rebuild = TilingModule(tile_size=[s*2 for s in tile_size], tile_overlap=overlap, base_size=[s*2 for s in full_size])
|
|
41
|
-
|
|
42
|
-
tiles = tiling_module.split_into_tiles(tensor_image).float().to(self.device)
|
|
43
|
-
tiles_rgb = tiling_module_rgb.split_into_tiles(tensor_RGB).float().to(self.device)
|
|
44
|
-
|
|
45
|
-
batches = torch.split(tiles, self.batch_size)
|
|
46
|
-
batches_rgb = torch.split(tiles_rgb, self.batch_size)
|
|
47
|
-
|
|
48
|
-
# Conditioning Setup
|
|
49
|
-
cond_tensor = torch.as_tensor(self.conditioning, device=self.device).float().unsqueeze(0)
|
|
50
|
-
cond_tensor[:, 0] /= 6400
|
|
51
|
-
cond_tensor[:, 1] = 0
|
|
52
|
-
cond_tensor = cond_tensor[:, 0:1]
|
|
53
|
-
|
|
54
|
-
processed_batches = []
|
|
55
|
-
|
|
56
|
-
# Determine Dtype
|
|
57
|
-
dtype_map = {'mps': torch.float16, 'cuda': torch.float16, 'cpu': torch.bfloat16}
|
|
58
|
-
autocast_dtype = dtype_map.get(self.device.type, torch.float32)
|
|
59
|
-
|
|
60
|
-
total_batches = len(batches_rgb)
|
|
61
|
-
|
|
62
|
-
# Inference Loop
|
|
63
|
-
with torch.no_grad():
|
|
64
|
-
with torch.autocast(device_type=self.device.type, dtype=autocast_dtype):
|
|
65
|
-
for i, (batch, batch_rgb) in tqdm(enumerate(zip(batches, batches_rgb)), disable=self.disable_tqdm):
|
|
66
|
-
if self._is_cancelled: return None, None
|
|
67
|
-
|
|
68
|
-
B = batch.shape[0]
|
|
69
|
-
# Expand conditioning to match batch size
|
|
70
|
-
curr_cond = cond_tensor.expand(B, -1)
|
|
71
|
-
|
|
72
|
-
output = self.model(batch_rgb, curr_cond)
|
|
73
|
-
|
|
74
|
-
# Output processing
|
|
75
|
-
if "affine" in self.model_params:
|
|
76
|
-
output, _, _ = match_colors_linear(output, batch_rgb)
|
|
77
|
-
processed_batches.append(output.cpu())
|
|
78
|
-
|
|
79
|
-
# Rebuild
|
|
80
|
-
tiles_out = torch.cat(processed_batches, dim=0)
|
|
81
|
-
stitched = tiling_module_rebuild.rebuild_with_masks(tiles_out).detach().cpu().numpy()[0]
|
|
82
|
-
|
|
83
|
-
torch.cuda.empty_cache()
|
|
84
|
-
|
|
85
|
-
return image_RGB.transpose(1, 2, 0), stitched.transpose(1, 2, 0)
|
|
86
|
-
|
|
87
|
-
def run(self):
|
|
88
|
-
try:
|
|
89
|
-
img, denoised_img = self._tile_process()
|
|
90
|
-
|
|
91
|
-
# Post-process blending
|
|
92
|
-
blend_alpha = self.conditioning[1] / 100
|
|
93
|
-
final_denoised = (denoised_img * (1 - blend_alpha)) + (img * blend_alpha)
|
|
94
|
-
|
|
95
|
-
return img, final_denoised
|
|
96
|
-
|
|
97
|
-
except Exception as e:
|
|
98
|
-
print(str(e))
|
|
@@ -1,108 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from blended_tiling import TilingModule
|
|
3
|
-
from RawForge.application.postprocessing import match_colors_linear
|
|
4
|
-
from tqdm import tqdm
|
|
5
|
-
import rawpy
|
|
6
|
-
|
|
7
|
-
class InferenceWorkerRawpy():
|
|
8
|
-
def __init__(self, model, model_params, device, rh, conditioning, dims, tile_size=512, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
|
|
9
|
-
super().__init__()
|
|
10
|
-
self.model = model
|
|
11
|
-
self.model_params = model_params
|
|
12
|
-
self.device = device
|
|
13
|
-
self.rh = rh
|
|
14
|
-
self.conditioning = conditioning
|
|
15
|
-
self.dims = dims
|
|
16
|
-
self.tile_size = tile_size
|
|
17
|
-
if 'tile_size' in model_params:
|
|
18
|
-
self.tile_size = model_params['tile_size']
|
|
19
|
-
self.tile_overlap = tile_overlap
|
|
20
|
-
self.batch_size = batch_size
|
|
21
|
-
if 'batch_size' in model_params:
|
|
22
|
-
self.batch_size = model_params['batch_size']
|
|
23
|
-
self._is_cancelled = False
|
|
24
|
-
self.disable_tqdm = disable_tqdm
|
|
25
|
-
|
|
26
|
-
def cancel(self):
|
|
27
|
-
self._is_cancelled = True
|
|
28
|
-
|
|
29
|
-
def _tile_process(self):
|
|
30
|
-
# Prepare Data
|
|
31
|
-
image_RGB = self.rh.rawpy_object.postprocess(
|
|
32
|
-
user_wb=[1, 1, 1, 1],
|
|
33
|
-
output_color=rawpy.ColorSpace.raw,
|
|
34
|
-
demosaic_algorithm= rawpy.DemosaicAlgorithm(3),
|
|
35
|
-
no_auto_bright=True,
|
|
36
|
-
use_camera_wb=False,
|
|
37
|
-
use_auto_wb=False,
|
|
38
|
-
gamma=(1, 1),
|
|
39
|
-
user_flip=0,
|
|
40
|
-
output_bps=16,
|
|
41
|
-
no_auto_scale=True,
|
|
42
|
-
) / self.rh.rawpy_object.white_level
|
|
43
|
-
|
|
44
|
-
image_RGB = image_RGB.transpose(2, 0, 1)
|
|
45
|
-
|
|
46
|
-
tensor_RGB = torch.from_numpy(image_RGB).unsqueeze(0).contiguous()
|
|
47
|
-
|
|
48
|
-
full_size = [image_RGB.shape[1], image_RGB.shape[2]]
|
|
49
|
-
tile_size = [self.tile_size, self.tile_size]
|
|
50
|
-
overlap = [self.tile_overlap, self.tile_overlap]
|
|
51
|
-
|
|
52
|
-
# Tiling Setup
|
|
53
|
-
tiling_module_rgb = TilingModule(tile_size=[s for s in tile_size], tile_overlap=overlap, base_size=[s for s in full_size])
|
|
54
|
-
|
|
55
|
-
tiles_rgb = tiling_module_rgb.split_into_tiles(tensor_RGB).float().to(self.device)
|
|
56
|
-
|
|
57
|
-
batches_rgb = torch.split(tiles_rgb, self.batch_size)
|
|
58
|
-
|
|
59
|
-
# Conditioning Setup
|
|
60
|
-
cond_tensor = torch.as_tensor(self.conditioning, device=self.device).float().unsqueeze(0)
|
|
61
|
-
cond_tensor[:, 0] /= 6400
|
|
62
|
-
cond_tensor[:, 1] = 0
|
|
63
|
-
cond_tensor = cond_tensor[:, 0:1]
|
|
64
|
-
|
|
65
|
-
processed_batches = []
|
|
66
|
-
|
|
67
|
-
# Determine Dtype
|
|
68
|
-
dtype_map = {'mps': torch.float16, 'cuda': torch.float16, 'cpu': torch.bfloat16}
|
|
69
|
-
autocast_dtype = dtype_map.get(self.device.type, torch.float32)
|
|
70
|
-
|
|
71
|
-
total_batches = len(batches_rgb)
|
|
72
|
-
|
|
73
|
-
# Inference Loop
|
|
74
|
-
with torch.no_grad():
|
|
75
|
-
with torch.autocast(device_type=self.device.type, dtype=autocast_dtype):
|
|
76
|
-
for i, (batch_rgb) in tqdm(enumerate(batches_rgb), disable=self.disable_tqdm):
|
|
77
|
-
if self._is_cancelled: return None, None
|
|
78
|
-
|
|
79
|
-
B = batch_rgb.shape[0]
|
|
80
|
-
# Expand conditioning to match batch size
|
|
81
|
-
curr_cond = cond_tensor.expand(B, -1)
|
|
82
|
-
output = self.model(batch_rgb, curr_cond*0)
|
|
83
|
-
processed_batches.append(output.cpu())
|
|
84
|
-
|
|
85
|
-
# Rebuild
|
|
86
|
-
tiles_out = torch.cat(processed_batches, dim=0)
|
|
87
|
-
stitched = tiling_module_rgb.rebuild_with_masks(tiles_out).detach().cpu()
|
|
88
|
-
|
|
89
|
-
if "affine" in self.model_params:
|
|
90
|
-
stitched, _, _ = match_colors_linear(stitched, tensor_RGB)
|
|
91
|
-
|
|
92
|
-
stitched = stitched.numpy()[0]
|
|
93
|
-
torch.cuda.empty_cache()
|
|
94
|
-
|
|
95
|
-
return image_RGB.transpose(1, 2, 0), stitched.transpose(1, 2, 0)
|
|
96
|
-
|
|
97
|
-
def run(self):
|
|
98
|
-
try:
|
|
99
|
-
img, denoised_img = self._tile_process()
|
|
100
|
-
|
|
101
|
-
# Post-process blending
|
|
102
|
-
blend_alpha = self.conditioning[1] / 100
|
|
103
|
-
final_denoised = (denoised_img * (1 - blend_alpha)) + (img * blend_alpha)
|
|
104
|
-
|
|
105
|
-
return img, final_denoised
|
|
106
|
-
|
|
107
|
-
except Exception as e:
|
|
108
|
-
print(str(e))
|
|
@@ -1,59 +0,0 @@
|
|
|
1
|
-
MODEL_REGISTRY = {
|
|
2
|
-
"TreeNetDenoise": {
|
|
3
|
-
"url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1.pt",
|
|
4
|
-
"filename": "ShadowWeightedL1.pt",
|
|
5
|
-
"max_iso": 65535,
|
|
6
|
-
},
|
|
7
|
-
"TreeNetDenoiseLight": {
|
|
8
|
-
"url": " https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_light.pt",
|
|
9
|
-
"filename": "ShadowWeightedL1_light.pt",
|
|
10
|
-
"max_iso": 65535,
|
|
11
|
-
},
|
|
12
|
-
"TreeNetDenoiseSuperLight": {
|
|
13
|
-
"url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_super_light.pt",
|
|
14
|
-
"filename": "ShadowWeightedL1_super_light.pt",
|
|
15
|
-
"max_iso": 65535,
|
|
16
|
-
},
|
|
17
|
-
"TreeNetDenoiseHeavy": {
|
|
18
|
-
"url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_24_deep_500.pt",
|
|
19
|
-
"filename": "ShadowWeightedL1_24_deep_500.pt",
|
|
20
|
-
"max_iso": 65535,
|
|
21
|
-
},
|
|
22
|
-
"Deblur": {
|
|
23
|
-
"url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/realblur_gamma_140.pt",
|
|
24
|
-
"filename": "realblur_gamma_140.pt",
|
|
25
|
-
"affine": True,
|
|
26
|
-
},
|
|
27
|
-
"DeepSharpen": {
|
|
28
|
-
"url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/Deblur_deep_24.pt",
|
|
29
|
-
"filename": "Deblur_deep_24.pt",
|
|
30
|
-
"affine": True,
|
|
31
|
-
},
|
|
32
|
-
"TreeNetDenoiseXTrans": {
|
|
33
|
-
"url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.pt",
|
|
34
|
-
"filename": "xtrans_fixed_exposure_no_conditioning_400.pt",
|
|
35
|
-
"backend": "rawpy",
|
|
36
|
-
"conditioning": "false",
|
|
37
|
-
},
|
|
38
|
-
"XFormerXTrans": {
|
|
39
|
-
"url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/xformer.pt",
|
|
40
|
-
"filename": "xformer.pt",
|
|
41
|
-
"backend": "rawpy",
|
|
42
|
-
"conditioning": "false",
|
|
43
|
-
"batch_size": 1,
|
|
44
|
-
},
|
|
45
|
-
"XFormerXTrans352": {
|
|
46
|
-
"url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/xformer_352.pt",
|
|
47
|
-
"filename": "xformer_352.pt",
|
|
48
|
-
"backend": "rawpy",
|
|
49
|
-
"conditioning": "false",
|
|
50
|
-
"batch_size": 1,
|
|
51
|
-
},
|
|
52
|
-
"RestormerXTrans": {
|
|
53
|
-
"url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/restormer.pt",
|
|
54
|
-
"filename": "restormer.pt",
|
|
55
|
-
"backend": "rawpy",
|
|
56
|
-
"conditioning": "false",
|
|
57
|
-
"batch_size": 1,
|
|
58
|
-
},
|
|
59
|
-
}
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
def match_colors_linear(
|
|
4
|
-
src: torch.Tensor,
|
|
5
|
-
tgt: torch.Tensor,
|
|
6
|
-
sample_fraction: float = 0.05
|
|
7
|
-
):
|
|
8
|
-
"""
|
|
9
|
-
Fit per-channel affine color transforms:
|
|
10
|
-
tgt ≈ scale * src + bias
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
src: [B, C, H, W] source tensor
|
|
14
|
-
tgt: [B, C, H, W] target tensor
|
|
15
|
-
sample_fraction: fraction of pixels to use for fitting
|
|
16
|
-
|
|
17
|
-
Returns:
|
|
18
|
-
transformed_src: source after color matching
|
|
19
|
-
scale: [B, C]
|
|
20
|
-
bias: [B, C]
|
|
21
|
-
"""
|
|
22
|
-
|
|
23
|
-
B, C, H, W = src.shape
|
|
24
|
-
device = src.device
|
|
25
|
-
|
|
26
|
-
# Flatten spatial dims
|
|
27
|
-
src_flat = src.view(B, C, -1)
|
|
28
|
-
tgt_flat = tgt.view(B, C, -1)
|
|
29
|
-
|
|
30
|
-
# Sample subset of pixels
|
|
31
|
-
N = src_flat.shape[-1]
|
|
32
|
-
k = max(64, int(N * sample_fraction))
|
|
33
|
-
|
|
34
|
-
idx = torch.randint(0, N, (k,), device=device)
|
|
35
|
-
|
|
36
|
-
src_s = src_flat[..., idx] # [B, C, k]
|
|
37
|
-
tgt_s = tgt_flat[..., idx]
|
|
38
|
-
|
|
39
|
-
# Compute scale and bias using least squares
|
|
40
|
-
# scale = cov(src, tgt) / var(src)
|
|
41
|
-
src_mean = src_s.mean(-1, keepdim=True)
|
|
42
|
-
tgt_mean = tgt_s.mean(-1, keepdim=True)
|
|
43
|
-
|
|
44
|
-
src_centered = src_s - src_mean
|
|
45
|
-
tgt_centered = tgt_s - tgt_mean
|
|
46
|
-
|
|
47
|
-
var_src = (src_centered ** 2).mean(-1)
|
|
48
|
-
cov = (src_centered * tgt_centered).mean(-1)
|
|
49
|
-
|
|
50
|
-
scale = cov / (var_src + 1e-8) # [B, C]
|
|
51
|
-
bias = tgt_mean.squeeze(-1) - scale * src_mean.squeeze(-1)
|
|
52
|
-
|
|
53
|
-
# Apply correction
|
|
54
|
-
scale_ = scale.view(B, C, 1, 1)
|
|
55
|
-
bias_ = bias.view(B, C, 1, 1)
|
|
56
|
-
transformed = src * scale_ + bias_
|
|
57
|
-
|
|
58
|
-
return transformed, scale, bias
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def scaled_dot_product(x1, x2, eps=1e-6):
|
|
62
|
-
dot = (x1 * x2).sum(axis=2, keepdims=True)
|
|
63
|
-
x1_mag = (x1 * x1).sum(axis=2, keepdims=True) ** .5
|
|
64
|
-
x2_mag = (x2 * x2).sum(axis=2, keepdims=True) ** .5
|
|
65
|
-
return dot/(x1_mag+x2_mag+eps)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def postprocess(img, denoised, lumi_blend=0, chroma_blend=0, eps=1e-6,
|
|
69
|
-
clip_highlights=False):
|
|
70
|
-
# Suggested by Jakob Andrén
|
|
71
|
-
dot = (img * denoised).sum(axis=2, keepdims=True)
|
|
72
|
-
img_mag = (img * img).sum(axis=2, keepdims=True) ** .5
|
|
73
|
-
denoised_mag = (denoised * denoised).sum(axis=2, keepdims=True) ** .5
|
|
74
|
-
# Project denoised along original image vector
|
|
75
|
-
lumi = dot / (denoised_mag ** 2 + eps) * denoised
|
|
76
|
-
chroma = img - lumi
|
|
77
|
-
output = (1-lumi_blend) * denoised + lumi * (lumi_blend) + chroma_blend * chroma
|
|
78
|
-
|
|
79
|
-
if clip_highlights:
|
|
80
|
-
output = clip_highlights_func(img, output)
|
|
81
|
-
|
|
82
|
-
return output
|
|
83
|
-
|
|
84
|
-
def clip_highlights_func(img, denoised):
|
|
85
|
-
mask = img == 1
|
|
86
|
-
denoised[mask] = 1
|
|
87
|
-
return denoised
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|