zea 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. zea/__init__.py +1 -1
  2. zea/backend/tensorflow/dataloader.py +0 -4
  3. zea/beamform/pixelgrid.py +1 -1
  4. zea/data/__init__.py +0 -9
  5. zea/data/augmentations.py +221 -28
  6. zea/data/convert/__init__.py +1 -6
  7. zea/data/convert/__main__.py +123 -0
  8. zea/data/convert/camus.py +99 -39
  9. zea/data/convert/echonet.py +183 -82
  10. zea/data/convert/echonetlvh/README.md +2 -3
  11. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +173 -102
  12. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  13. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  14. zea/data/convert/picmus.py +37 -40
  15. zea/data/convert/utils.py +86 -0
  16. zea/data/convert/{matlab.py → verasonics.py} +33 -61
  17. zea/data/data_format.py +124 -4
  18. zea/data/dataloader.py +12 -7
  19. zea/data/datasets.py +109 -70
  20. zea/data/file.py +91 -82
  21. zea/data/file_operations.py +496 -0
  22. zea/data/preset_utils.py +1 -1
  23. zea/display.py +7 -8
  24. zea/internal/checks.py +6 -12
  25. zea/internal/operators.py +4 -0
  26. zea/io_lib.py +108 -160
  27. zea/models/__init__.py +1 -1
  28. zea/models/diffusion.py +62 -11
  29. zea/models/lv_segmentation.py +2 -0
  30. zea/ops.py +398 -158
  31. zea/scan.py +18 -8
  32. zea/tensor_ops.py +82 -62
  33. zea/tools/fit_scan_cone.py +90 -160
  34. zea/tracking/__init__.py +16 -0
  35. zea/tracking/base.py +94 -0
  36. zea/tracking/lucas_kanade.py +474 -0
  37. zea/tracking/segmentation.py +110 -0
  38. zea/utils.py +11 -2
  39. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/METADATA +3 -1
  40. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/RECORD +43 -35
  41. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  42. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  43. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/ops.py CHANGED
@@ -88,10 +88,10 @@ Example of a yaml file:
88
88
 
89
89
  """
90
90
 
91
- import copy
92
91
  import hashlib
93
92
  import inspect
94
93
  import json
94
+ import uuid
95
95
  from functools import partial
96
96
  from typing import Any, Dict, List, Union
97
97
 
@@ -120,7 +120,15 @@ from zea.internal.registry import ops_registry
120
120
  from zea.probes import Probe
121
121
  from zea.scan import Scan
122
122
  from zea.simulator import simulate_rf
123
- from zea.tensor_ops import resample, reshape_axis, translate, vmap
123
+ from zea.tensor_ops import (
124
+ apply_along_axis,
125
+ correlate,
126
+ extend_n_dims,
127
+ resample,
128
+ reshape_axis,
129
+ translate,
130
+ vmap,
131
+ )
124
132
  from zea.utils import (
125
133
  FunctionTimer,
126
134
  deep_compare,
@@ -238,7 +246,7 @@ class Operation(keras.Operation):
238
246
  Analyze and store the input/output signatures of the `call` method.
239
247
  """
240
248
  self._input_signature = inspect.signature(self.call)
241
- self._valid_keys = set(self._input_signature.parameters.keys())
249
+ self._valid_keys = set(self._input_signature.parameters.keys()) | {self.key}
242
250
 
243
251
  @property
244
252
  def valid_keys(self) -> set:
@@ -405,6 +413,36 @@ class Operation(keras.Operation):
405
413
  return True
406
414
 
407
415
 
416
+ class ImageOperation(Operation):
417
+ """
418
+ Base class for image processing operations.
419
+
420
+ This class extends the Operation class to provide a common interface
421
+ for operations that process image data, with shape (batch, height, width, channels)
422
+ or (height, width, channels) if batch dimension is not present.
423
+
424
+ Subclasses should implement the `call` method to define the image processing logic, and call
425
+ ``super().call(**kwargs)`` to validate the input data shape.
426
+ """
427
+
428
+ def call(self, **kwargs):
429
+ """
430
+ Validate input data shape for image operations.
431
+
432
+ Args:
433
+ **kwargs: Keyword arguments containing input data.
434
+
435
+ Raises:
436
+ AssertionError: If input data does not have the expected number of dimensions.
437
+ """
438
+ data = kwargs[self.key]
439
+
440
+ if self.with_batch_dim:
441
+ assert ops.ndim(data) == 4, "Input data must have 4 dimensions (b, h, w, c)."
442
+ else:
443
+ assert ops.ndim(data) == 3, "Input data must have 3 dimensions (h, w, c)."
444
+
445
+
408
446
  @ops_registry("pipeline")
