zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/ops.py CHANGED
@@ -11,39 +11,62 @@ Operations can be run on their own:
11
11
 
12
12
  Examples
13
13
  ^^^^^^^^
14
- .. code-block:: python
14
+ .. doctest::
15
15
 
16
- data = np.random.randn(2000, 128, 1)
17
- # static arguments are passed in the constructor
18
- envelope_detect = EnvelopeDetect(axis=-1)
19
- # other parameters can be passed here along with the data
20
- envelope_data = envelope_detect(data=data)
16
+ >>> import numpy as np
17
+ >>> from zea.ops import EnvelopeDetect
18
+ >>> data = np.random.randn(2000, 128, 1)
19
+ >>> # static arguments are passed in the constructor
20
+ >>> envelope_detect = EnvelopeDetect(axis=-1)
21
+ >>> # other parameters can be passed here along with the data
22
+ >>> envelope_data = envelope_detect(data=data)
21
23
 
22
24
  Using a pipeline
23
25
  ----------------
24
26
 
25
27
  You can initialize with a default pipeline or create your own custom pipeline.
26
28
 
27
- .. code-block:: python
29
+ .. doctest::
28
30
 
29
- pipeline = Pipeline.from_default()
31
+ >>> from zea.ops import Pipeline, EnvelopeDetect, Normalize, LogCompress
32
+ >>> pipeline = Pipeline.from_default()
30
33
 
31
- operations = [
32
- EnvelopeDetect(),
33
- Normalize(),
34
- LogCompress(),
35
- ]
36
- pipeline_custom = Pipeline(operations)
34
+ >>> operations = [
35
+ ... EnvelopeDetect(),
36
+ ... Normalize(),
37
+ ... LogCompress(),
38
+ ... ]
39
+ >>> pipeline_custom = Pipeline(operations)
37
40
 
38
41
  One can also load a pipeline from a config or yaml/json file:
39
42
 
40
- .. code-block:: python
43
+ .. doctest::
44
+
45
+ >>> from zea import Pipeline
46
+
47
+ >>> # From JSON string
48
+ >>> json_string = '{"operations": ["identity"]}'
49
+ >>> pipeline = Pipeline.from_json(json_string)
50
+
51
+ >>> # from yaml file
52
+ >>> import yaml
53
+ >>> from zea import Config
54
+ >>> # Create a sample pipeline YAML file
55
+ >>> pipeline_dict = {
56
+ ... "operations": [
57
+ ... {"name": "identity"},
58
+ ... ]
59
+ ... }
60
+ >>> with open("pipeline.yaml", "w") as f:
61
+ ... yaml.dump(pipeline_dict, f)
62
+ >>> yaml_file = "pipeline.yaml"
63
+ >>> pipeline = Pipeline.from_yaml(yaml_file)
64
+
65
+ .. testcleanup::
41
66
 
42
- json_string = '{"operations": ["identity"]}'
43
- pipeline = Pipeline.from_json(json_string)
67
+ import os
44
68
 
45
- yaml_file = "pipeline.yaml"
46
- pipeline = Pipeline.from_yaml(yaml_file)
69
+ os.remove("pipeline.yaml")
47
70
 
48
71
  Example of a yaml file:
49
72
 
@@ -56,8 +79,6 @@ Example of a yaml file:
56
79
  params:
57
80
  operations:
58
81
  - name: tof_correction
59
- params:
60
- apply_phase_rotation: true
61
82
  - name: pfield_weighting
62
83
  - name: delay_and_sum
63
84
  num_patches: 100
@@ -67,10 +88,10 @@ Example of a yaml file:
67
88
 
68
89
  """
69
90
 
70
- import copy
71
91
  import hashlib
72
92
  import inspect
73
93
  import json
94
+ import uuid
74
95
  from functools import partial
75
96
  from typing import Any, Dict, List, Union
76
97
 
@@ -79,7 +100,7 @@ import numpy as np
79
100
  import scipy
80
101
  import yaml
81
102
  from keras import ops
82
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
103
+ from keras.src.layers.preprocessing.data_layer import DataLayer
83
104
 
84
105
  from zea import log
85
106
  from zea.backend import jit
@@ -99,8 +120,20 @@ from zea.internal.registry import ops_registry
99
120
  from zea.probes import Probe
100
121
  from zea.scan import Scan
101
122
  from zea.simulator import simulate_rf
102
- from zea.tensor_ops import batched_map, patched_map, resample, reshape_axis
103
- from zea.utils import FunctionTimer, deep_compare, map_negative_indices, translate
123
+ from zea.tensor_ops import (
124
+ apply_along_axis,
125
+ correlate,
126
+ extend_n_dims,
127
+ resample,
128
+ reshape_axis,
129
+ translate,
130
+ vmap,
131
+ )
132
+ from zea.utils import (
133
+ FunctionTimer,
134
+ deep_compare,
135
+ map_negative_indices,
136
+ )
104
137
 
105
138
 
106
139
  def get_ops(ops_name):
@@ -125,6 +158,7 @@ class Operation(keras.Operation):
125
158
  with_batch_dim: bool = True,
126
159
  jit_kwargs: dict | None = None,
127
160
  jittable: bool = True,
161
+ additional_output_keys: List[str] = None,
128
162
  **kwargs,
129
163
  ):
130
164
  """
@@ -143,6 +177,10 @@ class Operation(keras.Operation):
143
177
  with_batch_dim: Whether operations should expect a batch dimension in the input
144
178
  jit_kwargs: Additional keyword arguments for the JIT compiler
145
179
  jittable: Whether the operation can be JIT compiled
180
+ additional_output_keys: A list of additional output keys produced by the operation.
181
+ These are used to track if all keys are available for downstream operations.
182
+ If the operation has a conditional output, it is best to add all possible
183
+ output keys here.
146
184
  """
147
185
  super().__init__(**kwargs)
148
186
 
@@ -153,6 +191,7 @@ class Operation(keras.Operation):
153
191
  self.output_key = output_key # Key for output data
154
192
  if self.output_key is None:
155
193
  self.output_key = self.key
194
+ self.additional_output_keys = additional_output_keys or []
156
195
 
157
196
  self.inputs = [] # Source(s) of input data (name of a previous operation)
158
197
  self.allow_multiple_inputs = False # Only single input allowed by default
@@ -165,8 +204,6 @@ class Operation(keras.Operation):
165
204
  self._output_cache = {}
166
205
 
167
206
  # Obtain the input signature of the `call` method
168
- self._input_signature = None
169
- self._valid_keys = None # Keys valid for the `call` method
170
207
  self._trace_signatures()
171
208
 
172
209
  if jit_kwargs is None:
@@ -186,6 +223,11 @@ class Operation(keras.Operation):
186
223
  with log.set_level("ERROR"):
187
224
  self.set_jit(jit_compile)
188
225
 
226
+ @property
227
+ def output_keys(self) -> List[str]:
228
+ """Get the output keys of the operation."""
229
+ return [self.output_key] + self.additional_output_keys
230
+
189
231
  @property
190
232
  def static_params(self):
191
233
  """Get the static parameters of the operation."""
@@ -204,13 +246,18 @@ class Operation(keras.Operation):
204
246
  Analyze and store the input/output signatures of the `call` method.
205
247
  """
206
248
  self._input_signature = inspect.signature(self.call)
207
- self._valid_keys = set(self._input_signature.parameters.keys())
249
+ self._valid_keys = set(self._input_signature.parameters.keys()) | {self.key}
208
250
 
209
251
  @property
210
- def valid_keys(self):
252
+ def valid_keys(self) -> set:
211
253
  """Get the valid keys for the `call` method."""
212
254
  return self._valid_keys
213
255
 
256
+ @property
257
+ def needs_keys(self) -> set:
258
+ """Get a set of all input keys needed by the operation."""
259
+ return self.valid_keys
260
+
214
261
  @property
215
262
  def jittable(self):
216
263
  """Check if the operation can be JIT compiled."""
@@ -366,6 +413,36 @@ class Operation(keras.Operation):
366
413
  return True
367
414
 
368
415
 
416
+ class ImageOperation(Operation):
417
+ """
418
+ Base class for image processing operations.
419
+
420
+ This class extends the Operation class to provide a common interface
421
+ for operations that process image data, with shape (batch, height, width, channels)
422
+ or (height, width, channels) if batch dimension is not present.
423
+
424
+ Subclasses should implement the `call` method to define the image processing logic, and call
425
+ ``super().call(**kwargs)`` to validate the input data shape.
426
+ """
427
+
428
+ def call(self, **kwargs):
429
+ """
430
+ Validate input data shape for image operations.
431
+
432
+ Args:
433
+ **kwargs: Keyword arguments containing input data.
434
+
435
+ Raises:
436
+ AssertionError: If input data does not have the expected number of dimensions.
437
+ """
438
+ data = kwargs[self.key]
439
+
440
+ if self.with_batch_dim:
441
+ assert ops.ndim(data) == 4, "Input data must have 4 dimensions (b, h, w, c)."
442
+ else:
443
+ assert ops.ndim(data) == 3, "Input data must have 3 dimensions (h, w, c)."
444
+
445
+
369
446
  @ops_registry("pipeline")
