zea 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +3 -3
- zea/agent/masks.py +2 -2
- zea/agent/selection.py +3 -3
- zea/backend/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +1 -1
- zea/beamform/beamformer.py +4 -2
- zea/beamform/pfield.py +2 -2
- zea/data/augmentations.py +1 -1
- zea/data/convert/__main__.py +93 -52
- zea/data/convert/camus.py +8 -2
- zea/data/convert/echonet.py +1 -1
- zea/data/convert/echonetlvh/__init__.py +1 -1
- zea/data/convert/verasonics.py +810 -772
- zea/data/data_format.py +0 -2
- zea/data/file.py +28 -0
- zea/data/preset_utils.py +1 -1
- zea/display.py +1 -1
- zea/doppler.py +5 -5
- zea/func/__init__.py +109 -0
- zea/{tensor_ops.py → func/tensor.py} +32 -8
- zea/func/ultrasound.py +500 -0
- zea/internal/_generate_keras_ops.py +5 -5
- zea/metrics.py +6 -5
- zea/models/diffusion.py +1 -1
- zea/models/echonetlvh.py +1 -1
- zea/models/gmm.py +1 -1
- zea/ops/__init__.py +188 -0
- zea/ops/base.py +442 -0
- zea/{keras_ops.py → ops/keras_ops.py} +2 -2
- zea/ops/pipeline.py +1472 -0
- zea/ops/tensor.py +356 -0
- zea/ops/ultrasound.py +890 -0
- zea/probes.py +2 -10
- zea/scan.py +17 -20
- zea/tools/fit_scan_cone.py +1 -1
- zea/tools/selection_tool.py +1 -1
- zea/tracking/lucas_kanade.py +1 -1
- zea/tracking/segmentation.py +1 -1
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/METADATA +3 -1
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/RECORD +43 -37
- zea/ops.py +0 -3534
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops/pipeline.py
ADDED
|
@@ -0,0 +1,1472 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Dict, List, Union
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
import numpy as np
|
|
6
|
+
import yaml
|
|
7
|
+
from keras import ops
|
|
8
|
+
|
|
9
|
+
from zea import log
|
|
10
|
+
from zea.backend import jit
|
|
11
|
+
from zea.config import Config
|
|
12
|
+
from zea.func.tensor import (
|
|
13
|
+
vmap,
|
|
14
|
+
)
|
|
15
|
+
from zea.func.ultrasound import channels_to_complex, complex_to_channels
|
|
16
|
+
from zea.internal.core import (
|
|
17
|
+
DataTypes,
|
|
18
|
+
ZEADecoderJSON,
|
|
19
|
+
ZEAEncoderJSON,
|
|
20
|
+
dict_to_tensor,
|
|
21
|
+
)
|
|
22
|
+
from zea.internal.core import Object as ZEAObject
|
|
23
|
+
from zea.internal.registry import ops_registry
|
|
24
|
+
from zea.ops.base import (
|
|
25
|
+
Operation,
|
|
26
|
+
get_ops,
|
|
27
|
+
)
|
|
28
|
+
from zea.ops.tensor import Normalize
|
|
29
|
+
from zea.ops.ultrasound import (
|
|
30
|
+
Demodulate,
|
|
31
|
+
EnvelopeDetect,
|
|
32
|
+
LogCompress,
|
|
33
|
+
PfieldWeighting,
|
|
34
|
+
ReshapeGrid,
|
|
35
|
+
TOFCorrection,
|
|
36
|
+
)
|
|
37
|
+
from zea.probes import Probe
|
|
38
|
+
from zea.scan import Scan
|
|
39
|
+
from zea.utils import (
|
|
40
|
+
FunctionTimer,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@ops_registry("pipeline")
|
|
45
|
+
class Pipeline:
|
|
46
|
+
"""Pipeline class for processing ultrasound data through a series of operations."""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
operations: List[Operation],
|
|
51
|
+
with_batch_dim: bool = True,
|
|
52
|
+
jit_options: Union[str, None] = "ops",
|
|
53
|
+
jit_kwargs: dict | None = None,
|
|
54
|
+
name="pipeline",
|
|
55
|
+
validate=True,
|
|
56
|
+
timed: bool = False,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Initialize a pipeline.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
operations (list): A list of Operation instances representing the operations
|
|
63
|
+
to be performed.
|
|
64
|
+
with_batch_dim (bool, optional): Whether operations should expect a batch dimension.
|
|
65
|
+
Defaults to True.
|
|
66
|
+
jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None.
|
|
67
|
+
|
|
68
|
+
- "pipeline": compiles the entire pipeline as a single function.
|
|
69
|
+
This may be faster but does not preserve python control flow, such as caching.
|
|
70
|
+
|
|
71
|
+
- "ops": compiles each operation separately. This preserves python control flow and
|
|
72
|
+
caching functionality, but speeds up the operations.
|
|
73
|
+
|
|
74
|
+
- None: disables JIT compilation.
|
|
75
|
+
|
|
76
|
+
Defaults to "ops".
|
|
77
|
+
|
|
78
|
+
jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler.
|
|
79
|
+
name (str, optional): The name of the pipeline. Defaults to "pipeline".
|
|
80
|
+
validate (bool, optional): Whether to validate the pipeline. Defaults to True.
|
|
81
|
+
timed (bool, optional): Whether to time each operation. Defaults to False.
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
self._call_pipeline = self.call
|
|
85
|
+
self.name = name
|
|
86
|
+
|
|
87
|
+
self._pipeline_layers = operations
|
|
88
|
+
|
|
89
|
+
if jit_options not in ["pipeline", "ops", None]:
|
|
90
|
+
raise ValueError("jit_options must be 'pipeline', 'ops', or None")
|
|
91
|
+
|
|
92
|
+
self.with_batch_dim = with_batch_dim
|
|
93
|
+
self._validate_flag = validate
|
|
94
|
+
|
|
95
|
+
# Setup timer
|
|
96
|
+
if jit_options == "pipeline" and timed:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
"timed=True cannot be used with jit_options='pipeline' as the entire "
|
|
99
|
+
"pipeline is compiled into a single function. Try setting jit_options to "
|
|
100
|
+
"'ops' or None."
|
|
101
|
+
)
|
|
102
|
+
if timed:
|
|
103
|
+
log.warning(
|
|
104
|
+
"Timer has been initialized for the pipeline. To get an accurate timing estimate, "
|
|
105
|
+
"the `block_until_ready()` is used, which will slow down the execution, so "
|
|
106
|
+
"do not use for regular processing!"
|
|
107
|
+
)
|
|
108
|
+
self._callable_layers = self._get_timed_operations()
|
|
109
|
+
else:
|
|
110
|
+
self._callable_layers = self._pipeline_layers
|
|
111
|
+
self._timed = timed
|
|
112
|
+
|
|
113
|
+
if validate:
|
|
114
|
+
self.validate()
|
|
115
|
+
else:
|
|
116
|
+
log.warning("Pipeline validation is disabled, make sure to validate manually.")
|
|
117
|
+
|
|
118
|
+
if jit_kwargs is None:
|
|
119
|
+
jit_kwargs = {}
|
|
120
|
+
|
|
121
|
+
if keras.backend.backend() == "jax" and self.static_params != []:
|
|
122
|
+
jit_kwargs = {"static_argnames": self.static_params}
|
|
123
|
+
|
|
124
|
+
self.jit_kwargs = jit_kwargs
|
|
125
|
+
self.jit_options = jit_options # will handle the jit compilation
|
|
126
|
+
|
|
127
|
+
self._logged_difference_keys = False
|
|
128
|
+
|
|
129
|
+
# Do not log again for nested pipelines
|
|
130
|
+
for nested_pipeline in self._nested_pipelines:
|
|
131
|
+
nested_pipeline._logged_difference_keys = True
|
|
132
|
+
|
|
133
|
+
def needs(self, key) -> bool:
|
|
134
|
+
"""Check if the pipeline needs a specific key at the input."""
|
|
135
|
+
return key in self.needs_keys
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def _nested_pipelines(self):
|
|
139
|
+
return [operation for operation in self.operations if isinstance(operation, Pipeline)]
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def output_keys(self) -> set:
|
|
143
|
+
"""All output keys the pipeline guarantees to produce."""
|
|
144
|
+
output_keys = set()
|
|
145
|
+
for operation in self.operations:
|
|
146
|
+
output_keys.update(operation.output_keys)
|
|
147
|
+
return output_keys
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def valid_keys(self) -> set:
|
|
151
|
+
"""Get a set of valid keys for the pipeline.
|
|
152
|
+
|
|
153
|
+
This is all keys that can be passed to the pipeline as input.
|
|
154
|
+
"""
|
|
155
|
+
valid_keys = set()
|
|
156
|
+
for operation in self.operations:
|
|
157
|
+
valid_keys.update(operation.valid_keys)
|
|
158
|
+
return valid_keys
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def static_params(self) -> List[str]:
|
|
162
|
+
"""Get a list of static parameters for the pipeline."""
|
|
163
|
+
static_params = []
|
|
164
|
+
for operation in self.operations:
|
|
165
|
+
static_params.extend(operation.static_params)
|
|
166
|
+
return list(set(static_params))
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def needs_keys(self) -> set:
|
|
170
|
+
"""Get a set of all input keys needed by the pipeline.
|
|
171
|
+
|
|
172
|
+
Will keep track of keys that are already provided by previous operations.
|
|
173
|
+
"""
|
|
174
|
+
needs = set()
|
|
175
|
+
has_so_far = set()
|
|
176
|
+
previous_operation = None
|
|
177
|
+
for operation in self.operations:
|
|
178
|
+
if previous_operation is not None:
|
|
179
|
+
has_so_far.update(previous_operation.output_keys)
|
|
180
|
+
needs.update(operation.needs_keys - has_so_far)
|
|
181
|
+
previous_operation = operation
|
|
182
|
+
return needs
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def from_default(
|
|
186
|
+
cls,
|
|
187
|
+
beamformer="delay_and_sum",
|
|
188
|
+
num_patches=100,
|
|
189
|
+
baseband=False,
|
|
190
|
+
enable_pfield=False,
|
|
191
|
+
timed=False,
|
|
192
|
+
**kwargs,
|
|
193
|
+
) -> "Pipeline":
|
|
194
|
+
"""Create a default pipeline.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
beamformer (str): Type of beamformer to use. Currently supporting,
|
|
198
|
+
"delay_and_sum" and "delay_multiply_and_sum". Defaults to "delay_and_sum".
|
|
199
|
+
num_patches (int): Number of patches for the PatchedGrid operation.
|
|
200
|
+
Defaults to 100. If you get an out of memory error, try to increase this number.
|
|
201
|
+
baseband (bool): If True, assume the input data is baseband (I/Q) data,
|
|
202
|
+
which has 2 channels (last dim). Defaults to False, which assumes RF data,
|
|
203
|
+
so input signal has a single channel dim and is still on carrier frequency.
|
|
204
|
+
enable_pfield (bool): If True, apply PfieldWeighting. Defaults to False.
|
|
205
|
+
This will calculate pressure field and only beamform the data to those locations.
|
|
206
|
+
timed (bool, optional): Whether to time each operation. Defaults to False.
|
|
207
|
+
**kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
|
|
208
|
+
|
|
209
|
+
"""
|
|
210
|
+
operations = []
|
|
211
|
+
|
|
212
|
+
# Add the demodulate operation
|
|
213
|
+
if not baseband:
|
|
214
|
+
operations.append(Demodulate())
|
|
215
|
+
|
|
216
|
+
# Add beamforming ops
|
|
217
|
+
operations.append(
|
|
218
|
+
Beamform(
|
|
219
|
+
beamformer=beamformer,
|
|
220
|
+
num_patches=num_patches,
|
|
221
|
+
enable_pfield=enable_pfield,
|
|
222
|
+
),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Add display ops
|
|
226
|
+
operations += [
|
|
227
|
+
EnvelopeDetect(),
|
|
228
|
+
Normalize(),
|
|
229
|
+
LogCompress(),
|
|
230
|
+
]
|
|
231
|
+
return cls(operations, timed=timed, **kwargs)
|
|
232
|
+
|
|
233
|
+
def copy(self) -> "Pipeline":
|
|
234
|
+
"""Create a copy of the pipeline."""
|
|
235
|
+
return Pipeline(
|
|
236
|
+
self._pipeline_layers.copy(),
|
|
237
|
+
with_batch_dim=self.with_batch_dim,
|
|
238
|
+
jit_options=self.jit_options,
|
|
239
|
+
jit_kwargs=self.jit_kwargs,
|
|
240
|
+
name=self.name,
|
|
241
|
+
validate=self._validate_flag,
|
|
242
|
+
timed=self._timed,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def reinitialize(self):
|
|
246
|
+
"""Reinitialize the pipeline in place."""
|
|
247
|
+
self.__init__(
|
|
248
|
+
self._pipeline_layers,
|
|
249
|
+
with_batch_dim=self.with_batch_dim,
|
|
250
|
+
jit_options=self.jit_options,
|
|
251
|
+
jit_kwargs=self.jit_kwargs,
|
|
252
|
+
name=self.name,
|
|
253
|
+
validate=self._validate_flag,
|
|
254
|
+
timed=self._timed,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def prepend(self, operation: Operation):
|
|
258
|
+
"""Prepend an operation to the pipeline."""
|
|
259
|
+
self._pipeline_layers.insert(0, operation)
|
|
260
|
+
self.reinitialize()
|
|
261
|
+
|
|
262
|
+
def append(self, operation: Operation):
|
|
263
|
+
"""Append an operation to the pipeline."""
|
|
264
|
+
self._pipeline_layers.append(operation)
|
|
265
|
+
self.reinitialize()
|
|
266
|
+
|
|
267
|
+
def insert(self, index: int, operation: Operation):
|
|
268
|
+
"""Insert an operation at a specific index in the pipeline."""
|
|
269
|
+
if index < 0 or index > len(self._pipeline_layers):
|
|
270
|
+
raise IndexError("Index out of bounds for inserting operation.")
|
|
271
|
+
self._pipeline_layers.insert(index, operation)
|
|
272
|
+
self.reinitialize()
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def operations(self):
|
|
276
|
+
"""Alias for self.layers to match the zea naming convention"""
|
|
277
|
+
return self._pipeline_layers
|
|
278
|
+
|
|
279
|
+
def reset_timer(self):
|
|
280
|
+
"""Reset the timer for timed operations."""
|
|
281
|
+
if self._timed:
|
|
282
|
+
self._callable_layers = self._get_timed_operations()
|
|
283
|
+
else:
|
|
284
|
+
log.warning(
|
|
285
|
+
"Timer has not been initialized. Set timed=True when initializing the pipeline."
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _get_timed_operations(self):
|
|
289
|
+
"""Get a list of timed operations."""
|
|
290
|
+
self.timer = FunctionTimer()
|
|
291
|
+
return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
|
|
292
|
+
|
|
293
|
+
def call(self, **inputs):
|
|
294
|
+
"""Process input data through the pipeline."""
|
|
295
|
+
for operation in self._callable_layers:
|
|
296
|
+
try:
|
|
297
|
+
outputs = operation(**inputs)
|
|
298
|
+
except KeyError as exc:
|
|
299
|
+
raise KeyError(
|
|
300
|
+
f"[zea.Pipeline] Operation '{operation.__class__.__name__}' "
|
|
301
|
+
f"requires input key '{exc.args[0]}', "
|
|
302
|
+
"but it was not provided in the inputs.\n"
|
|
303
|
+
"Check whether the objects (such as `zea.Scan`) passed to "
|
|
304
|
+
"`pipeline.prepare_parameters()` contain all required keys.\n"
|
|
305
|
+
f"Current list of all passed keys: {list(inputs.keys())}\n"
|
|
306
|
+
f"Valid keys for this pipeline: {self.valid_keys}"
|
|
307
|
+
) from exc
|
|
308
|
+
except Exception as exc:
|
|
309
|
+
raise RuntimeError(
|
|
310
|
+
f"[zea.Pipeline] Error in operation '{operation.__class__.__name__}': {exc}"
|
|
311
|
+
)
|
|
312
|
+
inputs = outputs
|
|
313
|
+
return outputs
|
|
314
|
+
|
|
315
|
+
def __call__(self, return_numpy=False, **inputs):
|
|
316
|
+
"""Process input data through the pipeline."""
|
|
317
|
+
|
|
318
|
+
if any(key in inputs for key in ["probe", "scan", "config"]) or any(
|
|
319
|
+
isinstance(arg, ZEAObject) for arg in inputs.values()
|
|
320
|
+
):
|
|
321
|
+
raise ValueError(
|
|
322
|
+
"Probe, Scan and Config objects should be first processed with "
|
|
323
|
+
"`Pipeline.prepare_parameters` before calling the pipeline. "
|
|
324
|
+
"e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if any(isinstance(arg, str) for arg in inputs.values()):
|
|
328
|
+
raise ValueError(
|
|
329
|
+
"Pipeline does not support string inputs. "
|
|
330
|
+
"Please ensure all inputs are convertible to tensors."
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
if not self._logged_difference_keys:
|
|
334
|
+
difference_keys = set(inputs.keys()) - self.valid_keys
|
|
335
|
+
if difference_keys:
|
|
336
|
+
log.debug(
|
|
337
|
+
f"[zea.Pipeline] The following input keys are not used by the pipeline: "
|
|
338
|
+
f"{difference_keys}. Make sure this is intended. "
|
|
339
|
+
"This warning will only be shown once."
|
|
340
|
+
)
|
|
341
|
+
self._logged_difference_keys = True
|
|
342
|
+
|
|
343
|
+
## PROCESSING
|
|
344
|
+
outputs = self._call_pipeline(**inputs)
|
|
345
|
+
|
|
346
|
+
## PREPARE OUTPUT
|
|
347
|
+
if return_numpy:
|
|
348
|
+
# Convert tensors to numpy arrays but preserve None values
|
|
349
|
+
outputs = {
|
|
350
|
+
k: ops.convert_to_numpy(v) if ops.is_tensor(v) else v for k, v in outputs.items()
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
return outputs
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def jit_options(self):
|
|
357
|
+
"""Get the jit_options property of the pipeline."""
|
|
358
|
+
return self._jit_options
|
|
359
|
+
|
|
360
|
+
@jit_options.setter
|
|
361
|
+
def jit_options(self, value: Union[str, None]):
|
|
362
|
+
"""Set the jit_options property of the pipeline."""
|
|
363
|
+
self._jit_options = value
|
|
364
|
+
if value == "pipeline":
|
|
365
|
+
assert self.jittable, log.error(
|
|
366
|
+
"jit_options 'pipeline' cannot be used as the entire pipeline is not jittable. "
|
|
367
|
+
"The following operations are not jittable: "
|
|
368
|
+
f"{self.unjitable_ops}. "
|
|
369
|
+
"Try setting jit_options to 'ops' or None."
|
|
370
|
+
)
|
|
371
|
+
self.jit()
|
|
372
|
+
return
|
|
373
|
+
else:
|
|
374
|
+
self.unjit()
|
|
375
|
+
|
|
376
|
+
for operation in self.operations:
|
|
377
|
+
if isinstance(operation, Pipeline):
|
|
378
|
+
operation.jit_options = value
|
|
379
|
+
else:
|
|
380
|
+
if operation.jittable and operation._jit_compile:
|
|
381
|
+
operation.set_jit(value == "ops")
|
|
382
|
+
|
|
383
|
+
def jit(self):
|
|
384
|
+
"""JIT compile the pipeline."""
|
|
385
|
+
self._call_pipeline = jit(self.call, **self.jit_kwargs)
|
|
386
|
+
|
|
387
|
+
def unjit(self):
|
|
388
|
+
"""Un-JIT compile the pipeline."""
|
|
389
|
+
self._call_pipeline = self.call
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def jittable(self):
|
|
393
|
+
"""Check if all operations in the pipeline are jittable."""
|
|
394
|
+
return all(operation.jittable for operation in self.operations)
|
|
395
|
+
|
|
396
|
+
@property
|
|
397
|
+
def unjitable_ops(self):
|
|
398
|
+
"""Get a list of operations that are not jittable."""
|
|
399
|
+
return [operation for operation in self.operations if not operation.jittable]
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def with_batch_dim(self):
|
|
403
|
+
"""Get the with_batch_dim property of the pipeline."""
|
|
404
|
+
return self._with_batch_dim
|
|
405
|
+
|
|
406
|
+
@with_batch_dim.setter
|
|
407
|
+
def with_batch_dim(self, value):
|
|
408
|
+
"""Set the with_batch_dim property of the pipeline."""
|
|
409
|
+
self._with_batch_dim = value
|
|
410
|
+
for operation in self.operations:
|
|
411
|
+
operation.with_batch_dim = value
|
|
412
|
+
|
|
413
|
+
@property
|
|
414
|
+
def input_data_type(self):
|
|
415
|
+
"""Get the input_data_type property of the pipeline."""
|
|
416
|
+
return self.operations[0].input_data_type
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def output_data_type(self):
|
|
420
|
+
"""Get the output_data_type property of the pipeline."""
|
|
421
|
+
return self.operations[-1].output_data_type
|
|
422
|
+
|
|
423
|
+
def validate(self):
|
|
424
|
+
"""Validate the pipeline by checking the compatibility of the operations."""
|
|
425
|
+
operations = self.operations
|
|
426
|
+
for i in range(len(operations) - 1):
|
|
427
|
+
if operations[i].output_data_type is None:
|
|
428
|
+
continue
|
|
429
|
+
if operations[i + 1].input_data_type is None:
|
|
430
|
+
continue
|
|
431
|
+
if operations[i].output_data_type != operations[i + 1].input_data_type:
|
|
432
|
+
raise ValueError(
|
|
433
|
+
f"Operation {operations[i].__class__.__name__} output data type "
|
|
434
|
+
f"({operations[i].output_data_type}) is not compatible "
|
|
435
|
+
f"with the input data type ({operations[i + 1].input_data_type}) "
|
|
436
|
+
f"of operation {operations[i + 1].__class__.__name__}"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
def set_params(self, **params):
|
|
440
|
+
"""Set parameters for the operations in the pipeline by adding them to the cache."""
|
|
441
|
+
for operation in self.operations:
|
|
442
|
+
operation_params = {
|
|
443
|
+
key: value for key, value in params.items() if key in operation.valid_keys
|
|
444
|
+
}
|
|
445
|
+
if operation_params:
|
|
446
|
+
operation.set_input_cache(operation_params)
|
|
447
|
+
|
|
448
|
+
def get_params(self, per_operation: bool = False):
|
|
449
|
+
"""Get a snapshot of the current parameters of the operations in the pipeline.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
per_operation (bool): If True, return a list of dictionaries for each operation.
|
|
453
|
+
If False, return a single dictionary with all parameters combined.
|
|
454
|
+
"""
|
|
455
|
+
if per_operation:
|
|
456
|
+
return [operation._input_cache.copy() for operation in self.operations]
|
|
457
|
+
else:
|
|
458
|
+
params = {}
|
|
459
|
+
for operation in self.operations:
|
|
460
|
+
params.update(operation._input_cache)
|
|
461
|
+
return params
|
|
462
|
+
|
|
463
|
+
def __str__(self):
|
|
464
|
+
"""String representation of the pipeline."""
|
|
465
|
+
operations = []
|
|
466
|
+
for operation in self.operations:
|
|
467
|
+
if isinstance(operation, Pipeline):
|
|
468
|
+
operations.append(f"{operation.__class__.__name__}({str(operation)})")
|
|
469
|
+
else:
|
|
470
|
+
operations.append(operation.__class__.__name__)
|
|
471
|
+
string = " -> ".join(operations)
|
|
472
|
+
return string
|
|
473
|
+
|
|
474
|
+
def __repr__(self):
|
|
475
|
+
"""String representation of the pipeline."""
|
|
476
|
+
operations = []
|
|
477
|
+
for operation in self.operations:
|
|
478
|
+
if isinstance(operation, Pipeline):
|
|
479
|
+
operations.append(repr(operation))
|
|
480
|
+
else:
|
|
481
|
+
operations.append(operation.__class__.__name__)
|
|
482
|
+
return f"<Pipeline {self.name}=({', '.join(operations)})>"
|
|
483
|
+
|
|
484
|
+
@classmethod
|
|
485
|
+
def load(cls, file_path: str, **kwargs) -> "Pipeline":
|
|
486
|
+
"""Load a pipeline from a JSON or YAML file."""
|
|
487
|
+
if file_path.endswith(".json"):
|
|
488
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
489
|
+
json_str = f.read()
|
|
490
|
+
return pipeline_from_json(json_str, **kwargs)
|
|
491
|
+
elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
|
|
492
|
+
return pipeline_from_yaml(file_path, **kwargs)
|
|
493
|
+
else:
|
|
494
|
+
raise ValueError("File must have extension .json, .yaml, or .yml")
|
|
495
|
+
|
|
496
|
+
def get_dict(self) -> dict:
|
|
497
|
+
"""Convert the pipeline to a dictionary."""
|
|
498
|
+
config = {}
|
|
499
|
+
config["name"] = ops_registry.get_name(self)
|
|
500
|
+
config["operations"] = self._pipeline_to_list(self)
|
|
501
|
+
config["params"] = {
|
|
502
|
+
"with_batch_dim": self.with_batch_dim,
|
|
503
|
+
"jit_options": self.jit_options,
|
|
504
|
+
"jit_kwargs": self.jit_kwargs,
|
|
505
|
+
}
|
|
506
|
+
return config
|
|
507
|
+
|
|
508
|
+
@staticmethod
|
|
509
|
+
def _pipeline_to_list(pipeline):
|
|
510
|
+
"""Convert the pipeline to a list of operations."""
|
|
511
|
+
ops_list = []
|
|
512
|
+
for op in pipeline.operations:
|
|
513
|
+
ops_list.append(op.get_dict())
|
|
514
|
+
return ops_list
|
|
515
|
+
|
|
516
|
+
@classmethod
|
|
517
|
+
def from_config(cls, config: Dict, **kwargs) -> "Pipeline":
|
|
518
|
+
"""Create a pipeline from a dictionary or ``zea.Config`` object.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
config (dict or Config): Configuration dictionary or ``zea.Config`` object.
|
|
522
|
+
**kwargs: Additional keyword arguments to be passed to the pipeline.
|
|
523
|
+
|
|
524
|
+
Note:
|
|
525
|
+
Must have a ``pipeline`` key with a subkey ``operations``.
|
|
526
|
+
|
|
527
|
+
Example:
|
|
528
|
+
.. doctest::
|
|
529
|
+
|
|
530
|
+
>>> from zea import Config, Pipeline
|
|
531
|
+
>>> config = Config(
|
|
532
|
+
... {
|
|
533
|
+
... "operations": [
|
|
534
|
+
... "identity",
|
|
535
|
+
... ],
|
|
536
|
+
... }
|
|
537
|
+
... )
|
|
538
|
+
>>> pipeline = Pipeline.from_config(config)
|
|
539
|
+
"""
|
|
540
|
+
return pipeline_from_config(Config(config), **kwargs)
|
|
541
|
+
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_yaml(cls, file_path: str, **kwargs) -> "Pipeline":
|
|
544
|
+
"""Create a pipeline from a YAML file.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
file_path (str): Path to the YAML file.
|
|
548
|
+
**kwargs: Additional keyword arguments to be passed to the pipeline.
|
|
549
|
+
|
|
550
|
+
Note:
|
|
551
|
+
Must have the a `pipeline` key with a subkey `operations`.
|
|
552
|
+
|
|
553
|
+
Example:
|
|
554
|
+
.. doctest::
|
|
555
|
+
|
|
556
|
+
>>> import yaml
|
|
557
|
+
>>> from zea import Config
|
|
558
|
+
>>> # Create a sample pipeline YAML file
|
|
559
|
+
>>> pipeline_dict = {
|
|
560
|
+
... "operations": [
|
|
561
|
+
... "identity",
|
|
562
|
+
... ],
|
|
563
|
+
... }
|
|
564
|
+
>>> with open("pipeline.yaml", "w") as f:
|
|
565
|
+
... yaml.dump(pipeline_dict, f)
|
|
566
|
+
>>> from zea.ops import Pipeline
|
|
567
|
+
>>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
|
|
568
|
+
"""
|
|
569
|
+
return pipeline_from_yaml(file_path, **kwargs)
|
|
570
|
+
|
|
571
|
+
@classmethod
|
|
572
|
+
def from_json(cls, json_string: str, **kwargs) -> "Pipeline":
|
|
573
|
+
"""Create a pipeline from a JSON string.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
json_string (str): JSON string representing the pipeline.
|
|
577
|
+
**kwargs: Additional keyword arguments to be passed to the pipeline.
|
|
578
|
+
|
|
579
|
+
Note:
|
|
580
|
+
Must have the `operations` key.
|
|
581
|
+
|
|
582
|
+
Example:
|
|
583
|
+
```python
|
|
584
|
+
json_string = '{"operations": ["identity"]}'
|
|
585
|
+
pipeline = Pipeline.from_json(json_string)
|
|
586
|
+
```
|
|
587
|
+
"""
|
|
588
|
+
return pipeline_from_json(json_string, **kwargs)
|
|
589
|
+
|
|
590
|
+
def to_config(self) -> Config:
|
|
591
|
+
"""Convert the pipeline to a `zea.Config` object."""
|
|
592
|
+
return pipeline_to_config(self)
|
|
593
|
+
|
|
594
|
+
def to_json(self) -> str:
|
|
595
|
+
"""Convert the pipeline to a JSON string."""
|
|
596
|
+
return pipeline_to_json(self)
|
|
597
|
+
|
|
598
|
+
def to_yaml(self, file_path: str) -> None:
|
|
599
|
+
"""Convert the pipeline to a YAML file."""
|
|
600
|
+
pipeline_to_yaml(self, file_path)
|
|
601
|
+
|
|
602
|
+
@property
|
|
603
|
+
def key(self) -> str:
|
|
604
|
+
"""Input key of the pipeline."""
|
|
605
|
+
return self.operations[0].key
|
|
606
|
+
|
|
607
|
+
@property
|
|
608
|
+
def output_key(self) -> str:
|
|
609
|
+
"""Output key of the pipeline."""
|
|
610
|
+
return self.operations[-1].output_key
|
|
611
|
+
|
|
612
|
+
def __eq__(self, other):
|
|
613
|
+
"""Check if two pipelines are equal."""
|
|
614
|
+
if not isinstance(other, Pipeline):
|
|
615
|
+
return False
|
|
616
|
+
|
|
617
|
+
# Compare the operations in both pipelines
|
|
618
|
+
if len(self.operations) != len(other.operations):
|
|
619
|
+
return False
|
|
620
|
+
|
|
621
|
+
for op1, op2 in zip(self.operations, other.operations):
|
|
622
|
+
if not op1 == op2:
|
|
623
|
+
return False
|
|
624
|
+
|
|
625
|
+
return True
|
|
626
|
+
|
|
627
|
+
def prepare_parameters(
|
|
628
|
+
self,
|
|
629
|
+
probe: Probe = None,
|
|
630
|
+
scan: Scan = None,
|
|
631
|
+
config: Config = None,
|
|
632
|
+
**kwargs,
|
|
633
|
+
):
|
|
634
|
+
"""Prepare Probe, Scan and Config objects for the pipeline.
|
|
635
|
+
|
|
636
|
+
Serializes `zea.core.Object` instances and converts them to
|
|
637
|
+
dictionary of tensors.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
probe: Probe object.
|
|
641
|
+
scan: Scan object.
|
|
642
|
+
config: Config object.
|
|
643
|
+
**kwargs: Additional keyword arguments to be included in the inputs.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
dict: Dictionary of inputs with all values as tensors.
|
|
647
|
+
"""
|
|
648
|
+
# Initialize dictionaries for probe, scan, and config
|
|
649
|
+
probe_dict, scan_dict, config_dict = {}, {}, {}
|
|
650
|
+
|
|
651
|
+
# Process args to extract Probe, Scan, and Config objects
|
|
652
|
+
if probe is not None:
|
|
653
|
+
assert isinstance(probe, Probe), (
|
|
654
|
+
f"Expected an instance of `zea.probes.Probe`, got {type(probe)}"
|
|
655
|
+
)
|
|
656
|
+
probe_dict = probe.to_tensor(keep_as_is=self.static_params)
|
|
657
|
+
|
|
658
|
+
if scan is not None:
|
|
659
|
+
assert isinstance(scan, Scan), (
|
|
660
|
+
f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
|
|
661
|
+
)
|
|
662
|
+
scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
|
|
663
|
+
|
|
664
|
+
if config is not None:
|
|
665
|
+
assert isinstance(config, Config), (
|
|
666
|
+
f"Expected an instance of `zea.config.Config`, got {type(config)}"
|
|
667
|
+
)
|
|
668
|
+
config_dict.update(config.to_tensor(keep_as_is=self.static_params))
|
|
669
|
+
|
|
670
|
+
# Convert all kwargs to tensors
|
|
671
|
+
tensor_kwargs = dict_to_tensor(kwargs, keep_as_is=self.static_params)
|
|
672
|
+
|
|
673
|
+
# combine probe, scan, config and kwargs
|
|
674
|
+
# explicitly so we know which keys overwrite which
|
|
675
|
+
# kwargs > config > scan > probe
|
|
676
|
+
inputs = {
|
|
677
|
+
**probe_dict,
|
|
678
|
+
**scan_dict,
|
|
679
|
+
**config_dict,
|
|
680
|
+
**tensor_kwargs,
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
return inputs
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
@ops_registry("branched_pipeline")
|
|
687
|
+
class BranchedPipeline(Operation):
|
|
688
|
+
"""Operation that processes data through multiple branches.
|
|
689
|
+
|
|
690
|
+
This operation takes input data, processes it through multiple parallel branches,
|
|
691
|
+
and then merges the results from those branches using the specified merge strategy.
|
|
692
|
+
"""
|
|
693
|
+
|
|
694
|
+
def __init__(self, branches=None, merge_strategy="nested", **kwargs):
|
|
695
|
+
"""Initialize a branched pipeline.
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
branches (List[Union[List, Pipeline, Operation]]): List of branch operations
|
|
699
|
+
merge_strategy (str or callable): How to merge the outputs from branches:
|
|
700
|
+
- "nested" (default): Return outputs as a dictionary keyed by branch name
|
|
701
|
+
- "flatten": Flatten outputs by prefixing keys with the branch name
|
|
702
|
+
- "suffix": Flatten outputs by suffixing keys with the branch name
|
|
703
|
+
- callable: A custom merge function that accepts the branch outputs dict
|
|
704
|
+
**kwargs: Additional arguments for the Operation base class
|
|
705
|
+
"""
|
|
706
|
+
super().__init__(**kwargs)
|
|
707
|
+
|
|
708
|
+
# Convert branch specifications to operation chains
|
|
709
|
+
if branches is None:
|
|
710
|
+
branches = []
|
|
711
|
+
|
|
712
|
+
self.branches = {}
|
|
713
|
+
for i, branch in enumerate(branches, start=1):
|
|
714
|
+
branch_name = f"branch_{i}"
|
|
715
|
+
# Convert different branch specification types
|
|
716
|
+
if isinstance(branch, list):
|
|
717
|
+
# Convert list to operation chain
|
|
718
|
+
self.branches[branch_name] = make_operation_chain(branch)
|
|
719
|
+
elif isinstance(branch, (Pipeline, Operation)):
|
|
720
|
+
# Already a pipeline or operation
|
|
721
|
+
self.branches[branch_name] = branch
|
|
722
|
+
else:
|
|
723
|
+
raise ValueError(
|
|
724
|
+
f"Branch must be a list, Pipeline, or Operation, got {type(branch)}"
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# Set merge strategy
|
|
728
|
+
self.merge_strategy = merge_strategy
|
|
729
|
+
if isinstance(merge_strategy, str):
|
|
730
|
+
if merge_strategy == "nested":
|
|
731
|
+
self._merge_function = lambda outputs: outputs
|
|
732
|
+
elif merge_strategy == "flatten":
|
|
733
|
+
self._merge_function = self.flatten_outputs
|
|
734
|
+
elif merge_strategy == "suffix":
|
|
735
|
+
self._merge_function = self.suffix_merge_outputs
|
|
736
|
+
else:
|
|
737
|
+
raise ValueError(f"Unknown merge_strategy: {merge_strategy}")
|
|
738
|
+
elif callable(merge_strategy):
|
|
739
|
+
self._merge_function = merge_strategy
|
|
740
|
+
else:
|
|
741
|
+
raise ValueError("Invalid merge_strategy type provided.")
|
|
742
|
+
|
|
743
|
+
def call(self, **kwargs):
|
|
744
|
+
"""Process input through branches and merge results.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
**kwargs: Input keyword arguments
|
|
748
|
+
|
|
749
|
+
Returns:
|
|
750
|
+
dict: Merged outputs from all branches according to merge strategy
|
|
751
|
+
"""
|
|
752
|
+
branch_outputs = {}
|
|
753
|
+
for branch_name, branch in self.branches.items():
|
|
754
|
+
# Each branch gets a fresh copy of kwargs to avoid interference
|
|
755
|
+
branch_kwargs = kwargs.copy()
|
|
756
|
+
|
|
757
|
+
# Process through the branch
|
|
758
|
+
branch_result = branch(**branch_kwargs)
|
|
759
|
+
|
|
760
|
+
# Store branch outputs
|
|
761
|
+
branch_outputs[branch_name] = branch_result
|
|
762
|
+
|
|
763
|
+
# Apply merge strategy to combine outputs
|
|
764
|
+
merged_outputs = self._merge_function(branch_outputs)
|
|
765
|
+
|
|
766
|
+
return merged_outputs
|
|
767
|
+
|
|
768
|
+
def flatten_outputs(self, outputs: dict) -> dict:
|
|
769
|
+
"""
|
|
770
|
+
Flatten a nested dictionary by prefixing keys with the branch name.
|
|
771
|
+
For each branch, the resulting key is "{branch_name}_{original_key}".
|
|
772
|
+
"""
|
|
773
|
+
flat = {}
|
|
774
|
+
for branch_name, branch_dict in outputs.items():
|
|
775
|
+
for key, value in branch_dict.items():
|
|
776
|
+
new_key = f"{branch_name}_{key}"
|
|
777
|
+
if new_key in flat:
|
|
778
|
+
raise ValueError(f"Key collision detected for {new_key}")
|
|
779
|
+
flat[new_key] = value
|
|
780
|
+
return flat
|
|
781
|
+
|
|
782
|
+
def suffix_merge_outputs(self, outputs: dict) -> dict:
|
|
783
|
+
"""
|
|
784
|
+
Flatten a nested dictionary by suffixing keys with the branch name.
|
|
785
|
+
For each branch, the resulting key is "{original_key}_{branch_name}".
|
|
786
|
+
"""
|
|
787
|
+
flat = {}
|
|
788
|
+
for branch_name, branch_dict in outputs.items():
|
|
789
|
+
for key, value in branch_dict.items():
|
|
790
|
+
new_key = f"{key}_{branch_name}"
|
|
791
|
+
if new_key in flat:
|
|
792
|
+
raise ValueError(f"Key collision detected for {new_key}")
|
|
793
|
+
flat[new_key] = value
|
|
794
|
+
return flat
|
|
795
|
+
|
|
796
|
+
def get_config(self):
|
|
797
|
+
"""Return the config dictionary for serialization."""
|
|
798
|
+
config = super().get_config()
|
|
799
|
+
|
|
800
|
+
# Add branch configurations
|
|
801
|
+
branch_configs = {}
|
|
802
|
+
for branch_name, branch in self.branches.items():
|
|
803
|
+
if isinstance(branch, Pipeline):
|
|
804
|
+
# Get the operations list from the Pipeline
|
|
805
|
+
branch_configs[branch_name] = branch.get_config()
|
|
806
|
+
elif isinstance(branch, list):
|
|
807
|
+
# Convert list of operations to list of operation configs
|
|
808
|
+
branch_op_configs = []
|
|
809
|
+
for op in branch:
|
|
810
|
+
branch_op_configs.append(op.get_config())
|
|
811
|
+
branch_configs[branch_name] = {"operations": branch_op_configs}
|
|
812
|
+
else:
|
|
813
|
+
# Single operation
|
|
814
|
+
branch_configs[branch_name] = branch.get_config()
|
|
815
|
+
|
|
816
|
+
# Add merge strategy
|
|
817
|
+
if isinstance(self.merge_strategy, str):
|
|
818
|
+
merge_strategy_config = self.merge_strategy
|
|
819
|
+
else:
|
|
820
|
+
# For custom functions, use the name if available
|
|
821
|
+
merge_strategy_config = getattr(self.merge_strategy, "__name__", "custom")
|
|
822
|
+
|
|
823
|
+
config.update(
|
|
824
|
+
{
|
|
825
|
+
"branches": branch_configs,
|
|
826
|
+
"merge_strategy": merge_strategy_config,
|
|
827
|
+
}
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
return config
|
|
831
|
+
|
|
832
|
+
def get_dict(self):
|
|
833
|
+
"""Get the configuration of the operation."""
|
|
834
|
+
config = super().get_dict()
|
|
835
|
+
config.update({"name": "branched_pipeline"})
|
|
836
|
+
|
|
837
|
+
# Add branches (recursively) to the config
|
|
838
|
+
branches = {}
|
|
839
|
+
for branch_name, branch in self.branches.items():
|
|
840
|
+
if isinstance(branch, Pipeline):
|
|
841
|
+
branches[branch_name] = branch.get_dict()
|
|
842
|
+
elif isinstance(branch, list):
|
|
843
|
+
branches[branch_name] = [op.get_dict() for op in branch]
|
|
844
|
+
else:
|
|
845
|
+
branches[branch_name] = branch.get_dict()
|
|
846
|
+
config["branches"] = branches
|
|
847
|
+
config["merge_strategy"] = self.merge_strategy
|
|
848
|
+
return config
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
@ops_registry("map")
|
|
852
|
+
class Map(Pipeline):
|
|
853
|
+
"""
|
|
854
|
+
A pipeline that maps its operations over specified input arguments.
|
|
855
|
+
|
|
856
|
+
This can be used to reduce memory usage by processing data in chunks.
|
|
857
|
+
|
|
858
|
+
Notes
|
|
859
|
+
-----
|
|
860
|
+
- When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline.
|
|
861
|
+
- Changing anything other than ``self.output_key`` in the dict will not be propagated.
|
|
862
|
+
- Will be jitted as a single operation, not the individual operations.
|
|
863
|
+
- This class handles the batching.
|
|
864
|
+
|
|
865
|
+
For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for
|
|
866
|
+
jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_.
|
|
867
|
+
|
|
868
|
+
Example
|
|
869
|
+
-------
|
|
870
|
+
.. doctest::
|
|
871
|
+
|
|
872
|
+
>>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection
|
|
873
|
+
|
|
874
|
+
>>> # apply operations in batches of 8
|
|
875
|
+
>>> # in this case, over the first axis of "data"
|
|
876
|
+
>>> # or more specifically, process 8 transmits at a time
|
|
877
|
+
|
|
878
|
+
>>> pipeline_mapped = Map(
|
|
879
|
+
... [
|
|
880
|
+
... Demodulate(),
|
|
881
|
+
... TOFCorrection(),
|
|
882
|
+
... ],
|
|
883
|
+
... argnames="data",
|
|
884
|
+
... batch_size=8,
|
|
885
|
+
... )
|
|
886
|
+
|
|
887
|
+
>>> # you can also map a subset of the operations
|
|
888
|
+
>>> # for example, demodulate in 4 chunks
|
|
889
|
+
>>> # or more specifically, split the transmit axis into 4 parts
|
|
890
|
+
|
|
891
|
+
>>> pipeline_mapped = Pipeline(
|
|
892
|
+
... [
|
|
893
|
+
... Map([Demodulate()], argnames="data", chunks=4),
|
|
894
|
+
... TOFCorrection(),
|
|
895
|
+
... ],
|
|
896
|
+
... )
|
|
897
|
+
"""
|
|
898
|
+
|
|
899
|
+
def __init__(
|
|
900
|
+
self,
|
|
901
|
+
operations: List[Operation],
|
|
902
|
+
argnames: List[str] | str,
|
|
903
|
+
in_axes: List[Union[int, None]] | int = 0,
|
|
904
|
+
out_axes: List[Union[int, None]] | int = 0,
|
|
905
|
+
chunks: int | None = None,
|
|
906
|
+
batch_size: int | None = None,
|
|
907
|
+
**kwargs,
|
|
908
|
+
):
|
|
909
|
+
"""
|
|
910
|
+
Args:
|
|
911
|
+
operations (list): List of operations to be performed.
|
|
912
|
+
argnames (str or list): List of argument names (or keys) to map over.
|
|
913
|
+
Can also be a single string if only one argument is mapped over.
|
|
914
|
+
in_axes (int or list): Axes to map over for each argument.
|
|
915
|
+
If a single int is provided, it is used for all arguments.
|
|
916
|
+
out_axes (int or list): Axes to map over for each output.
|
|
917
|
+
If a single int is provided, it is used for all outputs.
|
|
918
|
+
chunks (int, optional): Number of chunks to split the input data into.
|
|
919
|
+
If None, no chunking is performed. Mutually exclusive with ``batch_size``.
|
|
920
|
+
batch_size (int, optional): Size of batches to process at once.
|
|
921
|
+
If None, no batching is performed. Mutually exclusive with ``chunks``.
|
|
922
|
+
"""
|
|
923
|
+
super().__init__(operations, **kwargs)
|
|
924
|
+
|
|
925
|
+
if batch_size is not None and chunks is not None:
|
|
926
|
+
raise ValueError(
|
|
927
|
+
"batch_size and chunks are mutually exclusive. Please specify only one."
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
if batch_size is not None and batch_size <= 0:
|
|
931
|
+
raise ValueError("batch_size must be a positive integer.")
|
|
932
|
+
|
|
933
|
+
if chunks is not None and chunks <= 0:
|
|
934
|
+
raise ValueError("chunks must be a positive integer.")
|
|
935
|
+
|
|
936
|
+
if isinstance(argnames, str):
|
|
937
|
+
argnames = [argnames]
|
|
938
|
+
|
|
939
|
+
self.argnames = argnames
|
|
940
|
+
self.in_axes = in_axes
|
|
941
|
+
self.out_axes = out_axes
|
|
942
|
+
self.chunks = chunks
|
|
943
|
+
self.batch_size = batch_size
|
|
944
|
+
|
|
945
|
+
if chunks is None and batch_size is None:
|
|
946
|
+
log.warning(
|
|
947
|
+
"[zea.ops.Map] Both `chunks` and `batch_size` are None. "
|
|
948
|
+
"This will behave like a normal Pipeline. "
|
|
949
|
+
"Consider setting one of them to process data in chunks or batches."
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
def call_item(**inputs):
|
|
953
|
+
"""Process data in patches."""
|
|
954
|
+
mapped_args = []
|
|
955
|
+
for argname in argnames:
|
|
956
|
+
mapped_args.append(inputs.pop(argname, None))
|
|
957
|
+
|
|
958
|
+
def patched_call(*args):
|
|
959
|
+
mapped_kwargs = [(k, v) for k, v in zip(argnames, args)]
|
|
960
|
+
out = super(Map, self).call(**dict(mapped_kwargs), **inputs)
|
|
961
|
+
|
|
962
|
+
# TODO: maybe it is possible to output everything?
|
|
963
|
+
# e.g. prepend a empty dimension to all inputs and just map over everything?
|
|
964
|
+
return out[self.output_key]
|
|
965
|
+
|
|
966
|
+
out = vmap(
|
|
967
|
+
patched_call,
|
|
968
|
+
in_axes=in_axes,
|
|
969
|
+
out_axes=out_axes,
|
|
970
|
+
chunks=chunks,
|
|
971
|
+
batch_size=batch_size,
|
|
972
|
+
fn_supports_batch=True,
|
|
973
|
+
disable_jit=not bool(self.jit_options),
|
|
974
|
+
)(*mapped_args)
|
|
975
|
+
|
|
976
|
+
return out
|
|
977
|
+
|
|
978
|
+
self.call_item = call_item
|
|
979
|
+
|
|
980
|
+
@property
|
|
981
|
+
def jit_options(self):
|
|
982
|
+
"""Get the jit_options property of the pipeline."""
|
|
983
|
+
return self._jit_options
|
|
984
|
+
|
|
985
|
+
@jit_options.setter
|
|
986
|
+
def jit_options(self, value):
|
|
987
|
+
"""Set the jit_options property of the pipeline."""
|
|
988
|
+
self._jit_options = value
|
|
989
|
+
if value in ["pipeline", "ops"]:
|
|
990
|
+
self.jit()
|
|
991
|
+
else:
|
|
992
|
+
self.unjit()
|
|
993
|
+
|
|
994
|
+
def jit(self):
|
|
995
|
+
"""JIT compile the pipeline."""
|
|
996
|
+
self._jittable_call = jit(self.jittable_call, **self.jit_kwargs)
|
|
997
|
+
|
|
998
|
+
def unjit(self):
|
|
999
|
+
"""Un-JIT compile the pipeline."""
|
|
1000
|
+
self._jittable_call = self.jittable_call
|
|
1001
|
+
self._call_pipeline = self.call
|
|
1002
|
+
|
|
1003
|
+
@property
|
|
1004
|
+
def with_batch_dim(self):
|
|
1005
|
+
"""Get the with_batch_dim property of the pipeline."""
|
|
1006
|
+
return self._with_batch_dim
|
|
1007
|
+
|
|
1008
|
+
@with_batch_dim.setter
|
|
1009
|
+
def with_batch_dim(self, value):
|
|
1010
|
+
"""Set the with_batch_dim property of the pipeline.
|
|
1011
|
+
The class handles the batching so the operations have to be set to False."""
|
|
1012
|
+
self._with_batch_dim = value
|
|
1013
|
+
for operation in self.operations:
|
|
1014
|
+
operation.with_batch_dim = False
|
|
1015
|
+
|
|
1016
|
+
def jittable_call(self, **inputs):
|
|
1017
|
+
"""Process input data through the pipeline."""
|
|
1018
|
+
if self._with_batch_dim:
|
|
1019
|
+
input_data = inputs.pop(self.key)
|
|
1020
|
+
output = ops.map(
|
|
1021
|
+
lambda x: self.call_item(**{self.key: x, **inputs}),
|
|
1022
|
+
input_data,
|
|
1023
|
+
)
|
|
1024
|
+
else:
|
|
1025
|
+
output = self.call_item(**inputs)
|
|
1026
|
+
|
|
1027
|
+
return {self.output_key: output}
|
|
1028
|
+
|
|
1029
|
+
def call(self, **inputs):
|
|
1030
|
+
"""Process input data through the pipeline."""
|
|
1031
|
+
output = self._jittable_call(**inputs)
|
|
1032
|
+
inputs.update(output)
|
|
1033
|
+
return inputs
|
|
1034
|
+
|
|
1035
|
+
def get_dict(self):
|
|
1036
|
+
"""Get the configuration of the pipeline."""
|
|
1037
|
+
config = super().get_dict()
|
|
1038
|
+
config.update({"name": "map"})
|
|
1039
|
+
|
|
1040
|
+
config["params"].update(
|
|
1041
|
+
{
|
|
1042
|
+
"argnames": self.argnames,
|
|
1043
|
+
"in_axes": self.in_axes,
|
|
1044
|
+
"out_axes": self.out_axes,
|
|
1045
|
+
"chunks": self.chunks,
|
|
1046
|
+
"batch_size": self.batch_size,
|
|
1047
|
+
}
|
|
1048
|
+
)
|
|
1049
|
+
return config
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
@ops_registry("patched_grid")
|
|
1053
|
+
class PatchedGrid(Map):
|
|
1054
|
+
"""
|
|
1055
|
+
A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys.
|
|
1056
|
+
|
|
1057
|
+
This can be used to reduce memory usage by processing data in chunks.
|
|
1058
|
+
|
|
1059
|
+
For more information and flexibility, see :class:`zea.ops.Map`.
|
|
1060
|
+
"""
|
|
1061
|
+
|
|
1062
|
+
def __init__(self, *args, num_patches=10, **kwargs):
|
|
1063
|
+
super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs)
|
|
1064
|
+
self.num_patches = num_patches
|
|
1065
|
+
|
|
1066
|
+
def get_dict(self):
|
|
1067
|
+
"""Get the configuration of the pipeline."""
|
|
1068
|
+
config = super().get_dict()
|
|
1069
|
+
config.update({"name": "patched_grid"})
|
|
1070
|
+
|
|
1071
|
+
config["params"].pop("argnames")
|
|
1072
|
+
config["params"].pop("chunks")
|
|
1073
|
+
config["params"].update({"num_patches": self.num_patches})
|
|
1074
|
+
return config
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
@ops_registry("beamform")
|
|
1078
|
+
class Beamform(Pipeline):
|
|
1079
|
+
"""Classical beamforming pipeline for ultrasound image formation.
|
|
1080
|
+
|
|
1081
|
+
Expected input data type is `DataTypes.RF_DATA` which has shape `(n_tx, n_ax, n_el, n_ch)`.
|
|
1082
|
+
|
|
1083
|
+
Will run the following operations in sequence:
|
|
1084
|
+
- TOFCorrection (output type `DataTypes.ALIGNED_DATA`: `(n_tx, n_ax, n_el, n_ch)`)
|
|
1085
|
+
- PfieldWeighting (optional, output type `DataTypes.ALIGNED_DATA`: `(n_tx, n_ax, n_el, n_ch)`)
|
|
1086
|
+
- Sum over channels (DAS)
|
|
1087
|
+
- Sum over transmits (Compounding) (output type `DataTypes.BEAMFORMED_DATA`: `(grid_size_z, grid_size_x, n_ch)`)
|
|
1088
|
+
- ReshapeGrid (flattened grid is also reshaped to `(grid_size_z, grid_size_x)`)
|
|
1089
|
+
""" # noqa: E501
|
|
1090
|
+
|
|
1091
|
+
def __init__(self, beamformer="delay_and_sum", num_patches=100, enable_pfield=False, **kwargs):
|
|
1092
|
+
"""Initialize a Delay-and-Sum beamforming `zea.Pipeline`.
|
|
1093
|
+
|
|
1094
|
+
Args:
|
|
1095
|
+
beamformer (str): Type of beamformer to use. Currently supporting,
|
|
1096
|
+
"delay_and_sum" and "delay_multiply_and_sum".
|
|
1097
|
+
num_patches (int): Number of patches to split the grid into for patch-wise
|
|
1098
|
+
beamforming. If 1, no patching is performed.
|
|
1099
|
+
enable_pfield (bool): Whether to include pressure field weighting in the beamforming.
|
|
1100
|
+
|
|
1101
|
+
"""
|
|
1102
|
+
|
|
1103
|
+
self.beamformer_type = beamformer
|
|
1104
|
+
self.num_patches = num_patches
|
|
1105
|
+
self.enable_pfield = enable_pfield
|
|
1106
|
+
|
|
1107
|
+
# for backwards compatibility
|
|
1108
|
+
name_mapping = {
|
|
1109
|
+
"das": "delay_and_sum",
|
|
1110
|
+
"dmas": "delay_multiply_and_sum",
|
|
1111
|
+
}
|
|
1112
|
+
if beamformer in name_mapping:
|
|
1113
|
+
log.deprecated(
|
|
1114
|
+
f"Beamformer name '{beamformer}' is deprecated. "
|
|
1115
|
+
f"Please use '{name_mapping[beamformer]}' instead."
|
|
1116
|
+
)
|
|
1117
|
+
self.beamformer_type = name_mapping[beamformer]
|
|
1118
|
+
|
|
1119
|
+
if self.beamformer_type not in ["delay_and_sum", "delay_multiply_and_sum"]:
|
|
1120
|
+
raise ValueError(
|
|
1121
|
+
f"Unsupported beamformer type: {self.beamformer_type}. "
|
|
1122
|
+
"Supported types are 'delay_and_sum' and 'delay_multiply_and_sum'."
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
# Get beamforming ops
|
|
1126
|
+
beamforming = [
|
|
1127
|
+
TOFCorrection(),
|
|
1128
|
+
# PfieldWeighting(), # Inserted conditionally
|
|
1129
|
+
get_ops(self.beamformer_type)(),
|
|
1130
|
+
]
|
|
1131
|
+
|
|
1132
|
+
if self.enable_pfield:
|
|
1133
|
+
beamforming.insert(1, PfieldWeighting())
|
|
1134
|
+
|
|
1135
|
+
# Optionally add patching
|
|
1136
|
+
if self.num_patches > 1:
|
|
1137
|
+
beamforming = [
|
|
1138
|
+
PatchedGrid(
|
|
1139
|
+
operations=beamforming,
|
|
1140
|
+
num_patches=self.num_patches,
|
|
1141
|
+
**kwargs,
|
|
1142
|
+
)
|
|
1143
|
+
]
|
|
1144
|
+
|
|
1145
|
+
# Reshape the grid to image shape
|
|
1146
|
+
beamforming.append(ReshapeGrid())
|
|
1147
|
+
|
|
1148
|
+
# Set the output data type of the last operation
|
|
1149
|
+
# which also defines the pipeline output type
|
|
1150
|
+
beamforming[-1].output_data_type = DataTypes.BEAMFORMED_DATA
|
|
1151
|
+
|
|
1152
|
+
super().__init__(operations=beamforming, **kwargs)
|
|
1153
|
+
|
|
1154
|
+
def __repr__(self):
|
|
1155
|
+
"""String representation of the pipeline."""
|
|
1156
|
+
operations = []
|
|
1157
|
+
for operation in self.operations:
|
|
1158
|
+
if isinstance(operation, Pipeline):
|
|
1159
|
+
operations.append(repr(operation))
|
|
1160
|
+
else:
|
|
1161
|
+
operations.append(operation.__class__.__name__)
|
|
1162
|
+
return f"<Beamform {self.name}=({', '.join(operations)})>"
|
|
1163
|
+
|
|
1164
|
+
def get_dict(self) -> dict:
|
|
1165
|
+
"""Convert the pipeline to a dictionary."""
|
|
1166
|
+
config = super().get_dict()
|
|
1167
|
+
config.update({"name": "beamform"})
|
|
1168
|
+
config["params"].update(
|
|
1169
|
+
{
|
|
1170
|
+
"beamformer": self.beamformer_type,
|
|
1171
|
+
"num_patches": self.num_patches,
|
|
1172
|
+
"enable_pfield": self.enable_pfield,
|
|
1173
|
+
}
|
|
1174
|
+
)
|
|
1175
|
+
return config
|
|
1176
|
+
|
|
1177
|
+
|
|
1178
|
+
@ops_registry("delay_and_sum")
|
|
1179
|
+
class DelayAndSum(Operation):
|
|
1180
|
+
"""Sums time-delayed signals along channels and transmits."""
|
|
1181
|
+
|
|
1182
|
+
def __init__(self, **kwargs):
|
|
1183
|
+
super().__init__(
|
|
1184
|
+
input_data_type=DataTypes.ALIGNED_DATA,
|
|
1185
|
+
output_data_type=DataTypes.BEAMFORMED_DATA,
|
|
1186
|
+
**kwargs,
|
|
1187
|
+
)
|
|
1188
|
+
|
|
1189
|
+
def call(self, grid=None, **kwargs):
|
|
1190
|
+
"""Performs DAS beamforming on tof-corrected input.
|
|
1191
|
+
|
|
1192
|
+
Args:
|
|
1193
|
+
tof_corrected_data (ops.Tensor): The TOF corrected input of shape
|
|
1194
|
+
`(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
|
|
1195
|
+
|
|
1196
|
+
Returns:
|
|
1197
|
+
dict: Dictionary containing beamformed_data
|
|
1198
|
+
of shape `(grid_size_z*grid_size_x, n_ch)`
|
|
1199
|
+
with optional batch dimension.
|
|
1200
|
+
"""
|
|
1201
|
+
data = kwargs[self.key]
|
|
1202
|
+
|
|
1203
|
+
# Sum over the channels (n_el), i.e. DAS
|
|
1204
|
+
beamformed_data = ops.sum(data, -2)
|
|
1205
|
+
# Sum over transmits (n_tx), i.e. Compounding
|
|
1206
|
+
beamformed_data = ops.sum(beamformed_data, -3)
|
|
1207
|
+
|
|
1208
|
+
return {self.output_key: beamformed_data}
|
|
1209
|
+
|
|
1210
|
+
|
|
1211
|
+
@ops_registry("delay_multiply_and_sum")
|
|
1212
|
+
class DelayMultiplyAndSum(Operation):
|
|
1213
|
+
"""Performs the operations for the Delay-Multiply-and-Sum beamformer except the delay.
|
|
1214
|
+
The delay should be performed by the TOF correction operation.
|
|
1215
|
+
"""
|
|
1216
|
+
|
|
1217
|
+
def __init__(self, **kwargs):
|
|
1218
|
+
super().__init__(
|
|
1219
|
+
input_data_type=DataTypes.ALIGNED_DATA,
|
|
1220
|
+
output_data_type=DataTypes.BEAMFORMED_DATA,
|
|
1221
|
+
**kwargs,
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1224
|
+
def process_image(self, data):
|
|
1225
|
+
"""Performs DMAS beamforming on tof-corrected input.
|
|
1226
|
+
|
|
1227
|
+
Args:
|
|
1228
|
+
data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
|
|
1229
|
+
|
|
1230
|
+
Returns:
|
|
1231
|
+
ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
|
|
1232
|
+
"""
|
|
1233
|
+
|
|
1234
|
+
if not data.shape[-1] == 2:
|
|
1235
|
+
raise ValueError(
|
|
1236
|
+
"MultiplyAndSum operation requires IQ data with 2 channels. "
|
|
1237
|
+
f"Got data with shape {data.shape}."
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
# Compute the correlation matrix
|
|
1241
|
+
data = channels_to_complex(data)
|
|
1242
|
+
|
|
1243
|
+
data = self._multiply(data)
|
|
1244
|
+
data = self._select_lower_triangle(data)
|
|
1245
|
+
data = ops.sum(data, axis=(0, 2, 3))
|
|
1246
|
+
|
|
1247
|
+
data = complex_to_channels(data)
|
|
1248
|
+
|
|
1249
|
+
return data
|
|
1250
|
+
|
|
1251
|
+
def _select_lower_triangle(self, data):
|
|
1252
|
+
"""Select only the lower triangle of the correlation matrix."""
|
|
1253
|
+
n_el = data.shape[3]
|
|
1254
|
+
mask = ops.ones((n_el, n_el), dtype=data.dtype) - ops.eye(n_el, dtype=data.dtype)
|
|
1255
|
+
data = data * mask[None, None, :, :] / 2
|
|
1256
|
+
return data
|
|
1257
|
+
|
|
1258
|
+
def _multiply(self, data):
|
|
1259
|
+
"""Apply the DMAS multiplication step."""
|
|
1260
|
+
channel_products = data[:, :, :, None] * data[:, :, None, :]
|
|
1261
|
+
|
|
1262
|
+
data = ops.sign(channel_products) * ops.cast(
|
|
1263
|
+
ops.sqrt(ops.abs(channel_products)), data.dtype
|
|
1264
|
+
)
|
|
1265
|
+
return data
|
|
1266
|
+
|
|
1267
|
+
def call(self, grid=None, **kwargs):
|
|
1268
|
+
"""Performs DMAS beamforming on tof-corrected input.
|
|
1269
|
+
|
|
1270
|
+
Args:
|
|
1271
|
+
tof_corrected_data (ops.Tensor): The TOF corrected input of shape
|
|
1272
|
+
`(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
|
|
1273
|
+
|
|
1274
|
+
Returns:
|
|
1275
|
+
dict: Dictionary containing beamformed_data
|
|
1276
|
+
of shape `(grid_size_z*grid_size_x, n_ch)`
|
|
1277
|
+
with optional batch dimension.
|
|
1278
|
+
"""
|
|
1279
|
+
data = kwargs[self.key]
|
|
1280
|
+
|
|
1281
|
+
if not self.with_batch_dim:
|
|
1282
|
+
beamformed_data = self.process_image(data)
|
|
1283
|
+
else:
|
|
1284
|
+
# Apply process_image to each item in the batch
|
|
1285
|
+
beamformed_data = ops.map(self.process_image, data)
|
|
1286
|
+
|
|
1287
|
+
return {self.output_key: beamformed_data}
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
def make_operation_chain(
|
|
1291
|
+
operation_chain: List[Union[str, Dict, Config, Operation, Pipeline]],
|
|
1292
|
+
) -> List[Operation]:
|
|
1293
|
+
"""Make an operation chain from a custom list of operations.
|
|
1294
|
+
|
|
1295
|
+
Args:
|
|
1296
|
+
operation_chain (list): List of operations to be performed.
|
|
1297
|
+
Each operation can be:
|
|
1298
|
+
- A string: operation initialized with default parameters
|
|
1299
|
+
- A dictionary: operation initialized with parameters in the dictionary
|
|
1300
|
+
- A Config object: converted to a dictionary and initialized
|
|
1301
|
+
- An Operation/Pipeline instance: used as-is
|
|
1302
|
+
|
|
1303
|
+
Returns:
|
|
1304
|
+
list: List of operations to be performed.
|
|
1305
|
+
|
|
1306
|
+
Example:
|
|
1307
|
+
.. doctest::
|
|
1308
|
+
|
|
1309
|
+
>>> from zea.ops import make_operation_chain, LogCompress
|
|
1310
|
+
>>> SomeCustomOperation = LogCompress # just for demonstration
|
|
1311
|
+
>>> chain = make_operation_chain(
|
|
1312
|
+
... [
|
|
1313
|
+
... "envelope_detect",
|
|
1314
|
+
... {"name": "normalize", "params": {"output_range": (0, 1)}},
|
|
1315
|
+
... SomeCustomOperation(),
|
|
1316
|
+
... ]
|
|
1317
|
+
... )
|
|
1318
|
+
"""
|
|
1319
|
+
chain = []
|
|
1320
|
+
for operation in operation_chain:
|
|
1321
|
+
# Handle already instantiated Operation or Pipeline objects
|
|
1322
|
+
if isinstance(operation, (Operation, Pipeline)):
|
|
1323
|
+
chain.append(operation)
|
|
1324
|
+
continue
|
|
1325
|
+
|
|
1326
|
+
assert isinstance(operation, (str, dict, Config)), (
|
|
1327
|
+
f"Operation {operation} should be a string, dict, Config object, Operation, or Pipeline"
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
if isinstance(operation, str):
|
|
1331
|
+
operation_instance = get_ops(operation)()
|
|
1332
|
+
|
|
1333
|
+
else:
|
|
1334
|
+
if isinstance(operation, Config):
|
|
1335
|
+
operation = operation.serialize()
|
|
1336
|
+
|
|
1337
|
+
params = operation.get("params", {})
|
|
1338
|
+
op_name = operation.get("name")
|
|
1339
|
+
operation_cls = get_ops(op_name)
|
|
1340
|
+
|
|
1341
|
+
# Handle branches for branched pipeline
|
|
1342
|
+
if op_name == "branched_pipeline" and "branches" in operation:
|
|
1343
|
+
branch_configs = operation.get("branches", {})
|
|
1344
|
+
branches = []
|
|
1345
|
+
|
|
1346
|
+
# Convert each branch configuration to an operation chain
|
|
1347
|
+
for _, branch_config in branch_configs.items():
|
|
1348
|
+
if isinstance(branch_config, (list, np.ndarray)):
|
|
1349
|
+
# This is a list of operations
|
|
1350
|
+
branch = make_operation_chain(branch_config)
|
|
1351
|
+
elif "operations" in branch_config:
|
|
1352
|
+
# This is a pipeline-like branch
|
|
1353
|
+
branch = make_operation_chain(branch_config["operations"])
|
|
1354
|
+
else:
|
|
1355
|
+
# This is a single operation branch
|
|
1356
|
+
branch_op_cls = get_ops(branch_config["name"])
|
|
1357
|
+
branch_params = branch_config.get("params", {})
|
|
1358
|
+
branch = branch_op_cls(**branch_params)
|
|
1359
|
+
|
|
1360
|
+
branches.append(branch)
|
|
1361
|
+
|
|
1362
|
+
# Create the branched pipeline instance
|
|
1363
|
+
operation_instance = operation_cls(branches=branches, **params)
|
|
1364
|
+
# Check for nested operations at the same level as params
|
|
1365
|
+
elif "operations" in operation:
|
|
1366
|
+
nested_operations = make_operation_chain(operation["operations"])
|
|
1367
|
+
# Instantiate pipeline-type operations with nested operations
|
|
1368
|
+
if issubclass(operation_cls, Beamform):
|
|
1369
|
+
# some pipelines, such as `zea.ops.Beamformer`, are initialized
|
|
1370
|
+
# not with a list of operations but with other parameters that then
|
|
1371
|
+
# internally create a list of operations
|
|
1372
|
+
operation_instance = operation_cls(**params)
|
|
1373
|
+
elif issubclass(operation_cls, Pipeline):
|
|
1374
|
+
# in most cases we want to pass an operations list to
|
|
1375
|
+
# initialize a pipeline
|
|
1376
|
+
operation_instance = operation_cls(operations=nested_operations, **params)
|
|
1377
|
+
else:
|
|
1378
|
+
operation_instance = operation_cls(operations=nested_operations, **params)
|
|
1379
|
+
elif operation["name"] in ["patched_grid"]:
|
|
1380
|
+
nested_operations = make_operation_chain(operation["params"].pop("operations"))
|
|
1381
|
+
operation_instance = operation_cls(operations=nested_operations, **params)
|
|
1382
|
+
else:
|
|
1383
|
+
operation_instance = operation_cls(**params)
|
|
1384
|
+
|
|
1385
|
+
chain.append(operation_instance)
|
|
1386
|
+
|
|
1387
|
+
return chain
|
|
1388
|
+
|
|
1389
|
+
|
|
1390
|
+
def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
|
|
1391
|
+
"""
|
|
1392
|
+
Create a Pipeline instance from a Config object.
|
|
1393
|
+
"""
|
|
1394
|
+
assert "operations" in config, (
|
|
1395
|
+
"Config object must have an 'operations' key for pipeline creation."
|
|
1396
|
+
)
|
|
1397
|
+
assert isinstance(config.operations, (list, np.ndarray)), (
|
|
1398
|
+
"Config object must have a list or numpy array of operations for pipeline creation."
|
|
1399
|
+
)
|
|
1400
|
+
|
|
1401
|
+
operations = make_operation_chain(config.operations)
|
|
1402
|
+
|
|
1403
|
+
# merge pipeline config without operations with kwargs
|
|
1404
|
+
pipeline_config = config.copy()
|
|
1405
|
+
pipeline_config.pop("operations")
|
|
1406
|
+
|
|
1407
|
+
kwargs = {**pipeline_config, **kwargs}
|
|
1408
|
+
return Pipeline(operations=operations, **kwargs)
|
|
1409
|
+
|
|
1410
|
+
|
|
1411
|
+
def pipeline_from_json(json_string: str, **kwargs) -> Pipeline:
|
|
1412
|
+
"""
|
|
1413
|
+
Create a Pipeline instance from a JSON string.
|
|
1414
|
+
"""
|
|
1415
|
+
pipeline_config = Config(json.loads(json_string, cls=ZEADecoderJSON))
|
|
1416
|
+
return pipeline_from_config(pipeline_config, **kwargs)
|
|
1417
|
+
|
|
1418
|
+
|
|
1419
|
+
def pipeline_from_yaml(yaml_path: str, **kwargs) -> Pipeline:
|
|
1420
|
+
"""
|
|
1421
|
+
Create a Pipeline instance from a YAML file.
|
|
1422
|
+
"""
|
|
1423
|
+
with open(yaml_path, "r", encoding="utf-8") as f:
|
|
1424
|
+
pipeline_config = yaml.safe_load(f)
|
|
1425
|
+
operations = pipeline_config["operations"]
|
|
1426
|
+
return pipeline_from_config(Config({"operations": operations}), **kwargs)
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
def pipeline_to_config(pipeline: Pipeline) -> Config:
|
|
1430
|
+
"""
|
|
1431
|
+
Convert a Pipeline instance into a Config object.
|
|
1432
|
+
"""
|
|
1433
|
+
# TODO: we currently add the full pipeline as 1 operation to the config.
|
|
1434
|
+
# In another PR we should add a "pipeline" entry to the config instead of the "operations"
|
|
1435
|
+
# entry. This allows us to also have non-default pipeline classes as top level op.
|
|
1436
|
+
pipeline_dict = {"operations": [pipeline.get_dict()]}
|
|
1437
|
+
|
|
1438
|
+
# HACK: If the top level operation is a single pipeline, collapse it into the operations list.
|
|
1439
|
+
ops = pipeline_dict["operations"]
|
|
1440
|
+
if ops[0]["name"] == "pipeline" and len(ops) == 1:
|
|
1441
|
+
pipeline_dict = {"operations": ops[0]["operations"]}
|
|
1442
|
+
|
|
1443
|
+
return Config(pipeline_dict)
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
def pipeline_to_json(pipeline: Pipeline) -> str:
|
|
1447
|
+
"""
|
|
1448
|
+
Convert a Pipeline instance into a JSON string.
|
|
1449
|
+
"""
|
|
1450
|
+
pipeline_dict = {"operations": [pipeline.get_dict()]}
|
|
1451
|
+
|
|
1452
|
+
# HACK: If the top level operation is a single pipeline, collapse it into the operations list.
|
|
1453
|
+
ops = pipeline_dict["operations"]
|
|
1454
|
+
if ops[0]["name"] == "pipeline" and len(ops) == 1:
|
|
1455
|
+
pipeline_dict = {"operations": ops[0]["operations"]}
|
|
1456
|
+
|
|
1457
|
+
return json.dumps(pipeline_dict, cls=ZEAEncoderJSON, indent=4)
|
|
1458
|
+
|
|
1459
|
+
|
|
1460
|
+
def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
|
|
1461
|
+
"""
|
|
1462
|
+
Convert a Pipeline instance into a YAML file.
|
|
1463
|
+
"""
|
|
1464
|
+
pipeline_dict = pipeline.get_dict()
|
|
1465
|
+
|
|
1466
|
+
# HACK: If the top level operation is a single pipeline, collapse it into the operations list.
|
|
1467
|
+
ops = pipeline_dict["operations"]
|
|
1468
|
+
if ops[0]["name"] == "pipeline" and len(ops) == 1:
|
|
1469
|
+
pipeline_dict = {"operations": ops[0]["operations"]}
|
|
1470
|
+
|
|
1471
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
1472
|
+
yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
|