zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +3 -3
- zea/agent/masks.py +2 -2
- zea/agent/selection.py +3 -3
- zea/backend/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +1 -5
- zea/beamform/beamformer.py +4 -2
- zea/beamform/pfield.py +2 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/data/__init__.py +0 -9
- zea/data/augmentations.py +222 -29
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +164 -0
- zea/data/convert/camus.py +106 -40
- zea/data/convert/echonet.py +184 -83
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/verasonics.py +1247 -0
- zea/data/data_format.py +124 -6
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +109 -70
- zea/data/file.py +119 -82
- zea/data/file_operations.py +496 -0
- zea/data/preset_utils.py +2 -2
- zea/display.py +8 -9
- zea/doppler.py +5 -5
- zea/func/__init__.py +109 -0
- zea/{tensor_ops.py → func/tensor.py} +113 -69
- zea/func/ultrasound.py +500 -0
- zea/internal/_generate_keras_ops.py +5 -5
- zea/internal/checks.py +6 -12
- zea/internal/operators.py +4 -0
- zea/io_lib.py +108 -160
- zea/metrics.py +6 -5
- zea/models/__init__.py +1 -1
- zea/models/diffusion.py +63 -12
- zea/models/echonetlvh.py +1 -1
- zea/models/gmm.py +1 -1
- zea/models/lv_segmentation.py +2 -0
- zea/ops/__init__.py +188 -0
- zea/ops/base.py +442 -0
- zea/{keras_ops.py → ops/keras_ops.py} +2 -2
- zea/ops/pipeline.py +1472 -0
- zea/ops/tensor.py +356 -0
- zea/ops/ultrasound.py +890 -0
- zea/probes.py +2 -10
- zea/scan.py +35 -28
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/selection_tool.py +1 -1
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +11 -2
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
- zea/data/convert/matlab.py +0 -1237
- zea/ops.py +0 -3294
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
- {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops.py
DELETED
|
@@ -1,3294 +0,0 @@
|
|
|
1
|
-
"""Operations and Pipelines for ultrasound data processing.
|
|
2
|
-
|
|
3
|
-
This module contains two important classes, :class:`Operation` and :class:`Pipeline`,
|
|
4
|
-
which are used to process ultrasound data. A pipeline is a sequence of operations
|
|
5
|
-
that are applied to the data in a specific order.
|
|
6
|
-
|
|
7
|
-
Stand-alone manual usage
|
|
8
|
-
------------------------
|
|
9
|
-
|
|
10
|
-
Operations can be run on their own:
|
|
11
|
-
|
|
12
|
-
Examples
|
|
13
|
-
^^^^^^^^
|
|
14
|
-
.. doctest::
|
|
15
|
-
|
|
16
|
-
>>> import numpy as np
|
|
17
|
-
>>> from zea.ops import EnvelopeDetect
|
|
18
|
-
>>> data = np.random.randn(2000, 128, 1)
|
|
19
|
-
>>> # static arguments are passed in the constructor
|
|
20
|
-
>>> envelope_detect = EnvelopeDetect(axis=-1)
|
|
21
|
-
>>> # other parameters can be passed here along with the data
|
|
22
|
-
>>> envelope_data = envelope_detect(data=data)
|
|
23
|
-
|
|
24
|
-
Using a pipeline
|
|
25
|
-
----------------
|
|
26
|
-
|
|
27
|
-
You can initialize with a default pipeline or create your own custom pipeline.
|
|
28
|
-
|
|
29
|
-
.. doctest::
|
|
30
|
-
|
|
31
|
-
>>> from zea.ops import Pipeline, EnvelopeDetect, Normalize, LogCompress
|
|
32
|
-
>>> pipeline = Pipeline.from_default()
|
|
33
|
-
|
|
34
|
-
>>> operations = [
|
|
35
|
-
... EnvelopeDetect(),
|
|
36
|
-
... Normalize(),
|
|
37
|
-
... LogCompress(),
|
|
38
|
-
... ]
|
|
39
|
-
>>> pipeline_custom = Pipeline(operations)
|
|
40
|
-
|
|
41
|
-
One can also load a pipeline from a config or yaml/json file:
|
|
42
|
-
|
|
43
|
-
.. doctest::
|
|
44
|
-
|
|
45
|
-
>>> from zea import Pipeline
|
|
46
|
-
|
|
47
|
-
>>> # From JSON string
|
|
48
|
-
>>> json_string = '{"operations": ["identity"]}'
|
|
49
|
-
>>> pipeline = Pipeline.from_json(json_string)
|
|
50
|
-
|
|
51
|
-
>>> # from yaml file
|
|
52
|
-
>>> import yaml
|
|
53
|
-
>>> from zea import Config
|
|
54
|
-
>>> # Create a sample pipeline YAML file
|
|
55
|
-
>>> pipeline_dict = {
|
|
56
|
-
... "operations": [
|
|
57
|
-
... {"name": "identity"},
|
|
58
|
-
... ]
|
|
59
|
-
... }
|
|
60
|
-
>>> with open("pipeline.yaml", "w") as f:
|
|
61
|
-
... yaml.dump(pipeline_dict, f)
|
|
62
|
-
>>> yaml_file = "pipeline.yaml"
|
|
63
|
-
>>> pipeline = Pipeline.from_yaml(yaml_file)
|
|
64
|
-
|
|
65
|
-
.. testcleanup::
|
|
66
|
-
|
|
67
|
-
import os
|
|
68
|
-
|
|
69
|
-
os.remove("pipeline.yaml")
|
|
70
|
-
|
|
71
|
-
Example of a yaml file:
|
|
72
|
-
|
|
73
|
-
.. code-block:: yaml
|
|
74
|
-
|
|
75
|
-
pipeline:
|
|
76
|
-
operations:
|
|
77
|
-
- name: demodulate
|
|
78
|
-
- name: "patched_grid"
|
|
79
|
-
params:
|
|
80
|
-
operations:
|
|
81
|
-
- name: tof_correction
|
|
82
|
-
- name: pfield_weighting
|
|
83
|
-
- name: delay_and_sum
|
|
84
|
-
num_patches: 100
|
|
85
|
-
- name: envelope_detect
|
|
86
|
-
- name: normalize
|
|
87
|
-
- name: log_compress
|
|
88
|
-
|
|
89
|
-
"""
|
|
90
|
-
|
|
91
|
-
import copy
|
|
92
|
-
import hashlib
|
|
93
|
-
import inspect
|
|
94
|
-
import json
|
|
95
|
-
from functools import partial
|
|
96
|
-
from typing import Any, Dict, List, Union
|
|
97
|
-
|
|
98
|
-
import keras
|
|
99
|
-
import numpy as np
|
|
100
|
-
import scipy
|
|
101
|
-
import yaml
|
|
102
|
-
from keras import ops
|
|
103
|
-
from keras.src.layers.preprocessing.data_layer import DataLayer
|
|
104
|
-
|
|
105
|
-
from zea import log
|
|
106
|
-
from zea.backend import jit
|
|
107
|
-
from zea.beamform.beamformer import tof_correction
|
|
108
|
-
from zea.config import Config
|
|
109
|
-
from zea.display import scan_convert
|
|
110
|
-
from zea.internal.checks import _assert_keys_and_axes
|
|
111
|
-
from zea.internal.core import (
|
|
112
|
-
DEFAULT_DYNAMIC_RANGE,
|
|
113
|
-
DataTypes,
|
|
114
|
-
ZEADecoderJSON,
|
|
115
|
-
ZEAEncoderJSON,
|
|
116
|
-
dict_to_tensor,
|
|
117
|
-
)
|
|
118
|
-
from zea.internal.core import Object as ZEAObject
|
|
119
|
-
from zea.internal.registry import ops_registry
|
|
120
|
-
from zea.probes import Probe
|
|
121
|
-
from zea.scan import Scan
|
|
122
|
-
from zea.simulator import simulate_rf
|
|
123
|
-
from zea.tensor_ops import resample, reshape_axis, translate, vmap
|
|
124
|
-
from zea.utils import (
|
|
125
|
-
FunctionTimer,
|
|
126
|
-
deep_compare,
|
|
127
|
-
map_negative_indices,
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def get_ops(ops_name):
|
|
132
|
-
"""Get the operation from the registry."""
|
|
133
|
-
return ops_registry[ops_name]
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class Operation(keras.Operation):
|
|
137
|
-
"""
|
|
138
|
-
A base abstract class for operations in the pipeline with caching functionality.
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
|
-
def __init__(
|
|
142
|
-
self,
|
|
143
|
-
input_data_type: Union[DataTypes, None] = None,
|
|
144
|
-
output_data_type: Union[DataTypes, None] = None,
|
|
145
|
-
key: Union[str, None] = "data",
|
|
146
|
-
output_key: Union[str, None] = None,
|
|
147
|
-
cache_inputs: Union[bool, List[str]] = False,
|
|
148
|
-
cache_outputs: bool = False,
|
|
149
|
-
jit_compile: bool = True,
|
|
150
|
-
with_batch_dim: bool = True,
|
|
151
|
-
jit_kwargs: dict | None = None,
|
|
152
|
-
jittable: bool = True,
|
|
153
|
-
additional_output_keys: List[str] = None,
|
|
154
|
-
**kwargs,
|
|
155
|
-
):
|
|
156
|
-
"""
|
|
157
|
-
Args:
|
|
158
|
-
input_data_type (DataTypes): The data type of the input data
|
|
159
|
-
output_data_type (DataTypes): The data type of the output data
|
|
160
|
-
key: The key for the input data (operation will operate on this key)
|
|
161
|
-
Defaults to "data".
|
|
162
|
-
output_key: The key for the output data (operation will output to this key)
|
|
163
|
-
Defaults to the same as the input key. If you want to store intermediate
|
|
164
|
-
results, you can set this to a different key. But make sure to update the
|
|
165
|
-
input key of the next operation to match the output key of this operation.
|
|
166
|
-
cache_inputs: A list of input keys to cache or True to cache all inputs
|
|
167
|
-
cache_outputs: A list of output keys to cache or True to cache all outputs
|
|
168
|
-
jit_compile: Whether to JIT compile the 'call' method for faster execution
|
|
169
|
-
with_batch_dim: Whether operations should expect a batch dimension in the input
|
|
170
|
-
jit_kwargs: Additional keyword arguments for the JIT compiler
|
|
171
|
-
jittable: Whether the operation can be JIT compiled
|
|
172
|
-
additional_output_keys: A list of additional output keys produced by the operation.
|
|
173
|
-
These are used to track if all keys are available for downstream operations.
|
|
174
|
-
If the operation has a conditional output, it is best to add all possible
|
|
175
|
-
output keys here.
|
|
176
|
-
"""
|
|
177
|
-
super().__init__(**kwargs)
|
|
178
|
-
|
|
179
|
-
self.input_data_type = input_data_type
|
|
180
|
-
self.output_data_type = output_data_type
|
|
181
|
-
|
|
182
|
-
self.key = key # Key for input data
|
|
183
|
-
self.output_key = output_key # Key for output data
|
|
184
|
-
if self.output_key is None:
|
|
185
|
-
self.output_key = self.key
|
|
186
|
-
self.additional_output_keys = additional_output_keys or []
|
|
187
|
-
|
|
188
|
-
self.inputs = [] # Source(s) of input data (name of a previous operation)
|
|
189
|
-
self.allow_multiple_inputs = False # Only single input allowed by default
|
|
190
|
-
|
|
191
|
-
self.cache_inputs = cache_inputs
|
|
192
|
-
self.cache_outputs = cache_outputs
|
|
193
|
-
|
|
194
|
-
# Initialize input and output caches
|
|
195
|
-
self._input_cache = {}
|
|
196
|
-
self._output_cache = {}
|
|
197
|
-
|
|
198
|
-
# Obtain the input signature of the `call` method
|
|
199
|
-
self._trace_signatures()
|
|
200
|
-
|
|
201
|
-
if jit_kwargs is None:
|
|
202
|
-
jit_kwargs = {}
|
|
203
|
-
|
|
204
|
-
if keras.backend.backend() == "jax" and self.static_params:
|
|
205
|
-
jit_kwargs |= {"static_argnames": self.static_params}
|
|
206
|
-
|
|
207
|
-
self.jit_kwargs = jit_kwargs
|
|
208
|
-
|
|
209
|
-
self.with_batch_dim = with_batch_dim
|
|
210
|
-
self._jittable = jittable
|
|
211
|
-
|
|
212
|
-
# Set the jit compilation flag and compile the `call` method
|
|
213
|
-
# Set zea logger level to suppress warnings regarding
|
|
214
|
-
# torch not being able to compile the function
|
|
215
|
-
with log.set_level("ERROR"):
|
|
216
|
-
self.set_jit(jit_compile)
|
|
217
|
-
|
|
218
|
-
@property
|
|
219
|
-
def output_keys(self) -> List[str]:
|
|
220
|
-
"""Get the output keys of the operation."""
|
|
221
|
-
return [self.output_key] + self.additional_output_keys
|
|
222
|
-
|
|
223
|
-
@property
|
|
224
|
-
def static_params(self):
|
|
225
|
-
"""Get the static parameters of the operation."""
|
|
226
|
-
return getattr(self.__class__, "STATIC_PARAMS", [])
|
|
227
|
-
|
|
228
|
-
def set_jit(self, jit_compile: bool):
|
|
229
|
-
"""Set the JIT compilation flag and set the `_call` method accordingly."""
|
|
230
|
-
self._jit_compile = jit_compile
|
|
231
|
-
if self._jit_compile and self.jittable:
|
|
232
|
-
self._call = jit(self.call, **self.jit_kwargs)
|
|
233
|
-
else:
|
|
234
|
-
self._call = self.call
|
|
235
|
-
|
|
236
|
-
def _trace_signatures(self):
|
|
237
|
-
"""
|
|
238
|
-
Analyze and store the input/output signatures of the `call` method.
|
|
239
|
-
"""
|
|
240
|
-
self._input_signature = inspect.signature(self.call)
|
|
241
|
-
self._valid_keys = set(self._input_signature.parameters.keys())
|
|
242
|
-
|
|
243
|
-
@property
|
|
244
|
-
def valid_keys(self) -> set:
|
|
245
|
-
"""Get the valid keys for the `call` method."""
|
|
246
|
-
return self._valid_keys
|
|
247
|
-
|
|
248
|
-
@property
|
|
249
|
-
def needs_keys(self) -> set:
|
|
250
|
-
"""Get a set of all input keys needed by the operation."""
|
|
251
|
-
return self.valid_keys
|
|
252
|
-
|
|
253
|
-
@property
|
|
254
|
-
def jittable(self):
|
|
255
|
-
"""Check if the operation can be JIT compiled."""
|
|
256
|
-
return self._jittable
|
|
257
|
-
|
|
258
|
-
def call(self, **kwargs):
|
|
259
|
-
"""
|
|
260
|
-
Abstract method that defines the processing logic for the operation.
|
|
261
|
-
Subclasses must implement this method.
|
|
262
|
-
"""
|
|
263
|
-
raise NotImplementedError
|
|
264
|
-
|
|
265
|
-
def set_input_cache(self, input_cache: Dict[str, Any]):
|
|
266
|
-
"""
|
|
267
|
-
Set a cache for inputs, then retrace the function if necessary.
|
|
268
|
-
|
|
269
|
-
Args:
|
|
270
|
-
input_cache: A dictionary containing cached inputs.
|
|
271
|
-
"""
|
|
272
|
-
self._input_cache.update(input_cache)
|
|
273
|
-
self._trace_signatures() # Retrace after updating cache to ensure correctness.
|
|
274
|
-
|
|
275
|
-
def set_output_cache(self, output_cache: Dict[str, Any]):
|
|
276
|
-
"""
|
|
277
|
-
Set a cache for outputs, then retrace the function if necessary.
|
|
278
|
-
|
|
279
|
-
Args:
|
|
280
|
-
output_cache: A dictionary containing cached outputs.
|
|
281
|
-
"""
|
|
282
|
-
self._output_cache.update(output_cache)
|
|
283
|
-
self._trace_signatures() # Retrace after updating cache to ensure correctness.
|
|
284
|
-
|
|
285
|
-
def clear_cache(self):
|
|
286
|
-
"""
|
|
287
|
-
Clear the input and output caches.
|
|
288
|
-
"""
|
|
289
|
-
self._input_cache.clear()
|
|
290
|
-
self._output_cache.clear()
|
|
291
|
-
|
|
292
|
-
def _hash_inputs(self, kwargs: Dict) -> str:
|
|
293
|
-
"""
|
|
294
|
-
Generate a hash for the given inputs to use as a cache key.
|
|
295
|
-
|
|
296
|
-
Args:
|
|
297
|
-
kwargs: Keyword arguments.
|
|
298
|
-
|
|
299
|
-
Returns:
|
|
300
|
-
A unique hash representing the inputs.
|
|
301
|
-
"""
|
|
302
|
-
input_json = json.dumps(kwargs, sort_keys=True, default=str)
|
|
303
|
-
return hashlib.md5(input_json.encode()).hexdigest()
|
|
304
|
-
|
|
305
|
-
def __call__(self, *args, **kwargs) -> Dict:
|
|
306
|
-
"""
|
|
307
|
-
Process the input keyword arguments and return the processed results.
|
|
308
|
-
|
|
309
|
-
Args:
|
|
310
|
-
kwargs: Keyword arguments to be processed.
|
|
311
|
-
|
|
312
|
-
Returns:
|
|
313
|
-
Combined input and output as kwargs.
|
|
314
|
-
"""
|
|
315
|
-
if args:
|
|
316
|
-
example_usage = f" result = {ops_registry.get_name(self)}({self.key}=my_data"
|
|
317
|
-
valid_keys_no_kwargs = self.valid_keys - {"kwargs"}
|
|
318
|
-
if valid_keys_no_kwargs:
|
|
319
|
-
example_usage += f", {list(valid_keys_no_kwargs)[0]}=param1, ..., **kwargs)"
|
|
320
|
-
else:
|
|
321
|
-
example_usage += ", **kwargs)"
|
|
322
|
-
raise TypeError(
|
|
323
|
-
f"{self.__class__.__name__}.__call__() only accepts keyword arguments. "
|
|
324
|
-
"Positional arguments are not allowed.\n"
|
|
325
|
-
f"Received positional arguments: {args}\n"
|
|
326
|
-
"Example usage:\n"
|
|
327
|
-
f"{example_usage}"
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
# Merge cached inputs with provided ones
|
|
331
|
-
merged_kwargs = {**self._input_cache, **kwargs}
|
|
332
|
-
|
|
333
|
-
# Return cached output if available
|
|
334
|
-
if self.cache_outputs:
|
|
335
|
-
cache_key = self._hash_inputs(merged_kwargs)
|
|
336
|
-
if cache_key in self._output_cache:
|
|
337
|
-
return {**merged_kwargs, **self._output_cache[cache_key]}
|
|
338
|
-
|
|
339
|
-
# Filter kwargs to match the valid keys of the `call` method
|
|
340
|
-
if "kwargs" not in self.valid_keys:
|
|
341
|
-
filtered_kwargs = {k: v for k, v in merged_kwargs.items() if k in self.valid_keys}
|
|
342
|
-
else:
|
|
343
|
-
filtered_kwargs = merged_kwargs
|
|
344
|
-
|
|
345
|
-
# Call the processing function
|
|
346
|
-
# If you want to jump in with debugger please set `jit_compile=False`
|
|
347
|
-
# when initializing the pipeline.
|
|
348
|
-
processed_output = self._call(**filtered_kwargs)
|
|
349
|
-
|
|
350
|
-
# Ensure the output is always a dictionary
|
|
351
|
-
if not isinstance(processed_output, dict):
|
|
352
|
-
raise TypeError(
|
|
353
|
-
f"The `call` method must return a dictionary. Got {type(processed_output)}."
|
|
354
|
-
)
|
|
355
|
-
|
|
356
|
-
# Merge outputs with inputs
|
|
357
|
-
combined_kwargs = {**merged_kwargs, **processed_output}
|
|
358
|
-
|
|
359
|
-
# Cache the result if caching is enabled
|
|
360
|
-
if self.cache_outputs:
|
|
361
|
-
if isinstance(self.cache_outputs, list):
|
|
362
|
-
cached_output = {
|
|
363
|
-
k: v for k, v in processed_output.items() if k in self.cache_outputs
|
|
364
|
-
}
|
|
365
|
-
else:
|
|
366
|
-
cached_output = processed_output
|
|
367
|
-
self._output_cache[cache_key] = cached_output
|
|
368
|
-
|
|
369
|
-
return combined_kwargs
|
|
370
|
-
|
|
371
|
-
def get_dict(self):
|
|
372
|
-
"""Get the configuration of the operation. Inherit from keras.Operation."""
|
|
373
|
-
config = {}
|
|
374
|
-
config.update({"name": ops_registry.get_name(self)})
|
|
375
|
-
config["params"] = {
|
|
376
|
-
"key": self.key,
|
|
377
|
-
"output_key": self.output_key,
|
|
378
|
-
"cache_inputs": self.cache_inputs,
|
|
379
|
-
"cache_outputs": self.cache_outputs,
|
|
380
|
-
"jit_compile": self._jit_compile,
|
|
381
|
-
"with_batch_dim": self.with_batch_dim,
|
|
382
|
-
"jit_kwargs": self.jit_kwargs,
|
|
383
|
-
}
|
|
384
|
-
return config
|
|
385
|
-
|
|
386
|
-
def __eq__(self, other):
|
|
387
|
-
"""Check equality of two operations based on type and configuration."""
|
|
388
|
-
if not isinstance(other, Operation):
|
|
389
|
-
return False
|
|
390
|
-
|
|
391
|
-
# Compare the class name and parameters
|
|
392
|
-
if self.__class__.__name__ != other.__class__.__name__:
|
|
393
|
-
return False
|
|
394
|
-
|
|
395
|
-
# Compare the name assigned to the operation
|
|
396
|
-
name = ops_registry.get_name(self)
|
|
397
|
-
other_name = ops_registry.get_name(other)
|
|
398
|
-
if name != other_name:
|
|
399
|
-
return False
|
|
400
|
-
|
|
401
|
-
# Compare the parameters of the operations
|
|
402
|
-
if not deep_compare(self.get_dict(), other.get_dict()):
|
|
403
|
-
return False
|
|
404
|
-
|
|
405
|
-
return True
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
@ops_registry("pipeline")
|
|
409
|
-
class Pipeline:
|
|
410
|
-
"""Pipeline class for processing ultrasound data through a series of operations."""
|
|
411
|
-
|
|
412
|
-
def __init__(
|
|
413
|
-
self,
|
|
414
|
-
operations: List[Operation],
|
|
415
|
-
with_batch_dim: bool = True,
|
|
416
|
-
jit_options: Union[str, None] = "ops",
|
|
417
|
-
jit_kwargs: dict | None = None,
|
|
418
|
-
name="pipeline",
|
|
419
|
-
validate=True,
|
|
420
|
-
timed: bool = False,
|
|
421
|
-
):
|
|
422
|
-
"""
|
|
423
|
-
Initialize a pipeline.
|
|
424
|
-
|
|
425
|
-
Args:
|
|
426
|
-
operations (list): A list of Operation instances representing the operations
|
|
427
|
-
to be performed.
|
|
428
|
-
with_batch_dim (bool, optional): Whether operations should expect a batch dimension.
|
|
429
|
-
Defaults to True.
|
|
430
|
-
jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None.
|
|
431
|
-
|
|
432
|
-
- "pipeline": compiles the entire pipeline as a single function.
|
|
433
|
-
This may be faster but does not preserve python control flow, such as caching.
|
|
434
|
-
|
|
435
|
-
- "ops": compiles each operation separately. This preserves python control flow and
|
|
436
|
-
caching functionality, but speeds up the operations.
|
|
437
|
-
|
|
438
|
-
- None: disables JIT compilation.
|
|
439
|
-
|
|
440
|
-
Defaults to "ops".
|
|
441
|
-
|
|
442
|
-
jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler.
|
|
443
|
-
name (str, optional): The name of the pipeline. Defaults to "pipeline".
|
|
444
|
-
validate (bool, optional): Whether to validate the pipeline. Defaults to True.
|
|
445
|
-
timed (bool, optional): Whether to time each operation. Defaults to False.
|
|
446
|
-
|
|
447
|
-
"""
|
|
448
|
-
self._call_pipeline = self.call
|
|
449
|
-
self.name = name
|
|
450
|
-
|
|
451
|
-
self._pipeline_layers = operations
|
|
452
|
-
|
|
453
|
-
if jit_options not in ["pipeline", "ops", None]:
|
|
454
|
-
raise ValueError("jit_options must be 'pipeline', 'ops', or None")
|
|
455
|
-
|
|
456
|
-
self.with_batch_dim = with_batch_dim
|
|
457
|
-
self._validate_flag = validate
|
|
458
|
-
|
|
459
|
-
# Setup timer
|
|
460
|
-
if jit_options == "pipeline" and timed:
|
|
461
|
-
raise ValueError(
|
|
462
|
-
"timed=True cannot be used with jit_options='pipeline' as the entire "
|
|
463
|
-
"pipeline is compiled into a single function. Try setting jit_options to "
|
|
464
|
-
"'ops' or None."
|
|
465
|
-
)
|
|
466
|
-
if timed:
|
|
467
|
-
log.warning(
|
|
468
|
-
"Timer has been initialized for the pipeline. To get an accurate timing estimate, "
|
|
469
|
-
"the `block_until_ready()` is used, which will slow down the execution, so "
|
|
470
|
-
"do not use for regular processing!"
|
|
471
|
-
)
|
|
472
|
-
self._callable_layers = self._get_timed_operations()
|
|
473
|
-
else:
|
|
474
|
-
self._callable_layers = self._pipeline_layers
|
|
475
|
-
self._timed = timed
|
|
476
|
-
|
|
477
|
-
if validate:
|
|
478
|
-
self.validate()
|
|
479
|
-
else:
|
|
480
|
-
log.warning("Pipeline validation is disabled, make sure to validate manually.")
|
|
481
|
-
|
|
482
|
-
if jit_kwargs is None:
|
|
483
|
-
jit_kwargs = {}
|
|
484
|
-
|
|
485
|
-
if keras.backend.backend() == "jax" and self.static_params != []:
|
|
486
|
-
jit_kwargs = {"static_argnames": self.static_params}
|
|
487
|
-
|
|
488
|
-
self.jit_kwargs = jit_kwargs
|
|
489
|
-
self.jit_options = jit_options # will handle the jit compilation
|
|
490
|
-
|
|
491
|
-
def needs(self, key) -> bool:
|
|
492
|
-
"""Check if the pipeline needs a specific key at the input."""
|
|
493
|
-
return key in self.needs_keys
|
|
494
|
-
|
|
495
|
-
@property
|
|
496
|
-
def output_keys(self) -> set:
|
|
497
|
-
"""All output keys the pipeline guarantees to produce."""
|
|
498
|
-
output_keys = set()
|
|
499
|
-
for operation in self.operations:
|
|
500
|
-
output_keys.update(operation.output_keys)
|
|
501
|
-
return output_keys
|
|
502
|
-
|
|
503
|
-
@property
|
|
504
|
-
def valid_keys(self) -> set:
|
|
505
|
-
"""Get a set of valid keys for the pipeline.
|
|
506
|
-
|
|
507
|
-
This is all keys that can be passed to the pipeline as input.
|
|
508
|
-
"""
|
|
509
|
-
valid_keys = set()
|
|
510
|
-
for operation in self.operations:
|
|
511
|
-
valid_keys.update(operation.valid_keys)
|
|
512
|
-
return valid_keys
|
|
513
|
-
|
|
514
|
-
@property
|
|
515
|
-
def static_params(self) -> List[str]:
|
|
516
|
-
"""Get a list of static parameters for the pipeline."""
|
|
517
|
-
static_params = []
|
|
518
|
-
for operation in self.operations:
|
|
519
|
-
static_params.extend(operation.static_params)
|
|
520
|
-
return list(set(static_params))
|
|
521
|
-
|
|
522
|
-
@property
|
|
523
|
-
def needs_keys(self) -> set:
|
|
524
|
-
"""Get a set of all input keys needed by the pipeline.
|
|
525
|
-
|
|
526
|
-
Will keep track of keys that are already provided by previous operations.
|
|
527
|
-
"""
|
|
528
|
-
needs = set()
|
|
529
|
-
has_so_far = set()
|
|
530
|
-
previous_operation = None
|
|
531
|
-
for operation in self.operations:
|
|
532
|
-
if previous_operation is not None:
|
|
533
|
-
has_so_far.update(previous_operation.output_keys)
|
|
534
|
-
needs.update(operation.needs_keys - has_so_far)
|
|
535
|
-
previous_operation = operation
|
|
536
|
-
return needs
|
|
537
|
-
|
|
538
|
-
@classmethod
|
|
539
|
-
def from_default(
|
|
540
|
-
cls, num_patches=100, baseband=False, pfield=False, timed=False, **kwargs
|
|
541
|
-
) -> "Pipeline":
|
|
542
|
-
"""Create a default pipeline.
|
|
543
|
-
|
|
544
|
-
Args:
|
|
545
|
-
num_patches (int): Number of patches for the PatchedGrid operation.
|
|
546
|
-
Defaults to 100. If you get an out of memory error, try to increase this number.
|
|
547
|
-
baseband (bool): If True, assume the input data is baseband (I/Q) data,
|
|
548
|
-
which has 2 channels (last dim). Defaults to False, which assumes RF data,
|
|
549
|
-
so input signal has a single channel dim and is still on carrier frequency.
|
|
550
|
-
pfield (bool): If True, apply Pfield weighting. Defaults to False.
|
|
551
|
-
This will calculate pressure field and only beamform the data to those locations.
|
|
552
|
-
timed (bool, optional): Whether to time each operation. Defaults to False.
|
|
553
|
-
**kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
|
|
554
|
-
|
|
555
|
-
"""
|
|
556
|
-
operations = []
|
|
557
|
-
|
|
558
|
-
# Add the demodulate operation
|
|
559
|
-
if not baseband:
|
|
560
|
-
operations.append(Demodulate())
|
|
561
|
-
|
|
562
|
-
# Get beamforming ops
|
|
563
|
-
beamforming = [
|
|
564
|
-
TOFCorrection(),
|
|
565
|
-
DelayAndSum(),
|
|
566
|
-
]
|
|
567
|
-
if pfield:
|
|
568
|
-
beamforming.insert(1, PfieldWeighting())
|
|
569
|
-
|
|
570
|
-
# Optionally add patching
|
|
571
|
-
if num_patches > 1:
|
|
572
|
-
beamforming = [PatchedGrid(operations=beamforming, num_patches=num_patches, **kwargs)]
|
|
573
|
-
|
|
574
|
-
# Add beamforming ops
|
|
575
|
-
operations += beamforming
|
|
576
|
-
|
|
577
|
-
# Add display ops
|
|
578
|
-
operations += [
|
|
579
|
-
EnvelopeDetect(),
|
|
580
|
-
Normalize(),
|
|
581
|
-
LogCompress(),
|
|
582
|
-
]
|
|
583
|
-
return cls(operations, timed=timed, **kwargs)
|
|
584
|
-
|
|
585
|
-
def copy(self) -> "Pipeline":
|
|
586
|
-
"""Create a copy of the pipeline."""
|
|
587
|
-
return Pipeline(
|
|
588
|
-
self._pipeline_layers.copy(),
|
|
589
|
-
with_batch_dim=self.with_batch_dim,
|
|
590
|
-
jit_options=self.jit_options,
|
|
591
|
-
jit_kwargs=self.jit_kwargs,
|
|
592
|
-
name=self.name,
|
|
593
|
-
validate=self._validate_flag,
|
|
594
|
-
timed=self._timed,
|
|
595
|
-
)
|
|
596
|
-
|
|
597
|
-
def reinitialize(self):
|
|
598
|
-
"""Reinitialize the pipeline in place."""
|
|
599
|
-
self.__init__(
|
|
600
|
-
self._pipeline_layers,
|
|
601
|
-
with_batch_dim=self.with_batch_dim,
|
|
602
|
-
jit_options=self.jit_options,
|
|
603
|
-
jit_kwargs=self.jit_kwargs,
|
|
604
|
-
name=self.name,
|
|
605
|
-
validate=self._validate_flag,
|
|
606
|
-
timed=self._timed,
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
def prepend(self, operation: Operation):
|
|
610
|
-
"""Prepend an operation to the pipeline."""
|
|
611
|
-
self._pipeline_layers.insert(0, operation)
|
|
612
|
-
self.reinitialize()
|
|
613
|
-
|
|
614
|
-
def append(self, operation: Operation):
|
|
615
|
-
"""Append an operation to the pipeline."""
|
|
616
|
-
self._pipeline_layers.append(operation)
|
|
617
|
-
self.reinitialize()
|
|
618
|
-
|
|
619
|
-
def insert(self, index: int, operation: Operation):
|
|
620
|
-
"""Insert an operation at a specific index in the pipeline."""
|
|
621
|
-
if index < 0 or index > len(self._pipeline_layers):
|
|
622
|
-
raise IndexError("Index out of bounds for inserting operation.")
|
|
623
|
-
self._pipeline_layers.insert(index, operation)
|
|
624
|
-
self.reinitialize()
|
|
625
|
-
|
|
626
|
-
@property
|
|
627
|
-
def operations(self):
|
|
628
|
-
"""Alias for self.layers to match the zea naming convention"""
|
|
629
|
-
return self._pipeline_layers
|
|
630
|
-
|
|
631
|
-
def reset_timer(self):
|
|
632
|
-
"""Reset the timer for timed operations."""
|
|
633
|
-
if self._timed:
|
|
634
|
-
self._callable_layers = self._get_timed_operations()
|
|
635
|
-
else:
|
|
636
|
-
log.warning(
|
|
637
|
-
"Timer has not been initialized. Set timed=True when initializing the pipeline."
|
|
638
|
-
)
|
|
639
|
-
|
|
640
|
-
def _get_timed_operations(self):
|
|
641
|
-
"""Get a list of timed operations."""
|
|
642
|
-
self.timer = FunctionTimer()
|
|
643
|
-
return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
|
|
644
|
-
|
|
645
|
-
def call(self, **inputs):
|
|
646
|
-
"""Process input data through the pipeline."""
|
|
647
|
-
for operation in self._callable_layers:
|
|
648
|
-
try:
|
|
649
|
-
outputs = operation(**inputs)
|
|
650
|
-
except KeyError as exc:
|
|
651
|
-
raise KeyError(
|
|
652
|
-
f"[zea.Pipeline] Operation '{operation.__class__.__name__}' "
|
|
653
|
-
f"requires input key '{exc.args[0]}', "
|
|
654
|
-
"but it was not provided in the inputs.\n"
|
|
655
|
-
"Check whether the objects (such as `zea.Scan`) passed to "
|
|
656
|
-
"`pipeline.prepare_parameters()` contain all required keys.\n"
|
|
657
|
-
f"Current list of all passed keys: {list(inputs.keys())}\n"
|
|
658
|
-
f"Valid keys for this pipeline: {self.valid_keys}"
|
|
659
|
-
) from exc
|
|
660
|
-
except Exception as exc:
|
|
661
|
-
raise RuntimeError(
|
|
662
|
-
f"[zea.Pipeline] Error in operation '{operation.__class__.__name__}': {exc}"
|
|
663
|
-
)
|
|
664
|
-
inputs = outputs
|
|
665
|
-
return outputs
|
|
666
|
-
|
|
667
|
-
def __call__(self, return_numpy=False, **inputs):
|
|
668
|
-
"""Process input data through the pipeline."""
|
|
669
|
-
|
|
670
|
-
if any(key in inputs for key in ["probe", "scan", "config"]) or any(
|
|
671
|
-
isinstance(arg, ZEAObject) for arg in inputs.values()
|
|
672
|
-
):
|
|
673
|
-
raise ValueError(
|
|
674
|
-
"Probe, Scan and Config objects should be first processed with "
|
|
675
|
-
"`Pipeline.prepare_parameters` before calling the pipeline. "
|
|
676
|
-
"e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
|
|
677
|
-
)
|
|
678
|
-
|
|
679
|
-
if any(isinstance(arg, str) for arg in inputs.values()):
|
|
680
|
-
raise ValueError(
|
|
681
|
-
"Pipeline does not support string inputs. "
|
|
682
|
-
"Please ensure all inputs are convertible to tensors."
|
|
683
|
-
)
|
|
684
|
-
|
|
685
|
-
## PROCESSING
|
|
686
|
-
outputs = self._call_pipeline(**inputs)
|
|
687
|
-
|
|
688
|
-
## PREPARE OUTPUT
|
|
689
|
-
if return_numpy:
|
|
690
|
-
# Convert tensors to numpy arrays but preserve None values
|
|
691
|
-
outputs = {
|
|
692
|
-
k: ops.convert_to_numpy(v) if v is ops.is_tensor(v) else v
|
|
693
|
-
for k, v in outputs.items()
|
|
694
|
-
}
|
|
695
|
-
|
|
696
|
-
return outputs
|
|
697
|
-
|
|
698
|
-
@property
|
|
699
|
-
def jit_options(self):
|
|
700
|
-
"""Get the jit_options property of the pipeline."""
|
|
701
|
-
return self._jit_options
|
|
702
|
-
|
|
703
|
-
@jit_options.setter
|
|
704
|
-
def jit_options(self, value: Union[str, None]):
|
|
705
|
-
"""Set the jit_options property of the pipeline."""
|
|
706
|
-
self._jit_options = value
|
|
707
|
-
if value == "pipeline":
|
|
708
|
-
assert self.jittable, log.error(
|
|
709
|
-
"jit_options 'pipeline' cannot be used as the entire pipeline is not jittable. "
|
|
710
|
-
"The following operations are not jittable: "
|
|
711
|
-
f"{self.unjitable_ops}. "
|
|
712
|
-
"Try setting jit_options to 'ops' or None."
|
|
713
|
-
)
|
|
714
|
-
self.jit()
|
|
715
|
-
return
|
|
716
|
-
else:
|
|
717
|
-
self.unjit()
|
|
718
|
-
|
|
719
|
-
for operation in self.operations:
|
|
720
|
-
if isinstance(operation, Pipeline):
|
|
721
|
-
operation.jit_options = value
|
|
722
|
-
else:
|
|
723
|
-
if operation.jittable and operation._jit_compile:
|
|
724
|
-
operation.set_jit(value == "ops")
|
|
725
|
-
|
|
726
|
-
def jit(self):
|
|
727
|
-
"""JIT compile the pipeline."""
|
|
728
|
-
self._call_pipeline = jit(self.call, **self.jit_kwargs)
|
|
729
|
-
|
|
730
|
-
def unjit(self):
|
|
731
|
-
"""Un-JIT compile the pipeline."""
|
|
732
|
-
self._call_pipeline = self.call
|
|
733
|
-
|
|
734
|
-
@property
|
|
735
|
-
def jittable(self):
|
|
736
|
-
"""Check if all operations in the pipeline are jittable."""
|
|
737
|
-
return all(operation.jittable for operation in self.operations)
|
|
738
|
-
|
|
739
|
-
@property
|
|
740
|
-
def unjitable_ops(self):
|
|
741
|
-
"""Get a list of operations that are not jittable."""
|
|
742
|
-
return [operation for operation in self.operations if not operation.jittable]
|
|
743
|
-
|
|
744
|
-
@property
|
|
745
|
-
def with_batch_dim(self):
|
|
746
|
-
"""Get the with_batch_dim property of the pipeline."""
|
|
747
|
-
return self._with_batch_dim
|
|
748
|
-
|
|
749
|
-
@with_batch_dim.setter
|
|
750
|
-
def with_batch_dim(self, value):
|
|
751
|
-
"""Set the with_batch_dim property of the pipeline."""
|
|
752
|
-
self._with_batch_dim = value
|
|
753
|
-
for operation in self.operations:
|
|
754
|
-
operation.with_batch_dim = value
|
|
755
|
-
|
|
756
|
-
@property
|
|
757
|
-
def input_data_type(self):
|
|
758
|
-
"""Get the input_data_type property of the pipeline."""
|
|
759
|
-
return self.operations[0].input_data_type
|
|
760
|
-
|
|
761
|
-
@property
|
|
762
|
-
def output_data_type(self):
|
|
763
|
-
"""Get the output_data_type property of the pipeline."""
|
|
764
|
-
return self.operations[-1].output_data_type
|
|
765
|
-
|
|
766
|
-
def validate(self):
|
|
767
|
-
"""Validate the pipeline by checking the compatibility of the operations."""
|
|
768
|
-
operations = self.operations
|
|
769
|
-
for i in range(len(operations) - 1):
|
|
770
|
-
if operations[i].output_data_type is None:
|
|
771
|
-
continue
|
|
772
|
-
if operations[i + 1].input_data_type is None:
|
|
773
|
-
continue
|
|
774
|
-
if operations[i].output_data_type != operations[i + 1].input_data_type:
|
|
775
|
-
raise ValueError(
|
|
776
|
-
f"Operation {operations[i].__class__.__name__} output data type "
|
|
777
|
-
f"({operations[i].output_data_type}) is not compatible "
|
|
778
|
-
f"with the input data type ({operations[i + 1].input_data_type}) "
|
|
779
|
-
f"of operation {operations[i + 1].__class__.__name__}"
|
|
780
|
-
)
|
|
781
|
-
|
|
782
|
-
def set_params(self, **params):
|
|
783
|
-
"""Set parameters for the operations in the pipeline by adding them to the cache."""
|
|
784
|
-
for operation in self.operations:
|
|
785
|
-
operation_params = {
|
|
786
|
-
key: value for key, value in params.items() if key in operation.valid_keys
|
|
787
|
-
}
|
|
788
|
-
if operation_params:
|
|
789
|
-
operation.set_input_cache(operation_params)
|
|
790
|
-
|
|
791
|
-
def get_params(self, per_operation: bool = False):
|
|
792
|
-
"""Get a snapshot of the current parameters of the operations in the pipeline.
|
|
793
|
-
|
|
794
|
-
Args:
|
|
795
|
-
per_operation (bool): If True, return a list of dictionaries for each operation.
|
|
796
|
-
If False, return a single dictionary with all parameters combined.
|
|
797
|
-
"""
|
|
798
|
-
if per_operation:
|
|
799
|
-
return [operation._input_cache.copy() for operation in self.operations]
|
|
800
|
-
else:
|
|
801
|
-
params = {}
|
|
802
|
-
for operation in self.operations:
|
|
803
|
-
params.update(operation._input_cache)
|
|
804
|
-
return params
|
|
805
|
-
|
|
806
|
-
def __str__(self):
|
|
807
|
-
"""String representation of the pipeline.
|
|
808
|
-
|
|
809
|
-
Will print on two parallel pipeline lines if it detects a splitting operations
|
|
810
|
-
(such as multi_bandpass_filter)
|
|
811
|
-
Will merge the pipeline lines if it detects a stacking operation (such as stack)
|
|
812
|
-
"""
|
|
813
|
-
split_operations = []
|
|
814
|
-
merge_operations = ["Stack"]
|
|
815
|
-
|
|
816
|
-
operations = [operation.__class__.__name__ for operation in self.operations]
|
|
817
|
-
string = " -> ".join(operations)
|
|
818
|
-
|
|
819
|
-
if any(operation in split_operations for operation in operations):
|
|
820
|
-
# a second line is needed with same length as the first line
|
|
821
|
-
split_line = " " * len(string)
|
|
822
|
-
# find the splitting operation and index and print \-> instead of -> after
|
|
823
|
-
split_detected = False
|
|
824
|
-
merge_detected = False
|
|
825
|
-
split_operation = None
|
|
826
|
-
for operation in operations:
|
|
827
|
-
if operation in split_operations:
|
|
828
|
-
index = string.index(operation)
|
|
829
|
-
index = index + len(operation)
|
|
830
|
-
split_line = split_line[:index] + "\\->" + split_line[index + len("\\->") :]
|
|
831
|
-
split_detected = True
|
|
832
|
-
merge_detected = False
|
|
833
|
-
split_operation = operation
|
|
834
|
-
continue
|
|
835
|
-
|
|
836
|
-
if operation in merge_operations:
|
|
837
|
-
index = string.index(operation)
|
|
838
|
-
index = index - 4
|
|
839
|
-
split_line = split_line[:index] + "/" + split_line[index + 1 :]
|
|
840
|
-
split_detected = False
|
|
841
|
-
merge_detected = True
|
|
842
|
-
continue
|
|
843
|
-
|
|
844
|
-
if split_detected:
|
|
845
|
-
# print all operations in the second line
|
|
846
|
-
index = string.index(operation)
|
|
847
|
-
split_line = (
|
|
848
|
-
split_line[:index]
|
|
849
|
-
+ operation
|
|
850
|
-
+ " -> "
|
|
851
|
-
+ split_line[index + len(operation) + len(" -> ") :]
|
|
852
|
-
)
|
|
853
|
-
assert merge_detected is True, log.error(
|
|
854
|
-
"Pipeline was never merged back together (with Stack operation), even "
|
|
855
|
-
f"though it was split with {split_operation}. "
|
|
856
|
-
"Please properly define your operation chain."
|
|
857
|
-
)
|
|
858
|
-
return f"\n{string}\n{split_line}\n"
|
|
859
|
-
|
|
860
|
-
return string
|
|
861
|
-
|
|
862
|
-
def __repr__(self):
|
|
863
|
-
"""String representation of the pipeline."""
|
|
864
|
-
operations = []
|
|
865
|
-
for operation in self.operations:
|
|
866
|
-
if isinstance(operation, Pipeline):
|
|
867
|
-
operations.append(repr(operation))
|
|
868
|
-
else:
|
|
869
|
-
operations.append(operation.__class__.__name__)
|
|
870
|
-
return f"<Pipeline {self.name}=({', '.join(operations)})>"
|
|
871
|
-
|
|
872
|
-
@classmethod
|
|
873
|
-
def load(cls, file_path: str, **kwargs) -> "Pipeline":
|
|
874
|
-
"""Load a pipeline from a JSON or YAML file."""
|
|
875
|
-
if file_path.endswith(".json"):
|
|
876
|
-
with open(file_path, "r", encoding="utf-8") as f:
|
|
877
|
-
json_str = f.read()
|
|
878
|
-
return pipeline_from_json(json_str, **kwargs)
|
|
879
|
-
elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
|
|
880
|
-
return pipeline_from_yaml(file_path, **kwargs)
|
|
881
|
-
else:
|
|
882
|
-
raise ValueError("File must have extension .json, .yaml, or .yml")
|
|
883
|
-
|
|
884
|
-
def get_dict(self) -> dict:
|
|
885
|
-
"""Convert the pipeline to a dictionary."""
|
|
886
|
-
config = {}
|
|
887
|
-
config["name"] = ops_registry.get_name(self)
|
|
888
|
-
config["operations"] = self._pipeline_to_list(self)
|
|
889
|
-
config["params"] = {
|
|
890
|
-
"with_batch_dim": self.with_batch_dim,
|
|
891
|
-
"jit_options": self.jit_options,
|
|
892
|
-
"jit_kwargs": self.jit_kwargs,
|
|
893
|
-
}
|
|
894
|
-
return config
|
|
895
|
-
|
|
896
|
-
@staticmethod
|
|
897
|
-
def _pipeline_to_list(pipeline):
|
|
898
|
-
"""Convert the pipeline to a list of operations."""
|
|
899
|
-
ops_list = []
|
|
900
|
-
for op in pipeline.operations:
|
|
901
|
-
ops_list.append(op.get_dict())
|
|
902
|
-
return ops_list
|
|
903
|
-
|
|
904
|
-
@classmethod
|
|
905
|
-
def from_config(cls, config: Dict, **kwargs) -> "Pipeline":
|
|
906
|
-
"""Create a pipeline from a dictionary or ``zea.Config`` object.
|
|
907
|
-
|
|
908
|
-
Args:
|
|
909
|
-
config (dict or Config): Configuration dictionary or ``zea.Config`` object.
|
|
910
|
-
**kwargs: Additional keyword arguments to be passed to the pipeline.
|
|
911
|
-
|
|
912
|
-
Note:
|
|
913
|
-
Must have a ``pipeline`` key with a subkey ``operations``.
|
|
914
|
-
|
|
915
|
-
Example:
|
|
916
|
-
.. doctest::
|
|
917
|
-
|
|
918
|
-
>>> from zea import Config, Pipeline
|
|
919
|
-
>>> config = Config(
|
|
920
|
-
... {
|
|
921
|
-
... "operations": [
|
|
922
|
-
... "identity",
|
|
923
|
-
... ],
|
|
924
|
-
... }
|
|
925
|
-
... )
|
|
926
|
-
>>> pipeline = Pipeline.from_config(config)
|
|
927
|
-
"""
|
|
928
|
-
return pipeline_from_config(Config(config), **kwargs)
|
|
929
|
-
|
|
930
|
-
@classmethod
|
|
931
|
-
def from_yaml(cls, file_path: str, **kwargs) -> "Pipeline":
|
|
932
|
-
"""Create a pipeline from a YAML file.
|
|
933
|
-
|
|
934
|
-
Args:
|
|
935
|
-
file_path (str): Path to the YAML file.
|
|
936
|
-
**kwargs: Additional keyword arguments to be passed to the pipeline.
|
|
937
|
-
|
|
938
|
-
Note:
|
|
939
|
-
Must have the a `pipeline` key with a subkey `operations`.
|
|
940
|
-
|
|
941
|
-
Example:
|
|
942
|
-
.. doctest::
|
|
943
|
-
|
|
944
|
-
>>> import yaml
|
|
945
|
-
>>> from zea import Config
|
|
946
|
-
>>> # Create a sample pipeline YAML file
|
|
947
|
-
>>> pipeline_dict = {
|
|
948
|
-
... "operations": [
|
|
949
|
-
... "identity",
|
|
950
|
-
... ],
|
|
951
|
-
... }
|
|
952
|
-
>>> with open("pipeline.yaml", "w") as f:
|
|
953
|
-
... yaml.dump(pipeline_dict, f)
|
|
954
|
-
>>> from zea.ops import Pipeline
|
|
955
|
-
>>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
|
|
956
|
-
"""
|
|
957
|
-
return pipeline_from_yaml(file_path, **kwargs)
|
|
958
|
-
|
|
959
|
-
@classmethod
|
|
960
|
-
def from_json(cls, json_string: str, **kwargs) -> "Pipeline":
|
|
961
|
-
"""Create a pipeline from a JSON string.
|
|
962
|
-
|
|
963
|
-
Args:
|
|
964
|
-
json_string (str): JSON string representing the pipeline.
|
|
965
|
-
**kwargs: Additional keyword arguments to be passed to the pipeline.
|
|
966
|
-
|
|
967
|
-
Note:
|
|
968
|
-
Must have the `operations` key.
|
|
969
|
-
|
|
970
|
-
Example:
|
|
971
|
-
```python
|
|
972
|
-
json_string = '{"operations": ["identity"]}'
|
|
973
|
-
pipeline = Pipeline.from_json(json_string)
|
|
974
|
-
```
|
|
975
|
-
"""
|
|
976
|
-
return pipeline_from_json(json_string, **kwargs)
|
|
977
|
-
|
|
978
|
-
def to_config(self) -> Config:
|
|
979
|
-
"""Convert the pipeline to a `zea.Config` object."""
|
|
980
|
-
return pipeline_to_config(self)
|
|
981
|
-
|
|
982
|
-
def to_json(self) -> str:
|
|
983
|
-
"""Convert the pipeline to a JSON string."""
|
|
984
|
-
return pipeline_to_json(self)
|
|
985
|
-
|
|
986
|
-
def to_yaml(self, file_path: str) -> None:
|
|
987
|
-
"""Convert the pipeline to a YAML file."""
|
|
988
|
-
pipeline_to_yaml(self, file_path)
|
|
989
|
-
|
|
990
|
-
@property
|
|
991
|
-
def key(self) -> str:
|
|
992
|
-
"""Input key of the pipeline."""
|
|
993
|
-
return self.operations[0].key
|
|
994
|
-
|
|
995
|
-
@property
|
|
996
|
-
def output_key(self) -> str:
|
|
997
|
-
"""Output key of the pipeline."""
|
|
998
|
-
return self.operations[-1].output_key
|
|
999
|
-
|
|
1000
|
-
def __eq__(self, other):
|
|
1001
|
-
"""Check if two pipelines are equal."""
|
|
1002
|
-
if not isinstance(other, Pipeline):
|
|
1003
|
-
return False
|
|
1004
|
-
|
|
1005
|
-
# Compare the operations in both pipelines
|
|
1006
|
-
if len(self.operations) != len(other.operations):
|
|
1007
|
-
return False
|
|
1008
|
-
|
|
1009
|
-
for op1, op2 in zip(self.operations, other.operations):
|
|
1010
|
-
if not op1 == op2:
|
|
1011
|
-
return False
|
|
1012
|
-
|
|
1013
|
-
return True
|
|
1014
|
-
|
|
1015
|
-
def prepare_parameters(
|
|
1016
|
-
self,
|
|
1017
|
-
probe: Probe = None,
|
|
1018
|
-
scan: Scan = None,
|
|
1019
|
-
config: Config = None,
|
|
1020
|
-
**kwargs,
|
|
1021
|
-
):
|
|
1022
|
-
"""Prepare Probe, Scan and Config objects for the pipeline.
|
|
1023
|
-
|
|
1024
|
-
Serializes `zea.core.Object` instances and converts them to
|
|
1025
|
-
dictionary of tensors.
|
|
1026
|
-
|
|
1027
|
-
Args:
|
|
1028
|
-
probe: Probe object.
|
|
1029
|
-
scan: Scan object.
|
|
1030
|
-
config: Config object.
|
|
1031
|
-
include (None, "all", or list): Only include these parameter/computed property names.
|
|
1032
|
-
If None or "all", include all.
|
|
1033
|
-
exclude (None or list): Exclude these parameter/computed property names.
|
|
1034
|
-
If provided, these keys will be excluded from the output.
|
|
1035
|
-
Only one of include or exclude can be set.
|
|
1036
|
-
|
|
1037
|
-
**kwargs: Additional keyword arguments to be included in the inputs.
|
|
1038
|
-
|
|
1039
|
-
Returns:
|
|
1040
|
-
dict: Dictionary of inputs with all values as tensors.
|
|
1041
|
-
"""
|
|
1042
|
-
# Initialize dictionaries for probe, scan, and config
|
|
1043
|
-
probe_dict, scan_dict, config_dict = {}, {}, {}
|
|
1044
|
-
|
|
1045
|
-
# Process args to extract Probe, Scan, and Config objects
|
|
1046
|
-
if probe is not None:
|
|
1047
|
-
assert isinstance(probe, Probe), (
|
|
1048
|
-
f"Expected an instance of `zea.probes.Probe`, got {type(probe)}"
|
|
1049
|
-
)
|
|
1050
|
-
probe_dict = probe.to_tensor(keep_as_is=self.static_params)
|
|
1051
|
-
|
|
1052
|
-
if scan is not None:
|
|
1053
|
-
assert isinstance(scan, Scan), (
|
|
1054
|
-
f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
|
|
1055
|
-
)
|
|
1056
|
-
scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
|
|
1057
|
-
|
|
1058
|
-
if config is not None:
|
|
1059
|
-
assert isinstance(config, Config), (
|
|
1060
|
-
f"Expected an instance of `zea.config.Config`, got {type(config)}"
|
|
1061
|
-
)
|
|
1062
|
-
config_dict.update(config.to_tensor(keep_as_is=self.static_params))
|
|
1063
|
-
|
|
1064
|
-
# Convert all kwargs to tensors
|
|
1065
|
-
tensor_kwargs = dict_to_tensor(kwargs, keep_as_is=self.static_params)
|
|
1066
|
-
|
|
1067
|
-
# combine probe, scan, config and kwargs
|
|
1068
|
-
# explicitly so we know which keys overwrite which
|
|
1069
|
-
# kwargs > config > scan > probe
|
|
1070
|
-
inputs = {
|
|
1071
|
-
**probe_dict,
|
|
1072
|
-
**scan_dict,
|
|
1073
|
-
**config_dict,
|
|
1074
|
-
**tensor_kwargs,
|
|
1075
|
-
}
|
|
1076
|
-
|
|
1077
|
-
return inputs
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
def make_operation_chain(
|
|
1081
|
-
operation_chain: List[Union[str, Dict, Config, Operation, Pipeline]],
|
|
1082
|
-
) -> List[Operation]:
|
|
1083
|
-
"""Make an operation chain from a custom list of operations.
|
|
1084
|
-
|
|
1085
|
-
Args:
|
|
1086
|
-
operation_chain (list): List of operations to be performed.
|
|
1087
|
-
Each operation can be:
|
|
1088
|
-
- A string: operation initialized with default parameters
|
|
1089
|
-
- A dictionary: operation initialized with parameters in the dictionary
|
|
1090
|
-
- A Config object: converted to a dictionary and initialized
|
|
1091
|
-
- An Operation/Pipeline instance: used as-is
|
|
1092
|
-
|
|
1093
|
-
Returns:
|
|
1094
|
-
list: List of operations to be performed.
|
|
1095
|
-
|
|
1096
|
-
Example:
|
|
1097
|
-
.. doctest::
|
|
1098
|
-
|
|
1099
|
-
>>> from zea.ops import make_operation_chain, LogCompress
|
|
1100
|
-
>>> SomeCustomOperation = LogCompress # just for demonstration
|
|
1101
|
-
>>> chain = make_operation_chain(
|
|
1102
|
-
... [
|
|
1103
|
-
... "envelope_detect",
|
|
1104
|
-
... {"name": "normalize", "params": {"output_range": (0, 1)}},
|
|
1105
|
-
... SomeCustomOperation(),
|
|
1106
|
-
... ]
|
|
1107
|
-
... )
|
|
1108
|
-
"""
|
|
1109
|
-
chain = []
|
|
1110
|
-
for operation in operation_chain:
|
|
1111
|
-
# Handle already instantiated Operation or Pipeline objects
|
|
1112
|
-
if isinstance(operation, (Operation, Pipeline)):
|
|
1113
|
-
chain.append(operation)
|
|
1114
|
-
continue
|
|
1115
|
-
|
|
1116
|
-
assert isinstance(operation, (str, dict, Config)), (
|
|
1117
|
-
f"Operation {operation} should be a string, dict, Config object, Operation, or Pipeline"
|
|
1118
|
-
)
|
|
1119
|
-
|
|
1120
|
-
if isinstance(operation, str):
|
|
1121
|
-
operation_instance = get_ops(operation)()
|
|
1122
|
-
|
|
1123
|
-
else:
|
|
1124
|
-
if isinstance(operation, Config):
|
|
1125
|
-
operation = operation.serialize()
|
|
1126
|
-
|
|
1127
|
-
params = operation.get("params", {})
|
|
1128
|
-
op_name = operation.get("name")
|
|
1129
|
-
operation_cls = get_ops(op_name)
|
|
1130
|
-
|
|
1131
|
-
# Handle branches for branched pipeline
|
|
1132
|
-
if op_name == "branched_pipeline" and "branches" in operation:
|
|
1133
|
-
branch_configs = operation.get("branches", {})
|
|
1134
|
-
branches = []
|
|
1135
|
-
|
|
1136
|
-
# Convert each branch configuration to an operation chain
|
|
1137
|
-
for _, branch_config in branch_configs.items():
|
|
1138
|
-
if isinstance(branch_config, (list, np.ndarray)):
|
|
1139
|
-
# This is a list of operations
|
|
1140
|
-
branch = make_operation_chain(branch_config)
|
|
1141
|
-
elif "operations" in branch_config:
|
|
1142
|
-
# This is a pipeline-like branch
|
|
1143
|
-
branch = make_operation_chain(branch_config["operations"])
|
|
1144
|
-
else:
|
|
1145
|
-
# This is a single operation branch
|
|
1146
|
-
branch_op_cls = get_ops(branch_config["name"])
|
|
1147
|
-
branch_params = branch_config.get("params", {})
|
|
1148
|
-
branch = branch_op_cls(**branch_params)
|
|
1149
|
-
|
|
1150
|
-
branches.append(branch)
|
|
1151
|
-
|
|
1152
|
-
# Create the branched pipeline instance
|
|
1153
|
-
operation_instance = operation_cls(branches=branches, **params)
|
|
1154
|
-
# Check for nested operations at the same level as params
|
|
1155
|
-
elif "operations" in operation:
|
|
1156
|
-
nested_operations = make_operation_chain(operation["operations"])
|
|
1157
|
-
|
|
1158
|
-
# Instantiate pipeline-type operations with nested operations
|
|
1159
|
-
if issubclass(operation_cls, Pipeline):
|
|
1160
|
-
operation_instance = operation_cls(operations=nested_operations, **params)
|
|
1161
|
-
else:
|
|
1162
|
-
operation_instance = operation_cls(operations=nested_operations, **params)
|
|
1163
|
-
elif operation["name"] in ["patched_grid"]:
|
|
1164
|
-
nested_operations = make_operation_chain(operation["params"].pop("operations"))
|
|
1165
|
-
operation_instance = operation_cls(operations=nested_operations, **params)
|
|
1166
|
-
else:
|
|
1167
|
-
operation_instance = operation_cls(**params)
|
|
1168
|
-
|
|
1169
|
-
chain.append(operation_instance)
|
|
1170
|
-
|
|
1171
|
-
return chain
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
|
|
1175
|
-
"""
|
|
1176
|
-
Create a Pipeline instance from a Config object.
|
|
1177
|
-
"""
|
|
1178
|
-
assert "operations" in config, (
|
|
1179
|
-
"Config object must have an 'operations' key for pipeline creation."
|
|
1180
|
-
)
|
|
1181
|
-
assert isinstance(config.operations, (list, np.ndarray)), (
|
|
1182
|
-
"Config object must have a list or numpy array of operations for pipeline creation."
|
|
1183
|
-
)
|
|
1184
|
-
|
|
1185
|
-
operations = make_operation_chain(config.operations)
|
|
1186
|
-
|
|
1187
|
-
# merge pipeline config without operations with kwargs
|
|
1188
|
-
pipeline_config = copy.deepcopy(config)
|
|
1189
|
-
pipeline_config.pop("operations")
|
|
1190
|
-
|
|
1191
|
-
kwargs = {**pipeline_config, **kwargs}
|
|
1192
|
-
return Pipeline(operations=operations, **kwargs)
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
def pipeline_from_json(json_string: str, **kwargs) -> Pipeline:
|
|
1196
|
-
"""
|
|
1197
|
-
Create a Pipeline instance from a JSON string.
|
|
1198
|
-
"""
|
|
1199
|
-
pipeline_config = Config(json.loads(json_string, cls=ZEADecoderJSON))
|
|
1200
|
-
return pipeline_from_config(pipeline_config, **kwargs)
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
def pipeline_from_yaml(yaml_path: str, **kwargs) -> Pipeline:
|
|
1204
|
-
"""
|
|
1205
|
-
Create a Pipeline instance from a YAML file.
|
|
1206
|
-
"""
|
|
1207
|
-
with open(yaml_path, "r", encoding="utf-8") as f:
|
|
1208
|
-
pipeline_config = yaml.safe_load(f)
|
|
1209
|
-
operations = pipeline_config["operations"]
|
|
1210
|
-
return pipeline_from_config(Config({"operations": operations}), **kwargs)
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
def pipeline_to_config(pipeline: Pipeline) -> Config:
|
|
1214
|
-
"""
|
|
1215
|
-
Convert a Pipeline instance into a Config object.
|
|
1216
|
-
"""
|
|
1217
|
-
# TODO: we currently add the full pipeline as 1 operation to the config.
|
|
1218
|
-
# In another PR we should add a "pipeline" entry to the config instead of the "operations"
|
|
1219
|
-
# entry. This allows us to also have non-default pipeline classes as top level op.
|
|
1220
|
-
pipeline_dict = {"operations": [pipeline.get_dict()]}
|
|
1221
|
-
|
|
1222
|
-
# HACK: If the top level operation is a single pipeline, collapse it into the operations list.
|
|
1223
|
-
ops = pipeline_dict["operations"]
|
|
1224
|
-
if ops[0]["name"] == "pipeline" and len(ops) == 1:
|
|
1225
|
-
pipeline_dict = {"operations": ops[0]["operations"]}
|
|
1226
|
-
|
|
1227
|
-
return Config(pipeline_dict)
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
def pipeline_to_json(pipeline: Pipeline) -> str:
|
|
1231
|
-
"""
|
|
1232
|
-
Convert a Pipeline instance into a JSON string.
|
|
1233
|
-
"""
|
|
1234
|
-
pipeline_dict = {"operations": [pipeline.get_dict()]}
|
|
1235
|
-
|
|
1236
|
-
# HACK: If the top level operation is a single pipeline, collapse it into the operations list.
|
|
1237
|
-
ops = pipeline_dict["operations"]
|
|
1238
|
-
if ops[0]["name"] == "pipeline" and len(ops) == 1:
|
|
1239
|
-
pipeline_dict = {"operations": ops[0]["operations"]}
|
|
1240
|
-
|
|
1241
|
-
return json.dumps(pipeline_dict, cls=ZEAEncoderJSON, indent=4)
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
|
|
1245
|
-
"""
|
|
1246
|
-
Convert a Pipeline instance into a YAML file.
|
|
1247
|
-
"""
|
|
1248
|
-
pipeline_dict = pipeline.get_dict()
|
|
1249
|
-
|
|
1250
|
-
# HACK: If the top level operation is a single pipeline, collapse it into the operations list.
|
|
1251
|
-
ops = pipeline_dict["operations"]
|
|
1252
|
-
if ops[0]["name"] == "pipeline" and len(ops) == 1:
|
|
1253
|
-
pipeline_dict = {"operations": ops[0]["operations"]}
|
|
1254
|
-
|
|
1255
|
-
with open(file_path, "w", encoding="utf-8") as f:
|
|
1256
|
-
yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
@ops_registry("patched_grid")
|
|
1260
|
-
class PatchedGrid(Pipeline):
|
|
1261
|
-
"""
|
|
1262
|
-
With this class you can form a pipeline that will be applied to patches of the grid.
|
|
1263
|
-
This is useful to avoid OOM errors when processing large grids.
|
|
1264
|
-
|
|
1265
|
-
Some things to NOTE about this class:
|
|
1266
|
-
|
|
1267
|
-
- The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
|
|
1268
|
-
|
|
1269
|
-
- Changing anything other than `self.output_data_type` in the dict will not be propagated!
|
|
1270
|
-
|
|
1271
|
-
- Will be jitted as a single operation, not the individual operations.
|
|
1272
|
-
|
|
1273
|
-
- This class handles the batching.
|
|
1274
|
-
|
|
1275
|
-
"""
|
|
1276
|
-
|
|
1277
|
-
def __init__(self, *args, num_patches=10, **kwargs):
|
|
1278
|
-
super().__init__(*args, name="patched_grid", **kwargs)
|
|
1279
|
-
self.num_patches = num_patches
|
|
1280
|
-
|
|
1281
|
-
for operation in self.operations:
|
|
1282
|
-
if isinstance(operation, DelayAndSum):
|
|
1283
|
-
operation.reshape_grid = False
|
|
1284
|
-
|
|
1285
|
-
self._jittable_call = self.jittable_call
|
|
1286
|
-
|
|
1287
|
-
@property
|
|
1288
|
-
def jit_options(self):
|
|
1289
|
-
"""Get the jit_options property of the pipeline."""
|
|
1290
|
-
return self._jit_options
|
|
1291
|
-
|
|
1292
|
-
@jit_options.setter
|
|
1293
|
-
def jit_options(self, value):
|
|
1294
|
-
"""Set the jit_options property of the pipeline."""
|
|
1295
|
-
self._jit_options = value
|
|
1296
|
-
if value in ["pipeline", "ops"]:
|
|
1297
|
-
self.jit()
|
|
1298
|
-
else:
|
|
1299
|
-
self.unjit()
|
|
1300
|
-
|
|
1301
|
-
def jit(self):
|
|
1302
|
-
"""JIT compile the pipeline."""
|
|
1303
|
-
self._jittable_call = jit(self.jittable_call, **self.jit_kwargs)
|
|
1304
|
-
|
|
1305
|
-
def unjit(self):
|
|
1306
|
-
"""Un-JIT compile the pipeline."""
|
|
1307
|
-
self._jittable_call = self.jittable_call
|
|
1308
|
-
self._call_pipeline = self.call
|
|
1309
|
-
|
|
1310
|
-
@property
|
|
1311
|
-
def with_batch_dim(self):
|
|
1312
|
-
"""Get the with_batch_dim property of the pipeline."""
|
|
1313
|
-
return self._with_batch_dim
|
|
1314
|
-
|
|
1315
|
-
@with_batch_dim.setter
|
|
1316
|
-
def with_batch_dim(self, value):
|
|
1317
|
-
"""Set the with_batch_dim property of the pipeline.
|
|
1318
|
-
The class handles the batching so the operations have to be set to False."""
|
|
1319
|
-
self._with_batch_dim = value
|
|
1320
|
-
for operation in self.operations:
|
|
1321
|
-
operation.with_batch_dim = False
|
|
1322
|
-
|
|
1323
|
-
@property
|
|
1324
|
-
def _extra_keys(self):
|
|
1325
|
-
return {"flatgrid", "grid_size_x", "grid_size_z"}
|
|
1326
|
-
|
|
1327
|
-
@property
|
|
1328
|
-
def valid_keys(self) -> set:
|
|
1329
|
-
"""Get a set of valid keys for the pipeline.
|
|
1330
|
-
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1331
|
-
inside it)."""
|
|
1332
|
-
return super().valid_keys.union(self._extra_keys)
|
|
1333
|
-
|
|
1334
|
-
@property
|
|
1335
|
-
def needs_keys(self) -> set:
|
|
1336
|
-
"""Get a set of all input keys needed by the pipeline.
|
|
1337
|
-
Adds the parameters that PatchedGrid itself operates on (even if not used by operations
|
|
1338
|
-
inside it)."""
|
|
1339
|
-
return super().needs_keys.union(self._extra_keys)
|
|
1340
|
-
|
|
1341
|
-
def call_item(self, inputs):
|
|
1342
|
-
"""Process data in patches."""
|
|
1343
|
-
# Extract necessary parameters
|
|
1344
|
-
# make sure to add those as valid keys above!
|
|
1345
|
-
grid_size_x = inputs["grid_size_x"]
|
|
1346
|
-
grid_size_z = inputs["grid_size_z"]
|
|
1347
|
-
flatgrid = inputs.pop("flatgrid")
|
|
1348
|
-
|
|
1349
|
-
# Define a list of keys to look up for patching
|
|
1350
|
-
flat_pfield = inputs.pop("flat_pfield", None)
|
|
1351
|
-
|
|
1352
|
-
def patched_call(flatgrid, flat_pfield):
|
|
1353
|
-
out = super(PatchedGrid, self).call(
|
|
1354
|
-
flatgrid=flatgrid, flat_pfield=flat_pfield, **inputs
|
|
1355
|
-
)
|
|
1356
|
-
return out[self.output_key]
|
|
1357
|
-
|
|
1358
|
-
out = vmap(
|
|
1359
|
-
patched_call,
|
|
1360
|
-
chunks=self.num_patches,
|
|
1361
|
-
fn_supports_batch=True,
|
|
1362
|
-
disable_jit=not bool(self.jit_options),
|
|
1363
|
-
)(flatgrid, flat_pfield)
|
|
1364
|
-
|
|
1365
|
-
return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
|
|
1366
|
-
|
|
1367
|
-
def jittable_call(self, **inputs):
|
|
1368
|
-
"""Process input data through the pipeline."""
|
|
1369
|
-
if self._with_batch_dim:
|
|
1370
|
-
input_data = inputs.pop(self.key)
|
|
1371
|
-
output = ops.map(
|
|
1372
|
-
lambda x: self.call_item({self.key: x, **inputs}),
|
|
1373
|
-
input_data,
|
|
1374
|
-
)
|
|
1375
|
-
else:
|
|
1376
|
-
output = self.call_item(inputs)
|
|
1377
|
-
|
|
1378
|
-
return {self.output_key: output}
|
|
1379
|
-
|
|
1380
|
-
def call(self, **inputs):
|
|
1381
|
-
"""Process input data through the pipeline."""
|
|
1382
|
-
output = self._jittable_call(**inputs)
|
|
1383
|
-
inputs.update(output)
|
|
1384
|
-
return inputs
|
|
1385
|
-
|
|
1386
|
-
def get_dict(self):
|
|
1387
|
-
"""Get the configuration of the pipeline."""
|
|
1388
|
-
config = super().get_dict()
|
|
1389
|
-
config.update({"name": "patched_grid"})
|
|
1390
|
-
config["params"].update({"num_patches": self.num_patches})
|
|
1391
|
-
return config
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
## Base Operations
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
@ops_registry("identity")
|
|
1398
|
-
class Identity(Operation):
|
|
1399
|
-
"""Identity operation."""
|
|
1400
|
-
|
|
1401
|
-
def call(self, **kwargs) -> Dict:
|
|
1402
|
-
"""Returns the input as is."""
|
|
1403
|
-
return {}
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
@ops_registry("merge")
|
|
1407
|
-
class Merge(Operation):
|
|
1408
|
-
"""Operation that merges sets of input dictionaries."""
|
|
1409
|
-
|
|
1410
|
-
def __init__(self, **kwargs):
|
|
1411
|
-
super().__init__(**kwargs)
|
|
1412
|
-
self.allow_multiple_inputs = True
|
|
1413
|
-
|
|
1414
|
-
def call(self, *args, **kwargs) -> Dict:
|
|
1415
|
-
"""
|
|
1416
|
-
Merges the input dictionaries. Priority is given to the last input.
|
|
1417
|
-
"""
|
|
1418
|
-
merged = {}
|
|
1419
|
-
for arg in args:
|
|
1420
|
-
if not isinstance(arg, dict):
|
|
1421
|
-
raise TypeError("All inputs must be dictionaries.")
|
|
1422
|
-
merged.update(arg)
|
|
1423
|
-
return merged
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
@ops_registry("split")
|
|
1427
|
-
class Split(Operation):
|
|
1428
|
-
"""Operation that splits an input dictionary n copies."""
|
|
1429
|
-
|
|
1430
|
-
def __init__(self, n: int, **kwargs):
|
|
1431
|
-
super().__init__(**kwargs)
|
|
1432
|
-
self.n = n
|
|
1433
|
-
|
|
1434
|
-
def call(self, **kwargs) -> List[Dict]:
|
|
1435
|
-
"""
|
|
1436
|
-
Splits the input dictionary into n copies.
|
|
1437
|
-
"""
|
|
1438
|
-
return [kwargs.copy() for _ in range(self.n)]
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
@ops_registry("stack")
|
|
1442
|
-
class Stack(Operation):
|
|
1443
|
-
"""Stack multiple data arrays along a new axis.
|
|
1444
|
-
Useful to merge data from parallel pipelines.
|
|
1445
|
-
"""
|
|
1446
|
-
|
|
1447
|
-
def __init__(
|
|
1448
|
-
self,
|
|
1449
|
-
keys: Union[str, List[str], None],
|
|
1450
|
-
axes: Union[int, List[int], None],
|
|
1451
|
-
**kwargs,
|
|
1452
|
-
):
|
|
1453
|
-
super().__init__(**kwargs)
|
|
1454
|
-
|
|
1455
|
-
self.keys, self.axes = _assert_keys_and_axes(keys, axes)
|
|
1456
|
-
|
|
1457
|
-
def call(self, **kwargs) -> Dict:
|
|
1458
|
-
"""
|
|
1459
|
-
Stacks the inputs corresponding to the specified keys along the specified axis.
|
|
1460
|
-
If a list of axes is provided, the length must match the number of keys.
|
|
1461
|
-
"""
|
|
1462
|
-
for key, axis in zip(self.keys, self.axes):
|
|
1463
|
-
kwargs[key] = keras.ops.stack([kwargs[key] for key in self.keys], axis=axis)
|
|
1464
|
-
return kwargs
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
@ops_registry("mean")
|
|
1468
|
-
class Mean(Operation):
|
|
1469
|
-
"""Take the mean of the input data along a specific axis."""
|
|
1470
|
-
|
|
1471
|
-
def __init__(self, keys, axes, **kwargs):
|
|
1472
|
-
super().__init__(**kwargs)
|
|
1473
|
-
|
|
1474
|
-
self.keys, self.axes = _assert_keys_and_axes(keys, axes)
|
|
1475
|
-
|
|
1476
|
-
def call(self, **kwargs):
|
|
1477
|
-
for key, axis in zip(self.keys, self.axes):
|
|
1478
|
-
kwargs[key] = ops.mean(kwargs[key], axis=axis)
|
|
1479
|
-
|
|
1480
|
-
return kwargs
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
@ops_registry("simulate_rf")
|
|
1484
|
-
class Simulate(Operation):
|
|
1485
|
-
"""Simulate RF data."""
|
|
1486
|
-
|
|
1487
|
-
# Define operation-specific static parameters
|
|
1488
|
-
STATIC_PARAMS = ["n_ax", "apply_lens_correction"]
|
|
1489
|
-
|
|
1490
|
-
def __init__(self, **kwargs):
|
|
1491
|
-
super().__init__(
|
|
1492
|
-
output_data_type=DataTypes.RAW_DATA,
|
|
1493
|
-
additional_output_keys=["n_ch"],
|
|
1494
|
-
**kwargs,
|
|
1495
|
-
)
|
|
1496
|
-
|
|
1497
|
-
def call(
|
|
1498
|
-
self,
|
|
1499
|
-
scatterer_positions,
|
|
1500
|
-
scatterer_magnitudes,
|
|
1501
|
-
probe_geometry,
|
|
1502
|
-
apply_lens_correction,
|
|
1503
|
-
lens_thickness,
|
|
1504
|
-
lens_sound_speed,
|
|
1505
|
-
sound_speed,
|
|
1506
|
-
n_ax,
|
|
1507
|
-
center_frequency,
|
|
1508
|
-
sampling_frequency,
|
|
1509
|
-
t0_delays,
|
|
1510
|
-
initial_times,
|
|
1511
|
-
element_width,
|
|
1512
|
-
attenuation_coef,
|
|
1513
|
-
tx_apodizations,
|
|
1514
|
-
**kwargs,
|
|
1515
|
-
):
|
|
1516
|
-
return {
|
|
1517
|
-
self.output_key: simulate_rf(
|
|
1518
|
-
ops.convert_to_tensor(scatterer_positions),
|
|
1519
|
-
ops.convert_to_tensor(scatterer_magnitudes),
|
|
1520
|
-
probe_geometry=probe_geometry,
|
|
1521
|
-
apply_lens_correction=apply_lens_correction,
|
|
1522
|
-
lens_thickness=lens_thickness,
|
|
1523
|
-
lens_sound_speed=lens_sound_speed,
|
|
1524
|
-
sound_speed=sound_speed,
|
|
1525
|
-
n_ax=n_ax,
|
|
1526
|
-
center_frequency=center_frequency,
|
|
1527
|
-
sampling_frequency=sampling_frequency,
|
|
1528
|
-
t0_delays=t0_delays,
|
|
1529
|
-
initial_times=initial_times,
|
|
1530
|
-
element_width=element_width,
|
|
1531
|
-
attenuation_coef=attenuation_coef,
|
|
1532
|
-
tx_apodizations=tx_apodizations,
|
|
1533
|
-
),
|
|
1534
|
-
"n_ch": 1, # Simulate always returns RF data (so single channel)
|
|
1535
|
-
}
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
@ops_registry("tof_correction")
|
|
1539
|
-
class TOFCorrection(Operation):
|
|
1540
|
-
"""Time-of-flight correction operation for ultrasound data."""
|
|
1541
|
-
|
|
1542
|
-
# Define operation-specific static parameters
|
|
1543
|
-
STATIC_PARAMS = [
|
|
1544
|
-
"f_number",
|
|
1545
|
-
"apply_lens_correction",
|
|
1546
|
-
"grid_size_x",
|
|
1547
|
-
"grid_size_z",
|
|
1548
|
-
]
|
|
1549
|
-
|
|
1550
|
-
def __init__(self, **kwargs):
|
|
1551
|
-
super().__init__(
|
|
1552
|
-
input_data_type=DataTypes.RAW_DATA,
|
|
1553
|
-
output_data_type=DataTypes.ALIGNED_DATA,
|
|
1554
|
-
**kwargs,
|
|
1555
|
-
)
|
|
1556
|
-
|
|
1557
|
-
def call(
|
|
1558
|
-
self,
|
|
1559
|
-
flatgrid,
|
|
1560
|
-
sound_speed,
|
|
1561
|
-
polar_angles,
|
|
1562
|
-
focus_distances,
|
|
1563
|
-
sampling_frequency,
|
|
1564
|
-
f_number,
|
|
1565
|
-
demodulation_frequency,
|
|
1566
|
-
t0_delays,
|
|
1567
|
-
tx_apodizations,
|
|
1568
|
-
initial_times,
|
|
1569
|
-
probe_geometry,
|
|
1570
|
-
t_peak,
|
|
1571
|
-
tx_waveform_indices,
|
|
1572
|
-
apply_lens_correction=None,
|
|
1573
|
-
lens_thickness=None,
|
|
1574
|
-
lens_sound_speed=None,
|
|
1575
|
-
**kwargs,
|
|
1576
|
-
):
|
|
1577
|
-
"""Perform time-of-flight correction on raw RF data.
|
|
1578
|
-
|
|
1579
|
-
Args:
|
|
1580
|
-
raw_data (ops.Tensor): Raw RF data to correct
|
|
1581
|
-
flatgrid (ops.Tensor): Grid points at which to evaluate the time-of-flight
|
|
1582
|
-
sound_speed (float): Sound speed in the medium
|
|
1583
|
-
polar_angles (ops.Tensor): Polar angles for scan lines
|
|
1584
|
-
focus_distances (ops.Tensor): Focus distances for scan lines
|
|
1585
|
-
sampling_frequency (float): Sampling frequency
|
|
1586
|
-
f_number (float): F-number for apodization
|
|
1587
|
-
demodulation_frequency (float): Demodulation frequency
|
|
1588
|
-
t0_delays (ops.Tensor): T0 delays
|
|
1589
|
-
tx_apodizations (ops.Tensor): Transmit apodizations
|
|
1590
|
-
initial_times (ops.Tensor): Initial times
|
|
1591
|
-
probe_geometry (ops.Tensor): Probe element positions
|
|
1592
|
-
t_peak (float): Time to peak of the transmit pulse
|
|
1593
|
-
tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
|
|
1594
|
-
transmit. (All zero if there is only one waveform)
|
|
1595
|
-
apply_lens_correction (bool): Whether to apply lens correction
|
|
1596
|
-
lens_thickness (float): Lens thickness
|
|
1597
|
-
lens_sound_speed (float): Sound speed in the lens
|
|
1598
|
-
|
|
1599
|
-
Returns:
|
|
1600
|
-
dict: Dictionary containing tof_corrected_data
|
|
1601
|
-
"""
|
|
1602
|
-
|
|
1603
|
-
raw_data = kwargs[self.key]
|
|
1604
|
-
|
|
1605
|
-
tof_kwargs = {
|
|
1606
|
-
"flatgrid": flatgrid,
|
|
1607
|
-
"t0_delays": t0_delays,
|
|
1608
|
-
"tx_apodizations": tx_apodizations,
|
|
1609
|
-
"sound_speed": sound_speed,
|
|
1610
|
-
"probe_geometry": probe_geometry,
|
|
1611
|
-
"initial_times": initial_times,
|
|
1612
|
-
"sampling_frequency": sampling_frequency,
|
|
1613
|
-
"demodulation_frequency": demodulation_frequency,
|
|
1614
|
-
"f_number": f_number,
|
|
1615
|
-
"polar_angles": polar_angles,
|
|
1616
|
-
"focus_distances": focus_distances,
|
|
1617
|
-
"t_peak": t_peak,
|
|
1618
|
-
"tx_waveform_indices": tx_waveform_indices,
|
|
1619
|
-
"apply_lens_correction": apply_lens_correction,
|
|
1620
|
-
"lens_thickness": lens_thickness,
|
|
1621
|
-
"lens_sound_speed": lens_sound_speed,
|
|
1622
|
-
}
|
|
1623
|
-
|
|
1624
|
-
if not self.with_batch_dim:
|
|
1625
|
-
tof_corrected = tof_correction(raw_data, **tof_kwargs)
|
|
1626
|
-
else:
|
|
1627
|
-
tof_corrected = ops.map(
|
|
1628
|
-
lambda data: tof_correction(data, **tof_kwargs),
|
|
1629
|
-
raw_data,
|
|
1630
|
-
)
|
|
1631
|
-
|
|
1632
|
-
return {self.output_key: tof_corrected}
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
@ops_registry("pfield_weighting")
|
|
1636
|
-
class PfieldWeighting(Operation):
|
|
1637
|
-
"""Weighting aligned data with the pressure field."""
|
|
1638
|
-
|
|
1639
|
-
def __init__(self, **kwargs):
|
|
1640
|
-
super().__init__(
|
|
1641
|
-
input_data_type=DataTypes.ALIGNED_DATA,
|
|
1642
|
-
output_data_type=DataTypes.ALIGNED_DATA,
|
|
1643
|
-
**kwargs,
|
|
1644
|
-
)
|
|
1645
|
-
|
|
1646
|
-
def call(self, flat_pfield=None, **kwargs):
|
|
1647
|
-
"""Weight data with pressure field.
|
|
1648
|
-
|
|
1649
|
-
Args:
|
|
1650
|
-
flat_pfield (ops.Tensor): Pressure field weight mask of shape (n_pix, n_tx)
|
|
1651
|
-
|
|
1652
|
-
Returns:
|
|
1653
|
-
dict: Dictionary containing weighted data
|
|
1654
|
-
"""
|
|
1655
|
-
data = kwargs[self.key]
|
|
1656
|
-
|
|
1657
|
-
if flat_pfield is None:
|
|
1658
|
-
return {self.output_key: data}
|
|
1659
|
-
|
|
1660
|
-
# Swap (n_pix, n_tx) to (n_tx, n_pix)
|
|
1661
|
-
flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
|
|
1662
|
-
|
|
1663
|
-
# Perform element-wise multiplication with the pressure weight mask
|
|
1664
|
-
# Also add the required dimensions for broadcasting
|
|
1665
|
-
if self.with_batch_dim:
|
|
1666
|
-
pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
|
|
1667
|
-
else:
|
|
1668
|
-
pfield_expanded = flat_pfield
|
|
1669
|
-
|
|
1670
|
-
pfield_expanded = pfield_expanded[..., None, None]
|
|
1671
|
-
weighted_data = data * pfield_expanded
|
|
1672
|
-
|
|
1673
|
-
return {self.output_key: weighted_data}
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
@ops_registry("delay_and_sum")
|
|
1677
|
-
class DelayAndSum(Operation):
|
|
1678
|
-
"""Sums time-delayed signals along channels and transmits."""
|
|
1679
|
-
|
|
1680
|
-
def __init__(
|
|
1681
|
-
self,
|
|
1682
|
-
reshape_grid=True,
|
|
1683
|
-
**kwargs,
|
|
1684
|
-
):
|
|
1685
|
-
super().__init__(
|
|
1686
|
-
input_data_type=DataTypes.ALIGNED_DATA,
|
|
1687
|
-
output_data_type=DataTypes.BEAMFORMED_DATA,
|
|
1688
|
-
**kwargs,
|
|
1689
|
-
)
|
|
1690
|
-
self.reshape_grid = reshape_grid
|
|
1691
|
-
|
|
1692
|
-
def process_image(self, data):
|
|
1693
|
-
"""Performs DAS beamforming on tof-corrected input.
|
|
1694
|
-
|
|
1695
|
-
Args:
|
|
1696
|
-
data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
|
|
1697
|
-
|
|
1698
|
-
Returns:
|
|
1699
|
-
ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
|
|
1700
|
-
"""
|
|
1701
|
-
# Sum over the channels, i.e. DAS
|
|
1702
|
-
data = ops.sum(data, -2)
|
|
1703
|
-
|
|
1704
|
-
# Sum over transmits, i.e. Compounding
|
|
1705
|
-
data = ops.sum(data, 0)
|
|
1706
|
-
|
|
1707
|
-
return data
|
|
1708
|
-
|
|
1709
|
-
def call(self, grid=None, **kwargs):
|
|
1710
|
-
"""Performs DAS beamforming on tof-corrected input.
|
|
1711
|
-
|
|
1712
|
-
Args:
|
|
1713
|
-
tof_corrected_data (ops.Tensor): The TOF corrected input of shape
|
|
1714
|
-
`(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
|
|
1715
|
-
|
|
1716
|
-
Returns:
|
|
1717
|
-
dict: Dictionary containing beamformed_data
|
|
1718
|
-
of shape `(grid_size_z*grid_size_x, n_ch)` when reshape_grid is False
|
|
1719
|
-
or `(grid_size_z, grid_size_x, n_ch)` when reshape_grid is True,
|
|
1720
|
-
with optional batch dimension.
|
|
1721
|
-
"""
|
|
1722
|
-
data = kwargs[self.key]
|
|
1723
|
-
|
|
1724
|
-
if not self.with_batch_dim:
|
|
1725
|
-
beamformed_data = self.process_image(data)
|
|
1726
|
-
else:
|
|
1727
|
-
# Apply process_image to each item in the batch
|
|
1728
|
-
beamformed_data = ops.map(self.process_image, data)
|
|
1729
|
-
|
|
1730
|
-
if self.reshape_grid:
|
|
1731
|
-
beamformed_data = reshape_axis(
|
|
1732
|
-
beamformed_data, grid.shape[:2], axis=int(self.with_batch_dim)
|
|
1733
|
-
)
|
|
1734
|
-
|
|
1735
|
-
return {self.output_key: beamformed_data}
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
def envelope_detect(data, axis=-3):
|
|
1739
|
-
"""Envelope detection of RF signals.
|
|
1740
|
-
|
|
1741
|
-
If the input data is real, it first applies the Hilbert transform along the specified axis
|
|
1742
|
-
and then computes the magnitude of the resulting complex signal.
|
|
1743
|
-
If the input data is complex, it computes the magnitude directly.
|
|
1744
|
-
|
|
1745
|
-
Args:
|
|
1746
|
-
- data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
|
|
1747
|
-
- axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
|
|
1748
|
-
|
|
1749
|
-
Returns:
|
|
1750
|
-
- envelope_data (Tensor): The envelope detected data
|
|
1751
|
-
of shape (..., grid_size_z, grid_size_x).
|
|
1752
|
-
"""
|
|
1753
|
-
if data.shape[-1] == 2:
|
|
1754
|
-
data = channels_to_complex(data)
|
|
1755
|
-
else:
|
|
1756
|
-
n_ax = ops.shape(data)[axis]
|
|
1757
|
-
n_ax_float = ops.cast(n_ax, "float32")
|
|
1758
|
-
|
|
1759
|
-
# Calculate next power of 2: M = 2^ceil(log2(n_ax))
|
|
1760
|
-
# see https://github.com/tue-bmd/zea/discussions/147
|
|
1761
|
-
log2_n_ax = ops.log2(n_ax_float)
|
|
1762
|
-
M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
|
|
1763
|
-
|
|
1764
|
-
data = hilbert(data, N=M, axis=axis)
|
|
1765
|
-
indices = ops.arange(n_ax)
|
|
1766
|
-
|
|
1767
|
-
data = ops.take(data, indices, axis=axis)
|
|
1768
|
-
data = ops.squeeze(data, axis=-1)
|
|
1769
|
-
|
|
1770
|
-
# data = ops.abs(data)
|
|
1771
|
-
real = ops.real(data)
|
|
1772
|
-
imag = ops.imag(data)
|
|
1773
|
-
data = ops.sqrt(real**2 + imag**2)
|
|
1774
|
-
data = ops.cast(data, "float32")
|
|
1775
|
-
return data
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
@ops_registry("envelope_detect")
|
|
1779
|
-
class EnvelopeDetect(Operation):
|
|
1780
|
-
"""Envelope detection of RF signals."""
|
|
1781
|
-
|
|
1782
|
-
def __init__(
|
|
1783
|
-
self,
|
|
1784
|
-
axis=-3,
|
|
1785
|
-
**kwargs,
|
|
1786
|
-
):
|
|
1787
|
-
super().__init__(
|
|
1788
|
-
input_data_type=DataTypes.BEAMFORMED_DATA,
|
|
1789
|
-
output_data_type=DataTypes.ENVELOPE_DATA,
|
|
1790
|
-
**kwargs,
|
|
1791
|
-
)
|
|
1792
|
-
self.axis = axis
|
|
1793
|
-
|
|
1794
|
-
def call(self, **kwargs):
|
|
1795
|
-
"""
|
|
1796
|
-
Args:
|
|
1797
|
-
- data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
|
|
1798
|
-
Returns:
|
|
1799
|
-
- envelope_data (Tensor): The envelope detected data
|
|
1800
|
-
of shape (..., grid_size_z, grid_size_x).
|
|
1801
|
-
"""
|
|
1802
|
-
data = kwargs[self.key]
|
|
1803
|
-
|
|
1804
|
-
data = envelope_detect(data, axis=self.axis)
|
|
1805
|
-
|
|
1806
|
-
return {self.output_key: data}
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
@ops_registry("upmix")
|
|
1810
|
-
class UpMix(Operation):
|
|
1811
|
-
"""Upmix IQ data to RF data."""
|
|
1812
|
-
|
|
1813
|
-
def __init__(
|
|
1814
|
-
self,
|
|
1815
|
-
upsampling_rate=1,
|
|
1816
|
-
**kwargs,
|
|
1817
|
-
):
|
|
1818
|
-
super().__init__(
|
|
1819
|
-
**kwargs,
|
|
1820
|
-
)
|
|
1821
|
-
self.upsampling_rate = upsampling_rate
|
|
1822
|
-
|
|
1823
|
-
def call(
|
|
1824
|
-
self,
|
|
1825
|
-
sampling_frequency=None,
|
|
1826
|
-
center_frequency=None,
|
|
1827
|
-
**kwargs,
|
|
1828
|
-
):
|
|
1829
|
-
data = kwargs[self.key]
|
|
1830
|
-
|
|
1831
|
-
if data.shape[-1] == 1:
|
|
1832
|
-
log.warning("Upmixing is not applicable to RF data.")
|
|
1833
|
-
return data
|
|
1834
|
-
elif data.shape[-1] == 2:
|
|
1835
|
-
data = channels_to_complex(data)
|
|
1836
|
-
|
|
1837
|
-
data = upmix(data, sampling_frequency, center_frequency, self.upsampling_rate)
|
|
1838
|
-
data = ops.expand_dims(data, axis=-1)
|
|
1839
|
-
return {self.output_key: data}
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
def log_compress(data, eps=1e-16):
|
|
1843
|
-
"""Apply logarithmic compression to data."""
|
|
1844
|
-
eps = ops.convert_to_tensor(eps, dtype=data.dtype)
|
|
1845
|
-
data = ops.where(data == 0, eps, data) # Avoid log(0)
|
|
1846
|
-
return 20 * keras.ops.log10(data)
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
@ops_registry("log_compress")
|
|
1850
|
-
class LogCompress(Operation):
|
|
1851
|
-
"""Logarithmic compression of data."""
|
|
1852
|
-
|
|
1853
|
-
def __init__(self, clip: bool = True, **kwargs):
|
|
1854
|
-
"""Initialize the LogCompress operation.
|
|
1855
|
-
|
|
1856
|
-
Args:
|
|
1857
|
-
clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
|
|
1858
|
-
"""
|
|
1859
|
-
super().__init__(
|
|
1860
|
-
input_data_type=DataTypes.ENVELOPE_DATA,
|
|
1861
|
-
output_data_type=DataTypes.IMAGE,
|
|
1862
|
-
**kwargs,
|
|
1863
|
-
)
|
|
1864
|
-
self.clip = clip
|
|
1865
|
-
|
|
1866
|
-
def call(self, dynamic_range=None, **kwargs):
|
|
1867
|
-
"""Apply logarithmic compression to data.
|
|
1868
|
-
|
|
1869
|
-
Args:
|
|
1870
|
-
dynamic_range (tuple, optional): Dynamic range in dB. Defaults to (-60, 0).
|
|
1871
|
-
|
|
1872
|
-
Returns:
|
|
1873
|
-
dict: Dictionary containing log-compressed data
|
|
1874
|
-
"""
|
|
1875
|
-
data = kwargs[self.key]
|
|
1876
|
-
|
|
1877
|
-
if dynamic_range is None:
|
|
1878
|
-
dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
|
|
1879
|
-
dynamic_range = ops.cast(dynamic_range, data.dtype)
|
|
1880
|
-
|
|
1881
|
-
compressed_data = log_compress(data)
|
|
1882
|
-
if self.clip:
|
|
1883
|
-
compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
|
|
1884
|
-
|
|
1885
|
-
return {self.output_key: compressed_data}
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
def normalize(data, output_range, input_range=None):
|
|
1889
|
-
"""Normalize data to a given range.
|
|
1890
|
-
|
|
1891
|
-
Equivalent to `translate` with clipping.
|
|
1892
|
-
|
|
1893
|
-
Args:
|
|
1894
|
-
data (ops.Tensor): Input data to normalize.
|
|
1895
|
-
output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
|
|
1896
|
-
input_range (tuple, optional): Range of input data.
|
|
1897
|
-
If None, the range will be computed from the data.
|
|
1898
|
-
Defaults to None.
|
|
1899
|
-
"""
|
|
1900
|
-
if input_range is None:
|
|
1901
|
-
input_range = (None, None)
|
|
1902
|
-
minval, maxval = input_range
|
|
1903
|
-
if minval is None:
|
|
1904
|
-
minval = ops.min(data)
|
|
1905
|
-
if maxval is None:
|
|
1906
|
-
maxval = ops.max(data)
|
|
1907
|
-
data = ops.clip(data, minval, maxval)
|
|
1908
|
-
normalized_data = translate(data, (minval, maxval), output_range)
|
|
1909
|
-
return normalized_data
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
@ops_registry("normalize")
|
|
1913
|
-
class Normalize(Operation):
|
|
1914
|
-
"""Normalize data to a given range."""
|
|
1915
|
-
|
|
1916
|
-
def __init__(self, output_range=None, input_range=None, **kwargs):
|
|
1917
|
-
super().__init__(additional_output_keys=["minval", "maxval"], **kwargs)
|
|
1918
|
-
if output_range is None:
|
|
1919
|
-
output_range = (0, 1)
|
|
1920
|
-
self.output_range = self.to_float32(output_range)
|
|
1921
|
-
self.input_range = self.to_float32(input_range)
|
|
1922
|
-
assert output_range is None or len(output_range) == 2
|
|
1923
|
-
assert input_range is None or len(input_range) == 2
|
|
1924
|
-
|
|
1925
|
-
@staticmethod
|
|
1926
|
-
def to_float32(data):
|
|
1927
|
-
"""Converts an iterable to float32 and leaves None values as is."""
|
|
1928
|
-
return (
|
|
1929
|
-
[np.float32(x) if x is not None else None for x in data] if data is not None else None
|
|
1930
|
-
)
|
|
1931
|
-
|
|
1932
|
-
def call(self, **kwargs):
|
|
1933
|
-
"""Normalize data to a given range.
|
|
1934
|
-
|
|
1935
|
-
Args:
|
|
1936
|
-
output_range (tuple, optional): Range to which data should be mapped.
|
|
1937
|
-
Defaults to (0, 1).
|
|
1938
|
-
input_range (tuple, optional): Range of input data. If None, the range
|
|
1939
|
-
of the input data will be computed. Defaults to None.
|
|
1940
|
-
|
|
1941
|
-
Returns:
|
|
1942
|
-
dict: Dictionary containing normalized data, along with the computed
|
|
1943
|
-
or provided input range (minval and maxval).
|
|
1944
|
-
"""
|
|
1945
|
-
data = kwargs[self.key]
|
|
1946
|
-
|
|
1947
|
-
# If input_range is not provided, try to get it from kwargs
|
|
1948
|
-
# This allows you to normalize based on the first frame in a sequence and avoid flicker
|
|
1949
|
-
if self.input_range is None:
|
|
1950
|
-
maxval = kwargs.get("maxval", None)
|
|
1951
|
-
minval = kwargs.get("minval", None)
|
|
1952
|
-
# If input_range is provided, use it
|
|
1953
|
-
else:
|
|
1954
|
-
minval, maxval = self.input_range
|
|
1955
|
-
|
|
1956
|
-
# If input_range is still not provided, compute it from the data
|
|
1957
|
-
if minval is None:
|
|
1958
|
-
minval = ops.min(data)
|
|
1959
|
-
if maxval is None:
|
|
1960
|
-
maxval = ops.max(data)
|
|
1961
|
-
|
|
1962
|
-
normalized_data = normalize(
|
|
1963
|
-
data, output_range=self.output_range, input_range=(minval, maxval)
|
|
1964
|
-
)
|
|
1965
|
-
|
|
1966
|
-
return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
@ops_registry("scan_convert")
|
|
1970
|
-
class ScanConvert(Operation):
|
|
1971
|
-
"""Scan convert images to cartesian coordinates."""
|
|
1972
|
-
|
|
1973
|
-
STATIC_PARAMS = ["fill_value"]
|
|
1974
|
-
|
|
1975
|
-
def __init__(self, order=1, **kwargs):
|
|
1976
|
-
"""Initialize the ScanConvert operation.
|
|
1977
|
-
|
|
1978
|
-
Args:
|
|
1979
|
-
order (int, optional): Interpolation order. Defaults to 1. Currently only
|
|
1980
|
-
GPU support for order=1.
|
|
1981
|
-
"""
|
|
1982
|
-
if order > 1:
|
|
1983
|
-
jittable = False
|
|
1984
|
-
log.warning(
|
|
1985
|
-
"GPU support for order > 1 is not available. " + "Disabling jit for ScanConvert."
|
|
1986
|
-
)
|
|
1987
|
-
else:
|
|
1988
|
-
jittable = True
|
|
1989
|
-
|
|
1990
|
-
super().__init__(
|
|
1991
|
-
input_data_type=DataTypes.IMAGE,
|
|
1992
|
-
output_data_type=DataTypes.IMAGE_SC,
|
|
1993
|
-
jittable=jittable,
|
|
1994
|
-
additional_output_keys=[
|
|
1995
|
-
"resolution",
|
|
1996
|
-
"x_lim",
|
|
1997
|
-
"y_lim",
|
|
1998
|
-
"z_lim",
|
|
1999
|
-
"rho_range",
|
|
2000
|
-
"theta_range",
|
|
2001
|
-
"phi_range",
|
|
2002
|
-
"d_rho",
|
|
2003
|
-
"d_theta",
|
|
2004
|
-
"d_phi",
|
|
2005
|
-
],
|
|
2006
|
-
**kwargs,
|
|
2007
|
-
)
|
|
2008
|
-
self.order = order
|
|
2009
|
-
|
|
2010
|
-
def call(
|
|
2011
|
-
self,
|
|
2012
|
-
rho_range=None,
|
|
2013
|
-
theta_range=None,
|
|
2014
|
-
phi_range=None,
|
|
2015
|
-
resolution=None,
|
|
2016
|
-
coordinates=None,
|
|
2017
|
-
fill_value=None,
|
|
2018
|
-
**kwargs,
|
|
2019
|
-
):
|
|
2020
|
-
"""Scan convert images to cartesian coordinates.
|
|
2021
|
-
|
|
2022
|
-
Args:
|
|
2023
|
-
rho_range (Tuple): Range of the rho axis in the polar coordinate system.
|
|
2024
|
-
Defined in meters.
|
|
2025
|
-
theta_range (Tuple): Range of the theta axis in the polar coordinate system.
|
|
2026
|
-
Defined in radians.
|
|
2027
|
-
phi_range (Tuple): Range of the phi axis in the polar coordinate system.
|
|
2028
|
-
Defined in radians.
|
|
2029
|
-
resolution (float): Resolution of the output image in meters per pixel.
|
|
2030
|
-
if None, the resolution is computed based on the input data.
|
|
2031
|
-
coordinates (Tensor): Coordinates for scan convertion. If None, will be computed
|
|
2032
|
-
based on rho_range, theta_range, phi_range and resolution. If provided, this
|
|
2033
|
-
operation can be jitted.
|
|
2034
|
-
fill_value (float): Value to fill the image with outside the defined region.
|
|
2035
|
-
|
|
2036
|
-
"""
|
|
2037
|
-
if fill_value is None:
|
|
2038
|
-
fill_value = np.nan
|
|
2039
|
-
|
|
2040
|
-
data = kwargs[self.key]
|
|
2041
|
-
|
|
2042
|
-
if self._jit_compile and self.jittable:
|
|
2043
|
-
assert coordinates is not None, (
|
|
2044
|
-
"coordinates must be provided to jit scan conversion."
|
|
2045
|
-
"You can set ScanConvert(jit_compile=False) to disable jitting."
|
|
2046
|
-
)
|
|
2047
|
-
|
|
2048
|
-
data_out, parameters = scan_convert(
|
|
2049
|
-
data,
|
|
2050
|
-
rho_range,
|
|
2051
|
-
theta_range,
|
|
2052
|
-
phi_range,
|
|
2053
|
-
resolution,
|
|
2054
|
-
coordinates,
|
|
2055
|
-
fill_value,
|
|
2056
|
-
self.order,
|
|
2057
|
-
with_batch_dim=self.with_batch_dim,
|
|
2058
|
-
)
|
|
2059
|
-
|
|
2060
|
-
return {self.output_key: data_out, **parameters}
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
@ops_registry("gaussian_blur")
|
|
2064
|
-
class GaussianBlur(Operation):
|
|
2065
|
-
"""
|
|
2066
|
-
GaussianBlur is an operation that applies a Gaussian blur to an input image.
|
|
2067
|
-
Uses scipy.ndimage.gaussian_filter to create a kernel.
|
|
2068
|
-
"""
|
|
2069
|
-
|
|
2070
|
-
def __init__(
|
|
2071
|
-
self,
|
|
2072
|
-
sigma: float,
|
|
2073
|
-
kernel_size: int | None = None,
|
|
2074
|
-
pad_mode="symmetric",
|
|
2075
|
-
truncate=4.0,
|
|
2076
|
-
**kwargs,
|
|
2077
|
-
):
|
|
2078
|
-
"""
|
|
2079
|
-
Args:
|
|
2080
|
-
sigma (float): Standard deviation for Gaussian kernel.
|
|
2081
|
-
kernel_size (int, optional): The size of the kernel. If None, the kernel
|
|
2082
|
-
size is calculated based on the sigma and truncate. Default is None.
|
|
2083
|
-
pad_mode (str): Padding mode for the input image. Default is 'symmetric'.
|
|
2084
|
-
truncate (float): Truncate the filter at this many standard deviations.
|
|
2085
|
-
"""
|
|
2086
|
-
super().__init__(**kwargs)
|
|
2087
|
-
if kernel_size is None:
|
|
2088
|
-
radius = round(truncate * sigma)
|
|
2089
|
-
self.kernel_size = 2 * radius + 1
|
|
2090
|
-
else:
|
|
2091
|
-
self.kernel_size = kernel_size
|
|
2092
|
-
self.sigma = sigma
|
|
2093
|
-
self.pad_mode = pad_mode
|
|
2094
|
-
self.radius = self.kernel_size // 2
|
|
2095
|
-
self.kernel = self.get_kernel()
|
|
2096
|
-
|
|
2097
|
-
def get_kernel(self):
|
|
2098
|
-
"""
|
|
2099
|
-
Create a gaussian kernel for blurring.
|
|
2100
|
-
|
|
2101
|
-
Returns:
|
|
2102
|
-
kernel (Tensor): A gaussian kernel for blurring.
|
|
2103
|
-
Shape is (kernel_size, kernel_size, 1, 1).
|
|
2104
|
-
"""
|
|
2105
|
-
n = np.zeros((self.kernel_size, self.kernel_size))
|
|
2106
|
-
n[self.radius, self.radius] = 1
|
|
2107
|
-
kernel = scipy.ndimage.gaussian_filter(n, sigma=self.sigma, mode="constant").astype(
|
|
2108
|
-
np.float32
|
|
2109
|
-
)
|
|
2110
|
-
kernel = kernel[:, :, None, None]
|
|
2111
|
-
return ops.convert_to_tensor(kernel)
|
|
2112
|
-
|
|
2113
|
-
def call(self, **kwargs):
|
|
2114
|
-
data = kwargs[self.key]
|
|
2115
|
-
|
|
2116
|
-
# Add batch dimension if not present
|
|
2117
|
-
if not self.with_batch_dim:
|
|
2118
|
-
data = data[None]
|
|
2119
|
-
|
|
2120
|
-
# Add channel dimension to kernel
|
|
2121
|
-
kernel = ops.tile(self.kernel, (1, 1, data.shape[-1], data.shape[-1]))
|
|
2122
|
-
|
|
2123
|
-
# Pad the input image according to the padding mode
|
|
2124
|
-
padded = ops.pad(
|
|
2125
|
-
data,
|
|
2126
|
-
[[0, 0], [self.radius, self.radius], [self.radius, self.radius], [0, 0]],
|
|
2127
|
-
mode=self.pad_mode,
|
|
2128
|
-
)
|
|
2129
|
-
|
|
2130
|
-
# Apply the gaussian kernel to the padded image
|
|
2131
|
-
out = ops.conv(padded, kernel, padding="valid", data_format="channels_last")
|
|
2132
|
-
|
|
2133
|
-
# Remove padding
|
|
2134
|
-
out = ops.slice(
|
|
2135
|
-
out,
|
|
2136
|
-
[0, 0, 0, 0],
|
|
2137
|
-
[out.shape[0], data.shape[1], data.shape[2], data.shape[3]],
|
|
2138
|
-
)
|
|
2139
|
-
|
|
2140
|
-
# Remove batch dimension if it was not present before
|
|
2141
|
-
if not self.with_batch_dim:
|
|
2142
|
-
out = ops.squeeze(out, axis=0)
|
|
2143
|
-
|
|
2144
|
-
return {self.output_key: out}
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
@ops_registry("lee_filter")
|
|
2148
|
-
class LeeFilter(Operation):
|
|
2149
|
-
"""
|
|
2150
|
-
The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
|
|
2151
|
-
and ultrasound image processing. It smooths the image while preserving edges and details.
|
|
2152
|
-
This implementation uses Gaussian filter for local statistics and treats channels independently.
|
|
2153
|
-
|
|
2154
|
-
Lee, J.S. (1980). Digital image enhancement and noise filtering by use of local statistics.
|
|
2155
|
-
IEEE Transactions on Pattern Analysis and Machine Intelligence, (2), 165-168.
|
|
2156
|
-
"""
|
|
2157
|
-
|
|
2158
|
-
def __init__(self, sigma=3, kernel_size=None, pad_mode="symmetric", **kwargs):
|
|
2159
|
-
"""
|
|
2160
|
-
Args:
|
|
2161
|
-
sigma (float): Standard deviation for Gaussian kernel. Default is 3.
|
|
2162
|
-
kernel_size (int, optional): Size of the Gaussian kernel. If None,
|
|
2163
|
-
it will be calculated based on sigma.
|
|
2164
|
-
pad_mode (str): Padding mode to be used for Gaussian blur. Default is "symmetric".
|
|
2165
|
-
"""
|
|
2166
|
-
super().__init__(**kwargs)
|
|
2167
|
-
self.sigma = sigma
|
|
2168
|
-
self.kernel_size = kernel_size
|
|
2169
|
-
self.pad_mode = pad_mode
|
|
2170
|
-
|
|
2171
|
-
# Create a GaussianBlur instance for computing local statistics
|
|
2172
|
-
self.gaussian_blur = GaussianBlur(
|
|
2173
|
-
sigma=self.sigma,
|
|
2174
|
-
kernel_size=self.kernel_size,
|
|
2175
|
-
pad_mode=self.pad_mode,
|
|
2176
|
-
with_batch_dim=self.with_batch_dim,
|
|
2177
|
-
jittable=self._jittable,
|
|
2178
|
-
key=self.key,
|
|
2179
|
-
)
|
|
2180
|
-
|
|
2181
|
-
@property
|
|
2182
|
-
def with_batch_dim(self):
|
|
2183
|
-
"""Get the with_batch_dim property of the LeeFilter operation."""
|
|
2184
|
-
return self._with_batch_dim
|
|
2185
|
-
|
|
2186
|
-
@with_batch_dim.setter
|
|
2187
|
-
def with_batch_dim(self, value):
|
|
2188
|
-
"""Set the with_batch_dim property of the LeeFilter operation."""
|
|
2189
|
-
self._with_batch_dim = value
|
|
2190
|
-
if hasattr(self, "gaussian_blur"):
|
|
2191
|
-
self.gaussian_blur.with_batch_dim = value
|
|
2192
|
-
|
|
2193
|
-
def call(self, **kwargs):
|
|
2194
|
-
data = kwargs[self.key]
|
|
2195
|
-
|
|
2196
|
-
# Apply Gaussian blur to get local mean
|
|
2197
|
-
img_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
|
|
2198
|
-
|
|
2199
|
-
# Apply Gaussian blur to squared data to get local squared mean
|
|
2200
|
-
data_squared = data**2
|
|
2201
|
-
kwargs[self.gaussian_blur.key] = data_squared
|
|
2202
|
-
img_sqr_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
|
|
2203
|
-
|
|
2204
|
-
# Calculate local variance
|
|
2205
|
-
img_variance = img_sqr_mean - img_mean**2
|
|
2206
|
-
|
|
2207
|
-
# Calculate global variance (per channel)
|
|
2208
|
-
if self.with_batch_dim:
|
|
2209
|
-
overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
|
|
2210
|
-
else:
|
|
2211
|
-
overall_variance = ops.var(data, axis=(-2, -1), keepdims=True)
|
|
2212
|
-
|
|
2213
|
-
# Calculate adaptive weights
|
|
2214
|
-
img_weights = img_variance / (img_variance + overall_variance)
|
|
2215
|
-
|
|
2216
|
-
# Apply Lee filter formula
|
|
2217
|
-
img_output = img_mean + img_weights * (data - img_mean)
|
|
2218
|
-
|
|
2219
|
-
return {self.output_key: img_output}
|
|
2220
|
-
|
|
2221
|
-
|
|
2222
|
-
@ops_registry("demodulate")
|
|
2223
|
-
class Demodulate(Operation):
|
|
2224
|
-
"""Demodulates the input data to baseband. After this operation, the carrier frequency
|
|
2225
|
-
is removed (0 Hz) and the data is in IQ format stored in two real valued channels."""
|
|
2226
|
-
|
|
2227
|
-
def __init__(self, axis=-3, **kwargs):
|
|
2228
|
-
super().__init__(
|
|
2229
|
-
input_data_type=DataTypes.RAW_DATA,
|
|
2230
|
-
output_data_type=DataTypes.RAW_DATA,
|
|
2231
|
-
jittable=True,
|
|
2232
|
-
additional_output_keys=["demodulation_frequency", "center_frequency", "n_ch"],
|
|
2233
|
-
**kwargs,
|
|
2234
|
-
)
|
|
2235
|
-
self.axis = axis
|
|
2236
|
-
|
|
2237
|
-
def call(self, center_frequency=None, sampling_frequency=None, **kwargs):
|
|
2238
|
-
data = kwargs[self.key]
|
|
2239
|
-
|
|
2240
|
-
demodulation_frequency = center_frequency
|
|
2241
|
-
|
|
2242
|
-
# Split the complex signal into two channels
|
|
2243
|
-
iq_data_two_channel = demodulate(
|
|
2244
|
-
data=data,
|
|
2245
|
-
center_frequency=center_frequency,
|
|
2246
|
-
sampling_frequency=sampling_frequency,
|
|
2247
|
-
axis=self.axis,
|
|
2248
|
-
)
|
|
2249
|
-
|
|
2250
|
-
return {
|
|
2251
|
-
self.output_key: iq_data_two_channel,
|
|
2252
|
-
"demodulation_frequency": demodulation_frequency,
|
|
2253
|
-
"center_frequency": 0.0,
|
|
2254
|
-
"n_ch": 2,
|
|
2255
|
-
}
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
@ops_registry("lambda")
|
|
2259
|
-
class Lambda(Operation):
|
|
2260
|
-
"""Use any function as an operation."""
|
|
2261
|
-
|
|
2262
|
-
def __init__(self, func, **kwargs):
|
|
2263
|
-
# Split kwargs into kwargs for partial and __init__
|
|
2264
|
-
op_kwargs = {k: v for k, v in kwargs.items() if k not in func.__code__.co_varnames}
|
|
2265
|
-
func_kwargs = {k: v for k, v in kwargs.items() if k in func.__code__.co_varnames}
|
|
2266
|
-
Lambda._check_if_unary(func, **func_kwargs)
|
|
2267
|
-
|
|
2268
|
-
super().__init__(**op_kwargs)
|
|
2269
|
-
self.func = partial(func, **func_kwargs)
|
|
2270
|
-
|
|
2271
|
-
@staticmethod
|
|
2272
|
-
def _check_if_unary(func, **kwargs):
|
|
2273
|
-
"""Checks if the kwargs are sufficient to call the function as a unary operation."""
|
|
2274
|
-
sig = inspect.signature(func)
|
|
2275
|
-
# Remove arguments that are already provided in func_kwargs
|
|
2276
|
-
params = list(sig.parameters.values())
|
|
2277
|
-
remaining = [p for p in params if p.name not in kwargs]
|
|
2278
|
-
# Count required positional arguments (excluding self/cls)
|
|
2279
|
-
required_positional = [
|
|
2280
|
-
p
|
|
2281
|
-
for p in remaining
|
|
2282
|
-
if p.default is p.empty and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
2283
|
-
]
|
|
2284
|
-
if len(required_positional) != 1:
|
|
2285
|
-
raise ValueError(
|
|
2286
|
-
f"Partial of {func.__name__} must be callable with exactly one required "
|
|
2287
|
-
f"positional argument, we still need: {required_positional}."
|
|
2288
|
-
)
|
|
2289
|
-
|
|
2290
|
-
def call(self, **kwargs):
|
|
2291
|
-
data = kwargs[self.key]
|
|
2292
|
-
data = self.func(data)
|
|
2293
|
-
return {self.output_key: data}
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
@ops_registry("pad")
|
|
2297
|
-
class Pad(Operation, DataLayer):
|
|
2298
|
-
"""Pad layer for padding tensors to a specified shape."""
|
|
2299
|
-
|
|
2300
|
-
def __init__(
|
|
2301
|
-
self,
|
|
2302
|
-
target_shape: list | tuple,
|
|
2303
|
-
uniform: bool = True,
|
|
2304
|
-
axis: Union[int, List[int]] = None,
|
|
2305
|
-
fail_on_bigger_shape: bool = True,
|
|
2306
|
-
pad_kwargs: dict = None,
|
|
2307
|
-
**kwargs,
|
|
2308
|
-
):
|
|
2309
|
-
super().__init__(**kwargs)
|
|
2310
|
-
self.target_shape = target_shape
|
|
2311
|
-
self.uniform = uniform
|
|
2312
|
-
self.axis = axis
|
|
2313
|
-
self.pad_kwargs = pad_kwargs or {}
|
|
2314
|
-
self.fail_on_bigger_shape = fail_on_bigger_shape
|
|
2315
|
-
|
|
2316
|
-
@staticmethod
|
|
2317
|
-
def _format_target_shape(shape_array, target_shape, axis):
|
|
2318
|
-
if isinstance(axis, int):
|
|
2319
|
-
axis = [axis]
|
|
2320
|
-
assert len(axis) == len(target_shape), (
|
|
2321
|
-
"The length of axis must be equal to the length of target_shape."
|
|
2322
|
-
)
|
|
2323
|
-
axis = map_negative_indices(axis, len(shape_array))
|
|
2324
|
-
|
|
2325
|
-
target_shape = [
|
|
2326
|
-
target_shape[axis.index(i)] if i in axis else shape_array[i]
|
|
2327
|
-
for i in range(len(shape_array))
|
|
2328
|
-
]
|
|
2329
|
-
return target_shape
|
|
2330
|
-
|
|
2331
|
-
def pad(
|
|
2332
|
-
self,
|
|
2333
|
-
z,
|
|
2334
|
-
target_shape: list | tuple,
|
|
2335
|
-
uniform: bool = True,
|
|
2336
|
-
axis: Union[int, List[int]] = None,
|
|
2337
|
-
fail_on_bigger_shape: bool = True,
|
|
2338
|
-
**kwargs,
|
|
2339
|
-
):
|
|
2340
|
-
"""
|
|
2341
|
-
Pads the input tensor `z` to the specified shape.
|
|
2342
|
-
|
|
2343
|
-
Parameters:
|
|
2344
|
-
z (tensor): The input tensor to be padded.
|
|
2345
|
-
target_shape (list or tuple): The target shape to pad the tensor to.
|
|
2346
|
-
uniform (bool, optional): If True, ensures that padding is uniform (even on both sides).
|
|
2347
|
-
Default is False.
|
|
2348
|
-
axis (int or list of int, optional): The axis or axes along which `target_shape` was
|
|
2349
|
-
specified. If None, `len(target_shape) == `len(ops.shape(z))` must hold.
|
|
2350
|
-
Default is None.
|
|
2351
|
-
fail_on_bigger_shape (bool, optional): If True (default), raises an error if any target
|
|
2352
|
-
dimension is smaller than the input shape; if False, pads only where the
|
|
2353
|
-
target shape exceeds the input shape and leaves other dimensions unchanged.
|
|
2354
|
-
kwargs: Additional keyword arguments to pass to the padding function.
|
|
2355
|
-
|
|
2356
|
-
Returns:
|
|
2357
|
-
tensor: The padded tensor with the specified shape.
|
|
2358
|
-
"""
|
|
2359
|
-
shape_array = self.backend.shape(z)
|
|
2360
|
-
|
|
2361
|
-
# When axis is provided, convert target_shape
|
|
2362
|
-
if axis is not None:
|
|
2363
|
-
target_shape = self._format_target_shape(shape_array, target_shape, axis)
|
|
2364
|
-
|
|
2365
|
-
if not fail_on_bigger_shape:
|
|
2366
|
-
target_shape = [max(target_shape[i], shape_array[i]) for i in range(len(shape_array))]
|
|
2367
|
-
|
|
2368
|
-
# Compute the padding required for each dimension
|
|
2369
|
-
pad_shape = np.array(target_shape) - shape_array
|
|
2370
|
-
|
|
2371
|
-
# Create the paddings array
|
|
2372
|
-
if uniform:
|
|
2373
|
-
# if odd, pad more on the left, same as:
|
|
2374
|
-
# https://keras.io/api/layers/preprocessing_layers/image_preprocessing/center_crop/
|
|
2375
|
-
right_pad = pad_shape // 2
|
|
2376
|
-
left_pad = pad_shape - right_pad
|
|
2377
|
-
paddings = np.stack([right_pad, left_pad], axis=1)
|
|
2378
|
-
else:
|
|
2379
|
-
paddings = np.stack([np.zeros_like(pad_shape), pad_shape], axis=1)
|
|
2380
|
-
|
|
2381
|
-
if np.any(paddings < 0):
|
|
2382
|
-
raise ValueError(
|
|
2383
|
-
f"Target shape {target_shape} must be greater than or equal "
|
|
2384
|
-
f"to the input shape {shape_array}."
|
|
2385
|
-
)
|
|
2386
|
-
|
|
2387
|
-
return self.backend.numpy.pad(z, paddings, **kwargs)
|
|
2388
|
-
|
|
2389
|
-
def call(self, **kwargs):
|
|
2390
|
-
data = kwargs[self.key]
|
|
2391
|
-
padded_data = self.pad(
|
|
2392
|
-
data,
|
|
2393
|
-
self.target_shape,
|
|
2394
|
-
self.uniform,
|
|
2395
|
-
self.axis,
|
|
2396
|
-
self.fail_on_bigger_shape,
|
|
2397
|
-
**self.pad_kwargs,
|
|
2398
|
-
)
|
|
2399
|
-
return {self.output_key: padded_data}
|
|
2400
|
-
|
|
2401
|
-
|
|
2402
|
-
@ops_registry("companding")
|
|
2403
|
-
class Companding(Operation):
|
|
2404
|
-
"""Companding according to the A- or μ-law algorithm.
|
|
2405
|
-
|
|
2406
|
-
Invertible compressing operation. Used to compress
|
|
2407
|
-
dynamic range of input data (and subsequently expand).
|
|
2408
|
-
|
|
2409
|
-
μ-law companding:
|
|
2410
|
-
https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
|
|
2411
|
-
A-law companding:
|
|
2412
|
-
https://en.wikipedia.org/wiki/A-law_algorithm
|
|
2413
|
-
|
|
2414
|
-
Args:
|
|
2415
|
-
expand (bool, optional): If set to False (default),
|
|
2416
|
-
data is compressed, else expanded.
|
|
2417
|
-
comp_type (str): either `a` or `mu`.
|
|
2418
|
-
mu (float, optional): compression parameter. Defaults to 255.
|
|
2419
|
-
A (float, optional): compression parameter. Defaults to 87.6.
|
|
2420
|
-
"""
|
|
2421
|
-
|
|
2422
|
-
def __init__(self, expand=False, comp_type="mu", **kwargs):
|
|
2423
|
-
super().__init__(**kwargs)
|
|
2424
|
-
self.expand = expand
|
|
2425
|
-
self.comp_type = comp_type.lower()
|
|
2426
|
-
if self.comp_type not in ["mu", "a"]:
|
|
2427
|
-
raise ValueError("comp_type must be 'mu' or 'a'.")
|
|
2428
|
-
|
|
2429
|
-
if self.comp_type == "mu":
|
|
2430
|
-
self._compand_func = self._mu_law_expand if self.expand else self._mu_law_compress
|
|
2431
|
-
else:
|
|
2432
|
-
self._compand_func = self._a_law_expand if self.expand else self._a_law_compress
|
|
2433
|
-
|
|
2434
|
-
@staticmethod
|
|
2435
|
-
def _mu_law_compress(x, mu=255, **kwargs):
|
|
2436
|
-
x = ops.clip(x, -1, 1)
|
|
2437
|
-
return ops.sign(x) * ops.log(1.0 + mu * ops.abs(x)) / ops.log(1.0 + mu)
|
|
2438
|
-
|
|
2439
|
-
@staticmethod
|
|
2440
|
-
def _mu_law_expand(y, mu=255, **kwargs):
|
|
2441
|
-
y = ops.clip(y, -1, 1)
|
|
2442
|
-
return ops.sign(y) * ((1.0 + mu) ** ops.abs(y) - 1.0) / mu
|
|
2443
|
-
|
|
2444
|
-
@staticmethod
|
|
2445
|
-
def _a_law_compress(x, A=87.6, **kwargs):
|
|
2446
|
-
x = ops.clip(x, -1, 1)
|
|
2447
|
-
x_sign = ops.sign(x)
|
|
2448
|
-
x_abs = ops.abs(x)
|
|
2449
|
-
A_log = ops.log(A)
|
|
2450
|
-
val1 = x_sign * A * x_abs / (1.0 + A_log)
|
|
2451
|
-
val2 = x_sign * (1.0 + ops.log(A * x_abs)) / (1.0 + A_log)
|
|
2452
|
-
y = ops.where((x_abs >= 0) & (x_abs < (1.0 / A)), val1, val2)
|
|
2453
|
-
return y
|
|
2454
|
-
|
|
2455
|
-
@staticmethod
|
|
2456
|
-
def _a_law_expand(y, A=87.6, **kwargs):
|
|
2457
|
-
y = ops.clip(y, -1, 1)
|
|
2458
|
-
y_sign = ops.sign(y)
|
|
2459
|
-
y_abs = ops.abs(y)
|
|
2460
|
-
A_log = ops.log(A)
|
|
2461
|
-
val1 = y_sign * y_abs * (1.0 + A_log) / A
|
|
2462
|
-
val2 = y_sign * ops.exp(y_abs * (1.0 + A_log) - 1.0) / A
|
|
2463
|
-
x = ops.where((y_abs >= 0) & (y_abs < (1.0 / (1.0 + A_log))), val1, val2)
|
|
2464
|
-
return x
|
|
2465
|
-
|
|
2466
|
-
def call(self, mu=255, A=87.6, **kwargs):
|
|
2467
|
-
data = kwargs[self.key]
|
|
2468
|
-
|
|
2469
|
-
mu = ops.cast(mu, data.dtype)
|
|
2470
|
-
A = ops.cast(A, data.dtype)
|
|
2471
|
-
|
|
2472
|
-
data_out = self._compand_func(data, mu=mu, A=A)
|
|
2473
|
-
return {self.output_key: data_out}
|
|
2474
|
-
|
|
2475
|
-
|
|
2476
|
-
@ops_registry("downsample")
|
|
2477
|
-
class Downsample(Operation):
|
|
2478
|
-
"""Downsample data along a specific axis."""
|
|
2479
|
-
|
|
2480
|
-
def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
|
|
2481
|
-
super().__init__(
|
|
2482
|
-
additional_output_keys=["sampling_frequency", "n_ax"],
|
|
2483
|
-
**kwargs,
|
|
2484
|
-
)
|
|
2485
|
-
self.factor = factor
|
|
2486
|
-
self.phase = phase
|
|
2487
|
-
self.axis = axis
|
|
2488
|
-
|
|
2489
|
-
def call(self, sampling_frequency=None, n_ax=None, **kwargs):
|
|
2490
|
-
data = kwargs[self.key]
|
|
2491
|
-
length = ops.shape(data)[self.axis]
|
|
2492
|
-
sample_idx = ops.arange(self.phase, length, self.factor)
|
|
2493
|
-
data_downsampled = ops.take(data, sample_idx, axis=self.axis)
|
|
2494
|
-
|
|
2495
|
-
output = {self.output_key: data_downsampled}
|
|
2496
|
-
# downsampling also affects the sampling frequency
|
|
2497
|
-
if sampling_frequency is not None:
|
|
2498
|
-
sampling_frequency = sampling_frequency / self.factor
|
|
2499
|
-
output["sampling_frequency"] = sampling_frequency
|
|
2500
|
-
if n_ax is not None:
|
|
2501
|
-
n_ax = n_ax // self.factor
|
|
2502
|
-
output["n_ax"] = n_ax
|
|
2503
|
-
return output
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
@ops_registry("branched_pipeline")
|
|
2507
|
-
class BranchedPipeline(Operation):
|
|
2508
|
-
"""Operation that processes data through multiple branches.
|
|
2509
|
-
|
|
2510
|
-
This operation takes input data, processes it through multiple parallel branches,
|
|
2511
|
-
and then merges the results from those branches using the specified merge strategy.
|
|
2512
|
-
"""
|
|
2513
|
-
|
|
2514
|
-
def __init__(self, branches=None, merge_strategy="nested", **kwargs):
|
|
2515
|
-
"""Initialize a branched pipeline.
|
|
2516
|
-
|
|
2517
|
-
Args:
|
|
2518
|
-
branches (List[Union[List, Pipeline, Operation]]): List of branch operations
|
|
2519
|
-
merge_strategy (str or callable): How to merge the outputs from branches:
|
|
2520
|
-
- "nested" (default): Return outputs as a dictionary keyed by branch name
|
|
2521
|
-
- "flatten": Flatten outputs by prefixing keys with the branch name
|
|
2522
|
-
- "suffix": Flatten outputs by suffixing keys with the branch name
|
|
2523
|
-
- callable: A custom merge function that accepts the branch outputs dict
|
|
2524
|
-
**kwargs: Additional arguments for the Operation base class
|
|
2525
|
-
"""
|
|
2526
|
-
super().__init__(**kwargs)
|
|
2527
|
-
|
|
2528
|
-
# Convert branch specifications to operation chains
|
|
2529
|
-
if branches is None:
|
|
2530
|
-
branches = []
|
|
2531
|
-
|
|
2532
|
-
self.branches = {}
|
|
2533
|
-
for i, branch in enumerate(branches, start=1):
|
|
2534
|
-
branch_name = f"branch_{i}"
|
|
2535
|
-
# Convert different branch specification types
|
|
2536
|
-
if isinstance(branch, list):
|
|
2537
|
-
# Convert list to operation chain
|
|
2538
|
-
self.branches[branch_name] = make_operation_chain(branch)
|
|
2539
|
-
elif isinstance(branch, (Pipeline, Operation)):
|
|
2540
|
-
# Already a pipeline or operation
|
|
2541
|
-
self.branches[branch_name] = branch
|
|
2542
|
-
else:
|
|
2543
|
-
raise ValueError(
|
|
2544
|
-
f"Branch must be a list, Pipeline, or Operation, got {type(branch)}"
|
|
2545
|
-
)
|
|
2546
|
-
|
|
2547
|
-
# Set merge strategy
|
|
2548
|
-
self.merge_strategy = merge_strategy
|
|
2549
|
-
if isinstance(merge_strategy, str):
|
|
2550
|
-
if merge_strategy == "nested":
|
|
2551
|
-
self._merge_function = lambda outputs: outputs
|
|
2552
|
-
elif merge_strategy == "flatten":
|
|
2553
|
-
self._merge_function = self.flatten_outputs
|
|
2554
|
-
elif merge_strategy == "suffix":
|
|
2555
|
-
self._merge_function = self.suffix_merge_outputs
|
|
2556
|
-
else:
|
|
2557
|
-
raise ValueError(f"Unknown merge_strategy: {merge_strategy}")
|
|
2558
|
-
elif callable(merge_strategy):
|
|
2559
|
-
self._merge_function = merge_strategy
|
|
2560
|
-
else:
|
|
2561
|
-
raise ValueError("Invalid merge_strategy type provided.")
|
|
2562
|
-
|
|
2563
|
-
def call(self, **kwargs):
|
|
2564
|
-
"""Process input through branches and merge results.
|
|
2565
|
-
|
|
2566
|
-
Args:
|
|
2567
|
-
**kwargs: Input keyword arguments
|
|
2568
|
-
|
|
2569
|
-
Returns:
|
|
2570
|
-
dict: Merged outputs from all branches according to merge strategy
|
|
2571
|
-
"""
|
|
2572
|
-
branch_outputs = {}
|
|
2573
|
-
for branch_name, branch in self.branches.items():
|
|
2574
|
-
# Each branch gets a fresh copy of kwargs to avoid interference
|
|
2575
|
-
branch_kwargs = kwargs.copy()
|
|
2576
|
-
|
|
2577
|
-
# Process through the branch
|
|
2578
|
-
branch_result = branch(**branch_kwargs)
|
|
2579
|
-
|
|
2580
|
-
# Store branch outputs
|
|
2581
|
-
branch_outputs[branch_name] = branch_result
|
|
2582
|
-
|
|
2583
|
-
# Apply merge strategy to combine outputs
|
|
2584
|
-
merged_outputs = self._merge_function(branch_outputs)
|
|
2585
|
-
|
|
2586
|
-
return merged_outputs
|
|
2587
|
-
|
|
2588
|
-
def flatten_outputs(self, outputs: dict) -> dict:
|
|
2589
|
-
"""
|
|
2590
|
-
Flatten a nested dictionary by prefixing keys with the branch name.
|
|
2591
|
-
For each branch, the resulting key is "{branch_name}_{original_key}".
|
|
2592
|
-
"""
|
|
2593
|
-
flat = {}
|
|
2594
|
-
for branch_name, branch_dict in outputs.items():
|
|
2595
|
-
for key, value in branch_dict.items():
|
|
2596
|
-
new_key = f"{branch_name}_{key}"
|
|
2597
|
-
if new_key in flat:
|
|
2598
|
-
raise ValueError(f"Key collision detected for {new_key}")
|
|
2599
|
-
flat[new_key] = value
|
|
2600
|
-
return flat
|
|
2601
|
-
|
|
2602
|
-
def suffix_merge_outputs(self, outputs: dict) -> dict:
|
|
2603
|
-
"""
|
|
2604
|
-
Flatten a nested dictionary by suffixing keys with the branch name.
|
|
2605
|
-
For each branch, the resulting key is "{original_key}_{branch_name}".
|
|
2606
|
-
"""
|
|
2607
|
-
flat = {}
|
|
2608
|
-
for branch_name, branch_dict in outputs.items():
|
|
2609
|
-
for key, value in branch_dict.items():
|
|
2610
|
-
new_key = f"{key}_{branch_name}"
|
|
2611
|
-
if new_key in flat:
|
|
2612
|
-
raise ValueError(f"Key collision detected for {new_key}")
|
|
2613
|
-
flat[new_key] = value
|
|
2614
|
-
return flat
|
|
2615
|
-
|
|
2616
|
-
def get_config(self):
|
|
2617
|
-
"""Return the config dictionary for serialization."""
|
|
2618
|
-
config = super().get_config()
|
|
2619
|
-
|
|
2620
|
-
# Add branch configurations
|
|
2621
|
-
branch_configs = {}
|
|
2622
|
-
for branch_name, branch in self.branches.items():
|
|
2623
|
-
if isinstance(branch, Pipeline):
|
|
2624
|
-
# Get the operations list from the Pipeline
|
|
2625
|
-
branch_configs[branch_name] = branch.get_config()
|
|
2626
|
-
elif isinstance(branch, list):
|
|
2627
|
-
# Convert list of operations to list of operation configs
|
|
2628
|
-
branch_op_configs = []
|
|
2629
|
-
for op in branch:
|
|
2630
|
-
branch_op_configs.append(op.get_config())
|
|
2631
|
-
branch_configs[branch_name] = {"operations": branch_op_configs}
|
|
2632
|
-
else:
|
|
2633
|
-
# Single operation
|
|
2634
|
-
branch_configs[branch_name] = branch.get_config()
|
|
2635
|
-
|
|
2636
|
-
# Add merge strategy
|
|
2637
|
-
if isinstance(self.merge_strategy, str):
|
|
2638
|
-
merge_strategy_config = self.merge_strategy
|
|
2639
|
-
else:
|
|
2640
|
-
# For custom functions, use the name if available
|
|
2641
|
-
merge_strategy_config = getattr(self.merge_strategy, "__name__", "custom")
|
|
2642
|
-
|
|
2643
|
-
config.update(
|
|
2644
|
-
{
|
|
2645
|
-
"branches": branch_configs,
|
|
2646
|
-
"merge_strategy": merge_strategy_config,
|
|
2647
|
-
}
|
|
2648
|
-
)
|
|
2649
|
-
|
|
2650
|
-
return config
|
|
2651
|
-
|
|
2652
|
-
def get_dict(self):
|
|
2653
|
-
"""Get the configuration of the operation."""
|
|
2654
|
-
config = super().get_dict()
|
|
2655
|
-
config.update({"name": "branched_pipeline"})
|
|
2656
|
-
|
|
2657
|
-
# Add branches (recursively) to the config
|
|
2658
|
-
branches = {}
|
|
2659
|
-
for branch_name, branch in self.branches.items():
|
|
2660
|
-
if isinstance(branch, Pipeline):
|
|
2661
|
-
branches[branch_name] = branch.get_dict()
|
|
2662
|
-
elif isinstance(branch, list):
|
|
2663
|
-
branches[branch_name] = [op.get_dict() for op in branch]
|
|
2664
|
-
else:
|
|
2665
|
-
branches[branch_name] = branch.get_dict()
|
|
2666
|
-
config["branches"] = branches
|
|
2667
|
-
config["merge_strategy"] = self.merge_strategy
|
|
2668
|
-
return config
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
@ops_registry("threshold")
|
|
2672
|
-
class Threshold(Operation):
|
|
2673
|
-
"""Threshold an array, setting values below/above a threshold to a fill value."""
|
|
2674
|
-
|
|
2675
|
-
def __init__(
|
|
2676
|
-
self,
|
|
2677
|
-
threshold_type="hard",
|
|
2678
|
-
below_threshold=True,
|
|
2679
|
-
fill_value="min",
|
|
2680
|
-
**kwargs,
|
|
2681
|
-
):
|
|
2682
|
-
super().__init__(**kwargs)
|
|
2683
|
-
if threshold_type not in ("hard", "soft"):
|
|
2684
|
-
raise ValueError("threshold_type must be 'hard' or 'soft'")
|
|
2685
|
-
self.threshold_type = threshold_type
|
|
2686
|
-
self.below_threshold = below_threshold
|
|
2687
|
-
self._fill_value_type = fill_value
|
|
2688
|
-
|
|
2689
|
-
# Define threshold function at init
|
|
2690
|
-
if threshold_type == "hard":
|
|
2691
|
-
if below_threshold:
|
|
2692
|
-
self._threshold_func = lambda data, threshold, fill: ops.where(
|
|
2693
|
-
data < threshold, fill, data
|
|
2694
|
-
)
|
|
2695
|
-
else:
|
|
2696
|
-
self._threshold_func = lambda data, threshold, fill: ops.where(
|
|
2697
|
-
data > threshold, fill, data
|
|
2698
|
-
)
|
|
2699
|
-
else: # soft
|
|
2700
|
-
if below_threshold:
|
|
2701
|
-
self._threshold_func = (
|
|
2702
|
-
lambda data, threshold, fill: ops.maximum(data - threshold, 0) + fill
|
|
2703
|
-
)
|
|
2704
|
-
else:
|
|
2705
|
-
self._threshold_func = (
|
|
2706
|
-
lambda data, threshold, fill: ops.minimum(data - threshold, 0) + fill
|
|
2707
|
-
)
|
|
2708
|
-
|
|
2709
|
-
def _resolve_fill_value(self, data, threshold):
|
|
2710
|
-
"""Get the fill value based on the fill_value_type."""
|
|
2711
|
-
fv = self._fill_value_type
|
|
2712
|
-
if isinstance(fv, (int, float)):
|
|
2713
|
-
return ops.convert_to_tensor(fv, dtype=data.dtype)
|
|
2714
|
-
elif fv == "min":
|
|
2715
|
-
return ops.min(data)
|
|
2716
|
-
elif fv == "max":
|
|
2717
|
-
return ops.max(data)
|
|
2718
|
-
elif fv == "threshold":
|
|
2719
|
-
return threshold
|
|
2720
|
-
else:
|
|
2721
|
-
raise ValueError("Unknown fill_value")
|
|
2722
|
-
|
|
2723
|
-
def call(
|
|
2724
|
-
self,
|
|
2725
|
-
threshold=None,
|
|
2726
|
-
percentile=None,
|
|
2727
|
-
**kwargs,
|
|
2728
|
-
):
|
|
2729
|
-
"""Threshold the input data.
|
|
2730
|
-
|
|
2731
|
-
Args:
|
|
2732
|
-
threshold: Numeric threshold.
|
|
2733
|
-
percentile: Percentile to derive threshold from.
|
|
2734
|
-
Returns:
|
|
2735
|
-
Tensor with thresholding applied.
|
|
2736
|
-
"""
|
|
2737
|
-
data = kwargs[self.key]
|
|
2738
|
-
if (threshold is None) == (percentile is None):
|
|
2739
|
-
raise ValueError("Pass either threshold or percentile, not both or neither.")
|
|
2740
|
-
|
|
2741
|
-
if percentile is not None:
|
|
2742
|
-
# Convert percentile to quantile value (0-1 range)
|
|
2743
|
-
threshold = ops.quantile(data, percentile / 100.0)
|
|
2744
|
-
|
|
2745
|
-
fill_value = self._resolve_fill_value(data, threshold)
|
|
2746
|
-
result = self._threshold_func(data, threshold, fill_value)
|
|
2747
|
-
return {self.output_key: result}
|
|
2748
|
-
|
|
2749
|
-
|
|
2750
|
-
@ops_registry("anisotropic_diffusion")
|
|
2751
|
-
class AnisotropicDiffusion(Operation):
|
|
2752
|
-
"""Speckle Reducing Anisotropic Diffusion (SRAD) filter.
|
|
2753
|
-
|
|
2754
|
-
Reference:
|
|
2755
|
-
- https://www.researchgate.net/publication/5602035_Speckle_reducing_anisotropic_diffusion
|
|
2756
|
-
- https://nl.mathworks.com/matlabcentral/fileexchange/54044-image-despeckle-filtering-toolbox
|
|
2757
|
-
"""
|
|
2758
|
-
|
|
2759
|
-
def call(self, niter=100, lmbda=0.1, rect=None, eps=1e-6, **kwargs):
|
|
2760
|
-
"""Anisotropic diffusion filter.
|
|
2761
|
-
|
|
2762
|
-
Assumes input data is non-negative.
|
|
2763
|
-
|
|
2764
|
-
Args:
|
|
2765
|
-
niter: Number of iterations.
|
|
2766
|
-
lmbda: Lambda parameter.
|
|
2767
|
-
rect: Rectangle [x1, y1, x2, y2] for homogeneous noise (optional).
|
|
2768
|
-
eps: Small epsilon for stability.
|
|
2769
|
-
Returns:
|
|
2770
|
-
Filtered image (2D tensor or batch of images).
|
|
2771
|
-
"""
|
|
2772
|
-
data = kwargs[self.key]
|
|
2773
|
-
|
|
2774
|
-
if not self.with_batch_dim:
|
|
2775
|
-
data = ops.expand_dims(data, axis=0)
|
|
2776
|
-
|
|
2777
|
-
batch_size = ops.shape(data)[0]
|
|
2778
|
-
|
|
2779
|
-
results = []
|
|
2780
|
-
for i in range(batch_size):
|
|
2781
|
-
image = data[i]
|
|
2782
|
-
image_out = self._anisotropic_diffusion_single(image, niter, lmbda, rect, eps)
|
|
2783
|
-
results.append(image_out)
|
|
2784
|
-
|
|
2785
|
-
result = ops.stack(results, axis=0)
|
|
2786
|
-
|
|
2787
|
-
if not self.with_batch_dim:
|
|
2788
|
-
result = ops.squeeze(result, axis=0)
|
|
2789
|
-
|
|
2790
|
-
return {self.output_key: result}
|
|
2791
|
-
|
|
2792
|
-
def _anisotropic_diffusion_single(self, image, niter, lmbda, rect, eps):
|
|
2793
|
-
"""Apply anisotropic diffusion to a single image (2D)."""
|
|
2794
|
-
image = ops.exp(image)
|
|
2795
|
-
M, N = image.shape
|
|
2796
|
-
|
|
2797
|
-
for _ in range(niter):
|
|
2798
|
-
iN = ops.concatenate([image[1:], ops.zeros((1, N), dtype=image.dtype)], axis=0)
|
|
2799
|
-
iS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), image[:-1]], axis=0)
|
|
2800
|
-
jW = ops.concatenate([image[:, 1:], ops.zeros((M, 1), dtype=image.dtype)], axis=1)
|
|
2801
|
-
jE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), image[:, :-1]], axis=1)
|
|
2802
|
-
|
|
2803
|
-
if rect is not None:
|
|
2804
|
-
x1, y1, x2, y2 = rect
|
|
2805
|
-
imageuniform = image[x1:x2, y1:y2]
|
|
2806
|
-
q0_squared = (ops.std(imageuniform) / (ops.mean(imageuniform) + eps)) ** 2
|
|
2807
|
-
|
|
2808
|
-
dN = iN - image
|
|
2809
|
-
dS = iS - image
|
|
2810
|
-
dW = jW - image
|
|
2811
|
-
dE = jE - image
|
|
2812
|
-
|
|
2813
|
-
G2 = (dN**2 + dS**2 + dW**2 + dE**2) / (image**2 + eps)
|
|
2814
|
-
L = (dN + dS + dW + dE) / (image + eps)
|
|
2815
|
-
num = (0.5 * G2) - ((1 / 16) * (L**2))
|
|
2816
|
-
den = (1 + ((1 / 4) * L)) ** 2
|
|
2817
|
-
q_squared = num / (den + eps)
|
|
2818
|
-
|
|
2819
|
-
if rect is not None:
|
|
2820
|
-
den = (q_squared - q0_squared) / (q0_squared * (1 + q0_squared) + eps)
|
|
2821
|
-
c = 1.0 / (1 + den)
|
|
2822
|
-
cS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), c[:-1]], axis=0)
|
|
2823
|
-
cE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), c[:, :-1]], axis=1)
|
|
2824
|
-
|
|
2825
|
-
D = (cS * dS) + (c * dN) + (cE * dE) + (c * dW)
|
|
2826
|
-
image = image + (lmbda / 4) * D
|
|
2827
|
-
|
|
2828
|
-
result = ops.log(image)
|
|
2829
|
-
return result
|
|
2830
|
-
|
|
2831
|
-
|
|
2832
|
-
@ops_registry("channels_to_complex")
|
|
2833
|
-
class ChannelsToComplex(Operation):
|
|
2834
|
-
def call(self, **kwargs):
|
|
2835
|
-
data = kwargs[self.key]
|
|
2836
|
-
output = channels_to_complex(data)
|
|
2837
|
-
return {self.output_key: output}
|
|
2838
|
-
|
|
2839
|
-
|
|
2840
|
-
@ops_registry("complex_to_channels")
|
|
2841
|
-
class ComplexToChannels(Operation):
|
|
2842
|
-
def __init__(self, axis=-1, **kwargs):
|
|
2843
|
-
super().__init__(**kwargs)
|
|
2844
|
-
self.axis = axis
|
|
2845
|
-
|
|
2846
|
-
def call(self, **kwargs):
|
|
2847
|
-
data = kwargs[self.key]
|
|
2848
|
-
output = complex_to_channels(data, axis=self.axis)
|
|
2849
|
-
return {self.output_key: output}
|
|
2850
|
-
|
|
2851
|
-
|
|
2852
|
-
def demodulate_not_jitable(
|
|
2853
|
-
rf_data,
|
|
2854
|
-
sampling_frequency=None,
|
|
2855
|
-
center_frequency=None,
|
|
2856
|
-
bandwidth=None,
|
|
2857
|
-
filter_coeff=None,
|
|
2858
|
-
):
|
|
2859
|
-
"""Demodulates an RF signal to complex base-band (IQ).
|
|
2860
|
-
|
|
2861
|
-
Demodulates the radiofrequency (RF) bandpass signals and returns the
|
|
2862
|
-
Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary)
|
|
2863
|
-
part contains the in-phase (quadrature) component.
|
|
2864
|
-
|
|
2865
|
-
This function operates (i.e. demodulates) on the RF signal over the
|
|
2866
|
-
(fast-) time axis which is assumed to be the last axis.
|
|
2867
|
-
|
|
2868
|
-
Args:
|
|
2869
|
-
rf_data (ndarray): real valued input array of size [..., n_ax, n_el].
|
|
2870
|
-
second to last axis is fast-time axis.
|
|
2871
|
-
sampling_frequency (float): the sampling frequency of the RF signals (in Hz).
|
|
2872
|
-
Only not necessary when filter_coeff is provided.
|
|
2873
|
-
center_frequency (float, optional): represents the center frequency (in Hz).
|
|
2874
|
-
Defaults to None.
|
|
2875
|
-
bandwidth (float, optional): Bandwidth of RF signal in % of center
|
|
2876
|
-
frequency. Defaults to None.
|
|
2877
|
-
The bandwidth in % is defined by:
|
|
2878
|
-
B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency).
|
|
2879
|
-
The cutoff frequency:
|
|
2880
|
-
Wn = Bandwidth_in_Hz/sampling_frequency, i.e:
|
|
2881
|
-
Wn = B*(center_frequency/100)/sampling_frequency.
|
|
2882
|
-
filter_coeff (list, optional): (b, a), numerator and denominator coefficients
|
|
2883
|
-
of FIR filter for quadratic band pass filter. All other parameters are ignored
|
|
2884
|
-
if filter_coeff are provided. Instead the given filter_coeff is directly used.
|
|
2885
|
-
If not provided, a filter is derived from the other params (sampling_frequency,
|
|
2886
|
-
center_frequency, bandwidth).
|
|
2887
|
-
see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html
|
|
2888
|
-
|
|
2889
|
-
Returns:
|
|
2890
|
-
iq_data (ndarray): complex valued base-band signal.
|
|
2891
|
-
|
|
2892
|
-
"""
|
|
2893
|
-
rf_data = ops.convert_to_numpy(rf_data)
|
|
2894
|
-
assert np.isreal(rf_data).all(), f"RF must contain real RF signals, got {rf_data.dtype}"
|
|
2895
|
-
|
|
2896
|
-
input_shape = rf_data.shape
|
|
2897
|
-
n_dim = len(input_shape)
|
|
2898
|
-
if n_dim > 2:
|
|
2899
|
-
*_, n_ax, n_el = input_shape
|
|
2900
|
-
else:
|
|
2901
|
-
n_ax, n_el = input_shape
|
|
2902
|
-
|
|
2903
|
-
if filter_coeff is None:
|
|
2904
|
-
assert sampling_frequency is not None, "provide sampling_frequency when no filter is given."
|
|
2905
|
-
# Time vector
|
|
2906
|
-
t = np.arange(n_ax) / sampling_frequency
|
|
2907
|
-
t0 = 0
|
|
2908
|
-
t = t + t0
|
|
2909
|
-
|
|
2910
|
-
# Estimate center frequency
|
|
2911
|
-
if center_frequency is None:
|
|
2912
|
-
# Keep a maximum of 100 randomly selected scanlines
|
|
2913
|
-
idx = np.arange(n_el)
|
|
2914
|
-
if n_el > 100:
|
|
2915
|
-
idx = np.random.permutation(idx)[:100]
|
|
2916
|
-
# Power Spectrum
|
|
2917
|
-
P = np.sum(
|
|
2918
|
-
np.abs(np.fft.fft(np.take(rf_data, idx, axis=-1), axis=-2)) ** 2,
|
|
2919
|
-
axis=-1,
|
|
2920
|
-
)
|
|
2921
|
-
P = P[: n_ax // 2]
|
|
2922
|
-
# Carrier frequency
|
|
2923
|
-
idx = np.sum(np.arange(n_ax // 2) * P) / np.sum(P)
|
|
2924
|
-
center_frequency = idx * sampling_frequency / n_ax
|
|
2925
|
-
|
|
2926
|
-
# Normalized cut-off frequency
|
|
2927
|
-
if bandwidth is None:
|
|
2928
|
-
Wn = min(2 * center_frequency / sampling_frequency, 0.5)
|
|
2929
|
-
bandwidth = center_frequency * Wn
|
|
2930
|
-
else:
|
|
2931
|
-
assert np.isscalar(bandwidth), "The signal bandwidth (in %) must be a scalar."
|
|
2932
|
-
assert (bandwidth > 0) & (bandwidth <= 200), (
|
|
2933
|
-
"The signal bandwidth (in %) must be within the interval of ]0,200]."
|
|
2934
|
-
)
|
|
2935
|
-
# bandwidth in Hz
|
|
2936
|
-
bandwidth = center_frequency * bandwidth / 100
|
|
2937
|
-
Wn = bandwidth / sampling_frequency
|
|
2938
|
-
assert (Wn > 0) & (Wn <= 1), (
|
|
2939
|
-
"The normalized cutoff frequency is not within the interval of (0,1). "
|
|
2940
|
-
"Check the input parameters!"
|
|
2941
|
-
)
|
|
2942
|
-
|
|
2943
|
-
# Down-mixing of the RF signals
|
|
2944
|
-
carrier = np.exp(-1j * 2 * np.pi * center_frequency * t)
|
|
2945
|
-
# add the singleton dimensions
|
|
2946
|
-
carrier = np.reshape(carrier, (*[1] * (n_dim - 2), n_ax, 1))
|
|
2947
|
-
iq_data = rf_data * carrier
|
|
2948
|
-
|
|
2949
|
-
# Low-pass filter
|
|
2950
|
-
N = 5
|
|
2951
|
-
b, a = scipy.signal.butter(N, Wn, "low")
|
|
2952
|
-
|
|
2953
|
-
# factor 2: to preserve the envelope amplitude
|
|
2954
|
-
iq_data = scipy.signal.filtfilt(b, a, iq_data, axis=-2) * 2
|
|
2955
|
-
|
|
2956
|
-
# Display a warning message if harmful aliasing is suspected
|
|
2957
|
-
# the RF signal is undersampled
|
|
2958
|
-
if sampling_frequency < (2 * center_frequency + bandwidth):
|
|
2959
|
-
# lower and higher frequencies of the bandpass signal
|
|
2960
|
-
fL = center_frequency - bandwidth / 2
|
|
2961
|
-
fH = center_frequency + bandwidth / 2
|
|
2962
|
-
n = fH // (fH - fL)
|
|
2963
|
-
harmless_aliasing = any(
|
|
2964
|
-
(2 * fH / np.arange(1, n) <= sampling_frequency)
|
|
2965
|
-
& (sampling_frequency <= 2 * fL / np.arange(1, n))
|
|
2966
|
-
)
|
|
2967
|
-
if not harmless_aliasing:
|
|
2968
|
-
log.warning(
|
|
2969
|
-
"rf2iq:harmful_aliasing Harmful aliasing is present: the aliases"
|
|
2970
|
-
" are not mutually exclusive!"
|
|
2971
|
-
)
|
|
2972
|
-
else:
|
|
2973
|
-
b, a = filter_coeff
|
|
2974
|
-
iq_data = scipy.signal.lfilter(b, a, rf_data, axis=-2) * 2
|
|
2975
|
-
|
|
2976
|
-
return iq_data
|
|
2977
|
-
|
|
2978
|
-
|
|
2979
|
-
def upmix(iq_data, sampling_frequency, center_frequency, upsampling_rate=6):
|
|
2980
|
-
"""Upsamples and upmixes complex base-band signals (IQ) to RF.
|
|
2981
|
-
|
|
2982
|
-
Args:
|
|
2983
|
-
iq_data (ndarray): complex valued input array of size [..., n_ax, n_el]. second
|
|
2984
|
-
to last axis is fast-time axis.
|
|
2985
|
-
sampling_frequency (float): the sampling frequency of the input IQ signal (in Hz).
|
|
2986
|
-
resulting sampling_frequency of RF data is upsampling_rate times higher.
|
|
2987
|
-
center_frequency (float, optional): represents the center frequency (in Hz).
|
|
2988
|
-
|
|
2989
|
-
Returns:
|
|
2990
|
-
rf_data (ndarray): output real valued rf data.
|
|
2991
|
-
"""
|
|
2992
|
-
assert iq_data.dtype in [
|
|
2993
|
-
"complex64",
|
|
2994
|
-
"complex128",
|
|
2995
|
-
], "IQ must contain all complex signals."
|
|
2996
|
-
|
|
2997
|
-
input_shape = iq_data.shape
|
|
2998
|
-
n_dim = len(input_shape)
|
|
2999
|
-
if n_dim > 2:
|
|
3000
|
-
*_, n_ax, _ = input_shape
|
|
3001
|
-
else:
|
|
3002
|
-
n_ax, _ = input_shape
|
|
3003
|
-
|
|
3004
|
-
# Time vector
|
|
3005
|
-
n_ax_up = n_ax * upsampling_rate
|
|
3006
|
-
sampling_frequency_up = sampling_frequency * upsampling_rate
|
|
3007
|
-
|
|
3008
|
-
t = ops.arange(n_ax_up, dtype="float32") / sampling_frequency_up
|
|
3009
|
-
t0 = 0
|
|
3010
|
-
t = t + t0
|
|
3011
|
-
|
|
3012
|
-
iq_data_upsampled = resample(
|
|
3013
|
-
iq_data,
|
|
3014
|
-
n_samples=n_ax_up,
|
|
3015
|
-
axis=-2,
|
|
3016
|
-
order=1,
|
|
3017
|
-
)
|
|
3018
|
-
|
|
3019
|
-
# Up-mixing of the IQ signals
|
|
3020
|
-
t = ops.cast(t, dtype="complex64")
|
|
3021
|
-
center_frequency = ops.cast(center_frequency, dtype="complex64")
|
|
3022
|
-
carrier = ops.exp(1j * 2 * np.pi * center_frequency * t)
|
|
3023
|
-
carrier = ops.reshape(carrier, (*[1] * (n_dim - 2), n_ax_up, 1))
|
|
3024
|
-
|
|
3025
|
-
rf_data = iq_data_upsampled * carrier
|
|
3026
|
-
rf_data = ops.real(rf_data) * ops.sqrt(2)
|
|
3027
|
-
|
|
3028
|
-
return ops.cast(rf_data, "float32")
|
|
3029
|
-
|
|
3030
|
-
|
|
3031
|
-
def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
|
|
3032
|
-
"""Band pass filter
|
|
3033
|
-
|
|
3034
|
-
Args:
|
|
3035
|
-
num_taps (int): number of taps in filter.
|
|
3036
|
-
sampling_frequency (float): sample frequency in Hz.
|
|
3037
|
-
f1 (float): cutoff frequency in Hz of left band edge.
|
|
3038
|
-
f2 (float): cutoff frequency in Hz of right band edge.
|
|
3039
|
-
|
|
3040
|
-
Returns:
|
|
3041
|
-
ndarray: band pass filter
|
|
3042
|
-
"""
|
|
3043
|
-
bpf = scipy.signal.firwin(num_taps, [f1, f2], pass_zero=False, fs=sampling_frequency)
|
|
3044
|
-
return bpf
|
|
3045
|
-
|
|
3046
|
-
|
|
3047
|
-
def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
|
|
3048
|
-
"""Design complex low-pass filter.
|
|
3049
|
-
|
|
3050
|
-
The filter is a low-pass FIR filter modulated to the center frequency.
|
|
3051
|
-
|
|
3052
|
-
Args:
|
|
3053
|
-
num_taps (int): number of taps in filter.
|
|
3054
|
-
sampling_frequency (float): sample frequency.
|
|
3055
|
-
f (float): center frequency.
|
|
3056
|
-
bw (float): bandwidth in Hz.
|
|
3057
|
-
|
|
3058
|
-
Raises:
|
|
3059
|
-
ValueError: if cutoff frequency (bw / 2) is not within (0, sampling_frequency / 2)
|
|
3060
|
-
|
|
3061
|
-
Returns:
|
|
3062
|
-
ndarray: Complex-valued low-pass filter
|
|
3063
|
-
"""
|
|
3064
|
-
cutoff = bw / 2
|
|
3065
|
-
if not (0 < cutoff < sampling_frequency / 2):
|
|
3066
|
-
raise ValueError(
|
|
3067
|
-
f"Cutoff frequency must be within (0, sampling_frequency / 2), "
|
|
3068
|
-
f"got {cutoff} Hz, must be within (0, {sampling_frequency / 2}) Hz"
|
|
3069
|
-
)
|
|
3070
|
-
# Design real-valued low-pass filter
|
|
3071
|
-
lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
|
|
3072
|
-
# Modulate to center frequency to make it complex
|
|
3073
|
-
time_points = np.arange(num_taps) / sampling_frequency
|
|
3074
|
-
lpf_complex = lpf * np.exp(1j * 2 * np.pi * f * time_points)
|
|
3075
|
-
return lpf_complex
|
|
3076
|
-
|
|
3077
|
-
|
|
3078
|
-
def complex_to_channels(complex_data, axis=-1):
|
|
3079
|
-
"""Unroll complex data to separate channels.
|
|
3080
|
-
|
|
3081
|
-
Args:
|
|
3082
|
-
complex_data (complex ndarray): complex input data.
|
|
3083
|
-
axis (int, optional): on which axis to extend. Defaults to -1.
|
|
3084
|
-
|
|
3085
|
-
Returns:
|
|
3086
|
-
ndarray: real array with real and imaginary components
|
|
3087
|
-
unrolled over two channels at axis.
|
|
3088
|
-
"""
|
|
3089
|
-
# assert ops.iscomplex(complex_data).any()
|
|
3090
|
-
q_data = ops.imag(complex_data)
|
|
3091
|
-
i_data = ops.real(complex_data)
|
|
3092
|
-
|
|
3093
|
-
i_data = ops.expand_dims(i_data, axis=axis)
|
|
3094
|
-
q_data = ops.expand_dims(q_data, axis=axis)
|
|
3095
|
-
|
|
3096
|
-
iq_data = ops.concatenate((i_data, q_data), axis=axis)
|
|
3097
|
-
return iq_data
|
|
3098
|
-
|
|
3099
|
-
|
|
3100
|
-
def channels_to_complex(data):
|
|
3101
|
-
"""Convert array with real and imaginary components at
|
|
3102
|
-
different channels to complex data array.
|
|
3103
|
-
|
|
3104
|
-
Args:
|
|
3105
|
-
data (ndarray): input data, with at 0 index of axis
|
|
3106
|
-
real component and 1 index of axis the imaginary.
|
|
3107
|
-
|
|
3108
|
-
Returns:
|
|
3109
|
-
ndarray: complex array with real and imaginary components.
|
|
3110
|
-
"""
|
|
3111
|
-
assert data.shape[-1] == 2, "Data must have two channels."
|
|
3112
|
-
data = ops.cast(data, "complex64")
|
|
3113
|
-
return data[..., 0] + 1j * data[..., 1]
|
|
3114
|
-
|
|
3115
|
-
|
|
3116
|
-
def hilbert(x, N: int = None, axis=-1):
|
|
3117
|
-
"""Manual implementation of the Hilbert transform function. The function
|
|
3118
|
-
returns the analytical signal.
|
|
3119
|
-
|
|
3120
|
-
Operated in the Fourier domain.
|
|
3121
|
-
|
|
3122
|
-
Note:
|
|
3123
|
-
THIS IS NOT THE MATHEMATICAL THE HILBERT TRANSFORM as you will find it on
|
|
3124
|
-
wikipedia, but computes the analytical signal. The implementation reproduces
|
|
3125
|
-
the behavior of the `scipy.signal.hilbert` function.
|
|
3126
|
-
|
|
3127
|
-
Args:
|
|
3128
|
-
x (ndarray): input data of any shape.
|
|
3129
|
-
N (int, optional): number of points in the FFT. Defaults to None.
|
|
3130
|
-
axis (int, optional): axis to operate on. Defaults to -1.
|
|
3131
|
-
Returns:
|
|
3132
|
-
x (ndarray): complex iq data of any shape.k
|
|
3133
|
-
|
|
3134
|
-
"""
|
|
3135
|
-
input_shape = x.shape
|
|
3136
|
-
n_dim = len(input_shape)
|
|
3137
|
-
|
|
3138
|
-
n_ax = input_shape[axis]
|
|
3139
|
-
|
|
3140
|
-
if axis < 0:
|
|
3141
|
-
axis = n_dim + axis
|
|
3142
|
-
|
|
3143
|
-
if N is not None:
|
|
3144
|
-
if N < n_ax:
|
|
3145
|
-
raise ValueError("N must be greater or equal to n_ax.")
|
|
3146
|
-
# only pad along the axis, use manual padding
|
|
3147
|
-
pad = N - n_ax
|
|
3148
|
-
zeros = ops.zeros(
|
|
3149
|
-
input_shape[:axis] + (pad,) + input_shape[axis + 1 :],
|
|
3150
|
-
)
|
|
3151
|
-
|
|
3152
|
-
x = ops.concatenate((x, zeros), axis=axis)
|
|
3153
|
-
else:
|
|
3154
|
-
N = n_ax
|
|
3155
|
-
|
|
3156
|
-
# Create filter to zero out negative frequencies
|
|
3157
|
-
h = np.zeros(N)
|
|
3158
|
-
if N % 2 == 0:
|
|
3159
|
-
h[0] = h[N // 2] = 1
|
|
3160
|
-
h[1 : N // 2] = 2
|
|
3161
|
-
else:
|
|
3162
|
-
h[0] = 1
|
|
3163
|
-
h[1 : (N + 1) // 2] = 2
|
|
3164
|
-
|
|
3165
|
-
idx = list(range(n_dim))
|
|
3166
|
-
# make sure axis gets to the end for fft (operates on last axis)
|
|
3167
|
-
idx.remove(axis)
|
|
3168
|
-
idx.append(axis)
|
|
3169
|
-
x = ops.transpose(x, idx)
|
|
3170
|
-
|
|
3171
|
-
if x.ndim > 1:
|
|
3172
|
-
ind = [np.newaxis] * x.ndim
|
|
3173
|
-
ind[-1] = slice(None)
|
|
3174
|
-
h = h[tuple(ind)]
|
|
3175
|
-
|
|
3176
|
-
h = ops.convert_to_tensor(h)
|
|
3177
|
-
h = ops.cast(h, "complex64")
|
|
3178
|
-
h = h + 1j * ops.zeros_like(h)
|
|
3179
|
-
|
|
3180
|
-
Xf_r, Xf_i = ops.fft((x, ops.zeros_like(x)))
|
|
3181
|
-
|
|
3182
|
-
Xf_r = ops.cast(Xf_r, "complex64")
|
|
3183
|
-
Xf_i = ops.cast(Xf_i, "complex64")
|
|
3184
|
-
|
|
3185
|
-
Xf = Xf_r + 1j * Xf_i
|
|
3186
|
-
Xf = Xf * h
|
|
3187
|
-
|
|
3188
|
-
# x = np.fft.ifft(Xf)
|
|
3189
|
-
# do manual ifft using fft
|
|
3190
|
-
Xf_r = ops.real(Xf)
|
|
3191
|
-
Xf_i = ops.imag(Xf)
|
|
3192
|
-
Xf_r_inv, Xf_i_inv = ops.fft((Xf_r, -Xf_i))
|
|
3193
|
-
|
|
3194
|
-
Xf_i_inv = ops.cast(Xf_i_inv, "complex64")
|
|
3195
|
-
Xf_r_inv = ops.cast(Xf_r_inv, "complex64")
|
|
3196
|
-
|
|
3197
|
-
x = Xf_r_inv / N
|
|
3198
|
-
x = x + 1j * (-Xf_i_inv / N)
|
|
3199
|
-
|
|
3200
|
-
# switch back to original shape
|
|
3201
|
-
idx = list(range(n_dim))
|
|
3202
|
-
idx.insert(axis, idx.pop(-1))
|
|
3203
|
-
x = ops.transpose(x, idx)
|
|
3204
|
-
return x
|
|
3205
|
-
|
|
3206
|
-
|
|
3207
|
-
def demodulate(data, center_frequency, sampling_frequency, axis=-3):
|
|
3208
|
-
"""Demodulates the input data to baseband. The function computes the analytical
|
|
3209
|
-
signal (the signal with negative frequencies removed) and then shifts the spectrum
|
|
3210
|
-
of the signal to baseband by multiplying with a complex exponential. Where the
|
|
3211
|
-
spectrum was centered around `center_frequency` before, it is now centered around
|
|
3212
|
-
0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts
|
|
3213
|
-
are stored in two real-valued channels.
|
|
3214
|
-
|
|
3215
|
-
Args:
|
|
3216
|
-
data (ops.Tensor): The input data to demodulate of shape `(..., axis, ..., 1)`.
|
|
3217
|
-
center_frequency (float): The center frequency of the signal.
|
|
3218
|
-
sampling_frequency (float): The sampling frequency of the signal.
|
|
3219
|
-
axis (int, optional): The axis along which to demodulate. Defaults to -3.
|
|
3220
|
-
|
|
3221
|
-
Returns:
|
|
3222
|
-
ops.Tensor: The demodulated IQ data of shape `(..., axis, ..., 2)`.
|
|
3223
|
-
"""
|
|
3224
|
-
# Compute the analytical signal
|
|
3225
|
-
analytical_signal = hilbert(data, axis=axis)
|
|
3226
|
-
|
|
3227
|
-
# Define frequency indices
|
|
3228
|
-
frequency_indices = ops.arange(analytical_signal.shape[axis])
|
|
3229
|
-
|
|
3230
|
-
# Expand the frequency indices to match the shape of the RF data
|
|
3231
|
-
indexing = [None] * data.ndim
|
|
3232
|
-
indexing[axis] = slice(None)
|
|
3233
|
-
indexing = tuple(indexing)
|
|
3234
|
-
frequency_indices_shaped_like_rf = frequency_indices[indexing]
|
|
3235
|
-
|
|
3236
|
-
# Cast to complex64
|
|
3237
|
-
center_frequency = ops.cast(center_frequency, dtype="complex64")
|
|
3238
|
-
sampling_frequency = ops.cast(sampling_frequency, dtype="complex64")
|
|
3239
|
-
frequency_indices_shaped_like_rf = ops.cast(frequency_indices_shaped_like_rf, dtype="complex64")
|
|
3240
|
-
|
|
3241
|
-
# Shift to baseband
|
|
3242
|
-
phasor_exponent = (
|
|
3243
|
-
-1j * 2 * np.pi * center_frequency * frequency_indices_shaped_like_rf / sampling_frequency
|
|
3244
|
-
)
|
|
3245
|
-
iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
|
|
3246
|
-
|
|
3247
|
-
# Split the complex signal into two channels
|
|
3248
|
-
iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
|
|
3249
|
-
|
|
3250
|
-
return iq_data_two_channel
|
|
3251
|
-
|
|
3252
|
-
|
|
3253
|
-
def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
|
|
3254
|
-
"""Compute the time of the peak of each waveform in a stack of waveforms.
|
|
3255
|
-
|
|
3256
|
-
Args:
|
|
3257
|
-
waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
|
|
3258
|
-
center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
|
|
3259
|
-
(n_waveforms,) or a scalar if all waveforms have the same center frequency.
|
|
3260
|
-
waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
|
|
3261
|
-
|
|
3262
|
-
Returns:
|
|
3263
|
-
ndarray: The time to peak for each waveform in seconds.
|
|
3264
|
-
"""
|
|
3265
|
-
t_peak = []
|
|
3266
|
-
center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
|
|
3267
|
-
for waveform, center_frequency in zip(waveforms, center_frequencies):
|
|
3268
|
-
t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
|
|
3269
|
-
return ops.stack(t_peak)
|
|
3270
|
-
|
|
3271
|
-
|
|
3272
|
-
def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
|
|
3273
|
-
"""Compute the time of the peak of the waveform.
|
|
3274
|
-
|
|
3275
|
-
Args:
|
|
3276
|
-
waveform (ndarray): The waveform of shape (n_samples).
|
|
3277
|
-
center_frequency (float): The center frequency of the waveform in Hz.
|
|
3278
|
-
waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
|
|
3279
|
-
|
|
3280
|
-
Returns:
|
|
3281
|
-
float: The time to peak for the waveform in seconds.
|
|
3282
|
-
"""
|
|
3283
|
-
n_samples = waveform.shape[0]
|
|
3284
|
-
if n_samples == 0:
|
|
3285
|
-
raise ValueError("Waveform has zero samples.")
|
|
3286
|
-
|
|
3287
|
-
waveforms_iq_complex_channels = demodulate(
|
|
3288
|
-
waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
|
|
3289
|
-
)
|
|
3290
|
-
waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
|
|
3291
|
-
envelope = ops.abs(waveforms_iq_complex)
|
|
3292
|
-
peak_idx = ops.argmax(envelope, axis=-1)
|
|
3293
|
-
t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
|
|
3294
|
-
return t_peak
|