370
447
  class Pipeline:
371
448
  """Pipeline class for processing ultrasound data through a series of operations."""
@@ -408,8 +485,6 @@ class Pipeline:
408
485
  """
409
486
  self._call_pipeline = self.call
410
487
  self.name = name
411
- self.timer = FunctionTimer()
412
- self.timed = timed
413
488
 
414
489
  self._pipeline_layers = operations
415
490
 
@@ -419,6 +494,24 @@ class Pipeline:
419
494
  self.with_batch_dim = with_batch_dim
420
495
  self._validate_flag = validate
421
496
 
497
+ # Setup timer
498
+ if jit_options == "pipeline" and timed:
499
+ raise ValueError(
500
+ "timed=True cannot be used with jit_options='pipeline' as the entire "
501
+ "pipeline is compiled into a single function. Try setting jit_options to "
502
+ "'ops' or None."
503
+ )
504
+ if timed:
505
+ log.warning(
506
+ "Timer has been initialized for the pipeline. To get an accurate timing estimate, "
507
+ "the `block_until_ready()` is used, which will slow down the execution, so "
508
+ "do not use for regular processing!"
509
+ )
510
+ self._callable_layers = self._get_timed_operations()
511
+ else:
512
+ self._callable_layers = self._pipeline_layers
513
+ self._timed = timed
514
+
422
515
  if validate:
423
516
  self.validate()
424
517
  else:
@@ -427,19 +520,40 @@ class Pipeline:
427
520
  if jit_kwargs is None:
428
521
  jit_kwargs = {}
429
522
 
430
- if keras.backend.backend() == "jax" and self.static_params:
523
+ if keras.backend.backend() == "jax" and self.static_params != []:
431
524
  jit_kwargs = {"static_argnames": self.static_params}
432
525
 
433
526
  self.jit_kwargs = jit_kwargs
434
527
  self.jit_options = jit_options # will handle the jit compilation
435
528
 
529
+ self._logged_difference_keys = False
530
+
531
+ # Do not log again for nested pipelines
532
+ for nested_pipeline in self._nested_pipelines:
533
+ nested_pipeline._logged_difference_keys = True
534
+
436
535
  def needs(self, key) -> bool:
437
- """Check if the pipeline needs a specific key."""
438
- return key in self.valid_keys
536
+ """Check if the pipeline needs a specific key at the input."""
537
+ return key in self.needs_keys
538
+
539
+ @property
540
+ def _nested_pipelines(self):
541
+ return [operation for operation in self.operations if isinstance(operation, Pipeline)]
542
+
543
+ @property
544
+ def output_keys(self) -> set:
545
+ """All output keys the pipeline guarantees to produce."""
546
+ output_keys = set()
547
+ for operation in self.operations:
548
+ output_keys.update(operation.output_keys)
549
+ return output_keys
439
550
 
440
551
  @property
441
552
  def valid_keys(self) -> set:
442
- """Get a set of valid keys for the pipeline."""
553
+ """Get a set of valid keys for the pipeline.
554
+
555
+ This is all keys that can be passed to the pipeline as input.
556
+ """
443
557
  valid_keys = set()
444
558
  for operation in self.operations:
445
559
  valid_keys.update(operation.valid_keys)
@@ -453,8 +567,26 @@ class Pipeline:
453
567
  static_params.extend(operation.static_params)
454
568
  return list(set(static_params))
455
569
 
570
+ @property
571
+ def needs_keys(self) -> set:
572
+ """Get a set of all input keys needed by the pipeline.
573
+
574
+ Will keep track of keys that are already provided by previous operations.
575
+ """
576
+ needs = set()
577
+ has_so_far = set()
578
+ previous_operation = None
579
+ for operation in self.operations:
580
+ if previous_operation is not None:
581
+ has_so_far.update(previous_operation.output_keys)
582
+ needs.update(operation.needs_keys - has_so_far)
583
+ previous_operation = operation
584
+ return needs
585
+
456
586
  @classmethod
457
- def from_default(cls, num_patches=100, baseband=False, pfield=False, **kwargs) -> "Pipeline":
587
+ def from_default(
588
+ cls, num_patches=100, baseband=False, pfield=False, timed=False, **kwargs
589
+ ) -> "Pipeline":
458
590
  """Create a default pipeline.
459
591
 
460
592
  Args:
@@ -465,6 +597,7 @@ class Pipeline:
465
597
  so input signal has a single channel dim and is still on carrier frequency.
466
598
  pfield (bool): If True, apply Pfield weighting. Defaults to False.
467
599
  This will calculate pressure field and only beamform the data to those locations.
600
+ timed (bool, optional): Whether to time each operation. Defaults to False.
468
601
  **kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
469
602
 
