zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -1
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/config.py +34 -25
  9. zea/data/__init__.py +22 -16
  10. zea/data/convert/camus.py +2 -1
  11. zea/data/convert/echonet.py +4 -4
  12. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  13. zea/data/convert/matlab.py +11 -4
  14. zea/data/data_format.py +31 -30
  15. zea/data/datasets.py +7 -5
  16. zea/data/file.py +104 -2
  17. zea/data/layers.py +3 -3
  18. zea/datapaths.py +16 -4
  19. zea/display.py +7 -5
  20. zea/interface.py +14 -16
  21. zea/internal/_generate_keras_ops.py +6 -7
  22. zea/internal/cache.py +2 -49
  23. zea/internal/config/validation.py +1 -2
  24. zea/internal/core.py +69 -6
  25. zea/internal/device.py +6 -2
  26. zea/internal/dummy_scan.py +330 -0
  27. zea/internal/operators.py +114 -2
  28. zea/internal/parameters.py +101 -70
  29. zea/internal/setup_zea.py +5 -6
  30. zea/internal/utils.py +282 -0
  31. zea/io_lib.py +247 -19
  32. zea/keras_ops.py +74 -4
  33. zea/log.py +9 -7
  34. zea/metrics.py +15 -7
  35. zea/models/__init__.py +30 -20
  36. zea/models/base.py +30 -14
  37. zea/models/carotid_segmenter.py +19 -4
  38. zea/models/diffusion.py +173 -12
  39. zea/models/echonet.py +22 -8
  40. zea/models/echonetlvh.py +31 -7
  41. zea/models/lpips.py +19 -2
  42. zea/models/lv_segmentation.py +28 -11
  43. zea/models/preset_utils.py +5 -5
  44. zea/models/regional_quality.py +30 -10
  45. zea/models/taesd.py +21 -5
  46. zea/models/unet.py +15 -1
  47. zea/ops.py +390 -196
  48. zea/probes.py +6 -6
  49. zea/scan.py +109 -49
  50. zea/simulator.py +24 -21
  51. zea/tensor_ops.py +406 -302
  52. zea/tools/hf.py +1 -1
  53. zea/tools/selection_tool.py +47 -86
  54. zea/utils.py +92 -480
  55. zea/visualize.py +177 -39
  56. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
  57. zea-0.0.7.dist-info/RECORD +114 -0
  58. zea-0.0.6.dist-info/RECORD +0 -112
  59. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
  60. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  61. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/ops.py CHANGED
@@ -11,39 +11,62 @@ Operations can be run on their own:
11
11
 
12
12
  Examples
13
13
  ^^^^^^^^
14
- .. code-block:: python
14
+ .. doctest::
15
15
 
16
- data = np.random.randn(2000, 128, 1)
17
- # static arguments are passed in the constructor
18
- envelope_detect = EnvelopeDetect(axis=-1)
19
- # other parameters can be passed here along with the data
20
- envelope_data = envelope_detect(data=data)
16
+ >>> import numpy as np
17
+ >>> from zea.ops import EnvelopeDetect
18
+ >>> data = np.random.randn(2000, 128, 1)
19
+ >>> # static arguments are passed in the constructor
20
+ >>> envelope_detect = EnvelopeDetect(axis=-1)
21
+ >>> # other parameters can be passed here along with the data
22
+ >>> envelope_data = envelope_detect(data=data)
21
23
 
22
24
  Using a pipeline
23
25
  ----------------
24
26
 
25
27
  You can initialize with a default pipeline or create your own custom pipeline.
26
28
 
27
- .. code-block:: python
29
+ .. doctest::
28
30
 
29
- pipeline = Pipeline.from_default()
31
+ >>> from zea.ops import Pipeline, EnvelopeDetect, Normalize, LogCompress
32
+ >>> pipeline = Pipeline.from_default()
30
33
 
31
- operations = [
32
- EnvelopeDetect(),
33
- Normalize(),
34
- LogCompress(),
35
- ]
36
- pipeline_custom = Pipeline(operations)
34
+ >>> operations = [
35
+ ... EnvelopeDetect(),
36
+ ... Normalize(),
37
+ ... LogCompress(),
38
+ ... ]
39
+ >>> pipeline_custom = Pipeline(operations)
37
40
 
38
41
  One can also load a pipeline from a config or yaml/json file:
39
42
 
40
- .. code-block:: python
43
+ .. doctest::
44
+
45
+ >>> from zea import Pipeline
46
+
47
+ >>> # From JSON string
48
+ >>> json_string = '{"operations": ["identity"]}'
49
+ >>> pipeline = Pipeline.from_json(json_string)
50
+
51
+ >>> # from yaml file
52
+ >>> import yaml
53
+ >>> from zea import Config
54
+ >>> # Create a sample pipeline YAML file
55
+ >>> pipeline_dict = {
56
+ ... "operations": [
57
+ ... {"name": "identity"},
58
+ ... ]
59
+ ... }
60
+ >>> with open("pipeline.yaml", "w") as f:
61
+ ... yaml.dump(pipeline_dict, f)
62
+ >>> yaml_file = "pipeline.yaml"
63
+ >>> pipeline = Pipeline.from_yaml(yaml_file)
64
+
65
+ .. testcleanup::
41
66
 
42
- json_string = '{"operations": ["identity"]}'
43
- pipeline = Pipeline.from_json(json_string)
67
+ import os
44
68
 
45
- yaml_file = "pipeline.yaml"
46
- pipeline = Pipeline.from_yaml(yaml_file)
69
+ os.remove("pipeline.yaml")
47
70
 
48
71
  Example of a yaml file:
49
72
 
