hyperbench 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperbench/__about__.py +7 -0
- hyperbench/__init__.py +157 -0
- hyperbench/adapters/__init__.py +19 -0
- hyperbench/adapters/base.py +178 -0
- hyperbench/adapters/callable.py +62 -0
- hyperbench/adapters/pipeline.py +192 -0
- hyperbench/adapters/tensorflow.py +67 -0
- hyperbench/adapters/torch.py +108 -0
- hyperbench/benchmark/__init__.py +18 -0
- hyperbench/benchmark/case.py +121 -0
- hyperbench/benchmark/generator.py +122 -0
- hyperbench/benchmark/results.py +127 -0
- hyperbench/benchmark/runner.py +392 -0
- hyperbench/cli.py +209 -0
- hyperbench/config.py +122 -0
- hyperbench/degradations/__init__.py +65 -0
- hyperbench/degradations/preprocessing.py +111 -0
- hyperbench/degradations/psf.py +233 -0
- hyperbench/degradations/spatial.py +100 -0
- hyperbench/degradations/spectral.py +97 -0
- hyperbench/degradations/srf.py +189 -0
- hyperbench/exceptions.py +53 -0
- hyperbench/io/__init__.py +15 -0
- hyperbench/io/loaders.py +70 -0
- hyperbench/io/matlab.py +97 -0
- hyperbench/metrics/__init__.py +27 -0
- hyperbench/metrics/core.py +181 -0
- hyperbench/metrics/hyperspectral.py +67 -0
- hyperbench/types.py +24 -0
- hyperbench/utils/__init__.py +90 -0
- hyperbench/utils/frameworks.py +314 -0
- hyperbench/utils/logging.py +152 -0
- hyperbench/utils/random.py +92 -0
- hyperbench/utils/validation.py +213 -0
- hyperbench/utils/visualization.py +255 -0
- hyperbench-0.1.0.dist-info/METADATA +171 -0
- hyperbench-0.1.0.dist-info/RECORD +41 -0
- hyperbench-0.1.0.dist-info/WHEEL +5 -0
- hyperbench-0.1.0.dist-info/entry_points.txt +2 -0
- hyperbench-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
- hyperbench-0.1.0.dist-info/top_level.txt +1 -0
hyperbench/__about__.py
ADDED
hyperbench/__init__.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# Author: Ritik Shah
|
|
2
|
+
|
|
3
|
+
"""HyperBench: synthetic benchmarking for hyperspectral super-resolution."""
|
|
4
|
+
|
|
5
|
+
from .__about__ import __version__
|
|
6
|
+
|
|
7
|
+
from .adapters import (
|
|
8
|
+
BaseAdapter,
|
|
9
|
+
CallableAdapter,
|
|
10
|
+
PipelineAdapter,
|
|
11
|
+
ReconstructionInputs,
|
|
12
|
+
TensorFlowModelAdapter,
|
|
13
|
+
TorchModelAdapter,
|
|
14
|
+
)
|
|
15
|
+
from .benchmark import (
|
|
16
|
+
BenchmarkCase,
|
|
17
|
+
BenchmarkConfig,
|
|
18
|
+
BenchmarkResults,
|
|
19
|
+
DegradationSpec,
|
|
20
|
+
SyntheticCase,
|
|
21
|
+
generate_cases,
|
|
22
|
+
run_benchmark,
|
|
23
|
+
)
|
|
24
|
+
from .config import benchmark_config_from_dict, load_benchmark_config
|
|
25
|
+
from .degradations import (
|
|
26
|
+
AVAILABLE_PSFS,
|
|
27
|
+
SUPPORTED_SRF_BAND_COUNTS,
|
|
28
|
+
make_psf,
|
|
29
|
+
normalize_image,
|
|
30
|
+
spatial_degradation,
|
|
31
|
+
spectral_degradation,
|
|
32
|
+
)
|
|
33
|
+
from .exceptions import (
|
|
34
|
+
AdapterError,
|
|
35
|
+
AdapterOutputError,
|
|
36
|
+
ConfigurationError,
|
|
37
|
+
FrameworkAvailabilityError,
|
|
38
|
+
HyperBenchError,
|
|
39
|
+
IOValidationError,
|
|
40
|
+
MetricsError,
|
|
41
|
+
PipelineError,
|
|
42
|
+
SceneKeyError,
|
|
43
|
+
ShapeMismatchError,
|
|
44
|
+
UnsupportedPSFError,
|
|
45
|
+
UnsupportedSRFError,
|
|
46
|
+
)
|
|
47
|
+
from .io import load_hsi
|
|
48
|
+
from .metrics import (
|
|
49
|
+
AVAILABLE_METRICS,
|
|
50
|
+
DEFAULT_METRICS,
|
|
51
|
+
compute_ergas,
|
|
52
|
+
compute_psnr,
|
|
53
|
+
compute_rmse,
|
|
54
|
+
compute_sam,
|
|
55
|
+
compute_ssim,
|
|
56
|
+
compute_uiqi,
|
|
57
|
+
evaluate_metrics,
|
|
58
|
+
)
|
|
59
|
+
from .types import Array, Metadata, PathLike
|
|
60
|
+
from .utils import (
|
|
61
|
+
convert_prediction_to_numpy_hwc,
|
|
62
|
+
get_preferred_tensorflow_device,
|
|
63
|
+
get_preferred_torch_device,
|
|
64
|
+
get_tensorflow_device_info,
|
|
65
|
+
get_torch_device_info,
|
|
66
|
+
is_tensorflow_available,
|
|
67
|
+
is_torch_available,
|
|
68
|
+
numpy_hwc_to_tf_image,
|
|
69
|
+
numpy_hwc_to_torch_image,
|
|
70
|
+
numpy_prediction_to_hwc,
|
|
71
|
+
numpy_to_tf_matrix,
|
|
72
|
+
numpy_to_torch_matrix,
|
|
73
|
+
plot_spectra,
|
|
74
|
+
print_data_stats,
|
|
75
|
+
print_framework_device_summary,
|
|
76
|
+
tf_image_to_numpy_hwc,
|
|
77
|
+
tf_matrix_to_numpy,
|
|
78
|
+
torch_image_to_numpy_hwc,
|
|
79
|
+
torch_matrix_to_numpy,
|
|
80
|
+
visualize_band,
|
|
81
|
+
visualize_hsi,
|
|
82
|
+
visualize_multispectral_with_srf,
|
|
83
|
+
visualize_psfs,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
__all__ = [
|
|
87
|
+
"__version__",
|
|
88
|
+
"Array",
|
|
89
|
+
"PathLike",
|
|
90
|
+
"Metadata",
|
|
91
|
+
"HyperBenchError",
|
|
92
|
+
"ConfigurationError",
|
|
93
|
+
"IOValidationError",
|
|
94
|
+
"SceneKeyError",
|
|
95
|
+
"ShapeMismatchError",
|
|
96
|
+
"UnsupportedPSFError",
|
|
97
|
+
"UnsupportedSRFError",
|
|
98
|
+
"AdapterError",
|
|
99
|
+
"AdapterOutputError",
|
|
100
|
+
"PipelineError",
|
|
101
|
+
"MetricsError",
|
|
102
|
+
"FrameworkAvailabilityError",
|
|
103
|
+
"BaseAdapter",
|
|
104
|
+
"CallableAdapter",
|
|
105
|
+
"PipelineAdapter",
|
|
106
|
+
"TorchModelAdapter",
|
|
107
|
+
"TensorFlowModelAdapter",
|
|
108
|
+
"ReconstructionInputs",
|
|
109
|
+
"DegradationSpec",
|
|
110
|
+
"BenchmarkConfig",
|
|
111
|
+
"BenchmarkCase",
|
|
112
|
+
"SyntheticCase",
|
|
113
|
+
"BenchmarkResults",
|
|
114
|
+
"generate_cases",
|
|
115
|
+
"run_benchmark",
|
|
116
|
+
"benchmark_config_from_dict",
|
|
117
|
+
"load_benchmark_config",
|
|
118
|
+
"load_hsi",
|
|
119
|
+
"AVAILABLE_PSFS",
|
|
120
|
+
"SUPPORTED_SRF_BAND_COUNTS",
|
|
121
|
+
"make_psf",
|
|
122
|
+
"normalize_image",
|
|
123
|
+
"spatial_degradation",
|
|
124
|
+
"spectral_degradation",
|
|
125
|
+
"AVAILABLE_METRICS",
|
|
126
|
+
"DEFAULT_METRICS",
|
|
127
|
+
"compute_rmse",
|
|
128
|
+
"compute_psnr",
|
|
129
|
+
"compute_ssim",
|
|
130
|
+
"compute_uiqi",
|
|
131
|
+
"compute_ergas",
|
|
132
|
+
"compute_sam",
|
|
133
|
+
"evaluate_metrics",
|
|
134
|
+
"print_data_stats",
|
|
135
|
+
"visualize_hsi",
|
|
136
|
+
"visualize_band",
|
|
137
|
+
"visualize_multispectral_with_srf",
|
|
138
|
+
"plot_spectra",
|
|
139
|
+
"visualize_psfs",
|
|
140
|
+
"is_torch_available",
|
|
141
|
+
"is_tensorflow_available",
|
|
142
|
+
"get_torch_device_info",
|
|
143
|
+
"get_tensorflow_device_info",
|
|
144
|
+
"get_preferred_torch_device",
|
|
145
|
+
"get_preferred_tensorflow_device",
|
|
146
|
+
"print_framework_device_summary",
|
|
147
|
+
"numpy_hwc_to_tf_image",
|
|
148
|
+
"tf_image_to_numpy_hwc",
|
|
149
|
+
"numpy_to_tf_matrix",
|
|
150
|
+
"tf_matrix_to_numpy",
|
|
151
|
+
"numpy_hwc_to_torch_image",
|
|
152
|
+
"torch_image_to_numpy_hwc",
|
|
153
|
+
"numpy_to_torch_matrix",
|
|
154
|
+
"torch_matrix_to_numpy",
|
|
155
|
+
"numpy_prediction_to_hwc",
|
|
156
|
+
"convert_prediction_to_numpy_hwc",
|
|
157
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Author: Ritik Shah
|
|
2
|
+
|
|
3
|
+
"""Model adapter layer for HyperBench."""
|
|
4
|
+
|
|
5
|
+
from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
|
|
6
|
+
from .callable import CallableAdapter
|
|
7
|
+
from .pipeline import PipelineAdapter
|
|
8
|
+
from .tensorflow import TensorFlowModelAdapter
|
|
9
|
+
from .torch import TorchModelAdapter
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ShapePolicy",
|
|
13
|
+
"ReconstructionInputs",
|
|
14
|
+
"BaseAdapter",
|
|
15
|
+
"CallableAdapter",
|
|
16
|
+
"PipelineAdapter",
|
|
17
|
+
"TorchModelAdapter",
|
|
18
|
+
"TensorFlowModelAdapter",
|
|
19
|
+
]
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
# Author: Ritik Shah
|
|
2
|
+
|
|
3
|
+
"""Base adapter abstractions for HyperBench."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any, Dict, Literal, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
Array = np.ndarray
|
|
14
|
+
ShapePolicy = Literal["strict", "crop", "pad"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ReconstructionInputs:
|
|
19
|
+
"""Canonical inference-time input bundle for HyperBench."""
|
|
20
|
+
|
|
21
|
+
lr_hsi: Array
|
|
22
|
+
hr_msi: Array
|
|
23
|
+
srf: Optional[Array] = None
|
|
24
|
+
psf: Optional[Array] = None
|
|
25
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseAdapter:
|
|
29
|
+
"""Base class for all HyperBench adapters."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
name: str,
|
|
34
|
+
shape_policy: ShapePolicy = "strict",
|
|
35
|
+
hr_multiple: int = 1,
|
|
36
|
+
lr_multiple: int = 1,
|
|
37
|
+
) -> None:
|
|
38
|
+
if shape_policy not in {"strict", "crop", "pad"}:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
"shape_policy must be one of {'strict', 'crop', 'pad'}, got {!r}".format(
|
|
41
|
+
shape_policy
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
self.name = name
|
|
46
|
+
self.shape_policy = shape_policy
|
|
47
|
+
self.hr_multiple = int(hr_multiple)
|
|
48
|
+
self.lr_multiple = int(lr_multiple)
|
|
49
|
+
|
|
50
|
+
self.config = {
|
|
51
|
+
"adapter_name": self.name,
|
|
52
|
+
"adapter_type": type(self).__name__,
|
|
53
|
+
"shape_policy": self.shape_policy,
|
|
54
|
+
"hr_multiple": self.hr_multiple,
|
|
55
|
+
"lr_multiple": self.lr_multiple,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
def predict(self, inputs: ReconstructionInputs) -> Any:
|
|
59
|
+
"""Run model inference for one HyperBench case.
|
|
60
|
+
|
|
61
|
+
Valid returns:
|
|
62
|
+
- prediction
|
|
63
|
+
- (prediction, stats_dict)
|
|
64
|
+
"""
|
|
65
|
+
raise NotImplementedError
|
|
66
|
+
|
|
67
|
+
def _validate_multiple(self, value: int, name: str) -> None:
|
|
68
|
+
if value < 1:
|
|
69
|
+
raise ValueError("{} must be >= 1, got {}".format(name, value))
|
|
70
|
+
|
|
71
|
+
def prepare_inputs(self, inputs: ReconstructionInputs) -> ReconstructionInputs:
|
|
72
|
+
"""Apply shape policy before inference."""
|
|
73
|
+
self._validate_multiple(self.hr_multiple, "hr_multiple")
|
|
74
|
+
self._validate_multiple(self.lr_multiple, "lr_multiple")
|
|
75
|
+
|
|
76
|
+
if self.shape_policy == "strict":
|
|
77
|
+
self._ensure_valid_shapes(inputs)
|
|
78
|
+
return inputs
|
|
79
|
+
elif self.shape_policy == "crop":
|
|
80
|
+
return self._crop_inputs_to_valid(inputs)
|
|
81
|
+
elif self.shape_policy == "pad":
|
|
82
|
+
return self._pad_inputs_to_valid(inputs)
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError("Unsupported shape_policy {!r}".format(self.shape_policy))
|
|
85
|
+
|
|
86
|
+
def _ensure_valid_shapes(self, inputs: ReconstructionInputs) -> None:
|
|
87
|
+
hr_h, hr_w = inputs.hr_msi.shape[:2]
|
|
88
|
+
lr_h, lr_w = inputs.lr_hsi.shape[:2]
|
|
89
|
+
|
|
90
|
+
if hr_h % self.hr_multiple != 0 or hr_w % self.hr_multiple != 0:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"HR input shape {} is not divisible by hr_multiple={}".format(
|
|
93
|
+
inputs.hr_msi.shape[:2], self.hr_multiple
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if lr_h % self.lr_multiple != 0 or lr_w % self.lr_multiple != 0:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
"LR input shape {} is not divisible by lr_multiple={}".format(
|
|
100
|
+
inputs.lr_hsi.shape[:2], self.lr_multiple
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def _crop_inputs_to_valid(self, inputs: ReconstructionInputs) -> ReconstructionInputs:
|
|
105
|
+
hr_h, hr_w = inputs.hr_msi.shape[:2]
|
|
106
|
+
lr_h, lr_w = inputs.lr_hsi.shape[:2]
|
|
107
|
+
|
|
108
|
+
new_hr_h = hr_h - (hr_h % self.hr_multiple)
|
|
109
|
+
new_hr_w = hr_w - (hr_w % self.hr_multiple)
|
|
110
|
+
new_lr_h = lr_h - (lr_h % self.lr_multiple)
|
|
111
|
+
new_lr_w = lr_w - (lr_w % self.lr_multiple)
|
|
112
|
+
|
|
113
|
+
if new_hr_h <= 0 or new_hr_w <= 0:
|
|
114
|
+
raise ValueError("Cropping would result in invalid HR shape")
|
|
115
|
+
if new_lr_h <= 0 or new_lr_w <= 0:
|
|
116
|
+
raise ValueError("Cropping would result in invalid LR shape")
|
|
117
|
+
|
|
118
|
+
return ReconstructionInputs(
|
|
119
|
+
lr_hsi=inputs.lr_hsi[:new_lr_h, :new_lr_w, :],
|
|
120
|
+
hr_msi=inputs.hr_msi[:new_hr_h, :new_hr_w, :],
|
|
121
|
+
srf=inputs.srf,
|
|
122
|
+
psf=inputs.psf,
|
|
123
|
+
metadata=dict(inputs.metadata),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def _compute_padding(self, size: int, multiple: int) -> int:
|
|
127
|
+
remainder = size % multiple
|
|
128
|
+
return 0 if remainder == 0 else multiple - remainder
|
|
129
|
+
|
|
130
|
+
def _pad_inputs_to_valid(self, inputs: ReconstructionInputs) -> ReconstructionInputs:
|
|
131
|
+
lr_h, lr_w, lr_c = inputs.lr_hsi.shape
|
|
132
|
+
hr_h, hr_w, hr_c = inputs.hr_msi.shape
|
|
133
|
+
|
|
134
|
+
pad_lr_h = self._compute_padding(lr_h, self.lr_multiple)
|
|
135
|
+
pad_lr_w = self._compute_padding(lr_w, self.lr_multiple)
|
|
136
|
+
pad_hr_h = self._compute_padding(hr_h, self.hr_multiple)
|
|
137
|
+
pad_hr_w = self._compute_padding(hr_w, self.hr_multiple)
|
|
138
|
+
|
|
139
|
+
lr_padded = np.pad(
|
|
140
|
+
inputs.lr_hsi,
|
|
141
|
+
((0, pad_lr_h), (0, pad_lr_w), (0, 0)),
|
|
142
|
+
mode="reflect",
|
|
143
|
+
).astype(np.float32)
|
|
144
|
+
|
|
145
|
+
hr_padded = np.pad(
|
|
146
|
+
inputs.hr_msi,
|
|
147
|
+
((0, pad_hr_h), (0, pad_hr_w), (0, 0)),
|
|
148
|
+
mode="reflect",
|
|
149
|
+
).astype(np.float32)
|
|
150
|
+
|
|
151
|
+
metadata = dict(inputs.metadata)
|
|
152
|
+
metadata["_original_lr_shape"] = (lr_h, lr_w, lr_c)
|
|
153
|
+
metadata["_original_hr_shape"] = (hr_h, hr_w, hr_c)
|
|
154
|
+
|
|
155
|
+
return ReconstructionInputs(
|
|
156
|
+
lr_hsi=lr_padded,
|
|
157
|
+
hr_msi=hr_padded,
|
|
158
|
+
srf=inputs.srf,
|
|
159
|
+
psf=inputs.psf,
|
|
160
|
+
metadata=metadata,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def _restore_output_shape(self, output: Array, inputs: ReconstructionInputs) -> Array:
|
|
164
|
+
"""Restore output to original HR shape when using pad policy."""
|
|
165
|
+
if self.shape_policy != "pad":
|
|
166
|
+
return output
|
|
167
|
+
|
|
168
|
+
original_hr_shape = inputs.metadata.get("_original_hr_shape")
|
|
169
|
+
if original_hr_shape is None:
|
|
170
|
+
return output
|
|
171
|
+
|
|
172
|
+
h, w, c = original_hr_shape
|
|
173
|
+
return output[:h, :w, :c]
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def csv_metadata(self) -> Dict[str, Any]:
|
|
177
|
+
"""Optional static metadata to merge into CSV rows."""
|
|
178
|
+
return {}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Author: Ritik Shah
|
|
2
|
+
|
|
3
|
+
"""Callable adapter for HyperBench."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Any, Callable, Dict, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from hyperbench.utils import numpy_prediction_to_hwc
|
|
12
|
+
|
|
13
|
+
from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
Array = np.ndarray
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CallableAdapter(BaseAdapter):
|
|
20
|
+
"""Wrap a plain Python callable as a HyperBench adapter."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
fn: Callable,
|
|
25
|
+
name: str = "callable_model",
|
|
26
|
+
shape_policy: ShapePolicy = "strict",
|
|
27
|
+
hr_multiple: int = 1,
|
|
28
|
+
lr_multiple: int = 1,
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__(
|
|
31
|
+
name=name,
|
|
32
|
+
shape_policy=shape_policy,
|
|
33
|
+
hr_multiple=hr_multiple,
|
|
34
|
+
lr_multiple=lr_multiple,
|
|
35
|
+
)
|
|
36
|
+
self.fn = fn
|
|
37
|
+
|
|
38
|
+
def predict(self, inputs: ReconstructionInputs) -> Any:
|
|
39
|
+
prepared = self.prepare_inputs(inputs)
|
|
40
|
+
|
|
41
|
+
result = self.fn(
|
|
42
|
+
prepared.lr_hsi,
|
|
43
|
+
prepared.hr_msi,
|
|
44
|
+
srf=prepared.srf,
|
|
45
|
+
psf=prepared.psf,
|
|
46
|
+
metadata=prepared.metadata,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if isinstance(result, tuple):
|
|
50
|
+
if len(result) != 2 or not isinstance(result[1], dict):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"CallableAdapter functions may return either prediction or "
|
|
53
|
+
"(prediction, stats_dict)."
|
|
54
|
+
)
|
|
55
|
+
prediction, stats = result
|
|
56
|
+
prediction = numpy_prediction_to_hwc(prediction, remove_batch_dim=True)
|
|
57
|
+
prediction = self._restore_output_shape(prediction, prepared)
|
|
58
|
+
return prediction.astype(np.float32), stats
|
|
59
|
+
|
|
60
|
+
prediction = numpy_prediction_to_hwc(result, remove_batch_dim=True)
|
|
61
|
+
prediction = self._restore_output_shape(prediction, prepared)
|
|
62
|
+
return prediction.astype(np.float32)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# Author: Ritik Shah
|
|
2
|
+
|
|
3
|
+
"""Generic pipeline adapter for HyperBench.
|
|
4
|
+
|
|
5
|
+
This adapter is the recommended path for real-world research models.
|
|
6
|
+
|
|
7
|
+
Expected user-side contract:
|
|
8
|
+
- an object with a `run_pipeline(...)` method, or
|
|
9
|
+
- a callable function itself
|
|
10
|
+
|
|
11
|
+
The pipeline should accept:
|
|
12
|
+
run_pipeline(HR_MSI, LR_HSI, srf, psf=None, metadata=None)
|
|
13
|
+
|
|
14
|
+
Valid returns:
|
|
15
|
+
- prediction
|
|
16
|
+
- (prediction, stats_dict)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from typing import Any, Dict, Optional
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from hyperbench.utils import (
|
|
26
|
+
convert_prediction_to_numpy_hwc,
|
|
27
|
+
get_preferred_tensorflow_device,
|
|
28
|
+
get_preferred_torch_device,
|
|
29
|
+
numpy_hwc_to_tf_image,
|
|
30
|
+
numpy_hwc_to_torch_image,
|
|
31
|
+
numpy_to_tf_matrix,
|
|
32
|
+
numpy_to_torch_matrix,
|
|
33
|
+
numpy_prediction_to_hwc,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
Array = np.ndarray
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PipelineAdapter(BaseAdapter):
|
|
43
|
+
"""Generic adapter for model wrappers exposing `run_pipeline(...)`."""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
pipeline: Any,
|
|
48
|
+
name: str = "pipeline_model",
|
|
49
|
+
input_backend: str = "numpy",
|
|
50
|
+
output_backend: Optional[str] = None,
|
|
51
|
+
add_batch_dim: bool = True,
|
|
52
|
+
device: str = "auto",
|
|
53
|
+
shape_policy: ShapePolicy = "strict",
|
|
54
|
+
hr_multiple: int = 1,
|
|
55
|
+
lr_multiple: int = 1,
|
|
56
|
+
) -> None:
|
|
57
|
+
super().__init__(
|
|
58
|
+
name=name,
|
|
59
|
+
shape_policy=shape_policy,
|
|
60
|
+
hr_multiple=hr_multiple,
|
|
61
|
+
lr_multiple=lr_multiple,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
backend = input_backend.lower()
|
|
65
|
+
if backend not in {"numpy", "tensorflow", "torch"}:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"input_backend must be one of {'numpy', 'tensorflow', 'torch'}, got {!r}".format(
|
|
68
|
+
input_backend
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
out_backend = output_backend.lower() if output_backend is not None else backend
|
|
73
|
+
if out_backend not in {"numpy", "tensorflow", "torch"}:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"output_backend must be one of {'numpy', 'tensorflow', 'torch'}, got {!r}".format(
|
|
76
|
+
output_backend
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.pipeline = pipeline
|
|
81
|
+
self.input_backend = backend
|
|
82
|
+
self.output_backend = out_backend
|
|
83
|
+
self.add_batch_dim = bool(add_batch_dim)
|
|
84
|
+
self.device = device
|
|
85
|
+
|
|
86
|
+
self.config.update(
|
|
87
|
+
{
|
|
88
|
+
"input_backend": self.input_backend,
|
|
89
|
+
"output_backend": self.output_backend,
|
|
90
|
+
"add_batch_dim": self.add_batch_dim,
|
|
91
|
+
"device": self.device,
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def _resolve_device(self) -> str:
|
|
96
|
+
if self.device != "auto":
|
|
97
|
+
return self.device
|
|
98
|
+
|
|
99
|
+
if self.input_backend == "torch":
|
|
100
|
+
return get_preferred_torch_device()
|
|
101
|
+
elif self.input_backend == "tensorflow":
|
|
102
|
+
return get_preferred_tensorflow_device()
|
|
103
|
+
else:
|
|
104
|
+
return "cpu"
|
|
105
|
+
|
|
106
|
+
def _prepare_image_input(self, image: Array, device: str) -> Any:
|
|
107
|
+
if self.input_backend == "numpy":
|
|
108
|
+
out = np.asarray(image, dtype=np.float32)
|
|
109
|
+
if self.add_batch_dim:
|
|
110
|
+
out = np.expand_dims(out, axis=0)
|
|
111
|
+
return out
|
|
112
|
+
elif self.input_backend == "tensorflow":
|
|
113
|
+
return numpy_hwc_to_tf_image(image, add_batch_dim=self.add_batch_dim)
|
|
114
|
+
elif self.input_backend == "torch":
|
|
115
|
+
return numpy_hwc_to_torch_image(
|
|
116
|
+
image,
|
|
117
|
+
add_batch_dim=self.add_batch_dim,
|
|
118
|
+
device=device,
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError("Unsupported input_backend {!r}".format(self.input_backend))
|
|
122
|
+
|
|
123
|
+
def _prepare_matrix_input(self, matrix: Optional[Array], device: str) -> Any:
|
|
124
|
+
if matrix is None:
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
if self.input_backend == "numpy":
|
|
128
|
+
return np.asarray(matrix, dtype=np.float32)
|
|
129
|
+
elif self.input_backend == "tensorflow":
|
|
130
|
+
return numpy_to_tf_matrix(matrix)
|
|
131
|
+
elif self.input_backend == "torch":
|
|
132
|
+
return numpy_to_torch_matrix(matrix, device=device)
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError("Unsupported input_backend {!r}".format(self.input_backend))
|
|
135
|
+
|
|
136
|
+
def _get_pipeline_callable(self):
|
|
137
|
+
if hasattr(self.pipeline, "run_pipeline"):
|
|
138
|
+
return self.pipeline.run_pipeline
|
|
139
|
+
elif callable(self.pipeline):
|
|
140
|
+
return self.pipeline
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"pipeline must be a callable or an object exposing run_pipeline(...)."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _normalize_result(self, result: Any, prepared: ReconstructionInputs):
|
|
147
|
+
if isinstance(result, tuple):
|
|
148
|
+
if len(result) != 2 or not isinstance(result[1], dict):
|
|
149
|
+
raise ValueError(
|
|
150
|
+
"PipelineAdapter pipelines may return either prediction or "
|
|
151
|
+
"(prediction, stats_dict)."
|
|
152
|
+
)
|
|
153
|
+
prediction, stats = result
|
|
154
|
+
else:
|
|
155
|
+
prediction, stats = result, None
|
|
156
|
+
|
|
157
|
+
if self.output_backend == "numpy":
|
|
158
|
+
pred_np = numpy_prediction_to_hwc(prediction, remove_batch_dim=self.add_batch_dim)
|
|
159
|
+
else:
|
|
160
|
+
pred_np = convert_prediction_to_numpy_hwc(
|
|
161
|
+
prediction,
|
|
162
|
+
output_backend=self.output_backend,
|
|
163
|
+
remove_batch_dim=self.add_batch_dim,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
pred_np = self._restore_output_shape(pred_np, prepared)
|
|
167
|
+
pred_np = np.asarray(pred_np, dtype=np.float32)
|
|
168
|
+
|
|
169
|
+
if stats is None:
|
|
170
|
+
return pred_np
|
|
171
|
+
return pred_np, stats
|
|
172
|
+
|
|
173
|
+
def predict(self, inputs: ReconstructionInputs) -> Any:
|
|
174
|
+
prepared = self.prepare_inputs(inputs)
|
|
175
|
+
device = self._resolve_device()
|
|
176
|
+
|
|
177
|
+
LR_HSI = self._prepare_image_input(prepared.lr_hsi, device=device)
|
|
178
|
+
HR_MSI = self._prepare_image_input(prepared.hr_msi, device=device)
|
|
179
|
+
srf = self._prepare_matrix_input(prepared.srf, device=device)
|
|
180
|
+
psf = self._prepare_matrix_input(prepared.psf, device=device)
|
|
181
|
+
|
|
182
|
+
pipeline_fn = self._get_pipeline_callable()
|
|
183
|
+
|
|
184
|
+
result = pipeline_fn(
|
|
185
|
+
HR_MSI=HR_MSI,
|
|
186
|
+
LR_HSI=LR_HSI,
|
|
187
|
+
srf=srf,
|
|
188
|
+
psf=psf,
|
|
189
|
+
metadata=prepared.metadata,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return self._normalize_result(result, prepared)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Author: Ritik Shah
|
|
2
|
+
|
|
3
|
+
"""TensorFlow model adapter for HyperBench."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
Array = np.ndarray
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TensorFlowModelAdapter(BaseAdapter):
|
|
16
|
+
"""Adapter for TensorFlow / Keras models."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model,
|
|
21
|
+
name: str = "tensorflow_model",
|
|
22
|
+
shape_policy: ShapePolicy = "strict",
|
|
23
|
+
hr_multiple: int = 1,
|
|
24
|
+
lr_multiple: int = 1,
|
|
25
|
+
add_batch_dim: bool = True,
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__(
|
|
28
|
+
name=name,
|
|
29
|
+
shape_policy=shape_policy,
|
|
30
|
+
hr_multiple=hr_multiple,
|
|
31
|
+
lr_multiple=lr_multiple,
|
|
32
|
+
)
|
|
33
|
+
self.model = model
|
|
34
|
+
self.add_batch_dim = add_batch_dim
|
|
35
|
+
|
|
36
|
+
self.config.update(
|
|
37
|
+
{
|
|
38
|
+
"add_batch_dim": self.add_batch_dim,
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def predict(self, inputs: ReconstructionInputs) -> Array:
|
|
43
|
+
try:
|
|
44
|
+
import tensorflow as tf
|
|
45
|
+
except ImportError as exc:
|
|
46
|
+
raise ImportError(
|
|
47
|
+
"TensorFlow is not installed. Install it separately before using "
|
|
48
|
+
"TensorFlowModelAdapter."
|
|
49
|
+
) from exc
|
|
50
|
+
|
|
51
|
+
from hyperbench.utils import (
|
|
52
|
+
numpy_hwc_to_tf_image,
|
|
53
|
+
numpy_to_tf_matrix,
|
|
54
|
+
tf_image_to_numpy_hwc,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
prepared = self.prepare_inputs(inputs)
|
|
58
|
+
|
|
59
|
+
lr = numpy_hwc_to_tf_image(prepared.lr_hsi, add_batch_dim=self.add_batch_dim)
|
|
60
|
+
hr = numpy_hwc_to_tf_image(prepared.hr_msi, add_batch_dim=self.add_batch_dim)
|
|
61
|
+
srf = numpy_to_tf_matrix(prepared.srf) if prepared.srf is not None else None
|
|
62
|
+
psf = numpy_to_tf_matrix(prepared.psf) if prepared.psf is not None else None
|
|
63
|
+
|
|
64
|
+
output = self.model(lr, hr, srf=srf, psf=psf)
|
|
65
|
+
output = tf_image_to_numpy_hwc(output, remove_batch_dim=self.add_batch_dim)
|
|
66
|
+
output = self._restore_output_shape(output, prepared)
|
|
67
|
+
return np.asarray(output, dtype=np.float32)
|