zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +3 -3
- zea/agent/masks.py +2 -2
- zea/agent/selection.py +3 -3
- zea/backend/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +1 -5
- zea/beamform/beamformer.py +4 -2
- zea/beamform/pfield.py +2 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/data/__init__.py +0 -9
- zea/data/augmentations.py +222 -29
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +164 -0
- zea/data/convert/camus.py +106 -40
- zea/data/convert/echonet.py +184 -83
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/verasonics.py +1247 -0
- zea/data/data_format.py +124 -6
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +109 -70
- zea/data/file.py +119 -82
- zea/data/file_operations.py +496 -0
- zea/data/preset_utils.py +2 -2
- zea/display.py +8 -9
- zea/doppler.py +5 -5
- zea/func/__init__.py +109 -0
- zea/{tensor_ops.py → func/tensor.py} +113 -69
- zea/func/ultrasound.py +500 -0
- zea/internal/_generate_keras_ops.py +5 -5
- zea/internal/checks.py +6 -12
- zea/internal/operators.py +4 -0
- zea/io_lib.py +108 -160
- zea/metrics.py +6 -5
- zea/models/__init__.py +1 -1
- zea/models/diffusion.py +63 -12
- zea/models/echonetlvh.py +1 -1
- zea/models/gmm.py +1 -1
- zea/models/lv_segmentation.py +2 -0
- zea/ops/__init__.py +188 -0
- zea/ops/base.py +442 -0
- zea/{keras_ops.py → ops/keras_ops.py} +2 -2
- zea/ops/pipeline.py +1472 -0
- zea/ops/tensor.py +356 -0
- zea/ops/ultrasound.py +890 -0
- zea/probes.py +2 -10
- zea/scan.py +35 -28
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/selection_tool.py +1 -1
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +11 -2
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
- zea/data/convert/matlab.py +0 -1237
- zea/ops.py +0 -3294
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops/ultrasound.py
ADDED
|
@@ -0,0 +1,890 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
from zea import log
|
|
8
|
+
from zea.beamform.beamformer import tof_correction
|
|
9
|
+
from zea.display import scan_convert
|
|
10
|
+
from zea.func.tensor import (
|
|
11
|
+
apply_along_axis,
|
|
12
|
+
correlate,
|
|
13
|
+
extend_n_dims,
|
|
14
|
+
reshape_axis,
|
|
15
|
+
)
|
|
16
|
+
from zea.func.ultrasound import (
|
|
17
|
+
channels_to_complex,
|
|
18
|
+
complex_to_channels,
|
|
19
|
+
demodulate,
|
|
20
|
+
envelope_detect,
|
|
21
|
+
get_low_pass_iq_filter,
|
|
22
|
+
log_compress,
|
|
23
|
+
upmix,
|
|
24
|
+
)
|
|
25
|
+
from zea.internal.core import (
|
|
26
|
+
DEFAULT_DYNAMIC_RANGE,
|
|
27
|
+
DataTypes,
|
|
28
|
+
)
|
|
29
|
+
from zea.internal.registry import ops_registry
|
|
30
|
+
from zea.ops.base import (
|
|
31
|
+
ImageOperation,
|
|
32
|
+
Operation,
|
|
33
|
+
)
|
|
34
|
+
from zea.ops.tensor import (
|
|
35
|
+
GaussianBlur,
|
|
36
|
+
)
|
|
37
|
+
from zea.simulator import simulate_rf
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@ops_registry("simulate_rf")
|
|
41
|
+
class Simulate(Operation):
|
|
42
|
+
"""Simulate RF data."""
|
|
43
|
+
|
|
44
|
+
# Define operation-specific static parameters
|
|
45
|
+
STATIC_PARAMS = ["n_ax", "apply_lens_correction"]
|
|
46
|
+
|
|
47
|
+
def __init__(self, **kwargs):
|
|
48
|
+
super().__init__(
|
|
49
|
+
output_data_type=DataTypes.RAW_DATA,
|
|
50
|
+
additional_output_keys=["n_ch"],
|
|
51
|
+
**kwargs,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def call(
|
|
55
|
+
self,
|
|
56
|
+
scatterer_positions,
|
|
57
|
+
scatterer_magnitudes,
|
|
58
|
+
probe_geometry,
|
|
59
|
+
apply_lens_correction,
|
|
60
|
+
lens_thickness,
|
|
61
|
+
lens_sound_speed,
|
|
62
|
+
sound_speed,
|
|
63
|
+
n_ax,
|
|
64
|
+
center_frequency,
|
|
65
|
+
sampling_frequency,
|
|
66
|
+
t0_delays,
|
|
67
|
+
initial_times,
|
|
68
|
+
element_width,
|
|
69
|
+
attenuation_coef,
|
|
70
|
+
tx_apodizations,
|
|
71
|
+
**kwargs,
|
|
72
|
+
):
|
|
73
|
+
return {
|
|
74
|
+
self.output_key: simulate_rf(
|
|
75
|
+
ops.convert_to_tensor(scatterer_positions),
|
|
76
|
+
ops.convert_to_tensor(scatterer_magnitudes),
|
|
77
|
+
probe_geometry=probe_geometry,
|
|
78
|
+
apply_lens_correction=apply_lens_correction,
|
|
79
|
+
lens_thickness=lens_thickness,
|
|
80
|
+
lens_sound_speed=lens_sound_speed,
|
|
81
|
+
sound_speed=sound_speed,
|
|
82
|
+
n_ax=n_ax,
|
|
83
|
+
center_frequency=center_frequency,
|
|
84
|
+
sampling_frequency=sampling_frequency,
|
|
85
|
+
t0_delays=t0_delays,
|
|
86
|
+
initial_times=initial_times,
|
|
87
|
+
element_width=element_width,
|
|
88
|
+
attenuation_coef=attenuation_coef,
|
|
89
|
+
tx_apodizations=tx_apodizations,
|
|
90
|
+
),
|
|
91
|
+
"n_ch": 1, # Simulate always returns RF data (so single channel)
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@ops_registry("tof_correction")
|
|
96
|
+
class TOFCorrection(Operation):
|
|
97
|
+
"""Time-of-flight correction operation for ultrasound data."""
|
|
98
|
+
|
|
99
|
+
# Define operation-specific static parameters
|
|
100
|
+
STATIC_PARAMS = ["f_number", "apply_lens_correction"]
|
|
101
|
+
|
|
102
|
+
def __init__(self, **kwargs):
|
|
103
|
+
super().__init__(
|
|
104
|
+
input_data_type=DataTypes.RAW_DATA,
|
|
105
|
+
output_data_type=DataTypes.ALIGNED_DATA,
|
|
106
|
+
**kwargs,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def call(
|
|
110
|
+
self,
|
|
111
|
+
flatgrid,
|
|
112
|
+
sound_speed,
|
|
113
|
+
polar_angles,
|
|
114
|
+
focus_distances,
|
|
115
|
+
sampling_frequency,
|
|
116
|
+
f_number,
|
|
117
|
+
demodulation_frequency,
|
|
118
|
+
t0_delays,
|
|
119
|
+
tx_apodizations,
|
|
120
|
+
initial_times,
|
|
121
|
+
probe_geometry,
|
|
122
|
+
t_peak,
|
|
123
|
+
tx_waveform_indices,
|
|
124
|
+
apply_lens_correction=None,
|
|
125
|
+
lens_thickness=None,
|
|
126
|
+
lens_sound_speed=None,
|
|
127
|
+
**kwargs,
|
|
128
|
+
):
|
|
129
|
+
"""Perform time-of-flight correction on raw RF data.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
raw_data (ops.Tensor): Raw RF data to correct
|
|
133
|
+
flatgrid (ops.Tensor): Grid points at which to evaluate the time-of-flight
|
|
134
|
+
sound_speed (float): Sound speed in the medium
|
|
135
|
+
polar_angles (ops.Tensor): Polar angles for scan lines
|
|
136
|
+
focus_distances (ops.Tensor): Focus distances for scan lines
|
|
137
|
+
sampling_frequency (float): Sampling frequency
|
|
138
|
+
f_number (float): F-number for apodization
|
|
139
|
+
demodulation_frequency (float): Demodulation frequency
|
|
140
|
+
t0_delays (ops.Tensor): T0 delays
|
|
141
|
+
tx_apodizations (ops.Tensor): Transmit apodizations
|
|
142
|
+
initial_times (ops.Tensor): Initial times
|
|
143
|
+
probe_geometry (ops.Tensor): Probe element positions
|
|
144
|
+
t_peak (float): Time to peak of the transmit pulse
|
|
145
|
+
tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
|
|
146
|
+
transmit. (All zero if there is only one waveform)
|
|
147
|
+
apply_lens_correction (bool): Whether to apply lens correction
|
|
148
|
+
lens_thickness (float): Lens thickness
|
|
149
|
+
lens_sound_speed (float): Sound speed in the lens
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
dict: Dictionary containing tof_corrected_data
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
raw_data = kwargs[self.key]
|
|
156
|
+
|
|
157
|
+
tof_kwargs = {
|
|
158
|
+
"flatgrid": flatgrid,
|
|
159
|
+
"t0_delays": t0_delays,
|
|
160
|
+
"tx_apodizations": tx_apodizations,
|
|
161
|
+
"sound_speed": sound_speed,
|
|
162
|
+
"probe_geometry": probe_geometry,
|
|
163
|
+
"initial_times": initial_times,
|
|
164
|
+
"sampling_frequency": sampling_frequency,
|
|
165
|
+
"demodulation_frequency": demodulation_frequency,
|
|
166
|
+
"f_number": f_number,
|
|
167
|
+
"polar_angles": polar_angles,
|
|
168
|
+
"focus_distances": focus_distances,
|
|
169
|
+
"t_peak": t_peak,
|
|
170
|
+
"tx_waveform_indices": tx_waveform_indices,
|
|
171
|
+
"apply_lens_correction": apply_lens_correction,
|
|
172
|
+
"lens_thickness": lens_thickness,
|
|
173
|
+
"lens_sound_speed": lens_sound_speed,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
if not self.with_batch_dim:
|
|
177
|
+
tof_corrected = tof_correction(raw_data, **tof_kwargs)
|
|
178
|
+
else:
|
|
179
|
+
tof_corrected = ops.map(
|
|
180
|
+
lambda data: tof_correction(data, **tof_kwargs),
|
|
181
|
+
raw_data,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
return {self.output_key: tof_corrected}
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@ops_registry("pfield_weighting")
|
|
188
|
+
class PfieldWeighting(Operation):
|
|
189
|
+
"""Weighting aligned data with the pressure field."""
|
|
190
|
+
|
|
191
|
+
def __init__(self, **kwargs):
|
|
192
|
+
super().__init__(
|
|
193
|
+
input_data_type=DataTypes.ALIGNED_DATA,
|
|
194
|
+
output_data_type=DataTypes.ALIGNED_DATA,
|
|
195
|
+
**kwargs,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def call(self, flat_pfield=None, **kwargs):
|
|
199
|
+
"""Weight data with pressure field.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
flat_pfield (ops.Tensor): Pressure field weight mask of shape (n_pix, n_tx)
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
dict: Dictionary containing weighted data
|
|
206
|
+
"""
|
|
207
|
+
data = kwargs[self.key] # must start with ((batch_size,) n_tx, n_pix, ...)
|
|
208
|
+
|
|
209
|
+
if flat_pfield is None:
|
|
210
|
+
return {self.output_key: data}
|
|
211
|
+
|
|
212
|
+
# Swap (n_pix, n_tx) to (n_tx, n_pix)
|
|
213
|
+
flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
|
|
214
|
+
|
|
215
|
+
# Add batch dimension if needed
|
|
216
|
+
if self.with_batch_dim:
|
|
217
|
+
pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
|
|
218
|
+
else:
|
|
219
|
+
pfield_expanded = flat_pfield
|
|
220
|
+
|
|
221
|
+
append_n_dims = ops.ndim(data) - ops.ndim(pfield_expanded)
|
|
222
|
+
pfield_expanded = extend_n_dims(pfield_expanded, axis=-1, n_dims=append_n_dims)
|
|
223
|
+
|
|
224
|
+
# Perform element-wise multiplication with the pressure weight mask
|
|
225
|
+
weighted_data = data * pfield_expanded
|
|
226
|
+
|
|
227
|
+
return {self.output_key: weighted_data}
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@ops_registry("scan_convert")
|
|
231
|
+
class ScanConvert(Operation):
|
|
232
|
+
"""Scan convert images to cartesian coordinates."""
|
|
233
|
+
|
|
234
|
+
STATIC_PARAMS = ["fill_value"]
|
|
235
|
+
|
|
236
|
+
def __init__(self, order=1, **kwargs):
|
|
237
|
+
"""Initialize the ScanConvert operation.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
order (int, optional): Interpolation order. Defaults to 1. Currently only
|
|
241
|
+
GPU support for order=1.
|
|
242
|
+
"""
|
|
243
|
+
if order > 1:
|
|
244
|
+
jittable = False
|
|
245
|
+
log.warning(
|
|
246
|
+
"GPU support for order > 1 is not available. " + "Disabling jit for ScanConvert."
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
jittable = True
|
|
250
|
+
|
|
251
|
+
super().__init__(
|
|
252
|
+
input_data_type=DataTypes.IMAGE,
|
|
253
|
+
output_data_type=DataTypes.IMAGE_SC,
|
|
254
|
+
jittable=jittable,
|
|
255
|
+
additional_output_keys=[
|
|
256
|
+
"resolution",
|
|
257
|
+
"x_lim",
|
|
258
|
+
"y_lim",
|
|
259
|
+
"z_lim",
|
|
260
|
+
"rho_range",
|
|
261
|
+
"theta_range",
|
|
262
|
+
"phi_range",
|
|
263
|
+
"d_rho",
|
|
264
|
+
"d_theta",
|
|
265
|
+
"d_phi",
|
|
266
|
+
],
|
|
267
|
+
**kwargs,
|
|
268
|
+
)
|
|
269
|
+
self.order = order
|
|
270
|
+
|
|
271
|
+
def call(
|
|
272
|
+
self,
|
|
273
|
+
rho_range=None,
|
|
274
|
+
theta_range=None,
|
|
275
|
+
phi_range=None,
|
|
276
|
+
resolution=None,
|
|
277
|
+
coordinates=None,
|
|
278
|
+
fill_value=None,
|
|
279
|
+
**kwargs,
|
|
280
|
+
):
|
|
281
|
+
"""Scan convert images to cartesian coordinates.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
rho_range (Tuple): Range of the rho axis in the polar coordinate system.
|
|
285
|
+
Defined in meters.
|
|
286
|
+
theta_range (Tuple): Range of the theta axis in the polar coordinate system.
|
|
287
|
+
Defined in radians.
|
|
288
|
+
phi_range (Tuple): Range of the phi axis in the polar coordinate system.
|
|
289
|
+
Defined in radians.
|
|
290
|
+
resolution (float): Resolution of the output image in meters per pixel.
|
|
291
|
+
if None, the resolution is computed based on the input data.
|
|
292
|
+
coordinates (Tensor): Coordinates for scan convertion. If None, will be computed
|
|
293
|
+
based on rho_range, theta_range, phi_range and resolution. If provided, this
|
|
294
|
+
operation can be jitted.
|
|
295
|
+
fill_value (float): Value to fill the image with outside the defined region.
|
|
296
|
+
|
|
297
|
+
"""
|
|
298
|
+
if fill_value is None:
|
|
299
|
+
fill_value = np.nan
|
|
300
|
+
|
|
301
|
+
data = kwargs[self.key]
|
|
302
|
+
|
|
303
|
+
if self._jit_compile and self.jittable:
|
|
304
|
+
assert coordinates is not None, (
|
|
305
|
+
"coordinates must be provided to jit scan conversion."
|
|
306
|
+
"You can set ScanConvert(jit_compile=False) to disable jitting."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
data_out, parameters = scan_convert(
|
|
310
|
+
data,
|
|
311
|
+
rho_range,
|
|
312
|
+
theta_range,
|
|
313
|
+
phi_range,
|
|
314
|
+
resolution,
|
|
315
|
+
coordinates,
|
|
316
|
+
fill_value,
|
|
317
|
+
self.order,
|
|
318
|
+
with_batch_dim=self.with_batch_dim,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
return {self.output_key: data_out, **parameters}
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@ops_registry("demodulate")
|
|
325
|
+
class Demodulate(Operation):
|
|
326
|
+
"""Demodulates the input data to baseband. After this operation, the carrier frequency
|
|
327
|
+
is removed (0 Hz) and the data is in IQ format stored in two real valued channels."""
|
|
328
|
+
|
|
329
|
+
def __init__(self, axis=-3, **kwargs):
|
|
330
|
+
super().__init__(
|
|
331
|
+
input_data_type=DataTypes.RAW_DATA,
|
|
332
|
+
output_data_type=DataTypes.RAW_DATA,
|
|
333
|
+
jittable=True,
|
|
334
|
+
additional_output_keys=[
|
|
335
|
+
"demodulation_frequency",
|
|
336
|
+
"center_frequency",
|
|
337
|
+
"n_ch",
|
|
338
|
+
],
|
|
339
|
+
**kwargs,
|
|
340
|
+
)
|
|
341
|
+
self.axis = axis
|
|
342
|
+
|
|
343
|
+
def call(self, center_frequency=None, sampling_frequency=None, **kwargs):
|
|
344
|
+
data = kwargs[self.key]
|
|
345
|
+
|
|
346
|
+
demodulation_frequency = center_frequency
|
|
347
|
+
|
|
348
|
+
# Split the complex signal into two channels
|
|
349
|
+
iq_data_two_channel = demodulate(
|
|
350
|
+
data=data,
|
|
351
|
+
center_frequency=center_frequency,
|
|
352
|
+
sampling_frequency=sampling_frequency,
|
|
353
|
+
axis=self.axis,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return {
|
|
357
|
+
self.output_key: iq_data_two_channel,
|
|
358
|
+
"demodulation_frequency": demodulation_frequency,
|
|
359
|
+
"center_frequency": 0.0,
|
|
360
|
+
"n_ch": 2,
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@ops_registry("fir_filter")
|
|
365
|
+
class FirFilter(Operation):
|
|
366
|
+
"""Apply a FIR filter to the input signal using convolution.
|
|
367
|
+
|
|
368
|
+
Looks for the filter taps in the input dictionary using the specified ``filter_key``.
|
|
369
|
+
"""
|
|
370
|
+
|
|
371
|
+
def __init__(
|
|
372
|
+
self,
|
|
373
|
+
axis: int,
|
|
374
|
+
complex_channels: bool = False,
|
|
375
|
+
filter_key: str = "fir_filter_taps",
|
|
376
|
+
**kwargs,
|
|
377
|
+
):
|
|
378
|
+
"""
|
|
379
|
+
Args:
|
|
380
|
+
axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
|
|
381
|
+
When using ``complex_channels=True``, the complex channels are removed to convert
|
|
382
|
+
to complex numbers before filtering, so adjust the ``axis`` accordingly!
|
|
383
|
+
complex_channels (bool): Whether the last dimension of the input signal represents
|
|
384
|
+
complex channels (real and imaginary parts). When True, it will convert the signal
|
|
385
|
+
to ``complex`` dtype before filtering and convert it back to two channels
|
|
386
|
+
after filtering.
|
|
387
|
+
filter_key (str): Key in the input dictionary where the FIR filter taps are stored.
|
|
388
|
+
Default is "fir_filter_taps".
|
|
389
|
+
"""
|
|
390
|
+
super().__init__(**kwargs)
|
|
391
|
+
self._check_axis(axis)
|
|
392
|
+
|
|
393
|
+
self.axis = axis
|
|
394
|
+
self.complex_channels = complex_channels
|
|
395
|
+
self.filter_key = filter_key
|
|
396
|
+
|
|
397
|
+
def _check_axis(self, axis, ndim=None):
|
|
398
|
+
"""Check if the axis is valid."""
|
|
399
|
+
if ndim is not None:
|
|
400
|
+
if axis < -ndim or axis >= ndim:
|
|
401
|
+
raise ValueError(f"Axis {axis} is out of bounds for array of dimension {ndim}.")
|
|
402
|
+
|
|
403
|
+
if self.with_batch_dim and (axis == 0 or (ndim is not None and axis == -ndim)):
|
|
404
|
+
raise ValueError("Cannot apply FIR filter along batch dimension.")
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def valid_keys(self):
|
|
408
|
+
"""Get the valid keys for the `call` method."""
|
|
409
|
+
return self._valid_keys.union({self.filter_key})
|
|
410
|
+
|
|
411
|
+
def call(self, **kwargs):
|
|
412
|
+
signal = kwargs[self.key]
|
|
413
|
+
fir_filter_taps = kwargs[self.filter_key]
|
|
414
|
+
|
|
415
|
+
if self.complex_channels:
|
|
416
|
+
signal = channels_to_complex(signal)
|
|
417
|
+
|
|
418
|
+
self._check_axis(self.axis, ndim=ops.ndim(signal))
|
|
419
|
+
|
|
420
|
+
def _convolve(signal):
|
|
421
|
+
"""Apply the filter to the signal using correlation."""
|
|
422
|
+
return correlate(signal, fir_filter_taps[::-1], mode="same")
|
|
423
|
+
|
|
424
|
+
filtered_signal = apply_along_axis(_convolve, self.axis, signal)
|
|
425
|
+
|
|
426
|
+
if self.complex_channels:
|
|
427
|
+
filtered_signal = complex_to_channels(filtered_signal)
|
|
428
|
+
|
|
429
|
+
return {self.output_key: filtered_signal}
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
@ops_registry("low_pass_filter")
|
|
433
|
+
class LowPassFilter(FirFilter):
|
|
434
|
+
"""Apply a low-pass FIR filter to the input signal using convolution.
|
|
435
|
+
|
|
436
|
+
It is recommended to use :class:`FirFilter` with pre-computed filter taps for jittable
|
|
437
|
+
operations. The :class:`LowPassFilter` operation itself is not jittable and is provided
|
|
438
|
+
for convenience only.
|
|
439
|
+
|
|
440
|
+
Uses :func:`get_low_pass_iq_filter` to compute the filter taps.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
def __init__(self, axis: int, complex_channels: bool = False, num_taps: int = 128, **kwargs):
|
|
444
|
+
"""Initialize the LowPassFilter operation.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
|
|
448
|
+
When using ``complex_channels=True``, the complex channels are removed to convert
|
|
449
|
+
to complex numbers before filtering, so adjust the ``axis`` accordingly.
|
|
450
|
+
complex_channels (bool): Whether the last dimension of the input signal represents
|
|
451
|
+
complex channels (real and imaginary parts). When True, it will convert the signal
|
|
452
|
+
to ``complex`` dtype before filtering and convert it back to two channels
|
|
453
|
+
after filtering.
|
|
454
|
+
num_taps (int): Number of taps in the FIR filter. Default is 128.
|
|
455
|
+
"""
|
|
456
|
+
self._random_suffix = str(uuid.uuid4())
|
|
457
|
+
kwargs.pop("filter_key", None)
|
|
458
|
+
kwargs.pop("jittable", None)
|
|
459
|
+
super().__init__(
|
|
460
|
+
axis=axis,
|
|
461
|
+
complex_channels=complex_channels,
|
|
462
|
+
filter_key=f"low_pass_{self._random_suffix}",
|
|
463
|
+
jittable=False,
|
|
464
|
+
**kwargs,
|
|
465
|
+
)
|
|
466
|
+
self.num_taps = num_taps
|
|
467
|
+
|
|
468
|
+
def call(self, bandwidth, sampling_frequency, center_frequency, **kwargs):
|
|
469
|
+
lpf = get_low_pass_iq_filter(
|
|
470
|
+
self.num_taps,
|
|
471
|
+
ops.convert_to_numpy(sampling_frequency).item(),
|
|
472
|
+
ops.convert_to_numpy(center_frequency).item(),
|
|
473
|
+
ops.convert_to_numpy(bandwidth).item(),
|
|
474
|
+
)
|
|
475
|
+
kwargs[self.filter_key] = lpf
|
|
476
|
+
return super().call(**kwargs)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
@ops_registry("channels_to_complex")
|
|
480
|
+
class ChannelsToComplex(Operation):
|
|
481
|
+
def call(self, **kwargs):
|
|
482
|
+
data = kwargs[self.key]
|
|
483
|
+
output = channels_to_complex(data)
|
|
484
|
+
return {self.output_key: output}
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
@ops_registry("complex_to_channels")
|
|
488
|
+
class ComplexToChannels(Operation):
|
|
489
|
+
def __init__(self, axis=-1, **kwargs):
|
|
490
|
+
super().__init__(**kwargs)
|
|
491
|
+
self.axis = axis
|
|
492
|
+
|
|
493
|
+
def call(self, **kwargs):
|
|
494
|
+
data = kwargs[self.key]
|
|
495
|
+
output = complex_to_channels(data, axis=self.axis)
|
|
496
|
+
return {self.output_key: output}
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@ops_registry("lee_filter")
|
|
500
|
+
class LeeFilter(ImageOperation):
|
|
501
|
+
"""
|
|
502
|
+
The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
|
|
503
|
+
and ultrasound image processing. It smooths the image while preserving edges and details.
|
|
504
|
+
This implementation uses Gaussian filter for local statistics and treats channels independently.
|
|
505
|
+
|
|
506
|
+
Lee, J.S. (1980). Digital image enhancement and noise filtering by use of local statistics.
|
|
507
|
+
IEEE Transactions on Pattern Analysis and Machine Intelligence, (2), 165-168.
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
def __init__(self, sigma=3, kernel_size=None, pad_mode="symmetric", **kwargs):
|
|
511
|
+
"""
|
|
512
|
+
Args:
|
|
513
|
+
sigma (float): Standard deviation for Gaussian kernel. Default is 3.
|
|
514
|
+
kernel_size (int, optional): Size of the Gaussian kernel. If None,
|
|
515
|
+
it will be calculated based on sigma.
|
|
516
|
+
pad_mode (str): Padding mode to be used for Gaussian blur. Default is "symmetric".
|
|
517
|
+
"""
|
|
518
|
+
super().__init__(**kwargs)
|
|
519
|
+
self.sigma = sigma
|
|
520
|
+
self.kernel_size = kernel_size
|
|
521
|
+
self.pad_mode = pad_mode
|
|
522
|
+
|
|
523
|
+
# Create a GaussianBlur instance for computing local statistics
|
|
524
|
+
self.gaussian_blur = GaussianBlur(
|
|
525
|
+
sigma=self.sigma,
|
|
526
|
+
kernel_size=self.kernel_size,
|
|
527
|
+
pad_mode=self.pad_mode,
|
|
528
|
+
with_batch_dim=self.with_batch_dim,
|
|
529
|
+
jittable=self._jittable,
|
|
530
|
+
key="data",
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
@property
|
|
534
|
+
def with_batch_dim(self):
|
|
535
|
+
"""Get the with_batch_dim property of the LeeFilter operation."""
|
|
536
|
+
return self._with_batch_dim
|
|
537
|
+
|
|
538
|
+
@with_batch_dim.setter
|
|
539
|
+
def with_batch_dim(self, value):
|
|
540
|
+
"""Set the with_batch_dim property of the LeeFilter operation."""
|
|
541
|
+
self._with_batch_dim = value
|
|
542
|
+
if hasattr(self, "gaussian_blur"):
|
|
543
|
+
self.gaussian_blur.with_batch_dim = value
|
|
544
|
+
|
|
545
|
+
def call(self, **kwargs):
|
|
546
|
+
"""Apply the Lee filter to the input data.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
data (ops.Tensor): Input image data of shape (height, width, channels) with
|
|
550
|
+
optional batch dimension if ``self.with_batch_dim``.
|
|
551
|
+
"""
|
|
552
|
+
super().call(**kwargs)
|
|
553
|
+
data = kwargs.pop(self.key)
|
|
554
|
+
|
|
555
|
+
# Apply Gaussian blur to get local mean
|
|
556
|
+
img_mean = self.gaussian_blur.call(data=data, **kwargs)[self.gaussian_blur.output_key]
|
|
557
|
+
|
|
558
|
+
# Apply Gaussian blur to squared data to get local squared mean
|
|
559
|
+
img_sqr_mean = self.gaussian_blur.call(
|
|
560
|
+
data=data**2,
|
|
561
|
+
**kwargs,
|
|
562
|
+
)[self.gaussian_blur.output_key]
|
|
563
|
+
|
|
564
|
+
# Calculate local variance
|
|
565
|
+
img_variance = img_sqr_mean - img_mean**2
|
|
566
|
+
|
|
567
|
+
# Calculate global variance (per channel)
|
|
568
|
+
overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
|
|
569
|
+
|
|
570
|
+
# Calculate adaptive weights
|
|
571
|
+
eps = keras.config.epsilon()
|
|
572
|
+
img_weights = img_variance / (img_variance + overall_variance + eps)
|
|
573
|
+
|
|
574
|
+
# Apply Lee filter formula
|
|
575
|
+
img_output = img_mean + img_weights * (data - img_mean)
|
|
576
|
+
|
|
577
|
+
return {self.output_key: img_output}
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
@ops_registry("companding")
|
|
581
|
+
class Companding(Operation):
|
|
582
|
+
"""Companding according to the A- or μ-law algorithm.
|
|
583
|
+
|
|
584
|
+
Invertible compressing operation. Used to compress
|
|
585
|
+
dynamic range of input data (and subsequently expand).
|
|
586
|
+
|
|
587
|
+
μ-law companding:
|
|
588
|
+
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
|
|
589
|
+
A-law companding:
|
|
590
|
+
https://en.wikipedia.org/wiki/A-law_algorithm
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
expand (bool, optional): If set to False (default),
|
|
594
|
+
data is compressed, else expanded.
|
|
595
|
+
comp_type (str): either `a` or `mu`.
|
|
596
|
+
mu (float, optional): compression parameter. Defaults to 255.
|
|
597
|
+
A (float, optional): compression parameter. Defaults to 87.6.
|
|
598
|
+
"""
|
|
599
|
+
|
|
600
|
+
def __init__(self, expand=False, comp_type="mu", **kwargs):
|
|
601
|
+
super().__init__(**kwargs)
|
|
602
|
+
self.expand = expand
|
|
603
|
+
self.comp_type = comp_type.lower()
|
|
604
|
+
if self.comp_type not in ["mu", "a"]:
|
|
605
|
+
raise ValueError("comp_type must be 'mu' or 'a'.")
|
|
606
|
+
|
|
607
|
+
if self.comp_type == "mu":
|
|
608
|
+
self._compand_func = self._mu_law_expand if self.expand else self._mu_law_compress
|
|
609
|
+
else:
|
|
610
|
+
self._compand_func = self._a_law_expand if self.expand else self._a_law_compress
|
|
611
|
+
|
|
612
|
+
@staticmethod
|
|
613
|
+
def _mu_law_compress(x, mu=255, **kwargs):
|
|
614
|
+
x = ops.clip(x, -1, 1)
|
|
615
|
+
return ops.sign(x) * ops.log(1.0 + mu * ops.abs(x)) / ops.log(1.0 + mu)
|
|
616
|
+
|
|
617
|
+
@staticmethod
|
|
618
|
+
def _mu_law_expand(y, mu=255, **kwargs):
|
|
619
|
+
y = ops.clip(y, -1, 1)
|
|
620
|
+
return ops.sign(y) * ((1.0 + mu) ** ops.abs(y) - 1.0) / mu
|
|
621
|
+
|
|
622
|
+
@staticmethod
|
|
623
|
+
def _a_law_compress(x, A=87.6, **kwargs):
|
|
624
|
+
x = ops.clip(x, -1, 1)
|
|
625
|
+
x_sign = ops.sign(x)
|
|
626
|
+
x_abs = ops.abs(x)
|
|
627
|
+
A_log = ops.log(A)
|
|
628
|
+
val1 = x_sign * A * x_abs / (1.0 + A_log)
|
|
629
|
+
val2 = x_sign * (1.0 + ops.log(A * x_abs)) / (1.0 + A_log)
|
|
630
|
+
y = ops.where((x_abs >= 0) & (x_abs < (1.0 / A)), val1, val2)
|
|
631
|
+
return y
|
|
632
|
+
|
|
633
|
+
@staticmethod
|
|
634
|
+
def _a_law_expand(y, A=87.6, **kwargs):
|
|
635
|
+
y = ops.clip(y, -1, 1)
|
|
636
|
+
y_sign = ops.sign(y)
|
|
637
|
+
y_abs = ops.abs(y)
|
|
638
|
+
A_log = ops.log(A)
|
|
639
|
+
val1 = y_sign * y_abs * (1.0 + A_log) / A
|
|
640
|
+
val2 = y_sign * ops.exp(y_abs * (1.0 + A_log) - 1.0) / A
|
|
641
|
+
x = ops.where((y_abs >= 0) & (y_abs < (1.0 / (1.0 + A_log))), val1, val2)
|
|
642
|
+
return x
|
|
643
|
+
|
|
644
|
+
def call(self, mu=255, A=87.6, **kwargs):
|
|
645
|
+
data = kwargs[self.key]
|
|
646
|
+
|
|
647
|
+
mu = ops.cast(mu, data.dtype)
|
|
648
|
+
A = ops.cast(A, data.dtype)
|
|
649
|
+
|
|
650
|
+
data_out = self._compand_func(data, mu=mu, A=A)
|
|
651
|
+
return {self.output_key: data_out}
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
@ops_registry("downsample")
|
|
655
|
+
class Downsample(Operation):
|
|
656
|
+
"""Downsample data along a specific axis."""
|
|
657
|
+
|
|
658
|
+
def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
|
|
659
|
+
super().__init__(
|
|
660
|
+
additional_output_keys=["sampling_frequency", "n_ax"],
|
|
661
|
+
**kwargs,
|
|
662
|
+
)
|
|
663
|
+
if factor < 1:
|
|
664
|
+
raise ValueError("Downsample factor must be >= 1.")
|
|
665
|
+
if phase < 0 or phase >= factor:
|
|
666
|
+
raise ValueError("phase must satisfy 0 <= phase < factor.")
|
|
667
|
+
self.factor = factor
|
|
668
|
+
self.phase = phase
|
|
669
|
+
self.axis = axis
|
|
670
|
+
|
|
671
|
+
def call(self, sampling_frequency=None, n_ax=None, **kwargs):
|
|
672
|
+
data = kwargs[self.key]
|
|
673
|
+
length = ops.shape(data)[self.axis]
|
|
674
|
+
sample_idx = ops.arange(self.phase, length, self.factor)
|
|
675
|
+
data_downsampled = ops.take(data, sample_idx, axis=self.axis)
|
|
676
|
+
|
|
677
|
+
output = {self.output_key: data_downsampled}
|
|
678
|
+
# downsampling also affects the sampling frequency
|
|
679
|
+
if sampling_frequency is not None:
|
|
680
|
+
sampling_frequency = sampling_frequency / self.factor
|
|
681
|
+
output["sampling_frequency"] = sampling_frequency
|
|
682
|
+
if n_ax is not None:
|
|
683
|
+
n_ax = n_ax // self.factor
|
|
684
|
+
output["n_ax"] = n_ax
|
|
685
|
+
return output
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
@ops_registry("anisotropic_diffusion")
|
|
689
|
+
class AnisotropicDiffusion(Operation):
|
|
690
|
+
"""Speckle Reducing Anisotropic Diffusion (SRAD) filter.
|
|
691
|
+
|
|
692
|
+
Reference:
|
|
693
|
+
- https://www.researchgate.net/publication/5602035_Speckle_reducing_anisotropic_diffusion
|
|
694
|
+
- https://nl.mathworks.com/matlabcentral/fileexchange/54044-image-despeckle-filtering-toolbox
|
|
695
|
+
"""
|
|
696
|
+
|
|
697
|
+
def call(self, niter=100, lmbda=0.1, rect=None, eps=1e-6, **kwargs):
|
|
698
|
+
"""Anisotropic diffusion filter.
|
|
699
|
+
|
|
700
|
+
Assumes input data is non-negative.
|
|
701
|
+
|
|
702
|
+
Args:
|
|
703
|
+
niter: Number of iterations.
|
|
704
|
+
lmbda: Lambda parameter.
|
|
705
|
+
rect: Rectangle [x1, y1, x2, y2] for homogeneous noise (optional).
|
|
706
|
+
eps: Small epsilon for stability.
|
|
707
|
+
Returns:
|
|
708
|
+
Filtered image (2D tensor or batch of images).
|
|
709
|
+
"""
|
|
710
|
+
data = kwargs[self.key]
|
|
711
|
+
|
|
712
|
+
if not self.with_batch_dim:
|
|
713
|
+
data = ops.expand_dims(data, axis=0)
|
|
714
|
+
|
|
715
|
+
batch_size = ops.shape(data)[0]
|
|
716
|
+
|
|
717
|
+
results = []
|
|
718
|
+
for i in range(batch_size):
|
|
719
|
+
image = data[i]
|
|
720
|
+
image_out = self._anisotropic_diffusion_single(image, niter, lmbda, rect, eps)
|
|
721
|
+
results.append(image_out)
|
|
722
|
+
|
|
723
|
+
result = ops.stack(results, axis=0)
|
|
724
|
+
|
|
725
|
+
if not self.with_batch_dim:
|
|
726
|
+
result = ops.squeeze(result, axis=0)
|
|
727
|
+
|
|
728
|
+
return {self.output_key: result}
|
|
729
|
+
|
|
730
|
+
def _anisotropic_diffusion_single(self, image, niter, lmbda, rect, eps):
|
|
731
|
+
"""Apply anisotropic diffusion to a single image (2D)."""
|
|
732
|
+
image = ops.exp(image)
|
|
733
|
+
M, N = image.shape
|
|
734
|
+
|
|
735
|
+
for _ in range(niter):
|
|
736
|
+
iN = ops.concatenate([image[1:], ops.zeros((1, N), dtype=image.dtype)], axis=0)
|
|
737
|
+
iS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), image[:-1]], axis=0)
|
|
738
|
+
jW = ops.concatenate([image[:, 1:], ops.zeros((M, 1), dtype=image.dtype)], axis=1)
|
|
739
|
+
jE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), image[:, :-1]], axis=1)
|
|
740
|
+
|
|
741
|
+
if rect is not None:
|
|
742
|
+
x1, y1, x2, y2 = rect
|
|
743
|
+
imageuniform = image[x1:x2, y1:y2]
|
|
744
|
+
q0_squared = (ops.std(imageuniform) / (ops.mean(imageuniform) + eps)) ** 2
|
|
745
|
+
|
|
746
|
+
dN = iN - image
|
|
747
|
+
dS = iS - image
|
|
748
|
+
dW = jW - image
|
|
749
|
+
dE = jE - image
|
|
750
|
+
|
|
751
|
+
G2 = (dN**2 + dS**2 + dW**2 + dE**2) / (image**2 + eps)
|
|
752
|
+
L = (dN + dS + dW + dE) / (image + eps)
|
|
753
|
+
num = (0.5 * G2) - ((1 / 16) * (L**2))
|
|
754
|
+
den = (1 + ((1 / 4) * L)) ** 2
|
|
755
|
+
q_squared = num / (den + eps)
|
|
756
|
+
|
|
757
|
+
if rect is not None:
|
|
758
|
+
den = (q_squared - q0_squared) / (q0_squared * (1 + q0_squared) + eps)
|
|
759
|
+
c = 1.0 / (1 + den)
|
|
760
|
+
cS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), c[:-1]], axis=0)
|
|
761
|
+
cE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), c[:, :-1]], axis=1)
|
|
762
|
+
|
|
763
|
+
D = (cS * dS) + (c * dN) + (cE * dE) + (c * dW)
|
|
764
|
+
image = image + (lmbda / 4) * D
|
|
765
|
+
|
|
766
|
+
result = ops.log(image)
|
|
767
|
+
return result
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
@ops_registry("envelope_detect")
|
|
771
|
+
class EnvelopeDetect(Operation):
|
|
772
|
+
"""Envelope detection of RF signals."""
|
|
773
|
+
|
|
774
|
+
def __init__(
|
|
775
|
+
self,
|
|
776
|
+
axis=-3,
|
|
777
|
+
**kwargs,
|
|
778
|
+
):
|
|
779
|
+
super().__init__(
|
|
780
|
+
input_data_type=DataTypes.BEAMFORMED_DATA,
|
|
781
|
+
output_data_type=DataTypes.ENVELOPE_DATA,
|
|
782
|
+
**kwargs,
|
|
783
|
+
)
|
|
784
|
+
self.axis = axis
|
|
785
|
+
|
|
786
|
+
def call(self, **kwargs):
|
|
787
|
+
"""
|
|
788
|
+
Args:
|
|
789
|
+
- data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
|
|
790
|
+
Returns:
|
|
791
|
+
- envelope_data (Tensor): The envelope detected data
|
|
792
|
+
of shape (..., grid_size_z, grid_size_x).
|
|
793
|
+
"""
|
|
794
|
+
data = kwargs[self.key]
|
|
795
|
+
|
|
796
|
+
data = envelope_detect(data, axis=self.axis)
|
|
797
|
+
|
|
798
|
+
return {self.output_key: data}
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
@ops_registry("upmix")
|
|
802
|
+
class UpMix(Operation):
|
|
803
|
+
"""Upmix IQ data to RF data."""
|
|
804
|
+
|
|
805
|
+
def __init__(
|
|
806
|
+
self,
|
|
807
|
+
upsampling_rate=1,
|
|
808
|
+
**kwargs,
|
|
809
|
+
):
|
|
810
|
+
super().__init__(
|
|
811
|
+
**kwargs,
|
|
812
|
+
)
|
|
813
|
+
self.upsampling_rate = upsampling_rate
|
|
814
|
+
|
|
815
|
+
def call(
|
|
816
|
+
self,
|
|
817
|
+
sampling_frequency=None,
|
|
818
|
+
center_frequency=None,
|
|
819
|
+
**kwargs,
|
|
820
|
+
):
|
|
821
|
+
data = kwargs[self.key]
|
|
822
|
+
|
|
823
|
+
if data.shape[-1] == 1:
|
|
824
|
+
log.warning("Upmixing is not applicable to RF data.")
|
|
825
|
+
return {self.output_key: data}
|
|
826
|
+
elif data.shape[-1] == 2:
|
|
827
|
+
data = channels_to_complex(data)
|
|
828
|
+
|
|
829
|
+
data = upmix(data, sampling_frequency, center_frequency, self.upsampling_rate)
|
|
830
|
+
data = ops.expand_dims(data, axis=-1)
|
|
831
|
+
return {self.output_key: data}
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
@ops_registry("log_compress")
|
|
835
|
+
class LogCompress(Operation):
|
|
836
|
+
"""Logarithmic compression of data."""
|
|
837
|
+
|
|
838
|
+
def __init__(self, clip: bool = True, **kwargs):
|
|
839
|
+
"""Initialize the LogCompress operation.
|
|
840
|
+
|
|
841
|
+
Args:
|
|
842
|
+
clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
|
|
843
|
+
"""
|
|
844
|
+
super().__init__(
|
|
845
|
+
input_data_type=DataTypes.ENVELOPE_DATA,
|
|
846
|
+
output_data_type=DataTypes.IMAGE,
|
|
847
|
+
**kwargs,
|
|
848
|
+
)
|
|
849
|
+
self.clip = clip
|
|
850
|
+
|
|
851
|
+
def call(self, dynamic_range=None, **kwargs):
|
|
852
|
+
"""Apply logarithmic compression to data.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
dynamic_range (tuple, optional): Dynamic range in dB. Defaults to (-60, 0).
|
|
856
|
+
|
|
857
|
+
Returns:
|
|
858
|
+
dict: Dictionary containing log-compressed data
|
|
859
|
+
"""
|
|
860
|
+
data = kwargs[self.key]
|
|
861
|
+
|
|
862
|
+
if dynamic_range is None:
|
|
863
|
+
dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
|
|
864
|
+
dynamic_range = ops.cast(dynamic_range, data.dtype)
|
|
865
|
+
|
|
866
|
+
compressed_data = log_compress(data)
|
|
867
|
+
if self.clip:
|
|
868
|
+
compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
|
|
869
|
+
|
|
870
|
+
return {self.output_key: compressed_data}
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
@ops_registry("reshape_grid")
|
|
874
|
+
class ReshapeGrid(Operation):
|
|
875
|
+
"""Reshape flat grid data to grid shape."""
|
|
876
|
+
|
|
877
|
+
def __init__(self, axis=0, **kwargs):
|
|
878
|
+
super().__init__(**kwargs)
|
|
879
|
+
self.axis = axis
|
|
880
|
+
|
|
881
|
+
def call(self, grid, **kwargs):
|
|
882
|
+
"""
|
|
883
|
+
Args:
|
|
884
|
+
- data (Tensor): The flat grid data of shape (..., n_pix, ...).
|
|
885
|
+
Returns:
|
|
886
|
+
- reshaped_data (Tensor): The reshaped data of shape (..., grid.shape, ...).
|
|
887
|
+
"""
|
|
888
|
+
data = kwargs[self.key]
|
|
889
|
+
reshaped_data = reshape_axis(data, grid.shape[:-1], self.axis + int(self.with_batch_dim))
|
|
890
|
+
return {self.output_key: reshaped_data}
|