rawforge 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {rawforge-0.1.0 → rawforge-0.2.0}/PKG-INFO +10 -2
  2. rawforge-0.2.0/RawForge/application/ImageSaver.py +55 -0
  3. rawforge-0.2.0/RawForge/application/InferenceWorker.py +79 -0
  4. rawforge-0.2.0/RawForge/application/MODEL_REGISTRY.py +135 -0
  5. {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/application/ModelHandler.py +15 -98
  6. rawforge-0.2.0/RawForge/application/helpers/censored_fit.py +100 -0
  7. rawforge-0.2.0/RawForge/application/helpers/get_image.py +41 -0
  8. rawforge-0.2.0/RawForge/application/helpers/utils.py +17 -0
  9. rawforge-0.2.0/RawForge/application/postprocessing.py +92 -0
  10. {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/main.py +36 -21
  11. {rawforge-0.1.0 → rawforge-0.2.0}/pyproject.toml +10 -3
  12. {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/PKG-INFO +10 -2
  13. {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/SOURCES.txt +4 -2
  14. rawforge-0.2.0/rawforge.egg-info/entry_points.txt +2 -0
  15. {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/requires.txt +13 -1
  16. rawforge-0.1.0/RawForge/application/InferenceWorker.py +0 -98
  17. rawforge-0.1.0/RawForge/application/InferenceWorkerRawpy.py +0 -108
  18. rawforge-0.1.0/RawForge/application/MODEL_REGISTRY.py +0 -59
  19. rawforge-0.1.0/RawForge/application/postprocessing.py +0 -87
  20. rawforge-0.1.0/RawForge/application/utils.py +0 -10
  21. rawforge-0.1.0/rawforge.egg-info/entry_points.txt +0 -2
  22. {rawforge-0.1.0 → rawforge-0.2.0}/LICENSE +0 -0
  23. {rawforge-0.1.0 → rawforge-0.2.0}/README.md +0 -0
  24. {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/__init__.py +0 -0
  25. {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/application/dng_utils.py +0 -0
  26. {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/scripts/generate_keys.py +0 -0
  27. {rawforge-0.1.0 → rawforge-0.2.0}/RawForge/scripts/sign_models.py +0 -0
  28. {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/dependency_links.txt +0 -0
  29. {rawforge-0.1.0 → rawforge-0.2.0}/rawforge.egg-info/top_level.txt +0 -0
  30. {rawforge-0.1.0 → rawforge-0.2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rawforge
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A compute backend/CLI application for using machine learning models on raw images.
5
5
  Author: Ryan Mueller
6
6
  License: MIT License
@@ -40,12 +40,20 @@ License-File: LICENSE
40
40
  Requires-Dist: numpy>=2.2
41
41
  Requires-Dist: RawHandler~=0.2.0
42
42
  Requires-Dist: colour_demosaicing~=0.2.6
43
- Requires-Dist: blended_tiling>=0.0.1.dev7
43
+ Requires-Dist: blended-tiling-numpy
44
44
  Requires-Dist: requests~=2.32
45
45
  Requires-Dist: platformdirs~=4.5
46
46
  Requires-Dist: tqdm~=4.67
47
47
  Requires-Dist: cryptography>=46.0
48
48
  Requires-Dist: tifffile
49
+ Provides-Extra: cpu
50
+ Requires-Dist: onnxruntime; extra == "cpu"
51
+ Provides-Extra: cuda
52
+ Requires-Dist: onnxruntime-gpu; extra == "cuda"
53
+ Provides-Extra: web
54
+ Requires-Dist: onnxruntime-web; extra == "web"
55
+ Provides-Extra: directml
56
+ Requires-Dist: onnxruntime-directml; extra == "directml"
49
57
  Dynamic: license-file
50
58
 
51
59
  # RawRefinery
@@ -0,0 +1,55 @@
1
+ import numpy as np
2
+ from RawForge.application.dng_utils import convert_color_matrix, to_dng
3
+ import tifffile
4
+
5
+ class ImageSaver():
6
+ def __init__(self, model_params, rh, dims=None):
7
+ self.rh = rh
8
+ self.model_params = model_params
9
+ self.dims = dims
10
+
11
+ def to_tiff(self, image, filename, apply_ccm=True):
12
+ image = image
13
+ if apply_ccm:
14
+ transform_matrix = self.rh.rgb_colorspace_transform(colorspace='lin_rec2020')
15
+ image = image[0].transpose(1, 2, 0)
16
+ transformed = image @ transform_matrix.T
17
+ else:
18
+ transformed = image[0].transpose(1, 2, 0)
19
+
20
+ transformed = np.clip(transformed, 0, 1)
21
+ transformed = transformed ** (1/2.2)
22
+ transformed = transformed * (2 ** 8 - 1)
23
+
24
+ uint_img = transformed.astype(np.uint8)
25
+
26
+ tifffile.imwrite(
27
+ filename,
28
+ uint_img,
29
+ photometric='rgb', # Explicitly define the color space
30
+ compression='deflate' # Optional: Lossless compression supported by darktable
31
+ )
32
+
33
+ def to_raw(self, denoised, filename, save_cfa):
34
+ # Compute CFA
35
+ if self.model_params['demosaicing'] == 'rawpy':
36
+ _, mask = self.rh.compute_mask_and_sparse(dims=self.dims)
37
+ denoised = denoised[0]
38
+ denoised = denoised.clip(0, 1)
39
+
40
+ denoised = np.where(mask, denoised, 0)
41
+ denoised = denoised.sum(axis=0)
42
+ denoised = denoised * ( self.rh.core_metadata.white_level) + self.rh.core_metadata.black_level_per_channel[0]
43
+ self.rh.to_dng(filename, uint_img=denoised)
44
+ else:
45
+ transform_matrix = np.linalg.inv(
46
+ self.rh.rgb_colorspace_transform(colorspace=self.rh.colorspace)
47
+ )
48
+
49
+ CCM = self.rh.rgb_colorspace_transform(colorspace='XYZ')
50
+ CCM = np.linalg.inv(CCM)
51
+ denoised = denoised[0].transpose(1, 2, 0)
52
+ transformed = denoised @ transform_matrix.T
53
+ uint_img = np.clip(transformed * 2**16-1, 0, 2**16-1).astype(np.uint16)
54
+ ccm1 = convert_color_matrix(CCM)
55
+ to_dng(uint_img, self.rh, filename, ccm1, save_cfa=save_cfa, convert_to_cfa=True)
@@ -0,0 +1,79 @@
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from blended_tiling_numpy import TilingModule
4
+ from RawForge.application.postprocessing import match_colors_linear
5
+ from tqdm import tqdm
6
+ import rawpy
7
+
8
+ class InferenceWorker():
9
+ def __init__(self, model, model_params, conditioning, tile_size=512, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
10
+ super().__init__()
11
+ self.model = model
12
+ self.model_params = model_params
13
+ self.conditioning = conditioning
14
+ self.tile_size = tile_size
15
+ if 'tile_size' in model_params:
16
+ self.tile_size = model_params['tile_size']
17
+ self.tile_overlap = tile_overlap
18
+ self.batch_size = batch_size
19
+ if 'batch_size' in model_params:
20
+ self.batch_size = model_params['batch_size']
21
+ self._is_cancelled = False
22
+ self.disable_tqdm = disable_tqdm
23
+
24
+ def cancel(self):
25
+ self._is_cancelled = True
26
+
27
+
28
+ def _tile_process(self, image_RGB, model_params):
29
+ # Prepare Data
30
+ full_size = [image_RGB.shape[2], image_RGB.shape[3]]
31
+ tile_size = [self.tile_size, self.tile_size]
32
+ overlap = [self.tile_overlap, self.tile_overlap]
33
+ # Tiling Setup
34
+ tiling_module_rgb = TilingModule(tile_size=[s for s in tile_size], tile_overlap=overlap, base_size=[s for s in full_size])
35
+
36
+ tiles_rgb = tiling_module_rgb.split_into_tiles(image_RGB)
37
+
38
+ batches_rgb = [tiles_rgb[i : i + self.batch_size]
39
+ for i in range(0, len(tiles_rgb), self.batch_size)]
40
+ # Conditioning Setup
41
+ cond_tensor = np.array([self.conditioning]).astype(np.float32)
42
+ cond_tensor[:, 0] /= 6400.
43
+ cond_tensor[:, 1] = 0.
44
+ cond_tensor = cond_tensor[:, 0:1]
45
+ cond_tensor = cond_tensor.astype(np.float16)
46
+ if 'cond_scale' in model_params:
47
+ cond_tensor *= cond_tensor * model_params['cond_scale']
48
+
49
+ processed_batches = []
50
+
51
+ # Inference Loop
52
+ for i, (batch_rgb) in tqdm(enumerate(batches_rgb), disable=self.disable_tqdm):
53
+ if self._is_cancelled: return None, None
54
+
55
+ B = batch_rgb.shape[0]
56
+ # Expand conditioning to match batch size
57
+ curr_cond = np.broadcast_to(cond_tensor, (B, cond_tensor.shape[-1])).astype(np.float16)
58
+
59
+ payload = {
60
+ "input": batch_rgb,
61
+ "cond": curr_cond
62
+ }
63
+ # Filter based on inputs
64
+ model_inputs = {i.name for i in self.model.get_inputs()}
65
+ filtered_inputs = {k: v for k, v in payload.items() if k in model_inputs}
66
+ output = self.model.run(["output"], filtered_inputs)
67
+
68
+ processed_batches.append(output[0])
69
+
70
+ # Rebuild
71
+ tiles_out = np.concat(processed_batches, axis=0)
72
+ stitched = tiling_module_rgb.rebuild_with_masks(tiles_out)
73
+
74
+ return image_RGB, stitched
75
+
76
+ def run(self, model_params, image_RGB):
77
+ img, denoised_img = self._tile_process(image_RGB, model_params)
78
+ return img, denoised_img
79
+
@@ -0,0 +1,135 @@
1
+ MODEL_REGISTRY = {
2
+ "TreeNetDenoise": {
3
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1.onnx",
4
+ "filename": "ShadowWeightedL1.onnx",
5
+ "max_iso": 65535,
6
+ "demosaicing": "Malvar2004",
7
+ },
8
+ "TreeNetDenoiseLight": {
9
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_light.onnx",
10
+ "filename": "ShadowWeightedL1_light.onnx",
11
+ "max_iso": 65535,
12
+ "demosaicing": "Malvar2004",
13
+ },
14
+ "TreeNetDenoiseSuperLight": {
15
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_super_light.onnx",
16
+ "filename": "ShadowWeightedL1_super_light.onnx",
17
+ "max_iso": 65535,
18
+ "demosaicing": "Malvar2004",
19
+ },
20
+ "TreeNetDenoiseHeavy": {
21
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_24_deep_500.onnx",
22
+ "filename": "ShadowWeightedL1_24_deep_500.onnx",
23
+ "max_iso": 65535,
24
+ "demosaicing": "Malvar2004",
25
+ },
26
+ "Deblur": {
27
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/realblur_gamma_140.onnx",
28
+ "filename": "realblur_gamma_140.onnx",
29
+ "demosaicing": "Malvar2004",
30
+ },
31
+ "DeepSharpen": {
32
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/Deblur_deep_24.onnx",
33
+ "filename": "Deblur_deep_24.onnx",
34
+ "demosaicing": "Malvar2004",
35
+ },
36
+ "TreeNetDenoiseXTrans": {
37
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.onnx",
38
+ "filename": "xtrans_fixed_exposure_no_conditioning_400.onnx",
39
+ "demosaicing": "rawpy",
40
+ "conditioning": "false",
41
+ "batch_size": 1,
42
+ "crop_size": 256,
43
+ "cond_scale": 0,
44
+ },
45
+ "RestormerXTrans": {
46
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/restormer.onnx",
47
+ "filename": "restormer.onnx",
48
+ "demosaicing": "rawpy",
49
+ "conditioning": "false",
50
+ "batch_size": 1,
51
+ },
52
+ "XFormerXTrans2": {
53
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xformer_fp16.onnx",
54
+ "filename": "xformer_fp16.onnx",
55
+ "demosaicing": "rawpy",
56
+ "conditioning": "false",
57
+ "batch_size": 1,
58
+ },
59
+
60
+ "XFormerDenoise": {
61
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xformer_static.onnx",
62
+ "filename": "xformer_static.onnx",
63
+ "demosaicing": "rawpy",
64
+ "conditioning": "false",
65
+ "batch_size": 1,
66
+ "crop_size": 256,
67
+ },
68
+ "NAFDenoise": {
69
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_static.onnx",
70
+ "filename": "NAF_static.onnx",
71
+ "demosaicing": "rawpy",
72
+ "conditioning": "false",
73
+ "batch_size": 1,
74
+ "crop_size": 256,
75
+ },
76
+
77
+ "NAFDenoise_dynamic": {
78
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_dynamic.onnx",
79
+ "filename": "NAF_dynamic.onnx",
80
+ "demosaicing": "rawpy",
81
+ "conditioning": "false",
82
+ "batch_size": 1,
83
+ "crop_size": 256,
84
+ },
85
+ "NAFDenoise_dynamic_fp16": {
86
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_dynamic_fp16.onnx",
87
+ "filename": "NAF_dynamic_fp16.onnx",
88
+ "demosaicing": "rawpy",
89
+ "conditioning": "false",
90
+ "batch_size": 1,
91
+ "crop_size": 256,
92
+ },
93
+ "NAF_static_fp16": {
94
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_static_fp16.onnx",
95
+ "filename": "NAF_static_fp16.onnx",
96
+ "demosaicing": "rawpy",
97
+ "conditioning": "false",
98
+ "batch_size": 1,
99
+ "crop_size": 256,
100
+ },
101
+ "NAF_trace_test": {
102
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.onnx",
103
+ "filename": "xtrans_fixed_exposure_no_conditioning_400.onnx",
104
+ "demosaicing": "rawpy",
105
+ "conditioning": "false",
106
+ "batch_size": 1,
107
+ "crop_size": 256,
108
+ },
109
+ "non_trace_test": {
110
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/non_trace_test.onnx",
111
+ "filename": "non_trace_test.onnx",
112
+ "demosaicing": "rawpy",
113
+ "conditioning": "false",
114
+ "batch_size": 1,
115
+ "crop_size": 256,
116
+ },
117
+ "trace_test": {
118
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/trace_test.onnx",
119
+ "filename": "trace_test.onnx",
120
+ "demosaicing": "rawpy",
121
+ "conditioning": "false",
122
+ "batch_size": 1,
123
+ "crop_size": 256,
124
+ },
125
+ "non_trace_test_opset_20_static_simp": {
126
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/non_trace_test_opset_20_static_simp.onnx",
127
+ "filename": "non_trace_test_opset_20_static_simp.onnx",
128
+ "demosaicing": "rawpy",
129
+ "conditioning": "false",
130
+ "batch_size": 1,
131
+ "crop_size": 256,
132
+ },
133
+
134
+
135
+ }
@@ -1,21 +1,15 @@
1
- import torch
1
+ import onnxruntime as ort
2
+ import os
2
3
  import numpy as np
3
4
  from pathlib import Path
4
5
  from platformdirs import user_data_dir
5
- from time import perf_counter
6
6
  import requests
7
7
  from cryptography.hazmat.primitives import hashes, serialization
8
8
  from cryptography.hazmat.primitives.asymmetric import padding
9
9
 
10
- from RawHandler.RawHandler import RawHandler
11
- from RawHandler.RawHandlerRawpy import RawHandlerRawpy
12
-
13
- from RawForge.application.dng_utils import convert_color_matrix, to_dng
14
- from RawForge.application.utils import can_use_gpu
15
-
16
10
  from RawForge.application.MODEL_REGISTRY import MODEL_REGISTRY
17
11
  from RawForge.application.InferenceWorker import InferenceWorker
18
- from RawForge.application.InferenceWorkerRawpy import InferenceWorkerRawpy
12
+ from RawForge.application.helpers.utils import get_best_providers
19
13
 
20
14
  key_string = '''-----BEGIN PUBLIC KEY-----
21
15
  MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA8iRGMPqFIFVF0TM/AbMI
@@ -40,46 +34,23 @@ class ModelHandler():
40
34
  def __init__(self):
41
35
  super().__init__()
42
36
 
43
- self.model = None
44
- self.rh = None
45
- self.iso = 100
46
- self.colorspace = 'lin_rec2020'
47
-
37
+ self.model = None
38
+ app_name = "RawForge"
39
+ self.data_dir = Path(user_data_dir(app_name))
40
+ self.cache_dir = Path(os.path.expanduser(f"{self.data_dir}/model_cache")).resolve()
48
41
  # Manage devices
49
- devices = {
50
- "cuda": can_use_gpu(),
51
- "mps": torch.backends.mps.is_available(),
52
- "cpu": lambda : True
53
- }
54
- self.devices = [d for d, is_available in devices.items() if is_available]
55
- self.set_device(self.devices[0])
56
-
57
- self.filename = None
58
- self.start_time = None
59
- self.model_params = {}
60
-
42
+ self.providers = get_best_providers(cache_dir=self.cache_dir)
61
43
  self.pub = serialization.load_pem_public_key(key_string.encode('utf-8'))
62
44
 
63
- def load_rh(self, path):
64
- """Loads the raw file handler"""
65
- if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
66
- self.rh = RawHandlerRawpy(path)
67
- else:
68
- self.rh = RawHandler(path)
69
- self.iso = self.rh.full_metadata.get_ISO()
70
- return self.iso
71
45
 
72
46
  def load_model(self, model_key):
73
47
  """Loads a model by key from the registry"""
74
48
  if model_key not in MODEL_REGISTRY:
75
49
  print(f"Model {model_key} not found in registry.")
76
50
  return
77
-
78
51
  conf = MODEL_REGISTRY[model_key]
79
52
  self.model_params = conf
80
- app_name = "RawForge"
81
- data_dir = Path(user_data_dir(app_name))
82
- model_path = data_dir / conf["filename"]
53
+ model_path = self.data_dir / conf["filename"]
83
54
 
84
55
  # Handle Download
85
56
  if not model_path.is_file():
@@ -97,41 +68,15 @@ class ModelHandler():
97
68
  # Verify model before load
98
69
  self._verify_model(model_path, model_path.with_suffix(f'{model_path.suffix}.sig'))
99
70
 
100
- loaded = torch.jit.load(model_path, map_location='cpu')
101
- self.model = loaded.eval().to(self.device)
71
+ session = ort.InferenceSession(
72
+ model_path,
73
+ providers=self.providers,
74
+ )
75
+ print("Loaded!")
76
+ self.model = session
102
77
  except Exception as e:
103
78
  print(f"Failed to load model: {e}")
104
79
 
105
- def set_device(self, device):
106
- self.device = torch.device(device)
107
- if self.model:
108
- self.model.to(self.device)
109
- print(f"Using Device {self.device} from {device}")
110
-
111
- def run_inference(self, conditioning, dims=None, inference_kwargs={}):
112
- """Starts the worker thread"""
113
- if not self.model or not self.rh:
114
- print("Model or Image not loaded.")
115
- return
116
-
117
- # Some older models were trained with a lower max iso
118
- if "max_iso" in self.model_params:
119
- conditioning[0] = min(conditioning[0], self.model_params["max_iso"])
120
-
121
- if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
122
- worker = InferenceWorkerRawpy(self.model, self.model_params, self.device, self.rh, conditioning, dims, **inference_kwargs)
123
- else:
124
- worker = InferenceWorker(self.model, self.model_params, self.device, self.rh, conditioning, dims, **inference_kwargs)
125
- img, final_denoised = worker.run()
126
-
127
- return img, final_denoised
128
-
129
-
130
- def generate_thumbnail(self, size=400):
131
- if not self.rh: return None
132
- thumb = self.rh.generate_thumbnail(min_preview_size=size, clip=True)
133
- return thumb
134
-
135
80
  def _verify_model(self, dest_path, sig_path):
136
81
  try:
137
82
  data = Path(dest_path).read_bytes()
@@ -156,7 +101,6 @@ class ModelHandler():
156
101
  print(f"Model {dest_path} not verified! Deleting.")
157
102
  return False
158
103
 
159
-
160
104
  def _download_file(self, url, dest_path):
161
105
  dest_path.parent.mkdir(parents=True, exist_ok=True)
162
106
  try:
@@ -179,31 +123,4 @@ class ModelHandler():
179
123
  except Exception as e:
180
124
  print(e)
181
125
  return False
182
-
183
-
184
- def handle_full_image(self, denoised, filename, save_cfa):
185
- # Compute CFA
186
- if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
187
- _, mask = self.rh.compute_mask_and_sparse(dims=(0, 99999, 0, 99999))
188
- denoised = denoised.transpose(2, 0, 1)
189
- denoised = denoised.clip(0, 1)
190
-
191
- denoised = np.where(mask, denoised, 0)
192
- denoised = denoised.sum(axis=0)
193
- denoised = denoised * ( self.rh.core_metadata.white_level) + self.rh.core_metadata.black_level_per_channel[0]
194
- self.rh.to_dng(filename, uint_img=denoised)
195
- else:
196
- transform_matrix = np.linalg.inv(
197
- self.rh.rgb_colorspace_transform(colorspace=self.colorspace)
198
- )
199
-
200
- CCM = self.rh.rgb_colorspace_transform(colorspace='XYZ')
201
- CCM = np.linalg.inv(CCM)
202
-
203
- transformed = denoised @ transform_matrix.T
204
- uint_img = np.clip(transformed * 2**16-1, 0, 2**16-1).astype(np.uint16)
205
- ccm1 = convert_color_matrix(CCM)
206
- to_dng(uint_img, self.rh, filename, ccm1, save_cfa=save_cfa, convert_to_cfa=True)
207
-
208
-
209
126
 
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ import math
3
+
4
+ def norm_cdf(z):
5
+ erf_vec = np.vectorize(math.erf)
6
+ return 0.5 * (1 + erf_vec(z / np.sqrt(2)))
7
+
8
+ def norm_pdf(z):
9
+ return (1 / np.sqrt(2 * np.pi)) * np.exp(-0.5 * z**2)
10
+
11
+
12
+ def censored_linear_fit_twosided(x, y, clip_low=0, clip_high=1,
13
+ max_iter=200, tol=1e-6, include_offset=True):
14
+ """
15
+ Fit y ≈ a + b*x + ε, ε ~ N(0, σ²) under two-sided censoring:
16
+ clip_low ≤ y_true ≤ clip_high
17
+ Observed y are clipped to [clip_low, clip_high].
18
+ Returns (a, b, sigma) estimated via EM.
19
+
20
+ Parameters
21
+ ----------
22
+ x, y : array_like
23
+ Input data.
24
+ clip_low, clip_high : float or None
25
+ Lower/upper clip levels. Can be None for one-sided clipping.
26
+ max_iter : int
27
+ Maximum EM iterations.
28
+ tol : float
29
+ Relative tolerance for convergence.
30
+ include_offset : bool
31
+ If False, forces intercept a=0 (fit y ≈ b*x).
32
+
33
+ Returns
34
+ -------
35
+ a, b, sigma : floats
36
+ Estimated regression parameters.
37
+ """
38
+ x = np.asarray(x).ravel()
39
+ y = np.asarray(y).ravel()
40
+ mask = np.isfinite(x) & np.isfinite(y)
41
+ x, y = x[mask], y[mask]
42
+
43
+ if len(x) < 3:
44
+ raise ValueError("Not enough data points.")
45
+
46
+ # Initial guess (ordinary least squares)
47
+ if include_offset:
48
+ A = np.vstack([np.ones_like(x), x]).T
49
+ a, b = np.linalg.lstsq(A, y, rcond=None)[0]
50
+ else:
51
+ b = np.dot(x, y) / np.dot(x, x)
52
+ a = 0.0
53
+ sigma = np.std(y - (a + b*x))
54
+
55
+ for _ in range(max_iter):
56
+ mu = a + b*x
57
+ y_exp = y.copy()
58
+
59
+ # Handle right-censoring (high clip)
60
+ if clip_high is not None:
61
+ high_mask = y >= clip_high - 1e-12
62
+ if np.any(high_mask):
63
+ z = (clip_high - mu[high_mask]) / sigma
64
+ Phi = norm_cdf(z)
65
+ phi = norm_pdf(z)
66
+ one_minus_Phi = 1.0 - Phi
67
+ lambda_ = np.zeros_like(z)
68
+ valid = one_minus_Phi > 1e-15
69
+ lambda_[valid] = phi[valid] / one_minus_Phi[valid]
70
+ y_exp[high_mask] = mu[high_mask] + sigma * lambda_
71
+
72
+ # Handle left-censoring (low clip)
73
+ if clip_low is not None:
74
+ low_mask = y <= clip_low + 1e-12
75
+ if np.any(low_mask):
76
+ z = (clip_low - mu[low_mask]) / sigma
77
+ Phi = norm_cdf(z)
78
+ phi = norm_pdf(z)
79
+ lambda_ = np.zeros_like(z)
80
+ valid = Phi > 1e-15
81
+ lambda_[valid] = -phi[valid] / Phi[valid]
82
+ y_exp[low_mask] = mu[low_mask] + sigma * lambda_
83
+
84
+ # M-step: re-fit with imputed expectations
85
+ if include_offset:
86
+ A = np.vstack([np.ones_like(x), x]).T
87
+ a_new, b_new = np.linalg.lstsq(A, y_exp, rcond=None)[0]
88
+ else:
89
+ b_new = np.dot(x, y_exp) / np.dot(x, x)
90
+ a_new = 0.0
91
+
92
+ sigma_new = np.std(y_exp - (a_new + b_new*x))
93
+
94
+ if np.allclose([a, b, sigma], [a_new, b_new, sigma_new],
95
+ rtol=tol, atol=tol):
96
+ break
97
+
98
+ a, b, sigma = a_new, b_new, sigma_new
99
+
100
+ return a, b, sigma
@@ -0,0 +1,41 @@
1
+ import numpy as np
2
+ import rawpy
3
+ from RawHandler.RawHandler import RawHandler
4
+ from RawHandler.RawHandlerRawpy import RawHandlerRawpy
5
+
6
+ def load_rh(path, demosaicing):
7
+ """Loads the raw file handler"""
8
+ if demosaicing == 'rawpy':
9
+ rh = RawHandlerRawpy(path)
10
+ elif demosaicing == 'Malvar2004':
11
+ rh = RawHandler(path)
12
+ iso = rh.full_metadata.get_ISO()
13
+ return rh, iso
14
+
15
+ def get_image(path, model_params, dims=None):
16
+ demosaicing = model_params['demosaicing']
17
+ rh, iso = load_rh(path, demosaicing)
18
+
19
+ if demosaicing=='Malvar2004':
20
+ from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004
21
+ image_RGB = rh.as_rgb(dims=dims,
22
+ demosaicing_func=demosaicing_CFA_Bayer_Malvar2004,
23
+ colorspace='lin_rec2020', clip=True)
24
+ elif demosaicing=='rawpy':
25
+ image_RGB = rh.rawpy_object.postprocess(
26
+ user_wb=[1, 1, 1, 1],
27
+ output_color=rawpy.ColorSpace.raw,
28
+ demosaic_algorithm= rawpy.DemosaicAlgorithm(3),
29
+ no_auto_bright=True,
30
+ use_camera_wb=False,
31
+ use_auto_wb=False,
32
+ gamma=(1, 1),
33
+ user_flip=0,
34
+ output_bps=16,
35
+ no_auto_scale=True,
36
+ ) / rh.rawpy_object.white_level
37
+ if dims is not None:
38
+ image_RGB = image_RGB[dims[0]:dims[1], dims[2]:dims[3]]
39
+ image_RGB = image_RGB.transpose(2, 0, 1)
40
+ image_RGB = np.expand_dims(image_RGB, axis=0).astype(np.float16)
41
+ return rh, image_RGB, iso
@@ -0,0 +1,17 @@
1
+ import onnxruntime as ort
2
+ import os
3
+
4
+ def get_best_providers(cache_dir):
5
+ os.makedirs(cache_dir, exist_ok=True)
6
+ available = ort.get_available_providers()
7
+ return available
8
+ priority = [
9
+ "CUDAExecutionProvider", # NVIDIA
10
+ "ROCMExecutionProvider", # AMD (Direct)
11
+ "MIGraphXExecutionProvider", # AMD (Optimized)
12
+ "CoreMLExecutionProvider", # Apple Silicon
13
+ "DmlExecutionProvider", # Windows (AMD/Intel/Generic GPU)
14
+ "CPUExecutionProvider" # The fallback
15
+ ]
16
+
17
+ return [p for p in priority if p in available]
@@ -0,0 +1,92 @@
1
+ import numpy as np
2
+ from RawForge.application.helpers.censored_fit import censored_linear_fit_twosided
3
+
4
+ def match_colors_linear(
5
+ src: np.ndarray,
6
+ tgt: np.ndarray,
7
+ sample_fraction: float = 0.05,
8
+ censored_fit: bool = True
9
+ ):
10
+ """
11
+ Fit per-channel affine color transforms using NumPy.
12
+ src/tgt: [B, C, H, W]
13
+ """
14
+ B, C, H, W = src.shape
15
+
16
+ # Flatten spatial dims: [B, C, H*W]
17
+ src_flat = src.reshape(B, C, -1)
18
+ tgt_flat = tgt.reshape(B, C, -1)
19
+
20
+ # Sample subset of pixels
21
+ N = src_flat.shape[-1]
22
+ k = max(64, int(N * sample_fraction))
23
+
24
+ # Generate random indices
25
+ idx = np.random.randint(0, N, size=(k,))
26
+
27
+ src_s = src_flat[..., idx] # [B, C, k]
28
+ tgt_s = tgt_flat[..., idx]
29
+
30
+ # Compute scale and bias using least squares
31
+ src_mean = src_s.mean(axis=-1, keepdims=True)
32
+ tgt_mean = tgt_s.mean(axis=-1, keepdims=True)
33
+
34
+ src_centered = src_s - src_mean
35
+ tgt_centered = tgt_s - tgt_mean
36
+
37
+ var_src = (src_centered ** 2).mean(axis=-1)
38
+ cov = (src_centered * tgt_centered).mean(axis=-1)
39
+
40
+ scale = cov / (var_src + 1e-8) # [B, C]
41
+ # Squeeze the mean to [B, C] for bias calculation
42
+ bias = tgt_mean.squeeze(-1) - scale * src_mean.squeeze(-1)
43
+
44
+ # Apply correction: reshape for broadcasting to [B, C, H, W]
45
+ scale_ = scale[:, :, np.newaxis, np.newaxis]
46
+ bias_ = bias[:, :, np.newaxis, np.newaxis]
47
+ transformed = src * scale_ + bias_
48
+
49
+ if censored_fit:
50
+ a, b, sigma = censored_linear_fit_twosided(tgt_s.astype(np.float32()), src_s.astype(np.float32()))
51
+ transformed = a + b * transformed
52
+
53
+ return transformed, scale, bias
54
+
55
+
56
+ def scaled_dot_product(x1, x2, eps=1e-6):
57
+ # Sum along axis 2 (C if input is B, H, C)
58
+ dot = (x1 * x2).sum(axis=2, keepdims=True)
59
+ x1_mag = np.sqrt((x1 * x1).sum(axis=2, keepdims=True))
60
+ x2_mag = np.sqrt((x2 * x2).sum(axis=2, keepdims=True))
61
+ return dot / (x1_mag + x2_mag + eps)
62
+
63
+
64
+ def postprocess(img, denoised, lumi_blend=0, chroma_blend=0, eps=1e-6,
65
+ clip_highlights=False, affine=False):
66
+
67
+ if affine:
68
+ _denoised = np.expand_dims(denoised.transpose(2, 0, 1), axis=0)
69
+ _img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
70
+ _denoised, _, _ = match_colors_linear(_denoised, _img)
71
+ denoised = _denoised[0].transpose(1, 2, 0)
72
+
73
+ dot = (img * denoised).sum(axis=2, keepdims=True)
74
+ denoised_mag = (denoised * denoised).sum(axis=2, keepdims=True) ** .5
75
+
76
+ # Project denoised along original image vector
77
+ lumi = dot / (denoised_mag ** 2 + eps) * denoised
78
+ chroma = img - lumi
79
+ output = (1 - lumi_blend) * denoised + lumi * (lumi_blend) + chroma_blend * chroma
80
+
81
+ if clip_highlights:
82
+ output = clip_highlights_func(img, output)
83
+
84
+
85
+
86
+ return output
87
+
88
+ def clip_highlights_func(img, denoised):
89
+ out = np.copy(denoised)
90
+ mask = (img == 1)
91
+ out[mask] = 1
92
+ return out
@@ -1,8 +1,15 @@
1
1
  import sys
2
2
  import platform
3
+ import argparse
4
+ from pathlib import Path
3
5
  from RawForge.application.ModelHandler import ModelHandler
4
6
  from RawForge.application.postprocessing import postprocess
5
- import argparse
7
+ from RawForge.application.MODEL_REGISTRY import MODEL_REGISTRY
8
+ from RawForge.application.InferenceWorker import InferenceWorker
9
+ from RawForge.application.ImageSaver import ImageSaver
10
+ from RawForge.application.helpers.get_image import get_image
11
+ import numpy as np
12
+
6
13
  # import glob
7
14
 
8
15
  def main():
@@ -17,13 +24,12 @@ def main():
17
24
  parser.add_argument('--device', type=str, help='Set device backend (cuda, cpu, mps).')
18
25
  parser.add_argument('--disable_tqdm', action='store_true', help='Disable the progress bar.')
19
26
  parser.add_argument('--tile_size', type=int, help='Set tile size. (default: 256)', default=256)
27
+ parser.add_argument('--tile_overlap', type=float, help='Set tile overlap. (default: 0.25)', default=0.25)
20
28
 
21
29
  parser.add_argument('--lumi', type=float, help='Lumi noise (0-1).', default=0)
22
30
  parser.add_argument('--chroma', type=float, help='Chroma noise (0-1).', default=0)
23
31
  parser.add_argument('--clip_highlights', action='store_true', help='Do not run model on clipped highlights.')
24
-
25
- args = parser.parse_args()
26
-
32
+ parser.add_argument('--affine', action='store_true', help='Affine fit the model to the input.')
27
33
  # # Glob handeling
28
34
  # in_files = sorted(glob.glob(args.in_file))
29
35
  # if not in_files:
@@ -33,30 +39,39 @@ def main():
33
39
  # raise ValueError(
34
40
  # "When using glob input, out_file must be a directory."
35
41
  # )
36
-
37
- handler = ModelHandler()
38
-
39
- handler.load_model(args.model)
40
-
41
- iso = handler.load_rh(args.in_file)
42
+ args = parser.parse_args()
42
43
 
44
+ # Initialize
45
+ models = args.model.split(',')
46
+ model_params = MODEL_REGISTRY[models[0]]
47
+ rh, image_RGB, iso = get_image(args.in_file, model_params, dims=args.dims)
43
48
  if not args.conditioning:
44
49
  conditioning = [iso, 0]
45
50
  else:
46
51
  conditioning = [int(x) for x in args.conditioning.split(',')]
47
-
48
- if args.device:
49
- handler.set_device(args.device)
50
-
51
-
52
52
  inference_kwargs = {"disable_tqdm": args.disable_tqdm,
53
- "tile_size": args.tile_size}
54
- img, denoised_image = handler.run_inference(conditioning=conditioning, dims=args.dims, inference_kwargs=inference_kwargs)
55
-
56
- output = postprocess(img, denoised_image, lumi_blend=args.lumi, chroma_blend=args.chroma, eps=1e-6,
57
- clip_highlights=args.clip_highlights)
58
- handler.handle_full_image(output, args.out_file, args.cfa)
53
+ "tile_size": args.tile_size,
54
+ "tile_overlap": args.tile_overlap}
55
+
56
+ # Run processing
57
+ output_img = image_RGB
58
+ for model in models:
59
+ handler = ModelHandler()
60
+ handler.load_model(model)
61
+ worker = InferenceWorker(handler.model, handler.model_params, conditioning, **inference_kwargs)
62
+ _, output_img = worker._tile_process(output_img, handler.model_params)
59
63
 
64
+ # Postprocess
65
+ output = postprocess(image_RGB, output_img, lumi_blend=args.lumi, chroma_blend=args.chroma, eps=1e-6,
66
+ clip_highlights=args.clip_highlights, affine=args.affine)
67
+
68
+ # Save image
69
+ saver = ImageSaver(model_params, rh, dims=args.dims)
70
+ apply_ccm = model_params['demosaicing'] == 'rawpy'
71
+ if Path(args.out_file).suffix=='.tiff':
72
+ saver.to_tiff(output, args.out_file, apply_ccm=apply_ccm)
73
+ else:
74
+ saver.to_raw(output, args.out_file, args.cfa)
60
75
 
61
76
  if __name__ == '__main__':
62
77
  main()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rawforge"
3
- version = "0.1.0"
3
+ version = "0.2.0"
4
4
  description = "A compute backend/CLI application for using machine learning models on raw images."
5
5
  authors = [
6
6
  { name = "Ryan Mueller"},
@@ -21,7 +21,7 @@ dependencies = [
21
21
  "numpy >=2.2",
22
22
  "RawHandler ~=0.2.0",
23
23
  "colour_demosaicing ~=0.2.6",
24
- "blended_tiling >=0.0.1.dev7",
24
+ "blended-tiling-numpy",
25
25
  "requests ~=2.32",
26
26
  "platformdirs ~=4.5",
27
27
  "tqdm ~=4.67",
@@ -30,6 +30,13 @@ dependencies = [
30
30
  ]
31
31
 
32
32
 
33
+ [project.optional-dependencies]
34
+ cpu = ["onnxruntime"]
35
+ cuda = ["onnxruntime-gpu"]
36
+ web = ["onnxruntime-web"]
37
+ directml = ["onnxruntime-directml"]
38
+
39
+
33
40
  [project.urls]
34
41
  "Homepage" = "https://github.com/rymuelle/RawForge"
35
42
  "Bug Tracker" = "https://github.com/rymuelle/RawForge/issues"
@@ -43,4 +50,4 @@ where = ["."]
43
50
  include = ["RawForge*"]
44
51
 
45
52
  [project.scripts]
46
- rawforge = "RawForge.main:main"
53
+ rawforgeonnx = "RawForge.main:main"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rawforge
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A compute backend/CLI application for using machine learning models on raw images.
5
5
  Author: Ryan Mueller
6
6
  License: MIT License
@@ -40,12 +40,20 @@ License-File: LICENSE
40
40
  Requires-Dist: numpy>=2.2
41
41
  Requires-Dist: RawHandler~=0.2.0
42
42
  Requires-Dist: colour_demosaicing~=0.2.6
43
- Requires-Dist: blended_tiling>=0.0.1.dev7
43
+ Requires-Dist: blended-tiling-numpy
44
44
  Requires-Dist: requests~=2.32
45
45
  Requires-Dist: platformdirs~=4.5
46
46
  Requires-Dist: tqdm~=4.67
47
47
  Requires-Dist: cryptography>=46.0
48
48
  Requires-Dist: tifffile
49
+ Provides-Extra: cpu
50
+ Requires-Dist: onnxruntime; extra == "cpu"
51
+ Provides-Extra: cuda
52
+ Requires-Dist: onnxruntime-gpu; extra == "cuda"
53
+ Provides-Extra: web
54
+ Requires-Dist: onnxruntime-web; extra == "web"
55
+ Provides-Extra: directml
56
+ Requires-Dist: onnxruntime-directml; extra == "directml"
49
57
  Dynamic: license-file
50
58
 
51
59
  # RawRefinery
@@ -3,13 +3,15 @@ README.md
3
3
  pyproject.toml
4
4
  RawForge/__init__.py
5
5
  RawForge/main.py
6
+ RawForge/application/ImageSaver.py
6
7
  RawForge/application/InferenceWorker.py
7
- RawForge/application/InferenceWorkerRawpy.py
8
8
  RawForge/application/MODEL_REGISTRY.py
9
9
  RawForge/application/ModelHandler.py
10
10
  RawForge/application/dng_utils.py
11
11
  RawForge/application/postprocessing.py
12
- RawForge/application/utils.py
12
+ RawForge/application/helpers/censored_fit.py
13
+ RawForge/application/helpers/get_image.py
14
+ RawForge/application/helpers/utils.py
13
15
  RawForge/scripts/generate_keys.py
14
16
  RawForge/scripts/sign_models.py
15
17
  rawforge.egg-info/PKG-INFO
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ rawforgeonnx = RawForge.main:main
@@ -1,9 +1,21 @@
1
1
  numpy>=2.2
2
2
  RawHandler~=0.2.0
3
3
  colour_demosaicing~=0.2.6
4
- blended_tiling>=0.0.1.dev7
4
+ blended-tiling-numpy
5
5
  requests~=2.32
6
6
  platformdirs~=4.5
7
7
  tqdm~=4.67
8
8
  cryptography>=46.0
9
9
  tifffile
10
+
11
+ [cpu]
12
+ onnxruntime
13
+
14
+ [cuda]
15
+ onnxruntime-gpu
16
+
17
+ [directml]
18
+ onnxruntime-directml
19
+
20
+ [web]
21
+ onnxruntime-web
@@ -1,98 +0,0 @@
1
- import torch
2
- from blended_tiling import TilingModule
3
- from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004
4
- from RawForge.application.postprocessing import match_colors_linear
5
- from tqdm import tqdm
6
-
7
- class InferenceWorker():
8
- def __init__(self, model, model_params, device, rh, conditioning, dims, tile_size=256, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
9
- super().__init__()
10
- self.model = model
11
- self.model_params = model_params
12
- self.device = device
13
- self.rh = rh
14
- self.conditioning = conditioning
15
- self.dims = dims
16
- # Quick and dirty hack to force to be even
17
- self.tile_size = tile_size
18
- self.tile_overlap = tile_overlap
19
- self.batch_size = batch_size
20
- self._is_cancelled = False
21
- self.disable_tqdm = disable_tqdm
22
-
23
- def cancel(self):
24
- self._is_cancelled = True
25
-
26
- def _tile_process(self):
27
- # Prepare Data
28
- image_RGGB = self.rh.as_rggb(dims=self.dims, colorspace='lin_rec2020')
29
- image_RGB = self.rh.as_rgb(dims=self.dims, demosaicing_func=demosaicing_CFA_Bayer_Malvar2004, colorspace='lin_rec2020', clip=True)
30
- tensor_image = torch.from_numpy(image_RGGB).unsqueeze(0).contiguous()
31
- tensor_RGB = torch.from_numpy(image_RGB).unsqueeze(0).contiguous()
32
-
33
- full_size = [image_RGGB.shape[1], image_RGGB.shape[2]]
34
- tile_size = [self.tile_size // 2, self.tile_size // 2]
35
- overlap = [self.tile_overlap, self.tile_overlap]
36
-
37
- # Tiling Setup
38
- tiling_module = TilingModule(tile_size=tile_size, tile_overlap=overlap, base_size=full_size)
39
- tiling_module_rgb = TilingModule(tile_size=[s*2 for s in tile_size], tile_overlap=overlap, base_size=[s*2 for s in full_size])
40
- tiling_module_rebuild = TilingModule(tile_size=[s*2 for s in tile_size], tile_overlap=overlap, base_size=[s*2 for s in full_size])
41
-
42
- tiles = tiling_module.split_into_tiles(tensor_image).float().to(self.device)
43
- tiles_rgb = tiling_module_rgb.split_into_tiles(tensor_RGB).float().to(self.device)
44
-
45
- batches = torch.split(tiles, self.batch_size)
46
- batches_rgb = torch.split(tiles_rgb, self.batch_size)
47
-
48
- # Conditioning Setup
49
- cond_tensor = torch.as_tensor(self.conditioning, device=self.device).float().unsqueeze(0)
50
- cond_tensor[:, 0] /= 6400
51
- cond_tensor[:, 1] = 0
52
- cond_tensor = cond_tensor[:, 0:1]
53
-
54
- processed_batches = []
55
-
56
- # Determine Dtype
57
- dtype_map = {'mps': torch.float16, 'cuda': torch.float16, 'cpu': torch.bfloat16}
58
- autocast_dtype = dtype_map.get(self.device.type, torch.float32)
59
-
60
- total_batches = len(batches_rgb)
61
-
62
- # Inference Loop
63
- with torch.no_grad():
64
- with torch.autocast(device_type=self.device.type, dtype=autocast_dtype):
65
- for i, (batch, batch_rgb) in tqdm(enumerate(zip(batches, batches_rgb)), disable=self.disable_tqdm):
66
- if self._is_cancelled: return None, None
67
-
68
- B = batch.shape[0]
69
- # Expand conditioning to match batch size
70
- curr_cond = cond_tensor.expand(B, -1)
71
-
72
- output = self.model(batch_rgb, curr_cond)
73
-
74
- # Output processing
75
- if "affine" in self.model_params:
76
- output, _, _ = match_colors_linear(output, batch_rgb)
77
- processed_batches.append(output.cpu())
78
-
79
- # Rebuild
80
- tiles_out = torch.cat(processed_batches, dim=0)
81
- stitched = tiling_module_rebuild.rebuild_with_masks(tiles_out).detach().cpu().numpy()[0]
82
-
83
- torch.cuda.empty_cache()
84
-
85
- return image_RGB.transpose(1, 2, 0), stitched.transpose(1, 2, 0)
86
-
87
- def run(self):
88
- try:
89
- img, denoised_img = self._tile_process()
90
-
91
- # Post-process blending
92
- blend_alpha = self.conditioning[1] / 100
93
- final_denoised = (denoised_img * (1 - blend_alpha)) + (img * blend_alpha)
94
-
95
- return img, final_denoised
96
-
97
- except Exception as e:
98
- print(str(e))
@@ -1,108 +0,0 @@
1
- import torch
2
- from blended_tiling import TilingModule
3
- from RawForge.application.postprocessing import match_colors_linear
4
- from tqdm import tqdm
5
- import rawpy
6
-
7
- class InferenceWorkerRawpy():
8
- def __init__(self, model, model_params, device, rh, conditioning, dims, tile_size=512, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
9
- super().__init__()
10
- self.model = model
11
- self.model_params = model_params
12
- self.device = device
13
- self.rh = rh
14
- self.conditioning = conditioning
15
- self.dims = dims
16
- self.tile_size = tile_size
17
- if 'tile_size' in model_params:
18
- self.tile_size = model_params['tile_size']
19
- self.tile_overlap = tile_overlap
20
- self.batch_size = batch_size
21
- if 'batch_size' in model_params:
22
- self.batch_size = model_params['batch_size']
23
- self._is_cancelled = False
24
- self.disable_tqdm = disable_tqdm
25
-
26
- def cancel(self):
27
- self._is_cancelled = True
28
-
29
- def _tile_process(self):
30
- # Prepare Data
31
- image_RGB = self.rh.rawpy_object.postprocess(
32
- user_wb=[1, 1, 1, 1],
33
- output_color=rawpy.ColorSpace.raw,
34
- demosaic_algorithm= rawpy.DemosaicAlgorithm(3),
35
- no_auto_bright=True,
36
- use_camera_wb=False,
37
- use_auto_wb=False,
38
- gamma=(1, 1),
39
- user_flip=0,
40
- output_bps=16,
41
- no_auto_scale=True,
42
- ) / self.rh.rawpy_object.white_level
43
-
44
- image_RGB = image_RGB.transpose(2, 0, 1)
45
-
46
- tensor_RGB = torch.from_numpy(image_RGB).unsqueeze(0).contiguous()
47
-
48
- full_size = [image_RGB.shape[1], image_RGB.shape[2]]
49
- tile_size = [self.tile_size, self.tile_size]
50
- overlap = [self.tile_overlap, self.tile_overlap]
51
-
52
- # Tiling Setup
53
- tiling_module_rgb = TilingModule(tile_size=[s for s in tile_size], tile_overlap=overlap, base_size=[s for s in full_size])
54
-
55
- tiles_rgb = tiling_module_rgb.split_into_tiles(tensor_RGB).float().to(self.device)
56
-
57
- batches_rgb = torch.split(tiles_rgb, self.batch_size)
58
-
59
- # Conditioning Setup
60
- cond_tensor = torch.as_tensor(self.conditioning, device=self.device).float().unsqueeze(0)
61
- cond_tensor[:, 0] /= 6400
62
- cond_tensor[:, 1] = 0
63
- cond_tensor = cond_tensor[:, 0:1]
64
-
65
- processed_batches = []
66
-
67
- # Determine Dtype
68
- dtype_map = {'mps': torch.float16, 'cuda': torch.float16, 'cpu': torch.bfloat16}
69
- autocast_dtype = dtype_map.get(self.device.type, torch.float32)
70
-
71
- total_batches = len(batches_rgb)
72
-
73
- # Inference Loop
74
- with torch.no_grad():
75
- with torch.autocast(device_type=self.device.type, dtype=autocast_dtype):
76
- for i, (batch_rgb) in tqdm(enumerate(batches_rgb), disable=self.disable_tqdm):
77
- if self._is_cancelled: return None, None
78
-
79
- B = batch_rgb.shape[0]
80
- # Expand conditioning to match batch size
81
- curr_cond = cond_tensor.expand(B, -1)
82
- output = self.model(batch_rgb, curr_cond*0)
83
- processed_batches.append(output.cpu())
84
-
85
- # Rebuild
86
- tiles_out = torch.cat(processed_batches, dim=0)
87
- stitched = tiling_module_rgb.rebuild_with_masks(tiles_out).detach().cpu()
88
-
89
- if "affine" in self.model_params:
90
- stitched, _, _ = match_colors_linear(stitched, tensor_RGB)
91
-
92
- stitched = stitched.numpy()[0]
93
- torch.cuda.empty_cache()
94
-
95
- return image_RGB.transpose(1, 2, 0), stitched.transpose(1, 2, 0)
96
-
97
- def run(self):
98
- try:
99
- img, denoised_img = self._tile_process()
100
-
101
- # Post-process blending
102
- blend_alpha = self.conditioning[1] / 100
103
- final_denoised = (denoised_img * (1 - blend_alpha)) + (img * blend_alpha)
104
-
105
- return img, final_denoised
106
-
107
- except Exception as e:
108
- print(str(e))
@@ -1,59 +0,0 @@
1
- MODEL_REGISTRY = {
2
- "TreeNetDenoise": {
3
- "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1.pt",
4
- "filename": "ShadowWeightedL1.pt",
5
- "max_iso": 65535,
6
- },
7
- "TreeNetDenoiseLight": {
8
- "url": " https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_light.pt",
9
- "filename": "ShadowWeightedL1_light.pt",
10
- "max_iso": 65535,
11
- },
12
- "TreeNetDenoiseSuperLight": {
13
- "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_super_light.pt",
14
- "filename": "ShadowWeightedL1_super_light.pt",
15
- "max_iso": 65535,
16
- },
17
- "TreeNetDenoiseHeavy": {
18
- "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_24_deep_500.pt",
19
- "filename": "ShadowWeightedL1_24_deep_500.pt",
20
- "max_iso": 65535,
21
- },
22
- "Deblur": {
23
- "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/realblur_gamma_140.pt",
24
- "filename": "realblur_gamma_140.pt",
25
- "affine": True,
26
- },
27
- "DeepSharpen": {
28
- "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/Deblur_deep_24.pt",
29
- "filename": "Deblur_deep_24.pt",
30
- "affine": True,
31
- },
32
- "TreeNetDenoiseXTrans": {
33
- "url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.pt",
34
- "filename": "xtrans_fixed_exposure_no_conditioning_400.pt",
35
- "backend": "rawpy",
36
- "conditioning": "false",
37
- },
38
- "XFormerXTrans": {
39
- "url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/xformer.pt",
40
- "filename": "xformer.pt",
41
- "backend": "rawpy",
42
- "conditioning": "false",
43
- "batch_size": 1,
44
- },
45
- "XFormerXTrans352": {
46
- "url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/xformer_352.pt",
47
- "filename": "xformer_352.pt",
48
- "backend": "rawpy",
49
- "conditioning": "false",
50
- "batch_size": 1,
51
- },
52
- "RestormerXTrans": {
53
- "url": "https://github.com/rymuelle/RawForge/releases/download/xtrans_v1.0.0/restormer.pt",
54
- "filename": "restormer.pt",
55
- "backend": "rawpy",
56
- "conditioning": "false",
57
- "batch_size": 1,
58
- },
59
- }
@@ -1,87 +0,0 @@
1
- import torch
2
-
3
- def match_colors_linear(
4
- src: torch.Tensor,
5
- tgt: torch.Tensor,
6
- sample_fraction: float = 0.05
7
- ):
8
- """
9
- Fit per-channel affine color transforms:
10
- tgt ≈ scale * src + bias
11
-
12
- Args:
13
- src: [B, C, H, W] source tensor
14
- tgt: [B, C, H, W] target tensor
15
- sample_fraction: fraction of pixels to use for fitting
16
-
17
- Returns:
18
- transformed_src: source after color matching
19
- scale: [B, C]
20
- bias: [B, C]
21
- """
22
-
23
- B, C, H, W = src.shape
24
- device = src.device
25
-
26
- # Flatten spatial dims
27
- src_flat = src.view(B, C, -1)
28
- tgt_flat = tgt.view(B, C, -1)
29
-
30
- # Sample subset of pixels
31
- N = src_flat.shape[-1]
32
- k = max(64, int(N * sample_fraction))
33
-
34
- idx = torch.randint(0, N, (k,), device=device)
35
-
36
- src_s = src_flat[..., idx] # [B, C, k]
37
- tgt_s = tgt_flat[..., idx]
38
-
39
- # Compute scale and bias using least squares
40
- # scale = cov(src, tgt) / var(src)
41
- src_mean = src_s.mean(-1, keepdim=True)
42
- tgt_mean = tgt_s.mean(-1, keepdim=True)
43
-
44
- src_centered = src_s - src_mean
45
- tgt_centered = tgt_s - tgt_mean
46
-
47
- var_src = (src_centered ** 2).mean(-1)
48
- cov = (src_centered * tgt_centered).mean(-1)
49
-
50
- scale = cov / (var_src + 1e-8) # [B, C]
51
- bias = tgt_mean.squeeze(-1) - scale * src_mean.squeeze(-1)
52
-
53
- # Apply correction
54
- scale_ = scale.view(B, C, 1, 1)
55
- bias_ = bias.view(B, C, 1, 1)
56
- transformed = src * scale_ + bias_
57
-
58
- return transformed, scale, bias
59
-
60
-
61
- def scaled_dot_product(x1, x2, eps=1e-6):
62
- dot = (x1 * x2).sum(axis=2, keepdims=True)
63
- x1_mag = (x1 * x1).sum(axis=2, keepdims=True) ** .5
64
- x2_mag = (x2 * x2).sum(axis=2, keepdims=True) ** .5
65
- return dot/(x1_mag+x2_mag+eps)
66
-
67
-
68
- def postprocess(img, denoised, lumi_blend=0, chroma_blend=0, eps=1e-6,
69
- clip_highlights=False):
70
- # Suggested by Jakob Andrén
71
- dot = (img * denoised).sum(axis=2, keepdims=True)
72
- img_mag = (img * img).sum(axis=2, keepdims=True) ** .5
73
- denoised_mag = (denoised * denoised).sum(axis=2, keepdims=True) ** .5
74
- # Project denoised along original image vector
75
- lumi = dot / (denoised_mag ** 2 + eps) * denoised
76
- chroma = img - lumi
77
- output = (1-lumi_blend) * denoised + lumi * (lumi_blend) + chroma_blend * chroma
78
-
79
- if clip_highlights:
80
- output = clip_highlights_func(img, output)
81
-
82
- return output
83
-
84
- def clip_highlights_func(img, denoised):
85
- mask = img == 1
86
- denoised[mask] = 1
87
- return denoised
@@ -1,10 +0,0 @@
1
- import torch
2
-
3
- def can_use_gpu():
4
- if not torch.cuda.is_available():
5
- return False
6
- try:
7
- x = torch.zeros(1, device="cuda")
8
- return True
9
- except Exception:
10
- return False
@@ -1,2 +0,0 @@
1
- [console_scripts]
2
- rawforge = RawForge.main:main
File without changes
File without changes
File without changes
File without changes