zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +3 -3
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +173 -12
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +28 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +390 -196
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +406 -302
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
- zea-0.0.7.dist-info/RECORD +114 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/internal/core.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Base classes for the toolbox"""
|
|
2
2
|
|
|
3
3
|
import enum
|
|
4
|
+
import hashlib
|
|
4
5
|
import json
|
|
5
6
|
import pickle
|
|
6
7
|
from copy import deepcopy
|
|
@@ -8,7 +9,8 @@ from copy import deepcopy
|
|
|
8
9
|
import keras
|
|
9
10
|
import numpy as np
|
|
10
11
|
|
|
11
|
-
from zea.utils import reduce_to_signature
|
|
12
|
+
from zea.internal.utils import reduce_to_signature
|
|
13
|
+
from zea.utils import update_dictionary
|
|
12
14
|
|
|
13
15
|
CONVERT_TO_KERAS_TYPES = (np.ndarray, int, float, list, tuple, bool)
|
|
14
16
|
BASE_FLOAT_PRECISION = "float32"
|
|
@@ -76,7 +78,7 @@ class Object:
|
|
|
76
78
|
attributes.pop(
|
|
77
79
|
"_serialized", None
|
|
78
80
|
) # Remove the cached serialized attribute to avoid recursion
|
|
79
|
-
self._serialized =
|
|
81
|
+
self._serialized = serialize_elements([attributes])
|
|
80
82
|
return self._serialized
|
|
81
83
|
|
|
82
84
|
def __setattr__(self, name: str, value):
|
|
@@ -167,9 +169,7 @@ def _skip_to_tensor(value):
|
|
|
167
169
|
# Skip str (because JIT does not support it)
|
|
168
170
|
# Skip methods and functions
|
|
169
171
|
# Skip byte strings
|
|
170
|
-
|
|
171
|
-
return True
|
|
172
|
-
return False
|
|
172
|
+
return isinstance(value, str) or callable(value) or isinstance(value, bytes)
|
|
173
173
|
|
|
174
174
|
|
|
175
175
|
def dict_to_tensor(dictionary, keep_as_is=None):
|
|
@@ -184,8 +184,9 @@ def dict_to_tensor(dictionary, keep_as_is=None):
|
|
|
184
184
|
# Get the value from the dictionary
|
|
185
185
|
value = dictionary[key]
|
|
186
186
|
|
|
187
|
-
if isinstance(value, Object):
|
|
187
|
+
if isinstance(value, Object) and hasattr(value, "to_tensor"):
|
|
188
188
|
snapshot[key] = value.to_tensor(keep_as_is=keep_as_is)
|
|
189
|
+
continue
|
|
189
190
|
|
|
190
191
|
# Skip certain types
|
|
191
192
|
if _skip_to_tensor(value):
|
|
@@ -288,3 +289,65 @@ class ZEADecoderJSON(json.JSONDecoder):
|
|
|
288
289
|
obj[key] = self._MOD_TYPES_MAP[value] if value is not None else None
|
|
289
290
|
|
|
290
291
|
return obj
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def serialize_elements(key_elements: list) -> str:
|
|
295
|
+
"""Serialize elements of a list to a string.
|
|
296
|
+
|
|
297
|
+
Generally, uses the pickle representation of the elements.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
key_elements (list): List of elements to serialize. Can be nested lists
|
|
301
|
+
or tuples. In this case the elements are serialized recursively.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
str: A serialized string representation of the elements, joined by underscores.
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
def _serialize(element) -> str:
|
|
308
|
+
return pickle.dumps(element).hex()
|
|
309
|
+
|
|
310
|
+
def _serialize_element(element) -> str:
|
|
311
|
+
if isinstance(element, (list, tuple)):
|
|
312
|
+
# If element is a list or tuple, serialize its elements recursively
|
|
313
|
+
element = serialize_elements(element)
|
|
314
|
+
elif isinstance(element, Object) and hasattr(element, "serialized"):
|
|
315
|
+
# Use the serialized attribute if it exists
|
|
316
|
+
element = str(element.serialized)
|
|
317
|
+
elif isinstance(element, keras.random.SeedGenerator):
|
|
318
|
+
# If element is a SeedGenerator, use the state
|
|
319
|
+
element = keras.ops.convert_to_numpy(element.state.value)
|
|
320
|
+
element = _serialize(element)
|
|
321
|
+
elif isinstance(element, dict):
|
|
322
|
+
# If element is a dictionary, sort its keys and serialize its values recursively.
|
|
323
|
+
# This is needed to ensure the internal state and ordering of the dictionary does
|
|
324
|
+
# not affect the serialization.
|
|
325
|
+
keys = list(sorted(element.keys()))
|
|
326
|
+
values = [element[k] for k in keys]
|
|
327
|
+
keys = serialize_elements(keys)
|
|
328
|
+
values = serialize_elements(values)
|
|
329
|
+
element = f"k_{keys}_v_{values}"
|
|
330
|
+
else:
|
|
331
|
+
# Otherwise, serialize the element directly
|
|
332
|
+
element = _serialize(element)
|
|
333
|
+
|
|
334
|
+
return element
|
|
335
|
+
|
|
336
|
+
serialized_elements = []
|
|
337
|
+
for element in key_elements:
|
|
338
|
+
serialized_elements.append(_serialize_element(element))
|
|
339
|
+
|
|
340
|
+
return "_".join(serialized_elements)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def hash_elements(key_elements: list) -> str:
|
|
344
|
+
"""Generate an MD5 hash of the elements.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
key_elements (list): List of elements to serialize and hash.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
str: An MD5 hash of the serialized elements.
|
|
351
|
+
"""
|
|
352
|
+
serialized = serialize_elements(key_elements)
|
|
353
|
+
return hashlib.md5(serialized.encode()).hexdigest()
|
zea/internal/device.py
CHANGED
|
@@ -377,7 +377,11 @@ def init_device(
|
|
|
377
377
|
allow_preallocate: bool = True,
|
|
378
378
|
verbose: bool = True,
|
|
379
379
|
):
|
|
380
|
-
"""
|
|
380
|
+
"""Automatically selects a GPU or CPU device.
|
|
381
|
+
|
|
382
|
+
Useful to call at the start of a script to set the device for
|
|
383
|
+
tensorflow, jax or pytorch. The function will select a GPU based
|
|
384
|
+
on available memory, or fall back to CPU if no GPU is available.
|
|
381
385
|
|
|
382
386
|
Args:
|
|
383
387
|
backend (str): String indicating which backend to use. Can be
|
|
@@ -412,7 +416,7 @@ def init_device(
|
|
|
412
416
|
elif backend in ["numpy", "cpu"]:
|
|
413
417
|
device = "cpu"
|
|
414
418
|
else:
|
|
415
|
-
raise ValueError(f"Unknown backend ({backend})
|
|
419
|
+
raise ValueError(f"Unknown backend ({backend}).")
|
|
416
420
|
|
|
417
421
|
# Early exit if device is CPU
|
|
418
422
|
if device == "cpu":
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""Module to create dummy Scan objects for testing and simulation purposes."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from zea.beamform.delays import compute_t0_delays_focused, compute_t0_delays_planewave
|
|
6
|
+
from zea.probes import Probe
|
|
7
|
+
from zea.scan import Scan
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _get_linear_probe():
|
|
11
|
+
"""Returns a probe for ultrasound simulation tests."""
|
|
12
|
+
n_el = 128
|
|
13
|
+
aperture = 30e-3
|
|
14
|
+
probe_geometry = np.stack(
|
|
15
|
+
[
|
|
16
|
+
np.linspace(-aperture / 2, aperture / 2, n_el),
|
|
17
|
+
np.zeros(n_el),
|
|
18
|
+
np.zeros(n_el),
|
|
19
|
+
],
|
|
20
|
+
axis=1,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
return Probe(
|
|
24
|
+
probe_geometry=probe_geometry,
|
|
25
|
+
center_frequency=2.5e6,
|
|
26
|
+
sampling_frequency=10e6,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_phased_array_probe():
|
|
31
|
+
"""Returns a probe for ultrasound simulation tests."""
|
|
32
|
+
n_el = 80
|
|
33
|
+
aperture = 20e-3
|
|
34
|
+
probe_geometry = np.stack(
|
|
35
|
+
[
|
|
36
|
+
np.linspace(-aperture / 2, aperture / 2, n_el),
|
|
37
|
+
np.zeros(n_el),
|
|
38
|
+
np.zeros(n_el),
|
|
39
|
+
],
|
|
40
|
+
axis=1,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return Probe(
|
|
44
|
+
probe_geometry=probe_geometry,
|
|
45
|
+
center_frequency=3.12e6,
|
|
46
|
+
sampling_frequency=12.5e6,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _get_n_ax(ultrasound_probe):
|
|
51
|
+
"""Returns the number of ax for ultrasound simulation tests based on the center
|
|
52
|
+
frequency. A probe with a higher center frequency needs more samples to cover
|
|
53
|
+
the image depth.
|
|
54
|
+
"""
|
|
55
|
+
is_low_frequency_probe = ultrasound_probe.center_frequency < 4e6
|
|
56
|
+
|
|
57
|
+
if is_low_frequency_probe:
|
|
58
|
+
return 510
|
|
59
|
+
|
|
60
|
+
return 1024
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _get_probe(kind) -> Probe:
|
|
64
|
+
if kind == "linear":
|
|
65
|
+
return _get_linear_probe()
|
|
66
|
+
elif kind == "phased_array":
|
|
67
|
+
return _get_phased_array_probe()
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Unknown probe kind: {kind}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_constant_scan_kwargs():
|
|
73
|
+
return {
|
|
74
|
+
"lens_sound_speed": 1000,
|
|
75
|
+
"lens_thickness": 1e-3,
|
|
76
|
+
"n_ch": 1,
|
|
77
|
+
"selected_transmits": "all",
|
|
78
|
+
"sound_speed": 1540.0,
|
|
79
|
+
"apply_lens_correction": False,
|
|
80
|
+
"attenuation_coef": 0.0,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _get_lims_and_gridsize(center_frequency, sound_speed):
|
|
85
|
+
"""Returns the limits and gridsize for ultrasound simulation tests."""
|
|
86
|
+
xlims, zlims = (-20e-3, 20e-3), (0, 35e-3)
|
|
87
|
+
width, height = xlims[1] - xlims[0], zlims[1] - zlims[0]
|
|
88
|
+
wavelength = sound_speed / center_frequency
|
|
89
|
+
gridsize = (
|
|
90
|
+
int(width / (0.5 * wavelength)) + 1,
|
|
91
|
+
int(height / (0.5 * wavelength)) + 1,
|
|
92
|
+
)
|
|
93
|
+
return {"xlims": xlims, "zlims": zlims, "grid_size_x": gridsize[0], "grid_size_z": gridsize[1]}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _get_planewave_scan(ultrasound_probe, grid_type, **kwargs):
|
|
97
|
+
"""Returns a scan for ultrasound simulation tests."""
|
|
98
|
+
constant_scan_kwargs = _get_constant_scan_kwargs()
|
|
99
|
+
n_el = ultrasound_probe.n_el
|
|
100
|
+
n_tx = 8
|
|
101
|
+
|
|
102
|
+
tx_apodizations = np.ones((n_tx, n_el)) * np.hanning(n_el)[None]
|
|
103
|
+
probe_geometry = ultrasound_probe.probe_geometry
|
|
104
|
+
|
|
105
|
+
angles = np.linspace(10, -10, n_tx) * np.pi / 180
|
|
106
|
+
|
|
107
|
+
sound_speed = constant_scan_kwargs["sound_speed"]
|
|
108
|
+
focus_distances = np.ones(n_tx) * np.inf
|
|
109
|
+
t0_delays = compute_t0_delays_planewave(
|
|
110
|
+
probe_geometry=probe_geometry, polar_angles=angles, sound_speed=sound_speed
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return Scan(
|
|
114
|
+
n_tx=n_tx,
|
|
115
|
+
n_el=n_el,
|
|
116
|
+
center_frequency=ultrasound_probe.center_frequency,
|
|
117
|
+
sampling_frequency=ultrasound_probe.sampling_frequency,
|
|
118
|
+
probe_geometry=probe_geometry,
|
|
119
|
+
t0_delays=t0_delays,
|
|
120
|
+
tx_apodizations=tx_apodizations,
|
|
121
|
+
element_width=np.linalg.norm(probe_geometry[1] - probe_geometry[0]),
|
|
122
|
+
focus_distances=focus_distances,
|
|
123
|
+
polar_angles=angles,
|
|
124
|
+
initial_times=np.ones(n_tx) * 1e-6,
|
|
125
|
+
n_ax=_get_n_ax(ultrasound_probe),
|
|
126
|
+
grid_type=grid_type,
|
|
127
|
+
**_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
|
|
128
|
+
**constant_scan_kwargs,
|
|
129
|
+
**kwargs,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _get_multistatic_scan(ultrasound_probe, grid_type, **kwargs):
|
|
134
|
+
n_el = ultrasound_probe.n_el
|
|
135
|
+
n_tx = 8
|
|
136
|
+
|
|
137
|
+
tx_apodizations = np.zeros((n_tx, n_el))
|
|
138
|
+
for n, idx in enumerate(np.linspace(0, n_el - 1, n_tx, dtype=int)):
|
|
139
|
+
tx_apodizations[n, idx] = 1
|
|
140
|
+
probe_geometry = ultrasound_probe.probe_geometry
|
|
141
|
+
|
|
142
|
+
focus_distances = np.zeros(n_tx)
|
|
143
|
+
t0_delays = np.zeros((n_tx, n_el))
|
|
144
|
+
|
|
145
|
+
constant_scan_kwargs = _get_constant_scan_kwargs()
|
|
146
|
+
|
|
147
|
+
return Scan(
|
|
148
|
+
n_tx=n_tx,
|
|
149
|
+
n_el=n_el,
|
|
150
|
+
center_frequency=ultrasound_probe.center_frequency,
|
|
151
|
+
sampling_frequency=ultrasound_probe.sampling_frequency,
|
|
152
|
+
probe_geometry=probe_geometry,
|
|
153
|
+
t0_delays=t0_delays,
|
|
154
|
+
tx_apodizations=tx_apodizations,
|
|
155
|
+
element_width=np.linalg.norm(probe_geometry[1] - probe_geometry[0]),
|
|
156
|
+
focus_distances=focus_distances,
|
|
157
|
+
polar_angles=np.zeros(n_tx),
|
|
158
|
+
initial_times=np.ones(n_tx) * 1e-6,
|
|
159
|
+
n_ax=_get_n_ax(ultrasound_probe),
|
|
160
|
+
grid_type=grid_type,
|
|
161
|
+
**_get_lims_and_gridsize(
|
|
162
|
+
ultrasound_probe.center_frequency, constant_scan_kwargs["sound_speed"]
|
|
163
|
+
),
|
|
164
|
+
**constant_scan_kwargs,
|
|
165
|
+
**kwargs,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _get_diverging_scan(ultrasound_probe, grid_type, **kwargs):
|
|
170
|
+
"""Returns a scan for ultrasound simulation tests."""
|
|
171
|
+
constant_scan_kwargs = _get_constant_scan_kwargs()
|
|
172
|
+
n_el = ultrasound_probe.n_el
|
|
173
|
+
n_tx = 8
|
|
174
|
+
|
|
175
|
+
tx_apodizations = np.ones((n_tx, n_el)) * np.hanning(n_el)[None]
|
|
176
|
+
|
|
177
|
+
angles = np.linspace(10, -10, n_tx) * np.pi / 180
|
|
178
|
+
|
|
179
|
+
sound_speed = constant_scan_kwargs["sound_speed"]
|
|
180
|
+
focus_distances = np.ones(n_tx) * -15e-3
|
|
181
|
+
t0_delays = compute_t0_delays_focused(
|
|
182
|
+
origins=np.zeros((n_tx, 3)),
|
|
183
|
+
focus_distances=focus_distances,
|
|
184
|
+
probe_geometry=ultrasound_probe.probe_geometry,
|
|
185
|
+
polar_angles=angles,
|
|
186
|
+
sound_speed=sound_speed,
|
|
187
|
+
)
|
|
188
|
+
element_width = np.linalg.norm(
|
|
189
|
+
ultrasound_probe.probe_geometry[1] - ultrasound_probe.probe_geometry[0]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return Scan(
|
|
193
|
+
n_tx=n_tx,
|
|
194
|
+
n_el=n_el,
|
|
195
|
+
center_frequency=ultrasound_probe.center_frequency,
|
|
196
|
+
sampling_frequency=ultrasound_probe.sampling_frequency,
|
|
197
|
+
probe_geometry=ultrasound_probe.probe_geometry,
|
|
198
|
+
t0_delays=t0_delays,
|
|
199
|
+
tx_apodizations=tx_apodizations,
|
|
200
|
+
element_width=element_width,
|
|
201
|
+
focus_distances=focus_distances,
|
|
202
|
+
polar_angles=angles,
|
|
203
|
+
initial_times=np.ones(n_tx) * 1e-6,
|
|
204
|
+
n_ax=_get_n_ax(ultrasound_probe),
|
|
205
|
+
grid_type=grid_type,
|
|
206
|
+
**_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
|
|
207
|
+
**constant_scan_kwargs,
|
|
208
|
+
**kwargs,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _get_focused_scan(ultrasound_probe, grid_type, **kwargs):
|
|
213
|
+
"""Returns a scan for ultrasound simulation tests."""
|
|
214
|
+
constant_scan_kwargs = _get_constant_scan_kwargs()
|
|
215
|
+
n_el = ultrasound_probe.n_el
|
|
216
|
+
n_tx = 8
|
|
217
|
+
|
|
218
|
+
tx_apodizations = np.ones((n_tx, n_el)) * np.hanning(n_el)[None]
|
|
219
|
+
|
|
220
|
+
angles = np.linspace(30, -30, n_tx) * np.pi / 180
|
|
221
|
+
|
|
222
|
+
sound_speed = constant_scan_kwargs["sound_speed"]
|
|
223
|
+
focus_distances = np.ones(n_tx) * 15e-3
|
|
224
|
+
t0_delays = compute_t0_delays_focused(
|
|
225
|
+
origins=np.zeros((n_tx, 3)),
|
|
226
|
+
focus_distances=focus_distances,
|
|
227
|
+
probe_geometry=ultrasound_probe.probe_geometry,
|
|
228
|
+
polar_angles=angles,
|
|
229
|
+
sound_speed=sound_speed,
|
|
230
|
+
)
|
|
231
|
+
element_width = np.linalg.norm(
|
|
232
|
+
ultrasound_probe.probe_geometry[1] - ultrasound_probe.probe_geometry[0]
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return Scan(
|
|
236
|
+
n_tx=n_tx,
|
|
237
|
+
n_el=n_el,
|
|
238
|
+
center_frequency=ultrasound_probe.center_frequency,
|
|
239
|
+
sampling_frequency=ultrasound_probe.sampling_frequency,
|
|
240
|
+
probe_geometry=ultrasound_probe.probe_geometry,
|
|
241
|
+
t0_delays=t0_delays,
|
|
242
|
+
tx_apodizations=tx_apodizations,
|
|
243
|
+
element_width=element_width,
|
|
244
|
+
focus_distances=focus_distances,
|
|
245
|
+
polar_angles=angles,
|
|
246
|
+
initial_times=np.ones(n_tx) * 1e-6,
|
|
247
|
+
n_ax=_get_n_ax(ultrasound_probe),
|
|
248
|
+
grid_type=grid_type,
|
|
249
|
+
**_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
|
|
250
|
+
**constant_scan_kwargs,
|
|
251
|
+
**kwargs,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _get_linescan_scan(ultrasound_probe, grid_type, **kwargs):
|
|
256
|
+
"""Returns a scan for ultrasound simulation tests."""
|
|
257
|
+
constant_scan_kwargs = _get_constant_scan_kwargs()
|
|
258
|
+
n_el = ultrasound_probe.n_el
|
|
259
|
+
n_tx = 8
|
|
260
|
+
|
|
261
|
+
center_elements = np.linspace(0, n_el + 1, n_tx + 2, dtype=int)
|
|
262
|
+
center_elements = center_elements[1:-1]
|
|
263
|
+
tx_apodizations = np.zeros((n_tx, n_el))
|
|
264
|
+
aperture_size_elements = 24
|
|
265
|
+
|
|
266
|
+
# Define subapertures
|
|
267
|
+
origins = []
|
|
268
|
+
for n, idx in enumerate(center_elements):
|
|
269
|
+
el0 = np.clip(idx - aperture_size_elements // 2, 0, n_el)
|
|
270
|
+
el1 = np.clip(idx + aperture_size_elements // 2, 0, n_el)
|
|
271
|
+
tx_apodizations[n, el0:el1] = np.hanning(el1 - el0)[None]
|
|
272
|
+
origins.append(ultrasound_probe.probe_geometry[idx])
|
|
273
|
+
origins = np.stack(origins, axis=0)
|
|
274
|
+
|
|
275
|
+
# All angles should be zero because each line fires straight ahead
|
|
276
|
+
angles = np.zeros(n_tx)
|
|
277
|
+
|
|
278
|
+
sound_speed = constant_scan_kwargs["sound_speed"]
|
|
279
|
+
|
|
280
|
+
focus_distances = np.ones(n_tx) * 15e-3
|
|
281
|
+
t0_delays = compute_t0_delays_focused(
|
|
282
|
+
origins=origins,
|
|
283
|
+
focus_distances=focus_distances,
|
|
284
|
+
probe_geometry=ultrasound_probe.probe_geometry,
|
|
285
|
+
polar_angles=angles,
|
|
286
|
+
sound_speed=sound_speed,
|
|
287
|
+
)
|
|
288
|
+
element_width = np.linalg.norm(
|
|
289
|
+
ultrasound_probe.probe_geometry[1] - ultrasound_probe.probe_geometry[0]
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
return Scan(
|
|
293
|
+
n_tx=n_tx,
|
|
294
|
+
n_el=n_el,
|
|
295
|
+
center_frequency=ultrasound_probe.center_frequency,
|
|
296
|
+
sampling_frequency=ultrasound_probe.sampling_frequency,
|
|
297
|
+
probe_geometry=ultrasound_probe.probe_geometry,
|
|
298
|
+
t0_delays=t0_delays,
|
|
299
|
+
tx_apodizations=tx_apodizations,
|
|
300
|
+
element_width=element_width,
|
|
301
|
+
focus_distances=focus_distances,
|
|
302
|
+
polar_angles=angles,
|
|
303
|
+
initial_times=np.ones(n_tx) * 1e-6,
|
|
304
|
+
n_ax=_get_n_ax(ultrasound_probe),
|
|
305
|
+
grid_type=grid_type,
|
|
306
|
+
**_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
|
|
307
|
+
**constant_scan_kwargs,
|
|
308
|
+
**kwargs,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _get_scan(ultrasound_probe, kind, grid_type="cartesian", **kwargs) -> Scan:
|
|
313
|
+
if kind == "planewave":
|
|
314
|
+
return _get_planewave_scan(ultrasound_probe, grid_type, **kwargs)
|
|
315
|
+
elif kind == "multistatic":
|
|
316
|
+
return _get_multistatic_scan(ultrasound_probe, grid_type, **kwargs)
|
|
317
|
+
elif kind == "diverging":
|
|
318
|
+
return _get_diverging_scan(ultrasound_probe, grid_type, **kwargs)
|
|
319
|
+
elif kind == "focused":
|
|
320
|
+
return _get_focused_scan(ultrasound_probe, grid_type, **kwargs)
|
|
321
|
+
elif kind == "linescan":
|
|
322
|
+
return _get_linescan_scan(ultrasound_probe, grid_type, **kwargs)
|
|
323
|
+
else:
|
|
324
|
+
raise ValueError(f"Unknown scan kind: {kind}")
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def get_scan(kind="planewave", probe_kind="linear", grid_type="cartesian", **kwargs) -> Scan:
|
|
328
|
+
"""Returns a scan for ultrasound simulation tests."""
|
|
329
|
+
ultrasound_probe = _get_probe(probe_kind)
|
|
330
|
+
return _get_scan(ultrasound_probe, kind, grid_type, **kwargs)
|
zea/internal/operators.py
CHANGED
|
@@ -6,6 +6,7 @@ Handles task-dependent operations (A) and noises (n) to simulate a measurement y
|
|
|
6
6
|
|
|
7
7
|
import abc
|
|
8
8
|
|
|
9
|
+
import numpy as np
|
|
9
10
|
from keras import ops
|
|
10
11
|
|
|
11
12
|
from zea.internal.core import Object
|
|
@@ -26,13 +27,16 @@ class Operator(abc.ABC, Object):
|
|
|
26
27
|
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
|
-
sigma = 0.0
|
|
30
|
-
|
|
31
30
|
@abc.abstractmethod
|
|
32
31
|
def forward(self, data, *args, **kwargs):
|
|
33
32
|
"""Implements the forward operator A: x -> y."""
|
|
34
33
|
raise NotImplementedError
|
|
35
34
|
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def transpose(self, data, *args, **kwargs):
|
|
37
|
+
"""Implements the transpose (or adjoint) of the operator A^T: y -> x."""
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
36
40
|
@abc.abstractmethod
|
|
37
41
|
def __str__(self):
|
|
38
42
|
"""String representation of the operator."""
|
|
@@ -78,5 +82,113 @@ class InpaintingOperator(Operator):
|
|
|
78
82
|
# return self.mask * data
|
|
79
83
|
return ops.where(mask, data, self.min_val)
|
|
80
84
|
|
|
85
|
+
def transpose(self, data, mask):
|
|
86
|
+
# masking operation is diagonal --> A.T = A
|
|
87
|
+
return self.forward(data, mask)
|
|
88
|
+
|
|
81
89
|
def __str__(self):
|
|
82
90
|
return "y = Ax + n, where A = I * M"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@operator_registry(name="fourier_blur")
|
|
94
|
+
class FourierBlurOperator(Operator):
|
|
95
|
+
"""Fourier-domain blurring operator class.
|
|
96
|
+
|
|
97
|
+
Applies blurring by masking high frequencies in the Fourier domain.
|
|
98
|
+
|
|
99
|
+
Formally defined as:
|
|
100
|
+
y = F^(-1)(M * F(x))
|
|
101
|
+
|
|
102
|
+
where F is the FFT, F^(-1) is the inverse FFT and M is the frequency mask.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, shape, cutoff_freq=0.5, smooth=True, **kwargs):
|
|
106
|
+
"""Initialize the Fourier blur operator.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
shape: Shape of the input data (H, W), (H, W, C), or (B, H, W, C).
|
|
110
|
+
cutoff_freq: Cutoff frequency as fraction of Nyquist frequency (0.0 to 1.0).
|
|
111
|
+
smooth: If True, use Gaussian rolloff; otherwise use hard cutoff.
|
|
112
|
+
**kwargs: Additional arguments.
|
|
113
|
+
"""
|
|
114
|
+
super().__init__(**kwargs)
|
|
115
|
+
self.cutoff_freq = cutoff_freq
|
|
116
|
+
self.shape = shape
|
|
117
|
+
|
|
118
|
+
# Precompute frequency mask
|
|
119
|
+
self.freq_mask = self.make_lowpass_mask(shape=shape, cutoff_freq=cutoff_freq, smooth=smooth)
|
|
120
|
+
|
|
121
|
+
def make_lowpass_mask(self, shape, cutoff_freq=0.1, smooth=True):
|
|
122
|
+
"""
|
|
123
|
+
Create a low-pass Fourier mask of given shape.
|
|
124
|
+
cutoff: relative frequency radius (0 < cutoff < 0.5)
|
|
125
|
+
smooth: if True, use Gaussian rolloff
|
|
126
|
+
"""
|
|
127
|
+
# Accept (H, W), (H, W, C) or (B, H, W, C)
|
|
128
|
+
if len(shape) == 2:
|
|
129
|
+
H, W = shape
|
|
130
|
+
elif len(shape) >= 3:
|
|
131
|
+
H, W = shape[-3], shape[-2]
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f"Invalid shape {shape}. Expected (H, W), (H, W, C), or (B, H, W, C).")
|
|
134
|
+
fy = np.fft.fftfreq(H)
|
|
135
|
+
fx = np.fft.fftfreq(W)
|
|
136
|
+
FX, FY = np.meshgrid(fx, fy)
|
|
137
|
+
R = np.sqrt(FX**2 + FY**2)
|
|
138
|
+
|
|
139
|
+
if smooth:
|
|
140
|
+
sigma = cutoff_freq / np.sqrt(2 * np.log(2))
|
|
141
|
+
mask = np.exp(-(R**2) / (2 * sigma**2))
|
|
142
|
+
else:
|
|
143
|
+
mask = (R < cutoff_freq).astype(np.float32)
|
|
144
|
+
|
|
145
|
+
# Shift DC to top-left to match fft2 conventions
|
|
146
|
+
# mask = np.fft.ifftshift(mask)
|
|
147
|
+
return ops.convert_to_tensor(mask)
|
|
148
|
+
|
|
149
|
+
def forward(self, data):
|
|
150
|
+
"""Apply Fourier-domain blurring.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
data: Input tensor of shape (B, H, W, C)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Blurred data tensor.
|
|
157
|
+
"""
|
|
158
|
+
# Convert to float32 for FFT
|
|
159
|
+
data_real = ops.cast(data, "float32")
|
|
160
|
+
data_imag = ops.zeros_like(data_real)
|
|
161
|
+
|
|
162
|
+
# fft2 calculates the 2D FFT on the last two dims, so we want
|
|
163
|
+
# H, W to be at the end
|
|
164
|
+
data_real = ops.transpose(data_real, (0, 3, 1, 2))
|
|
165
|
+
data_imag = ops.transpose(data_imag, (0, 3, 1, 2))
|
|
166
|
+
|
|
167
|
+
# Apply FFT - expects tuple (real, imag), returns tuple (real, imag)
|
|
168
|
+
fft_real, fft_imag = ops.fft2((data_real, data_imag))
|
|
169
|
+
|
|
170
|
+
# Apply frequency mask to both real and imaginary parts
|
|
171
|
+
mask_real = ops.real(self.freq_mask) # Extract real part of complex mask
|
|
172
|
+
masked_fft_real = fft_real * mask_real
|
|
173
|
+
masked_fft_imag = fft_imag * mask_real
|
|
174
|
+
|
|
175
|
+
# Apply inverse FFT
|
|
176
|
+
blurred_real, blurred_imag = ops.ifft2((masked_fft_real, masked_fft_imag))
|
|
177
|
+
|
|
178
|
+
# transpose back to original shape
|
|
179
|
+
blurred_real = ops.transpose(blurred_real, (0, 2, 3, 1))
|
|
180
|
+
|
|
181
|
+
# Take real part (imaginary should be ~0 for real input)
|
|
182
|
+
blurred_data = ops.cast(blurred_real, data.dtype)
|
|
183
|
+
|
|
184
|
+
return blurred_data
|
|
185
|
+
|
|
186
|
+
def transpose(self, data):
|
|
187
|
+
"""
|
|
188
|
+
transpose = forward because A^* (F^{-1} M F)^* = (F^* M^* (F^{-1})^*) = F^{-1} M F = A
|
|
189
|
+
i.e. this is a self-adjoint operator.
|
|
190
|
+
"""
|
|
191
|
+
return self.forward(data)
|
|
192
|
+
|
|
193
|
+
def __str__(self):
|
|
194
|
+
return f"y = F^(-1)(M * F(x)) filter at {self.cutoff_freq}"
|