@@ -56,8 +79,6 @@ Example of a yaml file:
56
79
  params:
57
80
  operations:
58
81
  - name: tof_correction
59
- params:
60
- apply_phase_rotation: true
61
82
  - name: pfield_weighting
62
83
  - name: delay_and_sum
63
84
  num_patches: 100
@@ -79,7 +100,7 @@ import numpy as np
79
100
  import scipy
80
101
  import yaml
81
102
  from keras import ops
82
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
103
+ from keras.src.layers.preprocessing.data_layer import DataLayer
83
104
 
84
105
  from zea import log
85
106
  from zea.backend import jit
@@ -99,8 +120,12 @@ from zea.internal.registry import ops_registry
99
120
  from zea.probes import Probe
100
121
  from zea.scan import Scan
101
122
  from zea.simulator import simulate_rf
102
- from zea.tensor_ops import batched_map, patched_map, resample, reshape_axis
103
- from zea.utils import FunctionTimer, deep_compare, map_negative_indices, translate
123
+ from zea.tensor_ops import resample, reshape_axis, translate, vmap
124
+ from zea.utils import (
125
+ FunctionTimer,
126
+ deep_compare,
127
+ map_negative_indices,
128
+ )
104
129
 
105
130
 
106
131
  def get_ops(ops_name):
@@ -125,6 +150,7 @@ class Operation(keras.Operation):
125
150
  with_batch_dim: bool = True,
126
151
  jit_kwargs: dict | None = None,
127
152
  jittable: bool = True,
153
+ additional_output_keys: List[str] = None,
128
154
  **kwargs,
129
155
  ):
130
156
  """
@@ -143,6 +169,10 @@ class Operation(keras.Operation):
143
169
  with_batch_dim: Whether operations should expect a batch dimension in the input
144
170
  jit_kwargs: Additional keyword arguments for the JIT compiler
145
171
  jittable: Whether the operation can be JIT compiled
172
+ additional_output_keys: A list of additional output keys produced by the operation.
173
+ These are used to track if all keys are available for downstream operations.
174
+ If the operation has a conditional output, it is best to add all possible
175
+ output keys here.
146
176
  """
147
177
  super().__init__(**kwargs)
148
178
 
@@ -153,6 +183,7 @@ class Operation(keras.Operation):
153
183
  self.output_key = output_key # Key for output data
154
184
  if self.output_key is None:
155
185
  self.output_key = self.key
186
+ self.additional_output_keys = additional_output_keys or []
156
187
 
157
188
  self.inputs = [] # Source(s) of input data (name of a previous operation)
158
189
  self.allow_multiple_inputs = False # Only single input allowed by default
@@ -165,8 +196,6 @@ class Operation(keras.Operation):
165
196
  self._output_cache = {}
166
197
 
167
198
  # Obtain the input signature of the `call` method
168
- self._input_signature = None
169
- self._valid_keys = None # Keys valid for the `call` method
170
199
  self._trace_signatures()
171
200
 
172
201
  if jit_kwargs is None:
@@ -186,6 +215,11 @@ class Operation(keras.Operation):
186
215
  with log.set_level("ERROR"):
187
216
  self.set_jit(jit_compile)
188
217
 
218
+ @property
219
+ def output_keys(self) -> List[str]:
220
+ """Get the output keys of the operation."""
221
+ return [self.output_key] + self.additional_output_keys
222
+
189
223
  @property
190
224
  def static_params(self):
191
225
  """Get the static parameters of the operation."""
@@ -207,10 +241,15 @@ class Operation(keras.Operation):
207
241
  self._valid_keys = set(self._input_signature.parameters.keys())
208
242
 
209
243
  @property
210
- def valid_keys(self):
244
+ def valid_keys(self) -> set:
211
245
  """Get the valid keys for the `call` method."""
212
246
  return self._valid_keys
213
247
 
248
+ @property
249
+ def needs_keys(self) -> set:
250
+ """Get a set of all input keys needed by the operation."""
251
+ return self.valid_keys
252
+
214
253
  @property
215
254
  def jittable(self):
216
255
  """Check if the operation can be JIT compiled."""