470
603
  """
@@ -476,7 +609,7 @@ class Pipeline:
476
609
 
477
610
  # Get beamforming ops
478
611
  beamforming = [
479
- TOFCorrection(apply_phase_rotation=True),
612
+ TOFCorrection(),
480
613
  DelayAndSum(),
481
614
  ]
482
615
  if pfield:
@@ -491,11 +624,12 @@ class Pipeline:
491
624
 
492
625
  # Add display ops
493
626
  operations += [
627
+ ReshapeGrid(),
494
628
  EnvelopeDetect(),
495
629
  Normalize(),
496
630
  LogCompress(),
497
631
  ]
498
- return cls(operations, **kwargs)
632
+ return cls(operations, timed=timed, **kwargs)
499
633
 
500
634
  def copy(self) -> "Pipeline":
501
635
  """Create a copy of the pipeline."""
@@ -506,57 +640,60 @@ class Pipeline:
506
640
  jit_kwargs=self.jit_kwargs,
507
641
  name=self.name,
508
642
  validate=self._validate_flag,
643
+ timed=self._timed,
644
+ )
645
+
646
+ def reinitialize(self):
647
+ """Reinitialize the pipeline in place."""
648
+ self.__init__(
649
+ self._pipeline_layers,
650
+ with_batch_dim=self.with_batch_dim,
651
+ jit_options=self.jit_options,
652
+ jit_kwargs=self.jit_kwargs,
653
+ name=self.name,
654
+ validate=self._validate_flag,
655
+ timed=self._timed,
509
656
  )
510
657
 
511
658
  def prepend(self, operation: Operation):
512
659
  """Prepend an operation to the pipeline."""
513
660
  self._pipeline_layers.insert(0, operation)
514
- self.copy()
661
+ self.reinitialize()
515
662
 
516
663
  def append(self, operation: Operation):
517
664
  """Append an operation to the pipeline."""
518
665
  self._pipeline_layers.append(operation)
519
- self.copy()
666
+ self.reinitialize()
520
667
 
521
668
  def insert(self, index: int, operation: Operation):
522
669
  """Insert an operation at a specific index in the pipeline."""
523
670
  if index < 0 or index > len(self._pipeline_layers):
524
671
  raise IndexError("Index out of bounds for inserting operation.")
525
672
  self._pipeline_layers.insert(index, operation)
526
- return self.copy()
673
+ self.reinitialize()
527
674
 
528
675
  @property
529
676
  def operations(self):
530
677
  """Alias for self.layers to match the zea naming convention"""
531
678
  return self._pipeline_layers
532
679
 
533
- def timed_call(self, **inputs):
534
- """Process input data through the pipeline."""
680
+ def reset_timer(self):
681
+ """Reset the timer for timed operations."""
682
+ if self._timed:
683
+ self._callable_layers = self._get_timed_operations()
684
+ else:
685
+ log.warning(
686
+ "Timer has not been initialized. Set timed=True when initializing the pipeline."
687
+ )
535
688
 
536
- for op in self._pipeline_layers:
537
- timed_op = self.timer(op, name=op.__class__.__name__)
538
- try:
539
- outputs = timed_op(**inputs)
540
- except KeyError as exc:
541
- raise KeyError(
542
- f"[zea.Pipeline] Operation '{op.__class__.__name__}' "
543
- f"requires input key '{exc.args[0]}', "
544
- "but it was not provided in the inputs.\n"
545
- "Check whether the objects (such as `zea.Scan`) passed to "
546
- "`pipeline.prepare_parameters()` contain all required keys.\n"
547
- f"Current list of all passed keys: {list(inputs.keys())}\n"
548
- f"Valid keys for this pipeline: {self.valid_keys}"
549
- ) from exc
550
- except Exception as exc:
551
- raise RuntimeError(
552
- f"[zea.Pipeline] Error in operation '{op.__class__.__name__}': {exc}"
553
- ) from exc
554
- inputs = outputs
555
- return outputs
689
+ def _get_timed_operations(self):
690
+ """Get a list of timed operations."""
691
+ self.timer = FunctionTimer()
692
+ return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
556
693
 
557
694
  def call(self, **inputs):
558
695
  """Process input data through the pipeline."""
559
- for operation in self._pipeline_layers:
696
+ for operation in self._callable_layers:
560
697
  try:
561
698
  outputs = operation(**inputs)
562
699
  except KeyError as exc:
@@ -579,14 +716,9 @@ class Pipeline:
579
716
  def __call__(self, return_numpy=False, **inputs):
580
717
  """Process input data through the pipeline."""
581
718
 
582
- if any(key in inputs for key in ["probe", "scan", "config"]):
583
- raise ValueError(
584
- "Probe, Scan and Config objects should be first processed with "
585
- "`Pipeline.prepare_parameters` before calling the pipeline. "
586
- "e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
587
- )
588
-
589
- if any(isinstance(arg, ZEAObject) for arg in inputs.values()):
719
+ if any(key in inputs for key in ["probe", "scan", "config"]) or any(
720
+ isinstance(arg, ZEAObject) for arg in inputs.values()
721
+ ):
590
722
  raise ValueError(
591
723
  "Probe, Scan and Config objects should be first processed with "
592
724
  "`Pipeline.prepare_parameters` before calling the pipeline. "
@@ -599,6 +731,16 @@ class Pipeline:
599
731
  "Please ensure all inputs are convertible to tensors."
600
732
  )
601
733
 
734
+ if not self._logged_difference_keys:
735
+ difference_keys = set(inputs.keys()) - self.valid_keys
736
+ if difference_keys:
737
+ log.debug(
738
+ f"[zea.Pipeline] The following input keys are not used by the pipeline: "
739
+ f"{difference_keys}. Make sure this is intended. "
740
+ "This warning will only be shown once."
741
+ )
742
+ self._logged_difference_keys = True
743
+
602
744
  ## PROCESSING
603
745
  outputs = self._call_pipeline(**inputs)
604
746
 
@@ -640,18 +782,13 @@ class Pipeline:
640
782
  if operation.jittable and operation._jit_compile:
641
783
  operation.set_jit(value == "ops")
642
784
 
643
- @property
644
- def _call_fn(self):
645
- """Get the call function of the pipeline."""
646
- return self.call if not self.timed else self.timed_call
647
-
648
785
  def jit(self):
649
786
  """JIT compile the pipeline."""
650
- self._call_pipeline = jit(self._call_fn, **self.jit_kwargs)
787
+ self._call_pipeline = jit(self.call, **self.jit_kwargs)
651
788
 
652
789
  def unjit(self):
653
790
  """Un-JIT compile the pipeline."""
654
- self._call_pipeline = self._call_fn
791
+ self._call_pipeline = self.call
655
792
 
656
793
  @property
657
794
  def jittable(self):
@@ -726,59 +863,9 @@ class Pipeline:
726
863
  return params
727
864
 
728
865
  def __str__(self):
729
- """String representation of the pipeline.
730
-
731
- Will print on two parallel pipeline lines if it detects a splitting operations
732
- (such as multi_bandpass_filter)
733
- Will merge the pipeline lines if it detects a stacking operation (such as stack)
734
- """
735
- split_operations = []
736
- merge_operations = ["Stack"]
737
-
866
+ """String representation of the pipeline."""
738
867
  operations = [operation.__class__.__name__ for operation in self.operations]
739
868
  string = " -> ".join(operations)
740
-
741
- if any(operation in split_operations for operation in operations):
742
- # a second line is needed with same length as the first line
743
- split_line = " " * len(string)
744
- # find the splitting operation and index and print \-> instead of -> after
745
- split_detected = False
746
- merge_detected = False
747
- split_operation = None
748
- for operation in operations:
749
- if operation in split_operations:
750
- index = string.index(operation)
751
- index = index + len(operation)
752
- split_line = split_line[:index] + "\\->" + split_line[index + len("\\->") :]
753
- split_detected = True
754
- merge_detected = False
755
- split_operation = operation
756
- continue
757
-
758
- if operation in merge_operations:
759
- index = string.index(operation)
760
- index = index - 4
761
- split_line = split_line[:index] + "/" + split_line[index + 1 :]
762
- split_detected = False
763
- merge_detected = True
764
- continue
765
-
766
- if split_detected:
767
- # print all operations in the second line
768
- index = string.index(operation)
769
- split_line = (
770
- split_line[:index]
771
- + operation
772
- + " -> "
773
- + split_line[index + len(operation) + len(" -> ") :]
774
- )
775
- assert merge_detected is True, log.error(
776
- "Pipeline was never merged back together (with Stack operation), even "
777
- f"though it was split with {split_operation}. "
778
- "Please properly define your operation chain."
779
- )
780
- return f"\n{string}\n{split_line}\n"
781
-
782
869
  return string
783
870
 
784
871
  def __repr__(self):
@@ -835,16 +922,17 @@ class Pipeline:
835
922
  Must have a ``pipeline`` key with a subkey ``operations``.
836
923
 
837
924
  Example:
838
- .. code-block:: python
839
-
840
- config = Config(
841
- {
842
- "operations": [
843
- "identity",
844
- ],
845
- }
846
- )
847
- pipeline = Pipeline.from_config(config)
925
+ .. doctest::
926
+
927
+ >>> from zea import Config, Pipeline
928
+ >>> config = Config(
929
+ ... {
930
+ ... "operations": [
931
+ ... "identity",
932
+ ... ],
933
+ ... }
934
+ ... )
935
+ >>> pipeline = Pipeline.from_config(config)
848
936
  """
849
937
  return pipeline_from_config(Config(config), **kwargs)
850
938
 
@@ -860,9 +948,20 @@ class Pipeline:
860
948
  Must have the a `pipeline` key with a subkey `operations`.
861
949
 
862
950
  Example:
863
- ```python
864
- pipeline = Pipeline.from_yaml("pipeline.yaml")
865
- ```
951
+ .. doctest::
952
+
953
+ >>> import yaml
954
+ >>> from zea import Config
955
+ >>> # Create a sample pipeline YAML file
956
+ >>> pipeline_dict = {
957
+ ... "operations": [
958
+ ... "identity",
959
+ ... ],
960
+ ... }
961
+ >>> with open("pipeline.yaml", "w") as f:
962
+ ... yaml.dump(pipeline_dict, f)
963
+ >>> from zea.ops import Pipeline
964
+ >>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
866
965
  """
867
966
  return pipeline_from_yaml(file_path, **kwargs)
868
967
 
@@ -963,7 +1062,7 @@ class Pipeline:
963
1062
  assert isinstance(scan, Scan), (
964
1063
  f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
965
1064
  )
966
- scan_dict = scan.to_tensor(include=self.valid_keys, keep_as_is=self.static_params)
1065
+ scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
967
1066
 
968
1067
  if config is not None:
969
1068
  assert isinstance(config, Config), (
@@ -1004,15 +1103,17 @@ def make_operation_chain(
1004
1103
  list: List of operations to be performed.
1005
1104
 
1006
1105
  Example:
1007
- .. code-block:: python
1008
-
1009
- chain = make_operation_chain(
1010
- [
1011
- "envelope_detect",
1012
- {"name": "normalize", "params": {"output_range": (0, 1)}},
1013
- SomeCustomOperation(),
1014
- ]
1015
- )
1106
+ .. doctest::
1107
+
1108
+ >>> from zea.ops import make_operation_chain, LogCompress
1109
+ >>> SomeCustomOperation = LogCompress # just for demonstration
1110
+ >>> chain = make_operation_chain(
1111
+ ... [
1112
+ ... "envelope_detect",
1113
+ ... {"name": "normalize", "params": {"output_range": (0, 1)}},
1114
+ ... SomeCustomOperation(),
1115
+ ... ]
1116
+ ... )
1016
1117
  """
1017
1118
  chain = []
1018
1119
  for operation in operation_chain:
@@ -1093,7 +1194,7 @@ def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
1093
1194
  operations = make_operation_chain(config.operations)
1094
1195
 
1095
1196
  # merge pipeline config without operations with kwargs
1096
- pipeline_config = copy.deepcopy(config)
1197
+ pipeline_config = config.copy()
1097
1198
  pipeline_config.pop("operations")
1098
1199
 
1099
1200
  kwargs = {**pipeline_config, **kwargs}
@@ -1164,33 +1265,134 @@ def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
1164
1265
  yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
1165
1266
 
1166
1267
 
