zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -1
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/config.py +34 -25
  9. zea/data/__init__.py +22 -16
  10. zea/data/convert/camus.py +2 -1
  11. zea/data/convert/echonet.py +4 -4
  12. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  13. zea/data/convert/matlab.py +11 -4
  14. zea/data/data_format.py +31 -30
  15. zea/data/datasets.py +7 -5
  16. zea/data/file.py +104 -2
  17. zea/data/layers.py +3 -3
  18. zea/datapaths.py +16 -4
  19. zea/display.py +7 -5
  20. zea/interface.py +14 -16
  21. zea/internal/_generate_keras_ops.py +6 -7
  22. zea/internal/cache.py +2 -49
  23. zea/internal/config/validation.py +1 -2
  24. zea/internal/core.py +69 -6
  25. zea/internal/device.py +6 -2
  26. zea/internal/dummy_scan.py +330 -0
  27. zea/internal/operators.py +114 -2
  28. zea/internal/parameters.py +101 -70
  29. zea/internal/setup_zea.py +5 -6
  30. zea/internal/utils.py +282 -0
  31. zea/io_lib.py +247 -19
  32. zea/keras_ops.py +74 -4
  33. zea/log.py +9 -7
  34. zea/metrics.py +15 -7
  35. zea/models/__init__.py +30 -20
  36. zea/models/base.py +30 -14
  37. zea/models/carotid_segmenter.py +19 -4
  38. zea/models/diffusion.py +173 -12
  39. zea/models/echonet.py +22 -8
  40. zea/models/echonetlvh.py +31 -7
  41. zea/models/lpips.py +19 -2
  42. zea/models/lv_segmentation.py +28 -11
  43. zea/models/preset_utils.py +5 -5
  44. zea/models/regional_quality.py +30 -10
  45. zea/models/taesd.py +21 -5
  46. zea/models/unet.py +15 -1
  47. zea/ops.py +390 -196
  48. zea/probes.py +6 -6
  49. zea/scan.py +109 -49
  50. zea/simulator.py +24 -21
  51. zea/tensor_ops.py +406 -302
  52. zea/tools/hf.py +1 -1
  53. zea/tools/selection_tool.py +47 -86
  54. zea/utils.py +92 -480
  55. zea/visualize.py +177 -39
  56. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
  57. zea-0.0.7.dist-info/RECORD +114 -0
  58. zea-0.0.6.dist-info/RECORD +0 -112
  59. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
  60. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  61. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/internal/core.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Base classes for the toolbox"""
2
2
 
3
3
  import enum
4
+ import hashlib
4
5
  import json
5
6
  import pickle
6
7
  from copy import deepcopy
@@ -8,7 +9,8 @@ from copy import deepcopy
8
9
  import keras
9
10
  import numpy as np
10
11
 
11
- from zea.utils import reduce_to_signature, update_dictionary
12
+ from zea.internal.utils import reduce_to_signature
13
+ from zea.utils import update_dictionary
12
14
 
13
15
  CONVERT_TO_KERAS_TYPES = (np.ndarray, int, float, list, tuple, bool)
14
16
  BASE_FLOAT_PRECISION = "float32"
@@ -76,7 +78,7 @@ class Object:
76
78
  attributes.pop(
77
79
  "_serialized", None
78
80
  ) # Remove the cached serialized attribute to avoid recursion
79
- self._serialized = pickle.dumps(attributes)
81
+ self._serialized = serialize_elements([attributes])
80
82
  return self._serialized
81
83
 
82
84
  def __setattr__(self, name: str, value):
@@ -167,9 +169,7 @@ def _skip_to_tensor(value):
167
169
  # Skip str (because JIT does not support it)
168
170
  # Skip methods and functions
169
171
  # Skip byte strings
170
- if isinstance(value, str) or callable(value) or isinstance(value, bytes):
171
- return True
172
- return False
172
+ return isinstance(value, str) or callable(value) or isinstance(value, bytes)
173
173
 
174
174
 
175
175
  def dict_to_tensor(dictionary, keep_as_is=None):
@@ -184,8 +184,9 @@ def dict_to_tensor(dictionary, keep_as_is=None):
184
184
  # Get the value from the dictionary
185
185
  value = dictionary[key]
186
186
 
187
- if isinstance(value, Object):
187
+ if isinstance(value, Object) and hasattr(value, "to_tensor"):
188
188
  snapshot[key] = value.to_tensor(keep_as_is=keep_as_is)
189
+ continue
189
190
 
190
191
  # Skip certain types
191
192
  if _skip_to_tensor(value):
@@ -288,3 +289,65 @@ class ZEADecoderJSON(json.JSONDecoder):
288
289
  obj[key] = self._MOD_TYPES_MAP[value] if value is not None else None
289
290
 
290
291
  return obj
292
+
293
+
294
+ def serialize_elements(key_elements: list) -> str:
295
+ """Serialize elements of a list to a string.
296
+
297
+ Generally, uses the pickle representation of the elements.
298
+
299
+ Args:
300
+ key_elements (list): List of elements to serialize. Can be nested lists
301
+ or tuples. In this case the elements are serialized recursively.
302
+
303
+ Returns:
304
+ str: A serialized string representation of the elements, joined by underscores.
305
+ """
306
+
307
+ def _serialize(element) -> str:
308
+ return pickle.dumps(element).hex()
309
+
310
+ def _serialize_element(element) -> str:
311
+ if isinstance(element, (list, tuple)):
312
+ # If element is a list or tuple, serialize its elements recursively
313
+ element = serialize_elements(element)
314
+ elif isinstance(element, Object) and hasattr(element, "serialized"):
315
+ # Use the serialized attribute if it exists
316
+ element = str(element.serialized)
317
+ elif isinstance(element, keras.random.SeedGenerator):
318
+ # If element is a SeedGenerator, use the state
319
+ element = keras.ops.convert_to_numpy(element.state.value)
320
+ element = _serialize(element)
321
+ elif isinstance(element, dict):
322
+ # If element is a dictionary, sort its keys and serialize its values recursively.
323
+ # This is needed to ensure the internal state and ordering of the dictionary does
324
+ # not affect the serialization.
325
+ keys = list(sorted(element.keys()))
326
+ values = [element[k] for k in keys]
327
+ keys = serialize_elements(keys)
328
+ values = serialize_elements(values)
329
+ element = f"k_{keys}_v_{values}"
330
+ else:
331
+ # Otherwise, serialize the element directly
332
+ element = _serialize(element)
333
+
334
+ return element
335
+
336
+ serialized_elements = []
337
+ for element in key_elements:
338
+ serialized_elements.append(_serialize_element(element))
339
+
340
+ return "_".join(serialized_elements)
341
+
342
+
343
+ def hash_elements(key_elements: list) -> str:
344
+ """Generate an MD5 hash of the elements.
345
+
346
+ Args:
347
+ key_elements (list): List of elements to serialize and hash.
348
+
349
+ Returns:
350
+ str: An MD5 hash of the serialized elements.
351
+ """
352
+ serialized = serialize_elements(key_elements)
353
+ return hashlib.md5(serialized.encode()).hexdigest()
zea/internal/device.py CHANGED
@@ -377,7 +377,11 @@ def init_device(
377
377
  allow_preallocate: bool = True,
378
378
  verbose: bool = True,
379
379
  ):
380
- """Selects a GPU or CPU device based on the config.
380
+ """Automatically selects a GPU or CPU device.
381
+
382
+ Useful to call at the start of a script to set the device for
383
+ tensorflow, jax or pytorch. The function will select a GPU based
384
+ on available memory, or fall back to CPU if no GPU is available.
381
385
 
382
386
  Args:
383
387
  backend (str): String indicating which backend to use. Can be
@@ -412,7 +416,7 @@ def init_device(
412
416
  elif backend in ["numpy", "cpu"]:
413
417
  device = "cpu"
414
418
  else:
415
- raise ValueError(f"Unknown backend ({backend}) in config.")
419
+ raise ValueError(f"Unknown backend ({backend}).")
416
420
 
417
421
  # Early exit if device is CPU
418
422
  if device == "cpu":
@@ -0,0 +1,330 @@
1
+ """Module to create dummy Scan objects for testing and simulation purposes."""
2
+
3
+ import numpy as np
4
+
5
+ from zea.beamform.delays import compute_t0_delays_focused, compute_t0_delays_planewave
6
+ from zea.probes import Probe
7
+ from zea.scan import Scan
8
+
9
+
10
+ def _get_linear_probe():
11
+ """Returns a probe for ultrasound simulation tests."""
12
+ n_el = 128
13
+ aperture = 30e-3
14
+ probe_geometry = np.stack(
15
+ [
16
+ np.linspace(-aperture / 2, aperture / 2, n_el),
17
+ np.zeros(n_el),
18
+ np.zeros(n_el),
19
+ ],
20
+ axis=1,
21
+ )
22
+
23
+ return Probe(
24
+ probe_geometry=probe_geometry,
25
+ center_frequency=2.5e6,
26
+ sampling_frequency=10e6,
27
+ )
28
+
29
+
30
+ def _get_phased_array_probe():
31
+ """Returns a probe for ultrasound simulation tests."""
32
+ n_el = 80
33
+ aperture = 20e-3
34
+ probe_geometry = np.stack(
35
+ [
36
+ np.linspace(-aperture / 2, aperture / 2, n_el),
37
+ np.zeros(n_el),
38
+ np.zeros(n_el),
39
+ ],
40
+ axis=1,
41
+ )
42
+
43
+ return Probe(
44
+ probe_geometry=probe_geometry,
45
+ center_frequency=3.12e6,
46
+ sampling_frequency=12.5e6,
47
+ )
48
+
49
+
50
+ def _get_n_ax(ultrasound_probe):
51
+ """Returns the number of ax for ultrasound simulation tests based on the center
52
+ frequency. A probe with a higher center frequency needs more samples to cover
53
+ the image depth.
54
+ """
55
+ is_low_frequency_probe = ultrasound_probe.center_frequency < 4e6
56
+
57
+ if is_low_frequency_probe:
58
+ return 510
59
+
60
+ return 1024
61
+
62
+
63
+ def _get_probe(kind) -> Probe:
64
+ if kind == "linear":
65
+ return _get_linear_probe()
66
+ elif kind == "phased_array":
67
+ return _get_phased_array_probe()
68
+ else:
69
+ raise ValueError(f"Unknown probe kind: {kind}")
70
+
71
+
72
+ def _get_constant_scan_kwargs():
73
+ return {
74
+ "lens_sound_speed": 1000,
75
+ "lens_thickness": 1e-3,
76
+ "n_ch": 1,
77
+ "selected_transmits": "all",
78
+ "sound_speed": 1540.0,
79
+ "apply_lens_correction": False,
80
+ "attenuation_coef": 0.0,
81
+ }
82
+
83
+
84
+ def _get_lims_and_gridsize(center_frequency, sound_speed):
85
+ """Returns the limits and gridsize for ultrasound simulation tests."""
86
+ xlims, zlims = (-20e-3, 20e-3), (0, 35e-3)
87
+ width, height = xlims[1] - xlims[0], zlims[1] - zlims[0]
88
+ wavelength = sound_speed / center_frequency
89
+ gridsize = (
90
+ int(width / (0.5 * wavelength)) + 1,
91
+ int(height / (0.5 * wavelength)) + 1,
92
+ )
93
+ return {"xlims": xlims, "zlims": zlims, "grid_size_x": gridsize[0], "grid_size_z": gridsize[1]}
94
+
95
+
96
+ def _get_planewave_scan(ultrasound_probe, grid_type, **kwargs):
97
+ """Returns a scan for ultrasound simulation tests."""
98
+ constant_scan_kwargs = _get_constant_scan_kwargs()
99
+ n_el = ultrasound_probe.n_el
100
+ n_tx = 8
101
+
102
+ tx_apodizations = np.ones((n_tx, n_el)) * np.hanning(n_el)[None]
103
+ probe_geometry = ultrasound_probe.probe_geometry
104
+
105
+ angles = np.linspace(10, -10, n_tx) * np.pi / 180
106
+
107
+ sound_speed = constant_scan_kwargs["sound_speed"]
108
+ focus_distances = np.ones(n_tx) * np.inf
109
+ t0_delays = compute_t0_delays_planewave(
110
+ probe_geometry=probe_geometry, polar_angles=angles, sound_speed=sound_speed
111
+ )
112
+
113
+ return Scan(
114
+ n_tx=n_tx,
115
+ n_el=n_el,
116
+ center_frequency=ultrasound_probe.center_frequency,
117
+ sampling_frequency=ultrasound_probe.sampling_frequency,
118
+ probe_geometry=probe_geometry,
119
+ t0_delays=t0_delays,
120
+ tx_apodizations=tx_apodizations,
121
+ element_width=np.linalg.norm(probe_geometry[1] - probe_geometry[0]),
122
+ focus_distances=focus_distances,
123
+ polar_angles=angles,
124
+ initial_times=np.ones(n_tx) * 1e-6,
125
+ n_ax=_get_n_ax(ultrasound_probe),
126
+ grid_type=grid_type,
127
+ **_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
128
+ **constant_scan_kwargs,
129
+ **kwargs,
130
+ )
131
+
132
+
133
+ def _get_multistatic_scan(ultrasound_probe, grid_type, **kwargs):
134
+ n_el = ultrasound_probe.n_el
135
+ n_tx = 8
136
+
137
+ tx_apodizations = np.zeros((n_tx, n_el))
138
+ for n, idx in enumerate(np.linspace(0, n_el - 1, n_tx, dtype=int)):
139
+ tx_apodizations[n, idx] = 1
140
+ probe_geometry = ultrasound_probe.probe_geometry
141
+
142
+ focus_distances = np.zeros(n_tx)
143
+ t0_delays = np.zeros((n_tx, n_el))
144
+
145
+ constant_scan_kwargs = _get_constant_scan_kwargs()
146
+
147
+ return Scan(
148
+ n_tx=n_tx,
149
+ n_el=n_el,
150
+ center_frequency=ultrasound_probe.center_frequency,
151
+ sampling_frequency=ultrasound_probe.sampling_frequency,
152
+ probe_geometry=probe_geometry,
153
+ t0_delays=t0_delays,
154
+ tx_apodizations=tx_apodizations,
155
+ element_width=np.linalg.norm(probe_geometry[1] - probe_geometry[0]),
156
+ focus_distances=focus_distances,
157
+ polar_angles=np.zeros(n_tx),
158
+ initial_times=np.ones(n_tx) * 1e-6,
159
+ n_ax=_get_n_ax(ultrasound_probe),
160
+ grid_type=grid_type,
161
+ **_get_lims_and_gridsize(
162
+ ultrasound_probe.center_frequency, constant_scan_kwargs["sound_speed"]
163
+ ),
164
+ **constant_scan_kwargs,
165
+ **kwargs,
166
+ )
167
+
168
+
169
+ def _get_diverging_scan(ultrasound_probe, grid_type, **kwargs):
170
+ """Returns a scan for ultrasound simulation tests."""
171
+ constant_scan_kwargs = _get_constant_scan_kwargs()
172
+ n_el = ultrasound_probe.n_el
173
+ n_tx = 8
174
+
175
+ tx_apodizations = np.ones((n_tx, n_el)) * np.hanning(n_el)[None]
176
+
177
+ angles = np.linspace(10, -10, n_tx) * np.pi / 180
178
+
179
+ sound_speed = constant_scan_kwargs["sound_speed"]
180
+ focus_distances = np.ones(n_tx) * -15e-3
181
+ t0_delays = compute_t0_delays_focused(
182
+ origins=np.zeros((n_tx, 3)),
183
+ focus_distances=focus_distances,
184
+ probe_geometry=ultrasound_probe.probe_geometry,
185
+ polar_angles=angles,
186
+ sound_speed=sound_speed,
187
+ )
188
+ element_width = np.linalg.norm(
189
+ ultrasound_probe.probe_geometry[1] - ultrasound_probe.probe_geometry[0]
190
+ )
191
+
192
+ return Scan(
193
+ n_tx=n_tx,
194
+ n_el=n_el,
195
+ center_frequency=ultrasound_probe.center_frequency,
196
+ sampling_frequency=ultrasound_probe.sampling_frequency,
197
+ probe_geometry=ultrasound_probe.probe_geometry,
198
+ t0_delays=t0_delays,
199
+ tx_apodizations=tx_apodizations,
200
+ element_width=element_width,
201
+ focus_distances=focus_distances,
202
+ polar_angles=angles,
203
+ initial_times=np.ones(n_tx) * 1e-6,
204
+ n_ax=_get_n_ax(ultrasound_probe),
205
+ grid_type=grid_type,
206
+ **_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
207
+ **constant_scan_kwargs,
208
+ **kwargs,
209
+ )
210
+
211
+
212
+ def _get_focused_scan(ultrasound_probe, grid_type, **kwargs):
213
+ """Returns a scan for ultrasound simulation tests."""
214
+ constant_scan_kwargs = _get_constant_scan_kwargs()
215
+ n_el = ultrasound_probe.n_el
216
+ n_tx = 8
217
+
218
+ tx_apodizations = np.ones((n_tx, n_el)) * np.hanning(n_el)[None]
219
+
220
+ angles = np.linspace(30, -30, n_tx) * np.pi / 180
221
+
222
+ sound_speed = constant_scan_kwargs["sound_speed"]
223
+ focus_distances = np.ones(n_tx) * 15e-3
224
+ t0_delays = compute_t0_delays_focused(
225
+ origins=np.zeros((n_tx, 3)),
226
+ focus_distances=focus_distances,
227
+ probe_geometry=ultrasound_probe.probe_geometry,
228
+ polar_angles=angles,
229
+ sound_speed=sound_speed,
230
+ )
231
+ element_width = np.linalg.norm(
232
+ ultrasound_probe.probe_geometry[1] - ultrasound_probe.probe_geometry[0]
233
+ )
234
+
235
+ return Scan(
236
+ n_tx=n_tx,
237
+ n_el=n_el,
238
+ center_frequency=ultrasound_probe.center_frequency,
239
+ sampling_frequency=ultrasound_probe.sampling_frequency,
240
+ probe_geometry=ultrasound_probe.probe_geometry,
241
+ t0_delays=t0_delays,
242
+ tx_apodizations=tx_apodizations,
243
+ element_width=element_width,
244
+ focus_distances=focus_distances,
245
+ polar_angles=angles,
246
+ initial_times=np.ones(n_tx) * 1e-6,
247
+ n_ax=_get_n_ax(ultrasound_probe),
248
+ grid_type=grid_type,
249
+ **_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
250
+ **constant_scan_kwargs,
251
+ **kwargs,
252
+ )
253
+
254
+
255
+ def _get_linescan_scan(ultrasound_probe, grid_type, **kwargs):
256
+ """Returns a scan for ultrasound simulation tests."""
257
+ constant_scan_kwargs = _get_constant_scan_kwargs()
258
+ n_el = ultrasound_probe.n_el
259
+ n_tx = 8
260
+
261
+ center_elements = np.linspace(0, n_el + 1, n_tx + 2, dtype=int)
262
+ center_elements = center_elements[1:-1]
263
+ tx_apodizations = np.zeros((n_tx, n_el))
264
+ aperture_size_elements = 24
265
+
266
+ # Define subapertures
267
+ origins = []
268
+ for n, idx in enumerate(center_elements):
269
+ el0 = np.clip(idx - aperture_size_elements // 2, 0, n_el)
270
+ el1 = np.clip(idx + aperture_size_elements // 2, 0, n_el)
271
+ tx_apodizations[n, el0:el1] = np.hanning(el1 - el0)[None]
272
+ origins.append(ultrasound_probe.probe_geometry[idx])
273
+ origins = np.stack(origins, axis=0)
274
+
275
+ # All angles should be zero because each line fires straight ahead
276
+ angles = np.zeros(n_tx)
277
+
278
+ sound_speed = constant_scan_kwargs["sound_speed"]
279
+
280
+ focus_distances = np.ones(n_tx) * 15e-3
281
+ t0_delays = compute_t0_delays_focused(
282
+ origins=origins,
283
+ focus_distances=focus_distances,
284
+ probe_geometry=ultrasound_probe.probe_geometry,
285
+ polar_angles=angles,
286
+ sound_speed=sound_speed,
287
+ )
288
+ element_width = np.linalg.norm(
289
+ ultrasound_probe.probe_geometry[1] - ultrasound_probe.probe_geometry[0]
290
+ )
291
+
292
+ return Scan(
293
+ n_tx=n_tx,
294
+ n_el=n_el,
295
+ center_frequency=ultrasound_probe.center_frequency,
296
+ sampling_frequency=ultrasound_probe.sampling_frequency,
297
+ probe_geometry=ultrasound_probe.probe_geometry,
298
+ t0_delays=t0_delays,
299
+ tx_apodizations=tx_apodizations,
300
+ element_width=element_width,
301
+ focus_distances=focus_distances,
302
+ polar_angles=angles,
303
+ initial_times=np.ones(n_tx) * 1e-6,
304
+ n_ax=_get_n_ax(ultrasound_probe),
305
+ grid_type=grid_type,
306
+ **_get_lims_and_gridsize(ultrasound_probe.center_frequency, sound_speed),
307
+ **constant_scan_kwargs,
308
+ **kwargs,
309
+ )
310
+
311
+
312
+ def _get_scan(ultrasound_probe, kind, grid_type="cartesian", **kwargs) -> Scan:
313
+ if kind == "planewave":
314
+ return _get_planewave_scan(ultrasound_probe, grid_type, **kwargs)
315
+ elif kind == "multistatic":
316
+ return _get_multistatic_scan(ultrasound_probe, grid_type, **kwargs)
317
+ elif kind == "diverging":
318
+ return _get_diverging_scan(ultrasound_probe, grid_type, **kwargs)
319
+ elif kind == "focused":
320
+ return _get_focused_scan(ultrasound_probe, grid_type, **kwargs)
321
+ elif kind == "linescan":
322
+ return _get_linescan_scan(ultrasound_probe, grid_type, **kwargs)
323
+ else:
324
+ raise ValueError(f"Unknown scan kind: {kind}")
325
+
326
+
327
+ def get_scan(kind="planewave", probe_kind="linear", grid_type="cartesian", **kwargs) -> Scan:
328
+ """Returns a scan for ultrasound simulation tests."""
329
+ ultrasound_probe = _get_probe(probe_kind)
330
+ return _get_scan(ultrasound_probe, kind, grid_type, **kwargs)
zea/internal/operators.py CHANGED
@@ -6,6 +6,7 @@ Handles task-dependent operations (A) and noises (n) to simulate a measurement y
6
6
 
7
7
  import abc
8
8
 
9
+ import numpy as np
9
10
  from keras import ops
10
11
 
11
12
  from zea.internal.core import Object
@@ -26,13 +27,16 @@ class Operator(abc.ABC, Object):
26
27
 
27
28
  """