@@ -408,8 +447,6 @@ class Pipeline:
408
447
  """
409
448
  self._call_pipeline = self.call
410
449
  self.name = name
411
- self.timer = FunctionTimer()
412
- self.timed = timed
413
450
 
414
451
  self._pipeline_layers = operations
415
452
 
@@ -419,6 +456,24 @@ class Pipeline:
419
456
  self.with_batch_dim = with_batch_dim
420
457
  self._validate_flag = validate
421
458
 
459
+ # Setup timer
460
+ if jit_options == "pipeline" and timed:
461
+ raise ValueError(
462
+ "timed=True cannot be used with jit_options='pipeline' as the entire "
463
+ "pipeline is compiled into a single function. Try setting jit_options to "
464
+ "'ops' or None."
465
+ )
466
+ if timed:
467
+ log.warning(
468
+ "Timer has been initialized for the pipeline. To get an accurate timing estimate, "
469
+ "the `block_until_ready()` is used, which will slow down the execution, so "
470
+ "do not use for regular processing!"
471
+ )
472
+ self._callable_layers = self._get_timed_operations()
473
+ else:
474
+ self._callable_layers = self._pipeline_layers
475
+ self._timed = timed
476
+
422
477
  if validate:
423
478
  self.validate()
424
479
  else:
@@ -427,19 +482,30 @@ class Pipeline:
427
482
  if jit_kwargs is None:
428
483
  jit_kwargs = {}
429
484
 
430
- if keras.backend.backend() == "jax" and self.static_params:
485
+ if keras.backend.backend() == "jax" and self.static_params != []:
431
486
  jit_kwargs = {"static_argnames": self.static_params}
432
487
 
433
488
  self.jit_kwargs = jit_kwargs
434
489
  self.jit_options = jit_options # will handle the jit compilation
435
490
 
436
491
  def needs(self, key) -> bool:
437
- """Check if the pipeline needs a specific key."""
438
- return key in self.valid_keys
492
+ """Check if the pipeline needs a specific key at the input."""
493
+ return key in self.needs_keys
494
+
495
+ @property
496
+ def output_keys(self) -> set:
497
+ """All output keys the pipeline guarantees to produce."""
498
+ output_keys = set()
499
+ for operation in self.operations:
500
+ output_keys.update(operation.output_keys)
501
+ return output_keys
439
502
 
440
503
  @property
441
504
  def valid_keys(self) -> set:
442
- """Get a set of valid keys for the pipeline."""
505
+ """Get a set of valid keys for the pipeline.
506
+
507
+ This is all keys that can be passed to the pipeline as input.
508
+ """
443
509
  valid_keys = set()
444
510
  for operation in self.operations:
445
511
  valid_keys.update(operation.valid_keys)
@@ -453,8 +519,26 @@ class Pipeline:
453
519
  static_params.extend(operation.static_params)
454
520
  return list(set(static_params))
455
521
 
522
+ @property
523
+ def needs_keys(self) -> set:
524
+ """Get a set of all input keys needed by the pipeline.
525
+
526
+ Will keep track of keys that are already provided by previous operations.
527
+ """
528
+ needs = set()
529
+ has_so_far = set()
530
+ previous_operation = None
531
+ for operation in self.operations:
532
+ if previous_operation is not None:
533
+ has_so_far.update(previous_operation.output_keys)
534
+ needs.update(operation.needs_keys - has_so_far)
535
+ previous_operation = operation
536
+ return needs
537
+
456
538
  @classmethod
457
- def from_default(cls, num_patches=100, baseband=False, pfield=False, **kwargs) -> "Pipeline":
539
+ def from_default(
540
+ cls, num_patches=100, baseband=False, pfield=False, timed=False, **kwargs
541
+ ) -> "Pipeline":
458
542
  """Create a default pipeline.
459
543
 
460
544
  Args:
@@ -465,6 +549,7 @@ class Pipeline:
465
549
  so input signal has a single channel dim and is still on carrier frequency.
466
550
  pfield (bool): If True, apply Pfield weighting. Defaults to False.
467
551
  This will calculate pressure field and only beamform the data to those locations.
552
+ timed (bool, optional): Whether to time each operation. Defaults to False.
468
553
  **kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
469
554
 