409
447
  class Pipeline:
410
448
  """Pipeline class for processing ultrasound data through a series of operations."""
@@ -488,10 +526,20 @@ class Pipeline:
488
526
  self.jit_kwargs = jit_kwargs
489
527
  self.jit_options = jit_options # will handle the jit compilation
490
528
 
529
+ self._logged_difference_keys = False
530
+
531
+ # Do not log again for nested pipelines
532
+ for nested_pipeline in self._nested_pipelines:
533
+ nested_pipeline._logged_difference_keys = True
534
+
491
535
  def needs(self, key) -> bool:
492
536
  """Check if the pipeline needs a specific key at the input."""
493
537
  return key in self.needs_keys
494
538
 
539
+ @property
540
+ def _nested_pipelines(self):
541
+ return [operation for operation in self.operations if isinstance(operation, Pipeline)]
542
+
495
543
  @property
496
544
  def output_keys(self) -> set:
497
545
  """All output keys the pipeline guarantees to produce."""
@@ -576,6 +624,7 @@ class Pipeline:
576
624
 
577
625
  # Add display ops
578
626
  operations += [
627
+ ReshapeGrid(),
579
628
  EnvelopeDetect(),
580
629
  Normalize(),
581
630
  LogCompress(),
@@ -682,6 +731,16 @@ class Pipeline:
682
731
  "Please ensure all inputs are convertible to tensors."
683
732
  )
684
733
 
734
+ if not self._logged_difference_keys:
735
+ difference_keys = set(inputs.keys()) - self.valid_keys
736
+ if difference_keys:
737
+ log.debug(
738
+ f"[zea.Pipeline] The following input keys are not used by the pipeline: "
739
+ f"{difference_keys}. Make sure this is intended. "
740
+ "This warning will only be shown once."
741
+ )
742
+ self._logged_difference_keys = True
743
+
685
744
  ## PROCESSING
686
745
  outputs = self._call_pipeline(**inputs)
687
746
 
@@ -804,59 +863,9 @@ class Pipeline:
804
863
  return params
805
864
 
806
865
  def __str__(self):
807
- """String representation of the pipeline.
808
-
809
- Will print on two parallel pipeline lines if it detects a splitting operations
810
- (such as multi_bandpass_filter)
811
- Will merge the pipeline lines if it detects a stacking operation (such as stack)
812
- """
813
- split_operations = []
814
- merge_operations = ["Stack"]
815
-
866
+ """String representation of the pipeline."""
816
867
  operations = [operation.__class__.__name__ for operation in self.operations]
817
868
  string = " -> ".join(operations)
818
-
819
- if any(operation in split_operations for operation in operations):
820
- # a second line is needed with same length as the first line
821
- split_line = " " * len(string)
822
- # find the splitting operation and index and print \-> instead of -> after
823
- split_detected = False
824
- merge_detected = False
825
- split_operation = None
826
- for operation in operations:
827
- if operation in split_operations:
828
- index = string.index(operation)
829
- index = index + len(operation)
830
- split_line = split_line[:index] + "\\->" + split_line[index + len("\\->") :]
831
- split_detected = True
832
- merge_detected = False
833
- split_operation = operation
834
- continue
835
-
836
- if operation in merge_operations:
837
- index = string.index(operation)
838
- index = index - 4
839
- split_line = split_line[:index] + "/" + split_line[index + 1 :]
840
- split_detected = False
841
- merge_detected = True
842
- continue
843
-
844
- if split_detected:
845
- # print all operations in the second line
846
- index = string.index(operation)
847
- split_line = (
848
- split_line[:index]
849
- + operation
850
- + " -> "
851
- + split_line[index + len(operation) + len(" -> ") :]
852
- )
853
- assert merge_detected is True, log.error(
854
- "Pipeline was never merged back together (with Stack operation), even "
855
- f"though it was split with {split_operation}. "
856
- "Please properly define your operation chain."
857
- )
858
- return f"\n{string}\n{split_line}\n"
859
-
860
869
  return string
861
870
 
862
871
  def __repr__(self):
@@ -1185,7 +1194,7 @@ def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
1185
1194
  operations = make_operation_chain(config.operations)
1186
1195
 
1187
1196
  # merge pipeline config without operations with kwargs
1188
- pipeline_config = copy.deepcopy(config)
1197
+ pipeline_config = config.copy()
1189
1198
  pipeline_config.pop("operations")
1190
1199
 
1191
1200
  kwargs = {**pipeline_config, **kwargs}
