zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -5
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/config.py +34 -25
- zea/data/__init__.py +22 -25
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +101 -40
- zea/data/convert/echonet.py +187 -86
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +44 -65
- zea/data/data_format.py +155 -34
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +112 -71
- zea/data/file.py +184 -73
- zea/data/file_operations.py +496 -0
- zea/data/layers.py +3 -3
- zea/data/preset_utils.py +1 -1
- zea/datapaths.py +16 -4
- zea/display.py +14 -13
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/checks.py +6 -12
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +118 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +322 -146
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +31 -21
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +235 -23
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +30 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +770 -336
- zea/probes.py +6 -6
- zea/scan.py +121 -51
- zea/simulator.py +24 -21
- zea/tensor_ops.py +477 -353
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +101 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
- zea-0.0.8.dist-info/RECORD +122 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/ops.py
CHANGED
|
@@ -11,39 +11,62 @@ Operations can be run on their own:
|
|
|
11
11
|
|
|
12
12
|
Examples
|
|
13
13
|
^^^^^^^^
|
|
14
|
-
..
|
|
14
|
+
.. doctest::
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
#
|
|
20
|
-
|
|
16
|
+
>>> import numpy as np
|
|
17
|
+
>>> from zea.ops import EnvelopeDetect
|
|
18
|
+
>>> data = np.random.randn(2000, 128, 1)
|
|
19
|
+
>>> # static arguments are passed in the constructor
|
|
20
|
+
>>> envelope_detect = EnvelopeDetect(axis=-1)
|
|
21
|
+
>>> # other parameters can be passed here along with the data
|
|
22
|
+
>>> envelope_data = envelope_detect(data=data)
|
|
21
23
|
|
|
22
24
|
Using a pipeline
|
|
23
25
|
----------------
|
|
24
26
|
|
|
25
27
|
You can initialize with a default pipeline or create your own custom pipeline.
|
|
26
28
|
|
|
27
|
-
..
|
|
29
|
+
.. doctest::
|
|
28
30
|
|
|
29
|
-
|
|
31
|
+
>>> from zea.ops import Pipeline, EnvelopeDetect, Normalize, LogCompress
|
|
32
|
+
>>> pipeline = Pipeline.from_default()
|
|
30
33
|
|
|
31
|
-
operations = [
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
]
|
|
36
|
-
pipeline_custom = Pipeline(operations)
|
|
34
|
+
>>> operations = [
|
|
35
|
+
... EnvelopeDetect(),
|
|
36
|
+
... Normalize(),
|
|
37
|
+
... LogCompress(),
|
|
38
|
+
... ]
|
|
39
|
+
>>> pipeline_custom = Pipeline(operations)
|
|
37
40
|
|
|
38
41
|
One can also load a pipeline from a config or yaml/json file:
|
|
39
42
|
|
|
40
|
-
..
|
|
43
|
+
.. doctest::
|
|
44
|
+
|
|
45
|
+
>>> from zea import Pipeline
|
|
46
|
+
|
|
47
|
+
>>> # From JSON string
|
|
48
|
+
>>> json_string = '{"operations": ["identity"]}'
|
|
49
|
+
>>> pipeline = Pipeline.from_json(json_string)
|
|
50
|
+
|
|
51
|
+
>>> # from yaml file
|
|
52
|
+
>>> import yaml
|
|
53
|
+
>>> from zea import Config
|
|
54
|
+
>>> # Create a sample pipeline YAML file
|
|
55
|
+
>>> pipeline_dict = {
|
|
56
|
+
... "operations": [
|
|
57
|
+
... {"name": "identity"},
|
|
58
|
+
... ]
|
|
59
|
+
... }
|
|
60
|
+
>>> with open("pipeline.yaml", "w") as f:
|
|
61
|
+
... yaml.dump(pipeline_dict, f)
|
|
62
|
+
>>> yaml_file = "pipeline.yaml"
|
|
63
|
+
>>> pipeline = Pipeline.from_yaml(yaml_file)
|
|
64
|
+
|
|
65
|
+
.. testcleanup::
|
|
41
66
|
|
|
42
|
-
|
|
43
|
-
pipeline = Pipeline.from_json(json_string)
|
|
67
|
+
import os
|
|
44
68
|
|
|
45
|
-
|
|
46
|
-
pipeline = Pipeline.from_yaml(yaml_file)
|
|
69
|
+
os.remove("pipeline.yaml")
|
|
47
70
|
|
|
48
71
|
Example of a yaml file:
|
|
49
72
|
|
|
@@ -56,8 +79,6 @@ Example of a yaml file:
|
|
|
56
79
|
params:
|
|
57
80
|
operations:
|
|
58
81
|
- name: tof_correction
|
|
59
|
-
params:
|
|
60
|
-
apply_phase_rotation: true
|
|
61
82
|
- name: pfield_weighting
|
|
62
83
|
- name: delay_and_sum
|
|
63
84
|
num_patches: 100
|
|
@@ -67,10 +88,10 @@ Example of a yaml file:
|
|
|
67
88
|
|
|
68
89
|
"""
|
|
69
90
|
|
|
70
|
-
import copy
|
|
71
91
|
import hashlib
|
|
72
92
|
import inspect
|
|
73
93
|
import json
|
|
94
|
+
import uuid
|
|
74
95
|
from functools import partial
|
|
75
96
|
from typing import Any, Dict, List, Union
|
|
76
97
|
|
|
@@ -79,7 +100,7 @@ import numpy as np
|
|
|
79
100
|
import scipy
|
|
80
101
|
import yaml
|
|
81
102
|
from keras import ops
|
|
82
|
-
from keras.src.layers.preprocessing.
|
|
103
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
|
83
104
|
|
|
84
105
|
from zea import log
|
|
85
106
|
from zea.backend import jit
|
|
@@ -99,8 +120,20 @@ from zea.internal.registry import ops_registry
|
|
|
99
120
|
from zea.probes import Probe
|
|
100
121
|
from zea.scan import Scan
|
|
101
122
|
from zea.simulator import simulate_rf
|
|
102
|
-
from zea.tensor_ops import
|
|
103
|
-
|
|
123
|
+
from zea.tensor_ops import (
|
|
124
|
+
apply_along_axis,
|
|
125
|
+
correlate,
|
|
126
|
+
extend_n_dims,
|
|
127
|
+
resample,
|
|
128
|
+
reshape_axis,
|
|
129
|
+
translate,
|
|
130
|
+
vmap,
|
|
131
|
+
)
|
|
132
|
+
from zea.utils import (
|
|
133
|
+
FunctionTimer,
|
|
134
|
+
deep_compare,
|
|
135
|
+
map_negative_indices,
|
|
136
|
+
)
|
|
104
137
|
|
|
105
138
|
|
|
106
139
|
def get_ops(ops_name):
|
|
@@ -125,6 +158,7 @@ class Operation(keras.Operation):
|
|
|
125
158
|
with_batch_dim: bool = True,
|
|
126
159
|
jit_kwargs: dict | None = None,
|
|
127
160
|
jittable: bool = True,
|
|
161
|
+
additional_output_keys: List[str] = None,
|
|
128
162
|
**kwargs,
|
|
129
163
|
):
|
|
130
164
|
"""
|
|
@@ -143,6 +177,10 @@ class Operation(keras.Operation):
|
|
|
143
177
|
with_batch_dim: Whether operations should expect a batch dimension in the input
|
|
144
178
|
jit_kwargs: Additional keyword arguments for the JIT compiler
|
|
145
179
|
jittable: Whether the operation can be JIT compiled
|
|
180
|
+
additional_output_keys: A list of additional output keys produced by the operation.
|
|
181
|
+
These are used to track if all keys are available for downstream operations.
|
|
182
|
+
If the operation has a conditional output, it is best to add all possible
|
|
183
|
+
output keys here.
|
|
146
184
|
"""
|
|
147
185
|
super().__init__(**kwargs)
|
|
148
186
|
|
|
@@ -153,6 +191,7 @@ class Operation(keras.Operation):
|
|
|
153
191
|
self.output_key = output_key # Key for output data
|
|
154
192
|
if self.output_key is None:
|
|
155
193
|
self.output_key = self.key
|
|
194
|
+
self.additional_output_keys = additional_output_keys or []
|
|
156
195
|
|
|
157
196
|
self.inputs = [] # Source(s) of input data (name of a previous operation)
|
|
158
197
|
self.allow_multiple_inputs = False # Only single input allowed by default
|
|
@@ -165,8 +204,6 @@ class Operation(keras.Operation):
|
|
|
165
204
|
self._output_cache = {}
|
|
166
205
|
|
|
167
206
|
# Obtain the input signature of the `call` method
|
|
168
|
-
self._input_signature = None
|
|
169
|
-
self._valid_keys = None # Keys valid for the `call` method
|
|
170
207
|
self._trace_signatures()
|
|
171
208
|
|
|
172
209
|
if jit_kwargs is None:
|
|
@@ -186,6 +223,11 @@ class Operation(keras.Operation):
|
|
|
186
223
|
with log.set_level("ERROR"):
|
|
187
224
|
self.set_jit(jit_compile)
|
|
188
225
|
|
|
226
|
+
@property
|
|
227
|
+
def output_keys(self) -> List[str]:
|
|
228
|
+
"""Get the output keys of the operation."""
|
|
229
|
+
return [self.output_key] + self.additional_output_keys
|
|
230
|
+
|
|
189
231
|
@property
|
|
190
232
|
def static_params(self):
|
|
191
233
|
"""Get the static parameters of the operation."""
|
|
@@ -204,13 +246,18 @@ class Operation(keras.Operation):
|
|
|
204
246
|
Analyze and store the input/output signatures of the `call` method.
|
|
205
247
|
"""
|
|
206
248
|
self._input_signature = inspect.signature(self.call)
|
|
207
|
-
self._valid_keys = set(self._input_signature.parameters.keys())
|
|
249
|
+
self._valid_keys = set(self._input_signature.parameters.keys()) | {self.key}
|
|
208
250
|
|
|
209
251
|
@property
|
|
210
|
-
def valid_keys(self):
|
|
252
|
+
def valid_keys(self) -> set:
|
|
211
253
|
"""Get the valid keys for the `call` method."""
|
|
212
254
|
return self._valid_keys
|
|
213
255
|
|
|
256
|
+
@property
|
|
257
|
+
def needs_keys(self) -> set:
|
|
258
|
+
"""Get a set of all input keys needed by the operation."""
|
|
259
|
+
return self.valid_keys
|
|
260
|
+
|
|
214
261
|
@property
|
|
215
262
|
def jittable(self):
|
|
216
263
|
"""Check if the operation can be JIT compiled."""
|
|
@@ -366,6 +413,36 @@ class Operation(keras.Operation):
|
|
|
366
413
|
return True
|
|
367
414
|
|
|
368
415
|
|
|
416
|
+
class ImageOperation(Operation):
|
|
417
|
+
"""
|
|
418
|
+
Base class for image processing operations.
|
|
419
|
+
|
|
420
|
+
This class extends the Operation class to provide a common interface
|
|
421
|
+
for operations that process image data, with shape (batch, height, width, channels)
|
|
422
|
+
or (height, width, channels) if batch dimension is not present.
|
|
423
|
+
|
|
424
|
+
Subclasses should implement the `call` method to define the image processing logic, and call
|
|
425
|
+
``super().call(**kwargs)`` to validate the input data shape.
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
def call(self, **kwargs):
|
|
429
|
+
"""
|
|
430
|
+
Validate input data shape for image operations.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
**kwargs: Keyword arguments containing input data.
|
|
434
|
+
|
|
435
|
+
Raises:
|
|
436
|
+
AssertionError: If input data does not have the expected number of dimensions.
|
|
437
|
+
"""
|
|
438
|
+
data = kwargs[self.key]
|
|
439
|
+
|
|
440
|
+
if self.with_batch_dim:
|
|
441
|
+
assert ops.ndim(data) == 4, "Input data must have 4 dimensions (b, h, w, c)."
|
|
442
|
+
else:
|
|
443
|
+
assert ops.ndim(data) == 3, "Input data must have 3 dimensions (h, w, c)."
|
|
444
|
+
|
|
445
|
+
|
|
369
446
|
@ops_registry("pipeline")
|
|
370
447
|
class Pipeline:
|
|
371
448
|
"""Pipeline class for processing ultrasound data through a series of operations."""
|
|
@@ -408,8 +485,6 @@ class Pipeline:
|
|
|
408
485
|
"""
|
|
409
486
|
self._call_pipeline = self.call
|
|
410
487
|
self.name = name
|
|
411
|
-
self.timer = FunctionTimer()
|
|
412
|
-
self.timed = timed
|
|
413
488
|
|
|
414
489
|
self._pipeline_layers = operations
|
|
415
490
|
|
|
@@ -419,6 +494,24 @@ class Pipeline:
|
|
|
419
494
|
self.with_batch_dim = with_batch_dim
|
|
420
495
|
self._validate_flag = validate
|
|
421
496
|
|
|
497
|
+
# Setup timer
|
|
498
|
+
if jit_options == "pipeline" and timed:
|
|
499
|
+
raise ValueError(
|
|
500
|
+
"timed=True cannot be used with jit_options='pipeline' as the entire "
|
|
501
|
+
"pipeline is compiled into a single function. Try setting jit_options to "
|
|
502
|
+
"'ops' or None."
|
|
503
|
+
)
|
|
504
|
+
if timed:
|
|
505
|
+
log.warning(
|
|
506
|
+
"Timer has been initialized for the pipeline. To get an accurate timing estimate, "
|
|
507
|
+
"the `block_until_ready()` is used, which will slow down the execution, so "
|
|
508
|
+
"do not use for regular processing!"
|
|
509
|
+
)
|
|
510
|
+
self._callable_layers = self._get_timed_operations()
|
|
511
|
+
else:
|
|
512
|
+
self._callable_layers = self._pipeline_layers
|
|
513
|
+
self._timed = timed
|
|
514
|
+
|
|
422
515
|
if validate:
|
|
423
516
|
self.validate()
|
|
424
517
|
else:
|
|
@@ -427,19 +520,40 @@ class Pipeline:
|
|
|
427
520
|
if jit_kwargs is None:
|
|
428
521
|
jit_kwargs = {}
|
|
429
522
|
|
|
430
|
-
if keras.backend.backend() == "jax" and self.static_params:
|
|
523
|
+
if keras.backend.backend() == "jax" and self.static_params != []:
|
|
431
524
|
jit_kwargs = {"static_argnames": self.static_params}
|
|
432
525
|
|
|
433
526
|
self.jit_kwargs = jit_kwargs
|
|
434
527
|
self.jit_options = jit_options # will handle the jit compilation
|
|
435
528
|
|
|
529
|
+
self._logged_difference_keys = False
|
|
530
|
+
|
|
531
|
+
# Do not log again for nested pipelines
|
|
532
|
+
for nested_pipeline in self._nested_pipelines:
|
|
533
|
+
nested_pipeline._logged_difference_keys = True
|
|
534
|
+
|
|
436
535
|
def needs(self, key) -> bool:
|
|
437
|
-
"""Check if the pipeline needs a specific key."""
|
|
438
|
-
return key in self.
|
|
536
|
+
"""Check if the pipeline needs a specific key at the input."""
|
|
537
|
+
return key in self.needs_keys
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def _nested_pipelines(self):
|
|
541
|
+
return [operation for operation in self.operations if isinstance(operation, Pipeline)]
|
|
542
|
+
|
|
543
|
+
@property
|
|
544
|
+
def output_keys(self) -> set:
|
|
545
|
+
"""All output keys the pipeline guarantees to produce."""
|
|
546
|
+
output_keys = set()
|
|
547
|
+
for operation in self.operations:
|
|
548
|
+
output_keys.update(operation.output_keys)
|
|
549
|
+
return output_keys
|
|
439
550
|
|
|
440
551
|
@property
|
|
441
552
|
def valid_keys(self) -> set:
|
|
442
|
-
"""Get a set of valid keys for the pipeline.
|
|
553
|
+
"""Get a set of valid keys for the pipeline.
|
|
554
|
+
|
|
555
|
+
This is all keys that can be passed to the pipeline as input.
|
|
556
|
+
"""
|
|
443
557
|
valid_keys = set()
|
|
444
558
|
for operation in self.operations:
|
|
445
559
|
valid_keys.update(operation.valid_keys)
|
|
@@ -453,8 +567,26 @@ class Pipeline:
|
|
|
453
567
|
static_params.extend(operation.static_params)
|
|
454
568
|
return list(set(static_params))
|
|
455
569
|
|
|
570
|
+
@property
|
|
571
|
+
def needs_keys(self) -> set:
|
|
572
|
+
"""Get a set of all input keys needed by the pipeline.
|
|
573
|
+
|
|
574
|
+
Will keep track of keys that are already provided by previous operations.
|
|
575
|
+
"""
|
|
576
|
+
needs = set()
|
|
577
|
+
has_so_far = set()
|
|
578
|
+
previous_operation = None
|
|
579
|
+
for operation in self.operations:
|
|
580
|
+
if previous_operation is not None:
|
|
581
|
+
has_so_far.update(previous_operation.output_keys)
|
|
582
|
+
needs.update(operation.needs_keys - has_so_far)
|
|
583
|
+
previous_operation = operation
|
|
584
|
+
return needs
|
|
585
|
+
|
|
456
586
|
@classmethod
|
|
457
|
-
def from_default(
|
|
587
|
+
def from_default(
|
|
588
|
+
cls, num_patches=100, baseband=False, pfield=False, timed=False, **kwargs
|
|
589
|
+
) -> "Pipeline":
|
|
458
590
|
"""Create a default pipeline.
|
|
459
591
|
|
|
460
592
|
Args:
|
|
@@ -465,6 +597,7 @@ class Pipeline:
|
|
|
465
597
|
so input signal has a single channel dim and is still on carrier frequency.
|
|
466
598
|
pfield (bool): If True, apply Pfield weighting. Defaults to False.
|
|
467
599
|
This will calculate pressure field and only beamform the data to those locations.
|
|
600
|
+
timed (bool, optional): Whether to time each operation. Defaults to False.
|
|
468
601
|
**kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
|
|
469
602
|
|
|
470
603
|
"""
|
|
@@ -476,7 +609,7 @@ class Pipeline:
|
|
|
476
609
|
|
|
477
610
|
# Get beamforming ops
|
|
478
611
|
beamforming = [
|
|
479
|
-
TOFCorrection(
|
|
612
|
+
TOFCorrection(),
|
|
480
613
|
DelayAndSum(),
|
|
481
614
|
]
|
|
482
615
|
if pfield:
|
|
@@ -491,11 +624,12 @@ class Pipeline:
|
|
|
491
624
|
|
|
492
625
|
# Add display ops
|
|
493
626
|
operations += [
|
|
627
|
+
ReshapeGrid(),
|
|
494
628
|
EnvelopeDetect(),
|
|
495
629
|
Normalize(),
|
|
496
630
|
LogCompress(),
|
|
497
631
|
]
|
|
498
|
-
return cls(operations, **kwargs)
|
|
632
|
+
return cls(operations, timed=timed, **kwargs)
|
|
499
633
|
|
|
500
634
|
def copy(self) -> "Pipeline":
|
|
501
635
|
"""Create a copy of the pipeline."""
|
|
@@ -506,57 +640,60 @@ class Pipeline:
|
|
|
506
640
|
jit_kwargs=self.jit_kwargs,
|
|
507
641
|
name=self.name,
|
|
508
642
|
validate=self._validate_flag,
|
|
643
|
+
timed=self._timed,
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
def reinitialize(self):
|
|
647
|
+
"""Reinitialize the pipeline in place."""
|
|
648
|
+
self.__init__(
|
|
649
|
+
self._pipeline_layers,
|
|
650
|
+
with_batch_dim=self.with_batch_dim,
|
|
651
|
+
jit_options=self.jit_options,
|
|
652
|
+
jit_kwargs=self.jit_kwargs,
|
|
653
|
+
name=self.name,
|
|
654
|
+
validate=self._validate_flag,
|
|
655
|
+
timed=self._timed,
|
|
509
656
|
)
|
|
510
657
|
|
|
511
658
|
def prepend(self, operation: Operation):
|
|
512
659
|
"""Prepend an operation to the pipeline."""
|
|
513
660
|
self._pipeline_layers.insert(0, operation)
|
|
514
|
-
self.
|
|
661
|
+
self.reinitialize()
|
|
515
662
|
|
|
516
663
|
def append(self, operation: Operation):
|
|
517
664
|
"""Append an operation to the pipeline."""
|
|
518
665
|
self._pipeline_layers.append(operation)
|
|
519
|
-
self.
|
|
666
|
+
self.reinitialize()
|
|
520
667
|
|
|
521
668
|
def insert(self, index: int, operation: Operation):
|
|
522
669
|
"""Insert an operation at a specific index in the pipeline."""
|
|
523
670
|
if index < 0 or index > len(self._pipeline_layers):
|
|
524
671
|
raise IndexError("Index out of bounds for inserting operation.")
|
|
525
672
|
self._pipeline_layers.insert(index, operation)
|
|
526
|
-
|
|
673
|
+
self.reinitialize()
|
|
527
674
|
|
|
528
675
|
@property
|
|
529
676
|
def operations(self):
|
|
530
677
|
"""Alias for self.layers to match the zea naming convention"""
|
|
531
678
|
return self._pipeline_layers
|
|
532
679
|
|
|
533
|
-
def
|
|
534
|
-
"""
|
|
680
|
+
def reset_timer(self):
|
|
681
|
+
"""Reset the timer for timed operations."""
|
|
682
|
+
if self._timed:
|
|
683
|
+
self._callable_layers = self._get_timed_operations()
|
|
684
|
+
else:
|
|
685
|
+
log.warning(
|
|
686
|
+
"Timer has not been initialized. Set timed=True when initializing the pipeline."
|
|
687
|
+
)
|
|
535
688
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
except KeyError as exc:
|
|
541
|
-
raise KeyError(
|
|
542
|
-
f"[zea.Pipeline] Operation '{op.__class__.__name__}' "
|
|
543
|
-
f"requires input key '{exc.args[0]}', "
|
|
544
|
-
"but it was not provided in the inputs.\n"
|
|
545
|
-
"Check whether the objects (such as `zea.Scan`) passed to "
|
|
546
|
-
"`pipeline.prepare_parameters()` contain all required keys.\n"
|
|
547
|
-
f"Current list of all passed keys: {list(inputs.keys())}\n"
|
|
548
|
-
f"Valid keys for this pipeline: {self.valid_keys}"
|
|
549
|
-
) from exc
|
|
550
|
-
except Exception as exc:
|
|
551
|
-
raise RuntimeError(
|
|
552
|
-
f"[zea.Pipeline] Error in operation '{op.__class__.__name__}': {exc}"
|
|
553
|
-
) from exc
|
|
554
|
-
inputs = outputs
|
|
555
|
-
return outputs
|
|
689
|
+
def _get_timed_operations(self):
|
|
690
|
+
"""Get a list of timed operations."""
|
|
691
|
+
self.timer = FunctionTimer()
|
|
692
|
+
return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
|
|
556
693
|
|
|
557
694
|
def call(self, **inputs):
|
|
558
695
|
"""Process input data through the pipeline."""
|
|
559
|
-
for operation in self.
|
|
696
|
+
for operation in self._callable_layers:
|
|
560
697
|
try:
|
|
561
698
|
outputs = operation(**inputs)
|
|
562
699
|
except KeyError as exc:
|
|
@@ -579,14 +716,9 @@ class Pipeline:
|
|
|
579
716
|
def __call__(self, return_numpy=False, **inputs):
|
|
580
717
|
"""Process input data through the pipeline."""
|
|
581
718
|
|
|
582
|
-
if any(key in inputs for key in ["probe", "scan", "config"])
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
"`Pipeline.prepare_parameters` before calling the pipeline. "
|
|
586
|
-
"e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
if any(isinstance(arg, ZEAObject) for arg in inputs.values()):
|
|
719
|
+
if any(key in inputs for key in ["probe", "scan", "config"]) or any(
|
|
720
|
+
isinstance(arg, ZEAObject) for arg in inputs.values()
|
|
721
|
+
):
|
|
590
722
|
raise ValueError(
|
|
591
723
|
"Probe, Scan and Config objects should be first processed with "
|
|
592
724
|
"`Pipeline.prepare_parameters` before calling the pipeline. "
|
|
@@ -599,6 +731,16 @@ class Pipeline:
|
|
|
599
731
|
"Please ensure all inputs are convertible to tensors."
|
|
600
732
|
)
|
|
601
733
|
|
|
734
|
+
if not self._logged_difference_keys:
|
|
735
|
+
difference_keys = set(inputs.keys()) - self.valid_keys
|
|
736
|
+
if difference_keys:
|
|
737
|
+
log.debug(
|
|
738
|
+
f"[zea.Pipeline] The following input keys are not used by the pipeline: "
|
|
739
|
+
f"{difference_keys}. Make sure this is intended. "
|
|
740
|
+
"This warning will only be shown once."
|
|
741
|
+
)
|
|
742
|
+
self._logged_difference_keys = True
|
|
743
|
+
|
|
602
744
|
## PROCESSING
|
|
603
745
|
outputs = self._call_pipeline(**inputs)
|
|
604
746
|
|
|
@@ -640,18 +782,13 @@ class Pipeline:
|
|
|
640
782
|
if operation.jittable and operation._jit_compile:
|
|
641
783
|
operation.set_jit(value == "ops")
|
|
642
784
|
|
|
643
|
-
@property
|
|
644
|
-
def _call_fn(self):
|
|
645
|
-
"""Get the call function of the pipeline."""
|
|
646
|
-
return self.call if not self.timed else self.timed_call
|
|
647
|
-
|
|
648
785
|
def jit(self):
|
|
649
786
|
"""JIT compile the pipeline."""
|
|
650
|
-
self._call_pipeline = jit(self.
|
|
787
|
+
self._call_pipeline = jit(self.call, **self.jit_kwargs)
|
|
651
788
|
|
|
652
789
|
def unjit(self):
|
|
653
790
|
"""Un-JIT compile the pipeline."""
|
|
654
|
-
self._call_pipeline = self.
|
|
791
|
+
self._call_pipeline = self.call
|
|
655
792
|
|
|
656
793
|
@property
|
|
657
794
|
def jittable(self):
|
|
@@ -726,59 +863,9 @@ class Pipeline:
|
|
|
726
863
|
return params
|
|
727
864
|
|
|
728
865
|
def __str__(self):
|
|
729
|
-
"""String representation of the pipeline.
|
|
730
|
-
|
|
731
|
-
Will print on two parallel pipeline lines if it detects a splitting operations
|
|
732
|
-
(such as multi_bandpass_filter)
|
|
733
|
-
Will merge the pipeline lines if it detects a stacking operation (such as stack)
|
|
734
|
-
"""
|
|
735
|
-
split_operations = []
|
|
736
|
-
merge_operations = ["Stack"]
|
|
737
|
-
|
|
866
|
+
"""String representation of the pipeline."""
|
|
738
867
|
operations = [operation.__class__.__name__ for operation in self.operations]
|
|
739
868
|
string = " -> ".join(operations)
|
|
740
|
-
|
|
741
|
-
if any(operation in split_operations for operation in operations):
|
|
742
|
-
# a second line is needed with same length as the first line
|
|
743
|
-
split_line = " " * len(string)
|
|
744
|
-
# find the splitting operation and index and print \-> instead of -> after
|
|
745
|
-
split_detected = False
|
|
746
|
-
merge_detected = False
|
|
747
|
-
split_operation = None
|
|
748
|
-
for operation in operations:
|
|
749
|
-
if operation in split_operations:
|
|
750
|
-
index = string.index(operation)
|
|
751
|
-
index = index + len(operation)
|
|
752
|
-
split_line = split_line[:index] + "\\->" + split_line[index + len("\\->") :]
|
|
753
|
-
split_detected = True
|
|
754
|
-
merge_detected = False
|
|
755
|
-
split_operation = operation
|
|
756
|
-
continue
|
|
757
|
-
|
|
758
|
-
if operation in merge_operations:
|
|
759
|
-
index = string.index(operation)
|
|
760
|
-
index = index - 4
|
|
761
|
-
split_line = split_line[:index] + "/" + split_line[index + 1 :]
|
|
762
|
-
split_detected = False
|
|
763
|
-
merge_detected = True
|
|
764
|
-
continue
|
|
765
|
-
|
|
766
|
-
if split_detected:
|
|
767
|
-
# print all operations in the second line
|
|
768
|
-
index = string.index(operation)
|
|
769
|
-
split_line = (
|
|
770
|
-
split_line[:index]
|
|
771
|
-
+ operation
|
|
772
|
-
+ " -> "
|
|
773
|
-
+ split_line[index + len(operation) + len(" -> ") :]
|
|
774
|
-
)
|
|
775
|
-
assert merge_detected is True, log.error(
|
|
776
|
-
"Pipeline was never merged back together (with Stack operation), even "
|
|
777
|
-
f"though it was split with {split_operation}. "
|
|
778
|
-
"Please properly define your operation chain."
|
|
779
|
-
)
|
|
780
|
-
return f"\n{string}\n{split_line}\n"
|
|
781
|
-
|
|
782
869
|
return string
|
|
783
870
|
|
|
784
871
|
def __repr__(self):
|
|
@@ -835,16 +922,17 @@ class Pipeline:
|
|
|
835
922
|
Must have a ``pipeline`` key with a subkey ``operations``.
|
|
836
923
|
|
|
837
924
|
Example:
|
|
838
|
-
..
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
925
|
+
.. doctest::
|
|
926
|
+
|
|
927
|
+
>>> from zea import Config, Pipeline
|
|
928
|
+
>>> config = Config(
|
|
929
|
+
... {
|
|
930
|
+
... "operations": [
|
|
931
|
+
... "identity",
|
|
932
|
+
... ],
|
|
933
|
+
... }
|
|
934
|
+
... )
|
|
935
|
+
>>> pipeline = Pipeline.from_config(config)
|
|
848
936
|
"""
|
|
849
937
|
return pipeline_from_config(Config(config), **kwargs)
|
|
850
938
|
|
|
@@ -860,9 +948,20 @@ class Pipeline:
|
|
|
860
948
|
Must have the a `pipeline` key with a subkey `operations`.
|
|
861
949
|
|
|
862
950
|
Example:
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
951
|
+
.. doctest::
|
|
952
|
+
|
|
953
|
+
>>> import yaml
|
|
954
|
+
>>> from zea import Config
|
|
955
|
+
>>> # Create a sample pipeline YAML file
|
|
956
|
+
>>> pipeline_dict = {
|
|
957
|
+
... "operations": [
|
|
958
|
+
... "identity",
|
|
959
|
+
... ],
|
|
960
|
+
... }
|
|
961
|
+
>>> with open("pipeline.yaml", "w") as f:
|
|
962
|
+
... yaml.dump(pipeline_dict, f)
|
|
963
|
+
>>> from zea.ops import Pipeline
|
|
964
|
+
>>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
|
|
866
965
|
"""
|
|
867
966
|
return pipeline_from_yaml(file_path, **kwargs)
|
|
868
967
|
|
|
@@ -963,7 +1062,7 @@ class Pipeline:
|
|
|
963
1062
|
assert isinstance(scan, Scan), (
|
|
964
1063
|
f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
|
|
965
1064
|
)
|
|
966
|
-
scan_dict = scan.to_tensor(include=self.
|
|
1065
|
+
scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
|
|
967
1066
|
|
|
968
1067
|
if config is not None:
|
|
969
1068
|
assert isinstance(config, Config), (
|
|
@@ -1004,15 +1103,17 @@ def make_operation_chain(
|
|
|
1004
1103
|
list: List of operations to be performed.
|
|
1005
1104
|
|
|
1006
1105
|
Example:
|
|
1007
|
-
..
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
)
|
|
1106
|
+
.. doctest::
|
|
1107
|
+
|
|
1108
|
+
>>> from zea.ops import make_operation_chain, LogCompress
|
|
1109
|
+
>>> SomeCustomOperation = LogCompress # just for demonstration
|
|
1110
|
+
>>> chain = make_operation_chain(
|
|
1111
|
+
... [
|
|
1112
|
+
... "envelope_detect",
|
|
1113
|
+
... {"name": "normalize", "params": {"output_range": (0, 1)}},
|
|
1114
|
+
... SomeCustomOperation(),
|
|
1115
|
+
... ]
|
|
1116
|
+
... )
|
|
1016
1117
|
"""
|
|
1017
1118
|
chain = []
|
|
1018
1119
|
for operation in operation_chain:
|
|
@@ -1093,7 +1194,7 @@ def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
|
|
|
1093
1194
|
operations = make_operation_chain(config.operations)
|
|
1094
1195
|
|
|
1095
1196
|
# merge pipeline config without operations with kwargs
|
|
1096
|
-
pipeline_config = copy
|
|
1197
|
+
pipeline_config = config.copy()
|
|
1097
1198
|
pipeline_config.pop("operations")
|
|
1098
1199
|
|
|
1099
1200
|
kwargs = {**pipeline_config, **kwargs}
|
|
@@ -1164,33 +1265,134 @@ def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
|
|
|
1164
1265
|
yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
|
|
1165
1266
|
|
|
1166
1267
|
|
|
1167
|
-
@ops_registry("
|
|
1168
|
-
class
|
|
1268
|
+
@ops_registry("map")
|
|
1269
|
+
class Map(Pipeline):
|
|
1169
1270
|
"""
|
|
1170
|
-
|
|
1171
|
-
This is useful to avoid OOM errors when processing large grids.
|
|
1172
|
-
|
|
1173
|
-
Some things to NOTE about this class:
|
|
1174
|
-
|
|
1175
|
-
- The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
|
|
1271
|
+
A pipeline that maps its operations over specified input arguments.
|
|
1176
1272
|
|
|
1177
|
-
|
|
1273
|
+
This can be used to reduce memory usage by processing data in chunks.
|
|
1178
1274
|
|
|
1275
|
+
Notes
|
|
1276
|
+
-----
|
|
1277
|
+
- When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline.
|
|
1278
|
+
- Changing anything other than ``self.output_key`` in the dict will not be propagated.
|
|
1179
1279
|
- Will be jitted as a single operation, not the individual operations.
|
|
1180
|
-
|
|
1181
1280
|
- This class handles the batching.
|
|
1182
1281
|
|
|
1282
|
+
For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for
|
|
1283
|
+
jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_.
|
|
1284
|
+
|
|
1285
|
+
Example
|
|
1286
|
+
-------
|
|
1287
|
+
.. doctest::
|
|
1288
|
+
|
|
1289
|
+
>>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection
|
|
1290
|
+
|
|
1291
|
+
>>> # apply operations in batches of 8
|
|
1292
|
+
>>> # in this case, over the first axis of "data"
|
|
1293
|
+
>>> # or more specifically, process 8 transmits at a time
|
|
1294
|
+
|
|
1295
|
+
>>> pipeline_mapped = Map(
|
|
1296
|
+
... [
|
|
1297
|
+
... Demodulate(),
|
|
1298
|
+
... TOFCorrection(),
|
|
1299
|
+
... ],
|
|
1300
|
+
... argnames="data",
|
|
1301
|
+
... batch_size=8,
|
|
1302
|
+
... )
|
|
1303
|
+
|
|
1304
|
+
>>> # you can also map a subset of the operations
|
|
1305
|
+
>>> # for example, demodulate in 4 chunks
|
|
1306
|
+
>>> # or more specifically, split the transmit axis into 4 parts
|
|
1307
|
+
|
|
1308
|
+
>>> pipeline_mapped = Pipeline(
|
|
1309
|
+
... [
|
|
1310
|
+
... Map([Demodulate()], argnames="data", chunks=4),
|
|
1311
|
+
... TOFCorrection(),
|
|
1312
|
+
... ],
|
|
1313
|
+
... )
|
|
1183
1314
|
"""
|
|
1184
1315
|
|
|
1185
|
-
def __init__(
|
|
1186
|
-
|
|
1187
|
-
|
|
1316
|
+
def __init__(
|
|
1317
|
+
self,
|
|
1318
|
+
operations: List[Operation],
|
|
1319
|
+
argnames: List[str] | str,
|
|
1320
|
+
in_axes: List[Union[int, None]] | int = 0,
|
|
1321
|
+
out_axes: List[Union[int, None]] | int = 0,
|
|
1322
|
+
chunks: int | None = None,
|
|
1323
|
+
batch_size: int | None = None,
|
|
1324
|
+
**kwargs,
|
|
1325
|
+
):
|
|
1326
|
+
"""
|
|
1327
|
+
Args:
|
|
1328
|
+
operations (list): List of operations to be performed.
|
|
1329
|
+
argnames (str or list): List of argument names (or keys) to map over.
|
|
1330
|
+
Can also be a single string if only one argument is mapped over.
|
|
1331
|
+
in_axes (int or list): Axes to map over for each argument.
|
|
1332
|
+
If a single int is provided, it is used for all arguments.
|
|
1333
|
+
out_axes (int or list): Axes to map over for each output.
|
|
1334
|
+
If a single int is provided, it is used for all outputs.
|
|
1335
|
+
chunks (int, optional): Number of chunks to split the input data into.
|
|
1336
|
+
If None, no chunking is performed. Mutually exclusive with ``batch_size``.
|
|
1337
|
+
batch_size (int, optional): Size of batches to process at once.
|
|
1338
|
+
If None, no batching is performed. Mutually exclusive with ``chunks``.
|
|
1339
|
+
"""
|
|
1340
|
+
super().__init__(operations, **kwargs)
|
|
1188
1341
|
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1342
|
+
if batch_size is not None and chunks is not None:
|
|
1343
|
+
raise ValueError(
|
|
1344
|
+
"batch_size and chunks are mutually exclusive. Please specify only one."
|
|
1345
|
+
)
|
|
1192
1346
|
|
|
1193
|
-
|
|
1347
|
+
if batch_size is not None and batch_size <= 0:
|
|
1348
|
+
raise ValueError("batch_size must be a positive integer.")
|
|
1349
|
+
|
|
1350
|
+
if chunks is not None and chunks <= 0:
|
|
1351
|
+
raise ValueError("chunks must be a positive integer.")
|
|
1352
|
+
|
|
1353
|
+
if isinstance(argnames, str):
|
|
1354
|
+
argnames = [argnames]
|
|
1355
|
+
|
|
1356
|
+
self.argnames = argnames
|
|
1357
|
+
self.in_axes = in_axes
|
|
1358
|
+
self.out_axes = out_axes
|
|
1359
|
+
self.chunks = chunks
|
|
1360
|
+
self.batch_size = batch_size
|
|
1361
|
+
|
|
1362
|
+
if chunks is None and batch_size is None:
|
|
1363
|
+
log.warning(
|
|
1364
|
+
"[zea.ops.Map] Both `chunks` and `batch_size` are None. "
|
|
1365
|
+
"This will behave like a normal Pipeline. "
|
|
1366
|
+
"Consider setting one of them to process data in chunks or batches."
|
|
1367
|
+
)
|
|
1368
|
+
|
|
1369
|
+
def call_item(**inputs):
|
|
1370
|
+
"""Process data in patches."""
|
|
1371
|
+
mapped_args = []
|
|
1372
|
+
for argname in argnames:
|
|
1373
|
+
mapped_args.append(inputs.pop(argname, None))
|
|
1374
|
+
|
|
1375
|
+
def patched_call(*args):
|
|
1376
|
+
mapped_kwargs = [(k, v) for k, v in zip(argnames, args)]
|
|
1377
|
+
out = super(Map, self).call(**dict(mapped_kwargs), **inputs)
|
|
1378
|
+
|
|
1379
|
+
# TODO: maybe it is possible to output everything?
|
|
1380
|
+
# e.g. prepend a empty dimension to all inputs and just map over everything?
|
|
1381
|
+
return out[self.output_key]
|
|
1382
|
+
|
|
1383
|
+
out = vmap(
|
|
1384
|
+
patched_call,
|
|
1385
|
+
in_axes=in_axes,
|
|
1386
|
+
out_axes=out_axes,
|
|
1387
|
+
chunks=chunks,
|
|
1388
|
+
batch_size=batch_size,
|
|
1389
|
+
fn_supports_batch=True,
|
|
1390
|
+
disable_jit=not bool(self.jit_options),
|
|
1391
|
+
)(*mapped_args)
|
|
1392
|
+
|
|
1393
|
+
return out
|
|
1394
|
+
|
|
1395
|
+
self.call_item = call_item
|
|
1194
1396
|
|
|
1195
1397
|
@property
|
|
1196
1398
|
def jit_options(self):
|
|
@@ -1230,59 +1432,28 @@ class PatchedGrid(Pipeline):
|
|
|
1230
1432
|
|
|
1231
1433
|
@property
|
|
1232
1434
|
def valid_keys(self) -> set:
|
|
1233
|
-
"""Get a set of valid keys for the pipeline.
|
|
1234
|
-
operates on (even if not used by operations
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
# TODO: maybe using n_tx and n_el from kwargs is better but these are tensors now
|
|
1246
|
-
# and this is not supported in broadcast_to
|
|
1247
|
-
n_tx = inputs[self.key].shape[0]
|
|
1248
|
-
n_pix = flatgrid.shape[0]
|
|
1249
|
-
n_el = inputs[self.key].shape[2]
|
|
1250
|
-
inputs["rx_apo"] = ops.broadcast_to(inputs.get("rx_apo", 1.0), (n_tx, n_pix, n_el))
|
|
1251
|
-
inputs["rx_apo"] = ops.swapaxes(inputs["rx_apo"], 0, 1) # put n_pix first
|
|
1252
|
-
|
|
1253
|
-
# Define a list of keys to look up for patching
|
|
1254
|
-
patch_keys = ["flat_pfield", "rx_apo"]
|
|
1255
|
-
|
|
1256
|
-
patch_arrays = {}
|
|
1257
|
-
for key in patch_keys:
|
|
1258
|
-
if key in inputs:
|
|
1259
|
-
patch_arrays[key] = inputs.pop(key)
|
|
1260
|
-
|
|
1261
|
-
def patched_call(flatgrid, **patch_kwargs):
|
|
1262
|
-
patch_args = {k: v for k, v in patch_kwargs.items() if v is not None}
|
|
1263
|
-
patch_args["rx_apo"] = ops.swapaxes(patch_args["rx_apo"], 0, 1)
|
|
1264
|
-
out = super(PatchedGrid, self).call(flatgrid=flatgrid, **patch_args, **inputs)
|
|
1265
|
-
return out[self.output_key]
|
|
1266
|
-
|
|
1267
|
-
out = patched_map(
|
|
1268
|
-
patched_call,
|
|
1269
|
-
flatgrid,
|
|
1270
|
-
self.num_patches,
|
|
1271
|
-
**patch_arrays,
|
|
1272
|
-
jit=bool(self.jit_options),
|
|
1273
|
-
)
|
|
1274
|
-
return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
|
|
1435
|
+
"""Get a set of valid keys for the pipeline.
|
|
1436
|
+
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1437
|
+
inside it)."""
|
|
1438
|
+
return super().valid_keys.union(self.argnames)
|
|
1439
|
+
|
|
1440
|
+
@property
|
|
1441
|
+
def needs_keys(self) -> set:
|
|
1442
|
+
"""Get a set of all input keys needed by the pipeline.
|
|
1443
|
+
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1444
|
+
inside it)."""
|
|
1445
|
+
return super().needs_keys.union(self.argnames)
|
|
1275
1446
|
|
|
1276
1447
|
def jittable_call(self, **inputs):
|
|
1277
1448
|
"""Process input data through the pipeline."""
|
|
1278
1449
|
if self._with_batch_dim:
|
|
1279
1450
|
input_data = inputs.pop(self.key)
|
|
1280
1451
|
output = ops.map(
|
|
1281
|
-
lambda x: self.call_item({self.key: x, **inputs}),
|
|
1452
|
+
lambda x: self.call_item(**{self.key: x, **inputs}),
|
|
1282
1453
|
input_data,
|
|
1283
1454
|
)
|
|
1284
1455
|
else:
|
|
1285
|
-
output = self.call_item(inputs)
|
|
1456
|
+
output = self.call_item(**inputs)
|
|
1286
1457
|
|
|
1287
1458
|
return {self.output_key: output}
|
|
1288
1459
|
|
|
@@ -1295,11 +1466,61 @@ class PatchedGrid(Pipeline):
|
|
|
1295
1466
|
def get_dict(self):
|
|
1296
1467
|
"""Get the configuration of the pipeline."""
|
|
1297
1468
|
config = super().get_dict()
|
|
1298
|
-
config.update(
|
|
1469
|
+
config["params"].update(
|
|
1470
|
+
{
|
|
1471
|
+
"argnames": self.argnames,
|
|
1472
|
+
"in_axes": self.in_axes,
|
|
1473
|
+
"out_axes": self.out_axes,
|
|
1474
|
+
"chunks": self.chunks,
|
|
1475
|
+
"batch_size": self.batch_size,
|
|
1476
|
+
}
|
|
1477
|
+
)
|
|
1478
|
+
return config
|
|
1479
|
+
|
|
1480
|
+
|
|
1481
|
+
@ops_registry("patched_grid")
|
|
1482
|
+
class PatchedGrid(Map):
|
|
1483
|
+
"""
|
|
1484
|
+
A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys.
|
|
1485
|
+
|
|
1486
|
+
This can be used to reduce memory usage by processing data in chunks.
|
|
1487
|
+
|
|
1488
|
+
For more information and flexibility, see :class:`zea.ops.Map`.
|
|
1489
|
+
"""
|
|
1490
|
+
|
|
1491
|
+
def __init__(self, *args, num_patches=10, **kwargs):
|
|
1492
|
+
super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs)
|
|
1493
|
+
self.num_patches = num_patches
|
|
1494
|
+
|
|
1495
|
+
def get_dict(self):
|
|
1496
|
+
"""Get the configuration of the pipeline."""
|
|
1497
|
+
config = super().get_dict()
|
|
1498
|
+
config["params"].pop("argnames")
|
|
1499
|
+
config["params"].pop("chunks")
|
|
1299
1500
|
config["params"].update({"num_patches": self.num_patches})
|
|
1300
1501
|
return config
|
|
1301
1502
|
|
|
1302
1503
|
|
|
1504
|
+
@ops_registry("reshape_grid")
|
|
1505
|
+
class ReshapeGrid(Operation):
|
|
1506
|
+
"""Reshape flat grid data to grid shape."""
|
|
1507
|
+
|
|
1508
|
+
def __init__(self, axis=0, **kwargs):
|
|
1509
|
+
super().__init__(**kwargs)
|
|
1510
|
+
self.axis = axis
|
|
1511
|
+
|
|
1512
|
+
def call(self, grid, **kwargs):
|
|
1513
|
+
"""
|
|
1514
|
+
Args:
|
|
1515
|
+
- data (Tensor): The flat grid data of shape (..., n_pix, ...).
|
|
1516
|
+
Returns:
|
|
1517
|
+
- reshaped_data (Tensor): The reshaped data of shape (..., grid.shape, ...).
|
|
1518
|
+
"""
|
|
1519
|
+
data = kwargs[self.key]
|
|
1520
|
+
reshaped_data = reshape_axis(data, grid.shape[:-1], self.axis + int(self.with_batch_dim))
|
|
1521
|
+
return {self.output_key: reshaped_data}
|
|
1522
|
+
|
|
1523
|
+
|
|
1303
1524
|
## Base Operations
|
|
1304
1525
|
|
|
1305
1526
|
|
|
@@ -1309,7 +1530,7 @@ class Identity(Operation):
|
|
|
1309
1530
|
|
|
1310
1531
|
def call(self, **kwargs) -> Dict:
|
|
1311
1532
|
"""Returns the input as is."""
|
|
1312
|
-
return
|
|
1533
|
+
return {}
|
|
1313
1534
|
|
|
1314
1535
|
|
|
1315
1536
|
@ops_registry("merge")
|
|
@@ -1332,21 +1553,6 @@ class Merge(Operation):
|
|
|
1332
1553
|
return merged
|
|
1333
1554
|
|
|
1334
1555
|
|
|
1335
|
-
@ops_registry("split")
|
|
1336
|
-
class Split(Operation):
|
|
1337
|
-
"""Operation that splits an input dictionary n copies."""
|
|
1338
|
-
|
|
1339
|
-
def __init__(self, n: int, **kwargs):
|
|
1340
|
-
super().__init__(**kwargs)
|
|
1341
|
-
self.n = n
|
|
1342
|
-
|
|
1343
|
-
def call(self, **kwargs) -> List[Dict]:
|
|
1344
|
-
"""
|
|
1345
|
-
Splits the input dictionary into n copies.
|
|
1346
|
-
"""
|
|
1347
|
-
return [kwargs.copy() for _ in range(self.n)]
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
1556
|
@ops_registry("stack")
|
|
1351
1557
|
class Stack(Operation):
|
|
1352
1558
|
"""Stack multiple data arrays along a new axis.
|
|
@@ -1399,6 +1605,7 @@ class Simulate(Operation):
|
|
|
1399
1605
|
def __init__(self, **kwargs):
|
|
1400
1606
|
super().__init__(
|
|
1401
1607
|
output_data_type=DataTypes.RAW_DATA,
|
|
1608
|
+
additional_output_keys=["n_ch"],
|
|
1402
1609
|
**kwargs,
|
|
1403
1610
|
)
|
|
1404
1611
|
|
|
@@ -1451,18 +1658,16 @@ class TOFCorrection(Operation):
|
|
|
1451
1658
|
STATIC_PARAMS = [
|
|
1452
1659
|
"f_number",
|
|
1453
1660
|
"apply_lens_correction",
|
|
1454
|
-
"apply_phase_rotation",
|
|
1455
1661
|
"grid_size_x",
|
|
1456
1662
|
"grid_size_z",
|
|
1457
1663
|
]
|
|
1458
1664
|
|
|
1459
|
-
def __init__(self,
|
|
1665
|
+
def __init__(self, **kwargs):
|
|
1460
1666
|
super().__init__(
|
|
1461
1667
|
input_data_type=DataTypes.RAW_DATA,
|
|
1462
1668
|
output_data_type=DataTypes.ALIGNED_DATA,
|
|
1463
1669
|
**kwargs,
|
|
1464
1670
|
)
|
|
1465
|
-
self.apply_phase_rotation = apply_phase_rotation
|
|
1466
1671
|
|
|
1467
1672
|
def call(
|
|
1468
1673
|
self,
|
|
@@ -1477,6 +1682,8 @@ class TOFCorrection(Operation):
|
|
|
1477
1682
|
tx_apodizations,
|
|
1478
1683
|
initial_times,
|
|
1479
1684
|
probe_geometry,
|
|
1685
|
+
t_peak,
|
|
1686
|
+
tx_waveform_indices,
|
|
1480
1687
|
apply_lens_correction=None,
|
|
1481
1688
|
lens_thickness=None,
|
|
1482
1689
|
lens_sound_speed=None,
|
|
@@ -1497,6 +1704,9 @@ class TOFCorrection(Operation):
|
|
|
1497
1704
|
tx_apodizations (ops.Tensor): Transmit apodizations
|
|
1498
1705
|
initial_times (ops.Tensor): Initial times
|
|
1499
1706
|
probe_geometry (ops.Tensor): Probe element positions
|
|
1707
|
+
t_peak (float): Time to peak of the transmit pulse
|
|
1708
|
+
tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
|
|
1709
|
+
transmit. (All zero if there is only one waveform)
|
|
1500
1710
|
apply_lens_correction (bool): Whether to apply lens correction
|
|
1501
1711
|
lens_thickness (float): Lens thickness
|
|
1502
1712
|
lens_sound_speed (float): Sound speed in the lens
|
|
@@ -1507,29 +1717,30 @@ class TOFCorrection(Operation):
|
|
|
1507
1717
|
|
|
1508
1718
|
raw_data = kwargs[self.key]
|
|
1509
1719
|
|
|
1510
|
-
|
|
1720
|
+
tof_kwargs = {
|
|
1511
1721
|
"flatgrid": flatgrid,
|
|
1512
|
-
"sound_speed": sound_speed,
|
|
1513
|
-
"angles": polar_angles,
|
|
1514
|
-
"focus_distances": focus_distances,
|
|
1515
|
-
"sampling_frequency": sampling_frequency,
|
|
1516
|
-
"fnum": f_number,
|
|
1517
|
-
"apply_phase_rotation": self.apply_phase_rotation,
|
|
1518
|
-
"demodulation_frequency": demodulation_frequency,
|
|
1519
1722
|
"t0_delays": t0_delays,
|
|
1520
1723
|
"tx_apodizations": tx_apodizations,
|
|
1521
|
-
"
|
|
1724
|
+
"sound_speed": sound_speed,
|
|
1522
1725
|
"probe_geometry": probe_geometry,
|
|
1726
|
+
"initial_times": initial_times,
|
|
1727
|
+
"sampling_frequency": sampling_frequency,
|
|
1728
|
+
"demodulation_frequency": demodulation_frequency,
|
|
1729
|
+
"f_number": f_number,
|
|
1730
|
+
"polar_angles": polar_angles,
|
|
1731
|
+
"focus_distances": focus_distances,
|
|
1732
|
+
"t_peak": t_peak,
|
|
1733
|
+
"tx_waveform_indices": tx_waveform_indices,
|
|
1523
1734
|
"apply_lens_correction": apply_lens_correction,
|
|
1524
1735
|
"lens_thickness": lens_thickness,
|
|
1525
1736
|
"lens_sound_speed": lens_sound_speed,
|
|
1526
1737
|
}
|
|
1527
1738
|
|
|
1528
1739
|
if not self.with_batch_dim:
|
|
1529
|
-
tof_corrected = tof_correction(raw_data, **
|
|
1740
|
+
tof_corrected = tof_correction(raw_data, **tof_kwargs)
|
|
1530
1741
|
else:
|
|
1531
1742
|
tof_corrected = ops.map(
|
|
1532
|
-
lambda data: tof_correction(data, **
|
|
1743
|
+
lambda data: tof_correction(data, **tof_kwargs),
|
|
1533
1744
|
raw_data,
|
|
1534
1745
|
)
|
|
1535
1746
|
|
|
@@ -1556,7 +1767,7 @@ class PfieldWeighting(Operation):
|
|
|
1556
1767
|
Returns:
|
|
1557
1768
|
dict: Dictionary containing weighted data
|
|
1558
1769
|
"""
|
|
1559
|
-
data = kwargs[self.key]
|
|
1770
|
+
data = kwargs[self.key] # must start with ((batch_size,) n_tx, n_pix, ...)
|
|
1560
1771
|
|
|
1561
1772
|
if flat_pfield is None:
|
|
1562
1773
|
return {self.output_key: data}
|
|
@@ -1564,14 +1775,16 @@ class PfieldWeighting(Operation):
|
|
|
1564
1775
|
# Swap (n_pix, n_tx) to (n_tx, n_pix)
|
|
1565
1776
|
flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
|
|
1566
1777
|
|
|
1567
|
-
#
|
|
1568
|
-
# Also add the required dimensions for broadcasting
|
|
1778
|
+
# Add batch dimension if needed
|
|
1569
1779
|
if self.with_batch_dim:
|
|
1570
1780
|
pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
|
|
1571
1781
|
else:
|
|
1572
1782
|
pfield_expanded = flat_pfield
|
|
1573
1783
|
|
|
1574
|
-
|
|
1784
|
+
append_n_dims = ops.ndim(data) - ops.ndim(pfield_expanded)
|
|
1785
|
+
pfield_expanded = extend_n_dims(pfield_expanded, axis=-1, n_dims=append_n_dims)
|
|
1786
|
+
|
|
1787
|
+
# Perform element-wise multiplication with the pressure weight mask
|
|
1575
1788
|
weighted_data = data * pfield_expanded
|
|
1576
1789
|
|
|
1577
1790
|
return {self.output_key: weighted_data}
|
|
@@ -1581,79 +1794,93 @@ class PfieldWeighting(Operation):
|
|
|
1581
1794
|
class DelayAndSum(Operation):
|
|
1582
1795
|
"""Sums time-delayed signals along channels and transmits."""
|
|
1583
1796
|
|
|
1584
|
-
def __init__(
|
|
1585
|
-
self,
|
|
1586
|
-
reshape_grid=True,
|
|
1587
|
-
**kwargs,
|
|
1588
|
-
):
|
|
1797
|
+
def __init__(self, **kwargs):
|
|
1589
1798
|
super().__init__(
|
|
1590
|
-
input_data_type=
|
|
1799
|
+
input_data_type=DataTypes.ALIGNED_DATA,
|
|
1591
1800
|
output_data_type=DataTypes.BEAMFORMED_DATA,
|
|
1592
1801
|
**kwargs,
|
|
1593
1802
|
)
|
|
1594
|
-
self.reshape_grid = reshape_grid
|
|
1595
1803
|
|
|
1596
|
-
def process_image(self, data
|
|
1804
|
+
def process_image(self, data):
|
|
1597
1805
|
"""Performs DAS beamforming on tof-corrected input.
|
|
1598
1806
|
|
|
1599
1807
|
Args:
|
|
1600
1808
|
data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
|
|
1601
|
-
rx_apo (ops.Tensor): Receive apodization window of shape `(n_tx, n_pix, n_el, n_ch)`.
|
|
1602
1809
|
|
|
1603
1810
|
Returns:
|
|
1604
1811
|
ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
|
|
1605
1812
|
"""
|
|
1606
1813
|
# Sum over the channels, i.e. DAS
|
|
1607
|
-
data = ops.sum(
|
|
1814
|
+
data = ops.sum(data, -2)
|
|
1608
1815
|
|
|
1609
1816
|
# Sum over transmits, i.e. Compounding
|
|
1610
1817
|
data = ops.sum(data, 0)
|
|
1611
1818
|
|
|
1612
1819
|
return data
|
|
1613
1820
|
|
|
1614
|
-
def call(
|
|
1615
|
-
self,
|
|
1616
|
-
rx_apo=None,
|
|
1617
|
-
grid=None,
|
|
1618
|
-
**kwargs,
|
|
1619
|
-
):
|
|
1821
|
+
def call(self, grid=None, **kwargs):
|
|
1620
1822
|
"""Performs DAS beamforming on tof-corrected input.
|
|
1621
1823
|
|
|
1622
1824
|
Args:
|
|
1623
1825
|
tof_corrected_data (ops.Tensor): The TOF corrected input of shape
|
|
1624
1826
|
`(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
|
|
1625
|
-
rx_apo (ops.Tensor): Receive apodization window
|
|
1626
|
-
of shape `(n_tx, grid_size_z*grid_size_x, n_el)`
|
|
1627
|
-
with optional batch dimension. Defaults to 1.0.
|
|
1628
1827
|
|
|
1629
1828
|
Returns:
|
|
1630
1829
|
dict: Dictionary containing beamformed_data
|
|
1631
|
-
of shape `(grid_size_z*grid_size_x, n_ch)`
|
|
1632
|
-
or `(grid_size_z, grid_size_x, n_ch)` when reshape_grid is True,
|
|
1830
|
+
of shape `(grid_size_z*grid_size_x, n_ch)`
|
|
1633
1831
|
with optional batch dimension.
|
|
1634
1832
|
"""
|
|
1635
1833
|
data = kwargs[self.key]
|
|
1636
1834
|
|
|
1637
|
-
if rx_apo is None:
|
|
1638
|
-
rx_apo = ops.ones(1, dtype=ops.dtype(data))
|
|
1639
|
-
rx_apo = ops.broadcast_to(rx_apo[..., None], data.shape)
|
|
1640
|
-
|
|
1641
1835
|
if not self.with_batch_dim:
|
|
1642
|
-
beamformed_data = self.process_image(data
|
|
1836
|
+
beamformed_data = self.process_image(data)
|
|
1643
1837
|
else:
|
|
1644
1838
|
# Apply process_image to each item in the batch
|
|
1645
|
-
beamformed_data =
|
|
1646
|
-
lambda data, rx_apo: self.process_image(data, rx_apo), data, rx_apo=rx_apo
|
|
1647
|
-
)
|
|
1648
|
-
|
|
1649
|
-
if self.reshape_grid:
|
|
1650
|
-
beamformed_data = reshape_axis(
|
|
1651
|
-
beamformed_data, grid.shape[:2], axis=int(self.with_batch_dim)
|
|
1652
|
-
)
|
|
1839
|
+
beamformed_data = ops.map(self.process_image, data)
|
|
1653
1840
|
|
|
1654
1841
|
return {self.output_key: beamformed_data}
|
|
1655
1842
|
|
|
1656
1843
|
|
|
1844
|
+
def envelope_detect(data, axis=-3):
|
|
1845
|
+
"""Envelope detection of RF signals.
|
|
1846
|
+
|
|
1847
|
+
If the input data is real, it first applies the Hilbert transform along the specified axis
|
|
1848
|
+
and then computes the magnitude of the resulting complex signal.
|
|
1849
|
+
If the input data is complex, it computes the magnitude directly.
|
|
1850
|
+
|
|
1851
|
+
Args:
|
|
1852
|
+
- data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
|
|
1853
|
+
- axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
|
|
1854
|
+
|
|
1855
|
+
Returns:
|
|
1856
|
+
- envelope_data (Tensor): The envelope detected data
|
|
1857
|
+
of shape (..., grid_size_z, grid_size_x).
|
|
1858
|
+
"""
|
|
1859
|
+
if data.shape[-1] == 2:
|
|
1860
|
+
data = channels_to_complex(data)
|
|
1861
|
+
else:
|
|
1862
|
+
n_ax = ops.shape(data)[axis]
|
|
1863
|
+
n_ax_float = ops.cast(n_ax, "float32")
|
|
1864
|
+
|
|
1865
|
+
# Calculate next power of 2: M = 2^ceil(log2(n_ax))
|
|
1866
|
+
# see https://github.com/tue-bmd/zea/discussions/147
|
|
1867
|
+
log2_n_ax = ops.log2(n_ax_float)
|
|
1868
|
+
M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
|
|
1869
|
+
|
|
1870
|
+
data = hilbert(data, N=M, axis=axis)
|
|
1871
|
+
indices = ops.arange(n_ax)
|
|
1872
|
+
|
|
1873
|
+
data = ops.take(data, indices, axis=axis)
|
|
1874
|
+
data = ops.squeeze(data, axis=-1)
|
|
1875
|
+
|
|
1876
|
+
# data = ops.abs(data)
|
|
1877
|
+
real = ops.real(data)
|
|
1878
|
+
imag = ops.imag(data)
|
|
1879
|
+
data = ops.sqrt(real**2 + imag**2)
|
|
1880
|
+
data = ops.cast(data, "float32")
|
|
1881
|
+
return data
|
|
1882
|
+
|
|
1883
|
+
|
|
1657
1884
|
@ops_registry("envelope_detect")
|
|
1658
1885
|
class EnvelopeDetect(Operation):
|
|
1659
1886
|
"""Envelope detection of RF signals."""
|
|
@@ -1680,23 +1907,7 @@ class EnvelopeDetect(Operation):
|
|
|
1680
1907
|
"""
|
|
1681
1908
|
data = kwargs[self.key]
|
|
1682
1909
|
|
|
1683
|
-
|
|
1684
|
-
data = channels_to_complex(data)
|
|
1685
|
-
else:
|
|
1686
|
-
n_ax = data.shape[self.axis]
|
|
1687
|
-
M = 2 ** int(np.ceil(np.log2(n_ax)))
|
|
1688
|
-
# data = scipy.signal.hilbert(data, N=M, axis=self.axis)
|
|
1689
|
-
data = hilbert(data, N=M, axis=self.axis)
|
|
1690
|
-
indices = ops.arange(n_ax)
|
|
1691
|
-
|
|
1692
|
-
data = ops.take(data, indices, axis=self.axis)
|
|
1693
|
-
data = ops.squeeze(data, axis=-1)
|
|
1694
|
-
|
|
1695
|
-
# data = ops.abs(data)
|
|
1696
|
-
real = ops.real(data)
|
|
1697
|
-
imag = ops.imag(data)
|
|
1698
|
-
data = ops.sqrt(real**2 + imag**2)
|
|
1699
|
-
data = ops.cast(data, "float32")
|
|
1910
|
+
data = envelope_detect(data, axis=self.axis)
|
|
1700
1911
|
|
|
1701
1912
|
return {self.output_key: data}
|
|
1702
1913
|
|
|
@@ -1734,19 +1945,29 @@ class UpMix(Operation):
|
|
|
1734
1945
|
return {self.output_key: data}
|
|
1735
1946
|
|
|
1736
1947
|
|
|
1948
|
+
def log_compress(data, eps=1e-16):
|
|
1949
|
+
"""Apply logarithmic compression to data."""
|
|
1950
|
+
eps = ops.convert_to_tensor(eps, dtype=data.dtype)
|
|
1951
|
+
data = ops.where(data == 0, eps, data) # Avoid log(0)
|
|
1952
|
+
return 20 * keras.ops.log10(data)
|
|
1953
|
+
|
|
1954
|
+
|
|
1737
1955
|
@ops_registry("log_compress")
|
|
1738
1956
|
class LogCompress(Operation):
|
|
1739
1957
|
"""Logarithmic compression of data."""
|
|
1740
1958
|
|
|
1741
|
-
def __init__(
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1959
|
+
def __init__(self, clip: bool = True, **kwargs):
|
|
1960
|
+
"""Initialize the LogCompress operation.
|
|
1961
|
+
|
|
1962
|
+
Args:
|
|
1963
|
+
clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
|
|
1964
|
+
"""
|
|
1745
1965
|
super().__init__(
|
|
1746
1966
|
input_data_type=DataTypes.ENVELOPE_DATA,
|
|
1747
1967
|
output_data_type=DataTypes.IMAGE,
|
|
1748
1968
|
**kwargs,
|
|
1749
1969
|
)
|
|
1970
|
+
self.clip = clip
|
|
1750
1971
|
|
|
1751
1972
|
def call(self, dynamic_range=None, **kwargs):
|
|
1752
1973
|
"""Apply logarithmic compression to data.
|
|
@@ -1763,20 +1984,43 @@ class LogCompress(Operation):
|
|
|
1763
1984
|
dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
|
|
1764
1985
|
dynamic_range = ops.cast(dynamic_range, data.dtype)
|
|
1765
1986
|
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
|
|
1987
|
+
compressed_data = log_compress(data)
|
|
1988
|
+
if self.clip:
|
|
1989
|
+
compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
|
|
1770
1990
|
|
|
1771
1991
|
return {self.output_key: compressed_data}
|
|
1772
1992
|
|
|
1773
1993
|
|
|
1994
|
+
def normalize(data, output_range, input_range=None):
|
|
1995
|
+
"""Normalize data to a given range.
|
|
1996
|
+
|
|
1997
|
+
Equivalent to `translate` with clipping.
|
|
1998
|
+
|
|
1999
|
+
Args:
|
|
2000
|
+
data (ops.Tensor): Input data to normalize.
|
|
2001
|
+
output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
|
|
2002
|
+
input_range (tuple, optional): Range of input data.
|
|
2003
|
+
If None, the range will be computed from the data.
|
|
2004
|
+
Defaults to None.
|
|
2005
|
+
"""
|
|
2006
|
+
if input_range is None:
|
|
2007
|
+
input_range = (None, None)
|
|
2008
|
+
minval, maxval = input_range
|
|
2009
|
+
if minval is None:
|
|
2010
|
+
minval = ops.min(data)
|
|
2011
|
+
if maxval is None:
|
|
2012
|
+
maxval = ops.max(data)
|
|
2013
|
+
data = ops.clip(data, minval, maxval)
|
|
2014
|
+
normalized_data = translate(data, (minval, maxval), output_range)
|
|
2015
|
+
return normalized_data
|
|
2016
|
+
|
|
2017
|
+
|
|
1774
2018
|
@ops_registry("normalize")
|
|
1775
2019
|
class Normalize(Operation):
|
|
1776
2020
|
"""Normalize data to a given range."""
|
|
1777
2021
|
|
|
1778
2022
|
def __init__(self, output_range=None, input_range=None, **kwargs):
|
|
1779
|
-
super().__init__(**kwargs)
|
|
2023
|
+
super().__init__(additional_output_keys=["minval", "maxval"], **kwargs)
|
|
1780
2024
|
if output_range is None:
|
|
1781
2025
|
output_range = (0, 1)
|
|
1782
2026
|
self.output_range = self.to_float32(output_range)
|
|
@@ -1821,11 +2065,9 @@ class Normalize(Operation):
|
|
|
1821
2065
|
if maxval is None:
|
|
1822
2066
|
maxval = ops.max(data)
|
|
1823
2067
|
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
# Map the data to the output range
|
|
1828
|
-
normalized_data = translate(data, (minval, maxval), self.output_range)
|
|
2068
|
+
normalized_data = normalize(
|
|
2069
|
+
data, output_range=self.output_range, input_range=(minval, maxval)
|
|
2070
|
+
)
|
|
1829
2071
|
|
|
1830
2072
|
return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
|
|
1831
2073
|
|
|
@@ -1855,6 +2097,18 @@ class ScanConvert(Operation):
|
|
|
1855
2097
|
input_data_type=DataTypes.IMAGE,
|
|
1856
2098
|
output_data_type=DataTypes.IMAGE_SC,
|
|
1857
2099
|
jittable=jittable,
|
|
2100
|
+
additional_output_keys=[
|
|
2101
|
+
"resolution",
|
|
2102
|
+
"x_lim",
|
|
2103
|
+
"y_lim",
|
|
2104
|
+
"z_lim",
|
|
2105
|
+
"rho_range",
|
|
2106
|
+
"theta_range",
|
|
2107
|
+
"phi_range",
|
|
2108
|
+
"d_rho",
|
|
2109
|
+
"d_theta",
|
|
2110
|
+
"d_phi",
|
|
2111
|
+
],
|
|
1858
2112
|
**kwargs,
|
|
1859
2113
|
)
|
|
1860
2114
|
self.order = order
|
|
@@ -1913,7 +2167,7 @@ class ScanConvert(Operation):
|
|
|
1913
2167
|
|
|
1914
2168
|
|
|
1915
2169
|
@ops_registry("gaussian_blur")
|
|
1916
|
-
class GaussianBlur(
|
|
2170
|
+
class GaussianBlur(ImageOperation):
|
|
1917
2171
|
"""
|
|
1918
2172
|
GaussianBlur is an operation that applies a Gaussian blur to an input image.
|
|
1919
2173
|
Uses scipy.ndimage.gaussian_filter to create a kernel.
|
|
@@ -1963,6 +2217,13 @@ class GaussianBlur(Operation):
|
|
|
1963
2217
|
return ops.convert_to_tensor(kernel)
|
|
1964
2218
|
|
|
1965
2219
|
def call(self, **kwargs):
|
|
2220
|
+
"""Apply a Gaussian filter to the input data.
|
|
2221
|
+
|
|
2222
|
+
Args:
|
|
2223
|
+
data (ops.Tensor): Input image data of shape (height, width, channels) with
|
|
2224
|
+
optional batch dimension if ``self.with_batch_dim``.
|
|
2225
|
+
"""
|
|
2226
|
+
super().call(**kwargs)
|
|
1966
2227
|
data = kwargs[self.key]
|
|
1967
2228
|
|
|
1968
2229
|
# Add batch dimension if not present
|
|
@@ -1997,7 +2258,7 @@ class GaussianBlur(Operation):
|
|
|
1997
2258
|
|
|
1998
2259
|
|
|
1999
2260
|
@ops_registry("lee_filter")
|
|
2000
|
-
class LeeFilter(
|
|
2261
|
+
class LeeFilter(ImageOperation):
|
|
2001
2262
|
"""
|
|
2002
2263
|
The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
|
|
2003
2264
|
and ultrasound image processing. It smooths the image while preserving edges and details.
|
|
@@ -2027,7 +2288,7 @@ class LeeFilter(Operation):
|
|
|
2027
2288
|
pad_mode=self.pad_mode,
|
|
2028
2289
|
with_batch_dim=self.with_batch_dim,
|
|
2029
2290
|
jittable=self._jittable,
|
|
2030
|
-
key=
|
|
2291
|
+
key="data",
|
|
2031
2292
|
)
|
|
2032
2293
|
|
|
2033
2294
|
@property
|
|
@@ -2043,24 +2304,29 @@ class LeeFilter(Operation):
|
|
|
2043
2304
|
self.gaussian_blur.with_batch_dim = value
|
|
2044
2305
|
|
|
2045
2306
|
def call(self, **kwargs):
|
|
2046
|
-
|
|
2307
|
+
"""Apply the Lee filter to the input data.
|
|
2308
|
+
|
|
2309
|
+
Args:
|
|
2310
|
+
data (ops.Tensor): Input image data of shape (height, width, channels) with
|
|
2311
|
+
optional batch dimension if ``self.with_batch_dim``.
|
|
2312
|
+
"""
|
|
2313
|
+
super().call(**kwargs)
|
|
2314
|
+
data = kwargs.pop(self.key)
|
|
2047
2315
|
|
|
2048
2316
|
# Apply Gaussian blur to get local mean
|
|
2049
|
-
img_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
|
|
2317
|
+
img_mean = self.gaussian_blur.call(data=data, **kwargs)[self.gaussian_blur.output_key]
|
|
2050
2318
|
|
|
2051
2319
|
# Apply Gaussian blur to squared data to get local squared mean
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2320
|
+
img_sqr_mean = self.gaussian_blur.call(
|
|
2321
|
+
data=data**2,
|
|
2322
|
+
**kwargs,
|
|
2323
|
+
)[self.gaussian_blur.output_key]
|
|
2055
2324
|
|
|
2056
2325
|
# Calculate local variance
|
|
2057
2326
|
img_variance = img_sqr_mean - img_mean**2
|
|
2058
2327
|
|
|
2059
2328
|
# Calculate global variance (per channel)
|
|
2060
|
-
|
|
2061
|
-
overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
|
|
2062
|
-
else:
|
|
2063
|
-
overall_variance = ops.var(data, axis=(-2, -1), keepdims=True)
|
|
2329
|
+
overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
|
|
2064
2330
|
|
|
2065
2331
|
# Calculate adaptive weights
|
|
2066
2332
|
img_weights = img_variance / (img_variance + overall_variance)
|
|
@@ -2081,6 +2347,11 @@ class Demodulate(Operation):
|
|
|
2081
2347
|
input_data_type=DataTypes.RAW_DATA,
|
|
2082
2348
|
output_data_type=DataTypes.RAW_DATA,
|
|
2083
2349
|
jittable=True,
|
|
2350
|
+
additional_output_keys=[
|
|
2351
|
+
"demodulation_frequency",
|
|
2352
|
+
"center_frequency",
|
|
2353
|
+
"n_ch",
|
|
2354
|
+
],
|
|
2084
2355
|
**kwargs,
|
|
2085
2356
|
)
|
|
2086
2357
|
self.axis = axis
|
|
@@ -2106,6 +2377,121 @@ class Demodulate(Operation):
|
|
|
2106
2377
|
}
|
|
2107
2378
|
|
|
2108
2379
|
|
|
2380
|
+
@ops_registry("fir_filter")
|
|
2381
|
+
class FirFilter(Operation):
|
|
2382
|
+
"""Apply a FIR filter to the input signal using convolution.
|
|
2383
|
+
|
|
2384
|
+
Looks for the filter taps in the input dictionary using the specified ``filter_key``.
|
|
2385
|
+
"""
|
|
2386
|
+
|
|
2387
|
+
def __init__(
|
|
2388
|
+
self,
|
|
2389
|
+
axis: int,
|
|
2390
|
+
complex_channels: bool = False,
|
|
2391
|
+
filter_key: str = "fir_filter_taps",
|
|
2392
|
+
**kwargs,
|
|
2393
|
+
):
|
|
2394
|
+
"""
|
|
2395
|
+
Args:
|
|
2396
|
+
axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
|
|
2397
|
+
When using ``complex_channels=True``, the complex channels are removed to convert
|
|
2398
|
+
to complex numbers before filtering, so adjust the ``axis`` accordingly!
|
|
2399
|
+
complex_channels (bool): Whether the last dimension of the input signal represents
|
|
2400
|
+
complex channels (real and imaginary parts). When True, it will convert the signal
|
|
2401
|
+
to ``complex`` dtype before filtering and convert it back to two channels
|
|
2402
|
+
after filtering.
|
|
2403
|
+
filter_key (str): Key in the input dictionary where the FIR filter taps are stored.
|
|
2404
|
+
Default is "fir_filter_taps".
|
|
2405
|
+
"""
|
|
2406
|
+
super().__init__(**kwargs)
|
|
2407
|
+
self._check_axis(axis)
|
|
2408
|
+
|
|
2409
|
+
self.axis = axis
|
|
2410
|
+
self.complex_channels = complex_channels
|
|
2411
|
+
self.filter_key = filter_key
|
|
2412
|
+
|
|
2413
|
+
def _check_axis(self, axis, ndim=None):
|
|
2414
|
+
"""Check if the axis is valid."""
|
|
2415
|
+
if ndim is not None:
|
|
2416
|
+
if axis < -ndim or axis >= ndim:
|
|
2417
|
+
raise ValueError(f"Axis {axis} is out of bounds for array of dimension {ndim}.")
|
|
2418
|
+
|
|
2419
|
+
if self.with_batch_dim and (axis == 0 or (ndim is not None and axis == -ndim)):
|
|
2420
|
+
raise ValueError("Cannot apply FIR filter along batch dimension.")
|
|
2421
|
+
|
|
2422
|
+
@property
|
|
2423
|
+
def valid_keys(self):
|
|
2424
|
+
"""Get the valid keys for the `call` method."""
|
|
2425
|
+
return self._valid_keys.union({self.filter_key})
|
|
2426
|
+
|
|
2427
|
+
def call(self, **kwargs):
|
|
2428
|
+
signal = kwargs[self.key]
|
|
2429
|
+
fir_filter_taps = kwargs[self.filter_key]
|
|
2430
|
+
|
|
2431
|
+
if self.complex_channels:
|
|
2432
|
+
signal = channels_to_complex(signal)
|
|
2433
|
+
|
|
2434
|
+
self._check_axis(self.axis, ndim=ops.ndim(signal))
|
|
2435
|
+
|
|
2436
|
+
def _convolve(signal):
|
|
2437
|
+
"""Apply the filter to the signal using correlation."""
|
|
2438
|
+
return correlate(signal, fir_filter_taps[::-1], mode="same")
|
|
2439
|
+
|
|
2440
|
+
filtered_signal = apply_along_axis(_convolve, self.axis, signal)
|
|
2441
|
+
|
|
2442
|
+
if self.complex_channels:
|
|
2443
|
+
filtered_signal = complex_to_channels(filtered_signal)
|
|
2444
|
+
|
|
2445
|
+
return {self.output_key: filtered_signal}
|
|
2446
|
+
|
|
2447
|
+
|
|
2448
|
+
@ops_registry("low_pass_filter")
|
|
2449
|
+
class LowPassFilter(FirFilter):
|
|
2450
|
+
"""Apply a low-pass FIR filter to the input signal using convolution.
|
|
2451
|
+
|
|
2452
|
+
It is recommended to use :class:`FirFilter` with pre-computed filter taps for jittable
|
|
2453
|
+
operations. The :class:`LowPassFilter` operation itself is not jittable and is provided
|
|
2454
|
+
for convenience only.
|
|
2455
|
+
|
|
2456
|
+
Uses :func:`get_low_pass_iq_filter` to compute the filter taps.
|
|
2457
|
+
"""
|
|
2458
|
+
|
|
2459
|
+
def __init__(self, axis: int, complex_channels: bool = False, num_taps: int = 128, **kwargs):
|
|
2460
|
+
"""Initialize the LowPassFilter operation.
|
|
2461
|
+
|
|
2462
|
+
Args:
|
|
2463
|
+
axis (int): Axis along which to apply the filter. Cannot be the batch dimension.
|
|
2464
|
+
When using ``complex_channels=True``, the complex channels are removed to convert
|
|
2465
|
+
to complex numbers before filtering, so adjust the ``axis`` accordingly.
|
|
2466
|
+
complex_channels (bool): Whether the last dimension of the input signal represents
|
|
2467
|
+
complex channels (real and imaginary parts). When True, it will convert the signal
|
|
2468
|
+
to ``complex`` dtype before filtering and convert it back to two channels
|
|
2469
|
+
after filtering.
|
|
2470
|
+
num_taps (int): Number of taps in the FIR filter. Default is 128.
|
|
2471
|
+
"""
|
|
2472
|
+
self._random_suffix = str(uuid.uuid4())
|
|
2473
|
+
kwargs.pop("filter_key", None)
|
|
2474
|
+
kwargs.pop("jittable", None)
|
|
2475
|
+
super().__init__(
|
|
2476
|
+
axis=axis,
|
|
2477
|
+
complex_channels=complex_channels,
|
|
2478
|
+
filter_key=f"low_pass_{self._random_suffix}",
|
|
2479
|
+
jittable=False,
|
|
2480
|
+
**kwargs,
|
|
2481
|
+
)
|
|
2482
|
+
self.num_taps = num_taps
|
|
2483
|
+
|
|
2484
|
+
def call(self, bandwidth, sampling_frequency, center_frequency, **kwargs):
|
|
2485
|
+
lpf = get_low_pass_iq_filter(
|
|
2486
|
+
self.num_taps,
|
|
2487
|
+
ops.convert_to_numpy(sampling_frequency).item(),
|
|
2488
|
+
ops.convert_to_numpy(center_frequency).item(),
|
|
2489
|
+
ops.convert_to_numpy(bandwidth).item(),
|
|
2490
|
+
)
|
|
2491
|
+
kwargs[self.filter_key] = lpf
|
|
2492
|
+
return super().call(**kwargs)
|
|
2493
|
+
|
|
2494
|
+
|
|
2109
2495
|
@ops_registry("lambda")
|
|
2110
2496
|
class Lambda(Operation):
|
|
2111
2497
|
"""Use any function as an operation."""
|
|
@@ -2140,12 +2526,15 @@ class Lambda(Operation):
|
|
|
2140
2526
|
|
|
2141
2527
|
def call(self, **kwargs):
|
|
2142
2528
|
data = kwargs[self.key]
|
|
2143
|
-
|
|
2529
|
+
if self.with_batch_dim:
|
|
2530
|
+
data = ops.map(self.func, data)
|
|
2531
|
+
else:
|
|
2532
|
+
data = self.func(data)
|
|
2144
2533
|
return {self.output_key: data}
|
|
2145
2534
|
|
|
2146
2535
|
|
|
2147
2536
|
@ops_registry("pad")
|
|
2148
|
-
class Pad(Operation,
|
|
2537
|
+
class Pad(Operation, DataLayer):
|
|
2149
2538
|
"""Pad layer for padding tensors to a specified shape."""
|
|
2150
2539
|
|
|
2151
2540
|
def __init__(
|
|
@@ -2330,6 +2719,7 @@ class Downsample(Operation):
|
|
|
2330
2719
|
|
|
2331
2720
|
def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
|
|
2332
2721
|
super().__init__(
|
|
2722
|
+
additional_output_keys=["sampling_frequency", "n_ax"],
|
|
2333
2723
|
**kwargs,
|
|
2334
2724
|
)
|
|
2335
2725
|
self.factor = factor
|
|
@@ -2894,7 +3284,7 @@ def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
|
|
|
2894
3284
|
return bpf
|
|
2895
3285
|
|
|
2896
3286
|
|
|
2897
|
-
def get_low_pass_iq_filter(num_taps, sampling_frequency,
|
|
3287
|
+
def get_low_pass_iq_filter(num_taps, sampling_frequency, center_frequency, bandwidth):
|
|
2898
3288
|
"""Design complex low-pass filter.
|
|
2899
3289
|
|
|
2900
3290
|
The filter is a low-pass FIR filter modulated to the center frequency.
|
|
@@ -2902,16 +3292,16 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
|
|
|
2902
3292
|
Args:
|
|
2903
3293
|
num_taps (int): number of taps in filter.
|
|
2904
3294
|
sampling_frequency (float): sample frequency.
|
|
2905
|
-
|
|
2906
|
-
|
|
3295
|
+
center_frequency (float): center frequency.
|
|
3296
|
+
bandwidth (float): bandwidth in Hz.
|
|
2907
3297
|
|
|
2908
3298
|
Raises:
|
|
2909
|
-
ValueError: if cutoff frequency (
|
|
3299
|
+
ValueError: if cutoff frequency (bandwidth / 2) is not within (0, sampling_frequency / 2)
|
|
2910
3300
|
|
|
2911
3301
|
Returns:
|
|
2912
3302
|
ndarray: Complex-valued low-pass filter
|
|
2913
3303
|
"""
|
|
2914
|
-
cutoff =
|
|
3304
|
+
cutoff = bandwidth / 2
|
|
2915
3305
|
if not (0 < cutoff < sampling_frequency / 2):
|
|
2916
3306
|
raise ValueError(
|
|
2917
3307
|
f"Cutoff frequency must be within (0, sampling_frequency / 2), "
|
|
@@ -2921,7 +3311,7 @@ def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
|
|
|
2921
3311
|
lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
|
|
2922
3312
|
# Modulate to center frequency to make it complex
|
|
2923
3313
|
time_points = np.arange(num_taps) / sampling_frequency
|
|
2924
|
-
lpf_complex = lpf * np.exp(1j * 2 * np.pi *
|
|
3314
|
+
lpf_complex = lpf * np.exp(1j * 2 * np.pi * center_frequency * time_points)
|
|
2925
3315
|
return lpf_complex
|
|
2926
3316
|
|
|
2927
3317
|
|
|
@@ -3095,6 +3485,50 @@ def demodulate(data, center_frequency, sampling_frequency, axis=-3):
|
|
|
3095
3485
|
iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
|
|
3096
3486
|
|
|
3097
3487
|
# Split the complex signal into two channels
|
|
3098
|
-
iq_data_two_channel = complex_to_channels(iq_data_signal_complex
|
|
3488
|
+
iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
|
|
3099
3489
|
|
|
3100
3490
|
return iq_data_two_channel
|
|
3491
|
+
|
|
3492
|
+
|
|
3493
|
+
def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
|
|
3494
|
+
"""Compute the time of the peak of each waveform in a stack of waveforms.
|
|
3495
|
+
|
|
3496
|
+
Args:
|
|
3497
|
+
waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
|
|
3498
|
+
center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
|
|
3499
|
+
(n_waveforms,) or a scalar if all waveforms have the same center frequency.
|
|
3500
|
+
waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
|
|
3501
|
+
|
|
3502
|
+
Returns:
|
|
3503
|
+
ndarray: The time to peak for each waveform in seconds.
|
|
3504
|
+
"""
|
|
3505
|
+
t_peak = []
|
|
3506
|
+
center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
|
|
3507
|
+
for waveform, center_frequency in zip(waveforms, center_frequencies):
|
|
3508
|
+
t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
|
|
3509
|
+
return ops.stack(t_peak)
|
|
3510
|
+
|
|
3511
|
+
|
|
3512
|
+
def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
|
|
3513
|
+
"""Compute the time of the peak of the waveform.
|
|
3514
|
+
|
|
3515
|
+
Args:
|
|
3516
|
+
waveform (ndarray): The waveform of shape (n_samples).
|
|
3517
|
+
center_frequency (float): The center frequency of the waveform in Hz.
|
|
3518
|
+
waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
|
|
3519
|
+
|
|
3520
|
+
Returns:
|
|
3521
|
+
float: The time to peak for the waveform in seconds.
|
|
3522
|
+
"""
|
|
3523
|
+
n_samples = waveform.shape[0]
|
|
3524
|
+
if n_samples == 0:
|
|
3525
|
+
raise ValueError("Waveform has zero samples.")
|
|
3526
|
+
|
|
3527
|
+
waveforms_iq_complex_channels = demodulate(
|
|
3528
|
+
waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
|
|
3529
|
+
)
|
|
3530
|
+
waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
|
|
3531
|
+
envelope = ops.abs(waveforms_iq_complex)
|
|
3532
|
+
peak_idx = ops.argmax(envelope, axis=-1)
|
|
3533
|
+
t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
|
|
3534
|
+
return t_peak
|