1167
- @ops_registry("patched_grid")
1168
- class PatchedGrid(Pipeline):
1268
+ @ops_registry("map")
1269
+ class Map(Pipeline):
1169
1270
  """
1170
- With this class you can form a pipeline that will be applied to patches of the grid.
1171
- This is useful to avoid OOM errors when processing large grids.
1172
-
1173
- Some things to NOTE about this class:
1174
-
1175
- - The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
1271
+ A pipeline that maps its operations over specified input arguments.
1176
1272
 
1177
- - Changing anything other than `self.output_data_type` in the dict will not be propagated!
1273
+ This can be used to reduce memory usage by processing data in chunks.
1178
1274
 
1275
+ Notes
1276
+ -----
1277
+ - When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline.
1278
+ - Changing anything other than ``self.output_key`` in the dict will not be propagated.
1179
1279
  - Will be jitted as a single operation, not the individual operations.
1180
-
1181
1280
  - This class handles the batching.
1182
1281
 
1282
+ For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for
1283
+ jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_.
1284
+
1285
+ Example
1286
+ -------
1287
+ .. doctest::
1288
+
1289
+ >>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection
1290
+
1291
+ >>> # apply operations in batches of 8
1292
+ >>> # in this case, over the first axis of "data"
1293
+ >>> # or more specifically, process 8 transmits at a time
1294
+
1295
+ >>> pipeline_mapped = Map(
1296
+ ... [
1297
+ ... Demodulate(),
1298
+ ... TOFCorrection(),
1299
+ ... ],
1300
+ ... argnames="data",
1301
+ ... batch_size=8,
1302
+ ... )
1303
+
1304
+ >>> # you can also map a subset of the operations
1305
+ >>> # for example, demodulate in 4 chunks
1306
+ >>> # or more specifically, split the transmit axis into 4 parts
1307
+
1308
+ >>> pipeline_mapped = Pipeline(
1309
+ ... [
1310
+ ... Map([Demodulate()], argnames="data", chunks=4),
1311
+ ... TOFCorrection(),
1312
+ ... ],
1313
+ ... )
1183
1314
  """
1184
1315
 
1185
- def __init__(self, *args, num_patches=10, **kwargs):
1186
- super().__init__(*args, name="patched_grid", **kwargs)
1187
- self.num_patches = num_patches
1316
+ def __init__(
1317
+ self,
1318
+ operations: List[Operation],
1319
+ argnames: List[str] | str,
1320
+ in_axes: List[Union[int, None]] | int = 0,
1321
+ out_axes: List[Union[int, None]] | int = 0,
1322
+ chunks: int | None = None,
1323
+ batch_size: int | None = None,
1324
+ **kwargs,
1325
+ ):
1326
+ """
1327
+ Args:
1328
+ operations (list): List of operations to be performed.
1329
+ argnames (str or list): List of argument names (or keys) to map over.
1330
+ Can also be a single string if only one argument is mapped over.
1331
+ in_axes (int or list): Axes to map over for each argument.
1332
+ If a single int is provided, it is used for all arguments.
1333
+ out_axes (int or list): Axes to map over for each output.
1334
+ If a single int is provided, it is used for all outputs.
1335
+ chunks (int, optional): Number of chunks to split the input data into.
1336
+ If None, no chunking is performed. Mutually exclusive with ``batch_size``.
1337
+ batch_size (int, optional): Size of batches to process at once.
1338
+ If None, no batching is performed. Mutually exclusive with ``chunks``.
1339
+ """
1340
+ super().__init__(operations, **kwargs)
1188
1341
 
1189
- for operation in self.operations:
1190
- if isinstance(operation, DelayAndSum):
1191
- operation.reshape_grid = False
1342
+ if batch_size is not None and chunks is not None:
1343
+ raise ValueError(
1344
+ "batch_size and chunks are mutually exclusive. Please specify only one."
1345
+ )
1192
1346
 
1193
- self._jittable_call = self.jittable_call
1347
+ if batch_size is not None and batch_size <= 0:
1348
+ raise ValueError("batch_size must be a positive integer.")
1349
+
1350
+ if chunks is not None and chunks <= 0:
1351
+ raise ValueError("chunks must be a positive integer.")
1352
+
1353
+ if isinstance(argnames, str):
1354
+ argnames = [argnames]
1355
+
1356
+ self.argnames = argnames
1357
+ self.in_axes = in_axes
1358
+ self.out_axes = out_axes
1359
+ self.chunks = chunks
1360
+ self.batch_size = batch_size
1361
+
1362
+ if chunks is None and batch_size is None:
1363
+ log.warning(
1364
+ "[zea.ops.Map] Both `chunks` and `batch_size` are None. "
1365
+ "This will behave like a normal Pipeline. "
1366
+ "Consider setting one of them to process data in chunks or batches."
1367
+ )
1368
+
1369
+ def call_item(**inputs):
1370
+ """Process data in patches."""
1371
+ mapped_args = []
1372
+ for argname in argnames:
1373
+ mapped_args.append(inputs.pop(argname, None))
1374
+
1375
+ def patched_call(*args):
1376
+ mapped_kwargs = [(k, v) for k, v in zip(argnames, args)]
1377
+ out = super(Map, self).call(**dict(mapped_kwargs), **inputs)
1378
+
1379
+ # TODO: maybe it is possible to output everything?
1380
+ # e.g. prepend a empty dimension to all inputs and just map over everything?
1381
+ return out[self.output_key]
1382
+
1383
+ out = vmap(
1384
+ patched_call,
1385
+ in_axes=in_axes,
1386
+ out_axes=out_axes,
1387
+ chunks=chunks,
1388
+ batch_size=batch_size,
1389
+ fn_supports_batch=True,
1390
+ disable_jit=not bool(self.jit_options),
1391
+ )(*mapped_args)
1392
+
1393
+ return out
1394
+
1395
+ self.call_item = call_item
1194
1396
 
1195
1397
  @property
1196
1398
  def jit_options(self):
@@ -1230,59 +1432,28 @@ class PatchedGrid(Pipeline):
1230
1432
 
1231
1433
  @property
1232
1434
  def valid_keys(self) -> set:
1233
- """Get a set of valid keys for the pipeline. Adds the parameters that PatchedGrid itself
1234
- operates on (even if not used by operations inside it)."""
1235
- return super().valid_keys.union({"flatgrid", "grid_size_x", "grid_size_z"})
1236
-
1237
- def call_item(self, inputs):
1238
- """Process data in patches."""
1239
- # Extract necessary parameters
1240
- # make sure to add those as valid keys above!
1241
- grid_size_x = inputs["grid_size_x"]
1242
- grid_size_z = inputs["grid_size_z"]
1243
- flatgrid = inputs.pop("flatgrid")
1244
-
1245
- # TODO: maybe using n_tx and n_el from kwargs is better but these are tensors now
1246
- # and this is not supported in broadcast_to
1247
- n_tx = inputs[self.key].shape[0]
1248
- n_pix = flatgrid.shape[0]
1249
- n_el = inputs[self.key].shape[2]
1250
- inputs["rx_apo"] = ops.broadcast_to(inputs.get("rx_apo", 1.0), (n_tx, n_pix, n_el))
1251
- inputs["rx_apo"] = ops.swapaxes(inputs["rx_apo"], 0, 1) # put n_pix first
1252
-
1253
- # Define a list of keys to look up for patching
1254
- patch_keys = ["flat_pfield", "rx_apo"]
1255
-
1256
- patch_arrays = {}
1257
- for key in patch_keys:
1258
- if key in inputs:
1259
- patch_arrays[key] = inputs.pop(key)
1260
-
1261
- def patched_call(flatgrid, **patch_kwargs):
1262
- patch_args = {k: v for k, v in patch_kwargs.items() if v is not None}
1263
- patch_args["rx_apo"] = ops.swapaxes(patch_args["rx_apo"], 0, 1)
1264
- out = super(PatchedGrid, self).call(flatgrid=flatgrid, **patch_args, **inputs)
1265
- return out[self.output_key]
1266
-
1267
- out = patched_map(
1268
- patched_call,
1269
- flatgrid,
1270
- self.num_patches,
1271
- **patch_arrays,
1272
- jit=bool(self.jit_options),
1273
- )
1274
- return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
1435
+ """Get a set of valid keys for the pipeline.
1436
+ Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1437
+ inside it)."""
1438
+ return super().valid_keys.union(self.argnames)
1439
+
1440
+ @property
1441
+ def needs_keys(self) -> set:
1442
+ """Get a set of all input keys needed by the pipeline.
1443
+ Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1444
+ inside it)."""
1445
+ return super().needs_keys.union(self.argnames)
1275
1446
 
1276
1447
  def jittable_call(self, **inputs):
1277
1448
  """Process input data through the pipeline."""
1278
1449
  if self._with_batch_dim:
1279
1450
  input_data = inputs.pop(self.key)
1280
1451
  output = ops.map(
1281
- lambda x: self.call_item({self.key: x, **inputs}),
1452
+ lambda x: self.call_item(**{self.key: x, **inputs}),
1282
1453
  input_data,
1283
1454
  )
1284
1455
  else:
1285
- output = self.call_item(inputs)
1456
+ output = self.call_item(**inputs)
1286
1457
 
1287
1458
  return {self.output_key: output}
1288
1459
 
@@ -1295,11 +1466,61 @@ class PatchedGrid(Pipeline):
1295
1466
  def get_dict(self):
1296
1467
  """Get the configuration of the pipeline."""
1297
1468
  config = super().get_dict()
1298
- config.update({"name": "patched_grid"})
1469
+ config["params"].update(
1470
+ {
1471
+ "argnames": self.argnames,
1472
+ "in_axes": self.in_axes,
1473
+ "out_axes": self.out_axes,
1474
+ "chunks": self.chunks,
1475
+ "batch_size": self.batch_size,
1476
+ }
1477
+ )
1478
+ return config
1479
+
1480
+
1481
+ @ops_registry("patched_grid")
1482
+ class PatchedGrid(Map):
1483
+ """
1484
+ A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys.
1485
+
1486
+ This can be used to reduce memory usage by processing data in chunks.
1487
+
1488
+ For more information and flexibility, see :class:`zea.ops.Map`.
1489
+ """
1490
+
1491
+ def __init__(self, *args, num_patches=10, **kwargs):
1492
+ super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs)
1493
+ self.num_patches = num_patches
1494
+
1495
+ def get_dict(self):
1496
+ """Get the configuration of the pipeline."""
1497
+ config = super().get_dict()
1498
+ config["params"].pop("argnames")
1499
+ config["params"].pop("chunks")
1299
1500
  config["params"].update({"num_patches": self.num_patches})
1300
1501
  return config
1301
1502
 
1302
1503
 
1504
+ @ops_registry("reshape_grid")
1505
+ class ReshapeGrid(Operation):
1506
+ """Reshape flat grid data to grid shape."""
1507
+
1508
+ def __init__(self, axis=0, **kwargs):
1509
+ super().__init__(**kwargs)
1510
+ self.axis = axis
1511
+
1512
+ def call(self, grid, **kwargs):
1513
+ """
1514
+ Args:
1515
+ - data (Tensor): The flat grid data of shape (..., n_pix, ...).
1516
+ Returns:
1517
+ - reshaped_data (Tensor): The reshaped data of shape (..., grid.shape, ...).
1518
+ """
1519
+ data = kwargs[self.key]
1520
+ reshaped_data = reshape_axis(data, grid.shape[:-1], self.axis + int(self.with_batch_dim))
1521
+ return {self.output_key: reshaped_data}
1522
+
1523
+
1303
1524
  ## Base Operations
1304
1525
 
1305
1526
 
@@ -1309,7 +1530,7 @@ class Identity(Operation):
1309
1530
 
1310
1531
  def call(self, **kwargs) -> Dict:
1311
1532
  """Returns the input as is."""
1312
- return kwargs
1533
+ return {}
1313
1534
 
1314
1535
 
1315
1536
  @ops_registry("merge")
@@ -1332,21 +1553,6 @@ class Merge(Operation):
1332
1553
  return merged
1333
1554
 
1334
1555
 
1335
- @ops_registry("split")
1336
- class Split(Operation):
1337
- """Operation that splits an input dictionary n copies."""
1338
-
1339
- def __init__(self, n: int, **kwargs):
1340
- super().__init__(**kwargs)
1341
- self.n = n
1342
-
1343
- def call(self, **kwargs) -> List[Dict]:
1344
- """
1345
- Splits the input dictionary into n copies.
1346
- """
1347
- return [kwargs.copy() for _ in range(self.n)]
1348
-
1349
-
1350
1556
  @ops_registry("stack")
1351
1557
  class Stack(Operation):
1352
1558
  """Stack multiple data arrays along a new axis.
@@ -1399,6 +1605,7 @@ class Simulate(Operation):
1399
1605
  def __init__(self, **kwargs):
1400
1606
  super().__init__(
1401
1607
  output_data_type=DataTypes.RAW_DATA,
1608
+ additional_output_keys=["n_ch"],
1402
1609
  **kwargs,
1403
1610
  )
1404
1611
 
@@ -1451,18 +1658,16 @@ class TOFCorrection(Operation):
1451
1658
  STATIC_PARAMS = [
1452
1659
  "f_number",
1453
1660
  "apply_lens_correction",
1454
- "apply_phase_rotation",
1455
1661
  "grid_size_x",
1456
1662
  "grid_size_z",
1457
1663
  ]
1458
1664
 
1459
- def __init__(self, apply_phase_rotation=True, **kwargs):
1665
+ def __init__(self, **kwargs):
1460
1666
  super().__init__(
1461
1667
  input_data_type=DataTypes.RAW_DATA,
1462
1668
  output_data_type=DataTypes.ALIGNED_DATA,
1463
1669
  **kwargs,
1464
1670
  )
1465
- self.apply_phase_rotation = apply_phase_rotation
1466
1671
 
1467
1672
  def call(
1468
1673
  self,
@@ -1477,6 +1682,8 @@ class TOFCorrection(Operation):
1477
1682
  tx_apodizations,
1478
1683
  initial_times,
1479
1684
  probe_geometry,
1685
+ t_peak,
1686
+ tx_waveform_indices,
1480
1687
  apply_lens_correction=None,
1481
1688
  lens_thickness=None,
1482
1689
  lens_sound_speed=None,
@@ -1497,6 +1704,9 @@ class TOFCorrection(Operation):
1497
1704
  tx_apodizations (ops.Tensor): Transmit apodizations
1498
1705
  initial_times (ops.Tensor): Initial times
1499
1706
  probe_geometry (ops.Tensor): Probe element positions
1707
+ t_peak (float): Time to peak of the transmit pulse
1708
+ tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
1709
+ transmit. (All zero if there is only one waveform)
1500
1710
  apply_lens_correction (bool): Whether to apply lens correction
1501
1711
  lens_thickness (float): Lens thickness
1502
1712
  lens_sound_speed (float): Sound speed in the lens
@@ -1507,29 +1717,30 @@ class TOFCorrection(Operation):
1507
1717
 
1508
1718
  raw_data = kwargs[self.key]
1509
1719
 
1510
- kwargs = {
1720
+ tof_kwargs = {
1511
1721
  "flatgrid": flatgrid,
1512
- "sound_speed": sound_speed,
1513
- "angles": polar_angles,
1514
- "focus_distances": focus_distances,
1515
- "sampling_frequency": sampling_frequency,
1516
- "fnum": f_number,
1517
- "apply_phase_rotation": self.apply_phase_rotation,
1518
- "demodulation_frequency": demodulation_frequency,
1519
1722
  "t0_delays": t0_delays,
1520
1723
  "tx_apodizations": tx_apodizations,
1521
- "initial_times": initial_times,
1724
+ "sound_speed": sound_speed,
1522
1725
  "probe_geometry": probe_geometry,
1726
+ "initial_times": initial_times,
1727
+ "sampling_frequency": sampling_frequency,
1728
+ "demodulation_frequency": demodulation_frequency,
1729
+ "f_number": f_number,
1730
+ "polar_angles": polar_angles,
1731
+ "focus_distances": focus_distances,
1732
+ "t_peak": t_peak,
1733
+ "tx_waveform_indices": tx_waveform_indices,
1523
1734
  "apply_lens_correction": apply_lens_correction,
1524
1735
  "lens_thickness": lens_thickness,
1525
1736
  "lens_sound_speed": lens_sound_speed,
1526
1737
  }
1527
1738
 
1528
1739
  if not self.with_batch_dim:
1529
- tof_corrected = tof_correction(raw_data, **kwargs)
1740
+ tof_corrected = tof_correction(raw_data, **tof_kwargs)
1530
1741
  else:
1531
1742
  tof_corrected = ops.map(
1532
- lambda data: tof_correction(data, **kwargs),
1743
+ lambda data: tof_correction(data, **tof_kwargs),
1533
1744
  raw_data,
1534
1745
  )
1535
1746
 
@@ -1556,7 +1767,7 @@ class PfieldWeighting(Operation):
1556
1767
  Returns:
1557
1768
  dict: Dictionary containing weighted data
1558
1769
  """
1559
- data = kwargs[self.key]
1770
+ data = kwargs[self.key] # must start with ((batch_size,) n_tx, n_pix, ...)
1560
1771
 
1561
1772
  if flat_pfield is None:
1562
1773
  return {self.output_key: data}