@@ -1256,33 +1265,134 @@ def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
1256
1265
  yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
1257
1266
 
1258
1267
 
1259
- @ops_registry("patched_grid")
1260
- class PatchedGrid(Pipeline):
1268
+ @ops_registry("map")
1269
+ class Map(Pipeline):
1261
1270
  """
1262
- With this class you can form a pipeline that will be applied to patches of the grid.
1263
- This is useful to avoid OOM errors when processing large grids.
1271
+ A pipeline that maps its operations over specified input arguments.
1272
+
1273
+ This can be used to reduce memory usage by processing data in chunks.
1274
+
1275
+ Notes
1276
+ -----
1277
+ - When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline.
1278
+ - Changing anything other than ``self.output_key`` in the dict will not be propagated.
1279
+ - Will be jitted as a single operation, not the individual operations.
1280
+ - This class handles the batching.
1264
1281
 
1265
- Some things to NOTE about this class:
1282
+ For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for
1283
+ jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_.
1266
1284
 
1267
- - The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
1285
+ Example
1286
+ -------
1287
+ .. doctest::
1268
1288
 
1269
- - Changing anything other than `self.output_data_type` in the dict will not be propagated!
1289
+ >>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection
1270
1290
 
1271
- - Will be jitted as a single operation, not the individual operations.
1291
+ >>> # apply operations in batches of 8
1292
+ >>> # in this case, over the first axis of "data"
1293
+ >>> # or more specifically, process 8 transmits at a time
1272
1294
 
1273
- - This class handles the batching.
1295
+ >>> pipeline_mapped = Map(
1296
+ ... [
1297
+ ... Demodulate(),
1298
+ ... TOFCorrection(),
1299
+ ... ],
1300
+ ... argnames="data",
1301
+ ... batch_size=8,
1302
+ ... )
1303
+
1304
+ >>> # you can also map a subset of the operations
1305
+ >>> # for example, demodulate in 4 chunks
1306
+ >>> # or more specifically, split the transmit axis into 4 parts
1274
1307
 
1308
+ >>> pipeline_mapped = Pipeline(
1309
+ ... [
1310
+ ... Map([Demodulate()], argnames="data", chunks=4),
1311
+ ... TOFCorrection(),
1312
+ ... ],
1313
+ ... )
1275
1314
  """
1276
1315
 
1277
- def __init__(self, *args, num_patches=10, **kwargs):
1278
- super().__init__(*args, name="patched_grid", **kwargs)
1279
- self.num_patches = num_patches
1316
+ def __init__(
1317
+ self,
1318
+ operations: List[Operation],
1319
+ argnames: List[str] | str,
1320
+ in_axes: List[Union[int, None]] | int = 0,
1321
+ out_axes: List[Union[int, None]] | int = 0,
1322
+ chunks: int | None = None,
1323
+ batch_size: int | None = None,
1324
+ **kwargs,
1325
+ ):
1326
+ """
1327
+ Args:
1328
+ operations (list): List of operations to be performed.
1329
+ argnames (str or list): List of argument names (or keys) to map over.
1330
+ Can also be a single string if only one argument is mapped over.
1331
+ in_axes (int or list): Axes to map over for each argument.
1332
+ If a single int is provided, it is used for all arguments.
1333
+ out_axes (int or list): Axes to map over for each output.
1334
+ If a single int is provided, it is used for all outputs.
1335
+ chunks (int, optional): Number of chunks to split the input data into.
1336
+ If None, no chunking is performed. Mutually exclusive with ``batch_size``.
1337
+ batch_size (int, optional): Size of batches to process at once.
1338
+ If None, no batching is performed. Mutually exclusive with ``chunks``.
1339
+ """
1340
+ super().__init__(operations, **kwargs)
1280
1341
 
1281
- for operation in self.operations:
1282
- if isinstance(operation, DelayAndSum):
1283
- operation.reshape_grid = False
1342
+ if batch_size is not None and chunks is not None:
1343
+ raise ValueError(
1344
+ "batch_size and chunks are mutually exclusive. Please specify only one."
1345
+ )
1284
1346
 
