hyperbench 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. hyperbench/__about__.py +7 -0
  2. hyperbench/__init__.py +157 -0
  3. hyperbench/adapters/__init__.py +19 -0
  4. hyperbench/adapters/base.py +178 -0
  5. hyperbench/adapters/callable.py +62 -0
  6. hyperbench/adapters/pipeline.py +192 -0
  7. hyperbench/adapters/tensorflow.py +67 -0
  8. hyperbench/adapters/torch.py +108 -0
  9. hyperbench/benchmark/__init__.py +18 -0
  10. hyperbench/benchmark/case.py +121 -0
  11. hyperbench/benchmark/generator.py +122 -0
  12. hyperbench/benchmark/results.py +127 -0
  13. hyperbench/benchmark/runner.py +392 -0
  14. hyperbench/cli.py +209 -0
  15. hyperbench/config.py +122 -0
  16. hyperbench/degradations/__init__.py +65 -0
  17. hyperbench/degradations/preprocessing.py +111 -0
  18. hyperbench/degradations/psf.py +233 -0
  19. hyperbench/degradations/spatial.py +100 -0
  20. hyperbench/degradations/spectral.py +97 -0
  21. hyperbench/degradations/srf.py +189 -0
  22. hyperbench/exceptions.py +53 -0
  23. hyperbench/io/__init__.py +15 -0
  24. hyperbench/io/loaders.py +70 -0
  25. hyperbench/io/matlab.py +97 -0
  26. hyperbench/metrics/__init__.py +27 -0
  27. hyperbench/metrics/core.py +181 -0
  28. hyperbench/metrics/hyperspectral.py +67 -0
  29. hyperbench/types.py +24 -0
  30. hyperbench/utils/__init__.py +90 -0
  31. hyperbench/utils/frameworks.py +314 -0
  32. hyperbench/utils/logging.py +152 -0
  33. hyperbench/utils/random.py +92 -0
  34. hyperbench/utils/validation.py +213 -0
  35. hyperbench/utils/visualization.py +255 -0
  36. hyperbench-0.1.0.dist-info/METADATA +171 -0
  37. hyperbench-0.1.0.dist-info/RECORD +41 -0
  38. hyperbench-0.1.0.dist-info/WHEEL +5 -0
  39. hyperbench-0.1.0.dist-info/entry_points.txt +2 -0
  40. hyperbench-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
  41. hyperbench-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,7 @@