@@ -1564,14 +1775,16 @@ class PfieldWeighting(Operation):
1564
1775
  # Swap (n_pix, n_tx) to (n_tx, n_pix)
1565
1776
  flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
1566
1777
 
1567
- # Perform element-wise multiplication with the pressure weight mask
1568
- # Also add the required dimensions for broadcasting
1778
+ # Add batch dimension if needed
1569
1779
  if self.with_batch_dim:
1570
1780
  pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
1571
1781
  else:
1572
1782
  pfield_expanded = flat_pfield
1573
1783
 
1574
- pfield_expanded = pfield_expanded[..., None, None]
1784
+ append_n_dims = ops.ndim(data) - ops.ndim(pfield_expanded)
1785
+ pfield_expanded = extend_n_dims(pfield_expanded, axis=-1, n_dims=append_n_dims)
1786
+
1787
+ # Perform element-wise multiplication with the pressure weight mask
1575
1788
  weighted_data = data * pfield_expanded
1576
1789
 
1577
1790
  return {self.output_key: weighted_data}
@@ -1581,79 +1794,93 @@ class PfieldWeighting(Operation):
1581
1794
  class DelayAndSum(Operation):
1582
1795
  """Sums time-delayed signals along channels and transmits."""
1583
1796
 
1584
- def __init__(
1585
- self,
1586
- reshape_grid=True,
1587
- **kwargs,
1588
- ):
1797
+ def __init__(self, **kwargs):
1589
1798
  super().__init__(
1590
- input_data_type=None,
1799
+ input_data_type=DataTypes.ALIGNED_DATA,
1591
1800
  output_data_type=DataTypes.BEAMFORMED_DATA,
1592
1801
  **kwargs,
1593
1802
  )
1594
- self.reshape_grid = reshape_grid
1595
1803
 
1596
- def process_image(self, data, rx_apo):
1804
+ def process_image(self, data):
1597
1805
  """Performs DAS beamforming on tof-corrected input.
1598
1806
 
1599
1807
  Args:
1600
1808
  data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
1601
- rx_apo (ops.Tensor): Receive apodization window of shape `(n_tx, n_pix, n_el, n_ch)`.
1602
1809
 
1603
1810
  Returns:
1604
1811
  ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
1605
1812
  """
1606
1813
  # Sum over the channels, i.e. DAS
1607
- data = ops.sum(rx_apo * data, -2)
1814
+ data = ops.sum(data, -2)
1608
1815
 
1609
1816
  # Sum over transmits, i.e. Compounding
1610
1817
  data = ops.sum(data, 0)
1611
1818
 
1612
1819
  return data
1613
1820
 
1614
- def call(
1615
- self,
1616
- rx_apo=None,
1617
- grid=None,
1618
- **kwargs,
1619
- ):
1821
+ def call(self, grid=None, **kwargs):
1620
1822
  """Performs DAS beamforming on tof-corrected input.
1621
1823
 
1622
1824
  Args:
1623
1825
  tof_corrected_data (ops.Tensor): The TOF corrected input of shape
1624
1826
  `(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
1625
- rx_apo (ops.Tensor): Receive apodization window
1626
- of shape `(n_tx, grid_size_z*grid_size_x, n_el)`
1627
- with optional batch dimension. Defaults to 1.0.
1628
1827
 
1629
1828
  Returns:
1630
1829
  dict: Dictionary containing beamformed_data
1631
- of shape `(grid_size_z*grid_size_x, n_ch)` when reshape_grid is False
1632
- or `(grid_size_z, grid_size_x, n_ch)` when reshape_grid is True,
1830
+ of shape `(grid_size_z*grid_size_x, n_ch)`
1633
1831
  with optional batch dimension.
1634
1832
  """
1635
1833
  data = kwargs[self.key]
1636
1834
 
1637
- if rx_apo is None:
1638
- rx_apo = ops.ones(1, dtype=ops.dtype(data))
1639
- rx_apo = ops.broadcast_to(rx_apo[..., None], data.shape)
1640
-
1641
1835
  if not self.with_batch_dim:
1642
- beamformed_data = self.process_image(data, rx_apo)
1836
+ beamformed_data = self.process_image(data)
1643
1837
  else:
1644
1838
  # Apply process_image to each item in the batch
1645
- beamformed_data = batched_map(
1646
- lambda data, rx_apo: self.process_image(data, rx_apo), data, rx_apo=rx_apo
1647
- )
1648
-
1649
- if self.reshape_grid:
1650
- beamformed_data = reshape_axis(
1651
- beamformed_data, grid.shape[:2], axis=int(self.with_batch_dim)
1652
- )
1839
+ beamformed_data = ops.map(self.process_image, data)
1653
1840
 
1654
1841
  return {self.output_key: beamformed_data}
1655
1842
 
1656
1843
 
1844
+ def envelope_detect(data, axis=-3):
1845
+ """Envelope detection of RF signals.
1846
+
1847
+ If the input data is real, it first applies the Hilbert transform along the specified axis
1848
+ and then computes the magnitude of the resulting complex signal.
1849
+ If the input data is complex, it computes the magnitude directly.
1850
+
1851
+ Args:
1852
+ - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
1853
+ - axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
1854
+
1855
+ Returns:
1856
+ - envelope_data (Tensor): The envelope detected data
1857
+ of shape (..., grid_size_z, grid_size_x).
1858
+ """
1859
+ if data.shape[-1] == 2:
1860
+ data = channels_to_complex(data)
1861
+ else:
1862
+ n_ax = ops.shape(data)[axis]
1863
+ n_ax_float = ops.cast(n_ax, "float32")
1864
+
1865
+ # Calculate next power of 2: M = 2^ceil(log2(n_ax))
1866
+ # see https://github.com/tue-bmd/zea/discussions/147
1867
+ log2_n_ax = ops.log2(n_ax_float)
1868
+ M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
1869
+
1870
+ data = hilbert(data, N=M, axis=axis)
1871
+ indices = ops.arange(n_ax)
1872
+
1873
+ data = ops.take(data, indices, axis=axis)
1874
+ data = ops.squeeze(data, axis=-1)
1875
+
1876
+ # data = ops.abs(data)
1877
+ real = ops.real(data)
1878
+ imag = ops.imag(data)
1879
+ data = ops.sqrt(real**2 + imag**2)
1880
+ data = ops.cast(data, "float32")
1881
+ return data
1882
+
1883
+
1657
1884
  @ops_registry("envelope_detect")
1658
1885
  class EnvelopeDetect(Operation):
1659
1886
  """Envelope detection of RF signals."""
@@ -1680,23 +1907,7 @@ class EnvelopeDetect(Operation):
1680
1907
  """
1681
1908
  data = kwargs[self.key]
1682
1909
 
1683
- if data.shape[-1] == 2:
1684
- data = channels_to_complex(data)
1685
- else:
1686
- n_ax = data.shape[self.axis]
1687
- M = 2 ** int(np.ceil(np.log2(n_ax)))
1688
- # data = scipy.signal.hilbert(data, N=M, axis=self.axis)
1689
- data = hilbert(data, N=M, axis=self.axis)
1690
- indices = ops.arange(n_ax)
1691
-
1692
- data = ops.take(data, indices, axis=self.axis)
1693
- data = ops.squeeze(data, axis=-1)
1694
-
1695
- # data = ops.abs(data)
1696
- real = ops.real(data)
1697
- imag = ops.imag(data)
1698
- data = ops.sqrt(real**2 + imag**2)
1699
- data = ops.cast(data, "float32")
1910
+ data = envelope_detect(data, axis=self.axis)
1700
1911
 
1701
1912
  return {self.output_key: data}
1702
1913
 
@@ -1734,19 +1945,29 @@ class UpMix(Operation):
1734
1945
  return {self.output_key: data}
1735
1946
 
1736
1947
 
1948
+ def log_compress(data, eps=1e-16):
1949
+ """Apply logarithmic compression to data."""
1950
+ eps = ops.convert_to_tensor(eps, dtype=data.dtype)
1951
+ data = ops.where(data == 0, eps, data) # Avoid log(0)
1952
+ return 20 * keras.ops.log10(data)
1953
+
1954
+
1737
1955
  @ops_registry("log_compress")
1738
1956
  class LogCompress(Operation):
1739
1957
  """Logarithmic compression of data."""
1740
1958
 
1741
- def __init__(
1742
- self,
1743
- **kwargs,
1744
- ):
1959
+ def __init__(self, clip: bool = True, **kwargs):
1960
+ """Initialize the LogCompress operation.
1961
+
1962
+ Args:
1963
+ clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
1964
+ """
1745
1965
  super().__init__(
1746
1966
  input_data_type=DataTypes.ENVELOPE_DATA,
1747
1967
  output_data_type=DataTypes.IMAGE,
1748
1968
  **kwargs,
1749
1969
  )
1970
+ self.clip = clip
1750
1971
 
1751
1972
  def call(self, dynamic_range=None, **kwargs):
1752
1973
  """Apply logarithmic compression to data.