1285
- self._jittable_call = self.jittable_call
1347
+ if batch_size is not None and batch_size <= 0:
1348
+ raise ValueError("batch_size must be a positive integer.")
1349
+
1350
+ if chunks is not None and chunks <= 0:
1351
+ raise ValueError("chunks must be a positive integer.")
1352
+
1353
+ if isinstance(argnames, str):
1354
+ argnames = [argnames]
1355
+
1356
+ self.argnames = argnames
1357
+ self.in_axes = in_axes
1358
+ self.out_axes = out_axes
1359
+ self.chunks = chunks
1360
+ self.batch_size = batch_size
1361
+
1362
+ if chunks is None and batch_size is None:
1363
+ log.warning(
1364
+ "[zea.ops.Map] Both `chunks` and `batch_size` are None. "
1365
+ "This will behave like a normal Pipeline. "
1366
+ "Consider setting one of them to process data in chunks or batches."
1367
+ )
1368
+
1369
+ def call_item(**inputs):
1370
+ """Process data in patches."""
1371
+ mapped_args = []
1372
+ for argname in argnames:
1373
+ mapped_args.append(inputs.pop(argname, None))
1374
+
1375
+ def patched_call(*args):
1376
+ mapped_kwargs = [(k, v) for k, v in zip(argnames, args)]
1377
+ out = super(Map, self).call(**dict(mapped_kwargs), **inputs)
1378
+
1379
+ # TODO: maybe it is possible to output everything?
1380
+ # e.g. prepend a empty dimension to all inputs and just map over everything?
1381
+ return out[self.output_key]
1382
+
1383
+ out = vmap(
1384
+ patched_call,
1385
+ in_axes=in_axes,
1386
+ out_axes=out_axes,
1387
+ chunks=chunks,
1388
+ batch_size=batch_size,
1389
+ fn_supports_batch=True,
1390
+ disable_jit=not bool(self.jit_options),
1391
+ )(*mapped_args)
1392
+
1393
+ return out
1394
+
1395
+ self.call_item = call_item
1286
1396
 
1287
1397
  @property
1288
1398
  def jit_options(self):
@@ -1320,60 +1430,30 @@ class PatchedGrid(Pipeline):
1320
1430
  for operation in self.operations:
1321
1431
  operation.with_batch_dim = False
1322
1432
 
1323
- @property
1324
- def _extra_keys(self):
1325
- return {"flatgrid", "grid_size_x", "grid_size_z"}
1326
-
1327
1433
  @property
1328
1434
  def valid_keys(self) -> set:
1329
1435
  """Get a set of valid keys for the pipeline.
1330
1436
  Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1331
1437
  inside it)."""
1332
- return super().valid_keys.union(self._extra_keys)
1438
+ return super().valid_keys.union(self.argnames)
1333
1439
 
1334
1440
  @property
1335
1441
  def needs_keys(self) -> set:
1336
1442
  """Get a set of all input keys needed by the pipeline.
1337
1443
  Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1338
1444
  inside it)."""
1339
- return super().needs_keys.union(self._extra_keys)
1340
-
1341
- def call_item(self, inputs):
1342
- """Process data in patches."""
1343
- # Extract necessary parameters
1344
- # make sure to add those as valid keys above!
1345
- grid_size_x = inputs["grid_size_x"]
1346
- grid_size_z = inputs["grid_size_z"]
1347
- flatgrid = inputs.pop("flatgrid")
1348
-
1349
- # Define a list of keys to look up for patching
1350
- flat_pfield = inputs.pop("flat_pfield", None)
1351
-
1352
- def patched_call(flatgrid, flat_pfield):
1353
- out = super(PatchedGrid, self).call(
1354
- flatgrid=flatgrid, flat_pfield=flat_pfield, **inputs
1355
- )
1356
- return out[self.output_key]
1357
-
1358
- out = vmap(
1359
- patched_call,
1360
- chunks=self.num_patches,
1361
- fn_supports_batch=True,
1362
- disable_jit=not bool(self.jit_options),
1363
- )(flatgrid, flat_pfield)
1364
-
1365
- return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
1445
+ return super().needs_keys.union(self.argnames)
1366
1446
 
1367
1447
  def jittable_call(self, **inputs):
1368
1448
  """Process input data through the pipeline."""
1369
1449
  if self._with_batch_dim:
1370
1450
  input_data = inputs.pop(self.key)
1371
1451
  output = ops.map(
1372
- lambda x: self.call_item({self.key: x, **inputs}),
1452
+ lambda x: self.call_item(**{self.key: x, **inputs}),
1373
1453
  input_data,
1374
1454
  )
1375
1455
  else:
1376
- output = self.call_item(inputs)
1456
+ output = self.call_item(**inputs)
1377
1457
 
1378
1458
  return {self.output_key: output}
1379
1459
 
@@ -1386,11 +1466,61 @@ class PatchedGrid(Pipeline):
1386
1466
  def get_dict(self):
1387
1467
  """Get the configuration of the pipeline."""
1388
1468
  config = super().get_dict()