470
555
  """
@@ -476,7 +561,7 @@ class Pipeline:
476
561
 
477
562
  # Get beamforming ops
478
563
  beamforming = [
479
- TOFCorrection(apply_phase_rotation=True),
564
+ TOFCorrection(),
480
565
  DelayAndSum(),
481
566
  ]
482
567
  if pfield:
@@ -495,7 +580,7 @@ class Pipeline:
495
580
  Normalize(),
496
581
  LogCompress(),
497
582
  ]
498
- return cls(operations, **kwargs)
583
+ return cls(operations, timed=timed, **kwargs)
499
584
 
500
585
  def copy(self) -> "Pipeline":
501
586
  """Create a copy of the pipeline."""
@@ -506,57 +591,60 @@ class Pipeline:
506
591
  jit_kwargs=self.jit_kwargs,
507
592
  name=self.name,
508
593
  validate=self._validate_flag,
594
+ timed=self._timed,
595
+ )
596
+
597
+ def reinitialize(self):
598
+ """Reinitialize the pipeline in place."""
599
+ self.__init__(
600
+ self._pipeline_layers,
601
+ with_batch_dim=self.with_batch_dim,
602
+ jit_options=self.jit_options,
603
+ jit_kwargs=self.jit_kwargs,
604
+ name=self.name,
605
+ validate=self._validate_flag,
606
+ timed=self._timed,
509
607
  )
510
608
 
511
609
  def prepend(self, operation: Operation):
512
610
  """Prepend an operation to the pipeline."""
513
611
  self._pipeline_layers.insert(0, operation)
514
- self.copy()
612
+ self.reinitialize()
515
613
 
516
614
  def append(self, operation: Operation):
517
615
  """Append an operation to the pipeline."""
518
616
  self._pipeline_layers.append(operation)
519
- self.copy()
617
+ self.reinitialize()
520
618
 
521
619
  def insert(self, index: int, operation: Operation):
522
620
  """Insert an operation at a specific index in the pipeline."""
523
621
  if index < 0 or index > len(self._pipeline_layers):
524
622
  raise IndexError("Index out of bounds for inserting operation.")
525
623
  self._pipeline_layers.insert(index, operation)
526
- return self.copy()
624
+ self.reinitialize()
527
625
 
528
626
  @property
529
627
  def operations(self):
530
628
  """Alias for self.layers to match the zea naming convention"""
531
629
  return self._pipeline_layers
532
630
 
533
- def timed_call(self, **inputs):
534
- """Process input data through the pipeline."""
631
+ def reset_timer(self):
632
+ """Reset the timer for timed operations."""
633
+ if self._timed:
634
+ self._callable_layers = self._get_timed_operations()
635
+ else:
636
+ log.warning(
637
+ "Timer has not been initialized. Set timed=True when initializing the pipeline."
638
+ )
535
639
 
536
- for op in self._pipeline_layers:
537
- timed_op = self.timer(op, name=op.__class__.__name__)
538
- try:
539
- outputs = timed_op(**inputs)
540
- except KeyError as exc:
541
- raise KeyError(
542
- f"[zea.Pipeline] Operation '{op.__class__.__name__}' "
543
- f"requires input key '{exc.args[0]}', "
544
- "but it was not provided in the inputs.\n"
545
- "Check whether the objects (such as `zea.Scan`) passed to "
546
- "`pipeline.prepare_parameters()` contain all required keys.\n"
547
- f"Current list of all passed keys: {list(inputs.keys())}\n"
548
- f"Valid keys for this pipeline: {self.valid_keys}"
549
- ) from exc
550
- except Exception as exc:
551
- raise RuntimeError(
552
- f"[zea.Pipeline] Error in operation '{op.__class__.__name__}': {exc}"
553
- ) from exc
554
- inputs = outputs
555
- return outputs
640
+ def _get_timed_operations(self):
641
+ """Get a list of timed operations."""
642
+ self.timer = FunctionTimer()
643
+ return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
556
644
 
557
645
  def call(self, **inputs):
558
646
  """Process input data through the pipeline."""
559
- for operation in self._pipeline_layers:
647
+ for operation in self._callable_layers:
560
648
  try:
561
649
  outputs = operation(**inputs)
562
650
  except KeyError as exc:
@@ -579,14 +667,9 @@ class Pipeline:
579
667
  def __call__(self, return_numpy=False, **inputs):
580
668
  """Process input data through the pipeline."""
581
669
 
582
- if any(key in inputs for key in ["probe", "scan", "config"]):
583
- raise ValueError(
584
- "Probe, Scan and Config objects should be first processed with "
585
- "`Pipeline.prepare_parameters` before calling the pipeline. "
586
- "e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
587
- )
588
-
589
- if any(isinstance(arg, ZEAObject) for arg in inputs.values()):
670
+ if any(key in inputs for key in ["probe", "scan", "config"]) or any(
671
+ isinstance(arg, ZEAObject) for arg in inputs.values()
672
+ ):
590
673
  raise ValueError(
591
674
  "Probe, Scan and Config objects should be first processed with "
592
675
  "`Pipeline.prepare_parameters` before calling the pipeline. "
@@ -640,18 +723,13 @@ class Pipeline:
640
723
  if operation.jittable and operation._jit_compile:
641
724
  operation.set_jit(value == "ops")
642
725
 
643
- @property
644
- def _call_fn(self):
645
- """Get the call function of the pipeline."""
646
- return self.call if not self.timed else self.timed_call
647
-
648
726
  def jit(self):
649
727
  """JIT compile the pipeline."""
650
- self._call_pipeline = jit(self._call_fn, **self.jit_kwargs)
728
+ self._call_pipeline = jit(self.call, **self.jit_kwargs)
651
729
 
652
730
  def unjit(self):
653
731
  """Un-JIT compile the pipeline."""
654
- self._call_pipeline = self._call_fn
732
+ self._call_pipeline = self.call
655
733
 
656
734
  @property
657
735
  def jittable(self):
@@ -835,16 +913,17 @@ class Pipeline:
835
913
  Must have a ``pipeline`` key with a subkey ``operations``.
836
914
 
837
915
  Example:
838
- .. code-block:: python
839
-
840
- config = Config(
841
- {
842
- "operations": [
843
- "identity",
844
- ],
845
- }
846
- )
847
- pipeline = Pipeline.from_config(config)
916
+ .. doctest::
917
+
918
+ >>> from zea import Config, Pipeline
919
+ >>> config = Config(
920
+ ... {
921
+ ... "operations": [
922
+ ... "identity",
923
+ ... ],
924
+ ... }
925
+ ... )
926
+ >>> pipeline = Pipeline.from_config(config)
848
927
  """
849
928
  return pipeline_from_config(Config(config), **kwargs)
850
929
 
@@ -860,9 +939,20 @@ class Pipeline:
860
939
  Must have the a `pipeline` key with a subkey `operations`.
861
940
 
862
941
  Example:
863
- ```python
864
- pipeline = Pipeline.from_yaml("pipeline.yaml")
865
- ```
942
+ .. doctest::
943
+
944
+ >>> import yaml
945
+ >>> from zea import Config
946
+ >>> # Create a sample pipeline YAML file
947
+ >>> pipeline_dict = {
948
+ ... "operations": [
949
+ ... "identity",
950
+ ... ],
951
+ ... }
952
+ >>> with open("pipeline.yaml", "w") as f:
953
+ ... yaml.dump(pipeline_dict, f)
954
+ >>> from zea.ops import Pipeline
955
+ >>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
866
956
  """
867
957
  return pipeline_from_yaml(file_path, **kwargs)
868
958
 
@@ -963,7 +1053,7 @@ class Pipeline:
963
1053
  assert isinstance(scan, Scan), (
964
1054
  f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
965
1055
  )
966
- scan_dict = scan.to_tensor(include=self.valid_keys, keep_as_is=self.static_params)
1056
+ scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
967
1057
 
968
1058
  if config is not None:
969
1059
  assert isinstance(config, Config), (
@@ -1004,15 +1094,17 @@ def make_operation_chain(
1004
1094
  list: List of operations to be performed.
1005
1095
 
1006
1096
  Example:
1007
- .. code-block:: python
1008
-
1009
- chain = make_operation_chain(
1010
- [
1011
- "envelope_detect",
1012
- {"name": "normalize", "params": {"output_range": (0, 1)}},
1013
- SomeCustomOperation(),
1014
- ]
1015
- )
1097
+ .. doctest::
1098
+
1099
+ >>> from zea.ops import make_operation_chain, LogCompress
1100
+ >>> SomeCustomOperation = LogCompress # just for demonstration
1101
+ >>> chain = make_operation_chain(
1102
+ ... [
1103
+ ... "envelope_detect",
1104
+ ... {"name": "normalize", "params": {"output_range": (0, 1)}},
1105
+ ... SomeCustomOperation(),
1106
+ ... ]
1107
+ ... )
1016
1108
  """