@@ -1763,20 +1984,43 @@ class LogCompress(Operation):
1763
1984
  dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
1764
1985
  dynamic_range = ops.cast(dynamic_range, data.dtype)
1765
1986
 
1766
- small_number = ops.convert_to_tensor(1e-16, dtype=data.dtype)
1767
- data = ops.where(data == 0, small_number, data)
1768
- compressed_data = 20 * ops.log10(data)
1769
- compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
1987
+ compressed_data = log_compress(data)
1988
+ if self.clip:
1989
+ compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
1770
1990
 
1771
1991
  return {self.output_key: compressed_data}
1772
1992
 
1773
1993
 
1994
+ def normalize(data, output_range, input_range=None):
1995
+ """Normalize data to a given range.
1996
+
1997
+ Equivalent to `translate` with clipping.
1998
+
1999
+ Args:
2000
+ data (ops.Tensor): Input data to normalize.
2001
+ output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
2002
+ input_range (tuple, optional): Range of input data.
2003
+ If None, the range will be computed from the data.
2004
+ Defaults to None.
2005
+ """
2006
+ if input_range is None:
2007
+ input_range = (None, None)
2008
+ minval, maxval = input_range
2009
+ if minval is None:
2010
+ minval = ops.min(data)
2011
+ if maxval is None:
2012
+ maxval = ops.max(data)
2013
+ data = ops.clip(data, minval, maxval)
2014
+ normalized_data = translate(data, (minval, maxval), output_range)
2015
+ return normalized_data
2016
+
2017
+
1774
2018
  @ops_registry("normalize")
1775
2019
  class Normalize(Operation):
1776
2020
  """Normalize data to a given range."""
1777
2021
 
1778
2022
  def __init__(self, output_range=None, input_range=None, **kwargs):
1779
- super().__init__(**kwargs)
2023
+ super().__init__(additional_output_keys=["minval", "maxval"], **kwargs)
1780
2024
  if output_range is None:
1781
2025
  output_range = (0, 1)
1782
2026
  self.output_range = self.to_float32(output_range)
@@ -1821,11 +2065,9 @@ class Normalize(Operation):
1821
2065
  if maxval is None:
1822
2066
  maxval = ops.max(data)
1823
2067
 
1824
- # Clip the data to the input range
1825
- data = ops.clip(data, minval, maxval)
1826
-
1827
- # Map the data to the output range
1828
- normalized_data = translate(data, (minval, maxval), self.output_range)
2068
+ normalized_data = normalize(
2069
+ data, output_range=self.output_range, input_range=(minval, maxval)
2070
+ )
1829
2071
 
1830
2072
  return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
1831
2073
 
@@ -1855,6 +2097,18 @@ class ScanConvert(Operation):
1855
2097
  input_data_type=DataTypes.IMAGE,
1856
2098
  output_data_type=DataTypes.IMAGE_SC,
1857
2099
  jittable=jittable,
2100
+ additional_output_keys=[
2101
+ "resolution",
2102
+ "x_lim",
2103
+ "y_lim",
2104
+ "z_lim",
2105
+ "rho_range",
2106
+ "theta_range",
2107
+ "phi_range",
2108
+ "d_rho",
2109
+ "d_theta",
2110
+ "d_phi",
2111
+ ],
1858
2112
  **kwargs,
1859
2113
  )
1860
2114
  self.order = order
@@ -1913,7 +2167,7 @@ class ScanConvert(Operation):
1913
2167
 
1914
2168
 
1915
2169
  @ops_registry("gaussian_blur")
1916
- class GaussianBlur(Operation):
2170
+ class GaussianBlur(ImageOperation):
1917
2171
  """
1918
2172
  GaussianBlur is an operation that applies a Gaussian blur to an input image.
1919
2173
  Uses scipy.ndimage.gaussian_filter to create a kernel.
@@ -1963,6 +2217,13 @@ class GaussianBlur(Operation):
1963
2217
  return ops.convert_to_tensor(kernel)
1964
2218
 
1965
2219
  def call(self, **kwargs):
2220
+ """Apply a Gaussian filter to the input data.
2221
+
2222
+ Args:
2223
+ data (ops.Tensor): Input image data of shape (height, width, channels) with
2224
+ optional batch dimension if ``self.with_batch_dim``.
2225
+ """
2226
+ super().call(**kwargs)
1966
2227
  data = kwargs[self.key]
1967
2228
 
1968
2229
  # Add batch dimension if not present
@@ -1997,7 +2258,7 @@ class GaussianBlur(Operation):
1997
2258
 
1998
2259
 
1999
2260
  @ops_registry("lee_filter")
2000
- class LeeFilter(Operation):
2261
+ class LeeFilter(ImageOperation):
2001
2262
  """
2002
2263
  The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
2003
2264
  and ultrasound image processing. It smooths the image while preserving edges and details.
@@ -2027,7 +2288,7 @@ class LeeFilter(Operation):
2027
2288
  pad_mode=self.pad_mode,
2028
2289
  with_batch_dim=self.with_batch_dim,
2029
2290
  jittable=self._jittable,
2030
- key=self.key,
2291
+ key="data",
2031
2292
  )
2032
2293
 
2033
2294
  @property
@@ -2043,24 +2304,29 @@ class LeeFilter(Operation):
2043
2304
  self.gaussian_blur.with_batch_dim = value
2044
2305
 
2045
2306
  def call(self, **kwargs):
2046
- data = kwargs[self.key]
2307
+ """Apply the Lee filter to the input data.
2308
+
2309
+ Args:
2310
+ data (ops.Tensor): Input image data of shape (height, width, channels) with
2311
+ optional batch dimension if ``self.with_batch_dim``.
2312
+ """
2313
+ super().call(**kwargs)
2314
+ data = kwargs.pop(self.key)
2047
2315
 
2048
2316
  # Apply Gaussian blur to get local mean
2049
- img_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
2317
+ img_mean = self.gaussian_blur.call(data=data, **kwargs)[self.gaussian_blur.output_key]
2050
2318
 
2051
2319
  # Apply Gaussian blur to squared data to get local squared mean
2052
- data_squared = data**2
2053
- kwargs[self.gaussian_blur.key] = data_squared
2054
- img_sqr_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
2320
+ img_sqr_mean = self.gaussian_blur.call(
2321
+ data=data**2,
2322
+ **kwargs,
2323
+ )[self.gaussian_blur.output_key]
2055
2324
 
2056
2325
  # Calculate local variance
2057
2326
  img_variance = img_sqr_mean - img_mean**2
2058
2327
 
2059
2328
  # Calculate global variance (per channel)
2060
- if self.with_batch_dim:
2061
- overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
2062
- else:
2063
- overall_variance = ops.var(data, axis=(-2, -1), keepdims=True)
2329
+ overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
2064
2330
 
2065
2331
  # Calculate adaptive weights
2066
2332
  img_weights = img_variance / (img_variance + overall_variance)
@@ -2081,6 +2347,11 @@ class Demodulate(Operation):
2081
2347
  input_data_type=DataTypes.RAW_DATA,
2082
2348
  output_data_type=DataTypes.RAW_DATA,
2083
2349
  jittable=True,
2350
+ additional_output_keys=[
2351
+ "demodulation_frequency",
2352
+ "center_frequency",
2353
+ "n_ch",
2354
+ ],
2084
2355
  **kwargs,
2085
2356
  )
2086
2357
  self.axis = axis
@@ -2106,6 +2377,121 @@ class Demodulate(Operation):
2106
2377
  }
2107
2378
 
2108
2379
 