1389
- config.update({"name": "patched_grid"})
1469
+ config["params"].update(
1470
+ {
1471
+ "argnames": self.argnames,
1472
+ "in_axes": self.in_axes,
1473
+ "out_axes": self.out_axes,
1474
+ "chunks": self.chunks,
1475
+ "batch_size": self.batch_size,
1476
+ }
1477
+ )
1478
+ return config
1479
+
1480
+
1481
+ @ops_registry("patched_grid")
1482
+ class PatchedGrid(Map):
1483
+ """
1484
+ A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys.
1485
+
1486
+ This can be used to reduce memory usage by processing data in chunks.
1487
+
1488
+ For more information and flexibility, see :class:`zea.ops.Map`.
1489
+ """
1490
+
1491
+ def __init__(self, *args, num_patches=10, **kwargs):
1492
+ super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs)
1493
+ self.num_patches = num_patches
1494
+
1495
+ def get_dict(self):
1496
+ """Get the configuration of the pipeline."""
1497
+ config = super().get_dict()
1498
+ config["params"].pop("argnames")
1499
+ config["params"].pop("chunks")
1390
1500
  config["params"].update({"num_patches": self.num_patches})
1391
1501
  return config
1392
1502
 
1393
1503
 
1504
+ @ops_registry("reshape_grid")
1505
+ class ReshapeGrid(Operation):
1506
+ """Reshape flat grid data to grid shape."""
1507
+
1508
+ def __init__(self, axis=0, **kwargs):
1509
+ super().__init__(**kwargs)
1510
+ self.axis = axis
1511
+
1512
+ def call(self, grid, **kwargs):
1513
+ """
1514
+ Args:
1515
+ - data (Tensor): The flat grid data of shape (..., n_pix, ...).
1516
+ Returns:
1517
+ - reshaped_data (Tensor): The reshaped data of shape (..., grid.shape, ...).
1518
+ """
1519
+ data = kwargs[self.key]
1520
+ reshaped_data = reshape_axis(data, grid.shape[:-1], self.axis + int(self.with_batch_dim))
1521
+ return {self.output_key: reshaped_data}
1522
+
1523
+
1394
1524
  ## Base Operations
1395
1525
 
1396
1526
 
@@ -1423,21 +1553,6 @@ class Merge(Operation):
1423
1553
  return merged
1424
1554
 
1425
1555
 
1426
- @ops_registry("split")
1427
- class Split(Operation):
1428
- """Operation that splits an input dictionary n copies."""
1429
-
1430
- def __init__(self, n: int, **kwargs):
1431
- super().__init__(**kwargs)
1432
- self.n = n
1433
-
1434
- def call(self, **kwargs) -> List[Dict]:
1435
- """
1436
- Splits the input dictionary into n copies.
1437
- """
1438
- return [kwargs.copy() for _ in range(self.n)]
1439
-
1440
-
1441
1556
  @ops_registry("stack")
1442
1557
  class Stack(Operation):
1443
1558
  """Stack multiple data arrays along a new axis.
@@ -1652,7 +1767,7 @@ class PfieldWeighting(Operation):
1652
1767
  Returns:
1653
1768
  dict: Dictionary containing weighted data
1654
1769
  """
1655
- data = kwargs[self.key]
1770
+ data = kwargs[self.key] # must start with ((batch_size,) n_tx, n_pix, ...)
1656
1771
 
1657
1772
  if flat_pfield is None:
1658
1773
  return {self.output_key: data}
@@ -1660,14 +1775,16 @@ class PfieldWeighting(Operation):
1660
1775
  # Swap (n_pix, n_tx) to (n_tx, n_pix)
1661
1776
  flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
1662
1777
 
1663
- # Perform element-wise multiplication with the pressure weight mask
1664
- # Also add the required dimensions for broadcasting
1778
+ # Add batch dimension if needed
1665
1779
  if self.with_batch_dim:
1666
1780
  pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
1667
1781
  else:
1668
1782
  pfield_expanded = flat_pfield
1669
1783
 
1670
- pfield_expanded = pfield_expanded[..., None, None]
1784
+ append_n_dims = ops.ndim(data) - ops.ndim(pfield_expanded)
1785
+ pfield_expanded = extend_n_dims(pfield_expanded, axis=-1, n_dims=append_n_dims)
1786
+
1787
+ # Perform element-wise multiplication with the pressure weight mask
1671
1788
  weighted_data = data * pfield_expanded
1672
1789
 
1673
1790
  return {self.output_key: weighted_data}