1017
1109
  chain = []
1018
1110
  for operation in operation_chain:
@@ -1228,11 +1320,23 @@ class PatchedGrid(Pipeline):
1228
1320
  for operation in self.operations:
1229
1321
  operation.with_batch_dim = False
1230
1322
 
1323
+ @property
1324
+ def _extra_keys(self):
1325
+ return {"flatgrid", "grid_size_x", "grid_size_z"}
1326
+
1231
1327
  @property
1232
1328
  def valid_keys(self) -> set:
1233
- """Get a set of valid keys for the pipeline. Adds the parameters that PatchedGrid itself
1234
- operates on (even if not used by operations inside it)."""
1235
- return super().valid_keys.union({"flatgrid", "grid_size_x", "grid_size_z"})
1329
+ """Get a set of valid keys for the pipeline.
1330
+ Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1331
+ inside it)."""
1332
+ return super().valid_keys.union(self._extra_keys)
1333
+
1334
+ @property
1335
+ def needs_keys(self) -> set:
1336
+ """Get a set of all input keys needed by the pipeline.
1337
+ Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1338
+ inside it)."""
1339
+ return super().needs_keys.union(self._extra_keys)
1236
1340
 
1237
1341
  def call_item(self, inputs):
1238
1342
  """Process data in patches."""
@@ -1242,35 +1346,22 @@ class PatchedGrid(Pipeline):
1242
1346
  grid_size_z = inputs["grid_size_z"]
1243
1347
  flatgrid = inputs.pop("flatgrid")
1244
1348
 
1245
- # TODO: maybe using n_tx and n_el from kwargs is better but these are tensors now
1246
- # and this is not supported in broadcast_to
1247
- n_tx = inputs[self.key].shape[0]
1248
- n_pix = flatgrid.shape[0]
1249
- n_el = inputs[self.key].shape[2]
1250
- inputs["rx_apo"] = ops.broadcast_to(inputs.get("rx_apo", 1.0), (n_tx, n_pix, n_el))
1251
- inputs["rx_apo"] = ops.swapaxes(inputs["rx_apo"], 0, 1) # put n_pix first
1252
-
1253
1349
  # Define a list of keys to look up for patching
1254
- patch_keys = ["flat_pfield", "rx_apo"]
1350
+ flat_pfield = inputs.pop("flat_pfield", None)
1255
1351
 
1256
- patch_arrays = {}
1257
- for key in patch_keys:
1258
- if key in inputs:
1259
- patch_arrays[key] = inputs.pop(key)
1260
-
1261
- def patched_call(flatgrid, **patch_kwargs):
1262
- patch_args = {k: v for k, v in patch_kwargs.items() if v is not None}
1263
- patch_args["rx_apo"] = ops.swapaxes(patch_args["rx_apo"], 0, 1)
1264
- out = super(PatchedGrid, self).call(flatgrid=flatgrid, **patch_args, **inputs)
1352
+ def patched_call(flatgrid, flat_pfield):
1353
+ out = super(PatchedGrid, self).call(
1354
+ flatgrid=flatgrid, flat_pfield=flat_pfield, **inputs
1355
+ )
1265
1356
  return out[self.output_key]
1266
1357
 
1267
- out = patched_map(
1358
+ out = vmap(
1268
1359
  patched_call,
1269
- flatgrid,
1270
- self.num_patches,
1271
- **patch_arrays,
1272
- jit=bool(self.jit_options),
1273
- )
1360
+ chunks=self.num_patches,
1361
+ fn_supports_batch=True,
1362
+ disable_jit=not bool(self.jit_options),
1363
+ )(flatgrid, flat_pfield)
1364
+
1274
1365
  return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
1275
1366
 
1276
1367
  def jittable_call(self, **inputs):
@@ -1309,7 +1400,7 @@ class Identity(Operation):
1309
1400
 
1310
1401
  def call(self, **kwargs) -> Dict:
1311
1402
  """Returns the input as is."""
1312
- return kwargs
1403
+ return {}
1313
1404
 
1314
1405
 
1315
1406
  @ops_registry("merge")
@@ -1399,6 +1490,7 @@ class Simulate(Operation):
1399
1490
  def __init__(self, **kwargs):
1400
1491
  super().__init__(
1401
1492
  output_data_type=DataTypes.RAW_DATA,
1493
+ additional_output_keys=["n_ch"],
1402
1494
  **kwargs,
1403
1495
  )
1404
1496
 
@@ -1451,18 +1543,16 @@ class TOFCorrection(Operation):
1451
1543
  STATIC_PARAMS = [
1452
1544
  "f_number",
1453
1545
  "apply_lens_correction",
1454
- "apply_phase_rotation",
1455
1546
  "grid_size_x",
1456
1547
  "grid_size_z",
1457
1548
  ]
1458
1549
 
1459
- def __init__(self, apply_phase_rotation=True, **kwargs):
1550
+ def __init__(self, **kwargs):
1460
1551
  super().__init__(
1461
1552
  input_data_type=DataTypes.RAW_DATA,
1462
1553
  output_data_type=DataTypes.ALIGNED_DATA,
1463
1554
  **kwargs,
1464
1555
  )
1465
- self.apply_phase_rotation = apply_phase_rotation
1466
1556
 