28
29
 
29
- sigma = 0.0
30
-
31
30
  @abc.abstractmethod
32
31
  def forward(self, data, *args, **kwargs):
33
32
  """Implements the forward operator A: x -> y."""
34
33
  raise NotImplementedError
35
34
 
35
+ @abc.abstractmethod
36
+ def transpose(self, data, *args, **kwargs):
37
+ """Implements the transpose (or adjoint) of the operator A^T: y -> x."""
38
+ raise NotImplementedError
39
+
36
40
  @abc.abstractmethod
37
41
  def __str__(self):
38
42
  """String representation of the operator."""
@@ -78,5 +82,113 @@ class InpaintingOperator(Operator):
78
82
  # return self.mask * data
79
83
  return ops.where(mask, data, self.min_val)
80
84
 
85
+ def transpose(self, data, mask):
86
+ # masking operation is diagonal --> A.T = A
87
+ return self.forward(data, mask)
88
+
81
89
  def __str__(self):
82
90
  return "y = Ax + n, where A = I * M"
91
+
92
+
93
+ @operator_registry(name="fourier_blur")
94
+ class FourierBlurOperator(Operator):
95
+ """Fourier-domain blurring operator class.
96
+
97
+ Applies blurring by masking high frequencies in the Fourier domain.
98
+
99
+ Formally defined as:
100
+ y = F^(-1)(M * F(x))
101
+
102
+ where F is the FFT, F^(-1) is the inverse FFT and M is the frequency mask.
103
+ """
104
+
105
+ def __init__(self, shape, cutoff_freq=0.5, smooth=True, **kwargs):
106
+ """Initialize the Fourier blur operator.
107
+
108
+ Args:
109
+ shape: Shape of the input data (H, W), (H, W, C), or (B, H, W, C).
110
+ cutoff_freq: Cutoff frequency as fraction of Nyquist frequency (0.0 to 1.0).
111
+ smooth: If True, use Gaussian rolloff; otherwise use hard cutoff.
112
+ **kwargs: Additional arguments.
113
+ """
114
+ super().__init__(**kwargs)
115
+ self.cutoff_freq = cutoff_freq
116
+ self.shape = shape
117
+
118
+ # Precompute frequency mask
119
+ self.freq_mask = self.make_lowpass_mask(shape=shape, cutoff_freq=cutoff_freq, smooth=smooth)
120
+
121
+ def make_lowpass_mask(self, shape, cutoff_freq=0.1, smooth=True):
122
+ """
123
+ Create a low-pass Fourier mask of given shape.
124
+ cutoff: relative frequency radius (0 < cutoff < 0.5)
125
+ smooth: if True, use Gaussian rolloff
126
+ """
127
+ # Accept (H, W), (H, W, C) or (B, H, W, C)
128
+ if len(shape) == 2:
129
+ H, W = shape
130
+ elif len(shape) >= 3:
131
+ H, W = shape[-3], shape[-2]
132
+ else:
133
+ raise ValueError(f"Invalid shape {shape}. Expected (H, W), (H, W, C), or (B, H, W, C).")
134
+ fy = np.fft.fftfreq(H)
135
+ fx = np.fft.fftfreq(W)
136
+ FX, FY = np.meshgrid(fx, fy)
137
+ R = np.sqrt(FX**2 + FY**2)
138
+
139
+ if smooth:
140
+ sigma = cutoff_freq / np.sqrt(2 * np.log(2))
141
+ mask = np.exp(-(R**2) / (2 * sigma**2))
142
+ else:
143
+ mask = (R < cutoff_freq).astype(np.float32)
144
+
145
+ # Shift DC to top-left to match fft2 conventions
146
+ # mask = np.fft.ifftshift(mask)
147
+ return ops.convert_to_tensor(mask)
148
+
149
+ def forward(self, data):
150
+ """Apply Fourier-domain blurring.
151
+
152
+ Args:
153
+ data: Input tensor of shape (B, H, W, C)
154
+
155
+ Returns:
156
+ Blurred data tensor.
157
+ """
158
+ # Convert to float32 for FFT
159
+ data_real = ops.cast(data, "float32")
160
+ data_imag = ops.zeros_like(data_real)
161
+
162
+ # fft2 calculates the 2D FFT on the last two dims, so we want
163
+ # H, W to be at the end
164
+ data_real = ops.transpose(data_real, (0, 3, 1, 2))
165
+ data_imag = ops.transpose(data_imag, (0, 3, 1, 2))
166
+
167
+ # Apply FFT - expects tuple (real, imag), returns tuple (real, imag)
168
+ fft_real, fft_imag = ops.fft2((data_real, data_imag))
169
+
170
+ # Apply frequency mask to both real and imaginary parts
171
+ mask_real = ops.real(self.freq_mask) # Extract real part of complex mask
172
+ masked_fft_real = fft_real * mask_real
173
+ masked_fft_imag = fft_imag * mask_real
174
+
175
+ # Apply inverse FFT
176
+ blurred_real, blurred_imag = ops.ifft2((masked_fft_real, masked_fft_imag))
177
+
178
+ # transpose back to original shape
179
+ blurred_real = ops.transpose(blurred_real, (0, 2, 3, 1))
180
+
181
+ # Take real part (imaginary should be ~0 for real input)
182
+ blurred_data = ops.cast(blurred_real, data.dtype)
183
+
184
+ return blurred_data
185
+
186
+ def transpose(self, data):
187
+ """
188
+ transpose = forward because A^* (F^{-1} M F)^* = (F^* M^* (F^{-1})^*) = F^{-1} M F = A
189
+ i.e. this is a self-adjoint operator.
190
+ """
191
+ return self.forward(data)
192
+
193
+ def __str__(self):
194
+ return f"y = F^(-1)(M * F(x)) filter at {self.cutoff_freq}"