@@ -1677,17 +1794,12 @@ class PfieldWeighting(Operation):
1677
1794
  class DelayAndSum(Operation):
1678
1795
  """Sums time-delayed signals along channels and transmits."""
1679
1796
 
1680
- def __init__(
1681
- self,
1682
- reshape_grid=True,
1683
- **kwargs,
1684
- ):
1797
+ def __init__(self, **kwargs):
1685
1798
  super().__init__(
1686
1799
  input_data_type=DataTypes.ALIGNED_DATA,
1687
1800
  output_data_type=DataTypes.BEAMFORMED_DATA,
1688
1801
  **kwargs,
1689
1802
  )
1690
- self.reshape_grid = reshape_grid
1691
1803
 
1692
1804
  def process_image(self, data):
1693
1805
  """Performs DAS beamforming on tof-corrected input.
@@ -1715,8 +1827,7 @@ class DelayAndSum(Operation):
1715
1827
 
1716
1828
  Returns:
1717
1829
  dict: Dictionary containing beamformed_data
1718
- of shape `(grid_size_z*grid_size_x, n_ch)` when reshape_grid is False
1719
- or `(grid_size_z, grid_size_x, n_ch)` when reshape_grid is True,
1830
+ of shape `(grid_size_z*grid_size_x, n_ch)`
1720
1831
  with optional batch dimension.
1721
1832
  """
1722
1833
  data = kwargs[self.key]
@@ -1727,11 +1838,6 @@ class DelayAndSum(Operation):
1727
1838
  # Apply process_image to each item in the batch
1728
1839
  beamformed_data = ops.map(self.process_image, data)
1729
1840
 
1730
- if self.reshape_grid:
1731
- beamformed_data = reshape_axis(
1732
- beamformed_data, grid.shape[:2], axis=int(self.with_batch_dim)
1733
- )
1734
-
1735
1841
  return {self.output_key: beamformed_data}
1736
1842
 
1737
1843
 
@@ -2061,7 +2167,7 @@ class ScanConvert(Operation):
2061
2167
 
2062
2168
 
2063
2169
  @ops_registry("gaussian_blur")
2064
- class GaussianBlur(Operation):
2170
+ class GaussianBlur(ImageOperation):
2065
2171
  """
2066
2172
  GaussianBlur is an operation that applies a Gaussian blur to an input image.
2067
2173
  Uses scipy.ndimage.gaussian_filter to create a kernel.
@@ -2111,6 +2217,13 @@ class GaussianBlur(Operation):
2111
2217
  return ops.convert_to_tensor(kernel)
2112
2218
 
2113
2219
  def call(self, **kwargs):
2220
+ """Apply a Gaussian filter to the input data.
2221
+
2222
+ Args:
2223
+ data (ops.Tensor): Input image data of shape (height, width, channels) with
2224
+ optional batch dimension if ``self.with_batch_dim``.
2225
+ """
2226
+ super().call(**kwargs)
2114
2227
  data = kwargs[self.key]
2115
2228
 
2116
2229
  # Add batch dimension if not present
@@ -2145,7 +2258,7 @@ class GaussianBlur(Operation):
2145
2258
 
2146
2259
 
2147
2260
  @ops_registry("lee_filter")
2148
- class LeeFilter(Operation):
2261
+ class LeeFilter(ImageOperation):
2149
2262
  """
2150
2263
  The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
2151
2264
  and ultrasound image processing. It smooths the image while preserving edges and details.
@@ -2175,7 +2288,7 @@ class LeeFilter(Operation):
2175
2288
  pad_mode=self.pad_mode,
2176
2289
  with_batch_dim=self.with_batch_dim,
2177
2290
  jittable=self._jittable,
2178
- key=self.key,
2291
+ key="data",
2179
2292
  )
2180
2293
 
2181
2294
  @property
@@ -2191,24 +2304,29 @@ class LeeFilter(Operation):
2191
2304
  self.gaussian_blur.with_batch_dim = value
2192
2305
 
2193
2306
  def call(self, **kwargs):
2194
- data = kwargs[self.key]
2307
+ """Apply the Lee filter to the input data.
2308
+
2309
+ Args:
2310
+ data (ops.Tensor): Input image data of shape (height, width, channels) with
2311
+ optional batch dimension if ``self.with_batch_dim``.
2312
+ """
2313
+ super().call(**kwargs)
2314
+ data = kwargs.pop(self.key)
2195
2315
 
2196
2316
  # Apply Gaussian blur to get local mean
2197
- img_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
2317
+ img_mean = self.gaussian_blur.call(data=data, **kwargs)[self.gaussian_blur.output_key]
2198
2318
 