1467
1557
  def call(
1468
1558
  self,
@@ -1477,6 +1567,8 @@ class TOFCorrection(Operation):
1477
1567
  tx_apodizations,
1478
1568
  initial_times,
1479
1569
  probe_geometry,
1570
+ t_peak,
1571
+ tx_waveform_indices,
1480
1572
  apply_lens_correction=None,
1481
1573
  lens_thickness=None,
1482
1574
  lens_sound_speed=None,
@@ -1497,6 +1589,9 @@ class TOFCorrection(Operation):
1497
1589
  tx_apodizations (ops.Tensor): Transmit apodizations
1498
1590
  initial_times (ops.Tensor): Initial times
1499
1591
  probe_geometry (ops.Tensor): Probe element positions
1592
+ t_peak (float): Time to peak of the transmit pulse
1593
+ tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
1594
+ transmit. (All zero if there is only one waveform)
1500
1595
  apply_lens_correction (bool): Whether to apply lens correction
1501
1596
  lens_thickness (float): Lens thickness
1502
1597
  lens_sound_speed (float): Sound speed in the lens
@@ -1507,29 +1602,30 @@ class TOFCorrection(Operation):
1507
1602
 
1508
1603
  raw_data = kwargs[self.key]
1509
1604
 
1510
- kwargs = {
1605
+ tof_kwargs = {
1511
1606
  "flatgrid": flatgrid,
1512
- "sound_speed": sound_speed,
1513
- "angles": polar_angles,
1514
- "focus_distances": focus_distances,
1515
- "sampling_frequency": sampling_frequency,
1516
- "fnum": f_number,
1517
- "apply_phase_rotation": self.apply_phase_rotation,
1518
- "demodulation_frequency": demodulation_frequency,
1519
1607
  "t0_delays": t0_delays,
1520
1608
  "tx_apodizations": tx_apodizations,
1521
- "initial_times": initial_times,
1609
+ "sound_speed": sound_speed,
1522
1610
  "probe_geometry": probe_geometry,
1611
+ "initial_times": initial_times,
1612
+ "sampling_frequency": sampling_frequency,
1613
+ "demodulation_frequency": demodulation_frequency,
1614
+ "f_number": f_number,
1615
+ "polar_angles": polar_angles,
1616
+ "focus_distances": focus_distances,
1617
+ "t_peak": t_peak,
1618
+ "tx_waveform_indices": tx_waveform_indices,
1523
1619
  "apply_lens_correction": apply_lens_correction,
1524
1620
  "lens_thickness": lens_thickness,
1525
1621
  "lens_sound_speed": lens_sound_speed,
1526
1622
  }
1527
1623
 
1528
1624
  if not self.with_batch_dim:
1529
- tof_corrected = tof_correction(raw_data, **kwargs)
1625
+ tof_corrected = tof_correction(raw_data, **tof_kwargs)
1530
1626
  else:
1531
1627
  tof_corrected = ops.map(
1532
- lambda data: tof_correction(data, **kwargs),
1628
+ lambda data: tof_correction(data, **tof_kwargs),
1533
1629
  raw_data,
1534
1630
  )
1535
1631
 
@@ -1587,44 +1683,35 @@ class DelayAndSum(Operation):
1587
1683
  **kwargs,
1588
1684
  ):
1589
1685
  super().__init__(
1590
- input_data_type=None,
1686
+ input_data_type=DataTypes.ALIGNED_DATA,
1591
1687
  output_data_type=DataTypes.BEAMFORMED_DATA,
1592
1688
  **kwargs,
1593
1689
  )
1594
1690
  self.reshape_grid = reshape_grid
1595
1691
 
1596
- def process_image(self, data, rx_apo):
1692
+ def process_image(self, data):
1597
1693
  """Performs DAS beamforming on tof-corrected input.
1598
1694
 
1599
1695
  Args:
1600
1696
  data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
1601
- rx_apo (ops.Tensor): Receive apodization window of shape `(n_tx, n_pix, n_el, n_ch)`.
1602
1697
 
1603
1698
  Returns:
1604
1699
  ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
1605
1700
  """
1606
1701
  # Sum over the channels, i.e. DAS
1607
- data = ops.sum(rx_apo * data, -2)
1702
+ data = ops.sum(data, -2)
1608
1703
 
1609
1704
  # Sum over transmits, i.e. Compounding
1610
1705
  data = ops.sum(data, 0)
1611
1706
 
1612
1707
  return data
1613
1708
 
1614
- def call(
1615
- self,
1616
- rx_apo=None,
1617
- grid=None,
1618
- **kwargs,
1619
- ):
1709
+ def call(self, grid=None, **kwargs):
1620
1710
  """Performs DAS beamforming on tof-corrected input.
1621
1711
 
1622
1712
  Args:
1623
1713
  tof_corrected_data (ops.Tensor): The TOF corrected input of shape
1624
1714
  `(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
1625
- rx_apo (ops.Tensor): Receive apodization window
1626
- of shape `(n_tx, grid_size_z*grid_size_x, n_el)`
1627
- with optional batch dimension. Defaults to 1.0.
1628
1715
 
1629
1716
  Returns:
1630
1717
  dict: Dictionary containing beamformed_data
@@ -1634,17 +1721,11 @@ class DelayAndSum(Operation):
1634
1721
  """
1635
1722
  data = kwargs[self.key]
1636
1723
 
1637
- if rx_apo is None:
1638
- rx_apo = ops.ones(1, dtype=ops.dtype(data))
1639
- rx_apo = ops.broadcast_to(rx_apo[..., None], data.shape)
1640
-
1641
1724
  if not self.with_batch_dim:
1642
- beamformed_data = self.process_image(data, rx_apo)
1725
+ beamformed_data = self.process_image(data)
1643
1726
  else:
1644
1727
  # Apply process_image to each item in the batch
1645
- beamformed_data = batched_map(
1646
- lambda data, rx_apo: self.process_image(data, rx_apo), data, rx_apo=rx_apo
1647
- )
1728
+ beamformed_data = ops.map(self.process_image, data)
1648
1729
 
1649
1730
  if self.reshape_grid:
1650
1731
  beamformed_data = reshape_axis(
@@ -1654,6 +1735,46 @@ class DelayAndSum(Operation):
1654
1735
  return {self.output_key: beamformed_data}
1655
1736
 
1656
1737
 
1738
+ def envelope_detect(data, axis=-3):
1739
+ """Envelope detection of RF signals.
1740
+
1741
+ If the input data is real, it first applies the Hilbert transform along the specified axis
1742
+ and then computes the magnitude of the resulting complex signal.
1743
+ If the input data is complex, it computes the magnitude directly.
1744
+
1745
+ Args:
1746
+ - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
1747
+ - axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
1748
+
1749
+ Returns:
1750
+ - envelope_data (Tensor): The envelope detected data
1751
+ of shape (..., grid_size_z, grid_size_x).
1752
+ """
1753
+ if data.shape[-1] == 2:
1754
+ data = channels_to_complex(data)
1755
+ else:
1756
+ n_ax = ops.shape(data)[axis]
1757
+ n_ax_float = ops.cast(n_ax, "float32")
1758
+
1759
+ # Calculate next power of 2: M = 2^ceil(log2(n_ax))
1760
+ # see https://github.com/tue-bmd/zea/discussions/147
1761
+ log2_n_ax = ops.log2(n_ax_float)
1762
+ M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
1763
+
1764
+ data = hilbert(data, N=M, axis=axis)
1765
+ indices = ops.arange(n_ax)
1766
+
1767
+ data = ops.take(data, indices, axis=axis)
1768
+ data = ops.squeeze(data, axis=-1)
1769
+
1770
+ # data = ops.abs(data)
1771
+ real = ops.real(data)
1772
+ imag = ops.imag(data)
1773
+ data = ops.sqrt(real**2 + imag**2)
1774
+ data = ops.cast(data, "float32")
1775
+ return data
1776
+
1777
+
1657
1778
  @ops_registry("envelope_detect")
1658
1779
  class EnvelopeDetect(Operation):
1659
1780
  """Envelope detection of RF signals."""
@@ -1680,23 +1801,7 @@ class EnvelopeDetect(Operation):
1680
1801
  """
1681
1802
  data = kwargs[self.key]
1682
1803
 
1683
- if data.shape[-1] == 2:
1684
- data = channels_to_complex(data)
1685
- else:
1686
- n_ax = data.shape[self.axis]
1687
- M = 2 ** int(np.ceil(np.log2(n_ax)))
1688
- # data = scipy.signal.hilbert(data, N=M, axis=self.axis)
1689
- data = hilbert(data, N=M, axis=self.axis)
1690
- indices = ops.arange(n_ax)
1691
-
1692
- data = ops.take(data, indices, axis=self.axis)
1693
- data = ops.squeeze(data, axis=-1)
1694
-
1695
- # data = ops.abs(data)
1696
- real = ops.real(data)
1697
- imag = ops.imag(data)
1698
- data = ops.sqrt(real**2 + imag**2)
1699
- data = ops.cast(data, "float32")
1804
+ data = envelope_detect(data, axis=self.axis)
1700
1805
 
1701
1806
  return {self.output_key: data}
1702
1807
 
@@ -1734,19 +1839,29 @@ class UpMix(Operation):
1734
1839
  return {self.output_key: data}
1735
1840
 
1736
1841
 
1842
+ def log_compress(data, eps=1e-16):
1843
+ """Apply logarithmic compression to data."""
1844
+ eps = ops.convert_to_tensor(eps, dtype=data.dtype)
1845
+ data = ops.where(data == 0, eps, data) # Avoid log(0)
1846
+ return 20 * keras.ops.log10(data)
1847
+
1848
+
1737
1849
  @ops_registry("log_compress")
1738
1850
  class LogCompress(Operation):
1739
1851
  """Logarithmic compression of data."""
1740
1852
 
1741
- def __init__(
1742
- self,
1743
- **kwargs,
1744
- ):
1853
+ def __init__(self, clip: bool = True, **kwargs):
1854
+ """Initialize the LogCompress operation.
1855
+
1856
+ Args:
1857
+ clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
1858
+ """
1745
1859
  super().__init__(
1746
1860
  input_data_type=DataTypes.ENVELOPE_DATA,
1747
1861
  output_data_type=DataTypes.IMAGE,
1748
1862
  **kwargs,
1749
1863
  )
1864
+ self.clip = clip
1750
1865
 
1751
1866
  def call(self, dynamic_range=None, **kwargs):
1752
1867
  """Apply logarithmic compression to data.
@@ -1763,20 +1878,43 @@ class LogCompress(Operation):
1763
1878
  dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
1764
1879
  dynamic_range = ops.cast(dynamic_range, data.dtype)
1765
1880
 
1766
- small_number = ops.convert_to_tensor(1e-16, dtype=data.dtype)
1767
- data = ops.where(data == 0, small_number, data)
1768
- compressed_data = 20 * ops.log10(data)
1769
- compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
1881
+ compressed_data = log_compress(data)
1882
+ if self.clip:
1883
+ compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
1770
1884
 
1771
1885
  return {self.output_key: compressed_data}
1772
1886
 
1773
1887
 
1888
+ def normalize(data, output_range, input_range=None):
1889
+ """Normalize data to a given range.
1890
+
1891
+ Equivalent to `translate` with clipping.
1892
+
1893
+ Args:
1894
+ data (ops.Tensor): Input data to normalize.
1895
+ output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
1896
+ input_range (tuple, optional): Range of input data.
1897
+ If None, the range will be computed from the data.
1898
+ Defaults to None.
1899
+ """
1900
+ if input_range is None:
1901
+ input_range = (None, None)
1902
+ minval, maxval = input_range
1903
+ if minval is None:
1904
+ minval = ops.min(data)
1905
+ if maxval is None:
1906
+ maxval = ops.max(data)
1907
+ data = ops.clip(data, minval, maxval)
1908
+ normalized_data = translate(data, (minval, maxval), output_range)
1909
+ return normalized_data
1910
+
1911
+
1774
1912
  @ops_registry("normalize")
1775
1913
  class Normalize(Operation):
1776
1914
  """Normalize data to a given range."""
1777
1915
 
1778
1916
  def __init__(self, output_range=None, input_range=None, **kwargs):
1779
- super().__init__(**kwargs)
1917
+ super().__init__(additional_output_keys=["minval", "maxval"], **kwargs)
1780
1918
  if output_range is None:
1781
1919
  output_range = (0, 1)
1782
1920
  self.output_range = self.to_float32(output_range)
@@ -1821,11 +1959,9 @@ class Normalize(Operation):
1821
1959
  if maxval is None:
1822
1960
  maxval = ops.max(data)
1823
1961
 
1824
- # Clip the data to the input range
1825
- data = ops.clip(data, minval, maxval)
1826
-
1827
- # Map the data to the output range
1828
- normalized_data = translate(data, (minval, maxval), self.output_range)
1962
+ normalized_data = normalize(
1963
+ data, output_range=self.output_range, input_range=(minval, maxval)
1964
+ )
1829
1965
 
1830
1966
  return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
1831
1967
 
@@ -1855,6 +1991,18 @@ class ScanConvert(Operation):
1855
1991
  input_data_type=DataTypes.IMAGE,
1856
1992
  output_data_type=DataTypes.IMAGE_SC,
1857
1993
  jittable=jittable,
1994
+ additional_output_keys=[
1995
+ "resolution",
1996
+ "x_lim",
1997
+ "y_lim",
1998
+ "z_lim",
1999
+ "rho_range",
2000
+ "theta_range",
2001
+ "phi_range",
2002
+ "d_rho",
2003
+ "d_theta",
2004
+ "d_phi",
2005
+ ],
1858
2006
  **kwargs,
1859
2007
  )
1860
2008
  self.order = order
@@ -2081,6 +2229,7 @@ class Demodulate(Operation):
2081
2229
  input_data_type=DataTypes.RAW_DATA,
2082
2230
  output_data_type=DataTypes.RAW_DATA,
2083
2231
  jittable=True,
2232
+ additional_output_keys=["demodulation_frequency", "center_frequency", "n_ch"],
2084
2233
  **kwargs,
2085
2234
  )
