rawforge 0.0.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {rawforge-0.0.2 → rawforge-0.2.0}/PKG-INFO +10 -2
  2. rawforge-0.2.0/RawForge/application/ImageSaver.py +55 -0
  3. rawforge-0.2.0/RawForge/application/InferenceWorker.py +79 -0
  4. rawforge-0.2.0/RawForge/application/MODEL_REGISTRY.py +135 -0
  5. {rawforge-0.0.2 → rawforge-0.2.0}/RawForge/application/ModelHandler.py +15 -94
  6. rawforge-0.2.0/RawForge/application/helpers/censored_fit.py +100 -0
  7. rawforge-0.2.0/RawForge/application/helpers/get_image.py +41 -0
  8. rawforge-0.2.0/RawForge/application/helpers/utils.py +17 -0
  9. rawforge-0.2.0/RawForge/application/postprocessing.py +92 -0
  10. {rawforge-0.0.2 → rawforge-0.2.0}/RawForge/main.py +39 -21
  11. {rawforge-0.0.2 → rawforge-0.2.0}/pyproject.toml +10 -3
  12. {rawforge-0.0.2 → rawforge-0.2.0}/rawforge.egg-info/PKG-INFO +10 -2
  13. {rawforge-0.0.2 → rawforge-0.2.0}/rawforge.egg-info/SOURCES.txt +4 -2
  14. rawforge-0.2.0/rawforge.egg-info/entry_points.txt +2 -0
  15. {rawforge-0.0.2 → rawforge-0.2.0}/rawforge.egg-info/requires.txt +13 -1
  16. rawforge-0.0.2/RawForge/application/InferenceWorker.py +0 -98
  17. rawforge-0.0.2/RawForge/application/InferenceWorkerRawpy.py +0 -108
  18. rawforge-0.0.2/RawForge/application/MODEL_REGISTRY.py +0 -60
  19. rawforge-0.0.2/RawForge/application/postprocessing.py +0 -87
  20. rawforge-0.0.2/RawForge/application/utils.py +0 -10
  21. rawforge-0.0.2/rawforge.egg-info/entry_points.txt +0 -2
  22. {rawforge-0.0.2 → rawforge-0.2.0}/LICENSE +0 -0
  23. {rawforge-0.0.2 → rawforge-0.2.0}/README.md +0 -0
  24. {rawforge-0.0.2 → rawforge-0.2.0}/RawForge/__init__.py +0 -0
  25. {rawforge-0.0.2 → rawforge-0.2.0}/RawForge/application/dng_utils.py +0 -0
  26. {rawforge-0.0.2 → rawforge-0.2.0}/RawForge/scripts/generate_keys.py +0 -0
  27. {rawforge-0.0.2 → rawforge-0.2.0}/RawForge/scripts/sign_models.py +0 -0
  28. {rawforge-0.0.2 → rawforge-0.2.0}/rawforge.egg-info/dependency_links.txt +0 -0
  29. {rawforge-0.0.2 → rawforge-0.2.0}/rawforge.egg-info/top_level.txt +0 -0
  30. {rawforge-0.0.2 → rawforge-0.2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rawforge
3
- Version: 0.0.2
3
+ Version: 0.2.0
4
4
  Summary: A compute backend/CLI application for using machine learning models on raw images.
5
5
  Author: Ryan Mueller
6
6
  License: MIT License
@@ -40,12 +40,20 @@ License-File: LICENSE
40
40
  Requires-Dist: numpy>=2.2
41
41
  Requires-Dist: RawHandler~=0.2.0
42
42
  Requires-Dist: colour_demosaicing~=0.2.6
43
- Requires-Dist: blended_tiling>=0.0.1.dev7
43
+ Requires-Dist: blended-tiling-numpy
44
44
  Requires-Dist: requests~=2.32
45
45
  Requires-Dist: platformdirs~=4.5
46
46
  Requires-Dist: tqdm~=4.67
47
47
  Requires-Dist: cryptography>=46.0
48
48
  Requires-Dist: tifffile
49
+ Provides-Extra: cpu
50
+ Requires-Dist: onnxruntime; extra == "cpu"
51
+ Provides-Extra: cuda
52
+ Requires-Dist: onnxruntime-gpu; extra == "cuda"
53
+ Provides-Extra: web
54
+ Requires-Dist: onnxruntime-web; extra == "web"
55
+ Provides-Extra: directml
56
+ Requires-Dist: onnxruntime-directml; extra == "directml"
49
57
  Dynamic: license-file
50
58
 
51
59
  # RawRefinery
@@ -0,0 +1,55 @@
1
+ import numpy as np
2
+ from RawForge.application.dng_utils import convert_color_matrix, to_dng
3
+ import tifffile
4
+
5
+ class ImageSaver():
6
+ def __init__(self, model_params, rh, dims=None):
7
+ self.rh = rh
8
+ self.model_params = model_params
9
+ self.dims = dims
10
+
11
+ def to_tiff(self, image, filename, apply_ccm=True):
12
+ image = image
13
+ if apply_ccm:
14
+ transform_matrix = self.rh.rgb_colorspace_transform(colorspace='lin_rec2020')
15
+ image = image[0].transpose(1, 2, 0)
16
+ transformed = image @ transform_matrix.T
17
+ else:
18
+ transformed = image[0].transpose(1, 2, 0)
19
+
20
+ transformed = np.clip(transformed, 0, 1)
21
+ transformed = transformed ** (1/2.2)
22
+ transformed = transformed * (2 ** 8 - 1)
23
+
24
+ uint_img = transformed.astype(np.uint8)
25
+
26
+ tifffile.imwrite(
27
+ filename,
28
+ uint_img,
29
+ photometric='rgb', # Explicitly define the color space
30
+ compression='deflate' # Optional: Lossless compression supported by darktable
31
+ )
32
+
33
+ def to_raw(self, denoised, filename, save_cfa):
34
+ # Compute CFA
35
+ if self.model_params['demosaicing'] == 'rawpy':
36
+ _, mask = self.rh.compute_mask_and_sparse(dims=self.dims)
37
+ denoised = denoised[0]
38
+ denoised = denoised.clip(0, 1)
39
+
40
+ denoised = np.where(mask, denoised, 0)
41
+ denoised = denoised.sum(axis=0)
42
+ denoised = denoised * ( self.rh.core_metadata.white_level) + self.rh.core_metadata.black_level_per_channel[0]
43
+ self.rh.to_dng(filename, uint_img=denoised)
44
+ else:
45
+ transform_matrix = np.linalg.inv(
46
+ self.rh.rgb_colorspace_transform(colorspace=self.rh.colorspace)
47
+ )
48
+
49
+ CCM = self.rh.rgb_colorspace_transform(colorspace='XYZ')
50
+ CCM = np.linalg.inv(CCM)
51
+ denoised = denoised[0].transpose(1, 2, 0)
52
+ transformed = denoised @ transform_matrix.T
53
+ uint_img = np.clip(transformed * 2**16-1, 0, 2**16-1).astype(np.uint16)
54
+ ccm1 = convert_color_matrix(CCM)
55
+ to_dng(uint_img, self.rh, filename, ccm1, save_cfa=save_cfa, convert_to_cfa=True)
@@ -0,0 +1,79 @@
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from blended_tiling_numpy import TilingModule
4
+ from RawForge.application.postprocessing import match_colors_linear
5
+ from tqdm import tqdm
6
+ import rawpy
7
+
8
+ class InferenceWorker():
9
+ def __init__(self, model, model_params, conditioning, tile_size=512, tile_overlap=0.25, batch_size=2, disable_tqdm=False):
10
+ super().__init__()
11
+ self.model = model
12
+ self.model_params = model_params
13
+ self.conditioning = conditioning
14
+ self.tile_size = tile_size
15
+ if 'tile_size' in model_params:
16
+ self.tile_size = model_params['tile_size']
17
+ self.tile_overlap = tile_overlap
18
+ self.batch_size = batch_size
19
+ if 'batch_size' in model_params:
20
+ self.batch_size = model_params['batch_size']
21
+ self._is_cancelled = False
22
+ self.disable_tqdm = disable_tqdm
23
+
24
+ def cancel(self):
25
+ self._is_cancelled = True
26
+
27
+
28
+ def _tile_process(self, image_RGB, model_params):
29
+ # Prepare Data
30
+ full_size = [image_RGB.shape[2], image_RGB.shape[3]]
31
+ tile_size = [self.tile_size, self.tile_size]
32
+ overlap = [self.tile_overlap, self.tile_overlap]
33
+ # Tiling Setup
34
+ tiling_module_rgb = TilingModule(tile_size=[s for s in tile_size], tile_overlap=overlap, base_size=[s for s in full_size])
35
+
36
+ tiles_rgb = tiling_module_rgb.split_into_tiles(image_RGB)
37
+
38
+ batches_rgb = [tiles_rgb[i : i + self.batch_size]
39
+ for i in range(0, len(tiles_rgb), self.batch_size)]
40
+ # Conditioning Setup
41
+ cond_tensor = np.array([self.conditioning]).astype(np.float32)
42
+ cond_tensor[:, 0] /= 6400.
43
+ cond_tensor[:, 1] = 0.
44
+ cond_tensor = cond_tensor[:, 0:1]
45
+ cond_tensor = cond_tensor.astype(np.float16)
46
+ if 'cond_scale' in model_params:
47
+ cond_tensor *= cond_tensor * model_params['cond_scale']
48
+
49
+ processed_batches = []
50
+
51
+ # Inference Loop
52
+ for i, (batch_rgb) in tqdm(enumerate(batches_rgb), disable=self.disable_tqdm):
53
+ if self._is_cancelled: return None, None
54
+
55
+ B = batch_rgb.shape[0]
56
+ # Expand conditioning to match batch size
57
+ curr_cond = np.broadcast_to(cond_tensor, (B, cond_tensor.shape[-1])).astype(np.float16)
58
+
59
+ payload = {
60
+ "input": batch_rgb,
61
+ "cond": curr_cond
62
+ }
63
+ # Filter based on inputs
64
+ model_inputs = {i.name for i in self.model.get_inputs()}
65
+ filtered_inputs = {k: v for k, v in payload.items() if k in model_inputs}
66
+ output = self.model.run(["output"], filtered_inputs)
67
+
68
+ processed_batches.append(output[0])
69
+
70
+ # Rebuild
71
+ tiles_out = np.concat(processed_batches, axis=0)
72
+ stitched = tiling_module_rgb.rebuild_with_masks(tiles_out)
73
+
74
+ return image_RGB, stitched
75
+
76
+ def run(self, model_params, image_RGB):
77
+ img, denoised_img = self._tile_process(image_RGB, model_params)
78
+ return img, denoised_img
79
+
@@ -0,0 +1,135 @@
1
+ MODEL_REGISTRY = {
2
+ "TreeNetDenoise": {
3
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1.onnx",
4
+ "filename": "ShadowWeightedL1.onnx",
5
+ "max_iso": 65535,
6
+ "demosaicing": "Malvar2004",
7
+ },
8
+ "TreeNetDenoiseLight": {
9
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_light.onnx",
10
+ "filename": "ShadowWeightedL1_light.onnx",
11
+ "max_iso": 65535,
12
+ "demosaicing": "Malvar2004",
13
+ },
14
+ "TreeNetDenoiseSuperLight": {
15
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_super_light.onnx",
16
+ "filename": "ShadowWeightedL1_super_light.onnx",
17
+ "max_iso": 65535,
18
+ "demosaicing": "Malvar2004",
19
+ },
20
+ "TreeNetDenoiseHeavy": {
21
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/ShadowWeightedL1_24_deep_500.onnx",
22
+ "filename": "ShadowWeightedL1_24_deep_500.onnx",
23
+ "max_iso": 65535,
24
+ "demosaicing": "Malvar2004",
25
+ },
26
+ "Deblur": {
27
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/realblur_gamma_140.onnx",
28
+ "filename": "realblur_gamma_140.onnx",
29
+ "demosaicing": "Malvar2004",
30
+ },
31
+ "DeepSharpen": {
32
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/Deblur_deep_24.onnx",
33
+ "filename": "Deblur_deep_24.onnx",
34
+ "demosaicing": "Malvar2004",
35
+ },
36
+ "TreeNetDenoiseXTrans": {
37
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.onnx",
38
+ "filename": "xtrans_fixed_exposure_no_conditioning_400.onnx",
39
+ "demosaicing": "rawpy",
40
+ "conditioning": "false",
41
+ "batch_size": 1,
42
+ "crop_size": 256,
43
+ "cond_scale": 0,
44
+ },
45
+ "RestormerXTrans": {
46
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/restormer.onnx",
47
+ "filename": "restormer.onnx",
48
+ "demosaicing": "rawpy",
49
+ "conditioning": "false",
50
+ "batch_size": 1,
51
+ },
52
+ "XFormerXTrans2": {
53
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xformer_fp16.onnx",
54
+ "filename": "xformer_fp16.onnx",
55
+ "demosaicing": "rawpy",
56
+ "conditioning": "false",
57
+ "batch_size": 1,
58
+ },
59
+
60
+ "XFormerDenoise": {
61
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xformer_static.onnx",
62
+ "filename": "xformer_static.onnx",
63
+ "demosaicing": "rawpy",
64
+ "conditioning": "false",
65
+ "batch_size": 1,
66
+ "crop_size": 256,
67
+ },
68
+ "NAFDenoise": {
69
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_static.onnx",
70
+ "filename": "NAF_static.onnx",
71
+ "demosaicing": "rawpy",
72
+ "conditioning": "false",
73
+ "batch_size": 1,
74
+ "crop_size": 256,
75
+ },
76
+
77
+ "NAFDenoise_dynamic": {
78
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_dynamic.onnx",
79
+ "filename": "NAF_dynamic.onnx",
80
+ "demosaicing": "rawpy",
81
+ "conditioning": "false",
82
+ "batch_size": 1,
83
+ "crop_size": 256,
84
+ },
85
+ "NAFDenoise_dynamic_fp16": {
86
+ "url": "https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_dynamic_fp16.onnx",
87
+ "filename": "NAF_dynamic_fp16.onnx",
88
+ "demosaicing": "rawpy",
89
+ "conditioning": "false",
90
+ "batch_size": 1,
91
+ "crop_size": 256,
92
+ },
93
+ "NAF_static_fp16": {
94
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/NAF_static_fp16.onnx",
95
+ "filename": "NAF_static_fp16.onnx",
96
+ "demosaicing": "rawpy",
97
+ "conditioning": "false",
98
+ "batch_size": 1,
99
+ "crop_size": 256,
100
+ },
101
+ "NAF_trace_test": {
102
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/xtrans_fixed_exposure_no_conditioning_400.onnx",
103
+ "filename": "xtrans_fixed_exposure_no_conditioning_400.onnx",
104
+ "demosaicing": "rawpy",
105
+ "conditioning": "false",
106
+ "batch_size": 1,
107
+ "crop_size": 256,
108
+ },
109
+ "non_trace_test": {
110
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/non_trace_test.onnx",
111
+ "filename": "non_trace_test.onnx",
112
+ "demosaicing": "rawpy",
113
+ "conditioning": "false",
114
+ "batch_size": 1,
115
+ "crop_size": 256,
116
+ },
117
+ "trace_test": {
118
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/trace_test.onnx",
119
+ "filename": "trace_test.onnx",
120
+ "demosaicing": "rawpy",
121
+ "conditioning": "false",
122
+ "batch_size": 1,
123
+ "crop_size": 256,
124
+ },
125
+ "non_trace_test_opset_20_static_simp": {
126
+ "url": " https://github.com/rymuelle/RawForge/releases/download/onnx_v1.0.0/non_trace_test_opset_20_static_simp.onnx",
127
+ "filename": "non_trace_test_opset_20_static_simp.onnx",
128
+ "demosaicing": "rawpy",
129
+ "conditioning": "false",
130
+ "batch_size": 1,
131
+ "crop_size": 256,
132
+ },
133
+
134
+
135
+ }
@@ -1,21 +1,15 @@
1
- import torch
1
+ import onnxruntime as ort
2
+ import os
2
3
  import numpy as np
3
4
  from pathlib import Path
4
5
  from platformdirs import user_data_dir
5
- from time import perf_counter
6
6
  import requests
7
7
  from cryptography.hazmat.primitives import hashes, serialization
8
8
  from cryptography.hazmat.primitives.asymmetric import padding
9
9
 
10
- from RawHandler.RawHandler import RawHandler
11
- from RawHandler.RawHandlerRawpy import RawHandlerRawpy
12
-
13
- from RawForge.application.dng_utils import convert_color_matrix, to_dng
14
- from RawForge.application.utils import can_use_gpu
15
-
16
10
  from RawForge.application.MODEL_REGISTRY import MODEL_REGISTRY
17
11
  from RawForge.application.InferenceWorker import InferenceWorker
18
- from RawForge.application.InferenceWorkerRawpy import InferenceWorkerRawpy
12
+ from RawForge.application.helpers.utils import get_best_providers
19
13
 
20
14
  key_string = '''-----BEGIN PUBLIC KEY-----
21
15
  MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA8iRGMPqFIFVF0TM/AbMI
@@ -40,46 +34,23 @@ class ModelHandler():
40
34
  def __init__(self):
41
35
  super().__init__()
42
36
 
43
- self.model = None
44
- self.rh = None
45
- self.iso = 100
46
- self.colorspace = 'lin_rec2020'
47
-
37
+ self.model = None
38
+ app_name = "RawForge"
39
+ self.data_dir = Path(user_data_dir(app_name))
40
+ self.cache_dir = Path(os.path.expanduser(f"{self.data_dir}/model_cache")).resolve()
48
41
  # Manage devices
49
- devices = {
50
- "cuda": can_use_gpu(),
51
- "mps": torch.backends.mps.is_available(),
52
- "cpu": lambda : True
53
- }
54
- self.devices = [d for d, is_available in devices.items() if is_available]
55
- self.set_device(self.devices[0])
56
-
57
- self.filename = None
58
- self.start_time = None
59
- self.model_params = {}
60
-
42
+ self.providers = get_best_providers(cache_dir=self.cache_dir)
61
43
  self.pub = serialization.load_pem_public_key(key_string.encode('utf-8'))
62
44
 
63
- def load_rh(self, path):
64
- """Loads the raw file handler"""
65
- if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
66
- self.rh = RawHandlerRawpy(path)
67
- else:
68
- self.rh = RawHandler(path)
69
- self.iso = self.rh.full_metadata.get_ISO()
70
- return self.iso
71
45
 
72
46
  def load_model(self, model_key):
73
47
  """Loads a model by key from the registry"""
74
48
  if model_key not in MODEL_REGISTRY:
75
49
  print(f"Model {model_key} not found in registry.")
76
50
  return
77
-
78
51
  conf = MODEL_REGISTRY[model_key]
79
52
  self.model_params = conf
80
- app_name = "RawForge"
81
- data_dir = Path(user_data_dir(app_name))
82
- model_path = data_dir / conf["filename"]
53
+ model_path = self.data_dir / conf["filename"]
83
54
 
84
55
  # Handle Download
85
56
  if not model_path.is_file():
@@ -97,37 +68,15 @@ class ModelHandler():
97
68
  # Verify model before load
98
69
  self._verify_model(model_path, model_path.with_suffix(f'{model_path.suffix}.sig'))
99
70
 
100
- loaded = torch.jit.load(model_path, map_location='cpu')
101
- self.model = loaded.eval().to(self.device)
71
+ session = ort.InferenceSession(
72
+ model_path,
73
+ providers=self.providers,
74
+ )
75
+ print("Loaded!")
76
+ self.model = session
102
77
  except Exception as e:
103
78
  print(f"Failed to load model: {e}")
104
79
 
105
- def set_device(self, device):
106
- self.device = torch.device(device)
107
- if self.model:
108
- self.model.to(self.device)
109
- print(f"Using Device {self.device} from {device}")
110
-
111
- def run_inference(self, conditioning, dims=None, inference_kwargs={}):
112
- """Starts the worker thread"""
113
- if not self.model or not self.rh:
114
- print("Model or Image not loaded.")
115
- return
116
-
117
- if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
118
- worker = InferenceWorkerRawpy(self.model, self.model_params, self.device, self.rh, conditioning, dims, **inference_kwargs)
119
- else:
120
- worker = InferenceWorker(self.model, self.model_params, self.device, self.rh, conditioning, dims, **inference_kwargs)
121
- img, final_denoised = worker.run()
122
-
123
- return img, final_denoised
124
-
125
-
126
- def generate_thumbnail(self, size=400):
127
- if not self.rh: return None
128
- thumb = self.rh.generate_thumbnail(min_preview_size=size, clip=True)
129
- return thumb
130
-
131
80
  def _verify_model(self, dest_path, sig_path):
132
81
  try:
133
82
  data = Path(dest_path).read_bytes()
@@ -152,7 +101,6 @@ class ModelHandler():
152
101
  print(f"Model {dest_path} not verified! Deleting.")
153
102
  return False
154
103
 
155
-
156
104
  def _download_file(self, url, dest_path):
157
105
  dest_path.parent.mkdir(parents=True, exist_ok=True)
158
106
  try:
@@ -175,31 +123,4 @@ class ModelHandler():
175
123
  except Exception as e:
176
124
  print(e)
177
125
  return False
178
-
179
-
180
- def handle_full_image(self, denoised, filename, save_cfa):
181
- # Compute CFA
182
- if 'backend' in self.model_params and self.model_params['backend'] == 'rawpy':
183
- _, mask = self.rh.compute_mask_and_sparse(dims=(0, 99999, 0, 99999))
184
- denoised = denoised.transpose(2, 0, 1)
185
- denoised = denoised.clip(0, 1)
186
-
187
- denoised = np.where(mask, denoised, 0)
188
- denoised = denoised.sum(axis=0)
189
- denoised = denoised * ( self.rh.core_metadata.white_level) + self.rh.core_metadata.black_level_per_channel[0]
190
- self.rh.to_dng(filename, uint_img=denoised)
191
- else:
192
- transform_matrix = np.linalg.inv(
193
- self.rh.rgb_colorspace_transform(colorspace=self.colorspace)
194
- )
195
-
196
- CCM = self.rh.rgb_colorspace_transform(colorspace='XYZ')
197
- CCM = np.linalg.inv(CCM)
198
-
199
- transformed = denoised @ transform_matrix.T
200
- uint_img = np.clip(transformed * 2**16-1, 0, 2**16-1).astype(np.uint16)
201
- ccm1 = convert_color_matrix(CCM)
202
- to_dng(uint_img, self.rh, filename, ccm1, save_cfa=save_cfa, convert_to_cfa=True)
203
-
204
-
205
126
 
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ import math
3
+
4
+ def norm_cdf(z):
5
+ erf_vec = np.vectorize(math.erf)
6
+ return 0.5 * (1 + erf_vec(z / np.sqrt(2)))
7
+
8
+ def norm_pdf(z):
9
+ return (1 / np.sqrt(2 * np.pi)) * np.exp(-0.5 * z**2)
10
+
11
+
12
+ def censored_linear_fit_twosided(x, y, clip_low=0, clip_high=1,
13
+ max_iter=200, tol=1e-6, include_offset=True):
14
+ """
15
+ Fit y ≈ a + b*x + ε, ε ~ N(0, σ²) under two-sided censoring:
16
+ clip_low ≤ y_true ≤ clip_high
17
+ Observed y are clipped to [clip_low, clip_high].
18
+ Returns (a, b, sigma) estimated via EM.
19
+
20
+ Parameters
21
+ ----------
22
+ x, y : array_like
23
+ Input data.
24
+ clip_low, clip_high : float or None
25
+ Lower/upper clip levels. Can be None for one-sided clipping.
26
+ max_iter : int
27
+ Maximum EM iterations.
28
+ tol : float
29
+ Relative tolerance for convergence.
30
+ include_offset : bool
31
+ If False, forces intercept a=0 (fit y ≈ b*x).
32
+
33
+ Returns
34
+ -------
35
+ a, b, sigma : floats
36
+ Estimated regression parameters.
37
+ """
38
+ x = np.asarray(x).ravel()
39
+ y = np.asarray(y).ravel()
40
+ mask = np.isfinite(x) & np.isfinite(y)
41
+ x, y = x[mask], y[mask]
42
+
43
+ if len(x) < 3:
44
+ raise ValueError("Not enough data points.")
45
+
46
+ # Initial guess (ordinary least squares)
47
+ if include_offset:
48
+ A = np.vstack([np.ones_like(x), x]).T
49
+ a, b = np.linalg.lstsq(A, y, rcond=None)[0]
50
+ else:
51
+ b = np.dot(x, y) / np.dot(x, x)
52
+ a = 0.0
53
+ sigma = np.std(y - (a + b*x))
54
+
55
+ for _ in range(max_iter):
56
+ mu = a + b*x
57
+ y_exp = y.copy()
58
+
59
+ # Handle right-censoring (high clip)
60
+ if clip_high is not None:
61
+ high_mask = y >= clip_high - 1e-12
62
+ if np.any(high_mask):
63
+ z = (clip_high - mu[high_mask]) / sigma
64
+ Phi = norm_cdf(z)
65
+ phi = norm_pdf(z)
66
+ one_minus_Phi = 1.0 - Phi
67
+ lambda_ = np.zeros_like(z)
68
+ valid = one_minus_Phi > 1e-15
69
+ lambda_[valid] = phi[valid] / one_minus_Phi[valid]
70
+ y_exp[high_mask] = mu[high_mask] + sigma * lambda_
71
+
72
+ # Handle left-censoring (low clip)
73
+ if clip_low is not None:
74
+ low_mask = y <= clip_low + 1e-12
75
+ if np.any(low_mask):
76
+ z = (clip_low - mu[low_mask]) / sigma
77
+ Phi = norm_cdf(z)
78
+ phi = norm_pdf(z)
79
+ lambda_ = np.zeros_like(z)
80
+ valid = Phi > 1e-15
81
+ lambda_[valid] = -phi[valid] / Phi[valid]
82
+ y_exp[low_mask] = mu[low_mask] + sigma * lambda_
83
+
84
+ # M-step: re-fit with imputed expectations
85
+ if include_offset:
86
+ A = np.vstack([np.ones_like(x), x]).T
87
+ a_new, b_new = np.linalg.lstsq(A, y_exp, rcond=None)[0]
88
+ else:
89
+ b_new = np.dot(x, y_exp) / np.dot(x, x)
90
+ a_new = 0.0
91
+
92
+ sigma_new = np.std(y_exp - (a_new + b_new*x))
93
+
94
+ if np.allclose([a, b, sigma], [a_new, b_new, sigma_new],
95
+ rtol=tol, atol=tol):
96
+ break
97
+
98
+ a, b, sigma = a_new, b_new, sigma_new
99
+
100
+ return a, b, sigma
@@ -0,0 +1,41 @@
1
+ import numpy as np
2
+ import rawpy
3
+ from RawHandler.RawHandler import RawHandler
4
+ from RawHandler.RawHandlerRawpy import RawHandlerRawpy
5
+
6
+ def load_rh(path, demosaicing):
7
+ """Loads the raw file handler"""
8
+ if demosaicing == 'rawpy':
9
+ rh = RawHandlerRawpy(path)
10
+ elif demosaicing == 'Malvar2004':
11
+ rh = RawHandler(path)
12
+ iso = rh.full_metadata.get_ISO()
13
+ return rh, iso
14
+
15
+ def get_image(path, model_params, dims=None):
16
+ demosaicing = model_params['demosaicing']
17
+ rh, iso = load_rh(path, demosaicing)
18
+
19
+ if demosaicing=='Malvar2004':
20
+ from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004
21
+ image_RGB = rh.as_rgb(dims=dims,
22
+ demosaicing_func=demosaicing_CFA_Bayer_Malvar2004,
23
+ colorspace='lin_rec2020', clip=True)
24
+ elif demosaicing=='rawpy':
25
+ image_RGB = rh.rawpy_object.postprocess(
26
+ user_wb=[1, 1, 1, 1],
27
+ output_color=rawpy.ColorSpace.raw,
28
+ demosaic_algorithm= rawpy.DemosaicAlgorithm(3),
29
+ no_auto_bright=True,
30
+ use_camera_wb=False,
31
+ use_auto_wb=False,
32
+ gamma=(1, 1),
33
+ user_flip=0,
34
+ output_bps=16,
35
+ no_auto_scale=True,
36
+ ) / rh.rawpy_object.white_level
37
+ if dims is not None:
38
+ image_RGB = image_RGB[dims[0]:dims[1], dims[2]:dims[3]]
39
+ image_RGB = image_RGB.transpose(2, 0, 1)
40
+ image_RGB = np.expand_dims(image_RGB, axis=0).astype(np.float16)
41
+ return rh, image_RGB, iso
@@ -0,0 +1,17 @@
1
+ import onnxruntime as ort
2
+ import os
3
+
4
+ def get_best_providers(cache_dir):
5
+ os.makedirs(cache_dir, exist_ok=True)
6
+ available = ort.get_available_providers()
7
+ return available
8
+ priority = [
9
+ "CUDAExecutionProvider", # NVIDIA
10
+ "ROCMExecutionProvider", # AMD (Direct)
11
+ "MIGraphXExecutionProvider", # AMD (Optimized)
12
+ "CoreMLExecutionProvider", # Apple Silicon
13
+ "DmlExecutionProvider", # Windows (AMD/Intel/Generic GPU)
14
+ "CPUExecutionProvider" # The fallback
15
+ ]
16
+
17
+ return [p for p in priority if p in available]