2199
2319
  # Apply Gaussian blur to squared data to get local squared mean
2200
- data_squared = data**2
2201
- kwargs[self.gaussian_blur.key] = data_squared
2202
- img_sqr_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
2320
+ img_sqr_mean = self.gaussian_blur.call(
2321
+ data=data**2,
2322
+ **kwargs,
2323
+ )[self.gaussian_blur.output_key]
2203
2324
 
2204
2325
  # Calculate local variance
2205
2326
  img_variance = img_sqr_mean - img_mean**2
2206
2327
 
2207
2328
  # Calculate global variance (per channel)
2208
- if self.with_batch_dim:
2209
- overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
2210
- else:
2211
- overall_variance = ops.var(data, axis=(-2, -1), keepdims=True)
2329
+ overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
2212
2330
 
2213
2331
  # Calculate adaptive weights
2214
2332
  img_weights = img_variance / (img_variance + overall_variance)
@@ -2229,7 +2347,11 @@ class Demodulate(Operation):
2229
2347
  input_data_type=DataTypes.RAW_DATA,
2230
2348
  output_data_type=DataTypes.RAW_DATA,
2231
2349
  jittable=True,
2232
- additional_output_keys=["demodulation_frequency", "center_frequency", "n_ch"],
2350
+ additional_output_keys=[
2351
+ "demodulation_frequency",
2352
+ "center_frequency",
2353
+ "n_ch",
2354
+ ],
2233
2355
  **kwargs,
2234
2356
  )
2235
2357
  self.axis = axis
@@ -2255,6 +2377,121 @@ class Demodulate(Operation):
2255
2377
  }
2256
2378
 
2257
2379
 
2380
+ @ops_registry("fir_filter")
2381
+ class FirFilter(Operation):
2382
+ """Apply a FIR filter to the input signal using convolution.
2383
+
2384
+ Looks for the filter taps in the input dictionary using the specified ``filter_key``.
2385
+ """
2386
+
2387
+ def __init__(
2388
+ self,
2389
+ axis: int,
2390
+ complex_channels: bool = False,
2391
+ filter_key: str = "fir_filter_taps",
2392
+ **kwargs,
2393
+ ):
2394
+ """
2395
+ Args:
2396
+ axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
2397
+ When using ``complex_channels=True``, the complex channels are removed to convert
2398
+ to complex numbers before filtering, so adjust the ``axis`` accordingly!
2399
+ complex_channels (bool): Whether the last dimension of the input signal represents
2400
+ complex channels (real and imaginary parts). When True, it will convert the signal
2401
+ to ``complex`` dtype before filtering and convert it back to two channels
2402
+ after filtering.
2403
+ filter_key (str): Key in the input dictionary where the FIR filter taps are stored.
2404
+ Default is "fir_filter_taps".
2405
+ """
2406
+ super().__init__(**kwargs)
2407
+ self._check_axis(axis)
2408
+
2409
+ self.axis = axis
2410
+ self.complex_channels = complex_channels
2411
+ self.filter_key = filter_key
2412
+
2413
+ def _check_axis(self, axis, ndim=None):
2414
+ """Check if the axis is valid."""
2415
+ if ndim is not None:
2416
+ if axis < -ndim or axis >= ndim:
2417
+ raise ValueError(f"Axis {axis} is out of bounds for array of dimension {ndim}.")
2418
+
2419
+ if self.with_batch_dim and (axis == 0 or (ndim is not None and axis == -ndim)):
2420
+ raise ValueError("Cannot apply FIR filter along batch dimension.")
2421
+
2422
+ @property
2423
+ def valid_keys(self):
2424
+ """Get the valid keys for the `call` method."""
2425
+ return self._valid_keys.union({self.filter_key})
2426
+
2427
+ def call(self, **kwargs):
2428
+ signal = kwargs[self.key]
2429
+ fir_filter_taps = kwargs[self.filter_key]
2430
+
2431
+ if self.complex_channels:
2432
+ signal = channels_to_complex(signal)
2433
+
2434
+ self._check_axis(self.axis, ndim=ops.ndim(signal))
2435
+
2436
+ def _convolve(signal):
2437
+ """Apply the filter to the signal using correlation."""
2438
+ return correlate(signal, fir_filter_taps[::-1], mode="same")
2439
+
2440
+ filtered_signal = apply_along_axis(_convolve, self.axis, signal)
2441
+
2442
+ if self.complex_channels:
2443
+ filtered_signal = complex_to_channels(filtered_signal)
2444
+
2445
+ return {self.output_key: filtered_signal}
2446
+
2447
+
2448
+ @ops_registry("low_pass_filter")
2449
+ class LowPassFilter(FirFilter):
2450
+ """Apply a low-pass FIR filter to the input signal using convolution.
2451
+
2452
+ It is recommended to use :class:`FirFilter` with pre-computed filter taps for jittable
2453
+ operations. The :class:`LowPassFilter` operation itself is not jittable and is provided
2454
+ for convenience only.
2455
+
2456
+ Uses :func:`get_low_pass_iq_filter` to compute the filter taps.
2457
+ """
2458
+
2459
+ def __init__(self, axis: int, complex_channels: bool = False, num_taps: int = 128, **kwargs):
2460
+ """Initialize the LowPassFilter operation.
2461
+
2462
+ Args:
2463
+ axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
2464
+ When using ``complex_channels=True``, the complex channels are removed to convert
2465
+ to complex numbers before filtering, so adjust the ``axis`` accordingly.
2466
+ complex_channels (bool): Whether the last dimension of the input signal represents
2467
+ complex channels (real and imaginary parts). When True, it will convert the signal
2468
+ to ``complex`` dtype before filtering and convert it back to two channels
2469
+ after filtering.
2470
+ num_taps (int): Number of taps in the FIR filter. Default is 128.
2471
+ """
2472
+ self._random_suffix = str(uuid.uuid4())
2473
+ kwargs.pop("filter_key", None)
2474
+ kwargs.pop("jittable", None)
2475
+ super().__init__(
2476
+ axis=axis,
2477
+ complex_channels=complex_channels,
2478
+ filter_key=f"low_pass_{self._random_suffix}",
2479
+ jittable=False,
2480
+ **kwargs,
2481
+ )
2482
+ self.num_taps = num_taps
2483
+
2484
+ def call(self, bandwidth, sampling_frequency, center_frequency, **kwargs):
2485
+ lpf = get_low_pass_iq_filter(
2486
+ self.num_taps,
2487
+ ops.convert_to_numpy(sampling_frequency).item(),
2488
+ ops.convert_to_numpy(center_frequency).item(),
2489
+ ops.convert_to_numpy(bandwidth).item(),
2490
+ )
2491
+ kwargs[self.filter_key] = lpf
2492
+ return super().call(**kwargs)
2493
+
2494
+
2258
2495
  @ops_registry("lambda")
