zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +3 -3
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +173 -12
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +28 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +390 -196
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +406 -302
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
- zea-0.0.7.dist-info/RECORD +114 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/ops.py
CHANGED
|
@@ -11,39 +11,62 @@ Operations can be run on their own:
|
|
|
11
11
|
|
|
12
12
|
Examples
|
|
13
13
|
^^^^^^^^
|
|
14
|
-
..
|
|
14
|
+
.. doctest::
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
#
|
|
20
|
-
|
|
16
|
+
>>> import numpy as np
|
|
17
|
+
>>> from zea.ops import EnvelopeDetect
|
|
18
|
+
>>> data = np.random.randn(2000, 128, 1)
|
|
19
|
+
>>> # static arguments are passed in the constructor
|
|
20
|
+
>>> envelope_detect = EnvelopeDetect(axis=-1)
|
|
21
|
+
>>> # other parameters can be passed here along with the data
|
|
22
|
+
>>> envelope_data = envelope_detect(data=data)
|
|
21
23
|
|
|
22
24
|
Using a pipeline
|
|
23
25
|
----------------
|
|
24
26
|
|
|
25
27
|
You can initialize with a default pipeline or create your own custom pipeline.
|
|
26
28
|
|
|
27
|
-
..
|
|
29
|
+
.. doctest::
|
|
28
30
|
|
|
29
|
-
|
|
31
|
+
>>> from zea.ops import Pipeline, EnvelopeDetect, Normalize, LogCompress
|
|
32
|
+
>>> pipeline = Pipeline.from_default()
|
|
30
33
|
|
|
31
|
-
operations = [
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
]
|
|
36
|
-
pipeline_custom = Pipeline(operations)
|
|
34
|
+
>>> operations = [
|
|
35
|
+
... EnvelopeDetect(),
|
|
36
|
+
... Normalize(),
|
|
37
|
+
... LogCompress(),
|
|
38
|
+
... ]
|
|
39
|
+
>>> pipeline_custom = Pipeline(operations)
|
|
37
40
|
|
|
38
41
|
One can also load a pipeline from a config or yaml/json file:
|
|
39
42
|
|
|
40
|
-
..
|
|
43
|
+
.. doctest::
|
|
44
|
+
|
|
45
|
+
>>> from zea import Pipeline
|
|
46
|
+
|
|
47
|
+
>>> # From JSON string
|
|
48
|
+
>>> json_string = '{"operations": ["identity"]}'
|
|
49
|
+
>>> pipeline = Pipeline.from_json(json_string)
|
|
50
|
+
|
|
51
|
+
>>> # from yaml file
|
|
52
|
+
>>> import yaml
|
|
53
|
+
>>> from zea import Config
|
|
54
|
+
>>> # Create a sample pipeline YAML file
|
|
55
|
+
>>> pipeline_dict = {
|
|
56
|
+
... "operations": [
|
|
57
|
+
... {"name": "identity"},
|
|
58
|
+
... ]
|
|
59
|
+
... }
|
|
60
|
+
>>> with open("pipeline.yaml", "w") as f:
|
|
61
|
+
... yaml.dump(pipeline_dict, f)
|
|
62
|
+
>>> yaml_file = "pipeline.yaml"
|
|
63
|
+
>>> pipeline = Pipeline.from_yaml(yaml_file)
|
|
64
|
+
|
|
65
|
+
.. testcleanup::
|
|
41
66
|
|
|
42
|
-
|
|
43
|
-
pipeline = Pipeline.from_json(json_string)
|
|
67
|
+
import os
|
|
44
68
|
|
|
45
|
-
|
|
46
|
-
pipeline = Pipeline.from_yaml(yaml_file)
|
|
69
|
+
os.remove("pipeline.yaml")
|
|
47
70
|
|
|
48
71
|
Example of a yaml file:
|
|
49
72
|
|
|
@@ -56,8 +79,6 @@ Example of a yaml file:
|
|
|
56
79
|
params:
|
|
57
80
|
operations:
|
|
58
81
|
- name: tof_correction
|
|
59
|
-
params:
|
|
60
|
-
apply_phase_rotation: true
|
|
61
82
|
- name: pfield_weighting
|
|
62
83
|
- name: delay_and_sum
|
|
63
84
|
num_patches: 100
|
|
@@ -79,7 +100,7 @@ import numpy as np
|
|
|
79
100
|
import scipy
|
|
80
101
|
import yaml
|
|
81
102
|
from keras import ops
|
|
82
|
-
from keras.src.layers.preprocessing.
|
|
103
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
|
83
104
|
|
|
84
105
|
from zea import log
|
|
85
106
|
from zea.backend import jit
|
|
@@ -99,8 +120,12 @@ from zea.internal.registry import ops_registry
|
|
|
99
120
|
from zea.probes import Probe
|
|
100
121
|
from zea.scan import Scan
|
|
101
122
|
from zea.simulator import simulate_rf
|
|
102
|
-
from zea.tensor_ops import
|
|
103
|
-
from zea.utils import
|
|
123
|
+
from zea.tensor_ops import resample, reshape_axis, translate, vmap
|
|
124
|
+
from zea.utils import (
|
|
125
|
+
FunctionTimer,
|
|
126
|
+
deep_compare,
|
|
127
|
+
map_negative_indices,
|
|
128
|
+
)
|
|
104
129
|
|
|
105
130
|
|
|
106
131
|
def get_ops(ops_name):
|
|
@@ -125,6 +150,7 @@ class Operation(keras.Operation):
|
|
|
125
150
|
with_batch_dim: bool = True,
|
|
126
151
|
jit_kwargs: dict | None = None,
|
|
127
152
|
jittable: bool = True,
|
|
153
|
+
additional_output_keys: List[str] = None,
|
|
128
154
|
**kwargs,
|
|
129
155
|
):
|
|
130
156
|
"""
|
|
@@ -143,6 +169,10 @@ class Operation(keras.Operation):
|
|
|
143
169
|
with_batch_dim: Whether operations should expect a batch dimension in the input
|
|
144
170
|
jit_kwargs: Additional keyword arguments for the JIT compiler
|
|
145
171
|
jittable: Whether the operation can be JIT compiled
|
|
172
|
+
additional_output_keys: A list of additional output keys produced by the operation.
|
|
173
|
+
These are used to track if all keys are available for downstream operations.
|
|
174
|
+
If the operation has a conditional output, it is best to add all possible
|
|
175
|
+
output keys here.
|
|
146
176
|
"""
|
|
147
177
|
super().__init__(**kwargs)
|
|
148
178
|
|
|
@@ -153,6 +183,7 @@ class Operation(keras.Operation):
|
|
|
153
183
|
self.output_key = output_key # Key for output data
|
|
154
184
|
if self.output_key is None:
|
|
155
185
|
self.output_key = self.key
|
|
186
|
+
self.additional_output_keys = additional_output_keys or []
|
|
156
187
|
|
|
157
188
|
self.inputs = [] # Source(s) of input data (name of a previous operation)
|
|
158
189
|
self.allow_multiple_inputs = False # Only single input allowed by default
|
|
@@ -165,8 +196,6 @@ class Operation(keras.Operation):
|
|
|
165
196
|
self._output_cache = {}
|
|
166
197
|
|
|
167
198
|
# Obtain the input signature of the `call` method
|
|
168
|
-
self._input_signature = None
|
|
169
|
-
self._valid_keys = None # Keys valid for the `call` method
|
|
170
199
|
self._trace_signatures()
|
|
171
200
|
|
|
172
201
|
if jit_kwargs is None:
|
|
@@ -186,6 +215,11 @@ class Operation(keras.Operation):
|
|
|
186
215
|
with log.set_level("ERROR"):
|
|
187
216
|
self.set_jit(jit_compile)
|
|
188
217
|
|
|
218
|
+
@property
|
|
219
|
+
def output_keys(self) -> List[str]:
|
|
220
|
+
"""Get the output keys of the operation."""
|
|
221
|
+
return [self.output_key] + self.additional_output_keys
|
|
222
|
+
|
|
189
223
|
@property
|
|
190
224
|
def static_params(self):
|
|
191
225
|
"""Get the static parameters of the operation."""
|
|
@@ -207,10 +241,15 @@ class Operation(keras.Operation):
|
|
|
207
241
|
self._valid_keys = set(self._input_signature.parameters.keys())
|
|
208
242
|
|
|
209
243
|
@property
|
|
210
|
-
def valid_keys(self):
|
|
244
|
+
def valid_keys(self) -> set:
|
|
211
245
|
"""Get the valid keys for the `call` method."""
|
|
212
246
|
return self._valid_keys
|
|
213
247
|
|
|
248
|
+
@property
|
|
249
|
+
def needs_keys(self) -> set:
|
|
250
|
+
"""Get a set of all input keys needed by the operation."""
|
|
251
|
+
return self.valid_keys
|
|
252
|
+
|
|
214
253
|
@property
|
|
215
254
|
def jittable(self):
|
|
216
255
|
"""Check if the operation can be JIT compiled."""
|
|
@@ -408,8 +447,6 @@ class Pipeline:
|
|
|
408
447
|
"""
|
|
409
448
|
self._call_pipeline = self.call
|
|
410
449
|
self.name = name
|
|
411
|
-
self.timer = FunctionTimer()
|
|
412
|
-
self.timed = timed
|
|
413
450
|
|
|
414
451
|
self._pipeline_layers = operations
|
|
415
452
|
|
|
@@ -419,6 +456,24 @@ class Pipeline:
|
|
|
419
456
|
self.with_batch_dim = with_batch_dim
|
|
420
457
|
self._validate_flag = validate
|
|
421
458
|
|
|
459
|
+
# Setup timer
|
|
460
|
+
if jit_options == "pipeline" and timed:
|
|
461
|
+
raise ValueError(
|
|
462
|
+
"timed=True cannot be used with jit_options='pipeline' as the entire "
|
|
463
|
+
"pipeline is compiled into a single function. Try setting jit_options to "
|
|
464
|
+
"'ops' or None."
|
|
465
|
+
)
|
|
466
|
+
if timed:
|
|
467
|
+
log.warning(
|
|
468
|
+
"Timer has been initialized for the pipeline. To get an accurate timing estimate, "
|
|
469
|
+
"the `block_until_ready()` is used, which will slow down the execution, so "
|
|
470
|
+
"do not use for regular processing!"
|
|
471
|
+
)
|
|
472
|
+
self._callable_layers = self._get_timed_operations()
|
|
473
|
+
else:
|
|
474
|
+
self._callable_layers = self._pipeline_layers
|
|
475
|
+
self._timed = timed
|
|
476
|
+
|
|
422
477
|
if validate:
|
|
423
478
|
self.validate()
|
|
424
479
|
else:
|
|
@@ -427,19 +482,30 @@ class Pipeline:
|
|
|
427
482
|
if jit_kwargs is None:
|
|
428
483
|
jit_kwargs = {}
|
|
429
484
|
|
|
430
|
-
if keras.backend.backend() == "jax" and self.static_params:
|
|
485
|
+
if keras.backend.backend() == "jax" and self.static_params != []:
|
|
431
486
|
jit_kwargs = {"static_argnames": self.static_params}
|
|
432
487
|
|
|
433
488
|
self.jit_kwargs = jit_kwargs
|
|
434
489
|
self.jit_options = jit_options # will handle the jit compilation
|
|
435
490
|
|
|
436
491
|
def needs(self, key) -> bool:
|
|
437
|
-
"""Check if the pipeline needs a specific key."""
|
|
438
|
-
return key in self.
|
|
492
|
+
"""Check if the pipeline needs a specific key at the input."""
|
|
493
|
+
return key in self.needs_keys
|
|
494
|
+
|
|
495
|
+
@property
|
|
496
|
+
def output_keys(self) -> set:
|
|
497
|
+
"""All output keys the pipeline guarantees to produce."""
|
|
498
|
+
output_keys = set()
|
|
499
|
+
for operation in self.operations:
|
|
500
|
+
output_keys.update(operation.output_keys)
|
|
501
|
+
return output_keys
|
|
439
502
|
|
|
440
503
|
@property
|
|
441
504
|
def valid_keys(self) -> set:
|
|
442
|
-
"""Get a set of valid keys for the pipeline.
|
|
505
|
+
"""Get a set of valid keys for the pipeline.
|
|
506
|
+
|
|
507
|
+
This is all keys that can be passed to the pipeline as input.
|
|
508
|
+
"""
|
|
443
509
|
valid_keys = set()
|
|
444
510
|
for operation in self.operations:
|
|
445
511
|
valid_keys.update(operation.valid_keys)
|
|
@@ -453,8 +519,26 @@ class Pipeline:
|
|
|
453
519
|
static_params.extend(operation.static_params)
|
|
454
520
|
return list(set(static_params))
|
|
455
521
|
|
|
522
|
+
@property
|
|
523
|
+
def needs_keys(self) -> set:
|
|
524
|
+
"""Get a set of all input keys needed by the pipeline.
|
|
525
|
+
|
|
526
|
+
Will keep track of keys that are already provided by previous operations.
|
|
527
|
+
"""
|
|
528
|
+
needs = set()
|
|
529
|
+
has_so_far = set()
|
|
530
|
+
previous_operation = None
|
|
531
|
+
for operation in self.operations:
|
|
532
|
+
if previous_operation is not None:
|
|
533
|
+
has_so_far.update(previous_operation.output_keys)
|
|
534
|
+
needs.update(operation.needs_keys - has_so_far)
|
|
535
|
+
previous_operation = operation
|
|
536
|
+
return needs
|
|
537
|
+
|
|
456
538
|
@classmethod
|
|
457
|
-
def from_default(
|
|
539
|
+
def from_default(
|
|
540
|
+
cls, num_patches=100, baseband=False, pfield=False, timed=False, **kwargs
|
|
541
|
+
) -> "Pipeline":
|
|
458
542
|
"""Create a default pipeline.
|
|
459
543
|
|
|
460
544
|
Args:
|
|
@@ -465,6 +549,7 @@ class Pipeline:
|
|
|
465
549
|
so input signal has a single channel dim and is still on carrier frequency.
|
|
466
550
|
pfield (bool): If True, apply Pfield weighting. Defaults to False.
|
|
467
551
|
This will calculate pressure field and only beamform the data to those locations.
|
|
552
|
+
timed (bool, optional): Whether to time each operation. Defaults to False.
|
|
468
553
|
**kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
|
|
469
554
|
|
|
470
555
|
"""
|
|
@@ -476,7 +561,7 @@ class Pipeline:
|
|
|
476
561
|
|
|
477
562
|
# Get beamforming ops
|
|
478
563
|
beamforming = [
|
|
479
|
-
TOFCorrection(
|
|
564
|
+
TOFCorrection(),
|
|
480
565
|
DelayAndSum(),
|
|
481
566
|
]
|
|
482
567
|
if pfield:
|
|
@@ -495,7 +580,7 @@ class Pipeline:
|
|
|
495
580
|
Normalize(),
|
|
496
581
|
LogCompress(),
|
|
497
582
|
]
|
|
498
|
-
return cls(operations, **kwargs)
|
|
583
|
+
return cls(operations, timed=timed, **kwargs)
|
|
499
584
|
|
|
500
585
|
def copy(self) -> "Pipeline":
|
|
501
586
|
"""Create a copy of the pipeline."""
|
|
@@ -506,57 +591,60 @@ class Pipeline:
|
|
|
506
591
|
jit_kwargs=self.jit_kwargs,
|
|
507
592
|
name=self.name,
|
|
508
593
|
validate=self._validate_flag,
|
|
594
|
+
timed=self._timed,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
def reinitialize(self):
|
|
598
|
+
"""Reinitialize the pipeline in place."""
|
|
599
|
+
self.__init__(
|
|
600
|
+
self._pipeline_layers,
|
|
601
|
+
with_batch_dim=self.with_batch_dim,
|
|
602
|
+
jit_options=self.jit_options,
|
|
603
|
+
jit_kwargs=self.jit_kwargs,
|
|
604
|
+
name=self.name,
|
|
605
|
+
validate=self._validate_flag,
|
|
606
|
+
timed=self._timed,
|
|
509
607
|
)
|
|
510
608
|
|
|
511
609
|
def prepend(self, operation: Operation):
|
|
512
610
|
"""Prepend an operation to the pipeline."""
|
|
513
611
|
self._pipeline_layers.insert(0, operation)
|
|
514
|
-
self.
|
|
612
|
+
self.reinitialize()
|
|
515
613
|
|
|
516
614
|
def append(self, operation: Operation):
|
|
517
615
|
"""Append an operation to the pipeline."""
|
|
518
616
|
self._pipeline_layers.append(operation)
|
|
519
|
-
self.
|
|
617
|
+
self.reinitialize()
|
|
520
618
|
|
|
521
619
|
def insert(self, index: int, operation: Operation):
|
|
522
620
|
"""Insert an operation at a specific index in the pipeline."""
|
|
523
621
|
if index < 0 or index > len(self._pipeline_layers):
|
|
524
622
|
raise IndexError("Index out of bounds for inserting operation.")
|
|
525
623
|
self._pipeline_layers.insert(index, operation)
|
|
526
|
-
|
|
624
|
+
self.reinitialize()
|
|
527
625
|
|
|
528
626
|
@property
|
|
529
627
|
def operations(self):
|
|
530
628
|
"""Alias for self.layers to match the zea naming convention"""
|
|
531
629
|
return self._pipeline_layers
|
|
532
630
|
|
|
533
|
-
def
|
|
534
|
-
"""
|
|
631
|
+
def reset_timer(self):
|
|
632
|
+
"""Reset the timer for timed operations."""
|
|
633
|
+
if self._timed:
|
|
634
|
+
self._callable_layers = self._get_timed_operations()
|
|
635
|
+
else:
|
|
636
|
+
log.warning(
|
|
637
|
+
"Timer has not been initialized. Set timed=True when initializing the pipeline."
|
|
638
|
+
)
|
|
535
639
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
except KeyError as exc:
|
|
541
|
-
raise KeyError(
|
|
542
|
-
f"[zea.Pipeline] Operation '{op.__class__.__name__}' "
|
|
543
|
-
f"requires input key '{exc.args[0]}', "
|
|
544
|
-
"but it was not provided in the inputs.\n"
|
|
545
|
-
"Check whether the objects (such as `zea.Scan`) passed to "
|
|
546
|
-
"`pipeline.prepare_parameters()` contain all required keys.\n"
|
|
547
|
-
f"Current list of all passed keys: {list(inputs.keys())}\n"
|
|
548
|
-
f"Valid keys for this pipeline: {self.valid_keys}"
|
|
549
|
-
) from exc
|
|
550
|
-
except Exception as exc:
|
|
551
|
-
raise RuntimeError(
|
|
552
|
-
f"[zea.Pipeline] Error in operation '{op.__class__.__name__}': {exc}"
|
|
553
|
-
) from exc
|
|
554
|
-
inputs = outputs
|
|
555
|
-
return outputs
|
|
640
|
+
def _get_timed_operations(self):
|
|
641
|
+
"""Get a list of timed operations."""
|
|
642
|
+
self.timer = FunctionTimer()
|
|
643
|
+
return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
|
|
556
644
|
|
|
557
645
|
def call(self, **inputs):
|
|
558
646
|
"""Process input data through the pipeline."""
|
|
559
|
-
for operation in self.
|
|
647
|
+
for operation in self._callable_layers:
|
|
560
648
|
try:
|
|
561
649
|
outputs = operation(**inputs)
|
|
562
650
|
except KeyError as exc:
|
|
@@ -579,14 +667,9 @@ class Pipeline:
|
|
|
579
667
|
def __call__(self, return_numpy=False, **inputs):
|
|
580
668
|
"""Process input data through the pipeline."""
|
|
581
669
|
|
|
582
|
-
if any(key in inputs for key in ["probe", "scan", "config"])
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
"`Pipeline.prepare_parameters` before calling the pipeline. "
|
|
586
|
-
"e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
if any(isinstance(arg, ZEAObject) for arg in inputs.values()):
|
|
670
|
+
if any(key in inputs for key in ["probe", "scan", "config"]) or any(
|
|
671
|
+
isinstance(arg, ZEAObject) for arg in inputs.values()
|
|
672
|
+
):
|
|
590
673
|
raise ValueError(
|
|
591
674
|
"Probe, Scan and Config objects should be first processed with "
|
|
592
675
|
"`Pipeline.prepare_parameters` before calling the pipeline. "
|
|
@@ -640,18 +723,13 @@ class Pipeline:
|
|
|
640
723
|
if operation.jittable and operation._jit_compile:
|
|
641
724
|
operation.set_jit(value == "ops")
|
|
642
725
|
|
|
643
|
-
@property
|
|
644
|
-
def _call_fn(self):
|
|
645
|
-
"""Get the call function of the pipeline."""
|
|
646
|
-
return self.call if not self.timed else self.timed_call
|
|
647
|
-
|
|
648
726
|
def jit(self):
|
|
649
727
|
"""JIT compile the pipeline."""
|
|
650
|
-
self._call_pipeline = jit(self.
|
|
728
|
+
self._call_pipeline = jit(self.call, **self.jit_kwargs)
|
|
651
729
|
|
|
652
730
|
def unjit(self):
|
|
653
731
|
"""Un-JIT compile the pipeline."""
|
|
654
|
-
self._call_pipeline = self.
|
|
732
|
+
self._call_pipeline = self.call
|
|
655
733
|
|
|
656
734
|
@property
|
|
657
735
|
def jittable(self):
|
|
@@ -835,16 +913,17 @@ class Pipeline:
|
|
|
835
913
|
Must have a ``pipeline`` key with a subkey ``operations``.
|
|
836
914
|
|
|
837
915
|
Example:
|
|
838
|
-
..
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
916
|
+
.. doctest::
|
|
917
|
+
|
|
918
|
+
>>> from zea import Config, Pipeline
|
|
919
|
+
>>> config = Config(
|
|
920
|
+
... {
|
|
921
|
+
... "operations": [
|
|
922
|
+
... "identity",
|
|
923
|
+
... ],
|
|
924
|
+
... }
|
|
925
|
+
... )
|
|
926
|
+
>>> pipeline = Pipeline.from_config(config)
|
|
848
927
|
"""
|
|
849
928
|
return pipeline_from_config(Config(config), **kwargs)
|
|
850
929
|
|
|
@@ -860,9 +939,20 @@ class Pipeline:
|
|
|
860
939
|
Must have the a `pipeline` key with a subkey `operations`.
|
|
861
940
|
|
|
862
941
|
Example:
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
942
|
+
.. doctest::
|
|
943
|
+
|
|
944
|
+
>>> import yaml
|
|
945
|
+
>>> from zea import Config
|
|
946
|
+
>>> # Create a sample pipeline YAML file
|
|
947
|
+
>>> pipeline_dict = {
|
|
948
|
+
... "operations": [
|
|
949
|
+
... "identity",
|
|
950
|
+
... ],
|
|
951
|
+
... }
|
|
952
|
+
>>> with open("pipeline.yaml", "w") as f:
|
|
953
|
+
... yaml.dump(pipeline_dict, f)
|
|
954
|
+
>>> from zea.ops import Pipeline
|
|
955
|
+
>>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
|
|
866
956
|
"""
|
|
867
957
|
return pipeline_from_yaml(file_path, **kwargs)
|
|
868
958
|
|
|
@@ -963,7 +1053,7 @@ class Pipeline:
|
|
|
963
1053
|
assert isinstance(scan, Scan), (
|
|
964
1054
|
f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
|
|
965
1055
|
)
|
|
966
|
-
scan_dict = scan.to_tensor(include=self.
|
|
1056
|
+
scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
|
|
967
1057
|
|
|
968
1058
|
if config is not None:
|
|
969
1059
|
assert isinstance(config, Config), (
|
|
@@ -1004,15 +1094,17 @@ def make_operation_chain(
|
|
|
1004
1094
|
list: List of operations to be performed.
|
|
1005
1095
|
|
|
1006
1096
|
Example:
|
|
1007
|
-
..
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
)
|
|
1097
|
+
.. doctest::
|
|
1098
|
+
|
|
1099
|
+
>>> from zea.ops import make_operation_chain, LogCompress
|
|
1100
|
+
>>> SomeCustomOperation = LogCompress # just for demonstration
|
|
1101
|
+
>>> chain = make_operation_chain(
|
|
1102
|
+
... [
|
|
1103
|
+
... "envelope_detect",
|
|
1104
|
+
... {"name": "normalize", "params": {"output_range": (0, 1)}},
|
|
1105
|
+
... SomeCustomOperation(),
|
|
1106
|
+
... ]
|
|
1107
|
+
... )
|
|
1016
1108
|
"""
|
|
1017
1109
|
chain = []
|
|
1018
1110
|
for operation in operation_chain:
|
|
@@ -1228,11 +1320,23 @@ class PatchedGrid(Pipeline):
|
|
|
1228
1320
|
for operation in self.operations:
|
|
1229
1321
|
operation.with_batch_dim = False
|
|
1230
1322
|
|
|
1323
|
+
@property
|
|
1324
|
+
def _extra_keys(self):
|
|
1325
|
+
return {"flatgrid", "grid_size_x", "grid_size_z"}
|
|
1326
|
+
|
|
1231
1327
|
@property
|
|
1232
1328
|
def valid_keys(self) -> set:
|
|
1233
|
-
"""Get a set of valid keys for the pipeline.
|
|
1234
|
-
operates on (even if not used by operations
|
|
1235
|
-
|
|
1329
|
+
"""Get a set of valid keys for the pipeline.
|
|
1330
|
+
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1331
|
+
inside it)."""
|
|
1332
|
+
return super().valid_keys.union(self._extra_keys)
|
|
1333
|
+
|
|
1334
|
+
@property
|
|
1335
|
+
def needs_keys(self) -> set:
|
|
1336
|
+
"""Get a set of all input keys needed by the pipeline.
|
|
1337
|
+
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1338
|
+
inside it)."""
|
|
1339
|
+
return super().needs_keys.union(self._extra_keys)
|
|
1236
1340
|
|
|
1237
1341
|
def call_item(self, inputs):
|
|
1238
1342
|
"""Process data in patches."""
|
|
@@ -1242,35 +1346,22 @@ class PatchedGrid(Pipeline):
|
|
|
1242
1346
|
grid_size_z = inputs["grid_size_z"]
|
|
1243
1347
|
flatgrid = inputs.pop("flatgrid")
|
|
1244
1348
|
|
|
1245
|
-
# TODO: maybe using n_tx and n_el from kwargs is better but these are tensors now
|
|
1246
|
-
# and this is not supported in broadcast_to
|
|
1247
|
-
n_tx = inputs[self.key].shape[0]
|
|
1248
|
-
n_pix = flatgrid.shape[0]
|
|
1249
|
-
n_el = inputs[self.key].shape[2]
|
|
1250
|
-
inputs["rx_apo"] = ops.broadcast_to(inputs.get("rx_apo", 1.0), (n_tx, n_pix, n_el))
|
|
1251
|
-
inputs["rx_apo"] = ops.swapaxes(inputs["rx_apo"], 0, 1) # put n_pix first
|
|
1252
|
-
|
|
1253
1349
|
# Define a list of keys to look up for patching
|
|
1254
|
-
|
|
1350
|
+
flat_pfield = inputs.pop("flat_pfield", None)
|
|
1255
1351
|
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
def patched_call(flatgrid, **patch_kwargs):
|
|
1262
|
-
patch_args = {k: v for k, v in patch_kwargs.items() if v is not None}
|
|
1263
|
-
patch_args["rx_apo"] = ops.swapaxes(patch_args["rx_apo"], 0, 1)
|
|
1264
|
-
out = super(PatchedGrid, self).call(flatgrid=flatgrid, **patch_args, **inputs)
|
|
1352
|
+
def patched_call(flatgrid, flat_pfield):
|
|
1353
|
+
out = super(PatchedGrid, self).call(
|
|
1354
|
+
flatgrid=flatgrid, flat_pfield=flat_pfield, **inputs
|
|
1355
|
+
)
|
|
1265
1356
|
return out[self.output_key]
|
|
1266
1357
|
|
|
1267
|
-
out =
|
|
1358
|
+
out = vmap(
|
|
1268
1359
|
patched_call,
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1360
|
+
chunks=self.num_patches,
|
|
1361
|
+
fn_supports_batch=True,
|
|
1362
|
+
disable_jit=not bool(self.jit_options),
|
|
1363
|
+
)(flatgrid, flat_pfield)
|
|
1364
|
+
|
|
1274
1365
|
return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
|
|
1275
1366
|
|
|
1276
1367
|
def jittable_call(self, **inputs):
|
|
@@ -1309,7 +1400,7 @@ class Identity(Operation):
|
|
|
1309
1400
|
|
|
1310
1401
|
def call(self, **kwargs) -> Dict:
|
|
1311
1402
|
"""Returns the input as is."""
|
|
1312
|
-
return
|
|
1403
|
+
return {}
|
|
1313
1404
|
|
|
1314
1405
|
|
|
1315
1406
|
@ops_registry("merge")
|
|
@@ -1399,6 +1490,7 @@ class Simulate(Operation):
|
|
|
1399
1490
|
def __init__(self, **kwargs):
|
|
1400
1491
|
super().__init__(
|
|
1401
1492
|
output_data_type=DataTypes.RAW_DATA,
|
|
1493
|
+
additional_output_keys=["n_ch"],
|
|
1402
1494
|
**kwargs,
|
|
1403
1495
|
)
|
|
1404
1496
|
|
|
@@ -1451,18 +1543,16 @@ class TOFCorrection(Operation):
|
|
|
1451
1543
|
STATIC_PARAMS = [
|
|
1452
1544
|
"f_number",
|
|
1453
1545
|
"apply_lens_correction",
|
|
1454
|
-
"apply_phase_rotation",
|
|
1455
1546
|
"grid_size_x",
|
|
1456
1547
|
"grid_size_z",
|
|
1457
1548
|
]
|
|
1458
1549
|
|
|
1459
|
-
def __init__(self,
|
|
1550
|
+
def __init__(self, **kwargs):
|
|
1460
1551
|
super().__init__(
|
|
1461
1552
|
input_data_type=DataTypes.RAW_DATA,
|
|
1462
1553
|
output_data_type=DataTypes.ALIGNED_DATA,
|
|
1463
1554
|
**kwargs,
|
|
1464
1555
|
)
|
|
1465
|
-
self.apply_phase_rotation = apply_phase_rotation
|
|
1466
1556
|
|
|
1467
1557
|
def call(
|
|
1468
1558
|
self,
|
|
@@ -1477,6 +1567,8 @@ class TOFCorrection(Operation):
|
|
|
1477
1567
|
tx_apodizations,
|
|
1478
1568
|
initial_times,
|
|
1479
1569
|
probe_geometry,
|
|
1570
|
+
t_peak,
|
|
1571
|
+
tx_waveform_indices,
|
|
1480
1572
|
apply_lens_correction=None,
|
|
1481
1573
|
lens_thickness=None,
|
|
1482
1574
|
lens_sound_speed=None,
|
|
@@ -1497,6 +1589,9 @@ class TOFCorrection(Operation):
|
|
|
1497
1589
|
tx_apodizations (ops.Tensor): Transmit apodizations
|
|
1498
1590
|
initial_times (ops.Tensor): Initial times
|
|
1499
1591
|
probe_geometry (ops.Tensor): Probe element positions
|
|
1592
|
+
t_peak (float): Time to peak of the transmit pulse
|
|
1593
|
+
tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
|
|
1594
|
+
transmit. (All zero if there is only one waveform)
|
|
1500
1595
|
apply_lens_correction (bool): Whether to apply lens correction
|
|
1501
1596
|
lens_thickness (float): Lens thickness
|
|
1502
1597
|
lens_sound_speed (float): Sound speed in the lens
|
|
@@ -1507,29 +1602,30 @@ class TOFCorrection(Operation):
|
|
|
1507
1602
|
|
|
1508
1603
|
raw_data = kwargs[self.key]
|
|
1509
1604
|
|
|
1510
|
-
|
|
1605
|
+
tof_kwargs = {
|
|
1511
1606
|
"flatgrid": flatgrid,
|
|
1512
|
-
"sound_speed": sound_speed,
|
|
1513
|
-
"angles": polar_angles,
|
|
1514
|
-
"focus_distances": focus_distances,
|
|
1515
|
-
"sampling_frequency": sampling_frequency,
|
|
1516
|
-
"fnum": f_number,
|
|
1517
|
-
"apply_phase_rotation": self.apply_phase_rotation,
|
|
1518
|
-
"demodulation_frequency": demodulation_frequency,
|
|
1519
1607
|
"t0_delays": t0_delays,
|
|
1520
1608
|
"tx_apodizations": tx_apodizations,
|
|
1521
|
-
"
|
|
1609
|
+
"sound_speed": sound_speed,
|
|
1522
1610
|
"probe_geometry": probe_geometry,
|
|
1611
|
+
"initial_times": initial_times,
|
|
1612
|
+
"sampling_frequency": sampling_frequency,
|
|
1613
|
+
"demodulation_frequency": demodulation_frequency,
|
|
1614
|
+
"f_number": f_number,
|
|
1615
|
+
"polar_angles": polar_angles,
|
|
1616
|
+
"focus_distances": focus_distances,
|
|
1617
|
+
"t_peak": t_peak,
|
|
1618
|
+
"tx_waveform_indices": tx_waveform_indices,
|
|
1523
1619
|
"apply_lens_correction": apply_lens_correction,
|
|
1524
1620
|
"lens_thickness": lens_thickness,
|
|
1525
1621
|
"lens_sound_speed": lens_sound_speed,
|
|
1526
1622
|
}
|
|
1527
1623
|
|
|
1528
1624
|
if not self.with_batch_dim:
|
|
1529
|
-
tof_corrected = tof_correction(raw_data, **
|
|
1625
|
+
tof_corrected = tof_correction(raw_data, **tof_kwargs)
|
|
1530
1626
|
else:
|
|
1531
1627
|
tof_corrected = ops.map(
|
|
1532
|
-
lambda data: tof_correction(data, **
|
|
1628
|
+
lambda data: tof_correction(data, **tof_kwargs),
|
|
1533
1629
|
raw_data,
|
|
1534
1630
|
)
|
|
1535
1631
|
|
|
@@ -1587,44 +1683,35 @@ class DelayAndSum(Operation):
|
|
|
1587
1683
|
**kwargs,
|
|
1588
1684
|
):
|
|
1589
1685
|
super().__init__(
|
|
1590
|
-
input_data_type=
|
|
1686
|
+
input_data_type=DataTypes.ALIGNED_DATA,
|
|
1591
1687
|
output_data_type=DataTypes.BEAMFORMED_DATA,
|
|
1592
1688
|
**kwargs,
|
|
1593
1689
|
)
|
|
1594
1690
|
self.reshape_grid = reshape_grid
|
|
1595
1691
|
|
|
1596
|
-
def process_image(self, data
|
|
1692
|
+
def process_image(self, data):
|
|
1597
1693
|
"""Performs DAS beamforming on tof-corrected input.
|
|
1598
1694
|
|
|
1599
1695
|
Args:
|
|
1600
1696
|
data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
|
|
1601
|
-
rx_apo (ops.Tensor): Receive apodization window of shape `(n_tx, n_pix, n_el, n_ch)`.
|
|
1602
1697
|
|
|
1603
1698
|
Returns:
|
|
1604
1699
|
ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
|
|
1605
1700
|
"""
|
|
1606
1701
|
# Sum over the channels, i.e. DAS
|
|
1607
|
-
data = ops.sum(
|
|
1702
|
+
data = ops.sum(data, -2)
|
|
1608
1703
|
|
|
1609
1704
|
# Sum over transmits, i.e. Compounding
|
|
1610
1705
|
data = ops.sum(data, 0)
|
|
1611
1706
|
|
|
1612
1707
|
return data
|
|
1613
1708
|
|
|
1614
|
-
def call(
|
|
1615
|
-
self,
|
|
1616
|
-
rx_apo=None,
|
|
1617
|
-
grid=None,
|
|
1618
|
-
**kwargs,
|
|
1619
|
-
):
|
|
1709
|
+
def call(self, grid=None, **kwargs):
|
|
1620
1710
|
"""Performs DAS beamforming on tof-corrected input.
|
|
1621
1711
|
|
|
1622
1712
|
Args:
|
|
1623
1713
|
tof_corrected_data (ops.Tensor): The TOF corrected input of shape
|
|
1624
1714
|
`(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
|
|
1625
|
-
rx_apo (ops.Tensor): Receive apodization window
|
|
1626
|
-
of shape `(n_tx, grid_size_z*grid_size_x, n_el)`
|
|
1627
|
-
with optional batch dimension. Defaults to 1.0.
|
|
1628
1715
|
|
|
1629
1716
|
Returns:
|
|
1630
1717
|
dict: Dictionary containing beamformed_data
|
|
@@ -1634,17 +1721,11 @@ class DelayAndSum(Operation):
|
|
|
1634
1721
|
"""
|
|
1635
1722
|
data = kwargs[self.key]
|
|
1636
1723
|
|
|
1637
|
-
if rx_apo is None:
|
|
1638
|
-
rx_apo = ops.ones(1, dtype=ops.dtype(data))
|
|
1639
|
-
rx_apo = ops.broadcast_to(rx_apo[..., None], data.shape)
|
|
1640
|
-
|
|
1641
1724
|
if not self.with_batch_dim:
|
|
1642
|
-
beamformed_data = self.process_image(data
|
|
1725
|
+
beamformed_data = self.process_image(data)
|
|
1643
1726
|
else:
|
|
1644
1727
|
# Apply process_image to each item in the batch
|
|
1645
|
-
beamformed_data =
|
|
1646
|
-
lambda data, rx_apo: self.process_image(data, rx_apo), data, rx_apo=rx_apo
|
|
1647
|
-
)
|
|
1728
|
+
beamformed_data = ops.map(self.process_image, data)
|
|
1648
1729
|
|
|
1649
1730
|
if self.reshape_grid:
|
|
1650
1731
|
beamformed_data = reshape_axis(
|
|
@@ -1654,6 +1735,46 @@ class DelayAndSum(Operation):
|
|
|
1654
1735
|
return {self.output_key: beamformed_data}
|
|
1655
1736
|
|
|
1656
1737
|
|
|
1738
|
+
def envelope_detect(data, axis=-3):
|
|
1739
|
+
"""Envelope detection of RF signals.
|
|
1740
|
+
|
|
1741
|
+
If the input data is real, it first applies the Hilbert transform along the specified axis
|
|
1742
|
+
and then computes the magnitude of the resulting complex signal.
|
|
1743
|
+
If the input data is complex, it computes the magnitude directly.
|
|
1744
|
+
|
|
1745
|
+
Args:
|
|
1746
|
+
- data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
|
|
1747
|
+
- axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
|
|
1748
|
+
|
|
1749
|
+
Returns:
|
|
1750
|
+
- envelope_data (Tensor): The envelope detected data
|
|
1751
|
+
of shape (..., grid_size_z, grid_size_x).
|
|
1752
|
+
"""
|
|
1753
|
+
if data.shape[-1] == 2:
|
|
1754
|
+
data = channels_to_complex(data)
|
|
1755
|
+
else:
|
|
1756
|
+
n_ax = ops.shape(data)[axis]
|
|
1757
|
+
n_ax_float = ops.cast(n_ax, "float32")
|
|
1758
|
+
|
|
1759
|
+
# Calculate next power of 2: M = 2^ceil(log2(n_ax))
|
|
1760
|
+
# see https://github.com/tue-bmd/zea/discussions/147
|
|
1761
|
+
log2_n_ax = ops.log2(n_ax_float)
|
|
1762
|
+
M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
|
|
1763
|
+
|
|
1764
|
+
data = hilbert(data, N=M, axis=axis)
|
|
1765
|
+
indices = ops.arange(n_ax)
|
|
1766
|
+
|
|
1767
|
+
data = ops.take(data, indices, axis=axis)
|
|
1768
|
+
data = ops.squeeze(data, axis=-1)
|
|
1769
|
+
|
|
1770
|
+
# data = ops.abs(data)
|
|
1771
|
+
real = ops.real(data)
|
|
1772
|
+
imag = ops.imag(data)
|
|
1773
|
+
data = ops.sqrt(real**2 + imag**2)
|
|
1774
|
+
data = ops.cast(data, "float32")
|
|
1775
|
+
return data
|
|
1776
|
+
|
|
1777
|
+
|
|
1657
1778
|
@ops_registry("envelope_detect")
|
|
1658
1779
|
class EnvelopeDetect(Operation):
|
|
1659
1780
|
"""Envelope detection of RF signals."""
|
|
@@ -1680,23 +1801,7 @@ class EnvelopeDetect(Operation):
|
|
|
1680
1801
|
"""
|
|
1681
1802
|
data = kwargs[self.key]
|
|
1682
1803
|
|
|
1683
|
-
|
|
1684
|
-
data = channels_to_complex(data)
|
|
1685
|
-
else:
|
|
1686
|
-
n_ax = data.shape[self.axis]
|
|
1687
|
-
M = 2 ** int(np.ceil(np.log2(n_ax)))
|
|
1688
|
-
# data = scipy.signal.hilbert(data, N=M, axis=self.axis)
|
|
1689
|
-
data = hilbert(data, N=M, axis=self.axis)
|
|
1690
|
-
indices = ops.arange(n_ax)
|
|
1691
|
-
|
|
1692
|
-
data = ops.take(data, indices, axis=self.axis)
|
|
1693
|
-
data = ops.squeeze(data, axis=-1)
|
|
1694
|
-
|
|
1695
|
-
# data = ops.abs(data)
|
|
1696
|
-
real = ops.real(data)
|
|
1697
|
-
imag = ops.imag(data)
|
|
1698
|
-
data = ops.sqrt(real**2 + imag**2)
|
|
1699
|
-
data = ops.cast(data, "float32")
|
|
1804
|
+
data = envelope_detect(data, axis=self.axis)
|
|
1700
1805
|
|
|
1701
1806
|
return {self.output_key: data}
|
|
1702
1807
|
|
|
@@ -1734,19 +1839,29 @@ class UpMix(Operation):
|
|
|
1734
1839
|
return {self.output_key: data}
|
|
1735
1840
|
|
|
1736
1841
|
|
|
1842
|
+
def log_compress(data, eps=1e-16):
|
|
1843
|
+
"""Apply logarithmic compression to data."""
|
|
1844
|
+
eps = ops.convert_to_tensor(eps, dtype=data.dtype)
|
|
1845
|
+
data = ops.where(data == 0, eps, data) # Avoid log(0)
|
|
1846
|
+
return 20 * keras.ops.log10(data)
|
|
1847
|
+
|
|
1848
|
+
|
|
1737
1849
|
@ops_registry("log_compress")
|
|
1738
1850
|
class LogCompress(Operation):
|
|
1739
1851
|
"""Logarithmic compression of data."""
|
|
1740
1852
|
|
|
1741
|
-
def __init__(
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1853
|
+
def __init__(self, clip: bool = True, **kwargs):
|
|
1854
|
+
"""Initialize the LogCompress operation.
|
|
1855
|
+
|
|
1856
|
+
Args:
|
|
1857
|
+
clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
|
|
1858
|
+
"""
|
|
1745
1859
|
super().__init__(
|
|
1746
1860
|
input_data_type=DataTypes.ENVELOPE_DATA,
|
|
1747
1861
|
output_data_type=DataTypes.IMAGE,
|
|
1748
1862
|
**kwargs,
|
|
1749
1863
|
)
|
|
1864
|
+
self.clip = clip
|
|
1750
1865
|
|
|
1751
1866
|
def call(self, dynamic_range=None, **kwargs):
|
|
1752
1867
|
"""Apply logarithmic compression to data.
|
|
@@ -1763,20 +1878,43 @@ class LogCompress(Operation):
|
|
|
1763
1878
|
dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
|
|
1764
1879
|
dynamic_range = ops.cast(dynamic_range, data.dtype)
|
|
1765
1880
|
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
|
|
1881
|
+
compressed_data = log_compress(data)
|
|
1882
|
+
if self.clip:
|
|
1883
|
+
compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
|
|
1770
1884
|
|
|
1771
1885
|
return {self.output_key: compressed_data}
|
|
1772
1886
|
|
|
1773
1887
|
|
|
1888
|
+
def normalize(data, output_range, input_range=None):
|
|
1889
|
+
"""Normalize data to a given range.
|
|
1890
|
+
|
|
1891
|
+
Equivalent to `translate` with clipping.
|
|
1892
|
+
|
|
1893
|
+
Args:
|
|
1894
|
+
data (ops.Tensor): Input data to normalize.
|
|
1895
|
+
output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
|
|
1896
|
+
input_range (tuple, optional): Range of input data.
|
|
1897
|
+
If None, the range will be computed from the data.
|
|
1898
|
+
Defaults to None.
|
|
1899
|
+
"""
|
|
1900
|
+
if input_range is None:
|
|
1901
|
+
input_range = (None, None)
|
|
1902
|
+
minval, maxval = input_range
|
|
1903
|
+
if minval is None:
|
|
1904
|
+
minval = ops.min(data)
|
|
1905
|
+
if maxval is None:
|
|
1906
|
+
maxval = ops.max(data)
|
|
1907
|
+
data = ops.clip(data, minval, maxval)
|
|
1908
|
+
normalized_data = translate(data, (minval, maxval), output_range)
|
|
1909
|
+
return normalized_data
|
|
1910
|
+
|
|
1911
|
+
|
|
1774
1912
|
@ops_registry("normalize")
|
|
1775
1913
|
class Normalize(Operation):
|
|
1776
1914
|
"""Normalize data to a given range."""
|
|
1777
1915
|
|
|
1778
1916
|
def __init__(self, output_range=None, input_range=None, **kwargs):
|
|
1779
|
-
super().__init__(**kwargs)
|
|
1917
|
+
super().__init__(additional_output_keys=["minval", "maxval"], **kwargs)
|
|
1780
1918
|
if output_range is None:
|
|
1781
1919
|
output_range = (0, 1)
|
|
1782
1920
|
self.output_range = self.to_float32(output_range)
|
|
@@ -1821,11 +1959,9 @@ class Normalize(Operation):
|
|
|
1821
1959
|
if maxval is None:
|
|
1822
1960
|
maxval = ops.max(data)
|
|
1823
1961
|
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
# Map the data to the output range
|
|
1828
|
-
normalized_data = translate(data, (minval, maxval), self.output_range)
|
|
1962
|
+
normalized_data = normalize(
|
|
1963
|
+
data, output_range=self.output_range, input_range=(minval, maxval)
|
|
1964
|
+
)
|
|
1829
1965
|
|
|
1830
1966
|
return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
|
|
1831
1967
|
|
|
@@ -1855,6 +1991,18 @@ class ScanConvert(Operation):
|
|
|
1855
1991
|
input_data_type=DataTypes.IMAGE,
|
|
1856
1992
|
output_data_type=DataTypes.IMAGE_SC,
|
|
1857
1993
|
jittable=jittable,
|
|
1994
|
+
additional_output_keys=[
|
|
1995
|
+
"resolution",
|
|
1996
|
+
"x_lim",
|
|
1997
|
+
"y_lim",
|
|
1998
|
+
"z_lim",
|
|
1999
|
+
"rho_range",
|
|
2000
|
+
"theta_range",
|
|
2001
|
+
"phi_range",
|
|
2002
|
+
"d_rho",
|
|
2003
|
+
"d_theta",
|
|
2004
|
+
"d_phi",
|
|
2005
|
+
],
|
|
1858
2006
|
**kwargs,
|
|
1859
2007
|
)
|
|
1860
2008
|
self.order = order
|
|
@@ -2081,6 +2229,7 @@ class Demodulate(Operation):
|
|
|
2081
2229
|
input_data_type=DataTypes.RAW_DATA,
|
|
2082
2230
|
output_data_type=DataTypes.RAW_DATA,
|
|
2083
2231
|
jittable=True,
|
|
2232
|
+
additional_output_keys=["demodulation_frequency", "center_frequency", "n_ch"],
|
|
2084
2233
|
**kwargs,
|
|
2085
2234
|
)
|
|
2086
2235
|
self.axis = axis
|
|
@@ -2145,7 +2294,7 @@ class Lambda(Operation):
|
|
|
2145
2294
|
|
|
2146
2295
|
|
|
2147
2296
|
@ops_registry("pad")
|
|
2148
|
-
class Pad(Operation,
|
|
2297
|
+
class Pad(Operation, DataLayer):
|
|
2149
2298
|
"""Pad layer for padding tensors to a specified shape."""
|
|
2150
2299
|
|
|
2151
2300
|
def __init__(
|
|
@@ -2330,6 +2479,7 @@ class Downsample(Operation):
|
|
|
2330
2479
|
|
|
2331
2480
|
def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
|
|
2332
2481
|
super().__init__(
|
|
2482
|
+
additional_output_keys=["sampling_frequency", "n_ax"],
|
|
2333
2483
|
**kwargs,
|
|
2334
2484
|
)
|
|
2335
2485
|
self.factor = factor
|
|
@@ -3095,6 +3245,50 @@ def demodulate(data, center_frequency, sampling_frequency, axis=-3):
|
|
|
3095
3245
|
iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
|
|
3096
3246
|
|
|
3097
3247
|
# Split the complex signal into two channels
|
|
3098
|
-
iq_data_two_channel = complex_to_channels(iq_data_signal_complex
|
|
3248
|
+
iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
|
|
3099
3249
|
|
|
3100
3250
|
return iq_data_two_channel
|
|
3251
|
+
|
|
3252
|
+
|
|
3253
|
+
def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
|
|
3254
|
+
"""Compute the time of the peak of each waveform in a stack of waveforms.
|
|
3255
|
+
|
|
3256
|
+
Args:
|
|
3257
|
+
waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
|
|
3258
|
+
center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
|
|
3259
|
+
(n_waveforms,) or a scalar if all waveforms have the same center frequency.
|
|
3260
|
+
waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
|
|
3261
|
+
|
|
3262
|
+
Returns:
|
|
3263
|
+
ndarray: The time to peak for each waveform in seconds.
|
|
3264
|
+
"""
|
|
3265
|
+
t_peak = []
|
|
3266
|
+
center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
|
|
3267
|
+
for waveform, center_frequency in zip(waveforms, center_frequencies):
|
|
3268
|
+
t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
|
|
3269
|
+
return ops.stack(t_peak)
|
|
3270
|
+
|
|
3271
|
+
|
|
3272
|
+
def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
|
|
3273
|
+
"""Compute the time of the peak of the waveform.
|
|
3274
|
+
|
|
3275
|
+
Args:
|
|
3276
|
+
waveform (ndarray): The waveform of shape (n_samples).
|
|
3277
|
+
center_frequency (float): The center frequency of the waveform in Hz.
|
|
3278
|
+
waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
|
|
3279
|
+
|
|
3280
|
+
Returns:
|
|
3281
|
+
float: The time to peak for the waveform in seconds.
|
|
3282
|
+
"""
|
|
3283
|
+
n_samples = waveform.shape[0]
|
|
3284
|
+
if n_samples == 0:
|
|
3285
|
+
raise ValueError("Waveform has zero samples.")
|
|
3286
|
+
|
|
3287
|
+
waveforms_iq_complex_channels = demodulate(
|
|
3288
|
+
waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
|
|
3289
|
+
)
|
|
3290
|
+
waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
|
|
3291
|
+
envelope = ops.abs(waveforms_iq_complex)
|
|
3292
|
+
peak_idx = ops.argmax(envelope, axis=-1)
|
|
3293
|
+
t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
|
|
3294
|
+
return t_peak
|