2380
+ @ops_registry("fir_filter")
2381
+ class FirFilter(Operation):
2382
+ """Apply a FIR filter to the input signal using convolution.
2383
+
2384
+ Looks for the filter taps in the input dictionary using the specified ``filter_key``.
2385
+ """
2386
+
2387
+ def __init__(
2388
+ self,
2389
+ axis: int,
2390
+ complex_channels: bool = False,
2391
+ filter_key: str = "fir_filter_taps",
2392
+ **kwargs,
2393
+ ):
2394
+ """
2395
+ Args:
2396
+ axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
2397
+ When using ``complex_channels=True``, the complex channels are removed to convert
2398
+ to complex numbers before filtering, so adjust the ``axis`` accordingly!
2399
+ complex_channels (bool): Whether the last dimension of the input signal represents
2400
+ complex channels (real and imaginary parts). When True, it will convert the signal
2401
+ to ``complex`` dtype before filtering and convert it back to two channels
2402
+ after filtering.
2403
+ filter_key (str): Key in the input dictionary where the FIR filter taps are stored.
2404
+ Default is "fir_filter_taps".
2405
+ """
2406
+ super().__init__(**kwargs)
2407
+ self._check_axis(axis)
2408
+
2409
+ self.axis = axis
2410
+ self.complex_channels = complex_channels
2411
+ self.filter_key = filter_key
2412
+
2413
+ def _check_axis(self, axis, ndim=None):
2414
+ """Check if the axis is valid."""
2415
+ if ndim is not None:
2416
+ if axis < -ndim or axis >= ndim:
2417
+ raise ValueError(f"Axis {axis} is out of bounds for array of dimension {ndim}.")
2418
+
2419
+ if self.with_batch_dim and (axis == 0 or (ndim is not None and axis == -ndim)):
2420
+ raise ValueError("Cannot apply FIR filter along batch dimension.")
2421
+
2422
+ @property
2423
+ def valid_keys(self):
2424
+ """Get the valid keys for the `call` method."""
2425
+ return self._valid_keys.union({self.filter_key})
2426
+
2427
+ def call(self, **kwargs):
2428
+ signal = kwargs[self.key]
2429
+ fir_filter_taps = kwargs[self.filter_key]
2430
+
2431
+ if self.complex_channels:
2432
+ signal = channels_to_complex(signal)
2433
+
2434
+ self._check_axis(self.axis, ndim=ops.ndim(signal))
2435
+
2436
+ def _convolve(signal):
2437
+ """Apply the filter to the signal using correlation."""
2438
+ return correlate(signal, fir_filter_taps[::-1], mode="same")
2439
+
2440
+ filtered_signal = apply_along_axis(_convolve, self.axis, signal)
2441
+
2442
+ if self.complex_channels:
2443
+ filtered_signal = complex_to_channels(filtered_signal)
2444
+
2445
+ return {self.output_key: filtered_signal}
2446
+
2447
+
2448
+ @ops_registry("low_pass_filter")
2449
+ class LowPassFilter(FirFilter):
2450
+ """Apply a low-pass FIR filter to the input signal using convolution.
2451
+
2452
+ It is recommended to use :class:`FirFilter` with pre-computed filter taps for jittable
2453
+ operations. The :class:`LowPassFilter` operation itself is not jittable and is provided
2454
+ for convenience only.
2455
+
2456
+ Uses :func:`get_low_pass_iq_filter` to compute the filter taps.
2457
+ """
2458
+
2459
+ def __init__(self, axis: int, complex_channels: bool = False, num_taps: int = 128, **kwargs):
2460
+ """Initialize the LowPassFilter operation.
2461
+
2462
+ Args:
2463
+ axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
2464
+ When using ``complex_channels=True``, the complex channels are removed to convert
2465
+ to complex numbers before filtering, so adjust the ``axis`` accordingly.
2466
+ complex_channels (bool): Whether the last dimension of the input signal represents
2467
+ complex channels (real and imaginary parts). When True, it will convert the signal
2468
+ to ``complex`` dtype before filtering and convert it back to two channels
2469
+ after filtering.
2470
+ num_taps (int): Number of taps in the FIR filter. Default is 128.
2471
+ """
2472
+ self._random_suffix = str(uuid.uuid4())
2473
+ kwargs.pop("filter_key", None)
2474
+ kwargs.pop("jittable", None)
2475
+ super().__init__(
2476
+ axis=axis,
2477
+ complex_channels=complex_channels,
2478
+ filter_key=f"low_pass_{self._random_suffix}",
2479
+ jittable=False,
2480
+ **kwargs,
2481
+ )
2482
+ self.num_taps = num_taps
2483
+
2484
+ def call(self, bandwidth, sampling_frequency, center_frequency, **kwargs):
2485
+ lpf = get_low_pass_iq_filter(
2486
+ self.num_taps,
2487
+ ops.convert_to_numpy(sampling_frequency).item(),
2488
+ ops.convert_to_numpy(center_frequency).item(),
2489
+ ops.convert_to_numpy(bandwidth).item(),
2490
+ )
2491
+ kwargs[self.filter_key] = lpf
2492
+ return super().call(**kwargs)
2493
+
2494
+
2109
2495
  @ops_registry("lambda")
2110
2496
  class Lambda(Operation):
2111
2497
  """Use any function as an operation."""
@@ -2140,12 +2526,15 @@ class Lambda(Operation):
2140
2526
 
2141
2527
  def call(self, **kwargs):
2142
2528
  data = kwargs[self.key]
2143
- data = self.func(data)
2529
+ if self.with_batch_dim:
2530
+ data = ops.map(self.func, data)
2531
+ else:
2532
+ data = self.func(data)
2144
2533
  return {self.output_key: data}
2145
2534
 
2146
2535
 
2147
2536
  @ops_registry("pad")
2148
- class Pad(Operation, TFDataLayer):
2537
+ class Pad(Operation, DataLayer):
2149
2538
  """Pad layer for padding tensors to a specified shape."""
2150
2539
 
2151
2540
  def __init__(
@@ -2330,6 +2719,7 @@ class Downsample(Operation):
2330
2719
 
2331
2720
  def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
2332
2721
  super().__init__(
2722
+ additional_output_keys=["sampling_frequency", "n_ax"],
2333
2723
  **kwargs,
2334
2724
  )
2335
2725
  self.factor = factor
@@ -2894,7 +3284,7 @@ def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
2894
3284
  return bpf
2895
3285
 
2896
3286
 
2897
- def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
3287
+ def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth):
2898
3288
  """Design complex low-pass filter.
2899
3289
 
2900
3290
  The filter is a low-pass FIR filter modulated to the center frequency.
@@ -2902,16 +3292,16 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
2902
3292
  Args:
2903
3293
  num_taps (int): number of taps in filter.
2904
3294
  sampling_frequency (float): sample frequency.
2905
- f (float): center frequency.
2906
- bw (float): bandwidth in Hz.
3295
+ center_frequency (float): center frequency.
3296
+ bandwidth (float): bandwidth in Hz.
2907
3297
 
2908
3298
  Raises:
2909
- ValueError: if cutoff frequency (bw / 2) is not within (0, sampling_frequency / 2)
3299
+ ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
2910
3300
 
2911
3301
  Returns:
2912
3302
  ndarray: Complex-valued low-pass filter
2913
3303
  """
2914
- cutoff = bw / 2
3304
+ cutoff = bandwidth / 2
2915
3305
  if not (0 < cutoff < sampling_frequency / 2):
2916
3306
  raise ValueError(
2917
3307
  f"Cutoff frequency must be within (0, sampling_frequency / 2), "
@@ -2921,7 +3311,7 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
2921
3311
  lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
2922
3312
  # Modulate to center frequency to make it complex
2923
3313
  time_points = np.arange(num_taps) / sampling_frequency
2924
- lpf_complex = lpf * np.exp(1j * 2 * np.pi * f * time_points)
3314
+ lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points)
2925
3315
  return lpf_complex
2926
3316
 
2927
3317
 
@@ -3095,6 +3485,50 @@ def demodulate(data, center_frequency, sampling_frequency, axis=-3):
3095
3485
  iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
3096
3486
 
3097
3487
  # Split the complex signal into two channels
3098
- iq_data_two_channel = complex_to_channels(iq_data_signal_complex[..., 0])
3488
+ iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
3099
3489
 
3100
3490
  return iq_data_two_channel
3491
+
3492
+
3493
+ def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
3494
+ """Compute the time of the peak of each waveform in a stack of waveforms.
3495
+
3496
+ Args:
3497
+ waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
3498
+ center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
3499
+ (n_waveforms,) or a scalar if all waveforms have the same center frequency.
3500
+ waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
3501
+
3502
+ Returns:
3503
+ ndarray: The time to peak for each waveform in seconds.
3504
+ """
3505
+ t_peak = []
3506
+ center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
3507
+ for waveform, center_frequency in zip(waveforms, center_frequencies):
3508
+ t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
3509
+ return ops.stack(t_peak)
3510
+
3511
+
3512
+ def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
3513
+ """Compute the time of the peak of the waveform.
3514
+
3515
+ Args:
3516
+ waveform (ndarray): The waveform of shape (n_samples).
3517
+ center_frequency (float): The center frequency of the waveform in Hz.
3518
+ waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
3519
+
3520
+ Returns:
3521
+ float: The time to peak for the waveform in seconds.
3522
+ """
3523
+ n_samples = waveform.shape[0]
3524
+ if n_samples == 0:
3525
+ raise ValueError("Waveform has zero samples.")
3526
+
3527
+ waveforms_iq_complex_channels = demodulate(
3528
+ waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
3529
+ )
3530
+ waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
3531
+ envelope = ops.abs(waveforms_iq_complex)
3532
+ peak_idx = ops.argmax(envelope, axis=-1)
3533
+ t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
3534
+ return t_peak