2259
2496
  class Lambda(Operation):
2260
2497
  """Use any function as an operation."""
@@ -2289,7 +2526,10 @@ class Lambda(Operation):
2289
2526
 
2290
2527
  def call(self, **kwargs):
2291
2528
  data = kwargs[self.key]
2292
- data = self.func(data)
2529
+ if self.with_batch_dim:
2530
+ data = ops.map(self.func, data)
2531
+ else:
2532
+ data = self.func(data)
2293
2533
  return {self.output_key: data}
2294
2534
 
2295
2535
 
@@ -3044,7 +3284,7 @@ def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
3044
3284
  return bpf
3045
3285
 
3046
3286
 
3047
- def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
3287
+ def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth):
3048
3288
  """Design complex low-pass filter.
3049
3289
 
3050
3290
  The filter is a low-pass FIR filter modulated to the center frequency.
@@ -3052,16 +3292,16 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
3052
3292
  Args:
3053
3293
  num_taps (int): number of taps in filter.
3054
3294
  sampling_frequency (float): sample frequency.
3055
- f (float): center frequency.
3056
- bw (float): bandwidth in Hz.
3295
+ center_frequency (float): center frequency.
3296
+ bandwidth (float): bandwidth in Hz.
3057
3297
 
3058
3298
  Raises:
3059
- ValueError: if cutoff frequency (bw / 2) is not within (0, sampling_frequency / 2)
3299
+ ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
3060
3300
 
3061
3301
  Returns:
3062
3302
  ndarray: Complex-valued low-pass filter
3063
3303
  """
3064
- cutoff = bw / 2
3304
+ cutoff = bandwidth / 2
3065
3305
  if not (0 < cutoff < sampling_frequency / 2):
3066
3306
  raise ValueError(
3067
3307
  f"Cutoff frequency must be within (0, sampling_frequency / 2), "
@@ -3071,7 +3311,7 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
3071
3311
  lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
3072
3312
  # Modulate to center frequency to make it complex
3073
3313
  time_points = np.arange(num_taps) / sampling_frequency
3074
- lpf_complex = lpf * np.exp(1j * 2 * np.pi * f * time_points)
3314
+ lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points)
3075
3315
  return lpf_complex
3076
3316
 
3077
3317