1
+ # Author: Ritik Shah
2
+
3
+ """Package metadata for HyperBench."""
4
+
5
+ __all__ = ["__version__"]
6
+
7
+ __version__ = "0.1.0"
hyperbench/__init__.py ADDED
@@ -0,0 +1,157 @@
1
+ # Author: Ritik Shah
2
+
3
+ """HyperBench: synthetic benchmarking for hyperspectral super-resolution."""
4
+
5
+ from .__about__ import __version__
6
+
7
+ from .adapters import (
8
+ BaseAdapter,
9
+ CallableAdapter,
10
+ PipelineAdapter,
11
+ ReconstructionInputs,
12
+ TensorFlowModelAdapter,
13
+ TorchModelAdapter,
14
+ )
15
+ from .benchmark import (
16
+ BenchmarkCase,
17
+ BenchmarkConfig,
18
+ BenchmarkResults,
19
+ DegradationSpec,
20
+ SyntheticCase,
21
+ generate_cases,
22
+ run_benchmark,
23
+ )
24
+ from .config import benchmark_config_from_dict, load_benchmark_config
25
+ from .degradations import (
26
+ AVAILABLE_PSFS,
27
+ SUPPORTED_SRF_BAND_COUNTS,
28
+ make_psf,
29
+ normalize_image,
30
+ spatial_degradation,
31
+ spectral_degradation,
32
+ )
33
+ from .exceptions import (
34
+ AdapterError,
35
+ AdapterOutputError,
36
+ ConfigurationError,
37
+ FrameworkAvailabilityError,
38
+ HyperBenchError,
39
+ IOValidationError,
40
+ MetricsError,
41
+ PipelineError,
42
+ SceneKeyError,
43
+ ShapeMismatchError,
44
+ UnsupportedPSFError,
45
+ UnsupportedSRFError,
46
+ )
47
+ from .io import load_hsi
48
+ from .metrics import (
49
+ AVAILABLE_METRICS,
50
+ DEFAULT_METRICS,
51
+ compute_ergas,
52
+ compute_psnr,
53
+ compute_rmse,
54
+ compute_sam,
55
+ compute_ssim,
56
+ compute_uiqi,
57
+ evaluate_metrics,
58
+ )
59
+ from .types import Array, Metadata, PathLike
60
+ from .utils import (
61
+ convert_prediction_to_numpy_hwc,
62
+ get_preferred_tensorflow_device,
63
+ get_preferred_torch_device,
64
+ get_tensorflow_device_info,
65
+ get_torch_device_info,
66
+ is_tensorflow_available,
67
+ is_torch_available,
68
+ numpy_hwc_to_tf_image,
69
+ numpy_hwc_to_torch_image,
70
+ numpy_prediction_to_hwc,
71
+ numpy_to_tf_matrix,
72
+ numpy_to_torch_matrix,
73
+ plot_spectra,
74
+ print_data_stats,
75
+ print_framework_device_summary,
76
+ tf_image_to_numpy_hwc,
77
+ tf_matrix_to_numpy,
78
+ torch_image_to_numpy_hwc,
79
+ torch_matrix_to_numpy,
80
+ visualize_band,
81
+ visualize_hsi,
82
+ visualize_multispectral_with_srf,
83
+ visualize_psfs,
84
+ )
85
+
86
+ __all__ = [
87
+ "__version__",
88
+ "Array",
89
+ "PathLike",
90
+ "Metadata",
91
+ "HyperBenchError",
92
+ "ConfigurationError",
93
+ "IOValidationError",
94
+ "SceneKeyError",
95
+ "ShapeMismatchError",
96
+ "UnsupportedPSFError",
97
+ "UnsupportedSRFError",
98
+ "AdapterError",
99
+ "AdapterOutputError",
100
+ "PipelineError",
101
+ "MetricsError",
102
+ "FrameworkAvailabilityError",
103
+ "BaseAdapter",
104
+ "CallableAdapter",
105
+ "PipelineAdapter",
106
+ "TorchModelAdapter",
107
+ "TensorFlowModelAdapter",
108
+ "ReconstructionInputs",
109
+ "DegradationSpec",
110
+ "BenchmarkConfig",
111
+ "BenchmarkCase",
112
+ "SyntheticCase",
113
+ "BenchmarkResults",
114
+ "generate_cases",
115
+ "run_benchmark",
116
+ "benchmark_config_from_dict",
117
+ "load_benchmark_config",
118
+ "load_hsi",
119
+ "AVAILABLE_PSFS",
120
+ "SUPPORTED_SRF_BAND_COUNTS",
121
+ "make_psf",
122
+ "normalize_image",
123
+ "spatial_degradation",
124
+ "spectral_degradation",
125
+ "AVAILABLE_METRICS",
126
+ "DEFAULT_METRICS",
127
+ "compute_rmse",
128
+ "compute_psnr",
129
+ "compute_ssim",
130
+ "compute_uiqi",
131
+ "compute_ergas",
132
+ "compute_sam",
133
+ "evaluate_metrics",
134
+ "print_data_stats",
135
+ "visualize_hsi",
136
+ "visualize_band",
137
+ "visualize_multispectral_with_srf",
138
+ "plot_spectra",
139
+ "visualize_psfs",
140
+ "is_torch_available",
141
+ "is_tensorflow_available",
142
+ "get_torch_device_info",
143
+ "get_tensorflow_device_info",
144
+ "get_preferred_torch_device",
145
+ "get_preferred_tensorflow_device",
146
+ "print_framework_device_summary",
147
+ "numpy_hwc_to_tf_image",
148
+ "tf_image_to_numpy_hwc",
149
+ "numpy_to_tf_matrix",
150
+ "tf_matrix_to_numpy",
151
+ "numpy_hwc_to_torch_image",
152
+ "torch_image_to_numpy_hwc",
153
+ "numpy_to_torch_matrix",
154
+ "torch_matrix_to_numpy",
155
+ "numpy_prediction_to_hwc",
156
+ "convert_prediction_to_numpy_hwc",
157
+ ]
@@ -0,0 +1,19 @@
1
+ # Author: Ritik Shah
2
+
3
+ """Model adapter layer for HyperBench."""
4
+
5
+ from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
6
+ from .callable import CallableAdapter
7
+ from .pipeline import PipelineAdapter
8
+ from .tensorflow import TensorFlowModelAdapter
9
+ from .torch import TorchModelAdapter
10
+
11
+ __all__ = [
12
+ "ShapePolicy",
13
+ "ReconstructionInputs",
14
+ "BaseAdapter",
15
+ "CallableAdapter",
16
+ "PipelineAdapter",
17
+ "TorchModelAdapter",
18
+ "TensorFlowModelAdapter",
19
+ ]
@@ -0,0 +1,178 @@
1
+ # Author: Ritik Shah
2
+
3
+ """Base adapter abstractions for HyperBench."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ import numpy as np
11
+
12
+
13
+ Array = np.ndarray
14
+ ShapePolicy = Literal["strict", "crop", "pad"]
15
+
16
+
17
+ @dataclass
18
+ class ReconstructionInputs:
19
+ """Canonical inference-time input bundle for HyperBench."""
20
+
21
+ lr_hsi: Array
22
+ hr_msi: Array
23
+ srf: Optional[Array] = None
24
+ psf: Optional[Array] = None
25
+ metadata: Dict[str, Any] = field(default_factory=dict)
26
+
27
+
28
+ class BaseAdapter:
29
+ """Base class for all HyperBench adapters."""
30
+
31
+ def __init__(
32
+ self,
33
+ name: str,
34
+ shape_policy: ShapePolicy = "strict",
35
+ hr_multiple: int = 1,
36
+ lr_multiple: int = 1,
37
+ ) -> None:
38
+ if shape_policy not in {"strict", "crop", "pad"}:
39
+ raise ValueError(
40
+ "shape_policy must be one of {'strict', 'crop', 'pad'}, got {!r}".format(
41
+ shape_policy
42
+ )
43
+ )
44
+
45
+ self.name = name
46
+ self.shape_policy = shape_policy
47
+ self.hr_multiple = int(hr_multiple)
48
+ self.lr_multiple = int(lr_multiple)
49
+
50
+ self.config = {
51
+ "adapter_name": self.name,
52
+ "adapter_type": type(self).__name__,
53
+ "shape_policy": self.shape_policy,
54
+ "hr_multiple": self.hr_multiple,
55
+ "lr_multiple": self.lr_multiple,
56
+ }
57
+
58
+ def predict(self, inputs: ReconstructionInputs) -> Any:
59
+ """Run model inference for one HyperBench case.
60
+
61
+ Valid returns:
62
+ - prediction
63
+ - (prediction, stats_dict)
64
+ """
65
+ raise NotImplementedError
66
+
67
+ def _validate_multiple(self, value: int, name: str) -> None:
68
+ if value < 1:
69
+ raise ValueError("{} must be >= 1, got {}".format(name, value))
70
+
71
+ def prepare_inputs(self, inputs: ReconstructionInputs) -> ReconstructionInputs:
72
+ """Apply shape policy before inference."""
73
+ self._validate_multiple(self.hr_multiple, "hr_multiple")
74
+ self._validate_multiple(self.lr_multiple, "lr_multiple")
75
+
76
+ if self.shape_policy == "strict":
77
+ self._ensure_valid_shapes(inputs)
78
+ return inputs
79
+ elif self.shape_policy == "crop":
80
+ return self._crop_inputs_to_valid(inputs)
81
+ elif self.shape_policy == "pad":
82
+ return self._pad_inputs_to_valid(inputs)
83
+ else:
84
+ raise ValueError("Unsupported shape_policy {!r}".format(self.shape_policy))
85
+
86
+ def _ensure_valid_shapes(self, inputs: ReconstructionInputs) -> None:
87
+ hr_h, hr_w = inputs.hr_msi.shape[:2]
88
+ lr_h, lr_w = inputs.lr_hsi.shape[:2]
89
+
90
+ if hr_h % self.hr_multiple != 0 or hr_w % self.hr_multiple != 0:
91
+ raise ValueError(
92
+ "HR input shape {} is not divisible by hr_multiple={}".format(
93
+ inputs.hr_msi.shape[:2], self.hr_multiple
94
+ )
95
+ )
96
+
97
+ if lr_h % self.lr_multiple != 0 or lr_w % self.lr_multiple != 0:
98
+ raise ValueError(
99
+ "LR input shape {} is not divisible by lr_multiple={}".format(
100
+ inputs.lr_hsi.shape[:2], self.lr_multiple
101
+ )
102
+ )
103
+
104
+ def _crop_inputs_to_valid(self, inputs: ReconstructionInputs) -> ReconstructionInputs:
105
+ hr_h, hr_w = inputs.hr_msi.shape[:2]
106
+ lr_h, lr_w = inputs.lr_hsi.shape[:2]
107
+
108
+ new_hr_h = hr_h - (hr_h % self.hr_multiple)
109
+ new_hr_w = hr_w - (hr_w % self.hr_multiple)
110
+ new_lr_h = lr_h - (lr_h % self.lr_multiple)
111
+ new_lr_w = lr_w - (lr_w % self.lr_multiple)
112
+
113
+ if new_hr_h <= 0 or new_hr_w <= 0:
114
+ raise ValueError("Cropping would result in invalid HR shape")
115
+ if new_lr_h <= 0 or new_lr_w <= 0:
116
+ raise ValueError("Cropping would result in invalid LR shape")
117
+
118
+ return ReconstructionInputs(
119
+ lr_hsi=inputs.lr_hsi[:new_lr_h, :new_lr_w, :],
120
+ hr_msi=inputs.hr_msi[:new_hr_h, :new_hr_w, :],
121
+ srf=inputs.srf,
122
+ psf=inputs.psf,
123
+ metadata=dict(inputs.metadata),
124
+ )
125
+
126
+ def _compute_padding(self, size: int, multiple: int) -> int:
127
+ remainder = size % multiple
128
+ return 0 if remainder == 0 else multiple - remainder
129
+
130
+ def _pad_inputs_to_valid(self, inputs: ReconstructionInputs) -> ReconstructionInputs:
131
+ lr_h, lr_w, lr_c = inputs.lr_hsi.shape
132
+ hr_h, hr_w, hr_c = inputs.hr_msi.shape
133
+
134
+ pad_lr_h = self._compute_padding(lr_h, self.lr_multiple)
135
+ pad_lr_w = self._compute_padding(lr_w, self.lr_multiple)
136
+ pad_hr_h = self._compute_padding(hr_h, self.hr_multiple)
137
+ pad_hr_w = self._compute_padding(hr_w, self.hr_multiple)
138
+
139
+ lr_padded = np.pad(
140
+ inputs.lr_hsi,
141
+ ((0, pad_lr_h), (0, pad_lr_w), (0, 0)),
142
+ mode="reflect",
143
+ ).astype(np.float32)
144
+
145
+ hr_padded = np.pad(
146
+ inputs.hr_msi,
147
+ ((0, pad_hr_h), (0, pad_hr_w), (0, 0)),
148
+ mode="reflect",
149
+ ).astype(np.float32)
150
+
151
+ metadata = dict(inputs.metadata)
152
+ metadata["_original_lr_shape"] = (lr_h, lr_w, lr_c)
153
+ metadata["_original_hr_shape"] = (hr_h, hr_w, hr_c)
154
+
155
+ return ReconstructionInputs(
156
+ lr_hsi=lr_padded,
157
+ hr_msi=hr_padded,
158
+ srf=inputs.srf,
159
+ psf=inputs.psf,
160
+ metadata=metadata,
161
+ )
162
+
163
+ def _restore_output_shape(self, output: Array, inputs: ReconstructionInputs) -> Array:
164
+ """Restore output to original HR shape when using pad policy."""
165
+ if self.shape_policy != "pad":
166
+ return output
167
+
168
+ original_hr_shape = inputs.metadata.get("_original_hr_shape")
169
+ if original_hr_shape is None:
170
+ return output
171
+
172
+ h, w, c = original_hr_shape
173
+ return output[:h, :w, :c]
174
+
175
+ @property
176
+ def csv_metadata(self) -> Dict[str, Any]:
177
+ """Optional static metadata to merge into CSV rows."""
178
+ return {}
@@ -0,0 +1,62 @@
1
+ # Author: Ritik Shah
2
+
3
+ """Callable adapter for HyperBench."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import Any, Callable, Dict, Tuple
8
+
9
+ import numpy as np
10
+
11
+ from hyperbench.utils import numpy_prediction_to_hwc
12
+
13
+ from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
14
+
15
+
16
+ Array = np.ndarray
17
+
18
+
19
+ class CallableAdapter(BaseAdapter):
20
+ """Wrap a plain Python callable as a HyperBench adapter."""
21
+
22
+ def __init__(
23
+ self,
24
+ fn: Callable,
25
+ name: str = "callable_model",
26
+ shape_policy: ShapePolicy = "strict",
27
+ hr_multiple: int = 1,
28
+ lr_multiple: int = 1,
29
+ ) -> None:
30
+ super().__init__(
31
+ name=name,
32
+ shape_policy=shape_policy,
33
+ hr_multiple=hr_multiple,
34
+ lr_multiple=lr_multiple,
35
+ )
36
+ self.fn = fn
37
+
38
+ def predict(self, inputs: ReconstructionInputs) -> Any:
39
+ prepared = self.prepare_inputs(inputs)
40
+
41
+ result = self.fn(
42
+ prepared.lr_hsi,
43
+ prepared.hr_msi,
44
+ srf=prepared.srf,
45
+ psf=prepared.psf,
46
+ metadata=prepared.metadata,
47
+ )
48
+
49
+ if isinstance(result, tuple):
50
+ if len(result) != 2 or not isinstance(result[1], dict):
51
+ raise ValueError(
52
+ "CallableAdapter functions may return either prediction or "
53
+ "(prediction, stats_dict)."
54
+ )
55
+ prediction, stats = result
56
+ prediction = numpy_prediction_to_hwc(prediction, remove_batch_dim=True)
57
+ prediction = self._restore_output_shape(prediction, prepared)
58
+ return prediction.astype(np.float32), stats
59
+
60
+ prediction = numpy_prediction_to_hwc(result, remove_batch_dim=True)
61
+ prediction = self._restore_output_shape(prediction, prepared)
62
+ return prediction.astype(np.float32)
@@ -0,0 +1,192 @@
1
+ # Author: Ritik Shah
2
+
3
+ """Generic pipeline adapter for HyperBench.
4
+
5
+ This adapter is the recommended path for real-world research models.
6
+
7
+ Expected user-side contract:
8
+ - an object with a `run_pipeline(...)` method, or
9
+ - a callable function itself
10
+
11
+ The pipeline should accept:
12
+ run_pipeline(HR_MSI, LR_HSI, srf, psf=None, metadata=None)
13
+
14
+ Valid returns:
15
+ - prediction
16
+ - (prediction, stats_dict)
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from typing import Any, Dict, Optional
22
+
23
+ import numpy as np
24
+
25
+ from hyperbench.utils import (
26
+ convert_prediction_to_numpy_hwc,
27
+ get_preferred_tensorflow_device,
28
+ get_preferred_torch_device,
29
+ numpy_hwc_to_tf_image,
30
+ numpy_hwc_to_torch_image,
31
+ numpy_to_tf_matrix,
32
+ numpy_to_torch_matrix,
33
+ numpy_prediction_to_hwc,
34
+ )
35
+
36
+ from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
37
+
38
+
39
+ Array = np.ndarray
40
+
41
+
42
+ class PipelineAdapter(BaseAdapter):
43
+ """Generic adapter for model wrappers exposing `run_pipeline(...)`."""
44
+
45
+ def __init__(
46
+ self,
47
+ pipeline: Any,
48
+ name: str = "pipeline_model",
49
+ input_backend: str = "numpy",
50
+ output_backend: Optional[str] = None,
51
+ add_batch_dim: bool = True,
52
+ device: str = "auto",
53
+ shape_policy: ShapePolicy = "strict",
54
+ hr_multiple: int = 1,
55
+ lr_multiple: int = 1,
56
+ ) -> None:
57
+ super().__init__(
58
+ name=name,
59
+ shape_policy=shape_policy,
60
+ hr_multiple=hr_multiple,
61
+ lr_multiple=lr_multiple,
62
+ )
63
+
64
+ backend = input_backend.lower()
65
+ if backend not in {"numpy", "tensorflow", "torch"}:
66
+ raise ValueError(
67
+ "input_backend must be one of {'numpy', 'tensorflow', 'torch'}, got {!r}".format(
68
+ input_backend
69
+ )
70
+ )
71
+
72
+ out_backend = output_backend.lower() if output_backend is not None else backend
73
+ if out_backend not in {"numpy", "tensorflow", "torch"}:
74
+ raise ValueError(
75
+ "output_backend must be one of {'numpy', 'tensorflow', 'torch'}, got {!r}".format(
76
+ output_backend
77
+ )
78
+ )
79
+
80
+ self.pipeline = pipeline
81
+ self.input_backend = backend
82
+ self.output_backend = out_backend
83
+ self.add_batch_dim = bool(add_batch_dim)
84
+ self.device = device
85
+
86
+ self.config.update(
87
+ {
88
+ "input_backend": self.input_backend,
89
+ "output_backend": self.output_backend,
90
+ "add_batch_dim": self.add_batch_dim,
91
+ "device": self.device,
92
+ }
93
+ )
94
+
95
+ def _resolve_device(self) -> str:
96
+ if self.device != "auto":
97
+ return self.device
98
+
99
+ if self.input_backend == "torch":
100
+ return get_preferred_torch_device()
101
+ elif self.input_backend == "tensorflow":
102
+ return get_preferred_tensorflow_device()
103
+ else:
104
+ return "cpu"
105
+
106
+ def _prepare_image_input(self, image: Array, device: str) -> Any:
107
+ if self.input_backend == "numpy":
108
+ out = np.asarray(image, dtype=np.float32)
109
+ if self.add_batch_dim:
110
+ out = np.expand_dims(out, axis=0)
111
+ return out
112
+ elif self.input_backend == "tensorflow":
113
+ return numpy_hwc_to_tf_image(image, add_batch_dim=self.add_batch_dim)
114
+ elif self.input_backend == "torch":
115
+ return numpy_hwc_to_torch_image(
116
+ image,
117
+ add_batch_dim=self.add_batch_dim,
118
+ device=device,
119
+ )
120
+ else:
121
+ raise ValueError("Unsupported input_backend {!r}".format(self.input_backend))
122
+
123
+ def _prepare_matrix_input(self, matrix: Optional[Array], device: str) -> Any:
124
+ if matrix is None:
125
+ return None
126
+
127
+ if self.input_backend == "numpy":
128
+ return np.asarray(matrix, dtype=np.float32)
129
+ elif self.input_backend == "tensorflow":
130
+ return numpy_to_tf_matrix(matrix)
131
+ elif self.input_backend == "torch":
132
+ return numpy_to_torch_matrix(matrix, device=device)
133
+ else:
134
+ raise ValueError("Unsupported input_backend {!r}".format(self.input_backend))
135
+
136
+ def _get_pipeline_callable(self):
137
+ if hasattr(self.pipeline, "run_pipeline"):
138
+ return self.pipeline.run_pipeline
139
+ elif callable(self.pipeline):
140
+ return self.pipeline
141
+ else:
142
+ raise ValueError(
143
+ "pipeline must be a callable or an object exposing run_pipeline(...)."
144
+ )
145
+
146
+ def _normalize_result(self, result: Any, prepared: ReconstructionInputs):
147
+ if isinstance(result, tuple):
148
+ if len(result) != 2 or not isinstance(result[1], dict):
149
+ raise ValueError(
150
+ "PipelineAdapter pipelines may return either prediction or "
151
+ "(prediction, stats_dict)."
152
+ )
153
+ prediction, stats = result
154
+ else:
155
+ prediction, stats = result, None
156
+
157
+ if self.output_backend == "numpy":
158
+ pred_np = numpy_prediction_to_hwc(prediction, remove_batch_dim=self.add_batch_dim)
159
+ else:
160
+ pred_np = convert_prediction_to_numpy_hwc(
161
+ prediction,
162
+ output_backend=self.output_backend,
163
+ remove_batch_dim=self.add_batch_dim,
164
+ )
165
+
166
+ pred_np = self._restore_output_shape(pred_np, prepared)
167
+ pred_np = np.asarray(pred_np, dtype=np.float32)
168
+
169
+ if stats is None:
170
+ return pred_np
171
+ return pred_np, stats
172
+
173
+ def predict(self, inputs: ReconstructionInputs) -> Any:
174
+ prepared = self.prepare_inputs(inputs)
175
+ device = self._resolve_device()
176
+
177
+ LR_HSI = self._prepare_image_input(prepared.lr_hsi, device=device)
178
+ HR_MSI = self._prepare_image_input(prepared.hr_msi, device=device)
179
+ srf = self._prepare_matrix_input(prepared.srf, device=device)
180
+ psf = self._prepare_matrix_input(prepared.psf, device=device)
181
+
182
+ pipeline_fn = self._get_pipeline_callable()
183
+
184
+ result = pipeline_fn(
185
+ HR_MSI=HR_MSI,
186
+ LR_HSI=LR_HSI,
187
+ srf=srf,
188
+ psf=psf,
189
+ metadata=prepared.metadata,
190
+ )
191
+
192
+ return self._normalize_result(result, prepared)
@@ -0,0 +1,67 @@
1
+ # Author: Ritik Shah
2
+
3
+ """TensorFlow model adapter for HyperBench."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+
9
+ from .base import BaseAdapter, ReconstructionInputs, ShapePolicy
10
+
11
+
12
+ Array = np.ndarray
13
+
14
+
15
+ class TensorFlowModelAdapter(BaseAdapter):
16
+ """Adapter for TensorFlow / Keras models."""
17
+
18
+ def __init__(
19
+ self,
20
+ model,
21
+ name: str = "tensorflow_model",
22
+ shape_policy: ShapePolicy = "strict",
23
+ hr_multiple: int = 1,
24
+ lr_multiple: int = 1,
25
+ add_batch_dim: bool = True,
26
+ ) -> None:
27
+ super().__init__(
28
+ name=name,
29
+ shape_policy=shape_policy,
30
+ hr_multiple=hr_multiple,
31
+ lr_multiple=lr_multiple,
32
+ )
33
+ self.model = model
34
+ self.add_batch_dim = add_batch_dim
35
+
36
+ self.config.update(
37
+ {
38
+ "add_batch_dim": self.add_batch_dim,
39
+ }
40
+ )
41
+
42
+ def predict(self, inputs: ReconstructionInputs) -> Array:
43
+ try:
44
+ import tensorflow as tf
45
+ except ImportError as exc:
46
+ raise ImportError(
47
+ "TensorFlow is not installed. Install it separately before using "
48
+ "TensorFlowModelAdapter."
49
+ ) from exc
50
+
51
+ from hyperbench.utils import (
52
+ numpy_hwc_to_tf_image,
53
+ numpy_to_tf_matrix,
54
+ tf_image_to_numpy_hwc,
55
+ )
56
+
57
+ prepared = self.prepare_inputs(inputs)
58
+
59
+ lr = numpy_hwc_to_tf_image(prepared.lr_hsi, add_batch_dim=self.add_batch_dim)
60
+ hr = numpy_hwc_to_tf_image(prepared.hr_msi, add_batch_dim=self.add_batch_dim)
61
+ srf = numpy_to_tf_matrix(prepared.srf) if prepared.srf is not None else None
62
+ psf = numpy_to_tf_matrix(prepared.psf) if prepared.psf is not None else None
63
+
64
+ output = self.model(lr, hr, srf=srf, psf=psf)
65
+ output = tf_image_to_numpy_hwc(output, remove_batch_dim=self.add_batch_dim)
66
+ output = self._restore_output_shape(output, prepared)
67
+ return np.asarray(output, dtype=np.float32)