2086
2235
  self.axis = axis
@@ -2145,7 +2294,7 @@ class Lambda(Operation):
2145
2294
 
2146
2295
 
2147
2296
  @ops_registry("pad")
2148
- class Pad(Operation, TFDataLayer):
2297
+ class Pad(Operation, DataLayer):
2149
2298
  """Pad layer for padding tensors to a specified shape."""
2150
2299
 
2151
2300
  def __init__(
@@ -2330,6 +2479,7 @@ class Downsample(Operation):
2330
2479
 
2331
2480
  def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
2332
2481
  super().__init__(
2482
+ additional_output_keys=["sampling_frequency", "n_ax"],
2333
2483
  **kwargs,
2334
2484
  )
2335
2485
  self.factor = factor
@@ -3095,6 +3245,50 @@ def demodulate(data, center_frequency, sampling_frequency, axis=-3):
3095
3245
  iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
3096
3246
 
3097
3247
  # Split the complex signal into two channels
3098
- iq_data_two_channel = complex_to_channels(iq_data_signal_complex[..., 0])
3248
+ iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
3099
3249
 
3100
3250
  return iq_data_two_channel
3251
+
3252
+
3253
+ def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
3254
+ """Compute the time of the peak of each waveform in a stack of waveforms.
3255
+
3256
+ Args:
3257
+ waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
3258
+ center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
3259
+ (n_waveforms,) or a scalar if all waveforms have the same center frequency.
3260
+ waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
3261
+
3262
+ Returns:
3263
+ ndarray: The time to peak for each waveform in seconds.
3264
+ """
3265
+ t_peak = []
3266
+ center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
3267
+ for waveform, center_frequency in zip(waveforms, center_frequencies):
3268
+ t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
3269
+ return ops.stack(t_peak)
3270
+
3271
+
3272
+ def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
3273
+ """Compute the time of the peak of the waveform.
3274
+
3275
+ Args:
3276
+ waveform (ndarray): The waveform of shape (n_samples).
3277
+ center_frequency (float): The center frequency of the waveform in Hz.
3278
+ waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
3279
+
3280
+ Returns:
3281
+ float: The time to peak for the waveform in seconds.
3282
+ """
3283
+ n_samples = waveform.shape[0]
3284
+ if n_samples == 0:
3285
+ raise ValueError("Waveform has zero samples.")
3286
+
3287
+ waveforms_iq_complex_channels = demodulate(
3288
+ waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
3289
+ )
3290
+ waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
3291
+ envelope = ops.abs(waveforms_iq_complex)
3292
+ peak_idx = ops.argmax(envelope, axis=-1)
3293
+ t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
3294
+ return t_peak