zea 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -1
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/data/augmentations.py +1 -1
  9. zea/data/convert/__main__.py +93 -52
  10. zea/data/convert/camus.py +8 -2
  11. zea/data/convert/echonet.py +1 -1
  12. zea/data/convert/echonetlvh/__init__.py +1 -1
  13. zea/data/convert/verasonics.py +810 -772
  14. zea/data/data_format.py +0 -2
  15. zea/data/file.py +28 -0
  16. zea/data/preset_utils.py +1 -1
  17. zea/display.py +1 -1
  18. zea/doppler.py +5 -5
  19. zea/func/__init__.py +109 -0
  20. zea/{tensor_ops.py → func/tensor.py} +32 -8
  21. zea/func/ultrasound.py +500 -0
  22. zea/internal/_generate_keras_ops.py +5 -5
  23. zea/metrics.py +6 -5
  24. zea/models/diffusion.py +1 -1
  25. zea/models/echonetlvh.py +1 -1
  26. zea/models/gmm.py +1 -1
  27. zea/ops/__init__.py +188 -0
  28. zea/ops/base.py +442 -0
  29. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  30. zea/ops/pipeline.py +1472 -0
  31. zea/ops/tensor.py +356 -0
  32. zea/ops/ultrasound.py +890 -0
  33. zea/probes.py +2 -10
  34. zea/scan.py +17 -20
  35. zea/tools/fit_scan_cone.py +1 -1
  36. zea/tools/selection_tool.py +1 -1
  37. zea/tracking/lucas_kanade.py +1 -1
  38. zea/tracking/segmentation.py +1 -1
  39. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/METADATA +3 -1
  40. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/RECORD +43 -37
  41. zea/ops.py +0 -3534
  42. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  43. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  44. {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops/pipeline.py ADDED
@@ -0,0 +1,1472 @@
1
+ import json
2
+ from typing import Dict, List, Union
3
+
4
+ import keras
5
+ import numpy as np
6
+ import yaml
7
+ from keras import ops
8
+
9
+ from zea import log
10
+ from zea.backend import jit
11
+ from zea.config import Config
12
+ from zea.func.tensor import (
13
+ vmap,
14
+ )
15
+ from zea.func.ultrasound import channels_to_complex, complex_to_channels
16
+ from zea.internal.core import (
17
+ DataTypes,
18
+ ZEADecoderJSON,
19
+ ZEAEncoderJSON,
20
+ dict_to_tensor,
21
+ )
22
+ from zea.internal.core import Object as ZEAObject
23
+ from zea.internal.registry import ops_registry
24
+ from zea.ops.base import (
25
+ Operation,
26
+ get_ops,
27
+ )
28
+ from zea.ops.tensor import Normalize
29
+ from zea.ops.ultrasound import (
30
+ Demodulate,
31
+ EnvelopeDetect,
32
+ LogCompress,
33
+ PfieldWeighting,
34
+ ReshapeGrid,
35
+ TOFCorrection,
36
+ )
37
+ from zea.probes import Probe
38
+ from zea.scan import Scan
39
+ from zea.utils import (
40
+ FunctionTimer,
41
+ )
42
+
43
+
44
+ @ops_registry("pipeline")
45
+ class Pipeline:
46
+ """Pipeline class for processing ultrasound data through a series of operations."""
47
+
48
+ def __init__(
49
+ self,
50
+ operations: List[Operation],
51
+ with_batch_dim: bool = True,
52
+ jit_options: Union[str, None] = "ops",
53
+ jit_kwargs: dict | None = None,
54
+ name="pipeline",
55
+ validate=True,
56
+ timed: bool = False,
57
+ ):
58
+ """
59
+ Initialize a pipeline.
60
+
61
+ Args:
62
+ operations (list): A list of Operation instances representing the operations
63
+ to be performed.
64
+ with_batch_dim (bool, optional): Whether operations should expect a batch dimension.
65
+ Defaults to True.
66
+ jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None.
67
+
68
+ - "pipeline": compiles the entire pipeline as a single function.
69
+ This may be faster but does not preserve python control flow, such as caching.
70
+
71
+ - "ops": compiles each operation separately. This preserves python control flow and
72
+ caching functionality, but speeds up the operations.
73
+
74
+ - None: disables JIT compilation.
75
+
76
+ Defaults to "ops".
77
+
78
+ jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler.
79
+ name (str, optional): The name of the pipeline. Defaults to "pipeline".
80
+ validate (bool, optional): Whether to validate the pipeline. Defaults to True.
81
+ timed (bool, optional): Whether to time each operation. Defaults to False.
82
+
83
+ """
84
+ self._call_pipeline = self.call
85
+ self.name = name
86
+
87
+ self._pipeline_layers = operations
88
+
89
+ if jit_options not in ["pipeline", "ops", None]:
90
+ raise ValueError("jit_options must be 'pipeline', 'ops', or None")
91
+
92
+ self.with_batch_dim = with_batch_dim
93
+ self._validate_flag = validate
94
+
95
+ # Setup timer
96
+ if jit_options == "pipeline" and timed:
97
+ raise ValueError(
98
+ "timed=True cannot be used with jit_options='pipeline' as the entire "
99
+ "pipeline is compiled into a single function. Try setting jit_options to "
100
+ "'ops' or None."
101
+ )
102
+ if timed:
103
+ log.warning(
104
+ "Timer has been initialized for the pipeline. To get an accurate timing estimate, "
105
+ "the `block_until_ready()` is used, which will slow down the execution, so "
106
+ "do not use for regular processing!"
107
+ )
108
+ self._callable_layers = self._get_timed_operations()
109
+ else:
110
+ self._callable_layers = self._pipeline_layers
111
+ self._timed = timed
112
+
113
+ if validate:
114
+ self.validate()
115
+ else:
116
+ log.warning("Pipeline validation is disabled, make sure to validate manually.")
117
+
118
+ if jit_kwargs is None:
119
+ jit_kwargs = {}
120
+
121
+ if keras.backend.backend() == "jax" and self.static_params != []:
122
+ jit_kwargs = {"static_argnames": self.static_params}
123
+
124
+ self.jit_kwargs = jit_kwargs
125
+ self.jit_options = jit_options # will handle the jit compilation
126
+
127
+ self._logged_difference_keys = False
128
+
129
+ # Do not log again for nested pipelines
130
+ for nested_pipeline in self._nested_pipelines:
131
+ nested_pipeline._logged_difference_keys = True
132
+
133
+ def needs(self, key) -> bool:
134
+ """Check if the pipeline needs a specific key at the input."""
135
+ return key in self.needs_keys
136
+
137
+ @property
138
+ def _nested_pipelines(self):
139
+ return [operation for operation in self.operations if isinstance(operation, Pipeline)]
140
+
141
+ @property
142
+ def output_keys(self) -> set:
143
+ """All output keys the pipeline guarantees to produce."""
144
+ output_keys = set()
145
+ for operation in self.operations:
146
+ output_keys.update(operation.output_keys)
147
+ return output_keys
148
+
149
+ @property
150
+ def valid_keys(self) -> set:
151
+ """Get a set of valid keys for the pipeline.
152
+
153
+ This is all keys that can be passed to the pipeline as input.
154
+ """
155
+ valid_keys = set()
156
+ for operation in self.operations:
157
+ valid_keys.update(operation.valid_keys)
158
+ return valid_keys
159
+
160
+ @property
161
+ def static_params(self) -> List[str]:
162
+ """Get a list of static parameters for the pipeline."""
163
+ static_params = []
164
+ for operation in self.operations:
165
+ static_params.extend(operation.static_params)
166
+ return list(set(static_params))
167
+
168
+ @property
169
+ def needs_keys(self) -> set:
170
+ """Get a set of all input keys needed by the pipeline.
171
+
172
+ Will keep track of keys that are already provided by previous operations.
173
+ """
174
+ needs = set()
175
+ has_so_far = set()
176
+ previous_operation = None
177
+ for operation in self.operations:
178
+ if previous_operation is not None:
179
+ has_so_far.update(previous_operation.output_keys)
180
+ needs.update(operation.needs_keys - has_so_far)
181
+ previous_operation = operation
182
+ return needs
183
+
184
+ @classmethod
185
+ def from_default(
186
+ cls,
187
+ beamformer="delay_and_sum",
188
+ num_patches=100,
189
+ baseband=False,
190
+ enable_pfield=False,
191
+ timed=False,
192
+ **kwargs,
193
+ ) -> "Pipeline":
194
+ """Create a default pipeline.
195
+
196
+ Args:
197
+ beamformer (str): Type of beamformer to use. Currently supporting,
198
+ "delay_and_sum" and "delay_multiply_and_sum". Defaults to "delay_and_sum".
199
+ num_patches (int): Number of patches for the PatchedGrid operation.
200
+ Defaults to 100. If you get an out of memory error, try to increase this number.
201
+ baseband (bool): If True, assume the input data is baseband (I/Q) data,
202
+ which has 2 channels (last dim). Defaults to False, which assumes RF data,
203
+ so input signal has a single channel dim and is still on carrier frequency.
204
+ enable_pfield (bool): If True, apply PfieldWeighting. Defaults to False.
205
+ This will calculate pressure field and only beamform the data to those locations.
206
+ timed (bool, optional): Whether to time each operation. Defaults to False.
207
+ **kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
208
+
209
+ """
210
+ operations = []
211
+
212
+ # Add the demodulate operation
213
+ if not baseband:
214
+ operations.append(Demodulate())
215
+
216
+ # Add beamforming ops
217
+ operations.append(
218
+ Beamform(
219
+ beamformer=beamformer,
220
+ num_patches=num_patches,
221
+ enable_pfield=enable_pfield,
222
+ ),
223
+ )
224
+
225
+ # Add display ops
226
+ operations += [
227
+ EnvelopeDetect(),
228
+ Normalize(),
229
+ LogCompress(),
230
+ ]
231
+ return cls(operations, timed=timed, **kwargs)
232
+
233
+ def copy(self) -> "Pipeline":
234
+ """Create a copy of the pipeline."""
235
+ return Pipeline(
236
+ self._pipeline_layers.copy(),
237
+ with_batch_dim=self.with_batch_dim,
238
+ jit_options=self.jit_options,
239
+ jit_kwargs=self.jit_kwargs,
240
+ name=self.name,
241
+ validate=self._validate_flag,
242
+ timed=self._timed,
243
+ )
244
+
245
+ def reinitialize(self):
246
+ """Reinitialize the pipeline in place."""
247
+ self.__init__(
248
+ self._pipeline_layers,
249
+ with_batch_dim=self.with_batch_dim,
250
+ jit_options=self.jit_options,
251
+ jit_kwargs=self.jit_kwargs,
252
+ name=self.name,
253
+ validate=self._validate_flag,
254
+ timed=self._timed,
255
+ )
256
+
257
+ def prepend(self, operation: Operation):
258
+ """Prepend an operation to the pipeline."""
259
+ self._pipeline_layers.insert(0, operation)
260
+ self.reinitialize()
261
+
262
+ def append(self, operation: Operation):
263
+ """Append an operation to the pipeline."""
264
+ self._pipeline_layers.append(operation)
265
+ self.reinitialize()
266
+
267
+ def insert(self, index: int, operation: Operation):
268
+ """Insert an operation at a specific index in the pipeline."""
269
+ if index < 0 or index > len(self._pipeline_layers):
270
+ raise IndexError("Index out of bounds for inserting operation.")
271
+ self._pipeline_layers.insert(index, operation)
272
+ self.reinitialize()
273
+
274
+ @property
275
+ def operations(self):
276
+ """Alias for self.layers to match the zea naming convention"""
277
+ return self._pipeline_layers
278
+
279
+ def reset_timer(self):
280
+ """Reset the timer for timed operations."""
281
+ if self._timed:
282
+ self._callable_layers = self._get_timed_operations()
283
+ else:
284
+ log.warning(
285
+ "Timer has not been initialized. Set timed=True when initializing the pipeline."
286
+ )
287
+
288
+ def _get_timed_operations(self):
289
+ """Get a list of timed operations."""
290
+ self.timer = FunctionTimer()
291
+ return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
292
+
293
+ def call(self, **inputs):
294
+ """Process input data through the pipeline."""
295
+ for operation in self._callable_layers:
296
+ try:
297
+ outputs = operation(**inputs)
298
+ except KeyError as exc:
299
+ raise KeyError(
300
+ f"[zea.Pipeline] Operation '{operation.__class__.__name__}' "
301
+ f"requires input key '{exc.args[0]}', "
302
+ "but it was not provided in the inputs.\n"
303
+ "Check whether the objects (such as `zea.Scan`) passed to "
304
+ "`pipeline.prepare_parameters()` contain all required keys.\n"
305
+ f"Current list of all passed keys: {list(inputs.keys())}\n"
306
+ f"Valid keys for this pipeline: {self.valid_keys}"
307
+ ) from exc
308
+ except Exception as exc:
309
+ raise RuntimeError(
310
+ f"[zea.Pipeline] Error in operation '{operation.__class__.__name__}': {exc}"
311
+ )
312
+ inputs = outputs
313
+ return outputs
314
+
315
+ def __call__(self, return_numpy=False, **inputs):
316
+ """Process input data through the pipeline."""
317
+
318
+ if any(key in inputs for key in ["probe", "scan", "config"]) or any(
319
+ isinstance(arg, ZEAObject) for arg in inputs.values()
320
+ ):
321
+ raise ValueError(
322
+ "Probe, Scan and Config objects should be first processed with "
323
+ "`Pipeline.prepare_parameters` before calling the pipeline. "
324
+ "e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
325
+ )
326
+
327
+ if any(isinstance(arg, str) for arg in inputs.values()):
328
+ raise ValueError(
329
+ "Pipeline does not support string inputs. "
330
+ "Please ensure all inputs are convertible to tensors."
331
+ )
332
+
333
+ if not self._logged_difference_keys:
334
+ difference_keys = set(inputs.keys()) - self.valid_keys
335
+ if difference_keys:
336
+ log.debug(
337
+ f"[zea.Pipeline] The following input keys are not used by the pipeline: "
338
+ f"{difference_keys}. Make sure this is intended. "
339
+ "This warning will only be shown once."
340
+ )
341
+ self._logged_difference_keys = True
342
+
343
+ ## PROCESSING
344
+ outputs = self._call_pipeline(**inputs)
345
+
346
+ ## PREPARE OUTPUT
347
+ if return_numpy:
348
+ # Convert tensors to numpy arrays but preserve None values
349
+ outputs = {
350
+ k: ops.convert_to_numpy(v) if ops.is_tensor(v) else v for k, v in outputs.items()
351
+ }
352
+
353
+ return outputs
354
+
355
+ @property
356
+ def jit_options(self):
357
+ """Get the jit_options property of the pipeline."""
358
+ return self._jit_options
359
+
360
+ @jit_options.setter
361
+ def jit_options(self, value: Union[str, None]):
362
+ """Set the jit_options property of the pipeline."""
363
+ self._jit_options = value
364
+ if value == "pipeline":
365
+ assert self.jittable, log.error(
366
+ "jit_options 'pipeline' cannot be used as the entire pipeline is not jittable. "
367
+ "The following operations are not jittable: "
368
+ f"{self.unjitable_ops}. "
369
+ "Try setting jit_options to 'ops' or None."
370
+ )
371
+ self.jit()
372
+ return
373
+ else:
374
+ self.unjit()
375
+
376
+ for operation in self.operations:
377
+ if isinstance(operation, Pipeline):
378
+ operation.jit_options = value
379
+ else:
380
+ if operation.jittable and operation._jit_compile:
381
+ operation.set_jit(value == "ops")
382
+
383
+ def jit(self):
384
+ """JIT compile the pipeline."""
385
+ self._call_pipeline = jit(self.call, **self.jit_kwargs)
386
+
387
+ def unjit(self):
388
+ """Un-JIT compile the pipeline."""
389
+ self._call_pipeline = self.call
390
+
391
+ @property
392
+ def jittable(self):
393
+ """Check if all operations in the pipeline are jittable."""
394
+ return all(operation.jittable for operation in self.operations)
395
+
396
+ @property
397
+ def unjitable_ops(self):
398
+ """Get a list of operations that are not jittable."""
399
+ return [operation for operation in self.operations if not operation.jittable]
400
+
401
+ @property
402
+ def with_batch_dim(self):
403
+ """Get the with_batch_dim property of the pipeline."""
404
+ return self._with_batch_dim
405
+
406
+ @with_batch_dim.setter
407
+ def with_batch_dim(self, value):
408
+ """Set the with_batch_dim property of the pipeline."""
409
+ self._with_batch_dim = value
410
+ for operation in self.operations:
411
+ operation.with_batch_dim = value
412
+
413
+ @property
414
+ def input_data_type(self):
415
+ """Get the input_data_type property of the pipeline."""
416
+ return self.operations[0].input_data_type
417
+
418
+ @property
419
+ def output_data_type(self):
420
+ """Get the output_data_type property of the pipeline."""
421
+ return self.operations[-1].output_data_type
422
+
423
+ def validate(self):
424
+ """Validate the pipeline by checking the compatibility of the operations."""
425
+ operations = self.operations
426
+ for i in range(len(operations) - 1):
427
+ if operations[i].output_data_type is None:
428
+ continue
429
+ if operations[i + 1].input_data_type is None:
430
+ continue
431
+ if operations[i].output_data_type != operations[i + 1].input_data_type:
432
+ raise ValueError(
433
+ f"Operation {operations[i].__class__.__name__} output data type "
434
+ f"({operations[i].output_data_type}) is not compatible "
435
+ f"with the input data type ({operations[i + 1].input_data_type}) "
436
+ f"of operation {operations[i + 1].__class__.__name__}"
437
+ )
438
+
439
+ def set_params(self, **params):
440
+ """Set parameters for the operations in the pipeline by adding them to the cache."""
441
+ for operation in self.operations:
442
+ operation_params = {
443
+ key: value for key, value in params.items() if key in operation.valid_keys
444
+ }
445
+ if operation_params:
446
+ operation.set_input_cache(operation_params)
447
+
448
+ def get_params(self, per_operation: bool = False):
449
+ """Get a snapshot of the current parameters of the operations in the pipeline.
450
+
451
+ Args:
452
+ per_operation (bool): If True, return a list of dictionaries for each operation.
453
+ If False, return a single dictionary with all parameters combined.
454
+ """
455
+ if per_operation:
456
+ return [operation._input_cache.copy() for operation in self.operations]
457
+ else:
458
+ params = {}
459
+ for operation in self.operations:
460
+ params.update(operation._input_cache)
461
+ return params
462
+
463
+ def __str__(self):
464
+ """String representation of the pipeline."""
465
+ operations = []
466
+ for operation in self.operations:
467
+ if isinstance(operation, Pipeline):
468
+ operations.append(f"{operation.__class__.__name__}({str(operation)})")
469
+ else:
470
+ operations.append(operation.__class__.__name__)
471
+ string = " -> ".join(operations)
472
+ return string
473
+
474
+ def __repr__(self):
475
+ """String representation of the pipeline."""
476
+ operations = []
477
+ for operation in self.operations:
478
+ if isinstance(operation, Pipeline):
479
+ operations.append(repr(operation))
480
+ else:
481
+ operations.append(operation.__class__.__name__)
482
+ return f"<Pipeline {self.name}=({', '.join(operations)})>"
483
+
484
+ @classmethod
485
+ def load(cls, file_path: str, **kwargs) -> "Pipeline":
486
+ """Load a pipeline from a JSON or YAML file."""
487
+ if file_path.endswith(".json"):
488
+ with open(file_path, "r", encoding="utf-8") as f:
489
+ json_str = f.read()
490
+ return pipeline_from_json(json_str, **kwargs)
491
+ elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
492
+ return pipeline_from_yaml(file_path, **kwargs)
493
+ else:
494
+ raise ValueError("File must have extension .json, .yaml, or .yml")
495
+
496
+ def get_dict(self) -> dict:
497
+ """Convert the pipeline to a dictionary."""
498
+ config = {}
499
+ config["name"] = ops_registry.get_name(self)
500
+ config["operations"] = self._pipeline_to_list(self)
501
+ config["params"] = {
502
+ "with_batch_dim": self.with_batch_dim,
503
+ "jit_options": self.jit_options,
504
+ "jit_kwargs": self.jit_kwargs,
505
+ }
506
+ return config
507
+
508
+ @staticmethod
509
+ def _pipeline_to_list(pipeline):
510
+ """Convert the pipeline to a list of operations."""
511
+ ops_list = []
512
+ for op in pipeline.operations:
513
+ ops_list.append(op.get_dict())
514
+ return ops_list
515
+
516
+ @classmethod
517
+ def from_config(cls, config: Dict, **kwargs) -> "Pipeline":
518
+ """Create a pipeline from a dictionary or ``zea.Config`` object.
519
+
520
+ Args:
521
+ config (dict or Config): Configuration dictionary or ``zea.Config`` object.
522
+ **kwargs: Additional keyword arguments to be passed to the pipeline.
523
+
524
+ Note:
525
+ Must have a ``pipeline`` key with a subkey ``operations``.
526
+
527
+ Example:
528
+ .. doctest::
529
+
530
+ >>> from zea import Config, Pipeline
531
+ >>> config = Config(
532
+ ... {
533
+ ... "operations": [
534
+ ... "identity",
535
+ ... ],
536
+ ... }
537
+ ... )
538
+ >>> pipeline = Pipeline.from_config(config)
539
+ """
540
+ return pipeline_from_config(Config(config), **kwargs)
541
+
542
+ @classmethod
543
+ def from_yaml(cls, file_path: str, **kwargs) -> "Pipeline":
544
+ """Create a pipeline from a YAML file.
545
+
546
+ Args:
547
+ file_path (str): Path to the YAML file.
548
+ **kwargs: Additional keyword arguments to be passed to the pipeline.
549
+
550
+ Note:
551
+ Must have the a `pipeline` key with a subkey `operations`.
552
+
553
+ Example:
554
+ .. doctest::
555
+
556
+ >>> import yaml
557
+ >>> from zea import Config
558
+ >>> # Create a sample pipeline YAML file
559
+ >>> pipeline_dict = {
560
+ ... "operations": [
561
+ ... "identity",
562
+ ... ],
563
+ ... }
564
+ >>> with open("pipeline.yaml", "w") as f:
565
+ ... yaml.dump(pipeline_dict, f)
566
+ >>> from zea.ops import Pipeline
567
+ >>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
568
+ """
569
+ return pipeline_from_yaml(file_path, **kwargs)
570
+
571
+ @classmethod
572
+ def from_json(cls, json_string: str, **kwargs) -> "Pipeline":
573
+ """Create a pipeline from a JSON string.
574
+
575
+ Args:
576
+ json_string (str): JSON string representing the pipeline.
577
+ **kwargs: Additional keyword arguments to be passed to the pipeline.
578
+
579
+ Note:
580
+ Must have the `operations` key.
581
+
582
+ Example:
583
+ ```python
584
+ json_string = '{"operations": ["identity"]}'
585
+ pipeline = Pipeline.from_json(json_string)
586
+ ```
587
+ """
588
+ return pipeline_from_json(json_string, **kwargs)
589
+
590
+ def to_config(self) -> Config:
591
+ """Convert the pipeline to a `zea.Config` object."""
592
+ return pipeline_to_config(self)
593
+
594
+ def to_json(self) -> str:
595
+ """Convert the pipeline to a JSON string."""
596
+ return pipeline_to_json(self)
597
+
598
+ def to_yaml(self, file_path: str) -> None:
599
+ """Convert the pipeline to a YAML file."""
600
+ pipeline_to_yaml(self, file_path)
601
+
602
+ @property
603
+ def key(self) -> str:
604
+ """Input key of the pipeline."""
605
+ return self.operations[0].key
606
+
607
+ @property
608
+ def output_key(self) -> str:
609
+ """Output key of the pipeline."""
610
+ return self.operations[-1].output_key
611
+
612
+ def __eq__(self, other):
613
+ """Check if two pipelines are equal."""
614
+ if not isinstance(other, Pipeline):
615
+ return False
616
+
617
+ # Compare the operations in both pipelines
618
+ if len(self.operations) != len(other.operations):
619
+ return False
620
+
621
+ for op1, op2 in zip(self.operations, other.operations):
622
+ if not op1 == op2:
623
+ return False
624
+
625
+ return True
626
+
627
+ def prepare_parameters(
628
+ self,
629
+ probe: Probe = None,
630
+ scan: Scan = None,
631
+ config: Config = None,
632
+ **kwargs,
633
+ ):
634
+ """Prepare Probe, Scan and Config objects for the pipeline.
635
+
636
+ Serializes `zea.core.Object` instances and converts them to
637
+ dictionary of tensors.
638
+
639
+ Args:
640
+ probe: Probe object.
641
+ scan: Scan object.
642
+ config: Config object.
643
+ **kwargs: Additional keyword arguments to be included in the inputs.
644
+
645
+ Returns:
646
+ dict: Dictionary of inputs with all values as tensors.
647
+ """
648
+ # Initialize dictionaries for probe, scan, and config
649
+ probe_dict, scan_dict, config_dict = {}, {}, {}
650
+
651
+ # Process args to extract Probe, Scan, and Config objects
652
+ if probe is not None:
653
+ assert isinstance(probe, Probe), (
654
+ f"Expected an instance of `zea.probes.Probe`, got {type(probe)}"
655
+ )
656
+ probe_dict = probe.to_tensor(keep_as_is=self.static_params)
657
+
658
+ if scan is not None:
659
+ assert isinstance(scan, Scan), (
660
+ f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
661
+ )
662
+ scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
663
+
664
+ if config is not None:
665
+ assert isinstance(config, Config), (
666
+ f"Expected an instance of `zea.config.Config`, got {type(config)}"
667
+ )
668
+ config_dict.update(config.to_tensor(keep_as_is=self.static_params))
669
+
670
+ # Convert all kwargs to tensors
671
+ tensor_kwargs = dict_to_tensor(kwargs, keep_as_is=self.static_params)
672
+
673
+ # combine probe, scan, config and kwargs
674
+ # explicitly so we know which keys overwrite which
675
+ # kwargs > config > scan > probe
676
+ inputs = {
677
+ **probe_dict,
678
+ **scan_dict,
679
+ **config_dict,
680
+ **tensor_kwargs,
681
+ }
682
+
683
+ return inputs
684
+
685
+
686
+ @ops_registry("branched_pipeline")
687
+ class BranchedPipeline(Operation):
688
+ """Operation that processes data through multiple branches.
689
+
690
+ This operation takes input data, processes it through multiple parallel branches,
691
+ and then merges the results from those branches using the specified merge strategy.
692
+ """
693
+
694
+ def __init__(self, branches=None, merge_strategy="nested", **kwargs):
695
+ """Initialize a branched pipeline.
696
+
697
+ Args:
698
+ branches (List[Union[List, Pipeline, Operation]]): List of branch operations
699
+ merge_strategy (str or callable): How to merge the outputs from branches:
700
+ - "nested" (default): Return outputs as a dictionary keyed by branch name
701
+ - "flatten": Flatten outputs by prefixing keys with the branch name
702
+ - "suffix": Flatten outputs by suffixing keys with the branch name
703
+ - callable: A custom merge function that accepts the branch outputs dict
704
+ **kwargs: Additional arguments for the Operation base class
705
+ """
706
+ super().__init__(**kwargs)
707
+
708
+ # Convert branch specifications to operation chains
709
+ if branches is None:
710
+ branches = []
711
+
712
+ self.branches = {}
713
+ for i, branch in enumerate(branches, start=1):
714
+ branch_name = f"branch_{i}"
715
+ # Convert different branch specification types
716
+ if isinstance(branch, list):
717
+ # Convert list to operation chain
718
+ self.branches[branch_name] = make_operation_chain(branch)
719
+ elif isinstance(branch, (Pipeline, Operation)):
720
+ # Already a pipeline or operation
721
+ self.branches[branch_name] = branch
722
+ else:
723
+ raise ValueError(
724
+ f"Branch must be a list, Pipeline, or Operation, got {type(branch)}"
725
+ )
726
+
727
+ # Set merge strategy
728
+ self.merge_strategy = merge_strategy
729
+ if isinstance(merge_strategy, str):
730
+ if merge_strategy == "nested":
731
+ self._merge_function = lambda outputs: outputs
732
+ elif merge_strategy == "flatten":
733
+ self._merge_function = self.flatten_outputs
734
+ elif merge_strategy == "suffix":
735
+ self._merge_function = self.suffix_merge_outputs
736
+ else:
737
+ raise ValueError(f"Unknown merge_strategy: {merge_strategy}")
738
+ elif callable(merge_strategy):
739
+ self._merge_function = merge_strategy
740
+ else:
741
+ raise ValueError("Invalid merge_strategy type provided.")
742
+
743
+ def call(self, **kwargs):
744
+ """Process input through branches and merge results.
745
+
746
+ Args:
747
+ **kwargs: Input keyword arguments
748
+
749
+ Returns:
750
+ dict: Merged outputs from all branches according to merge strategy
751
+ """
752
+ branch_outputs = {}
753
+ for branch_name, branch in self.branches.items():
754
+ # Each branch gets a fresh copy of kwargs to avoid interference
755
+ branch_kwargs = kwargs.copy()
756
+
757
+ # Process through the branch
758
+ branch_result = branch(**branch_kwargs)
759
+
760
+ # Store branch outputs
761
+ branch_outputs[branch_name] = branch_result
762
+
763
+ # Apply merge strategy to combine outputs
764
+ merged_outputs = self._merge_function(branch_outputs)
765
+
766
+ return merged_outputs
767
+
768
+ def flatten_outputs(self, outputs: dict) -> dict:
769
+ """
770
+ Flatten a nested dictionary by prefixing keys with the branch name.
771
+ For each branch, the resulting key is "{branch_name}_{original_key}".
772
+ """
773
+ flat = {}
774
+ for branch_name, branch_dict in outputs.items():
775
+ for key, value in branch_dict.items():
776
+ new_key = f"{branch_name}_{key}"
777
+ if new_key in flat:
778
+ raise ValueError(f"Key collision detected for {new_key}")
779
+ flat[new_key] = value
780
+ return flat
781
+
782
+ def suffix_merge_outputs(self, outputs: dict) -> dict:
783
+ """
784
+ Flatten a nested dictionary by suffixing keys with the branch name.
785
+ For each branch, the resulting key is "{original_key}_{branch_name}".
786
+ """
787
+ flat = {}
788
+ for branch_name, branch_dict in outputs.items():
789
+ for key, value in branch_dict.items():
790
+ new_key = f"{key}_{branch_name}"
791
+ if new_key in flat:
792
+ raise ValueError(f"Key collision detected for {new_key}")
793
+ flat[new_key] = value
794
+ return flat
795
+
796
+ def get_config(self):
797
+ """Return the config dictionary for serialization."""
798
+ config = super().get_config()
799
+
800
+ # Add branch configurations
801
+ branch_configs = {}
802
+ for branch_name, branch in self.branches.items():
803
+ if isinstance(branch, Pipeline):
804
+ # Get the operations list from the Pipeline
805
+ branch_configs[branch_name] = branch.get_config()
806
+ elif isinstance(branch, list):
807
+ # Convert list of operations to list of operation configs
808
+ branch_op_configs = []
809
+ for op in branch:
810
+ branch_op_configs.append(op.get_config())
811
+ branch_configs[branch_name] = {"operations": branch_op_configs}
812
+ else:
813
+ # Single operation
814
+ branch_configs[branch_name] = branch.get_config()
815
+
816
+ # Add merge strategy
817
+ if isinstance(self.merge_strategy, str):
818
+ merge_strategy_config = self.merge_strategy
819
+ else:
820
+ # For custom functions, use the name if available
821
+ merge_strategy_config = getattr(self.merge_strategy, "__name__", "custom")
822
+
823
+ config.update(
824
+ {
825
+ "branches": branch_configs,
826
+ "merge_strategy": merge_strategy_config,
827
+ }
828
+ )
829
+
830
+ return config
831
+
832
+ def get_dict(self):
833
+ """Get the configuration of the operation."""
834
+ config = super().get_dict()
835
+ config.update({"name": "branched_pipeline"})
836
+
837
+ # Add branches (recursively) to the config
838
+ branches = {}
839
+ for branch_name, branch in self.branches.items():
840
+ if isinstance(branch, Pipeline):
841
+ branches[branch_name] = branch.get_dict()
842
+ elif isinstance(branch, list):
843
+ branches[branch_name] = [op.get_dict() for op in branch]
844
+ else:
845
+ branches[branch_name] = branch.get_dict()
846
+ config["branches"] = branches
847
+ config["merge_strategy"] = self.merge_strategy
848
+ return config
849
+
850
+
851
+ @ops_registry("map")
852
+ class Map(Pipeline):
853
+ """
854
+ A pipeline that maps its operations over specified input arguments.
855
+
856
+ This can be used to reduce memory usage by processing data in chunks.
857
+
858
+ Notes
859
+ -----
860
+ - When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline.
861
+ - Changing anything other than ``self.output_key`` in the dict will not be propagated.
862
+ - Will be jitted as a single operation, not the individual operations.
863
+ - This class handles the batching.
864
+
865
+ For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for
866
+ jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_.
867
+
868
+ Example
869
+ -------
870
+ .. doctest::
871
+
872
+ >>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection
873
+
874
+ >>> # apply operations in batches of 8
875
+ >>> # in this case, over the first axis of "data"
876
+ >>> # or more specifically, process 8 transmits at a time
877
+
878
+ >>> pipeline_mapped = Map(
879
+ ... [
880
+ ... Demodulate(),
881
+ ... TOFCorrection(),
882
+ ... ],
883
+ ... argnames="data",
884
+ ... batch_size=8,
885
+ ... )
886
+
887
+ >>> # you can also map a subset of the operations
888
+ >>> # for example, demodulate in 4 chunks
889
+ >>> # or more specifically, split the transmit axis into 4 parts
890
+
891
+ >>> pipeline_mapped = Pipeline(
892
+ ... [
893
+ ... Map([Demodulate()], argnames="data", chunks=4),
894
+ ... TOFCorrection(),
895
+ ... ],
896
+ ... )
897
+ """
898
+
899
+ def __init__(
900
+ self,
901
+ operations: List[Operation],
902
+ argnames: List[str] | str,
903
+ in_axes: List[Union[int, None]] | int = 0,
904
+ out_axes: List[Union[int, None]] | int = 0,
905
+ chunks: int | None = None,
906
+ batch_size: int | None = None,
907
+ **kwargs,
908
+ ):
909
+ """
910
+ Args:
911
+ operations (list): List of operations to be performed.
912
+ argnames (str or list): List of argument names (or keys) to map over.
913
+ Can also be a single string if only one argument is mapped over.
914
+ in_axes (int or list): Axes to map over for each argument.
915
+ If a single int is provided, it is used for all arguments.
916
+ out_axes (int or list): Axes to map over for each output.
917
+ If a single int is provided, it is used for all outputs.
918
+ chunks (int, optional): Number of chunks to split the input data into.
919
+ If None, no chunking is performed. Mutually exclusive with ``batch_size``.
920
+ batch_size (int, optional): Size of batches to process at once.
921
+ If None, no batching is performed. Mutually exclusive with ``chunks``.
922
+ """
923
+ super().__init__(operations, **kwargs)
924
+
925
+ if batch_size is not None and chunks is not None:
926
+ raise ValueError(
927
+ "batch_size and chunks are mutually exclusive. Please specify only one."
928
+ )
929
+
930
+ if batch_size is not None and batch_size <= 0:
931
+ raise ValueError("batch_size must be a positive integer.")
932
+
933
+ if chunks is not None and chunks <= 0:
934
+ raise ValueError("chunks must be a positive integer.")
935
+
936
+ if isinstance(argnames, str):
937
+ argnames = [argnames]
938
+
939
+ self.argnames = argnames
940
+ self.in_axes = in_axes
941
+ self.out_axes = out_axes
942
+ self.chunks = chunks
943
+ self.batch_size = batch_size
944
+
945
+ if chunks is None and batch_size is None:
946
+ log.warning(
947
+ "[zea.ops.Map] Both `chunks` and `batch_size` are None. "
948
+ "This will behave like a normal Pipeline. "
949
+ "Consider setting one of them to process data in chunks or batches."
950
+ )
951
+
952
+ def call_item(**inputs):
953
+ """Process data in patches."""
954
+ mapped_args = []
955
+ for argname in argnames:
956
+ mapped_args.append(inputs.pop(argname, None))
957
+
958
+ def patched_call(*args):
959
+ mapped_kwargs = [(k, v) for k, v in zip(argnames, args)]
960
+ out = super(Map, self).call(**dict(mapped_kwargs), **inputs)
961
+
962
+ # TODO: maybe it is possible to output everything?
963
+ # e.g. prepend a empty dimension to all inputs and just map over everything?
964
+ return out[self.output_key]
965
+
966
+ out = vmap(
967
+ patched_call,
968
+ in_axes=in_axes,
969
+ out_axes=out_axes,
970
+ chunks=chunks,
971
+ batch_size=batch_size,
972
+ fn_supports_batch=True,
973
+ disable_jit=not bool(self.jit_options),
974
+ )(*mapped_args)
975
+
976
+ return out
977
+
978
+ self.call_item = call_item
979
+
980
+ @property
981
+ def jit_options(self):
982
+ """Get the jit_options property of the pipeline."""
983
+ return self._jit_options
984
+
985
+ @jit_options.setter
986
+ def jit_options(self, value):
987
+ """Set the jit_options property of the pipeline."""
988
+ self._jit_options = value
989
+ if value in ["pipeline", "ops"]:
990
+ self.jit()
991
+ else:
992
+ self.unjit()
993
+
994
+ def jit(self):
995
+ """JIT compile the pipeline."""
996
+ self._jittable_call = jit(self.jittable_call, **self.jit_kwargs)
997
+
998
+ def unjit(self):
999
+ """Un-JIT compile the pipeline."""
1000
+ self._jittable_call = self.jittable_call
1001
+ self._call_pipeline = self.call
1002
+
1003
+ @property
1004
+ def with_batch_dim(self):
1005
+ """Get the with_batch_dim property of the pipeline."""
1006
+ return self._with_batch_dim
1007
+
1008
+ @with_batch_dim.setter
1009
+ def with_batch_dim(self, value):
1010
+ """Set the with_batch_dim property of the pipeline.
1011
+ The class handles the batching so the operations have to be set to False."""
1012
+ self._with_batch_dim = value
1013
+ for operation in self.operations:
1014
+ operation.with_batch_dim = False
1015
+
1016
+ def jittable_call(self, **inputs):
1017
+ """Process input data through the pipeline."""
1018
+ if self._with_batch_dim:
1019
+ input_data = inputs.pop(self.key)
1020
+ output = ops.map(
1021
+ lambda x: self.call_item(**{self.key: x, **inputs}),
1022
+ input_data,
1023
+ )
1024
+ else:
1025
+ output = self.call_item(**inputs)
1026
+
1027
+ return {self.output_key: output}
1028
+
1029
+ def call(self, **inputs):
1030
+ """Process input data through the pipeline."""
1031
+ output = self._jittable_call(**inputs)
1032
+ inputs.update(output)
1033
+ return inputs
1034
+
1035
+ def get_dict(self):
1036
+ """Get the configuration of the pipeline."""
1037
+ config = super().get_dict()
1038
+ config.update({"name": "map"})
1039
+
1040
+ config["params"].update(
1041
+ {
1042
+ "argnames": self.argnames,
1043
+ "in_axes": self.in_axes,
1044
+ "out_axes": self.out_axes,
1045
+ "chunks": self.chunks,
1046
+ "batch_size": self.batch_size,
1047
+ }
1048
+ )
1049
+ return config
1050
+
1051
+
1052
+ @ops_registry("patched_grid")
1053
+ class PatchedGrid(Map):
1054
+ """
1055
+ A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys.
1056
+
1057
+ This can be used to reduce memory usage by processing data in chunks.
1058
+
1059
+ For more information and flexibility, see :class:`zea.ops.Map`.
1060
+ """
1061
+
1062
+ def __init__(self, *args, num_patches=10, **kwargs):
1063
+ super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs)
1064
+ self.num_patches = num_patches
1065
+
1066
+ def get_dict(self):
1067
+ """Get the configuration of the pipeline."""
1068
+ config = super().get_dict()
1069
+ config.update({"name": "patched_grid"})
1070
+
1071
+ config["params"].pop("argnames")
1072
+ config["params"].pop("chunks")
1073
+ config["params"].update({"num_patches": self.num_patches})
1074
+ return config
1075
+
1076
+
1077
+ @ops_registry("beamform")
1078
+ class Beamform(Pipeline):
1079
+ """Classical beamforming pipeline for ultrasound image formation.
1080
+
1081
+ Expected input data type is `DataTypes.RF_DATA` which has shape `(n_tx, n_ax, n_el, n_ch)`.
1082
+
1083
+ Will run the following operations in sequence:
1084
+ - TOFCorrection (output type `DataTypes.ALIGNED_DATA`: `(n_tx, n_ax, n_el, n_ch)`)
1085
+ - PfieldWeighting (optional, output type `DataTypes.ALIGNED_DATA`: `(n_tx, n_ax, n_el, n_ch)`)
1086
+ - Sum over channels (DAS)
1087
+ - Sum over transmits (Compounding) (output type `DataTypes.BEAMFORMED_DATA`: `(grid_size_z, grid_size_x, n_ch)`)
1088
+ - ReshapeGrid (flattened grid is also reshaped to `(grid_size_z, grid_size_x)`)
1089
+ """ # noqa: E501
1090
+
1091
+ def __init__(self, beamformer="delay_and_sum", num_patches=100, enable_pfield=False, **kwargs):
1092
+ """Initialize a Delay-and-Sum beamforming `zea.Pipeline`.
1093
+
1094
+ Args:
1095
+ beamformer (str): Type of beamformer to use. Currently supporting,
1096
+ "delay_and_sum" and "delay_multiply_and_sum".
1097
+ num_patches (int): Number of patches to split the grid into for patch-wise
1098
+ beamforming. If 1, no patching is performed.
1099
+ enable_pfield (bool): Whether to include pressure field weighting in the beamforming.
1100
+
1101
+ """
1102
+
1103
+ self.beamformer_type = beamformer
1104
+ self.num_patches = num_patches
1105
+ self.enable_pfield = enable_pfield
1106
+
1107
+ # for backwards compatibility
1108
+ name_mapping = {
1109
+ "das": "delay_and_sum",
1110
+ "dmas": "delay_multiply_and_sum",
1111
+ }
1112
+ if beamformer in name_mapping:
1113
+ log.deprecated(
1114
+ f"Beamformer name '{beamformer}' is deprecated. "
1115
+ f"Please use '{name_mapping[beamformer]}' instead."
1116
+ )
1117
+ self.beamformer_type = name_mapping[beamformer]
1118
+
1119
+ if self.beamformer_type not in ["delay_and_sum", "delay_multiply_and_sum"]:
1120
+ raise ValueError(
1121
+ f"Unsupported beamformer type: {self.beamformer_type}. "
1122
+ "Supported types are 'delay_and_sum' and 'delay_multiply_and_sum'."
1123
+ )
1124
+
1125
+ # Get beamforming ops
1126
+ beamforming = [
1127
+ TOFCorrection(),
1128
+ # PfieldWeighting(), # Inserted conditionally
1129
+ get_ops(self.beamformer_type)(),
1130
+ ]
1131
+
1132
+ if self.enable_pfield:
1133
+ beamforming.insert(1, PfieldWeighting())
1134
+
1135
+ # Optionally add patching
1136
+ if self.num_patches > 1:
1137
+ beamforming = [
1138
+ PatchedGrid(
1139
+ operations=beamforming,
1140
+ num_patches=self.num_patches,
1141
+ **kwargs,
1142
+ )
1143
+ ]
1144
+
1145
+ # Reshape the grid to image shape
1146
+ beamforming.append(ReshapeGrid())
1147
+
1148
+ # Set the output data type of the last operation
1149
+ # which also defines the pipeline output type
1150
+ beamforming[-1].output_data_type = DataTypes.BEAMFORMED_DATA
1151
+
1152
+ super().__init__(operations=beamforming, **kwargs)
1153
+
1154
+ def __repr__(self):
1155
+ """String representation of the pipeline."""
1156
+ operations = []
1157
+ for operation in self.operations:
1158
+ if isinstance(operation, Pipeline):
1159
+ operations.append(repr(operation))
1160
+ else:
1161
+ operations.append(operation.__class__.__name__)
1162
+ return f"<Beamform {self.name}=({', '.join(operations)})>"
1163
+
1164
+ def get_dict(self) -> dict:
1165
+ """Convert the pipeline to a dictionary."""
1166
+ config = super().get_dict()
1167
+ config.update({"name": "beamform"})
1168
+ config["params"].update(
1169
+ {
1170
+ "beamformer": self.beamformer_type,
1171
+ "num_patches": self.num_patches,
1172
+ "enable_pfield": self.enable_pfield,
1173
+ }
1174
+ )
1175
+ return config
1176
+
1177
+
1178
+ @ops_registry("delay_and_sum")
1179
+ class DelayAndSum(Operation):
1180
+ """Sums time-delayed signals along channels and transmits."""
1181
+
1182
+ def __init__(self, **kwargs):
1183
+ super().__init__(
1184
+ input_data_type=DataTypes.ALIGNED_DATA,
1185
+ output_data_type=DataTypes.BEAMFORMED_DATA,
1186
+ **kwargs,
1187
+ )
1188
+
1189
+ def call(self, grid=None, **kwargs):
1190
+ """Performs DAS beamforming on tof-corrected input.
1191
+
1192
+ Args:
1193
+ tof_corrected_data (ops.Tensor): The TOF corrected input of shape
1194
+ `(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
1195
+
1196
+ Returns:
1197
+ dict: Dictionary containing beamformed_data
1198
+ of shape `(grid_size_z*grid_size_x, n_ch)`
1199
+ with optional batch dimension.
1200
+ """
1201
+ data = kwargs[self.key]
1202
+
1203
+ # Sum over the channels (n_el), i.e. DAS
1204
+ beamformed_data = ops.sum(data, -2)
1205
+ # Sum over transmits (n_tx), i.e. Compounding
1206
+ beamformed_data = ops.sum(beamformed_data, -3)
1207
+
1208
+ return {self.output_key: beamformed_data}
1209
+
1210
+
1211
+ @ops_registry("delay_multiply_and_sum")
1212
+ class DelayMultiplyAndSum(Operation):
1213
+ """Performs the operations for the Delay-Multiply-and-Sum beamformer except the delay.
1214
+ The delay should be performed by the TOF correction operation.
1215
+ """
1216
+
1217
+ def __init__(self, **kwargs):
1218
+ super().__init__(
1219
+ input_data_type=DataTypes.ALIGNED_DATA,
1220
+ output_data_type=DataTypes.BEAMFORMED_DATA,
1221
+ **kwargs,
1222
+ )
1223
+
1224
+ def process_image(self, data):
1225
+ """Performs DMAS beamforming on tof-corrected input.
1226
+
1227
+ Args:
1228
+ data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
1229
+
1230
+ Returns:
1231
+ ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
1232
+ """
1233
+
1234
+ if not data.shape[-1] == 2:
1235
+ raise ValueError(
1236
+ "MultiplyAndSum operation requires IQ data with 2 channels. "
1237
+ f"Got data with shape {data.shape}."
1238
+ )
1239
+
1240
+ # Compute the correlation matrix
1241
+ data = channels_to_complex(data)
1242
+
1243
+ data = self._multiply(data)
1244
+ data = self._select_lower_triangle(data)
1245
+ data = ops.sum(data, axis=(0, 2, 3))
1246
+
1247
+ data = complex_to_channels(data)
1248
+
1249
+ return data
1250
+
1251
+ def _select_lower_triangle(self, data):
1252
+ """Select only the lower triangle of the correlation matrix."""
1253
+ n_el = data.shape[3]
1254
+ mask = ops.ones((n_el, n_el), dtype=data.dtype) - ops.eye(n_el, dtype=data.dtype)
1255
+ data = data * mask[None, None, :, :] / 2
1256
+ return data
1257
+
1258
+ def _multiply(self, data):
1259
+ """Apply the DMAS multiplication step."""
1260
+ channel_products = data[:, :, :, None] * data[:, :, None, :]
1261
+
1262
+ data = ops.sign(channel_products) * ops.cast(
1263
+ ops.sqrt(ops.abs(channel_products)), data.dtype
1264
+ )
1265
+ return data
1266
+
1267
+ def call(self, grid=None, **kwargs):
1268
+ """Performs DMAS beamforming on tof-corrected input.
1269
+
1270
+ Args:
1271
+ tof_corrected_data (ops.Tensor): The TOF corrected input of shape
1272
+ `(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
1273
+
1274
+ Returns:
1275
+ dict: Dictionary containing beamformed_data
1276
+ of shape `(grid_size_z*grid_size_x, n_ch)`
1277
+ with optional batch dimension.
1278
+ """
1279
+ data = kwargs[self.key]
1280
+
1281
+ if not self.with_batch_dim:
1282
+ beamformed_data = self.process_image(data)
1283
+ else:
1284
+ # Apply process_image to each item in the batch
1285
+ beamformed_data = ops.map(self.process_image, data)
1286
+
1287
+ return {self.output_key: beamformed_data}
1288
+
1289
+
1290
+ def make_operation_chain(
1291
+ operation_chain: List[Union[str, Dict, Config, Operation, Pipeline]],
1292
+ ) -> List[Operation]:
1293
+ """Make an operation chain from a custom list of operations.
1294
+
1295
+ Args:
1296
+ operation_chain (list): List of operations to be performed.
1297
+ Each operation can be:
1298
+ - A string: operation initialized with default parameters
1299
+ - A dictionary: operation initialized with parameters in the dictionary
1300
+ - A Config object: converted to a dictionary and initialized
1301
+ - An Operation/Pipeline instance: used as-is
1302
+
1303
+ Returns:
1304
+ list: List of operations to be performed.
1305
+
1306
+ Example:
1307
+ .. doctest::
1308
+
1309
+ >>> from zea.ops import make_operation_chain, LogCompress
1310
+ >>> SomeCustomOperation = LogCompress # just for demonstration
1311
+ >>> chain = make_operation_chain(
1312
+ ... [
1313
+ ... "envelope_detect",
1314
+ ... {"name": "normalize", "params": {"output_range": (0, 1)}},
1315
+ ... SomeCustomOperation(),
1316
+ ... ]
1317
+ ... )
1318
+ """
1319
+ chain = []
1320
+ for operation in operation_chain:
1321
+ # Handle already instantiated Operation or Pipeline objects
1322
+ if isinstance(operation, (Operation, Pipeline)):
1323
+ chain.append(operation)
1324
+ continue
1325
+
1326
+ assert isinstance(operation, (str, dict, Config)), (
1327
+ f"Operation {operation} should be a string, dict, Config object, Operation, or Pipeline"
1328
+ )
1329
+
1330
+ if isinstance(operation, str):
1331
+ operation_instance = get_ops(operation)()
1332
+
1333
+ else:
1334
+ if isinstance(operation, Config):
1335
+ operation = operation.serialize()
1336
+
1337
+ params = operation.get("params", {})
1338
+ op_name = operation.get("name")
1339
+ operation_cls = get_ops(op_name)
1340
+
1341
+ # Handle branches for branched pipeline
1342
+ if op_name == "branched_pipeline" and "branches" in operation:
1343
+ branch_configs = operation.get("branches", {})
1344
+ branches = []
1345
+
1346
+ # Convert each branch configuration to an operation chain
1347
+ for _, branch_config in branch_configs.items():
1348
+ if isinstance(branch_config, (list, np.ndarray)):
1349
+ # This is a list of operations
1350
+ branch = make_operation_chain(branch_config)
1351
+ elif "operations" in branch_config:
1352
+ # This is a pipeline-like branch
1353
+ branch = make_operation_chain(branch_config["operations"])
1354
+ else:
1355
+ # This is a single operation branch
1356
+ branch_op_cls = get_ops(branch_config["name"])
1357
+ branch_params = branch_config.get("params", {})
1358
+ branch = branch_op_cls(**branch_params)
1359
+
1360
+ branches.append(branch)
1361
+
1362
+ # Create the branched pipeline instance
1363
+ operation_instance = operation_cls(branches=branches, **params)
1364
+ # Check for nested operations at the same level as params
1365
+ elif "operations" in operation:
1366
+ nested_operations = make_operation_chain(operation["operations"])
1367
+ # Instantiate pipeline-type operations with nested operations
1368
+ if issubclass(operation_cls, Beamform):
1369
+ # some pipelines, such as `zea.ops.Beamformer`, are initialized
1370
+ # not with a list of operations but with other parameters that then
1371
+ # internally create a list of operations
1372
+ operation_instance = operation_cls(**params)
1373
+ elif issubclass(operation_cls, Pipeline):
1374
+ # in most cases we want to pass an operations list to
1375
+ # initialize a pipeline
1376
+ operation_instance = operation_cls(operations=nested_operations, **params)
1377
+ else:
1378
+ operation_instance = operation_cls(operations=nested_operations, **params)
1379
+ elif operation["name"] in ["patched_grid"]:
1380
+ nested_operations = make_operation_chain(operation["params"].pop("operations"))
1381
+ operation_instance = operation_cls(operations=nested_operations, **params)
1382
+ else:
1383
+ operation_instance = operation_cls(**params)
1384
+
1385
+ chain.append(operation_instance)
1386
+
1387
+ return chain
1388
+
1389
+
1390
+ def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
1391
+ """
1392
+ Create a Pipeline instance from a Config object.
1393
+ """
1394
+ assert "operations" in config, (
1395
+ "Config object must have an 'operations' key for pipeline creation."
1396
+ )
1397
+ assert isinstance(config.operations, (list, np.ndarray)), (
1398
+ "Config object must have a list or numpy array of operations for pipeline creation."
1399
+ )
1400
+
1401
+ operations = make_operation_chain(config.operations)
1402
+
1403
+ # merge pipeline config without operations with kwargs
1404
+ pipeline_config = config.copy()
1405
+ pipeline_config.pop("operations")
1406
+
1407
+ kwargs = {**pipeline_config, **kwargs}
1408
+ return Pipeline(operations=operations, **kwargs)
1409
+
1410
+
1411
+ def pipeline_from_json(json_string: str, **kwargs) -> Pipeline:
1412
+ """
1413
+ Create a Pipeline instance from a JSON string.
1414
+ """
1415
+ pipeline_config = Config(json.loads(json_string, cls=ZEADecoderJSON))
1416
+ return pipeline_from_config(pipeline_config, **kwargs)
1417
+
1418
+
1419
+ def pipeline_from_yaml(yaml_path: str, **kwargs) -> Pipeline:
1420
+ """
1421
+ Create a Pipeline instance from a YAML file.
1422
+ """
1423
+ with open(yaml_path, "r", encoding="utf-8") as f:
1424
+ pipeline_config = yaml.safe_load(f)
1425
+ operations = pipeline_config["operations"]
1426
+ return pipeline_from_config(Config({"operations": operations}), **kwargs)
1427
+
1428
+
1429
+ def pipeline_to_config(pipeline: Pipeline) -> Config:
1430
+ """
1431
+ Convert a Pipeline instance into a Config object.
1432
+ """
1433
+ # TODO: we currently add the full pipeline as 1 operation to the config.
1434
+ # In another PR we should add a "pipeline" entry to the config instead of the "operations"
1435
+ # entry. This allows us to also have non-default pipeline classes as top level op.
1436
+ pipeline_dict = {"operations": [pipeline.get_dict()]}
1437
+
1438
+ # HACK: If the top level operation is a single pipeline, collapse it into the operations list.
1439
+ ops = pipeline_dict["operations"]
1440
+ if ops[0]["name"] == "pipeline" and len(ops) == 1:
1441
+ pipeline_dict = {"operations": ops[0]["operations"]}
1442
+
1443
+ return Config(pipeline_dict)
1444
+
1445
+
1446
+ def pipeline_to_json(pipeline: Pipeline) -> str:
1447
+ """
1448
+ Convert a Pipeline instance into a JSON string.
1449
+ """
1450
+ pipeline_dict = {"operations": [pipeline.get_dict()]}
1451
+
1452
+ # HACK: If the top level operation is a single pipeline, collapse it into the operations list.
1453
+ ops = pipeline_dict["operations"]
1454
+ if ops[0]["name"] == "pipeline" and len(ops) == 1:
1455
+ pipeline_dict = {"operations": ops[0]["operations"]}
1456
+
1457
+ return json.dumps(pipeline_dict, cls=ZEAEncoderJSON, indent=4)
1458
+
1459
+
1460
+ def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
1461
+ """
1462
+ Convert a Pipeline instance into a YAML file.
1463
+ """
1464
+ pipeline_dict = pipeline.get_dict()
1465
+
1466
+ # HACK: If the top level operation is a single pipeline, collapse it into the operations list.
1467
+ ops = pipeline_dict["operations"]
1468
+ if ops[0]["name"] == "pipeline" and len(ops) == 1:
1469
+ pipeline_dict = {"operations": ops[0]["operations"]}
1470
+
1471
+ with open(file_path, "w", encoding="utf-8") as f:
1472
+ yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)