zea 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +0 -4
- zea/beamform/pixelgrid.py +1 -1
- zea/data/__init__.py +0 -9
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +99 -39
- zea/data/convert/echonet.py +183 -82
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +173 -102
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +33 -61
- zea/data/data_format.py +124 -4
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +109 -70
- zea/data/file.py +91 -82
- zea/data/file_operations.py +496 -0
- zea/data/preset_utils.py +1 -1
- zea/display.py +7 -8
- zea/internal/checks.py +6 -12
- zea/internal/operators.py +4 -0
- zea/io_lib.py +108 -160
- zea/models/__init__.py +1 -1
- zea/models/diffusion.py +62 -11
- zea/models/lv_segmentation.py +2 -0
- zea/ops.py +398 -158
- zea/scan.py +18 -8
- zea/tensor_ops.py +82 -62
- zea/tools/fit_scan_cone.py +90 -160
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +11 -2
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/METADATA +3 -1
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/RECORD +43 -35
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/ops.py
CHANGED
|
@@ -88,10 +88,10 @@ Example of a yaml file:
|
|
|
88
88
|
|
|
89
89
|
"""
|
|
90
90
|
|
|
91
|
-
import copy
|
|
92
91
|
import hashlib
|
|
93
92
|
import inspect
|
|
94
93
|
import json
|
|
94
|
+
import uuid
|
|
95
95
|
from functools import partial
|
|
96
96
|
from typing import Any, Dict, List, Union
|
|
97
97
|
|
|
@@ -120,7 +120,15 @@ from zea.internal.registry import ops_registry
|
|
|
120
120
|
from zea.probes import Probe
|
|
121
121
|
from zea.scan import Scan
|
|
122
122
|
from zea.simulator import simulate_rf
|
|
123
|
-
from zea.tensor_ops import
|
|
123
|
+
from zea.tensor_ops import (
|
|
124
|
+
apply_along_axis,
|
|
125
|
+
correlate,
|
|
126
|
+
extend_n_dims,
|
|
127
|
+
resample,
|
|
128
|
+
reshape_axis,
|
|
129
|
+
translate,
|
|
130
|
+
vmap,
|
|
131
|
+
)
|
|
124
132
|
from zea.utils import (
|
|
125
133
|
FunctionTimer,
|
|
126
134
|
deep_compare,
|
|
@@ -238,7 +246,7 @@ class Operation(keras.Operation):
|
|
|
238
246
|
Analyze and store the input/output signatures of the `call` method.
|
|
239
247
|
"""
|
|
240
248
|
self._input_signature = inspect.signature(self.call)
|
|
241
|
-
self._valid_keys = set(self._input_signature.parameters.keys())
|
|
249
|
+
self._valid_keys = set(self._input_signature.parameters.keys()) | {self.key}
|
|
242
250
|
|
|
243
251
|
@property
|
|
244
252
|
def valid_keys(self) -> set:
|
|
@@ -405,6 +413,36 @@ class Operation(keras.Operation):
|
|
|
405
413
|
return True
|
|
406
414
|
|
|
407
415
|
|
|
416
|
+
class ImageOperation(Operation):
|
|
417
|
+
"""
|
|
418
|
+
Base class for image processing operations.
|
|
419
|
+
|
|
420
|
+
This class extends the Operation class to provide a common interface
|
|
421
|
+
for operations that process image data, with shape (batch, height, width, channels)
|
|
422
|
+
or (height, width, channels) if batch dimension is not present.
|
|
423
|
+
|
|
424
|
+
Subclasses should implement the `call` method to define the image processing logic, and call
|
|
425
|
+
``super().call(**kwargs)`` to validate the input data shape.
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
def call(self, **kwargs):
|
|
429
|
+
"""
|
|
430
|
+
Validate input data shape for image operations.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
**kwargs: Keyword arguments containing input data.
|
|
434
|
+
|
|
435
|
+
Raises:
|
|
436
|
+
AssertionError: If input data does not have the expected number of dimensions.
|
|
437
|
+
"""
|
|
438
|
+
data = kwargs[self.key]
|
|
439
|
+
|
|
440
|
+
if self.with_batch_dim:
|
|
441
|
+
assert ops.ndim(data) == 4, "Input data must have 4 dimensions (b, h, w, c)."
|
|
442
|
+
else:
|
|
443
|
+
assert ops.ndim(data) == 3, "Input data must have 3 dimensions (h, w, c)."
|
|
444
|
+
|
|
445
|
+
|
|
408
446
|
@ops_registry("pipeline")
|
|
409
447
|
class Pipeline:
|
|
410
448
|
"""Pipeline class for processing ultrasound data through a series of operations."""
|
|
@@ -488,10 +526,20 @@ class Pipeline:
|
|
|
488
526
|
self.jit_kwargs = jit_kwargs
|
|
489
527
|
self.jit_options = jit_options # will handle the jit compilation
|
|
490
528
|
|
|
529
|
+
self._logged_difference_keys = False
|
|
530
|
+
|
|
531
|
+
# Do not log again for nested pipelines
|
|
532
|
+
for nested_pipeline in self._nested_pipelines:
|
|
533
|
+
nested_pipeline._logged_difference_keys = True
|
|
534
|
+
|
|
491
535
|
def needs(self, key) -> bool:
|
|
492
536
|
"""Check if the pipeline needs a specific key at the input."""
|
|
493
537
|
return key in self.needs_keys
|
|
494
538
|
|
|
539
|
+
@property
|
|
540
|
+
def _nested_pipelines(self):
|
|
541
|
+
return [operation for operation in self.operations if isinstance(operation, Pipeline)]
|
|
542
|
+
|
|
495
543
|
@property
|
|
496
544
|
def output_keys(self) -> set:
|
|
497
545
|
"""All output keys the pipeline guarantees to produce."""
|
|
@@ -576,6 +624,7 @@ class Pipeline:
|
|
|
576
624
|
|
|
577
625
|
# Add display ops
|
|
578
626
|
operations += [
|
|
627
|
+
ReshapeGrid(),
|
|
579
628
|
EnvelopeDetect(),
|
|
580
629
|
Normalize(),
|
|
581
630
|
LogCompress(),
|
|
@@ -682,6 +731,16 @@ class Pipeline:
|
|
|
682
731
|
"Please ensure all inputs are convertible to tensors."
|
|
683
732
|
)
|
|
684
733
|
|
|
734
|
+
if not self._logged_difference_keys:
|
|
735
|
+
difference_keys = set(inputs.keys()) - self.valid_keys
|
|
736
|
+
if difference_keys:
|
|
737
|
+
log.debug(
|
|
738
|
+
f"[zea.Pipeline] The following input keys are not used by the pipeline: "
|
|
739
|
+
f"{difference_keys}. Make sure this is intended. "
|
|
740
|
+
"This warning will only be shown once."
|
|
741
|
+
)
|
|
742
|
+
self._logged_difference_keys = True
|
|
743
|
+
|
|
685
744
|
## PROCESSING
|
|
686
745
|
outputs = self._call_pipeline(**inputs)
|
|
687
746
|
|
|
@@ -804,59 +863,9 @@ class Pipeline:
|
|
|
804
863
|
return params
|
|
805
864
|
|
|
806
865
|
def __str__(self):
|
|
807
|
-
"""String representation of the pipeline.
|
|
808
|
-
|
|
809
|
-
Will print on two parallel pipeline lines if it detects a splitting operations
|
|
810
|
-
(such as multi_bandpass_filter)
|
|
811
|
-
Will merge the pipeline lines if it detects a stacking operation (such as stack)
|
|
812
|
-
"""
|
|
813
|
-
split_operations = []
|
|
814
|
-
merge_operations = ["Stack"]
|
|
815
|
-
|
|
866
|
+
"""String representation of the pipeline."""
|
|
816
867
|
operations = [operation.__class__.__name__ for operation in self.operations]
|
|
817
868
|
string = " -> ".join(operations)
|
|
818
|
-
|
|
819
|
-
if any(operation in split_operations for operation in operations):
|
|
820
|
-
# a second line is needed with same length as the first line
|
|
821
|
-
split_line = " " * len(string)
|
|
822
|
-
# find the splitting operation and index and print \-> instead of -> after
|
|
823
|
-
split_detected = False
|
|
824
|
-
merge_detected = False
|
|
825
|
-
split_operation = None
|
|
826
|
-
for operation in operations:
|
|
827
|
-
if operation in split_operations:
|
|
828
|
-
index = string.index(operation)
|
|
829
|
-
index = index + len(operation)
|
|
830
|
-
split_line = split_line[:index] + "\\->" + split_line[index + len("\\->") :]
|
|
831
|
-
split_detected = True
|
|
832
|
-
merge_detected = False
|
|
833
|
-
split_operation = operation
|
|
834
|
-
continue
|
|
835
|
-
|
|
836
|
-
if operation in merge_operations:
|
|
837
|
-
index = string.index(operation)
|
|
838
|
-
index = index - 4
|
|
839
|
-
split_line = split_line[:index] + "/" + split_line[index + 1 :]
|
|
840
|
-
split_detected = False
|
|
841
|
-
merge_detected = True
|
|
842
|
-
continue
|
|
843
|
-
|
|
844
|
-
if split_detected:
|
|
845
|
-
# print all operations in the second line
|
|
846
|
-
index = string.index(operation)
|
|
847
|
-
split_line = (
|
|
848
|
-
split_line[:index]
|
|
849
|
-
+ operation
|
|
850
|
-
+ " -> "
|
|
851
|
-
+ split_line[index + len(operation) + len(" -> ") :]
|
|
852
|
-
)
|
|
853
|
-
assert merge_detected is True, log.error(
|
|
854
|
-
"Pipeline was never merged back together (with Stack operation), even "
|
|
855
|
-
f"though it was split with {split_operation}. "
|
|
856
|
-
"Please properly define your operation chain."
|
|
857
|
-
)
|
|
858
|
-
return f"\n{string}\n{split_line}\n"
|
|
859
|
-
|
|
860
869
|
return string
|
|
861
870
|
|
|
862
871
|
def __repr__(self):
|
|
@@ -1185,7 +1194,7 @@ def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
|
|
|
1185
1194
|
operations = make_operation_chain(config.operations)
|
|
1186
1195
|
|
|
1187
1196
|
# merge pipeline config without operations with kwargs
|
|
1188
|
-
pipeline_config = copy
|
|
1197
|
+
pipeline_config = config.copy()
|
|
1189
1198
|
pipeline_config.pop("operations")
|
|
1190
1199
|
|
|
1191
1200
|
kwargs = {**pipeline_config, **kwargs}
|
|
@@ -1256,33 +1265,134 @@ def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
|
|
|
1256
1265
|
yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
|
|
1257
1266
|
|
|
1258
1267
|
|
|
1259
|
-
@ops_registry("
|
|
1260
|
-
class
|
|
1268
|
+
@ops_registry("map")
|
|
1269
|
+
class Map(Pipeline):
|
|
1261
1270
|
"""
|
|
1262
|
-
|
|
1263
|
-
|
|
1271
|
+
A pipeline that maps its operations over specified input arguments.
|
|
1272
|
+
|
|
1273
|
+
This can be used to reduce memory usage by processing data in chunks.
|
|
1274
|
+
|
|
1275
|
+
Notes
|
|
1276
|
+
-----
|
|
1277
|
+
- When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline.
|
|
1278
|
+
- Changing anything other than ``self.output_key`` in the dict will not be propagated.
|
|
1279
|
+
- Will be jitted as a single operation, not the individual operations.
|
|
1280
|
+
- This class handles the batching.
|
|
1264
1281
|
|
|
1265
|
-
|
|
1282
|
+
For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for
|
|
1283
|
+
jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_.
|
|
1266
1284
|
|
|
1267
|
-
|
|
1285
|
+
Example
|
|
1286
|
+
-------
|
|
1287
|
+
.. doctest::
|
|
1268
1288
|
|
|
1269
|
-
|
|
1289
|
+
>>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection
|
|
1270
1290
|
|
|
1271
|
-
|
|
1291
|
+
>>> # apply operations in batches of 8
|
|
1292
|
+
>>> # in this case, over the first axis of "data"
|
|
1293
|
+
>>> # or more specifically, process 8 transmits at a time
|
|
1272
1294
|
|
|
1273
|
-
|
|
1295
|
+
>>> pipeline_mapped = Map(
|
|
1296
|
+
... [
|
|
1297
|
+
... Demodulate(),
|
|
1298
|
+
... TOFCorrection(),
|
|
1299
|
+
... ],
|
|
1300
|
+
... argnames="data",
|
|
1301
|
+
... batch_size=8,
|
|
1302
|
+
... )
|
|
1303
|
+
|
|
1304
|
+
>>> # you can also map a subset of the operations
|
|
1305
|
+
>>> # for example, demodulate in 4 chunks
|
|
1306
|
+
>>> # or more specifically, split the transmit axis into 4 parts
|
|
1274
1307
|
|
|
1308
|
+
>>> pipeline_mapped = Pipeline(
|
|
1309
|
+
... [
|
|
1310
|
+
... Map([Demodulate()], argnames="data", chunks=4),
|
|
1311
|
+
... TOFCorrection(),
|
|
1312
|
+
... ],
|
|
1313
|
+
... )
|
|
1275
1314
|
"""
|
|
1276
1315
|
|
|
1277
|
-
def __init__(
|
|
1278
|
-
|
|
1279
|
-
|
|
1316
|
+
def __init__(
|
|
1317
|
+
self,
|
|
1318
|
+
operations: List[Operation],
|
|
1319
|
+
argnames: List[str] | str,
|
|
1320
|
+
in_axes: List[Union[int, None]] | int = 0,
|
|
1321
|
+
out_axes: List[Union[int, None]] | int = 0,
|
|
1322
|
+
chunks: int | None = None,
|
|
1323
|
+
batch_size: int | None = None,
|
|
1324
|
+
**kwargs,
|
|
1325
|
+
):
|
|
1326
|
+
"""
|
|
1327
|
+
Args:
|
|
1328
|
+
operations (list): List of operations to be performed.
|
|
1329
|
+
argnames (str or list): List of argument names (or keys) to map over.
|
|
1330
|
+
Can also be a single string if only one argument is mapped over.
|
|
1331
|
+
in_axes (int or list): Axes to map over for each argument.
|
|
1332
|
+
If a single int is provided, it is used for all arguments.
|
|
1333
|
+
out_axes (int or list): Axes to map over for each output.
|
|
1334
|
+
If a single int is provided, it is used for all outputs.
|
|
1335
|
+
chunks (int, optional): Number of chunks to split the input data into.
|
|
1336
|
+
If None, no chunking is performed. Mutually exclusive with ``batch_size``.
|
|
1337
|
+
batch_size (int, optional): Size of batches to process at once.
|
|
1338
|
+
If None, no batching is performed. Mutually exclusive with ``chunks``.
|
|
1339
|
+
"""
|
|
1340
|
+
super().__init__(operations, **kwargs)
|
|
1280
1341
|
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1342
|
+
if batch_size is not None and chunks is not None:
|
|
1343
|
+
raise ValueError(
|
|
1344
|
+
"batch_size and chunks are mutually exclusive. Please specify only one."
|
|
1345
|
+
)
|
|
1284
1346
|
|
|
1285
|
-
|
|
1347
|
+
if batch_size is not None and batch_size <= 0:
|
|
1348
|
+
raise ValueError("batch_size must be a positive integer.")
|
|
1349
|
+
|
|
1350
|
+
if chunks is not None and chunks <= 0:
|
|
1351
|
+
raise ValueError("chunks must be a positive integer.")
|
|
1352
|
+
|
|
1353
|
+
if isinstance(argnames, str):
|
|
1354
|
+
argnames = [argnames]
|
|
1355
|
+
|
|
1356
|
+
self.argnames = argnames
|
|
1357
|
+
self.in_axes = in_axes
|
|
1358
|
+
self.out_axes = out_axes
|
|
1359
|
+
self.chunks = chunks
|
|
1360
|
+
self.batch_size = batch_size
|
|
1361
|
+
|
|
1362
|
+
if chunks is None and batch_size is None:
|
|
1363
|
+
log.warning(
|
|
1364
|
+
"[zea.ops.Map] Both `chunks` and `batch_size` are None. "
|
|
1365
|
+
"This will behave like a normal Pipeline. "
|
|
1366
|
+
"Consider setting one of them to process data in chunks or batches."
|
|
1367
|
+
)
|
|
1368
|
+
|
|
1369
|
+
def call_item(**inputs):
|
|
1370
|
+
"""Process data in patches."""
|
|
1371
|
+
mapped_args = []
|
|
1372
|
+
for argname in argnames:
|
|
1373
|
+
mapped_args.append(inputs.pop(argname, None))
|
|
1374
|
+
|
|
1375
|
+
def patched_call(*args):
|
|
1376
|
+
mapped_kwargs = [(k, v) for k, v in zip(argnames, args)]
|
|
1377
|
+
out = super(Map, self).call(**dict(mapped_kwargs), **inputs)
|
|
1378
|
+
|
|
1379
|
+
# TODO: maybe it is possible to output everything?
|
|
1380
|
+
# e.g. prepend a empty dimension to all inputs and just map over everything?
|
|
1381
|
+
return out[self.output_key]
|
|
1382
|
+
|
|
1383
|
+
out = vmap(
|
|
1384
|
+
patched_call,
|
|
1385
|
+
in_axes=in_axes,
|
|
1386
|
+
out_axes=out_axes,
|
|
1387
|
+
chunks=chunks,
|
|
1388
|
+
batch_size=batch_size,
|
|
1389
|
+
fn_supports_batch=True,
|
|
1390
|
+
disable_jit=not bool(self.jit_options),
|
|
1391
|
+
)(*mapped_args)
|
|
1392
|
+
|
|
1393
|
+
return out
|
|
1394
|
+
|
|
1395
|
+
self.call_item = call_item
|
|
1286
1396
|
|
|
1287
1397
|
@property
|
|
1288
1398
|
def jit_options(self):
|
|
@@ -1320,60 +1430,30 @@ class PatchedGrid(Pipeline):
|
|
|
1320
1430
|
for operation in self.operations:
|
|
1321
1431
|
operation.with_batch_dim = False
|
|
1322
1432
|
|
|
1323
|
-
@property
|
|
1324
|
-
def _extra_keys(self):
|
|
1325
|
-
return {"flatgrid", "grid_size_x", "grid_size_z"}
|
|
1326
|
-
|
|
1327
1433
|
@property
|
|
1328
1434
|
def valid_keys(self) -> set:
|
|
1329
1435
|
"""Get a set of valid keys for the pipeline.
|
|
1330
1436
|
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1331
1437
|
inside it)."""
|
|
1332
|
-
return super().valid_keys.union(self.
|
|
1438
|
+
return super().valid_keys.union(self.argnames)
|
|
1333
1439
|
|
|
1334
1440
|
@property
|
|
1335
1441
|
def needs_keys(self) -> set:
|
|
1336
1442
|
"""Get a set of all input keys needed by the pipeline.
|
|
1337
1443
|
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1338
1444
|
inside it)."""
|
|
1339
|
-
return super().needs_keys.union(self.
|
|
1340
|
-
|
|
1341
|
-
def call_item(self, inputs):
|
|
1342
|
-
"""Process data in patches."""
|
|
1343
|
-
# Extract necessary parameters
|
|
1344
|
-
# make sure to add those as valid keys above!
|
|
1345
|
-
grid_size_x = inputs["grid_size_x"]
|
|
1346
|
-
grid_size_z = inputs["grid_size_z"]
|
|
1347
|
-
flatgrid = inputs.pop("flatgrid")
|
|
1348
|
-
|
|
1349
|
-
# Define a list of keys to look up for patching
|
|
1350
|
-
flat_pfield = inputs.pop("flat_pfield", None)
|
|
1351
|
-
|
|
1352
|
-
def patched_call(flatgrid, flat_pfield):
|
|
1353
|
-
out = super(PatchedGrid, self).call(
|
|
1354
|
-
flatgrid=flatgrid, flat_pfield=flat_pfield, **inputs
|
|
1355
|
-
)
|
|
1356
|
-
return out[self.output_key]
|
|
1357
|
-
|
|
1358
|
-
out = vmap(
|
|
1359
|
-
patched_call,
|
|
1360
|
-
chunks=self.num_patches,
|
|
1361
|
-
fn_supports_batch=True,
|
|
1362
|
-
disable_jit=not bool(self.jit_options),
|
|
1363
|
-
)(flatgrid, flat_pfield)
|
|
1364
|
-
|
|
1365
|
-
return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
|
|
1445
|
+
return super().needs_keys.union(self.argnames)
|
|
1366
1446
|
|
|
1367
1447
|
def jittable_call(self, **inputs):
|
|
1368
1448
|
"""Process input data through the pipeline."""
|
|
1369
1449
|
if self._with_batch_dim:
|
|
1370
1450
|
input_data = inputs.pop(self.key)
|
|
1371
1451
|
output = ops.map(
|
|
1372
|
-
lambda x: self.call_item({self.key: x, **inputs}),
|
|
1452
|
+
lambda x: self.call_item(**{self.key: x, **inputs}),
|
|
1373
1453
|
input_data,
|
|
1374
1454
|
)
|
|
1375
1455
|
else:
|
|
1376
|
-
output = self.call_item(inputs)
|
|
1456
|
+
output = self.call_item(**inputs)
|
|
1377
1457
|
|
|
1378
1458
|
return {self.output_key: output}
|
|
1379
1459
|
|
|
@@ -1386,11 +1466,61 @@ class PatchedGrid(Pipeline):
|
|
|
1386
1466
|
def get_dict(self):
|
|
1387
1467
|
"""Get the configuration of the pipeline."""
|
|
1388
1468
|
config = super().get_dict()
|
|
1389
|
-
config.update(
|
|
1469
|
+
config["params"].update(
|
|
1470
|
+
{
|
|
1471
|
+
"argnames": self.argnames,
|
|
1472
|
+
"in_axes": self.in_axes,
|
|
1473
|
+
"out_axes": self.out_axes,
|
|
1474
|
+
"chunks": self.chunks,
|
|
1475
|
+
"batch_size": self.batch_size,
|
|
1476
|
+
}
|
|
1477
|
+
)
|
|
1478
|
+
return config
|
|
1479
|
+
|
|
1480
|
+
|
|
1481
|
+
@ops_registry("patched_grid")
|
|
1482
|
+
class PatchedGrid(Map):
|
|
1483
|
+
"""
|
|
1484
|
+
A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys.
|
|
1485
|
+
|
|
1486
|
+
This can be used to reduce memory usage by processing data in chunks.
|
|
1487
|
+
|
|
1488
|
+
For more information and flexibility, see :class:`zea.ops.Map`.
|
|
1489
|
+
"""
|
|
1490
|
+
|
|
1491
|
+
def __init__(self, *args, num_patches=10, **kwargs):
|
|
1492
|
+
super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs)
|
|
1493
|
+
self.num_patches = num_patches
|
|
1494
|
+
|
|
1495
|
+
def get_dict(self):
|
|
1496
|
+
"""Get the configuration of the pipeline."""
|
|
1497
|
+
config = super().get_dict()
|
|
1498
|
+
config["params"].pop("argnames")
|
|
1499
|
+
config["params"].pop("chunks")
|
|
1390
1500
|
config["params"].update({"num_patches": self.num_patches})
|
|
1391
1501
|
return config
|
|
1392
1502
|
|
|
1393
1503
|
|
|
1504
|
+
@ops_registry("reshape_grid")
|
|
1505
|
+
class ReshapeGrid(Operation):
|
|
1506
|
+
"""Reshape flat grid data to grid shape."""
|
|
1507
|
+
|
|
1508
|
+
def __init__(self, axis=0, **kwargs):
|
|
1509
|
+
super().__init__(**kwargs)
|
|
1510
|
+
self.axis = axis
|
|
1511
|
+
|
|
1512
|
+
def call(self, grid, **kwargs):
|
|
1513
|
+
"""
|
|
1514
|
+
Args:
|
|
1515
|
+
- data (Tensor): The flat grid data of shape (..., n_pix, ...).
|
|
1516
|
+
Returns:
|
|
1517
|
+
- reshaped_data (Tensor): The reshaped data of shape (..., grid.shape, ...).
|
|
1518
|
+
"""
|
|
1519
|
+
data = kwargs[self.key]
|
|
1520
|
+
reshaped_data = reshape_axis(data, grid.shape[:-1], self.axis + int(self.with_batch_dim))
|
|
1521
|
+
return {self.output_key: reshaped_data}
|
|
1522
|
+
|
|
1523
|
+
|
|
1394
1524
|
## Base Operations
|
|
1395
1525
|
|
|
1396
1526
|
|
|
@@ -1423,21 +1553,6 @@ class Merge(Operation):
|
|
|
1423
1553
|
return merged
|
|
1424
1554
|
|
|
1425
1555
|
|
|
1426
|
-
@ops_registry("split")
|
|
1427
|
-
class Split(Operation):
|
|
1428
|
-
"""Operation that splits an input dictionary n copies."""
|
|
1429
|
-
|
|
1430
|
-
def __init__(self, n: int, **kwargs):
|
|
1431
|
-
super().__init__(**kwargs)
|
|
1432
|
-
self.n = n
|
|
1433
|
-
|
|
1434
|
-
def call(self, **kwargs) -> List[Dict]:
|
|
1435
|
-
"""
|
|
1436
|
-
Splits the input dictionary into n copies.
|
|
1437
|
-
"""
|
|
1438
|
-
return [kwargs.copy() for _ in range(self.n)]
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
1556
|
@ops_registry("stack")
|
|
1442
1557
|
class Stack(Operation):
|
|
1443
1558
|
"""Stack multiple data arrays along a new axis.
|
|
@@ -1652,7 +1767,7 @@ class PfieldWeighting(Operation):
|
|
|
1652
1767
|
Returns:
|
|
1653
1768
|
dict: Dictionary containing weighted data
|
|
1654
1769
|
"""
|
|
1655
|
-
data = kwargs[self.key]
|
|
1770
|
+
data = kwargs[self.key] # must start with ((batch_size,) n_tx, n_pix, ...)
|
|
1656
1771
|
|
|
1657
1772
|
if flat_pfield is None:
|
|
1658
1773
|
return {self.output_key: data}
|
|
@@ -1660,14 +1775,16 @@ class PfieldWeighting(Operation):
|
|
|
1660
1775
|
# Swap (n_pix, n_tx) to (n_tx, n_pix)
|
|
1661
1776
|
flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
|
|
1662
1777
|
|
|
1663
|
-
#
|
|
1664
|
-
# Also add the required dimensions for broadcasting
|
|
1778
|
+
# Add batch dimension if needed
|
|
1665
1779
|
if self.with_batch_dim:
|
|
1666
1780
|
pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
|
|
1667
1781
|
else:
|
|
1668
1782
|
pfield_expanded = flat_pfield
|
|
1669
1783
|
|
|
1670
|
-
|
|
1784
|
+
append_n_dims = ops.ndim(data) - ops.ndim(pfield_expanded)
|
|
1785
|
+
pfield_expanded = extend_n_dims(pfield_expanded, axis=-1, n_dims=append_n_dims)
|
|
1786
|
+
|
|
1787
|
+
# Perform element-wise multiplication with the pressure weight mask
|
|
1671
1788
|
weighted_data = data * pfield_expanded
|
|
1672
1789
|
|
|
1673
1790
|
return {self.output_key: weighted_data}
|
|
@@ -1677,17 +1794,12 @@ class PfieldWeighting(Operation):
|
|
|
1677
1794
|
class DelayAndSum(Operation):
|
|
1678
1795
|
"""Sums time-delayed signals along channels and transmits."""
|
|
1679
1796
|
|
|
1680
|
-
def __init__(
|
|
1681
|
-
self,
|
|
1682
|
-
reshape_grid=True,
|
|
1683
|
-
**kwargs,
|
|
1684
|
-
):
|
|
1797
|
+
def __init__(self, **kwargs):
|
|
1685
1798
|
super().__init__(
|
|
1686
1799
|
input_data_type=DataTypes.ALIGNED_DATA,
|
|
1687
1800
|
output_data_type=DataTypes.BEAMFORMED_DATA,
|
|
1688
1801
|
**kwargs,
|
|
1689
1802
|
)
|
|
1690
|
-
self.reshape_grid = reshape_grid
|
|
1691
1803
|
|
|
1692
1804
|
def process_image(self, data):
|
|
1693
1805
|
"""Performs DAS beamforming on tof-corrected input.
|
|
@@ -1715,8 +1827,7 @@ class DelayAndSum(Operation):
|
|
|
1715
1827
|
|
|
1716
1828
|
Returns:
|
|
1717
1829
|
dict: Dictionary containing beamformed_data
|
|
1718
|
-
of shape `(grid_size_z*grid_size_x, n_ch)`
|
|
1719
|
-
or `(grid_size_z, grid_size_x, n_ch)` when reshape_grid is True,
|
|
1830
|
+
of shape `(grid_size_z*grid_size_x, n_ch)`
|
|
1720
1831
|
with optional batch dimension.
|
|
1721
1832
|
"""
|
|
1722
1833
|
data = kwargs[self.key]
|
|
@@ -1727,11 +1838,6 @@ class DelayAndSum(Operation):
|
|
|
1727
1838
|
# Apply process_image to each item in the batch
|
|
1728
1839
|
beamformed_data = ops.map(self.process_image, data)
|
|
1729
1840
|
|
|
1730
|
-
if self.reshape_grid:
|
|
1731
|
-
beamformed_data = reshape_axis(
|
|
1732
|
-
beamformed_data, grid.shape[:2], axis=int(self.with_batch_dim)
|
|
1733
|
-
)
|
|
1734
|
-
|
|
1735
1841
|
return {self.output_key: beamformed_data}
|
|
1736
1842
|
|
|
1737
1843
|
|
|
@@ -2061,7 +2167,7 @@ class ScanConvert(Operation):
|
|
|
2061
2167
|
|
|
2062
2168
|
|
|
2063
2169
|
@ops_registry("gaussian_blur")
|
|
2064
|
-
class GaussianBlur(
|
|
2170
|
+
class GaussianBlur(ImageOperation):
|
|
2065
2171
|
"""
|
|
2066
2172
|
GaussianBlur is an operation that applies a Gaussian blur to an input image.
|
|
2067
2173
|
Uses scipy.ndimage.gaussian_filter to create a kernel.
|
|
@@ -2111,6 +2217,13 @@ class GaussianBlur(Operation):
|
|
|
2111
2217
|
return ops.convert_to_tensor(kernel)
|
|
2112
2218
|
|
|
2113
2219
|
def call(self, **kwargs):
|
|
2220
|
+
"""Apply a Gaussian filter to the input data.
|
|
2221
|
+
|
|
2222
|
+
Args:
|
|
2223
|
+
data (ops.Tensor): Input image data of shape (height, width, channels) with
|
|
2224
|
+
optional batch dimension if ``self.with_batch_dim``.
|
|
2225
|
+
"""
|
|
2226
|
+
super().call(**kwargs)
|
|
2114
2227
|
data = kwargs[self.key]
|
|
2115
2228
|
|
|
2116
2229
|
# Add batch dimension if not present
|
|
@@ -2145,7 +2258,7 @@ class GaussianBlur(Operation):
|
|
|
2145
2258
|
|
|
2146
2259
|
|
|
2147
2260
|
@ops_registry("lee_filter")
|
|
2148
|
-
class LeeFilter(
|
|
2261
|
+
class LeeFilter(ImageOperation):
|
|
2149
2262
|
"""
|
|
2150
2263
|
The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
|
|
2151
2264
|
and ultrasound image processing. It smooths the image while preserving edges and details.
|
|
@@ -2175,7 +2288,7 @@ class LeeFilter(Operation):
|
|
|
2175
2288
|
pad_mode=self.pad_mode,
|
|
2176
2289
|
with_batch_dim=self.with_batch_dim,
|
|
2177
2290
|
jittable=self._jittable,
|
|
2178
|
-
key=
|
|
2291
|
+
key="data",
|
|
2179
2292
|
)
|
|
2180
2293
|
|
|
2181
2294
|
@property
|
|
@@ -2191,24 +2304,29 @@ class LeeFilter(Operation):
|
|
|
2191
2304
|
self.gaussian_blur.with_batch_dim = value
|
|
2192
2305
|
|
|
2193
2306
|
def call(self, **kwargs):
|
|
2194
|
-
|
|
2307
|
+
"""Apply the Lee filter to the input data.
|
|
2308
|
+
|
|
2309
|
+
Args:
|
|
2310
|
+
data (ops.Tensor): Input image data of shape (height, width, channels) with
|
|
2311
|
+
optional batch dimension if ``self.with_batch_dim``.
|
|
2312
|
+
"""
|
|
2313
|
+
super().call(**kwargs)
|
|
2314
|
+
data = kwargs.pop(self.key)
|
|
2195
2315
|
|
|
2196
2316
|
# Apply Gaussian blur to get local mean
|
|
2197
|
-
img_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
|
|
2317
|
+
img_mean = self.gaussian_blur.call(data=data, **kwargs)[self.gaussian_blur.output_key]
|
|
2198
2318
|
|
|
2199
2319
|
# Apply Gaussian blur to squared data to get local squared mean
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
|
|
2320
|
+
img_sqr_mean = self.gaussian_blur.call(
|
|
2321
|
+
data=data**2,
|
|
2322
|
+
**kwargs,
|
|
2323
|
+
)[self.gaussian_blur.output_key]
|
|
2203
2324
|
|
|
2204
2325
|
# Calculate local variance
|
|
2205
2326
|
img_variance = img_sqr_mean - img_mean**2
|
|
2206
2327
|
|
|
2207
2328
|
# Calculate global variance (per channel)
|
|
2208
|
-
|
|
2209
|
-
overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
|
|
2210
|
-
else:
|
|
2211
|
-
overall_variance = ops.var(data, axis=(-2, -1), keepdims=True)
|
|
2329
|
+
overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
|
|
2212
2330
|
|
|
2213
2331
|
# Calculate adaptive weights
|
|
2214
2332
|
img_weights = img_variance / (img_variance + overall_variance)
|
|
@@ -2229,7 +2347,11 @@ class Demodulate(Operation):
|
|
|
2229
2347
|
input_data_type=DataTypes.RAW_DATA,
|
|
2230
2348
|
output_data_type=DataTypes.RAW_DATA,
|
|
2231
2349
|
jittable=True,
|
|
2232
|
-
additional_output_keys=[
|
|
2350
|
+
additional_output_keys=[
|
|
2351
|
+
"demodulation_frequency",
|
|
2352
|
+
"center_frequency",
|
|
2353
|
+
"n_ch",
|
|
2354
|
+
],
|
|
2233
2355
|
**kwargs,
|
|
2234
2356
|
)
|
|
2235
2357
|
self.axis = axis
|
|
@@ -2255,6 +2377,121 @@ class Demodulate(Operation):
|
|
|
2255
2377
|
}
|
|
2256
2378
|
|
|
2257
2379
|
|
|
2380
|
+
@ops_registry("fir_filter")
|
|
2381
|
+
class FirFilter(Operation):
|
|
2382
|
+
"""Apply a FIR filter to the input signal using convolution.
|
|
2383
|
+
|
|
2384
|
+
Looks for the filter taps in the input dictionary using the specified ``filter_key``.
|
|
2385
|
+
"""
|
|
2386
|
+
|
|
2387
|
+
def __init__(
|
|
2388
|
+
self,
|
|
2389
|
+
axis: int,
|
|
2390
|
+
complex_channels: bool = False,
|
|
2391
|
+
filter_key: str = "fir_filter_taps",
|
|
2392
|
+
**kwargs,
|
|
2393
|
+
):
|
|
2394
|
+
"""
|
|
2395
|
+
Args:
|
|
2396
|
+
axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
|
|
2397
|
+
When using ``complex_channels=True``, the complex channels are removed to convert
|
|
2398
|
+
to complex numbers before filtering, so adjust the ``axis`` accordingly!
|
|
2399
|
+
complex_channels (bool): Whether the last dimension of the input signal represents
|
|
2400
|
+
complex channels (real and imaginary parts). When True, it will convert the signal
|
|
2401
|
+
to ``complex`` dtype before filtering and convert it back to two channels
|
|
2402
|
+
after filtering.
|
|
2403
|
+
filter_key (str): Key in the input dictionary where the FIR filter taps are stored.
|
|
2404
|
+
Default is "fir_filter_taps".
|
|
2405
|
+
"""
|
|
2406
|
+
super().__init__(**kwargs)
|
|
2407
|
+
self._check_axis(axis)
|
|
2408
|
+
|
|
2409
|
+
self.axis = axis
|
|
2410
|
+
self.complex_channels = complex_channels
|
|
2411
|
+
self.filter_key = filter_key
|
|
2412
|
+
|
|
2413
|
+
def _check_axis(self, axis, ndim=None):
|
|
2414
|
+
"""Check if the axis is valid."""
|
|
2415
|
+
if ndim is not None:
|
|
2416
|
+
if axis < -ndim or axis >= ndim:
|
|
2417
|
+
raise ValueError(f"Axis {axis} is out of bounds for array of dimension {ndim}.")
|
|
2418
|
+
|
|
2419
|
+
if self.with_batch_dim and (axis == 0 or (ndim is not None and axis == -ndim)):
|
|
2420
|
+
raise ValueError("Cannot apply FIR filter along batch dimension.")
|
|
2421
|
+
|
|
2422
|
+
@property
|
|
2423
|
+
def valid_keys(self):
|
|
2424
|
+
"""Get the valid keys for the `call` method."""
|
|
2425
|
+
return self._valid_keys.union({self.filter_key})
|
|
2426
|
+
|
|
2427
|
+
def call(self, **kwargs):
|
|
2428
|
+
signal = kwargs[self.key]
|
|
2429
|
+
fir_filter_taps = kwargs[self.filter_key]
|
|
2430
|
+
|
|
2431
|
+
if self.complex_channels:
|
|
2432
|
+
signal = channels_to_complex(signal)
|
|
2433
|
+
|
|
2434
|
+
self._check_axis(self.axis, ndim=ops.ndim(signal))
|
|
2435
|
+
|
|
2436
|
+
def _convolve(signal):
|
|
2437
|
+
"""Apply the filter to the signal using correlation."""
|
|
2438
|
+
return correlate(signal, fir_filter_taps[::-1], mode="same")
|
|
2439
|
+
|
|
2440
|
+
filtered_signal = apply_along_axis(_convolve, self.axis, signal)
|
|
2441
|
+
|
|
2442
|
+
if self.complex_channels:
|
|
2443
|
+
filtered_signal = complex_to_channels(filtered_signal)
|
|
2444
|
+
|
|
2445
|
+
return {self.output_key: filtered_signal}
|
|
2446
|
+
|
|
2447
|
+
|
|
2448
|
+
@ops_registry("low_pass_filter")
|
|
2449
|
+
class LowPassFilter(FirFilter):
|
|
2450
|
+
"""Apply a low-pass FIR filter to the input signal using convolution.
|
|
2451
|
+
|
|
2452
|
+
It is recommended to use :class:`FirFilter` with pre-computed filter taps for jittable
|
|
2453
|
+
operations. The :class:`LowPassFilter` operation itself is not jittable and is provided
|
|
2454
|
+
for convenience only.
|
|
2455
|
+
|
|
2456
|
+
Uses :func:`get_low_pass_iq_filter` to compute the filter taps.
|
|
2457
|
+
"""
|
|
2458
|
+
|
|
2459
|
+
def __init__(self, axis: int, complex_channels: bool = False, num_taps: int = 128, **kwargs):
|
|
2460
|
+
"""Initialize the LowPassFilter operation.
|
|
2461
|
+
|
|
2462
|
+
Args:
|
|
2463
|
+
axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
|
|
2464
|
+
When using ``complex_channels=True``, the complex channels are removed to convert
|
|
2465
|
+
to complex numbers before filtering, so adjust the ``axis`` accordingly.
|
|
2466
|
+
complex_channels (bool): Whether the last dimension of the input signal represents
|
|
2467
|
+
complex channels (real and imaginary parts). When True, it will convert the signal
|
|
2468
|
+
to ``complex`` dtype before filtering and convert it back to two channels
|
|
2469
|
+
after filtering.
|
|
2470
|
+
num_taps (int): Number of taps in the FIR filter. Default is 128.
|
|
2471
|
+
"""
|
|
2472
|
+
self._random_suffix = str(uuid.uuid4())
|
|
2473
|
+
kwargs.pop("filter_key", None)
|
|
2474
|
+
kwargs.pop("jittable", None)
|
|
2475
|
+
super().__init__(
|
|
2476
|
+
axis=axis,
|
|
2477
|
+
complex_channels=complex_channels,
|
|
2478
|
+
filter_key=f"low_pass_{self._random_suffix}",
|
|
2479
|
+
jittable=False,
|
|
2480
|
+
**kwargs,
|
|
2481
|
+
)
|
|
2482
|
+
self.num_taps = num_taps
|
|
2483
|
+
|
|
2484
|
+
def call(self, bandwidth, sampling_frequency, center_frequency, **kwargs):
|
|
2485
|
+
lpf = get_low_pass_iq_filter(
|
|
2486
|
+
self.num_taps,
|
|
2487
|
+
ops.convert_to_numpy(sampling_frequency).item(),
|
|
2488
|
+
ops.convert_to_numpy(center_frequency).item(),
|
|
2489
|
+
ops.convert_to_numpy(bandwidth).item(),
|
|
2490
|
+
)
|
|
2491
|
+
kwargs[self.filter_key] = lpf
|
|
2492
|
+
return super().call(**kwargs)
|
|
2493
|
+
|
|
2494
|
+
|
|
2258
2495
|
@ops_registry("lambda")
|
|
2259
2496
|
class Lambda(Operation):
|
|
2260
2497
|
"""Use any function as an operation."""
|
|
@@ -2289,7 +2526,10 @@ class Lambda(Operation):
|
|
|
2289
2526
|
|
|
2290
2527
|
def call(self, **kwargs):
|
|
2291
2528
|
data = kwargs[self.key]
|
|
2292
|
-
|
|
2529
|
+
if self.with_batch_dim:
|
|
2530
|
+
data = ops.map(self.func, data)
|
|
2531
|
+
else:
|
|
2532
|
+
data = self.func(data)
|
|
2293
2533
|
return {self.output_key: data}
|
|
2294
2534
|
|
|
2295
2535
|
|
|
@@ -3044,7 +3284,7 @@ def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
|
|
|
3044
3284
|
return bpf
|
|
3045
3285
|
|
|
3046
3286
|
|
|
3047
|
-
def get_low_pass_iq_filter(num_taps, sampling_frequency,
|
|
3287
|
+
def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth):
|
|
3048
3288
|
"""Design complex low-pass filter.
|
|
3049
3289
|
|
|
3050
3290
|
The filter is a low-pass FIR filter modulated to the center frequency.
|
|
@@ -3052,16 +3292,16 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
|
|
|
3052
3292
|
Args:
|
|
3053
3293
|
num_taps (int): number of taps in filter.
|
|
3054
3294
|
sampling_frequency (float): sample frequency.
|
|
3055
|
-
|
|
3056
|
-
|
|
3295
|
+
center_frequency (float): center frequency.
|
|
3296
|
+
bandwidth (float): bandwidth in Hz.
|
|
3057
3297
|
|
|
3058
3298
|
Raises:
|
|
3059
|
-
ValueError: if cutoff frequency (
|
|
3299
|
+
ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
|
|
3060
3300
|
|
|
3061
3301
|
Returns:
|
|
3062
3302
|
ndarray: Complex-valued low-pass filter
|
|
3063
3303
|
"""
|
|
3064
|
-
cutoff =
|
|
3304
|
+
cutoff = bandwidth / 2
|
|
3065
3305
|
if not (0 < cutoff < sampling_frequency / 2):
|
|
3066
3306
|
raise ValueError(
|
|
3067
3307
|
f"Cutoff frequency must be within (0, sampling_frequency / 2), "
|
|
@@ -3071,7 +3311,7 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
|
|
|
3071
3311
|
lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
|
|
3072
3312
|
# Modulate to center frequency to make it complex
|
|
3073
3313
|
time_points = np.arange(num_taps) / sampling_frequency
|
|
3074
|
-
lpf_complex = lpf * np.exp(1j * 2 * np.pi *
|
|
3314
|
+
lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points)
|
|
3075
3315
|
return lpf_complex
|
|
3076
3316
|
|
|
3077
3317
|
|