rawforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
RawForge/__init__.py ADDED
File without changes
@@ -0,0 +1,99 @@
1
+ import torch
2
+ from blended_tiling import TilingModule
3
+ from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004
4
+ from RawForge.application.postprocessing import match_colors_linear
5
+ from tqdm import tqdm
6
+
7
+ class InferenceWorker():
8
+ def __init__(self, model, model_params, device, rh, conditioning, dims, tile_size=256, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
9
+ super().__init__()
10
+ self.model = model
11
+ self.model_params = model_params
12
+ self.device = device
13
+ self.rh = rh
14
+ self.conditioning = conditioning
15
+ self.dims = dims
16
+ # Quick and dirty hack to force to be even
17
+ self.tile_size = tile_size
18
+ self.tile_overlap = tile_overlap
19
+ self.batch_size = batch_size
20
+ self._is_cancelled = False
21
+ self.disable_tqdm = disable_tqdm
22
+
23
+ def cancel(self):
24
+ self._is_cancelled = True
25
+
26
+ def _tile_process(self):
27
+ # Prepare Data
28
+ image_RGGB = self.rh.as_rggb(dims=self.dims, colorspace='lin_rec2020')
29
+ image_RGB = self.rh.as_rgb(dims=self.dims, demosaicing_func=demosaicing_CFA_Bayer_Malvar2004, colorspace='lin_rec2020', clip=True)
30
+
31
+ tensor_image = torch.from_numpy(image_RGGB).unsqueeze(0).contiguous()
32
+ tensor_RGB = torch.from_numpy(image_RGB).unsqueeze(0).contiguous()
33
+
34
+ full_size = [image_RGGB.shape[1], image_RGGB.shape[2]]
35
+ tile_size = [self.tile_size // 2, self.tile_size // 2]
36
+ overlap = [self.tile_overlap, self.tile_overlap]
37
+
38
+ # Tiling Setup
39
+ tiling_module = TilingModule(tile_size=tile_size, tile_overlap=overlap, base_size=full_size)
40
+ tiling_module_rgb = TilingModule(tile_size=[s*2 for s in tile_size], tile_overlap=overlap, base_size=[s*2 for s in full_size])
41
+ tiling_module_rebuild = TilingModule(tile_size=[s*2 for s in tile_size], tile_overlap=overlap, base_size=[s*2 for s in full_size])
42
+
43
+ tiles = tiling_module.split_into_tiles(tensor_image).float().to(self.device)
44
+ tiles_rgb = tiling_module_rgb.split_into_tiles(tensor_RGB).float().to(self.device)
45
+
46
+ batches = torch.split(tiles, self.batch_size)
47
+ batches_rgb = torch.split(tiles_rgb, self.batch_size)
48
+
49
+ # Conditioning Setup
50
+ cond_tensor = torch.as_tensor(self.conditioning, device=self.device).float().unsqueeze(0)
51
+ cond_tensor[:, 0] /= 6400
52
+ cond_tensor[:, 1] = 0
53
+ cond_tensor = cond_tensor[:, 0:1]
54
+
55
+ processed_batches = []
56
+
57
+ # Determine Dtype
58
+ dtype_map = {'mps': torch.float16, 'cuda': torch.float16, 'cpu': torch.bfloat16}
59
+ autocast_dtype = dtype_map.get(self.device.type, torch.float32)
60
+
61
+ total_batches = len(batches_rgb)
62
+
63
+ # Inference Loop
64
+ with torch.no_grad():
65
+ with torch.autocast(device_type=self.device.type, dtype=autocast_dtype):
66
+ for i, (batch, batch_rgb) in tqdm(enumerate(zip(batches, batches_rgb)), disable=self.disable_tqdm):
67
+ if self._is_cancelled: return None, None
68
+
69
+ B = batch.shape[0]
70
+ # Expand conditioning to match batch size
71
+ curr_cond = cond_tensor.expand(B, -1)
72
+
73
+ output = self.model(batch_rgb, curr_cond)
74
+
75
+ # Output processing
76
+ if "affine" in self.model_params:
77
+ output, _, _ = match_colors_linear(output, batch_rgb)
78
+ processed_batches.append(output.cpu())
79
+
80
+ # Rebuild
81
+ tiles_out = torch.cat(processed_batches, dim=0)
82
+ stitched = tiling_module_rebuild.rebuild_with_masks(tiles_out).detach().cpu().numpy()[0]
83
+
84
+ torch.cuda.empty_cache()
85
+
86
+ return image_RGB.transpose(1, 2, 0), stitched.transpose(1, 2, 0)
87
+
88
+ def run(self):
89
+ try:
90
+ img, denoised_img = self._tile_process()
91
+
92
+ # Post-process blending
93
+ blend_alpha = self.conditioning[1] / 100
94
+ final_denoised = (denoised_img * (1 - blend_alpha)) + (img * blend_alpha)
95
+
96
+ return img, final_denoised
97
+
98
+ except Exception as e:
99
+ print(str(e))
@@ -0,0 +1,30 @@
1
+ MODEL_REGISTRY = {
2
+ "TreeNetDenoise": {
3
+ "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1.pt",
4
+ "filename": "ShadowWeightedL1.pt"
5
+ },
6
+ "TreeNetDenoiseLight": {
7
+ "url": " https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_light.pt",
8
+ "filename": "ShadowWeightedL1_light.pt"
9
+ },
10
+ "TreeNetDenoiseSuperLight": {
11
+ "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_super_light.pt",
12
+ "filename": "ShadowWeightedL1_super_light.pt"
13
+ },
14
+
15
+ "TreeNetDenoiseHeavy": {
16
+ "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/ShadowWeightedL1_24_deep_500.pt",
17
+ "filename": "ShadowWeightedL1_24_deep_500.pt"
18
+ },
19
+
20
+ "Deblur": {
21
+ "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/realblur_gamma_140.pt",
22
+ "filename": "realblur_gamma_140.pt",
23
+ "affine": True,
24
+ },
25
+ "DeepSharpen": {
26
+ "url": "https://github.com/rymuelle/RawRefinery/releases/download/v1.2.1-alpha/Deblur_deep_24.pt",
27
+ "filename": "Deblur_deep_24.pt",
28
+ "affine": True,
29
+ },
30
+ }
@@ -0,0 +1,188 @@
1
+ import torch
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from platformdirs import user_data_dir
5
+ from time import perf_counter
6
+ import requests
7
+ from cryptography.hazmat.primitives import hashes, serialization
8
+ from cryptography.hazmat.primitives.asymmetric import padding
9
+
10
+ from RawHandler.RawHandler import RawHandler
11
+ from RawForge.application.dng_utils import convert_color_matrix, to_dng
12
+ from RawForge.application.utils import can_use_gpu
13
+
14
+ from RawForge.application.MODEL_REGISTRY import MODEL_REGISTRY
15
+ from RawForge.application.InferenceWorker import InferenceWorker
16
+
17
+ key_string = '''-----BEGIN PUBLIC KEY-----
18
+ MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA8iRGMPqFIFVF0TM/AbMI
19
+ DJUqdjY1S7dGn6rYjLixnhohHLKIo2ZhFUfPaeYrDoqJblP9MxbBLm6a782/Us0A
20
+ vblTQOsdVFHOlVEiDUkG9CJrzh7arqJF+v2LLP9qPIcL5QdIHM+BCKlbNPBU/TJB
21
+ 49b6a+1FfKCEeY1z9F8H6GCHGeRB43lz5/1yMoBnq//Rc7NrvinwlNcFYHHM1oj6
22
+ Hk6KPkgitya11QgTTva+XimR7cbw7h9/vJKbrS7tValApio3Ypmx7AKf6/k16S9K
23
+ BCFDN3cyWmjItQNzEWbO2nuM9d3PX2O4FcZVfsA/GU0qSuKFUrrN0KcxKGglLdu4
24
+ 3Nt3JmOh+VebVWPSTeMzn2R1LDs2CsDpGG+KnHso80HBBq6RuHTugTiUZ2EwjiXN
25
+ lRS7olKFQOPwT0tm1EVkH8IxQgV4KJbCb6hAScvWfsDdsP+bu4R+QI9hfU6HCWG3
26
+ a8w1AY+5GT7zp1pzKifmnXgMXF3VnAPTGRhpIvPQfum2+tppLZueXlalobK0MDzi
27
+ n36TNhRELao1W7Tvc18fxyZn37BBgKs89JO85/cjD72yhVowW7Hy9lL7RnB+etaN
28
+ ehXoYFsJReNmD5KNgRtmXbsCUJ+D8v7BVYNGl1UgebmQnMdMWyiU/3l1Uuy8HS3L
29
+ 1QJYp42f5QqONttCqVzgzrECAwEAAQ==
30
+ -----END PUBLIC KEY-----'''
31
+
32
+
33
+ class ModelHandler():
34
+ """
35
+ Manages the LifeCycle of the Model, the RawHandler, and the Worker Thread.
36
+ """
37
+ def __init__(self):
38
+ super().__init__()
39
+
40
+ self.model = None
41
+ self.rh = None
42
+ self.iso = 100
43
+ self.colorspace = 'lin_rec2020'
44
+
45
+ # Manage devices
46
+ devices = {
47
+ "cuda": can_use_gpu(),
48
+ "mps": torch.backends.mps.is_available(),
49
+ "cpu": lambda : True
50
+ }
51
+ self.devices = [d for d, is_available in devices.items() if is_available]
52
+ self.set_device(self.devices[0])
53
+
54
+ self.filename = None
55
+ self.start_time = None
56
+ self.model_params = {}
57
+
58
+ self.pub = serialization.load_pem_public_key(key_string.encode('utf-8'))
59
+
60
+ def load_rh(self, path):
61
+ """Loads the raw file handler"""
62
+ self.rh = RawHandler(path, colorspace=self.colorspace)
63
+ if 'EXIF ISOSpeedRatings' in self.rh.full_metadata:
64
+ self.iso = int(self.rh.full_metadata['EXIF ISOSpeedRatings'].values[0])
65
+ else:
66
+ self.iso = 100
67
+ return self.iso
68
+
69
+ def load_model(self, model_key):
70
+ """Loads a model by key from the registry"""
71
+ if model_key not in MODEL_REGISTRY:
72
+ print(f"Model {model_key} not found in registry.")
73
+ return
74
+
75
+ conf = MODEL_REGISTRY[model_key]
76
+ self.model_params = conf
77
+ app_name = "RawForge"
78
+ data_dir = Path(user_data_dir(app_name))
79
+ model_path = data_dir / conf["filename"]
80
+
81
+ # Handle Download
82
+ if not model_path.is_file():
83
+ if conf["url"]:
84
+ print(f"Downloading {model_key}...")
85
+ if not self._download_file(conf["url"], model_path):
86
+ print("Failed to download model.")
87
+ return
88
+ else:
89
+ print(f"Model file not found at {model_path}")
90
+ return
91
+
92
+ try:
93
+ print(f"Loading model: {model_path}")
94
+ # Verify model before load
95
+ self._verify_model(model_path, model_path.with_suffix(f'{model_path.suffix}.sig'))
96
+
97
+ loaded = torch.jit.load(model_path, map_location='cpu')
98
+ self.model = loaded.eval().to(self.device)
99
+ except Exception as e:
100
+ print(f"Failed to load model: {e}")
101
+
102
+ def set_device(self, device):
103
+ self.device = torch.device(device)
104
+ if self.model:
105
+ self.model.to(self.device)
106
+ print(f"Using Device {self.device} from {device}")
107
+
108
+ def run_inference(self, conditioning, dims=None, inference_kwargs={}):
109
+ """Starts the worker thread"""
110
+ if not self.model or not self.rh:
111
+ print("Model or Image not loaded.")
112
+ return
113
+
114
+ worker = InferenceWorker(self.model, self.model_params, self.device, self.rh, conditioning, dims, **inference_kwargs)
115
+ img, final_denoised = worker.run()
116
+
117
+ return img, final_denoised
118
+
119
+
120
+ def generate_thumbnail(self, size=400):
121
+ if not self.rh: return None
122
+ thumb = self.rh.generate_thumbnail(min_preview_size=size, clip=True)
123
+ return thumb
124
+
125
+ def _verify_model(self, dest_path, sig_path):
126
+ try:
127
+ data = Path(dest_path).read_bytes()
128
+ signature = Path(sig_path).read_bytes()
129
+ self.pub.verify(
130
+ signature,
131
+ data,
132
+ padding.PSS(
133
+ mgf=padding.MGF1(hashes.SHA256()),
134
+ salt_length=padding.PSS.MAX_LENGTH
135
+ ),
136
+ hashes.SHA256(),
137
+ )
138
+ print(f"Model {dest_path} verified!")
139
+ return True
140
+ except Exception as e:
141
+ print(e)
142
+ if dest_path.exists():
143
+ dest_path.unlink()
144
+ if sig_path.exists():
145
+ sig_path.unlink()
146
+ print(f"Model {dest_path} not verified! Deleting.")
147
+ return False
148
+
149
+
150
+ def _download_file(self, url, dest_path):
151
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
152
+ try:
153
+ r = requests.get(url, stream=True)
154
+ r.raise_for_status()
155
+ with open(dest_path, 'wb') as f:
156
+ for chunk in r.iter_content(chunk_size=8192):
157
+ f.write(chunk)
158
+
159
+ # Download model signature
160
+ r = requests.get(url + '.sig', stream=True)
161
+ r.raise_for_status()
162
+ sig_path = dest_path.with_suffix(f'{dest_path.suffix}.sig')
163
+ with open(sig_path, 'wb') as f:
164
+ for chunk in r.iter_content(chunk_size=8192):
165
+ f.write(chunk)
166
+ print("test verification")
167
+ return self._verify_model(dest_path, sig_path)
168
+
169
+ except Exception as e:
170
+ print(e)
171
+ return False
172
+
173
+
174
+ def handle_full_image(self, denoised, filename, save_cfa):
175
+
176
+ transform_matrix = np.linalg.inv(
177
+ self.rh.rgb_colorspace_transform(colorspace=self.colorspace)
178
+ )
179
+
180
+ CCM = self.rh.rgb_colorspace_transform(colorspace='XYZ')
181
+ CCM = np.linalg.inv(CCM)
182
+
183
+ transformed = denoised @ transform_matrix.T
184
+ uint_img = np.clip(transformed * 2**16-1, 0, 2**16-1).astype(np.uint16)
185
+ ccm1 = convert_color_matrix(CCM)
186
+ to_dng(uint_img, self.rh, filename, ccm1, save_cfa=save_cfa, convert_to_cfa=True)
187
+
188
+
@@ -0,0 +1,201 @@
1
+ import numpy as np
2
+ from pidng.core import RAW2DNG, DNGTags, Tag
3
+ from pidng.defs import *
4
+
5
+ def get_ratios(string, rh):
6
+ return [x.as_integer_ratio() for x in rh.full_metadata[string].values]
7
+
8
+
9
+ def get_as_shot_neutral(rh, denominator=10000):
10
+
11
+ cam_mul = rh.core_metadata.camera_white_balance
12
+
13
+ if cam_mul[0] == 0 or cam_mul[2] == 0:
14
+ return [[denominator, denominator], [denominator, denominator], [denominator, denominator]]
15
+
16
+ r_neutral = cam_mul[1] / cam_mul[0]
17
+ g_neutral = 1.0
18
+ b_neutral = cam_mul[1] / cam_mul[2]
19
+
20
+ return [
21
+ [int(r_neutral * denominator), denominator],
22
+ [int(g_neutral * denominator), denominator],
23
+ [int(b_neutral * denominator), denominator],
24
+ ]
25
+ def convert_ccm_to_rational(matrix_3x3, denominator=10000):
26
+
27
+ numerator_matrix = np.round(matrix_3x3 * denominator).astype(int)
28
+ numerators_flat = numerator_matrix.flatten()
29
+ ccm_rational = [[num, denominator] for num in numerators_flat]
30
+
31
+ return ccm_rational
32
+
33
+
34
+
35
+ def simulate_CFA(image, pattern="RGGB", cfa_type="bayer"):
36
+ """
37
+ Simulate a CFA image from an RGB image.
38
+
39
+ Args:
40
+ image: numpy array (H, W, 3), RGB image.
41
+ pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer,
42
+ or ignored if cfa_type="xtrans".
43
+ cfa_type: "bayer" or "xtrans".
44
+
45
+ Returns:
46
+ cfa: numpy array (H, W) CFA image.
47
+ sparse_mask: numpy array (H, W, r), mask of pixels.
48
+ """
49
+ width = image.shape[1]
50
+ height = image.shape[0]
51
+ cfa = np.zeros((height, width, 3), dtype=image.dtype)
52
+ sparse_mask = np.zeros((height, width, 3), dtype=image.dtype)
53
+ if cfa_type == "bayer":
54
+ # 2×2 Bayer masks
55
+ masks = {
56
+ "RGGB": np.array([["R", "G"], ["G", "B"]]),
57
+ "BGGR": np.array([["B", "G"], ["G", "R"]]),
58
+ "GRBG": np.array([["G", "R"], ["B", "G"]]),
59
+ "GBRG": np.array([["G", "B"], ["R", "G"]]),
60
+ }
61
+ if pattern not in masks:
62
+ raise ValueError(f"Unknown Bayer pattern: {pattern}")
63
+
64
+ mask = masks[pattern]
65
+ cmap = {"R": 0, "G": 1, "B": 2}
66
+
67
+ for i in range(2):
68
+ for j in range(2):
69
+ ch = cmap[mask[i, j]]
70
+ cfa[i::2, j::2, ch] = image[i::2, j::2, ch]
71
+ sparse_mask[i::2, j::2, ch] = 1
72
+ elif cfa_type == "xtrans":
73
+ # Fuji X-Trans 6×6 repeating pattern
74
+ xtrans_pattern = np.array([
75
+ ["G","B","R","G","R","B"],
76
+ ["R","G","G","B","G","G"],
77
+ ["B","G","G","R","G","G"],
78
+ ["G","R","B","G","B","R"],
79
+ ["B","G","G","R","G","G"],
80
+ ["R","G","G","B","G","G"],
81
+ ])
82
+ cmap = {"R":0, "G":1, "B":2}
83
+
84
+ for i in range(6):
85
+ for j in range(6):
86
+ ch = cmap[xtrans_pattern[i, j]]
87
+ cfa[i::6, j::6, ch] = image[i::6, j::6, ch]
88
+ sparse_mask[i::2, j::2, ch] = 1
89
+ else:
90
+ raise ValueError(f"Unknown CFA type: {cfa_type}")
91
+
92
+ return cfa.sum(axis=2), sparse_mask
93
+
94
+ def to_dng(uint_img, rh, filepath, ccm1, save_cfa=True, convert_to_cfa=True, use_orig_wb_points=False):
95
+ width = uint_img.shape[1]
96
+ height = uint_img.shape[0]
97
+ bpp = 16
98
+
99
+ t = DNGTags()
100
+
101
+ if save_cfa:
102
+ if convert_to_cfa:
103
+ cfa, _ = simulate_CFA(uint_img, pattern="RGGB", cfa_type="bayer")
104
+ uint_img = cfa.astype(np.uint16)
105
+ t.set(Tag.BitsPerSample, bpp)
106
+ t.set(Tag.SamplesPerPixel, 1)
107
+ t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Color_Filter_Array)
108
+ t.set(Tag.CFARepeatPatternDim, [2,2])
109
+ t.set(Tag.CFAPattern, CFAPattern.RGGB)
110
+ t.set(Tag.BlackLevelRepeatDim, [2,2])
111
+ # This should not be used except to save testing patches
112
+ if use_orig_wb_points:
113
+ bl = rh.core_metadata.black_level_per_channel
114
+ t.set(Tag.BlackLevel, bl)
115
+ t.set(Tag.WhiteLevel, rh.core_metadata.white_level)
116
+ else:
117
+ t.set(Tag.BlackLevel, [0, 0, 0, 0])
118
+ t.set(Tag.WhiteLevel, 65535)
119
+ else:
120
+ t.set(Tag.BitsPerSample, [bpp, bpp, bpp]) # 3 channels for RGB
121
+ t.set(Tag.SamplesPerPixel, 3) # 3 for RGB
122
+ t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Linear_Raw)
123
+ t.set(Tag.BlackLevel,[0,0,0])
124
+ t.set(Tag.WhiteLevel, [65535, 65535, 65535])
125
+
126
+ t.set(Tag.ImageWidth, width)
127
+ t.set(Tag.ImageLength, height)
128
+ t.set(Tag.PlanarConfiguration, 1) # 1 for chunky (interleaved RGB)
129
+
130
+ t.set(Tag.TileWidth, width)
131
+ t.set(Tag.TileLength, height)
132
+
133
+ t.set(Tag.ColorMatrix1, ccm1)
134
+ t.set(Tag.CalibrationIlluminant1, CalibrationIlluminant.D65)
135
+ wb = get_as_shot_neutral(rh)
136
+ t.set(Tag.AsShotNeutral, wb)
137
+ t.set(Tag.BaselineExposure, [[0,100]])
138
+
139
+
140
+ try:
141
+ t.set(Tag.Make, rh.full_metadata['Image Make'].values)
142
+ t.set(Tag.Model, rh.full_metadata['Image Model'].values)
143
+ t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])
144
+ exposures = get_ratios('EXIF ExposureTime', rh)
145
+ fnumber = get_ratios('EXIF FNumber', rh)
146
+ ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh)
147
+ FocalLength = get_ratios('EXIF FocalLength', rh)
148
+ t.set(Tag.FocalLength, FocalLength)
149
+ t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)
150
+ t.set(Tag.ExposureBiasValue, ExposureBiasValue)
151
+ t.set(Tag.ExposureTime, exposures)
152
+ t.set(Tag.FNumber, fnumber)
153
+ t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)
154
+ except:
155
+ print("Could not save EXIF")
156
+ t.set(Tag.DNGVersion, DNGVersion.V1_4)
157
+ t.set(Tag.DNGBackwardVersion, DNGVersion.V1_2)
158
+ t.set(Tag.PreviewColorSpace, PreviewColorSpace.Adobe_RGB)
159
+
160
+ r = RAW2DNG()
161
+
162
+ r.options(t, path="", compress=False)
163
+
164
+ r.convert(uint_img, filename=filepath)
165
+
166
+
167
+
168
+ def convert_color_matrix(matrix):
169
+ """
170
+ Converts a 3x3 NumPy matrix of floats into a list of integer pairs.
171
+
172
+ Each float value in the matrix is converted to a fractional representation
173
+ with a denominator of 10000. The numerator is calculated by scaling the
174
+ float value by 10000 and rounding to the nearest integer.
175
+
176
+ Args:
177
+ matrix: A 3x3 NumPy array with floating-point numbers.
178
+
179
+ Returns:
180
+ A list of 9 lists, where each inner list contains two integers
181
+ representing the numerator and denominator.
182
+ """
183
+ # Ensure the input is a NumPy array
184
+ if not isinstance(matrix, np.ndarray):
185
+ raise TypeError("Input must be a NumPy array.")
186
+
187
+ # Flatten the 3x3 matrix into a 1D array of 9 elements
188
+ flattened_matrix = matrix.flatten()
189
+
190
+ # Initialize the list for the converted matrix
191
+ converted_list = []
192
+ denominator = 10000
193
+
194
+ # Iterate over each element in the flattened matrix
195
+ for element in flattened_matrix:
196
+ # Scale the element, round it to the nearest integer, and cast to int
197
+ numerator = int(round(element * denominator))
198
+ # Append the [numerator, denominator] pair to the result list
199
+ converted_list.append([numerator, denominator])
200
+
201
+ return converted_list
@@ -0,0 +1,58 @@
1
+ import torch
2
+
3
+ def match_colors_linear(
4
+ src: torch.Tensor,
5
+ tgt: torch.Tensor,
6
+ sample_fraction: float = 0.05
7
+ ):
8
+ """
9
+ Fit per-channel affine color transforms:
10
+ tgt ≈ scale * src + bias
11
+
12
+ Args:
13
+ src: [B, C, H, W] source tensor
14
+ tgt: [B, C, H, W] target tensor
15
+ sample_fraction: fraction of pixels to use for fitting
16
+
17
+ Returns:
18
+ transformed_src: source after color matching
19
+ scale: [B, C]
20
+ bias: [B, C]
21
+ """
22
+
23
+ B, C, H, W = src.shape
24
+ device = src.device
25
+
26
+ # Flatten spatial dims
27
+ src_flat = src.view(B, C, -1)
28
+ tgt_flat = tgt.view(B, C, -1)
29
+
30
+ # Sample subset of pixels
31
+ N = src_flat.shape[-1]
32
+ k = max(64, int(N * sample_fraction))
33
+
34
+ idx = torch.randint(0, N, (k,), device=device)
35
+
36
+ src_s = src_flat[..., idx] # [B, C, k]
37
+ tgt_s = tgt_flat[..., idx]
38
+
39
+ # Compute scale and bias using least squares
40
+ # scale = cov(src, tgt) / var(src)
41
+ src_mean = src_s.mean(-1, keepdim=True)
42
+ tgt_mean = tgt_s.mean(-1, keepdim=True)
43
+
44
+ src_centered = src_s - src_mean
45
+ tgt_centered = tgt_s - tgt_mean
46
+
47
+ var_src = (src_centered ** 2).mean(-1)
48
+ cov = (src_centered * tgt_centered).mean(-1)
49
+
50
+ scale = cov / (var_src + 1e-8) # [B, C]
51
+ bias = tgt_mean.squeeze(-1) - scale * src_mean.squeeze(-1)
52
+
53
+ # Apply correction
54
+ scale_ = scale.view(B, C, 1, 1)
55
+ bias_ = bias.view(B, C, 1, 1)
56
+ transformed = src * scale_ + bias_
57
+
58
+ return transformed, scale, bias
@@ -0,0 +1,10 @@
1
+ import torch
2
+
3
+ def can_use_gpu():
4
+ if not torch.cuda.is_available():
5
+ return False
6
+ try:
7
+ x = torch.zeros(1, device="cuda")
8
+ return True
9
+ except Exception:
10
+ return False
RawForge/main.py ADDED
@@ -0,0 +1,41 @@
1
+ import sys
2
+ import platform
3
+ from RawForge.application.ModelHandler import ModelHandler
4
+ import argparse
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description='A command line utility for processing raw images.')
8
+ parser.add_argument('model', type=str, help='The name of the model to use.')
9
+ parser.add_argument('in_file', type=str, help='The name of the file to open.')
10
+ parser.add_argument('out_file', type=str, help='The name of the file to save.')
11
+ parser.add_argument('--conditioning', type=str, help='Conditioning array to feed model.')
12
+ parser.add_argument('--dims', type=int, nargs=4, metavar=("x0", "x1", "y0", "y1"), help='Optional crop dimensions.')
13
+
14
+ parser.add_argument('--cfa', action='store_true', help='Save the image as a CFA image (default: False).')
15
+ parser.add_argument('--device', type=str, help='Set device backend (cuda, cpu, mps).')
16
+ parser.add_argument('--disable_tqdm', action='store_true', help='Disable the progress bar.')
17
+ parser.add_argument('--tile_size', type=int, help='Set tile size. (default: 256)', default=256)
18
+
19
+ args = parser.parse_args()
20
+
21
+ handler = ModelHandler()
22
+
23
+ handler.load_model(args.model)
24
+
25
+ iso = handler.load_rh(args.in_file)
26
+
27
+ if not args.conditioning:
28
+ conditioning = [iso, 0]
29
+
30
+ if args.device:
31
+ handler.set_device(args.device)
32
+
33
+ inference_kwargs = {"disable_tqdm": args.disable_tqdm,
34
+ "tile_size": args.tile_size}
35
+ img, denoised_image = handler.run_inference(conditioning=conditioning, dims=args.dims, inference_kwargs=inference_kwargs)
36
+
37
+ handler.handle_full_image(denoised_image, args.out_file, args.cfa)
38
+
39
+
40
+ if __name__ == '__main__':
41
+ main()
@@ -0,0 +1,106 @@
1
+ Metadata-Version: 2.4
2
+ Name: rawforge
3
+ Version: 0.0.1
4
+ Summary: A compute backend/CLI application for using machine learning models on raw images.
5
+ Author: Ryan Mueller
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 rymuelle
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Project-URL: Homepage, https://github.com/rymuelle/RawForge
29
+ Project-URL: Bug Tracker, https://github.com/rymuelle/RawForge/issues
30
+ Keywords: python,Machine Learning,Camera Raw
31
+ Classifier: Programming Language :: Python :: 3
32
+ Classifier: Programming Language :: Python :: 3.11
33
+ Classifier: Programming Language :: Python :: 3.12
34
+ Classifier: Programming Language :: Python :: 3.13
35
+ Classifier: License :: OSI Approved :: MIT License
36
+ Classifier: Operating System :: OS Independent
37
+ Requires-Python: >=3.11
38
+ Description-Content-Type: text/markdown
39
+ License-File: LICENSE
40
+ Requires-Dist: pidng~=4.0.9
41
+ Requires-Dist: numpy>=2.3
42
+ Requires-Dist: RawHandler~=0.0.2
43
+ Requires-Dist: colour_demosaicing~=0.2.6
44
+ Requires-Dist: blended_tiling>=0.0.1.dev7
45
+ Requires-Dist: requests~=2.32
46
+ Requires-Dist: platformdirs~=4.5
47
+ Requires-Dist: tqdm~=4.67
48
+ Requires-Dist: cryptography~=46.0
49
+ Dynamic: license-file
50
+
51
+ # RawRefinery
52
+
53
+ <!-- [![PyPI version](https://img.shields.io/pypi/v/rawrefinery.svg)](https://pypi.org/project/rawrefinery/)
54
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
55
+ [![Python version](https://img.shields.io/pypi/pyversions/rawrefinery.svg)](https://pypi.org/project/rawrefinery/) -->
56
+
57
+ **RawForge** is an open-source command line application for **raw image quality refinement and denoising**.
58
+
59
+ Currently in **alpha release**.
60
+
61
+ Sorry for the lack of instructions!
62
+
63
+ ---
64
+ Example command line syntax:
65
+
66
+ ```bash
67
+ rawforge TreeNetDenoiseHeavy test.CR2 test_heavy.dng --cfa
68
+ ```
69
+
70
+
71
+ ```bash
72
+ usage: rawforge [-h] [--conditioning CONDITIONING] [--dims x0 x1 y0 y1] [--cfa] [--device DEVICE] [--disable_tqdm] [--tile_size TILE_SIZE] model in_file out_file
73
+
74
+ A command line utility for processing raw images.
75
+
76
+ positional arguments:
77
+ model The name of the model to use.
78
+ in_file The name of the file to open.
79
+ out_file The name of the file to save.
80
+
81
+ options:
82
+ -h, --help show this help message and exit
83
+ --conditioning CONDITIONING
84
+ Conditioning array to feed model.
85
+ --dims x0 x1 y0 y1 Optional crop dimensions.
86
+ --cfa Save the image as a CFA image (default: False).
87
+ --device DEVICE Set device backend (cuda, cpu, mps).
88
+ --disable_tqdm Disable the progress bar.
89
+ --tile_size TILE_SIZE
90
+ Set tile size. (default: 256)
91
+ ```
92
+
93
+ ----
94
+
95
+ ## Acknowledgments
96
+
97
+ With thanks to:
98
+
99
+
100
+ > Brummer, Benoit; De Vleeschouwer, Christophe. (2025).
101
+ > *Raw Natural Image Noise Dataset.*
102
+ > [https://doi.org/10.14428/DVN/DEQCIM](https://doi.org/10.14428/DVN/DEQCIM), Open Data @ UCLouvain, V1.
103
+
104
+ > Chen, Liangyu; Chu, Xiaojie; Zhang, Xiangyu; Chen, Jianhao. (2022).
105
+ > *NAFNet: Simple Baselines for Image Restoration.*
106
+ > [https://doi.org/10.48550/arXiv.2208.04677](https://doi.org/10.48550/arXiv.2208.04677), arXiv, V1.
@@ -0,0 +1,14 @@
1
+ RawForge/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ RawForge/main.py,sha256=vhA0Q8AT9-5DoZiAEbDXr9pSXDfaGIqvs4Bg12Dvt4o,1696
3
+ RawForge/application/InferenceWorker.py,sha256=7a_zlukd39x9n0Hq1hnIJJq3shLbHvYPONHW5ywLMJo,4216
4
+ RawForge/application/MODEL_REGISTRY.py,sha256=DnVUACayTTXoLQvdCH1DUekGEGQXLgWIlFzvQXs2pmI,1237
5
+ RawForge/application/ModelHandler.py,sha256=j2mlTi_WsCHHW7U808sB9ChQjbeRgn0jbGFLV5Xjfl8,6815
6
+ RawForge/application/dng_utils.py,sha256=F_TFWQ7A19QIBhK3aXAEkUuXpHO8vOA5LixDQNGxGnM,7053
7
+ RawForge/application/postprocessing.py,sha256=_neffnnVbIDp5Ut2LisFk-xX5QIId72qKEBPTSc0c9Q,1493
8
+ RawForge/application/utils.py,sha256=lpGLpGkmhhlGBrav1r0ygBCAbBT0A6kC9fje46_eIu8,205
9
+ rawforge-0.0.1.dist-info/licenses/LICENSE,sha256=C6iC-8U95WxetY9TsV-RHN-CbI9wPI3nFMHMpu_Ve4E,1065
10
+ rawforge-0.0.1.dist-info/METADATA,sha256=WcnBA2HYI68sPOZmNBzKvJVfA_rrH6mXLYsC_ysHADI,4266
11
+ rawforge-0.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
+ rawforge-0.0.1.dist-info/entry_points.txt,sha256=7yVkkrWfZQqHO9012zxV1igRLOpkNf42KemeslwzhL0,48
13
+ rawforge-0.0.1.dist-info/top_level.txt,sha256=tCW6Kg9PQlghoo0V4yDVw5j49b2LbGZOslrsaRXE1iw,9
14
+ rawforge-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ rawforge = RawForge.main:main
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 rymuelle
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ RawForge