zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -5
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/data/__init__.py +0 -9
  10. zea/data/augmentations.py +222 -29
  11. zea/data/convert/__init__.py +1 -6
  12. zea/data/convert/__main__.py +164 -0
  13. zea/data/convert/camus.py +106 -40
  14. zea/data/convert/echonet.py +184 -83
  15. zea/data/convert/echonetlvh/README.md +2 -3
  16. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  17. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  18. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  19. zea/data/convert/picmus.py +37 -40
  20. zea/data/convert/utils.py +86 -0
  21. zea/data/convert/verasonics.py +1247 -0
  22. zea/data/data_format.py +124 -6
  23. zea/data/dataloader.py +12 -7
  24. zea/data/datasets.py +109 -70
  25. zea/data/file.py +119 -82
  26. zea/data/file_operations.py +496 -0
  27. zea/data/preset_utils.py +2 -2
  28. zea/display.py +8 -9
  29. zea/doppler.py +5 -5
  30. zea/func/__init__.py +109 -0
  31. zea/{tensor_ops.py → func/tensor.py} +113 -69
  32. zea/func/ultrasound.py +500 -0
  33. zea/internal/_generate_keras_ops.py +5 -5
  34. zea/internal/checks.py +6 -12
  35. zea/internal/operators.py +4 -0
  36. zea/io_lib.py +108 -160
  37. zea/metrics.py +6 -5
  38. zea/models/__init__.py +1 -1
  39. zea/models/diffusion.py +63 -12
  40. zea/models/echonetlvh.py +1 -1
  41. zea/models/gmm.py +1 -1
  42. zea/models/lv_segmentation.py +2 -0
  43. zea/ops/__init__.py +188 -0
  44. zea/ops/base.py +442 -0
  45. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  46. zea/ops/pipeline.py +1472 -0
  47. zea/ops/tensor.py +356 -0
  48. zea/ops/ultrasound.py +890 -0
  49. zea/probes.py +2 -10
  50. zea/scan.py +35 -28
  51. zea/tools/fit_scan_cone.py +90 -160
  52. zea/tools/selection_tool.py +1 -1
  53. zea/tracking/__init__.py +16 -0
  54. zea/tracking/base.py +94 -0
  55. zea/tracking/lucas_kanade.py +474 -0
  56. zea/tracking/segmentation.py +110 -0
  57. zea/utils.py +11 -2
  58. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
  59. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
  60. zea/data/convert/matlab.py +0 -1237
  61. zea/ops.py +0 -3294
  62. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  63. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  64. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops.py DELETED
@@ -1,3294 +0,0 @@
1
- """Operations and Pipelines for ultrasound data processing.
2
-
3
- This module contains two important classes, :class:`Operation` and :class:`Pipeline`,
4
- which are used to process ultrasound data. A pipeline is a sequence of operations
5
- that are applied to the data in a specific order.
6
-
7
- Stand-alone manual usage
8
- ------------------------
9
-
10
- Operations can be run on their own:
11
-
12
- Examples
13
- ^^^^^^^^
14
- .. doctest::
15
-
16
- >>> import numpy as np
17
- >>> from zea.ops import EnvelopeDetect
18
- >>> data = np.random.randn(2000, 128, 1)
19
- >>> # static arguments are passed in the constructor
20
- >>> envelope_detect = EnvelopeDetect(axis=-1)
21
- >>> # other parameters can be passed here along with the data
22
- >>> envelope_data = envelope_detect(data=data)
23
-
24
- Using a pipeline
25
- ----------------
26
-
27
- You can initialize with a default pipeline or create your own custom pipeline.
28
-
29
- .. doctest::
30
-
31
- >>> from zea.ops import Pipeline, EnvelopeDetect, Normalize, LogCompress
32
- >>> pipeline = Pipeline.from_default()
33
-
34
- >>> operations = [
35
- ... EnvelopeDetect(),
36
- ... Normalize(),
37
- ... LogCompress(),
38
- ... ]
39
- >>> pipeline_custom = Pipeline(operations)
40
-
41
- One can also load a pipeline from a config or yaml/json file:
42
-
43
- .. doctest::
44
-
45
- >>> from zea import Pipeline
46
-
47
- >>> # From JSON string
48
- >>> json_string = '{"operations": ["identity"]}'
49
- >>> pipeline = Pipeline.from_json(json_string)
50
-
51
- >>> # from yaml file
52
- >>> import yaml
53
- >>> from zea import Config
54
- >>> # Create a sample pipeline YAML file
55
- >>> pipeline_dict = {
56
- ... "operations": [
57
- ... {"name": "identity"},
58
- ... ]
59
- ... }
60
- >>> with open("pipeline.yaml", "w") as f:
61
- ... yaml.dump(pipeline_dict, f)
62
- >>> yaml_file = "pipeline.yaml"
63
- >>> pipeline = Pipeline.from_yaml(yaml_file)
64
-
65
- .. testcleanup::
66
-
67
- import os
68
-
69
- os.remove("pipeline.yaml")
70
-
71
- Example of a yaml file:
72
-
73
- .. code-block:: yaml
74
-
75
- pipeline:
76
- operations:
77
- - name: demodulate
78
- - name: "patched_grid"
79
- params:
80
- operations:
81
- - name: tof_correction
82
- - name: pfield_weighting
83
- - name: delay_and_sum
84
- num_patches: 100
85
- - name: envelope_detect
86
- - name: normalize
87
- - name: log_compress
88
-
89
- """
90
-
91
- import copy
92
- import hashlib
93
- import inspect
94
- import json
95
- from functools import partial
96
- from typing import Any, Dict, List, Union
97
-
98
- import keras
99
- import numpy as np
100
- import scipy
101
- import yaml
102
- from keras import ops
103
- from keras.src.layers.preprocessing.data_layer import DataLayer
104
-
105
- from zea import log
106
- from zea.backend import jit
107
- from zea.beamform.beamformer import tof_correction
108
- from zea.config import Config
109
- from zea.display import scan_convert
110
- from zea.internal.checks import _assert_keys_and_axes
111
- from zea.internal.core import (
112
- DEFAULT_DYNAMIC_RANGE,
113
- DataTypes,
114
- ZEADecoderJSON,
115
- ZEAEncoderJSON,
116
- dict_to_tensor,
117
- )
118
- from zea.internal.core import Object as ZEAObject
119
- from zea.internal.registry import ops_registry
120
- from zea.probes import Probe
121
- from zea.scan import Scan
122
- from zea.simulator import simulate_rf
123
- from zea.tensor_ops import resample, reshape_axis, translate, vmap
124
- from zea.utils import (
125
- FunctionTimer,
126
- deep_compare,
127
- map_negative_indices,
128
- )
129
-
130
-
131
- def get_ops(ops_name):
132
- """Get the operation from the registry."""
133
- return ops_registry[ops_name]
134
-
135
-
136
- class Operation(keras.Operation):
137
- """
138
- A base abstract class for operations in the pipeline with caching functionality.
139
- """
140
-
141
- def __init__(
142
- self,
143
- input_data_type: Union[DataTypes, None] = None,
144
- output_data_type: Union[DataTypes, None] = None,
145
- key: Union[str, None] = "data",
146
- output_key: Union[str, None] = None,
147
- cache_inputs: Union[bool, List[str]] = False,
148
- cache_outputs: bool = False,
149
- jit_compile: bool = True,
150
- with_batch_dim: bool = True,
151
- jit_kwargs: dict | None = None,
152
- jittable: bool = True,
153
- additional_output_keys: List[str] = None,
154
- **kwargs,
155
- ):
156
- """
157
- Args:
158
- input_data_type (DataTypes): The data type of the input data
159
- output_data_type (DataTypes): The data type of the output data
160
- key: The key for the input data (operation will operate on this key)
161
- Defaults to "data".
162
- output_key: The key for the output data (operation will output to this key)
163
- Defaults to the same as the input key. If you want to store intermediate
164
- results, you can set this to a different key. But make sure to update the
165
- input key of the next operation to match the output key of this operation.
166
- cache_inputs: A list of input keys to cache or True to cache all inputs
167
- cache_outputs: A list of output keys to cache or True to cache all outputs
168
- jit_compile: Whether to JIT compile the 'call' method for faster execution
169
- with_batch_dim: Whether operations should expect a batch dimension in the input
170
- jit_kwargs: Additional keyword arguments for the JIT compiler
171
- jittable: Whether the operation can be JIT compiled
172
- additional_output_keys: A list of additional output keys produced by the operation.
173
- These are used to track if all keys are available for downstream operations.
174
- If the operation has a conditional output, it is best to add all possible
175
- output keys here.
176
- """
177
- super().__init__(**kwargs)
178
-
179
- self.input_data_type = input_data_type
180
- self.output_data_type = output_data_type
181
-
182
- self.key = key # Key for input data
183
- self.output_key = output_key # Key for output data
184
- if self.output_key is None:
185
- self.output_key = self.key
186
- self.additional_output_keys = additional_output_keys or []
187
-
188
- self.inputs = [] # Source(s) of input data (name of a previous operation)
189
- self.allow_multiple_inputs = False # Only single input allowed by default
190
-
191
- self.cache_inputs = cache_inputs
192
- self.cache_outputs = cache_outputs
193
-
194
- # Initialize input and output caches
195
- self._input_cache = {}
196
- self._output_cache = {}
197
-
198
- # Obtain the input signature of the `call` method
199
- self._trace_signatures()
200
-
201
- if jit_kwargs is None:
202
- jit_kwargs = {}
203
-
204
- if keras.backend.backend() == "jax" and self.static_params:
205
- jit_kwargs |= {"static_argnames": self.static_params}
206
-
207
- self.jit_kwargs = jit_kwargs
208
-
209
- self.with_batch_dim = with_batch_dim
210
- self._jittable = jittable
211
-
212
- # Set the jit compilation flag and compile the `call` method
213
- # Set zea logger level to suppress warnings regarding
214
- # torch not being able to compile the function
215
- with log.set_level("ERROR"):
216
- self.set_jit(jit_compile)
217
-
218
- @property
219
- def output_keys(self) -> List[str]:
220
- """Get the output keys of the operation."""
221
- return [self.output_key] + self.additional_output_keys
222
-
223
- @property
224
- def static_params(self):
225
- """Get the static parameters of the operation."""
226
- return getattr(self.__class__, "STATIC_PARAMS", [])
227
-
228
- def set_jit(self, jit_compile: bool):
229
- """Set the JIT compilation flag and set the `_call` method accordingly."""
230
- self._jit_compile = jit_compile
231
- if self._jit_compile and self.jittable:
232
- self._call = jit(self.call, **self.jit_kwargs)
233
- else:
234
- self._call = self.call
235
-
236
- def _trace_signatures(self):
237
- """
238
- Analyze and store the input/output signatures of the `call` method.
239
- """
240
- self._input_signature = inspect.signature(self.call)
241
- self._valid_keys = set(self._input_signature.parameters.keys())
242
-
243
- @property
244
- def valid_keys(self) -> set:
245
- """Get the valid keys for the `call` method."""
246
- return self._valid_keys
247
-
248
- @property
249
- def needs_keys(self) -> set:
250
- """Get a set of all input keys needed by the operation."""
251
- return self.valid_keys
252
-
253
- @property
254
- def jittable(self):
255
- """Check if the operation can be JIT compiled."""
256
- return self._jittable
257
-
258
- def call(self, **kwargs):
259
- """
260
- Abstract method that defines the processing logic for the operation.
261
- Subclasses must implement this method.
262
- """
263
- raise NotImplementedError
264
-
265
- def set_input_cache(self, input_cache: Dict[str, Any]):
266
- """
267
- Set a cache for inputs, then retrace the function if necessary.
268
-
269
- Args:
270
- input_cache: A dictionary containing cached inputs.
271
- """
272
- self._input_cache.update(input_cache)
273
- self._trace_signatures() # Retrace after updating cache to ensure correctness.
274
-
275
- def set_output_cache(self, output_cache: Dict[str, Any]):
276
- """
277
- Set a cache for outputs, then retrace the function if necessary.
278
-
279
- Args:
280
- output_cache: A dictionary containing cached outputs.
281
- """
282
- self._output_cache.update(output_cache)
283
- self._trace_signatures() # Retrace after updating cache to ensure correctness.
284
-
285
- def clear_cache(self):
286
- """
287
- Clear the input and output caches.
288
- """
289
- self._input_cache.clear()
290
- self._output_cache.clear()
291
-
292
- def _hash_inputs(self, kwargs: Dict) -> str:
293
- """
294
- Generate a hash for the given inputs to use as a cache key.
295
-
296
- Args:
297
- kwargs: Keyword arguments.
298
-
299
- Returns:
300
- A unique hash representing the inputs.
301
- """
302
- input_json = json.dumps(kwargs, sort_keys=True, default=str)
303
- return hashlib.md5(input_json.encode()).hexdigest()
304
-
305
- def __call__(self, *args, **kwargs) -> Dict:
306
- """
307
- Process the input keyword arguments and return the processed results.
308
-
309
- Args:
310
- kwargs: Keyword arguments to be processed.
311
-
312
- Returns:
313
- Combined input and output as kwargs.
314
- """
315
- if args:
316
- example_usage = f" result = {ops_registry.get_name(self)}({self.key}=my_data"
317
- valid_keys_no_kwargs = self.valid_keys - {"kwargs"}
318
- if valid_keys_no_kwargs:
319
- example_usage += f", {list(valid_keys_no_kwargs)[0]}=param1, ..., **kwargs)"
320
- else:
321
- example_usage += ", **kwargs)"
322
- raise TypeError(
323
- f"{self.__class__.__name__}.__call__() only accepts keyword arguments. "
324
- "Positional arguments are not allowed.\n"
325
- f"Received positional arguments: {args}\n"
326
- "Example usage:\n"
327
- f"{example_usage}"
328
- )
329
-
330
- # Merge cached inputs with provided ones
331
- merged_kwargs = {**self._input_cache, **kwargs}
332
-
333
- # Return cached output if available
334
- if self.cache_outputs:
335
- cache_key = self._hash_inputs(merged_kwargs)
336
- if cache_key in self._output_cache:
337
- return {**merged_kwargs, **self._output_cache[cache_key]}
338
-
339
- # Filter kwargs to match the valid keys of the `call` method
340
- if "kwargs" not in self.valid_keys:
341
- filtered_kwargs = {k: v for k, v in merged_kwargs.items() if k in self.valid_keys}
342
- else:
343
- filtered_kwargs = merged_kwargs
344
-
345
- # Call the processing function
346
- # If you want to jump in with debugger please set `jit_compile=False`
347
- # when initializing the pipeline.
348
- processed_output = self._call(**filtered_kwargs)
349
-
350
- # Ensure the output is always a dictionary
351
- if not isinstance(processed_output, dict):
352
- raise TypeError(
353
- f"The `call` method must return a dictionary. Got {type(processed_output)}."
354
- )
355
-
356
- # Merge outputs with inputs
357
- combined_kwargs = {**merged_kwargs, **processed_output}
358
-
359
- # Cache the result if caching is enabled
360
- if self.cache_outputs:
361
- if isinstance(self.cache_outputs, list):
362
- cached_output = {
363
- k: v for k, v in processed_output.items() if k in self.cache_outputs
364
- }
365
- else:
366
- cached_output = processed_output
367
- self._output_cache[cache_key] = cached_output
368
-
369
- return combined_kwargs
370
-
371
- def get_dict(self):
372
- """Get the configuration of the operation. Inherit from keras.Operation."""
373
- config = {}
374
- config.update({"name": ops_registry.get_name(self)})
375
- config["params"] = {
376
- "key": self.key,
377
- "output_key": self.output_key,
378
- "cache_inputs": self.cache_inputs,
379
- "cache_outputs": self.cache_outputs,
380
- "jit_compile": self._jit_compile,
381
- "with_batch_dim": self.with_batch_dim,
382
- "jit_kwargs": self.jit_kwargs,
383
- }
384
- return config
385
-
386
- def __eq__(self, other):
387
- """Check equality of two operations based on type and configuration."""
388
- if not isinstance(other, Operation):
389
- return False
390
-
391
- # Compare the class name and parameters
392
- if self.__class__.__name__ != other.__class__.__name__:
393
- return False
394
-
395
- # Compare the name assigned to the operation
396
- name = ops_registry.get_name(self)
397
- other_name = ops_registry.get_name(other)
398
- if name != other_name:
399
- return False
400
-
401
- # Compare the parameters of the operations
402
- if not deep_compare(self.get_dict(), other.get_dict()):
403
- return False
404
-
405
- return True
406
-
407
-
408
- @ops_registry("pipeline")
409
- class Pipeline:
410
- """Pipeline class for processing ultrasound data through a series of operations."""
411
-
412
- def __init__(
413
- self,
414
- operations: List[Operation],
415
- with_batch_dim: bool = True,
416
- jit_options: Union[str, None] = "ops",
417
- jit_kwargs: dict | None = None,
418
- name="pipeline",
419
- validate=True,
420
- timed: bool = False,
421
- ):
422
- """
423
- Initialize a pipeline.
424
-
425
- Args:
426
- operations (list): A list of Operation instances representing the operations
427
- to be performed.
428
- with_batch_dim (bool, optional): Whether operations should expect a batch dimension.
429
- Defaults to True.
430
- jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None.
431
-
432
- - "pipeline": compiles the entire pipeline as a single function.
433
- This may be faster but does not preserve python control flow, such as caching.
434
-
435
- - "ops": compiles each operation separately. This preserves python control flow and
436
- caching functionality, but speeds up the operations.
437
-
438
- - None: disables JIT compilation.
439
-
440
- Defaults to "ops".
441
-
442
- jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler.
443
- name (str, optional): The name of the pipeline. Defaults to "pipeline".
444
- validate (bool, optional): Whether to validate the pipeline. Defaults to True.
445
- timed (bool, optional): Whether to time each operation. Defaults to False.
446
-
447
- """
448
- self._call_pipeline = self.call
449
- self.name = name
450
-
451
- self._pipeline_layers = operations
452
-
453
- if jit_options not in ["pipeline", "ops", None]:
454
- raise ValueError("jit_options must be 'pipeline', 'ops', or None")
455
-
456
- self.with_batch_dim = with_batch_dim
457
- self._validate_flag = validate
458
-
459
- # Setup timer
460
- if jit_options == "pipeline" and timed:
461
- raise ValueError(
462
- "timed=True cannot be used with jit_options='pipeline' as the entire "
463
- "pipeline is compiled into a single function. Try setting jit_options to "
464
- "'ops' or None."
465
- )
466
- if timed:
467
- log.warning(
468
- "Timer has been initialized for the pipeline. To get an accurate timing estimate, "
469
- "the `block_until_ready()` is used, which will slow down the execution, so "
470
- "do not use for regular processing!"
471
- )
472
- self._callable_layers = self._get_timed_operations()
473
- else:
474
- self._callable_layers = self._pipeline_layers
475
- self._timed = timed
476
-
477
- if validate:
478
- self.validate()
479
- else:
480
- log.warning("Pipeline validation is disabled, make sure to validate manually.")
481
-
482
- if jit_kwargs is None:
483
- jit_kwargs = {}
484
-
485
- if keras.backend.backend() == "jax" and self.static_params != []:
486
- jit_kwargs = {"static_argnames": self.static_params}
487
-
488
- self.jit_kwargs = jit_kwargs
489
- self.jit_options = jit_options # will handle the jit compilation
490
-
491
- def needs(self, key) -> bool:
492
- """Check if the pipeline needs a specific key at the input."""
493
- return key in self.needs_keys
494
-
495
- @property
496
- def output_keys(self) -> set:
497
- """All output keys the pipeline guarantees to produce."""
498
- output_keys = set()
499
- for operation in self.operations:
500
- output_keys.update(operation.output_keys)
501
- return output_keys
502
-
503
- @property
504
- def valid_keys(self) -> set:
505
- """Get a set of valid keys for the pipeline.
506
-
507
- This is all keys that can be passed to the pipeline as input.
508
- """
509
- valid_keys = set()
510
- for operation in self.operations:
511
- valid_keys.update(operation.valid_keys)
512
- return valid_keys
513
-
514
- @property
515
- def static_params(self) -> List[str]:
516
- """Get a list of static parameters for the pipeline."""
517
- static_params = []
518
- for operation in self.operations:
519
- static_params.extend(operation.static_params)
520
- return list(set(static_params))
521
-
522
- @property
523
- def needs_keys(self) -> set:
524
- """Get a set of all input keys needed by the pipeline.
525
-
526
- Will keep track of keys that are already provided by previous operations.
527
- """
528
- needs = set()
529
- has_so_far = set()
530
- previous_operation = None
531
- for operation in self.operations:
532
- if previous_operation is not None:
533
- has_so_far.update(previous_operation.output_keys)
534
- needs.update(operation.needs_keys - has_so_far)
535
- previous_operation = operation
536
- return needs
537
-
538
- @classmethod
539
- def from_default(
540
- cls, num_patches=100, baseband=False, pfield=False, timed=False, **kwargs
541
- ) -> "Pipeline":
542
- """Create a default pipeline.
543
-
544
- Args:
545
- num_patches (int): Number of patches for the PatchedGrid operation.
546
- Defaults to 100. If you get an out of memory error, try to increase this number.
547
- baseband (bool): If True, assume the input data is baseband (I/Q) data,
548
- which has 2 channels (last dim). Defaults to False, which assumes RF data,
549
- so input signal has a single channel dim and is still on carrier frequency.
550
- pfield (bool): If True, apply Pfield weighting. Defaults to False.
551
- This will calculate pressure field and only beamform the data to those locations.
552
- timed (bool, optional): Whether to time each operation. Defaults to False.
553
- **kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
554
-
555
- """
556
- operations = []
557
-
558
- # Add the demodulate operation
559
- if not baseband:
560
- operations.append(Demodulate())
561
-
562
- # Get beamforming ops
563
- beamforming = [
564
- TOFCorrection(),
565
- DelayAndSum(),
566
- ]
567
- if pfield:
568
- beamforming.insert(1, PfieldWeighting())
569
-
570
- # Optionally add patching
571
- if num_patches > 1:
572
- beamforming = [PatchedGrid(operations=beamforming, num_patches=num_patches, **kwargs)]
573
-
574
- # Add beamforming ops
575
- operations += beamforming
576
-
577
- # Add display ops
578
- operations += [
579
- EnvelopeDetect(),
580
- Normalize(),
581
- LogCompress(),
582
- ]
583
- return cls(operations, timed=timed, **kwargs)
584
-
585
- def copy(self) -> "Pipeline":
586
- """Create a copy of the pipeline."""
587
- return Pipeline(
588
- self._pipeline_layers.copy(),
589
- with_batch_dim=self.with_batch_dim,
590
- jit_options=self.jit_options,
591
- jit_kwargs=self.jit_kwargs,
592
- name=self.name,
593
- validate=self._validate_flag,
594
- timed=self._timed,
595
- )
596
-
597
- def reinitialize(self):
598
- """Reinitialize the pipeline in place."""
599
- self.__init__(
600
- self._pipeline_layers,
601
- with_batch_dim=self.with_batch_dim,
602
- jit_options=self.jit_options,
603
- jit_kwargs=self.jit_kwargs,
604
- name=self.name,
605
- validate=self._validate_flag,
606
- timed=self._timed,
607
- )
608
-
609
- def prepend(self, operation: Operation):
610
- """Prepend an operation to the pipeline."""
611
- self._pipeline_layers.insert(0, operation)
612
- self.reinitialize()
613
-
614
- def append(self, operation: Operation):
615
- """Append an operation to the pipeline."""
616
- self._pipeline_layers.append(operation)
617
- self.reinitialize()
618
-
619
- def insert(self, index: int, operation: Operation):
620
- """Insert an operation at a specific index in the pipeline."""
621
- if index < 0 or index > len(self._pipeline_layers):
622
- raise IndexError("Index out of bounds for inserting operation.")
623
- self._pipeline_layers.insert(index, operation)
624
- self.reinitialize()
625
-
626
- @property
627
- def operations(self):
628
- """Alias for self.layers to match the zea naming convention"""
629
- return self._pipeline_layers
630
-
631
- def reset_timer(self):
632
- """Reset the timer for timed operations."""
633
- if self._timed:
634
- self._callable_layers = self._get_timed_operations()
635
- else:
636
- log.warning(
637
- "Timer has not been initialized. Set timed=True when initializing the pipeline."
638
- )
639
-
640
- def _get_timed_operations(self):
641
- """Get a list of timed operations."""
642
- self.timer = FunctionTimer()
643
- return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
644
-
645
- def call(self, **inputs):
646
- """Process input data through the pipeline."""
647
- for operation in self._callable_layers:
648
- try:
649
- outputs = operation(**inputs)
650
- except KeyError as exc:
651
- raise KeyError(
652
- f"[zea.Pipeline] Operation '{operation.__class__.__name__}' "
653
- f"requires input key '{exc.args[0]}', "
654
- "but it was not provided in the inputs.\n"
655
- "Check whether the objects (such as `zea.Scan`) passed to "
656
- "`pipeline.prepare_parameters()` contain all required keys.\n"
657
- f"Current list of all passed keys: {list(inputs.keys())}\n"
658
- f"Valid keys for this pipeline: {self.valid_keys}"
659
- ) from exc
660
- except Exception as exc:
661
- raise RuntimeError(
662
- f"[zea.Pipeline] Error in operation '{operation.__class__.__name__}': {exc}"
663
- )
664
- inputs = outputs
665
- return outputs
666
-
667
- def __call__(self, return_numpy=False, **inputs):
668
- """Process input data through the pipeline."""
669
-
670
- if any(key in inputs for key in ["probe", "scan", "config"]) or any(
671
- isinstance(arg, ZEAObject) for arg in inputs.values()
672
- ):
673
- raise ValueError(
674
- "Probe, Scan and Config objects should be first processed with "
675
- "`Pipeline.prepare_parameters` before calling the pipeline. "
676
- "e.g. inputs = Pipeline.prepare_parameters(probe, scan, config)"
677
- )
678
-
679
- if any(isinstance(arg, str) for arg in inputs.values()):
680
- raise ValueError(
681
- "Pipeline does not support string inputs. "
682
- "Please ensure all inputs are convertible to tensors."
683
- )
684
-
685
- ## PROCESSING
686
- outputs = self._call_pipeline(**inputs)
687
-
688
- ## PREPARE OUTPUT
689
- if return_numpy:
690
- # Convert tensors to numpy arrays but preserve None values
691
- outputs = {
692
- k: ops.convert_to_numpy(v) if v is ops.is_tensor(v) else v
693
- for k, v in outputs.items()
694
- }
695
-
696
- return outputs
697
-
698
- @property
699
- def jit_options(self):
700
- """Get the jit_options property of the pipeline."""
701
- return self._jit_options
702
-
703
- @jit_options.setter
704
- def jit_options(self, value: Union[str, None]):
705
- """Set the jit_options property of the pipeline."""
706
- self._jit_options = value
707
- if value == "pipeline":
708
- assert self.jittable, log.error(
709
- "jit_options 'pipeline' cannot be used as the entire pipeline is not jittable. "
710
- "The following operations are not jittable: "
711
- f"{self.unjitable_ops}. "
712
- "Try setting jit_options to 'ops' or None."
713
- )
714
- self.jit()
715
- return
716
- else:
717
- self.unjit()
718
-
719
- for operation in self.operations:
720
- if isinstance(operation, Pipeline):
721
- operation.jit_options = value
722
- else:
723
- if operation.jittable and operation._jit_compile:
724
- operation.set_jit(value == "ops")
725
-
726
- def jit(self):
727
- """JIT compile the pipeline."""
728
- self._call_pipeline = jit(self.call, **self.jit_kwargs)
729
-
730
- def unjit(self):
731
- """Un-JIT compile the pipeline."""
732
- self._call_pipeline = self.call
733
-
734
- @property
735
- def jittable(self):
736
- """Check if all operations in the pipeline are jittable."""
737
- return all(operation.jittable for operation in self.operations)
738
-
739
- @property
740
- def unjitable_ops(self):
741
- """Get a list of operations that are not jittable."""
742
- return [operation for operation in self.operations if not operation.jittable]
743
-
744
- @property
745
- def with_batch_dim(self):
746
- """Get the with_batch_dim property of the pipeline."""
747
- return self._with_batch_dim
748
-
749
- @with_batch_dim.setter
750
- def with_batch_dim(self, value):
751
- """Set the with_batch_dim property of the pipeline."""
752
- self._with_batch_dim = value
753
- for operation in self.operations:
754
- operation.with_batch_dim = value
755
-
756
- @property
757
- def input_data_type(self):
758
- """Get the input_data_type property of the pipeline."""
759
- return self.operations[0].input_data_type
760
-
761
- @property
762
- def output_data_type(self):
763
- """Get the output_data_type property of the pipeline."""
764
- return self.operations[-1].output_data_type
765
-
766
- def validate(self):
767
- """Validate the pipeline by checking the compatibility of the operations."""
768
- operations = self.operations
769
- for i in range(len(operations) - 1):
770
- if operations[i].output_data_type is None:
771
- continue
772
- if operations[i + 1].input_data_type is None:
773
- continue
774
- if operations[i].output_data_type != operations[i + 1].input_data_type:
775
- raise ValueError(
776
- f"Operation {operations[i].__class__.__name__} output data type "
777
- f"({operations[i].output_data_type}) is not compatible "
778
- f"with the input data type ({operations[i + 1].input_data_type}) "
779
- f"of operation {operations[i + 1].__class__.__name__}"
780
- )
781
-
782
- def set_params(self, **params):
783
- """Set parameters for the operations in the pipeline by adding them to the cache."""
784
- for operation in self.operations:
785
- operation_params = {
786
- key: value for key, value in params.items() if key in operation.valid_keys
787
- }
788
- if operation_params:
789
- operation.set_input_cache(operation_params)
790
-
791
- def get_params(self, per_operation: bool = False):
792
- """Get a snapshot of the current parameters of the operations in the pipeline.
793
-
794
- Args:
795
- per_operation (bool): If True, return a list of dictionaries for each operation.
796
- If False, return a single dictionary with all parameters combined.
797
- """
798
- if per_operation:
799
- return [operation._input_cache.copy() for operation in self.operations]
800
- else:
801
- params = {}
802
- for operation in self.operations:
803
- params.update(operation._input_cache)
804
- return params
805
-
806
- def __str__(self):
807
- """String representation of the pipeline.
808
-
809
- Will print on two parallel pipeline lines if it detects a splitting operations
810
- (such as multi_bandpass_filter)
811
- Will merge the pipeline lines if it detects a stacking operation (such as stack)
812
- """
813
- split_operations = []
814
- merge_operations = ["Stack"]
815
-
816
- operations = [operation.__class__.__name__ for operation in self.operations]
817
- string = " -> ".join(operations)
818
-
819
- if any(operation in split_operations for operation in operations):
820
- # a second line is needed with same length as the first line
821
- split_line = " " * len(string)
822
- # find the splitting operation and index and print \-> instead of -> after
823
- split_detected = False
824
- merge_detected = False
825
- split_operation = None
826
- for operation in operations:
827
- if operation in split_operations:
828
- index = string.index(operation)
829
- index = index + len(operation)
830
- split_line = split_line[:index] + "\\->" + split_line[index + len("\\->") :]
831
- split_detected = True
832
- merge_detected = False
833
- split_operation = operation
834
- continue
835
-
836
- if operation in merge_operations:
837
- index = string.index(operation)
838
- index = index - 4
839
- split_line = split_line[:index] + "/" + split_line[index + 1 :]
840
- split_detected = False
841
- merge_detected = True
842
- continue
843
-
844
- if split_detected:
845
- # print all operations in the second line
846
- index = string.index(operation)
847
- split_line = (
848
- split_line[:index]
849
- + operation
850
- + " -> "
851
- + split_line[index + len(operation) + len(" -> ") :]
852
- )
853
- assert merge_detected is True, log.error(
854
- "Pipeline was never merged back together (with Stack operation), even "
855
- f"though it was split with {split_operation}. "
856
- "Please properly define your operation chain."
857
- )
858
- return f"\n{string}\n{split_line}\n"
859
-
860
- return string
861
-
862
- def __repr__(self):
863
- """String representation of the pipeline."""
864
- operations = []
865
- for operation in self.operations:
866
- if isinstance(operation, Pipeline):
867
- operations.append(repr(operation))
868
- else:
869
- operations.append(operation.__class__.__name__)
870
- return f"<Pipeline {self.name}=({', '.join(operations)})>"
871
-
872
- @classmethod
873
- def load(cls, file_path: str, **kwargs) -> "Pipeline":
874
- """Load a pipeline from a JSON or YAML file."""
875
- if file_path.endswith(".json"):
876
- with open(file_path, "r", encoding="utf-8") as f:
877
- json_str = f.read()
878
- return pipeline_from_json(json_str, **kwargs)
879
- elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
880
- return pipeline_from_yaml(file_path, **kwargs)
881
- else:
882
- raise ValueError("File must have extension .json, .yaml, or .yml")
883
-
884
- def get_dict(self) -> dict:
885
- """Convert the pipeline to a dictionary."""
886
- config = {}
887
- config["name"] = ops_registry.get_name(self)
888
- config["operations"] = self._pipeline_to_list(self)
889
- config["params"] = {
890
- "with_batch_dim": self.with_batch_dim,
891
- "jit_options": self.jit_options,
892
- "jit_kwargs": self.jit_kwargs,
893
- }
894
- return config
895
-
896
- @staticmethod
897
- def _pipeline_to_list(pipeline):
898
- """Convert the pipeline to a list of operations."""
899
- ops_list = []
900
- for op in pipeline.operations:
901
- ops_list.append(op.get_dict())
902
- return ops_list
903
-
904
- @classmethod
905
- def from_config(cls, config: Dict, **kwargs) -> "Pipeline":
906
- """Create a pipeline from a dictionary or ``zea.Config`` object.
907
-
908
- Args:
909
- config (dict or Config): Configuration dictionary or ``zea.Config`` object.
910
- **kwargs: Additional keyword arguments to be passed to the pipeline.
911
-
912
- Note:
913
- Must have a ``pipeline`` key with a subkey ``operations``.
914
-
915
- Example:
916
- .. doctest::
917
-
918
- >>> from zea import Config, Pipeline
919
- >>> config = Config(
920
- ... {
921
- ... "operations": [
922
- ... "identity",
923
- ... ],
924
- ... }
925
- ... )
926
- >>> pipeline = Pipeline.from_config(config)
927
- """
928
- return pipeline_from_config(Config(config), **kwargs)
929
-
930
- @classmethod
931
- def from_yaml(cls, file_path: str, **kwargs) -> "Pipeline":
932
- """Create a pipeline from a YAML file.
933
-
934
- Args:
935
- file_path (str): Path to the YAML file.
936
- **kwargs: Additional keyword arguments to be passed to the pipeline.
937
-
938
- Note:
939
- Must have the a `pipeline` key with a subkey `operations`.
940
-
941
- Example:
942
- .. doctest::
943
-
944
- >>> import yaml
945
- >>> from zea import Config
946
- >>> # Create a sample pipeline YAML file
947
- >>> pipeline_dict = {
948
- ... "operations": [
949
- ... "identity",
950
- ... ],
951
- ... }
952
- >>> with open("pipeline.yaml", "w") as f:
953
- ... yaml.dump(pipeline_dict, f)
954
- >>> from zea.ops import Pipeline
955
- >>> pipeline = Pipeline.from_yaml("pipeline.yaml", jit_options=None)
956
- """
957
- return pipeline_from_yaml(file_path, **kwargs)
958
-
959
- @classmethod
960
- def from_json(cls, json_string: str, **kwargs) -> "Pipeline":
961
- """Create a pipeline from a JSON string.
962
-
963
- Args:
964
- json_string (str): JSON string representing the pipeline.
965
- **kwargs: Additional keyword arguments to be passed to the pipeline.
966
-
967
- Note:
968
- Must have the `operations` key.
969
-
970
- Example:
971
- ```python
972
- json_string = '{"operations": ["identity"]}'
973
- pipeline = Pipeline.from_json(json_string)
974
- ```
975
- """
976
- return pipeline_from_json(json_string, **kwargs)
977
-
978
- def to_config(self) -> Config:
979
- """Convert the pipeline to a `zea.Config` object."""
980
- return pipeline_to_config(self)
981
-
982
- def to_json(self) -> str:
983
- """Convert the pipeline to a JSON string."""
984
- return pipeline_to_json(self)
985
-
986
- def to_yaml(self, file_path: str) -> None:
987
- """Convert the pipeline to a YAML file."""
988
- pipeline_to_yaml(self, file_path)
989
-
990
- @property
991
- def key(self) -> str:
992
- """Input key of the pipeline."""
993
- return self.operations[0].key
994
-
995
- @property
996
- def output_key(self) -> str:
997
- """Output key of the pipeline."""
998
- return self.operations[-1].output_key
999
-
1000
- def __eq__(self, other):
1001
- """Check if two pipelines are equal."""
1002
- if not isinstance(other, Pipeline):
1003
- return False
1004
-
1005
- # Compare the operations in both pipelines
1006
- if len(self.operations) != len(other.operations):
1007
- return False
1008
-
1009
- for op1, op2 in zip(self.operations, other.operations):
1010
- if not op1 == op2:
1011
- return False
1012
-
1013
- return True
1014
-
1015
- def prepare_parameters(
1016
- self,
1017
- probe: Probe = None,
1018
- scan: Scan = None,
1019
- config: Config = None,
1020
- **kwargs,
1021
- ):
1022
- """Prepare Probe, Scan and Config objects for the pipeline.
1023
-
1024
- Serializes `zea.core.Object` instances and converts them to
1025
- dictionary of tensors.
1026
-
1027
- Args:
1028
- probe: Probe object.
1029
- scan: Scan object.
1030
- config: Config object.
1031
- include (None, "all", or list): Only include these parameter/computed property names.
1032
- If None or "all", include all.
1033
- exclude (None or list): Exclude these parameter/computed property names.
1034
- If provided, these keys will be excluded from the output.
1035
- Only one of include or exclude can be set.
1036
-
1037
- **kwargs: Additional keyword arguments to be included in the inputs.
1038
-
1039
- Returns:
1040
- dict: Dictionary of inputs with all values as tensors.
1041
- """
1042
- # Initialize dictionaries for probe, scan, and config
1043
- probe_dict, scan_dict, config_dict = {}, {}, {}
1044
-
1045
- # Process args to extract Probe, Scan, and Config objects
1046
- if probe is not None:
1047
- assert isinstance(probe, Probe), (
1048
- f"Expected an instance of `zea.probes.Probe`, got {type(probe)}"
1049
- )
1050
- probe_dict = probe.to_tensor(keep_as_is=self.static_params)
1051
-
1052
- if scan is not None:
1053
- assert isinstance(scan, Scan), (
1054
- f"Expected an instance of `zea.scan.Scan`, got {type(scan)}"
1055
- )
1056
- scan_dict = scan.to_tensor(include=self.needs_keys, keep_as_is=self.static_params)
1057
-
1058
- if config is not None:
1059
- assert isinstance(config, Config), (
1060
- f"Expected an instance of `zea.config.Config`, got {type(config)}"
1061
- )
1062
- config_dict.update(config.to_tensor(keep_as_is=self.static_params))
1063
-
1064
- # Convert all kwargs to tensors
1065
- tensor_kwargs = dict_to_tensor(kwargs, keep_as_is=self.static_params)
1066
-
1067
- # combine probe, scan, config and kwargs
1068
- # explicitly so we know which keys overwrite which
1069
- # kwargs > config > scan > probe
1070
- inputs = {
1071
- **probe_dict,
1072
- **scan_dict,
1073
- **config_dict,
1074
- **tensor_kwargs,
1075
- }
1076
-
1077
- return inputs
1078
-
1079
-
1080
- def make_operation_chain(
1081
- operation_chain: List[Union[str, Dict, Config, Operation, Pipeline]],
1082
- ) -> List[Operation]:
1083
- """Make an operation chain from a custom list of operations.
1084
-
1085
- Args:
1086
- operation_chain (list): List of operations to be performed.
1087
- Each operation can be:
1088
- - A string: operation initialized with default parameters
1089
- - A dictionary: operation initialized with parameters in the dictionary
1090
- - A Config object: converted to a dictionary and initialized
1091
- - An Operation/Pipeline instance: used as-is
1092
-
1093
- Returns:
1094
- list: List of operations to be performed.
1095
-
1096
- Example:
1097
- .. doctest::
1098
-
1099
- >>> from zea.ops import make_operation_chain, LogCompress
1100
- >>> SomeCustomOperation = LogCompress # just for demonstration
1101
- >>> chain = make_operation_chain(
1102
- ... [
1103
- ... "envelope_detect",
1104
- ... {"name": "normalize", "params": {"output_range": (0, 1)}},
1105
- ... SomeCustomOperation(),
1106
- ... ]
1107
- ... )
1108
- """
1109
- chain = []
1110
- for operation in operation_chain:
1111
- # Handle already instantiated Operation or Pipeline objects
1112
- if isinstance(operation, (Operation, Pipeline)):
1113
- chain.append(operation)
1114
- continue
1115
-
1116
- assert isinstance(operation, (str, dict, Config)), (
1117
- f"Operation {operation} should be a string, dict, Config object, Operation, or Pipeline"
1118
- )
1119
-
1120
- if isinstance(operation, str):
1121
- operation_instance = get_ops(operation)()
1122
-
1123
- else:
1124
- if isinstance(operation, Config):
1125
- operation = operation.serialize()
1126
-
1127
- params = operation.get("params", {})
1128
- op_name = operation.get("name")
1129
- operation_cls = get_ops(op_name)
1130
-
1131
- # Handle branches for branched pipeline
1132
- if op_name == "branched_pipeline" and "branches" in operation:
1133
- branch_configs = operation.get("branches", {})
1134
- branches = []
1135
-
1136
- # Convert each branch configuration to an operation chain
1137
- for _, branch_config in branch_configs.items():
1138
- if isinstance(branch_config, (list, np.ndarray)):
1139
- # This is a list of operations
1140
- branch = make_operation_chain(branch_config)
1141
- elif "operations" in branch_config:
1142
- # This is a pipeline-like branch
1143
- branch = make_operation_chain(branch_config["operations"])
1144
- else:
1145
- # This is a single operation branch
1146
- branch_op_cls = get_ops(branch_config["name"])
1147
- branch_params = branch_config.get("params", {})
1148
- branch = branch_op_cls(**branch_params)
1149
-
1150
- branches.append(branch)
1151
-
1152
- # Create the branched pipeline instance
1153
- operation_instance = operation_cls(branches=branches, **params)
1154
- # Check for nested operations at the same level as params
1155
- elif "operations" in operation:
1156
- nested_operations = make_operation_chain(operation["operations"])
1157
-
1158
- # Instantiate pipeline-type operations with nested operations
1159
- if issubclass(operation_cls, Pipeline):
1160
- operation_instance = operation_cls(operations=nested_operations, **params)
1161
- else:
1162
- operation_instance = operation_cls(operations=nested_operations, **params)
1163
- elif operation["name"] in ["patched_grid"]:
1164
- nested_operations = make_operation_chain(operation["params"].pop("operations"))
1165
- operation_instance = operation_cls(operations=nested_operations, **params)
1166
- else:
1167
- operation_instance = operation_cls(**params)
1168
-
1169
- chain.append(operation_instance)
1170
-
1171
- return chain
1172
-
1173
-
1174
- def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
1175
- """
1176
- Create a Pipeline instance from a Config object.
1177
- """
1178
- assert "operations" in config, (
1179
- "Config object must have an 'operations' key for pipeline creation."
1180
- )
1181
- assert isinstance(config.operations, (list, np.ndarray)), (
1182
- "Config object must have a list or numpy array of operations for pipeline creation."
1183
- )
1184
-
1185
- operations = make_operation_chain(config.operations)
1186
-
1187
- # merge pipeline config without operations with kwargs
1188
- pipeline_config = copy.deepcopy(config)
1189
- pipeline_config.pop("operations")
1190
-
1191
- kwargs = {**pipeline_config, **kwargs}
1192
- return Pipeline(operations=operations, **kwargs)
1193
-
1194
-
1195
- def pipeline_from_json(json_string: str, **kwargs) -> Pipeline:
1196
- """
1197
- Create a Pipeline instance from a JSON string.
1198
- """
1199
- pipeline_config = Config(json.loads(json_string, cls=ZEADecoderJSON))
1200
- return pipeline_from_config(pipeline_config, **kwargs)
1201
-
1202
-
1203
- def pipeline_from_yaml(yaml_path: str, **kwargs) -> Pipeline:
1204
- """
1205
- Create a Pipeline instance from a YAML file.
1206
- """
1207
- with open(yaml_path, "r", encoding="utf-8") as f:
1208
- pipeline_config = yaml.safe_load(f)
1209
- operations = pipeline_config["operations"]
1210
- return pipeline_from_config(Config({"operations": operations}), **kwargs)
1211
-
1212
-
1213
- def pipeline_to_config(pipeline: Pipeline) -> Config:
1214
- """
1215
- Convert a Pipeline instance into a Config object.
1216
- """
1217
- # TODO: we currently add the full pipeline as 1 operation to the config.
1218
- # In another PR we should add a "pipeline" entry to the config instead of the "operations"
1219
- # entry. This allows us to also have non-default pipeline classes as top level op.
1220
- pipeline_dict = {"operations": [pipeline.get_dict()]}
1221
-
1222
- # HACK: If the top level operation is a single pipeline, collapse it into the operations list.
1223
- ops = pipeline_dict["operations"]
1224
- if ops[0]["name"] == "pipeline" and len(ops) == 1:
1225
- pipeline_dict = {"operations": ops[0]["operations"]}
1226
-
1227
- return Config(pipeline_dict)
1228
-
1229
-
1230
- def pipeline_to_json(pipeline: Pipeline) -> str:
1231
- """
1232
- Convert a Pipeline instance into a JSON string.
1233
- """
1234
- pipeline_dict = {"operations": [pipeline.get_dict()]}
1235
-
1236
- # HACK: If the top level operation is a single pipeline, collapse it into the operations list.
1237
- ops = pipeline_dict["operations"]
1238
- if ops[0]["name"] == "pipeline" and len(ops) == 1:
1239
- pipeline_dict = {"operations": ops[0]["operations"]}
1240
-
1241
- return json.dumps(pipeline_dict, cls=ZEAEncoderJSON, indent=4)
1242
-
1243
-
1244
- def pipeline_to_yaml(pipeline: Pipeline, file_path: str) -> None:
1245
- """
1246
- Convert a Pipeline instance into a YAML file.
1247
- """
1248
- pipeline_dict = pipeline.get_dict()
1249
-
1250
- # HACK: If the top level operation is a single pipeline, collapse it into the operations list.
1251
- ops = pipeline_dict["operations"]
1252
- if ops[0]["name"] == "pipeline" and len(ops) == 1:
1253
- pipeline_dict = {"operations": ops[0]["operations"]}
1254
-
1255
- with open(file_path, "w", encoding="utf-8") as f:
1256
- yaml.dump(pipeline_dict, f, Dumper=yaml.Dumper, indent=4)
1257
-
1258
-
1259
- @ops_registry("patched_grid")
1260
- class PatchedGrid(Pipeline):
1261
- """
1262
- With this class you can form a pipeline that will be applied to patches of the grid.
1263
- This is useful to avoid OOM errors when processing large grids.
1264
-
1265
- Some things to NOTE about this class:
1266
-
1267
- - The ops have to use flatgrid and flat_pfield as inputs, these will be patched.
1268
-
1269
- - Changing anything other than `self.output_data_type` in the dict will not be propagated!
1270
-
1271
- - Will be jitted as a single operation, not the individual operations.
1272
-
1273
- - This class handles the batching.
1274
-
1275
- """
1276
-
1277
- def __init__(self, *args, num_patches=10, **kwargs):
1278
- super().__init__(*args, name="patched_grid", **kwargs)
1279
- self.num_patches = num_patches
1280
-
1281
- for operation in self.operations:
1282
- if isinstance(operation, DelayAndSum):
1283
- operation.reshape_grid = False
1284
-
1285
- self._jittable_call = self.jittable_call
1286
-
1287
- @property
1288
- def jit_options(self):
1289
- """Get the jit_options property of the pipeline."""
1290
- return self._jit_options
1291
-
1292
- @jit_options.setter
1293
- def jit_options(self, value):
1294
- """Set the jit_options property of the pipeline."""
1295
- self._jit_options = value
1296
- if value in ["pipeline", "ops"]:
1297
- self.jit()
1298
- else:
1299
- self.unjit()
1300
-
1301
- def jit(self):
1302
- """JIT compile the pipeline."""
1303
- self._jittable_call = jit(self.jittable_call, **self.jit_kwargs)
1304
-
1305
- def unjit(self):
1306
- """Un-JIT compile the pipeline."""
1307
- self._jittable_call = self.jittable_call
1308
- self._call_pipeline = self.call
1309
-
1310
- @property
1311
- def with_batch_dim(self):
1312
- """Get the with_batch_dim property of the pipeline."""
1313
- return self._with_batch_dim
1314
-
1315
- @with_batch_dim.setter
1316
- def with_batch_dim(self, value):
1317
- """Set the with_batch_dim property of the pipeline.
1318
- The class handles the batching so the operations have to be set to False."""
1319
- self._with_batch_dim = value
1320
- for operation in self.operations:
1321
- operation.with_batch_dim = False
1322
-
1323
- @property
1324
- def _extra_keys(self):
1325
- return {"flatgrid", "grid_size_x", "grid_size_z"}
1326
-
1327
- @property
1328
- def valid_keys(self) -> set:
1329
- """Get a set of valid keys for the pipeline.
1330
- Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1331
- inside it)."""
1332
- return super().valid_keys.union(self._extra_keys)
1333
-
1334
- @property
1335
- def needs_keys(self) -> set:
1336
- """Get a set of all input keys needed by the pipeline.
1337
- Adds the parameters that PatchedGrid itself operates on (even if not used by operations
1338
- inside it)."""
1339
- return super().needs_keys.union(self._extra_keys)
1340
-
1341
- def call_item(self, inputs):
1342
- """Process data in patches."""
1343
- # Extract necessary parameters
1344
- # make sure to add those as valid keys above!
1345
- grid_size_x = inputs["grid_size_x"]
1346
- grid_size_z = inputs["grid_size_z"]
1347
- flatgrid = inputs.pop("flatgrid")
1348
-
1349
- # Define a list of keys to look up for patching
1350
- flat_pfield = inputs.pop("flat_pfield", None)
1351
-
1352
- def patched_call(flatgrid, flat_pfield):
1353
- out = super(PatchedGrid, self).call(
1354
- flatgrid=flatgrid, flat_pfield=flat_pfield, **inputs
1355
- )
1356
- return out[self.output_key]
1357
-
1358
- out = vmap(
1359
- patched_call,
1360
- chunks=self.num_patches,
1361
- fn_supports_batch=True,
1362
- disable_jit=not bool(self.jit_options),
1363
- )(flatgrid, flat_pfield)
1364
-
1365
- return ops.reshape(out, (grid_size_z, grid_size_x, *ops.shape(out)[1:]))
1366
-
1367
- def jittable_call(self, **inputs):
1368
- """Process input data through the pipeline."""
1369
- if self._with_batch_dim:
1370
- input_data = inputs.pop(self.key)
1371
- output = ops.map(
1372
- lambda x: self.call_item({self.key: x, **inputs}),
1373
- input_data,
1374
- )
1375
- else:
1376
- output = self.call_item(inputs)
1377
-
1378
- return {self.output_key: output}
1379
-
1380
- def call(self, **inputs):
1381
- """Process input data through the pipeline."""
1382
- output = self._jittable_call(**inputs)
1383
- inputs.update(output)
1384
- return inputs
1385
-
1386
- def get_dict(self):
1387
- """Get the configuration of the pipeline."""
1388
- config = super().get_dict()
1389
- config.update({"name": "patched_grid"})
1390
- config["params"].update({"num_patches": self.num_patches})
1391
- return config
1392
-
1393
-
1394
- ## Base Operations
1395
-
1396
-
1397
- @ops_registry("identity")
1398
- class Identity(Operation):
1399
- """Identity operation."""
1400
-
1401
- def call(self, **kwargs) -> Dict:
1402
- """Returns the input as is."""
1403
- return {}
1404
-
1405
-
1406
- @ops_registry("merge")
1407
- class Merge(Operation):
1408
- """Operation that merges sets of input dictionaries."""
1409
-
1410
- def __init__(self, **kwargs):
1411
- super().__init__(**kwargs)
1412
- self.allow_multiple_inputs = True
1413
-
1414
- def call(self, *args, **kwargs) -> Dict:
1415
- """
1416
- Merges the input dictionaries. Priority is given to the last input.
1417
- """
1418
- merged = {}
1419
- for arg in args:
1420
- if not isinstance(arg, dict):
1421
- raise TypeError("All inputs must be dictionaries.")
1422
- merged.update(arg)
1423
- return merged
1424
-
1425
-
1426
- @ops_registry("split")
1427
- class Split(Operation):
1428
- """Operation that splits an input dictionary n copies."""
1429
-
1430
- def __init__(self, n: int, **kwargs):
1431
- super().__init__(**kwargs)
1432
- self.n = n
1433
-
1434
- def call(self, **kwargs) -> List[Dict]:
1435
- """
1436
- Splits the input dictionary into n copies.
1437
- """
1438
- return [kwargs.copy() for _ in range(self.n)]
1439
-
1440
-
1441
- @ops_registry("stack")
1442
- class Stack(Operation):
1443
- """Stack multiple data arrays along a new axis.
1444
- Useful to merge data from parallel pipelines.
1445
- """
1446
-
1447
- def __init__(
1448
- self,
1449
- keys: Union[str, List[str], None],
1450
- axes: Union[int, List[int], None],
1451
- **kwargs,
1452
- ):
1453
- super().__init__(**kwargs)
1454
-
1455
- self.keys, self.axes = _assert_keys_and_axes(keys, axes)
1456
-
1457
- def call(self, **kwargs) -> Dict:
1458
- """
1459
- Stacks the inputs corresponding to the specified keys along the specified axis.
1460
- If a list of axes is provided, the length must match the number of keys.
1461
- """
1462
- for key, axis in zip(self.keys, self.axes):
1463
- kwargs[key] = keras.ops.stack([kwargs[key] for key in self.keys], axis=axis)
1464
- return kwargs
1465
-
1466
-
1467
- @ops_registry("mean")
1468
- class Mean(Operation):
1469
- """Take the mean of the input data along a specific axis."""
1470
-
1471
- def __init__(self, keys, axes, **kwargs):
1472
- super().__init__(**kwargs)
1473
-
1474
- self.keys, self.axes = _assert_keys_and_axes(keys, axes)
1475
-
1476
- def call(self, **kwargs):
1477
- for key, axis in zip(self.keys, self.axes):
1478
- kwargs[key] = ops.mean(kwargs[key], axis=axis)
1479
-
1480
- return kwargs
1481
-
1482
-
1483
- @ops_registry("simulate_rf")
1484
- class Simulate(Operation):
1485
- """Simulate RF data."""
1486
-
1487
- # Define operation-specific static parameters
1488
- STATIC_PARAMS = ["n_ax", "apply_lens_correction"]
1489
-
1490
- def __init__(self, **kwargs):
1491
- super().__init__(
1492
- output_data_type=DataTypes.RAW_DATA,
1493
- additional_output_keys=["n_ch"],
1494
- **kwargs,
1495
- )
1496
-
1497
- def call(
1498
- self,
1499
- scatterer_positions,
1500
- scatterer_magnitudes,
1501
- probe_geometry,
1502
- apply_lens_correction,
1503
- lens_thickness,
1504
- lens_sound_speed,
1505
- sound_speed,
1506
- n_ax,
1507
- center_frequency,
1508
- sampling_frequency,
1509
- t0_delays,
1510
- initial_times,
1511
- element_width,
1512
- attenuation_coef,
1513
- tx_apodizations,
1514
- **kwargs,
1515
- ):
1516
- return {
1517
- self.output_key: simulate_rf(
1518
- ops.convert_to_tensor(scatterer_positions),
1519
- ops.convert_to_tensor(scatterer_magnitudes),
1520
- probe_geometry=probe_geometry,
1521
- apply_lens_correction=apply_lens_correction,
1522
- lens_thickness=lens_thickness,
1523
- lens_sound_speed=lens_sound_speed,
1524
- sound_speed=sound_speed,
1525
- n_ax=n_ax,
1526
- center_frequency=center_frequency,
1527
- sampling_frequency=sampling_frequency,
1528
- t0_delays=t0_delays,
1529
- initial_times=initial_times,
1530
- element_width=element_width,
1531
- attenuation_coef=attenuation_coef,
1532
- tx_apodizations=tx_apodizations,
1533
- ),
1534
- "n_ch": 1, # Simulate always returns RF data (so single channel)
1535
- }
1536
-
1537
-
1538
- @ops_registry("tof_correction")
1539
- class TOFCorrection(Operation):
1540
- """Time-of-flight correction operation for ultrasound data."""
1541
-
1542
- # Define operation-specific static parameters
1543
- STATIC_PARAMS = [
1544
- "f_number",
1545
- "apply_lens_correction",
1546
- "grid_size_x",
1547
- "grid_size_z",
1548
- ]
1549
-
1550
- def __init__(self, **kwargs):
1551
- super().__init__(
1552
- input_data_type=DataTypes.RAW_DATA,
1553
- output_data_type=DataTypes.ALIGNED_DATA,
1554
- **kwargs,
1555
- )
1556
-
1557
- def call(
1558
- self,
1559
- flatgrid,
1560
- sound_speed,
1561
- polar_angles,
1562
- focus_distances,
1563
- sampling_frequency,
1564
- f_number,
1565
- demodulation_frequency,
1566
- t0_delays,
1567
- tx_apodizations,
1568
- initial_times,
1569
- probe_geometry,
1570
- t_peak,
1571
- tx_waveform_indices,
1572
- apply_lens_correction=None,
1573
- lens_thickness=None,
1574
- lens_sound_speed=None,
1575
- **kwargs,
1576
- ):
1577
- """Perform time-of-flight correction on raw RF data.
1578
-
1579
- Args:
1580
- raw_data (ops.Tensor): Raw RF data to correct
1581
- flatgrid (ops.Tensor): Grid points at which to evaluate the time-of-flight
1582
- sound_speed (float): Sound speed in the medium
1583
- polar_angles (ops.Tensor): Polar angles for scan lines
1584
- focus_distances (ops.Tensor): Focus distances for scan lines
1585
- sampling_frequency (float): Sampling frequency
1586
- f_number (float): F-number for apodization
1587
- demodulation_frequency (float): Demodulation frequency
1588
- t0_delays (ops.Tensor): T0 delays
1589
- tx_apodizations (ops.Tensor): Transmit apodizations
1590
- initial_times (ops.Tensor): Initial times
1591
- probe_geometry (ops.Tensor): Probe element positions
1592
- t_peak (float): Time to peak of the transmit pulse
1593
- tx_waveform_indices (ops.Tensor): Index of the transmit waveform for each
1594
- transmit. (All zero if there is only one waveform)
1595
- apply_lens_correction (bool): Whether to apply lens correction
1596
- lens_thickness (float): Lens thickness
1597
- lens_sound_speed (float): Sound speed in the lens
1598
-
1599
- Returns:
1600
- dict: Dictionary containing tof_corrected_data
1601
- """
1602
-
1603
- raw_data = kwargs[self.key]
1604
-
1605
- tof_kwargs = {
1606
- "flatgrid": flatgrid,
1607
- "t0_delays": t0_delays,
1608
- "tx_apodizations": tx_apodizations,
1609
- "sound_speed": sound_speed,
1610
- "probe_geometry": probe_geometry,
1611
- "initial_times": initial_times,
1612
- "sampling_frequency": sampling_frequency,
1613
- "demodulation_frequency": demodulation_frequency,
1614
- "f_number": f_number,
1615
- "polar_angles": polar_angles,
1616
- "focus_distances": focus_distances,
1617
- "t_peak": t_peak,
1618
- "tx_waveform_indices": tx_waveform_indices,
1619
- "apply_lens_correction": apply_lens_correction,
1620
- "lens_thickness": lens_thickness,
1621
- "lens_sound_speed": lens_sound_speed,
1622
- }
1623
-
1624
- if not self.with_batch_dim:
1625
- tof_corrected = tof_correction(raw_data, **tof_kwargs)
1626
- else:
1627
- tof_corrected = ops.map(
1628
- lambda data: tof_correction(data, **tof_kwargs),
1629
- raw_data,
1630
- )
1631
-
1632
- return {self.output_key: tof_corrected}
1633
-
1634
-
1635
- @ops_registry("pfield_weighting")
1636
- class PfieldWeighting(Operation):
1637
- """Weighting aligned data with the pressure field."""
1638
-
1639
- def __init__(self, **kwargs):
1640
- super().__init__(
1641
- input_data_type=DataTypes.ALIGNED_DATA,
1642
- output_data_type=DataTypes.ALIGNED_DATA,
1643
- **kwargs,
1644
- )
1645
-
1646
- def call(self, flat_pfield=None, **kwargs):
1647
- """Weight data with pressure field.
1648
-
1649
- Args:
1650
- flat_pfield (ops.Tensor): Pressure field weight mask of shape (n_pix, n_tx)
1651
-
1652
- Returns:
1653
- dict: Dictionary containing weighted data
1654
- """
1655
- data = kwargs[self.key]
1656
-
1657
- if flat_pfield is None:
1658
- return {self.output_key: data}
1659
-
1660
- # Swap (n_pix, n_tx) to (n_tx, n_pix)
1661
- flat_pfield = ops.swapaxes(flat_pfield, 0, 1)
1662
-
1663
- # Perform element-wise multiplication with the pressure weight mask
1664
- # Also add the required dimensions for broadcasting
1665
- if self.with_batch_dim:
1666
- pfield_expanded = ops.expand_dims(flat_pfield, axis=0)
1667
- else:
1668
- pfield_expanded = flat_pfield
1669
-
1670
- pfield_expanded = pfield_expanded[..., None, None]
1671
- weighted_data = data * pfield_expanded
1672
-
1673
- return {self.output_key: weighted_data}
1674
-
1675
-
1676
- @ops_registry("delay_and_sum")
1677
- class DelayAndSum(Operation):
1678
- """Sums time-delayed signals along channels and transmits."""
1679
-
1680
- def __init__(
1681
- self,
1682
- reshape_grid=True,
1683
- **kwargs,
1684
- ):
1685
- super().__init__(
1686
- input_data_type=DataTypes.ALIGNED_DATA,
1687
- output_data_type=DataTypes.BEAMFORMED_DATA,
1688
- **kwargs,
1689
- )
1690
- self.reshape_grid = reshape_grid
1691
-
1692
- def process_image(self, data):
1693
- """Performs DAS beamforming on tof-corrected input.
1694
-
1695
- Args:
1696
- data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
1697
-
1698
- Returns:
1699
- ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
1700
- """
1701
- # Sum over the channels, i.e. DAS
1702
- data = ops.sum(data, -2)
1703
-
1704
- # Sum over transmits, i.e. Compounding
1705
- data = ops.sum(data, 0)
1706
-
1707
- return data
1708
-
1709
- def call(self, grid=None, **kwargs):
1710
- """Performs DAS beamforming on tof-corrected input.
1711
-
1712
- Args:
1713
- tof_corrected_data (ops.Tensor): The TOF corrected input of shape
1714
- `(n_tx, grid_size_z*grid_size_x, n_el, n_ch)` with optional batch dimension.
1715
-
1716
- Returns:
1717
- dict: Dictionary containing beamformed_data
1718
- of shape `(grid_size_z*grid_size_x, n_ch)` when reshape_grid is False
1719
- or `(grid_size_z, grid_size_x, n_ch)` when reshape_grid is True,
1720
- with optional batch dimension.
1721
- """
1722
- data = kwargs[self.key]
1723
-
1724
- if not self.with_batch_dim:
1725
- beamformed_data = self.process_image(data)
1726
- else:
1727
- # Apply process_image to each item in the batch
1728
- beamformed_data = ops.map(self.process_image, data)
1729
-
1730
- if self.reshape_grid:
1731
- beamformed_data = reshape_axis(
1732
- beamformed_data, grid.shape[:2], axis=int(self.with_batch_dim)
1733
- )
1734
-
1735
- return {self.output_key: beamformed_data}
1736
-
1737
-
1738
- def envelope_detect(data, axis=-3):
1739
- """Envelope detection of RF signals.
1740
-
1741
- If the input data is real, it first applies the Hilbert transform along the specified axis
1742
- and then computes the magnitude of the resulting complex signal.
1743
- If the input data is complex, it computes the magnitude directly.
1744
-
1745
- Args:
1746
- - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
1747
- - axis (int): Axis along which to apply the Hilbert transform. Defaults to -3.
1748
-
1749
- Returns:
1750
- - envelope_data (Tensor): The envelope detected data
1751
- of shape (..., grid_size_z, grid_size_x).
1752
- """
1753
- if data.shape[-1] == 2:
1754
- data = channels_to_complex(data)
1755
- else:
1756
- n_ax = ops.shape(data)[axis]
1757
- n_ax_float = ops.cast(n_ax, "float32")
1758
-
1759
- # Calculate next power of 2: M = 2^ceil(log2(n_ax))
1760
- # see https://github.com/tue-bmd/zea/discussions/147
1761
- log2_n_ax = ops.log2(n_ax_float)
1762
- M = ops.cast(2 ** ops.ceil(log2_n_ax), "int32")
1763
-
1764
- data = hilbert(data, N=M, axis=axis)
1765
- indices = ops.arange(n_ax)
1766
-
1767
- data = ops.take(data, indices, axis=axis)
1768
- data = ops.squeeze(data, axis=-1)
1769
-
1770
- # data = ops.abs(data)
1771
- real = ops.real(data)
1772
- imag = ops.imag(data)
1773
- data = ops.sqrt(real**2 + imag**2)
1774
- data = ops.cast(data, "float32")
1775
- return data
1776
-
1777
-
1778
- @ops_registry("envelope_detect")
1779
- class EnvelopeDetect(Operation):
1780
- """Envelope detection of RF signals."""
1781
-
1782
- def __init__(
1783
- self,
1784
- axis=-3,
1785
- **kwargs,
1786
- ):
1787
- super().__init__(
1788
- input_data_type=DataTypes.BEAMFORMED_DATA,
1789
- output_data_type=DataTypes.ENVELOPE_DATA,
1790
- **kwargs,
1791
- )
1792
- self.axis = axis
1793
-
1794
- def call(self, **kwargs):
1795
- """
1796
- Args:
1797
- - data (Tensor): The beamformed data of shape (..., grid_size_z, grid_size_x, n_ch).
1798
- Returns:
1799
- - envelope_data (Tensor): The envelope detected data
1800
- of shape (..., grid_size_z, grid_size_x).
1801
- """
1802
- data = kwargs[self.key]
1803
-
1804
- data = envelope_detect(data, axis=self.axis)
1805
-
1806
- return {self.output_key: data}
1807
-
1808
-
1809
- @ops_registry("upmix")
1810
- class UpMix(Operation):
1811
- """Upmix IQ data to RF data."""
1812
-
1813
- def __init__(
1814
- self,
1815
- upsampling_rate=1,
1816
- **kwargs,
1817
- ):
1818
- super().__init__(
1819
- **kwargs,
1820
- )
1821
- self.upsampling_rate = upsampling_rate
1822
-
1823
- def call(
1824
- self,
1825
- sampling_frequency=None,
1826
- center_frequency=None,
1827
- **kwargs,
1828
- ):
1829
- data = kwargs[self.key]
1830
-
1831
- if data.shape[-1] == 1:
1832
- log.warning("Upmixing is not applicable to RF data.")
1833
- return data
1834
- elif data.shape[-1] == 2:
1835
- data = channels_to_complex(data)
1836
-
1837
- data = upmix(data, sampling_frequency, center_frequency, self.upsampling_rate)
1838
- data = ops.expand_dims(data, axis=-1)
1839
- return {self.output_key: data}
1840
-
1841
-
1842
- def log_compress(data, eps=1e-16):
1843
- """Apply logarithmic compression to data."""
1844
- eps = ops.convert_to_tensor(eps, dtype=data.dtype)
1845
- data = ops.where(data == 0, eps, data) # Avoid log(0)
1846
- return 20 * keras.ops.log10(data)
1847
-
1848
-
1849
- @ops_registry("log_compress")
1850
- class LogCompress(Operation):
1851
- """Logarithmic compression of data."""
1852
-
1853
- def __init__(self, clip: bool = True, **kwargs):
1854
- """Initialize the LogCompress operation.
1855
-
1856
- Args:
1857
- clip (bool): Whether to clip the output to a dynamic range. Defaults to True.
1858
- """
1859
- super().__init__(
1860
- input_data_type=DataTypes.ENVELOPE_DATA,
1861
- output_data_type=DataTypes.IMAGE,
1862
- **kwargs,
1863
- )
1864
- self.clip = clip
1865
-
1866
- def call(self, dynamic_range=None, **kwargs):
1867
- """Apply logarithmic compression to data.
1868
-
1869
- Args:
1870
- dynamic_range (tuple, optional): Dynamic range in dB. Defaults to (-60, 0).
1871
-
1872
- Returns:
1873
- dict: Dictionary containing log-compressed data
1874
- """
1875
- data = kwargs[self.key]
1876
-
1877
- if dynamic_range is None:
1878
- dynamic_range = ops.array(DEFAULT_DYNAMIC_RANGE)
1879
- dynamic_range = ops.cast(dynamic_range, data.dtype)
1880
-
1881
- compressed_data = log_compress(data)
1882
- if self.clip:
1883
- compressed_data = ops.clip(compressed_data, dynamic_range[0], dynamic_range[1])
1884
-
1885
- return {self.output_key: compressed_data}
1886
-
1887
-
1888
- def normalize(data, output_range, input_range=None):
1889
- """Normalize data to a given range.
1890
-
1891
- Equivalent to `translate` with clipping.
1892
-
1893
- Args:
1894
- data (ops.Tensor): Input data to normalize.
1895
- output_range (tuple): Range to which data should be mapped, e.g., (0, 1).
1896
- input_range (tuple, optional): Range of input data.
1897
- If None, the range will be computed from the data.
1898
- Defaults to None.
1899
- """
1900
- if input_range is None:
1901
- input_range = (None, None)
1902
- minval, maxval = input_range
1903
- if minval is None:
1904
- minval = ops.min(data)
1905
- if maxval is None:
1906
- maxval = ops.max(data)
1907
- data = ops.clip(data, minval, maxval)
1908
- normalized_data = translate(data, (minval, maxval), output_range)
1909
- return normalized_data
1910
-
1911
-
1912
- @ops_registry("normalize")
1913
- class Normalize(Operation):
1914
- """Normalize data to a given range."""
1915
-
1916
- def __init__(self, output_range=None, input_range=None, **kwargs):
1917
- super().__init__(additional_output_keys=["minval", "maxval"], **kwargs)
1918
- if output_range is None:
1919
- output_range = (0, 1)
1920
- self.output_range = self.to_float32(output_range)
1921
- self.input_range = self.to_float32(input_range)
1922
- assert output_range is None or len(output_range) == 2
1923
- assert input_range is None or len(input_range) == 2
1924
-
1925
- @staticmethod
1926
- def to_float32(data):
1927
- """Converts an iterable to float32 and leaves None values as is."""
1928
- return (
1929
- [np.float32(x) if x is not None else None for x in data] if data is not None else None
1930
- )
1931
-
1932
- def call(self, **kwargs):
1933
- """Normalize data to a given range.
1934
-
1935
- Args:
1936
- output_range (tuple, optional): Range to which data should be mapped.
1937
- Defaults to (0, 1).
1938
- input_range (tuple, optional): Range of input data. If None, the range
1939
- of the input data will be computed. Defaults to None.
1940
-
1941
- Returns:
1942
- dict: Dictionary containing normalized data, along with the computed
1943
- or provided input range (minval and maxval).
1944
- """
1945
- data = kwargs[self.key]
1946
-
1947
- # If input_range is not provided, try to get it from kwargs
1948
- # This allows you to normalize based on the first frame in a sequence and avoid flicker
1949
- if self.input_range is None:
1950
- maxval = kwargs.get("maxval", None)
1951
- minval = kwargs.get("minval", None)
1952
- # If input_range is provided, use it
1953
- else:
1954
- minval, maxval = self.input_range
1955
-
1956
- # If input_range is still not provided, compute it from the data
1957
- if minval is None:
1958
- minval = ops.min(data)
1959
- if maxval is None:
1960
- maxval = ops.max(data)
1961
-
1962
- normalized_data = normalize(
1963
- data, output_range=self.output_range, input_range=(minval, maxval)
1964
- )
1965
-
1966
- return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
1967
-
1968
-
1969
- @ops_registry("scan_convert")
1970
- class ScanConvert(Operation):
1971
- """Scan convert images to cartesian coordinates."""
1972
-
1973
- STATIC_PARAMS = ["fill_value"]
1974
-
1975
- def __init__(self, order=1, **kwargs):
1976
- """Initialize the ScanConvert operation.
1977
-
1978
- Args:
1979
- order (int, optional): Interpolation order. Defaults to 1. Currently only
1980
- GPU support for order=1.
1981
- """
1982
- if order > 1:
1983
- jittable = False
1984
- log.warning(
1985
- "GPU support for order > 1 is not available. " + "Disabling jit for ScanConvert."
1986
- )
1987
- else:
1988
- jittable = True
1989
-
1990
- super().__init__(
1991
- input_data_type=DataTypes.IMAGE,
1992
- output_data_type=DataTypes.IMAGE_SC,
1993
- jittable=jittable,
1994
- additional_output_keys=[
1995
- "resolution",
1996
- "x_lim",
1997
- "y_lim",
1998
- "z_lim",
1999
- "rho_range",
2000
- "theta_range",
2001
- "phi_range",
2002
- "d_rho",
2003
- "d_theta",
2004
- "d_phi",
2005
- ],
2006
- **kwargs,
2007
- )
2008
- self.order = order
2009
-
2010
- def call(
2011
- self,
2012
- rho_range=None,
2013
- theta_range=None,
2014
- phi_range=None,
2015
- resolution=None,
2016
- coordinates=None,
2017
- fill_value=None,
2018
- **kwargs,
2019
- ):
2020
- """Scan convert images to cartesian coordinates.
2021
-
2022
- Args:
2023
- rho_range (Tuple): Range of the rho axis in the polar coordinate system.
2024
- Defined in meters.
2025
- theta_range (Tuple): Range of the theta axis in the polar coordinate system.
2026
- Defined in radians.
2027
- phi_range (Tuple): Range of the phi axis in the polar coordinate system.
2028
- Defined in radians.
2029
- resolution (float): Resolution of the output image in meters per pixel.
2030
- if None, the resolution is computed based on the input data.
2031
- coordinates (Tensor): Coordinates for scan convertion. If None, will be computed
2032
- based on rho_range, theta_range, phi_range and resolution. If provided, this
2033
- operation can be jitted.
2034
- fill_value (float): Value to fill the image with outside the defined region.
2035
-
2036
- """
2037
- if fill_value is None:
2038
- fill_value = np.nan
2039
-
2040
- data = kwargs[self.key]
2041
-
2042
- if self._jit_compile and self.jittable:
2043
- assert coordinates is not None, (
2044
- "coordinates must be provided to jit scan conversion."
2045
- "You can set ScanConvert(jit_compile=False) to disable jitting."
2046
- )
2047
-
2048
- data_out, parameters = scan_convert(
2049
- data,
2050
- rho_range,
2051
- theta_range,
2052
- phi_range,
2053
- resolution,
2054
- coordinates,
2055
- fill_value,
2056
- self.order,
2057
- with_batch_dim=self.with_batch_dim,
2058
- )
2059
-
2060
- return {self.output_key: data_out, **parameters}
2061
-
2062
-
2063
- @ops_registry("gaussian_blur")
2064
- class GaussianBlur(Operation):
2065
- """
2066
- GaussianBlur is an operation that applies a Gaussian blur to an input image.
2067
- Uses scipy.ndimage.gaussian_filter to create a kernel.
2068
- """
2069
-
2070
- def __init__(
2071
- self,
2072
- sigma: float,
2073
- kernel_size: int | None = None,
2074
- pad_mode="symmetric",
2075
- truncate=4.0,
2076
- **kwargs,
2077
- ):
2078
- """
2079
- Args:
2080
- sigma (float): Standard deviation for Gaussian kernel.
2081
- kernel_size (int, optional): The size of the kernel. If None, the kernel
2082
- size is calculated based on the sigma and truncate. Default is None.
2083
- pad_mode (str): Padding mode for the input image. Default is 'symmetric'.
2084
- truncate (float): Truncate the filter at this many standard deviations.
2085
- """
2086
- super().__init__(**kwargs)
2087
- if kernel_size is None:
2088
- radius = round(truncate * sigma)
2089
- self.kernel_size = 2 * radius + 1
2090
- else:
2091
- self.kernel_size = kernel_size
2092
- self.sigma = sigma
2093
- self.pad_mode = pad_mode
2094
- self.radius = self.kernel_size // 2
2095
- self.kernel = self.get_kernel()
2096
-
2097
- def get_kernel(self):
2098
- """
2099
- Create a gaussian kernel for blurring.
2100
-
2101
- Returns:
2102
- kernel (Tensor): A gaussian kernel for blurring.
2103
- Shape is (kernel_size, kernel_size, 1, 1).
2104
- """
2105
- n = np.zeros((self.kernel_size, self.kernel_size))
2106
- n[self.radius, self.radius] = 1
2107
- kernel = scipy.ndimage.gaussian_filter(n, sigma=self.sigma, mode="constant").astype(
2108
- np.float32
2109
- )
2110
- kernel = kernel[:, :, None, None]
2111
- return ops.convert_to_tensor(kernel)
2112
-
2113
- def call(self, **kwargs):
2114
- data = kwargs[self.key]
2115
-
2116
- # Add batch dimension if not present
2117
- if not self.with_batch_dim:
2118
- data = data[None]
2119
-
2120
- # Add channel dimension to kernel
2121
- kernel = ops.tile(self.kernel, (1, 1, data.shape[-1], data.shape[-1]))
2122
-
2123
- # Pad the input image according to the padding mode
2124
- padded = ops.pad(
2125
- data,
2126
- [[0, 0], [self.radius, self.radius], [self.radius, self.radius], [0, 0]],
2127
- mode=self.pad_mode,
2128
- )
2129
-
2130
- # Apply the gaussian kernel to the padded image
2131
- out = ops.conv(padded, kernel, padding="valid", data_format="channels_last")
2132
-
2133
- # Remove padding
2134
- out = ops.slice(
2135
- out,
2136
- [0, 0, 0, 0],
2137
- [out.shape[0], data.shape[1], data.shape[2], data.shape[3]],
2138
- )
2139
-
2140
- # Remove batch dimension if it was not present before
2141
- if not self.with_batch_dim:
2142
- out = ops.squeeze(out, axis=0)
2143
-
2144
- return {self.output_key: out}
2145
-
2146
-
2147
- @ops_registry("lee_filter")
2148
- class LeeFilter(Operation):
2149
- """
2150
- The Lee filter is a speckle reduction filter commonly used in synthetic aperture radar (SAR)
2151
- and ultrasound image processing. It smooths the image while preserving edges and details.
2152
- This implementation uses Gaussian filter for local statistics and treats channels independently.
2153
-
2154
- Lee, J.S. (1980). Digital image enhancement and noise filtering by use of local statistics.
2155
- IEEE Transactions on Pattern Analysis and Machine Intelligence, (2), 165-168.
2156
- """
2157
-
2158
- def __init__(self, sigma=3, kernel_size=None, pad_mode="symmetric", **kwargs):
2159
- """
2160
- Args:
2161
- sigma (float): Standard deviation for Gaussian kernel. Default is 3.
2162
- kernel_size (int, optional): Size of the Gaussian kernel. If None,
2163
- it will be calculated based on sigma.
2164
- pad_mode (str): Padding mode to be used for Gaussian blur. Default is "symmetric".
2165
- """
2166
- super().__init__(**kwargs)
2167
- self.sigma = sigma
2168
- self.kernel_size = kernel_size
2169
- self.pad_mode = pad_mode
2170
-
2171
- # Create a GaussianBlur instance for computing local statistics
2172
- self.gaussian_blur = GaussianBlur(
2173
- sigma=self.sigma,
2174
- kernel_size=self.kernel_size,
2175
- pad_mode=self.pad_mode,
2176
- with_batch_dim=self.with_batch_dim,
2177
- jittable=self._jittable,
2178
- key=self.key,
2179
- )
2180
-
2181
- @property
2182
- def with_batch_dim(self):
2183
- """Get the with_batch_dim property of the LeeFilter operation."""
2184
- return self._with_batch_dim
2185
-
2186
- @with_batch_dim.setter
2187
- def with_batch_dim(self, value):
2188
- """Set the with_batch_dim property of the LeeFilter operation."""
2189
- self._with_batch_dim = value
2190
- if hasattr(self, "gaussian_blur"):
2191
- self.gaussian_blur.with_batch_dim = value
2192
-
2193
- def call(self, **kwargs):
2194
- data = kwargs[self.key]
2195
-
2196
- # Apply Gaussian blur to get local mean
2197
- img_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
2198
-
2199
- # Apply Gaussian blur to squared data to get local squared mean
2200
- data_squared = data**2
2201
- kwargs[self.gaussian_blur.key] = data_squared
2202
- img_sqr_mean = self.gaussian_blur.call(**kwargs)[self.gaussian_blur.output_key]
2203
-
2204
- # Calculate local variance
2205
- img_variance = img_sqr_mean - img_mean**2
2206
-
2207
- # Calculate global variance (per channel)
2208
- if self.with_batch_dim:
2209
- overall_variance = ops.var(data, axis=(-3, -2), keepdims=True)
2210
- else:
2211
- overall_variance = ops.var(data, axis=(-2, -1), keepdims=True)
2212
-
2213
- # Calculate adaptive weights
2214
- img_weights = img_variance / (img_variance + overall_variance)
2215
-
2216
- # Apply Lee filter formula
2217
- img_output = img_mean + img_weights * (data - img_mean)
2218
-
2219
- return {self.output_key: img_output}
2220
-
2221
-
2222
- @ops_registry("demodulate")
2223
- class Demodulate(Operation):
2224
- """Demodulates the input data to baseband. After this operation, the carrier frequency
2225
- is removed (0 Hz) and the data is in IQ format stored in two real valued channels."""
2226
-
2227
- def __init__(self, axis=-3, **kwargs):
2228
- super().__init__(
2229
- input_data_type=DataTypes.RAW_DATA,
2230
- output_data_type=DataTypes.RAW_DATA,
2231
- jittable=True,
2232
- additional_output_keys=["demodulation_frequency", "center_frequency", "n_ch"],
2233
- **kwargs,
2234
- )
2235
- self.axis = axis
2236
-
2237
- def call(self, center_frequency=None, sampling_frequency=None, **kwargs):
2238
- data = kwargs[self.key]
2239
-
2240
- demodulation_frequency = center_frequency
2241
-
2242
- # Split the complex signal into two channels
2243
- iq_data_two_channel = demodulate(
2244
- data=data,
2245
- center_frequency=center_frequency,
2246
- sampling_frequency=sampling_frequency,
2247
- axis=self.axis,
2248
- )
2249
-
2250
- return {
2251
- self.output_key: iq_data_two_channel,
2252
- "demodulation_frequency": demodulation_frequency,
2253
- "center_frequency": 0.0,
2254
- "n_ch": 2,
2255
- }
2256
-
2257
-
2258
- @ops_registry("lambda")
2259
- class Lambda(Operation):
2260
- """Use any function as an operation."""
2261
-
2262
- def __init__(self, func, **kwargs):
2263
- # Split kwargs into kwargs for partial and __init__
2264
- op_kwargs = {k: v for k, v in kwargs.items() if k not in func.__code__.co_varnames}
2265
- func_kwargs = {k: v for k, v in kwargs.items() if k in func.__code__.co_varnames}
2266
- Lambda._check_if_unary(func, **func_kwargs)
2267
-
2268
- super().__init__(**op_kwargs)
2269
- self.func = partial(func, **func_kwargs)
2270
-
2271
- @staticmethod
2272
- def _check_if_unary(func, **kwargs):
2273
- """Checks if the kwargs are sufficient to call the function as a unary operation."""
2274
- sig = inspect.signature(func)
2275
- # Remove arguments that are already provided in func_kwargs
2276
- params = list(sig.parameters.values())
2277
- remaining = [p for p in params if p.name not in kwargs]
2278
- # Count required positional arguments (excluding self/cls)
2279
- required_positional = [
2280
- p
2281
- for p in remaining
2282
- if p.default is p.empty and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
2283
- ]
2284
- if len(required_positional) != 1:
2285
- raise ValueError(
2286
- f"Partial of {func.__name__} must be callable with exactly one required "
2287
- f"positional argument, we still need: {required_positional}."
2288
- )
2289
-
2290
- def call(self, **kwargs):
2291
- data = kwargs[self.key]
2292
- data = self.func(data)
2293
- return {self.output_key: data}
2294
-
2295
-
2296
- @ops_registry("pad")
2297
- class Pad(Operation, DataLayer):
2298
- """Pad layer for padding tensors to a specified shape."""
2299
-
2300
- def __init__(
2301
- self,
2302
- target_shape: list | tuple,
2303
- uniform: bool = True,
2304
- axis: Union[int, List[int]] = None,
2305
- fail_on_bigger_shape: bool = True,
2306
- pad_kwargs: dict = None,
2307
- **kwargs,
2308
- ):
2309
- super().__init__(**kwargs)
2310
- self.target_shape = target_shape
2311
- self.uniform = uniform
2312
- self.axis = axis
2313
- self.pad_kwargs = pad_kwargs or {}
2314
- self.fail_on_bigger_shape = fail_on_bigger_shape
2315
-
2316
- @staticmethod
2317
- def _format_target_shape(shape_array, target_shape, axis):
2318
- if isinstance(axis, int):
2319
- axis = [axis]
2320
- assert len(axis) == len(target_shape), (
2321
- "The length of axis must be equal to the length of target_shape."
2322
- )
2323
- axis = map_negative_indices(axis, len(shape_array))
2324
-
2325
- target_shape = [
2326
- target_shape[axis.index(i)] if i in axis else shape_array[i]
2327
- for i in range(len(shape_array))
2328
- ]
2329
- return target_shape
2330
-
2331
- def pad(
2332
- self,
2333
- z,
2334
- target_shape: list | tuple,
2335
- uniform: bool = True,
2336
- axis: Union[int, List[int]] = None,
2337
- fail_on_bigger_shape: bool = True,
2338
- **kwargs,
2339
- ):
2340
- """
2341
- Pads the input tensor `z` to the specified shape.
2342
-
2343
- Parameters:
2344
- z (tensor): The input tensor to be padded.
2345
- target_shape (list or tuple): The target shape to pad the tensor to.
2346
- uniform (bool, optional): If True, ensures that padding is uniform (even on both sides).
2347
- Default is False.
2348
- axis (int or list of int, optional): The axis or axes along which `target_shape` was
2349
- specified. If None, `len(target_shape) == `len(ops.shape(z))` must hold.
2350
- Default is None.
2351
- fail_on_bigger_shape (bool, optional): If True (default), raises an error if any target
2352
- dimension is smaller than the input shape; if False, pads only where the
2353
- target shape exceeds the input shape and leaves other dimensions unchanged.
2354
- kwargs: Additional keyword arguments to pass to the padding function.
2355
-
2356
- Returns:
2357
- tensor: The padded tensor with the specified shape.
2358
- """
2359
- shape_array = self.backend.shape(z)
2360
-
2361
- # When axis is provided, convert target_shape
2362
- if axis is not None:
2363
- target_shape = self._format_target_shape(shape_array, target_shape, axis)
2364
-
2365
- if not fail_on_bigger_shape:
2366
- target_shape = [max(target_shape[i], shape_array[i]) for i in range(len(shape_array))]
2367
-
2368
- # Compute the padding required for each dimension
2369
- pad_shape = np.array(target_shape) - shape_array
2370
-
2371
- # Create the paddings array
2372
- if uniform:
2373
- # if odd, pad more on the left, same as:
2374
- # https://keras.io/api/layers/preprocessing_layers/image_preprocessing/center_crop/
2375
- right_pad = pad_shape // 2
2376
- left_pad = pad_shape - right_pad
2377
- paddings = np.stack([right_pad, left_pad], axis=1)
2378
- else:
2379
- paddings = np.stack([np.zeros_like(pad_shape), pad_shape], axis=1)
2380
-
2381
- if np.any(paddings < 0):
2382
- raise ValueError(
2383
- f"Target shape {target_shape} must be greater than or equal "
2384
- f"to the input shape {shape_array}."
2385
- )
2386
-
2387
- return self.backend.numpy.pad(z, paddings, **kwargs)
2388
-
2389
- def call(self, **kwargs):
2390
- data = kwargs[self.key]
2391
- padded_data = self.pad(
2392
- data,
2393
- self.target_shape,
2394
- self.uniform,
2395
- self.axis,
2396
- self.fail_on_bigger_shape,
2397
- **self.pad_kwargs,
2398
- )
2399
- return {self.output_key: padded_data}
2400
-
2401
-
2402
- @ops_registry("companding")
2403
- class Companding(Operation):
2404
- """Companding according to the A- or μ-law algorithm.
2405
-
2406
- Invertible compressing operation. Used to compress
2407
- dynamic range of input data (and subsequently expand).
2408
-
2409
- μ-law companding:
2410
- https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
2411
- A-law companding:
2412
- https://en.wikipedia.org/wiki/A-law_algorithm
2413
-
2414
- Args:
2415
- expand (bool, optional): If set to False (default),
2416
- data is compressed, else expanded.
2417
- comp_type (str): either `a` or `mu`.
2418
- mu (float, optional): compression parameter. Defaults to 255.
2419
- A (float, optional): compression parameter. Defaults to 87.6.
2420
- """
2421
-
2422
- def __init__(self, expand=False, comp_type="mu", **kwargs):
2423
- super().__init__(**kwargs)
2424
- self.expand = expand
2425
- self.comp_type = comp_type.lower()
2426
- if self.comp_type not in ["mu", "a"]:
2427
- raise ValueError("comp_type must be 'mu' or 'a'.")
2428
-
2429
- if self.comp_type == "mu":
2430
- self._compand_func = self._mu_law_expand if self.expand else self._mu_law_compress
2431
- else:
2432
- self._compand_func = self._a_law_expand if self.expand else self._a_law_compress
2433
-
2434
- @staticmethod
2435
- def _mu_law_compress(x, mu=255, **kwargs):
2436
- x = ops.clip(x, -1, 1)
2437
- return ops.sign(x) * ops.log(1.0 + mu * ops.abs(x)) / ops.log(1.0 + mu)
2438
-
2439
- @staticmethod
2440
- def _mu_law_expand(y, mu=255, **kwargs):
2441
- y = ops.clip(y, -1, 1)
2442
- return ops.sign(y) * ((1.0 + mu) ** ops.abs(y) - 1.0) / mu
2443
-
2444
- @staticmethod
2445
- def _a_law_compress(x, A=87.6, **kwargs):
2446
- x = ops.clip(x, -1, 1)
2447
- x_sign = ops.sign(x)
2448
- x_abs = ops.abs(x)
2449
- A_log = ops.log(A)
2450
- val1 = x_sign * A * x_abs / (1.0 + A_log)
2451
- val2 = x_sign * (1.0 + ops.log(A * x_abs)) / (1.0 + A_log)
2452
- y = ops.where((x_abs >= 0) & (x_abs < (1.0 / A)), val1, val2)
2453
- return y
2454
-
2455
- @staticmethod
2456
- def _a_law_expand(y, A=87.6, **kwargs):
2457
- y = ops.clip(y, -1, 1)
2458
- y_sign = ops.sign(y)
2459
- y_abs = ops.abs(y)
2460
- A_log = ops.log(A)
2461
- val1 = y_sign * y_abs * (1.0 + A_log) / A
2462
- val2 = y_sign * ops.exp(y_abs * (1.0 + A_log) - 1.0) / A
2463
- x = ops.where((y_abs >= 0) & (y_abs < (1.0 / (1.0 + A_log))), val1, val2)
2464
- return x
2465
-
2466
- def call(self, mu=255, A=87.6, **kwargs):
2467
- data = kwargs[self.key]
2468
-
2469
- mu = ops.cast(mu, data.dtype)
2470
- A = ops.cast(A, data.dtype)
2471
-
2472
- data_out = self._compand_func(data, mu=mu, A=A)
2473
- return {self.output_key: data_out}
2474
-
2475
-
2476
- @ops_registry("downsample")
2477
- class Downsample(Operation):
2478
- """Downsample data along a specific axis."""
2479
-
2480
- def __init__(self, factor: int = 1, phase: int = 0, axis: int = -3, **kwargs):
2481
- super().__init__(
2482
- additional_output_keys=["sampling_frequency", "n_ax"],
2483
- **kwargs,
2484
- )
2485
- self.factor = factor
2486
- self.phase = phase
2487
- self.axis = axis
2488
-
2489
- def call(self, sampling_frequency=None, n_ax=None, **kwargs):
2490
- data = kwargs[self.key]
2491
- length = ops.shape(data)[self.axis]
2492
- sample_idx = ops.arange(self.phase, length, self.factor)
2493
- data_downsampled = ops.take(data, sample_idx, axis=self.axis)
2494
-
2495
- output = {self.output_key: data_downsampled}
2496
- # downsampling also affects the sampling frequency
2497
- if sampling_frequency is not None:
2498
- sampling_frequency = sampling_frequency / self.factor
2499
- output["sampling_frequency"] = sampling_frequency
2500
- if n_ax is not None:
2501
- n_ax = n_ax // self.factor
2502
- output["n_ax"] = n_ax
2503
- return output
2504
-
2505
-
2506
- @ops_registry("branched_pipeline")
2507
- class BranchedPipeline(Operation):
2508
- """Operation that processes data through multiple branches.
2509
-
2510
- This operation takes input data, processes it through multiple parallel branches,
2511
- and then merges the results from those branches using the specified merge strategy.
2512
- """
2513
-
2514
- def __init__(self, branches=None, merge_strategy="nested", **kwargs):
2515
- """Initialize a branched pipeline.
2516
-
2517
- Args:
2518
- branches (List[Union[List, Pipeline, Operation]]): List of branch operations
2519
- merge_strategy (str or callable): How to merge the outputs from branches:
2520
- - "nested" (default): Return outputs as a dictionary keyed by branch name
2521
- - "flatten": Flatten outputs by prefixing keys with the branch name
2522
- - "suffix": Flatten outputs by suffixing keys with the branch name
2523
- - callable: A custom merge function that accepts the branch outputs dict
2524
- **kwargs: Additional arguments for the Operation base class
2525
- """
2526
- super().__init__(**kwargs)
2527
-
2528
- # Convert branch specifications to operation chains
2529
- if branches is None:
2530
- branches = []
2531
-
2532
- self.branches = {}
2533
- for i, branch in enumerate(branches, start=1):
2534
- branch_name = f"branch_{i}"
2535
- # Convert different branch specification types
2536
- if isinstance(branch, list):
2537
- # Convert list to operation chain
2538
- self.branches[branch_name] = make_operation_chain(branch)
2539
- elif isinstance(branch, (Pipeline, Operation)):
2540
- # Already a pipeline or operation
2541
- self.branches[branch_name] = branch
2542
- else:
2543
- raise ValueError(
2544
- f"Branch must be a list, Pipeline, or Operation, got {type(branch)}"
2545
- )
2546
-
2547
- # Set merge strategy
2548
- self.merge_strategy = merge_strategy
2549
- if isinstance(merge_strategy, str):
2550
- if merge_strategy == "nested":
2551
- self._merge_function = lambda outputs: outputs
2552
- elif merge_strategy == "flatten":
2553
- self._merge_function = self.flatten_outputs
2554
- elif merge_strategy == "suffix":
2555
- self._merge_function = self.suffix_merge_outputs
2556
- else:
2557
- raise ValueError(f"Unknown merge_strategy: {merge_strategy}")
2558
- elif callable(merge_strategy):
2559
- self._merge_function = merge_strategy
2560
- else:
2561
- raise ValueError("Invalid merge_strategy type provided.")
2562
-
2563
- def call(self, **kwargs):
2564
- """Process input through branches and merge results.
2565
-
2566
- Args:
2567
- **kwargs: Input keyword arguments
2568
-
2569
- Returns:
2570
- dict: Merged outputs from all branches according to merge strategy
2571
- """
2572
- branch_outputs = {}
2573
- for branch_name, branch in self.branches.items():
2574
- # Each branch gets a fresh copy of kwargs to avoid interference
2575
- branch_kwargs = kwargs.copy()
2576
-
2577
- # Process through the branch
2578
- branch_result = branch(**branch_kwargs)
2579
-
2580
- # Store branch outputs
2581
- branch_outputs[branch_name] = branch_result
2582
-
2583
- # Apply merge strategy to combine outputs
2584
- merged_outputs = self._merge_function(branch_outputs)
2585
-
2586
- return merged_outputs
2587
-
2588
- def flatten_outputs(self, outputs: dict) -> dict:
2589
- """
2590
- Flatten a nested dictionary by prefixing keys with the branch name.
2591
- For each branch, the resulting key is "{branch_name}_{original_key}".
2592
- """
2593
- flat = {}
2594
- for branch_name, branch_dict in outputs.items():
2595
- for key, value in branch_dict.items():
2596
- new_key = f"{branch_name}_{key}"
2597
- if new_key in flat:
2598
- raise ValueError(f"Key collision detected for {new_key}")
2599
- flat[new_key] = value
2600
- return flat
2601
-
2602
- def suffix_merge_outputs(self, outputs: dict) -> dict:
2603
- """
2604
- Flatten a nested dictionary by suffixing keys with the branch name.
2605
- For each branch, the resulting key is "{original_key}_{branch_name}".
2606
- """
2607
- flat = {}
2608
- for branch_name, branch_dict in outputs.items():
2609
- for key, value in branch_dict.items():
2610
- new_key = f"{key}_{branch_name}"
2611
- if new_key in flat:
2612
- raise ValueError(f"Key collision detected for {new_key}")
2613
- flat[new_key] = value
2614
- return flat
2615
-
2616
- def get_config(self):
2617
- """Return the config dictionary for serialization."""
2618
- config = super().get_config()
2619
-
2620
- # Add branch configurations
2621
- branch_configs = {}
2622
- for branch_name, branch in self.branches.items():
2623
- if isinstance(branch, Pipeline):
2624
- # Get the operations list from the Pipeline
2625
- branch_configs[branch_name] = branch.get_config()
2626
- elif isinstance(branch, list):
2627
- # Convert list of operations to list of operation configs
2628
- branch_op_configs = []
2629
- for op in branch:
2630
- branch_op_configs.append(op.get_config())
2631
- branch_configs[branch_name] = {"operations": branch_op_configs}
2632
- else:
2633
- # Single operation
2634
- branch_configs[branch_name] = branch.get_config()
2635
-
2636
- # Add merge strategy
2637
- if isinstance(self.merge_strategy, str):
2638
- merge_strategy_config = self.merge_strategy
2639
- else:
2640
- # For custom functions, use the name if available
2641
- merge_strategy_config = getattr(self.merge_strategy, "__name__", "custom")
2642
-
2643
- config.update(
2644
- {
2645
- "branches": branch_configs,
2646
- "merge_strategy": merge_strategy_config,
2647
- }
2648
- )
2649
-
2650
- return config
2651
-
2652
- def get_dict(self):
2653
- """Get the configuration of the operation."""
2654
- config = super().get_dict()
2655
- config.update({"name": "branched_pipeline"})
2656
-
2657
- # Add branches (recursively) to the config
2658
- branches = {}
2659
- for branch_name, branch in self.branches.items():
2660
- if isinstance(branch, Pipeline):
2661
- branches[branch_name] = branch.get_dict()
2662
- elif isinstance(branch, list):
2663
- branches[branch_name] = [op.get_dict() for op in branch]
2664
- else:
2665
- branches[branch_name] = branch.get_dict()
2666
- config["branches"] = branches
2667
- config["merge_strategy"] = self.merge_strategy
2668
- return config
2669
-
2670
-
2671
- @ops_registry("threshold")
2672
- class Threshold(Operation):
2673
- """Threshold an array, setting values below/above a threshold to a fill value."""
2674
-
2675
- def __init__(
2676
- self,
2677
- threshold_type="hard",
2678
- below_threshold=True,
2679
- fill_value="min",
2680
- **kwargs,
2681
- ):
2682
- super().__init__(**kwargs)
2683
- if threshold_type not in ("hard", "soft"):
2684
- raise ValueError("threshold_type must be 'hard' or 'soft'")
2685
- self.threshold_type = threshold_type
2686
- self.below_threshold = below_threshold
2687
- self._fill_value_type = fill_value
2688
-
2689
- # Define threshold function at init
2690
- if threshold_type == "hard":
2691
- if below_threshold:
2692
- self._threshold_func = lambda data, threshold, fill: ops.where(
2693
- data < threshold, fill, data
2694
- )
2695
- else:
2696
- self._threshold_func = lambda data, threshold, fill: ops.where(
2697
- data > threshold, fill, data
2698
- )
2699
- else: # soft
2700
- if below_threshold:
2701
- self._threshold_func = (
2702
- lambda data, threshold, fill: ops.maximum(data - threshold, 0) + fill
2703
- )
2704
- else:
2705
- self._threshold_func = (
2706
- lambda data, threshold, fill: ops.minimum(data - threshold, 0) + fill
2707
- )
2708
-
2709
- def _resolve_fill_value(self, data, threshold):
2710
- """Get the fill value based on the fill_value_type."""
2711
- fv = self._fill_value_type
2712
- if isinstance(fv, (int, float)):
2713
- return ops.convert_to_tensor(fv, dtype=data.dtype)
2714
- elif fv == "min":
2715
- return ops.min(data)
2716
- elif fv == "max":
2717
- return ops.max(data)
2718
- elif fv == "threshold":
2719
- return threshold
2720
- else:
2721
- raise ValueError("Unknown fill_value")
2722
-
2723
- def call(
2724
- self,
2725
- threshold=None,
2726
- percentile=None,
2727
- **kwargs,
2728
- ):
2729
- """Threshold the input data.
2730
-
2731
- Args:
2732
- threshold: Numeric threshold.
2733
- percentile: Percentile to derive threshold from.
2734
- Returns:
2735
- Tensor with thresholding applied.
2736
- """
2737
- data = kwargs[self.key]
2738
- if (threshold is None) == (percentile is None):
2739
- raise ValueError("Pass either threshold or percentile, not both or neither.")
2740
-
2741
- if percentile is not None:
2742
- # Convert percentile to quantile value (0-1 range)
2743
- threshold = ops.quantile(data, percentile / 100.0)
2744
-
2745
- fill_value = self._resolve_fill_value(data, threshold)
2746
- result = self._threshold_func(data, threshold, fill_value)
2747
- return {self.output_key: result}
2748
-
2749
-
2750
- @ops_registry("anisotropic_diffusion")
2751
- class AnisotropicDiffusion(Operation):
2752
- """Speckle Reducing Anisotropic Diffusion (SRAD) filter.
2753
-
2754
- Reference:
2755
- - https://www.researchgate.net/publication/5602035_Speckle_reducing_anisotropic_diffusion
2756
- - https://nl.mathworks.com/matlabcentral/fileexchange/54044-image-despeckle-filtering-toolbox
2757
- """
2758
-
2759
- def call(self, niter=100, lmbda=0.1, rect=None, eps=1e-6, **kwargs):
2760
- """Anisotropic diffusion filter.
2761
-
2762
- Assumes input data is non-negative.
2763
-
2764
- Args:
2765
- niter: Number of iterations.
2766
- lmbda: Lambda parameter.
2767
- rect: Rectangle [x1, y1, x2, y2] for homogeneous noise (optional).
2768
- eps: Small epsilon for stability.
2769
- Returns:
2770
- Filtered image (2D tensor or batch of images).
2771
- """
2772
- data = kwargs[self.key]
2773
-
2774
- if not self.with_batch_dim:
2775
- data = ops.expand_dims(data, axis=0)
2776
-
2777
- batch_size = ops.shape(data)[0]
2778
-
2779
- results = []
2780
- for i in range(batch_size):
2781
- image = data[i]
2782
- image_out = self._anisotropic_diffusion_single(image, niter, lmbda, rect, eps)
2783
- results.append(image_out)
2784
-
2785
- result = ops.stack(results, axis=0)
2786
-
2787
- if not self.with_batch_dim:
2788
- result = ops.squeeze(result, axis=0)
2789
-
2790
- return {self.output_key: result}
2791
-
2792
- def _anisotropic_diffusion_single(self, image, niter, lmbda, rect, eps):
2793
- """Apply anisotropic diffusion to a single image (2D)."""
2794
- image = ops.exp(image)
2795
- M, N = image.shape
2796
-
2797
- for _ in range(niter):
2798
- iN = ops.concatenate([image[1:], ops.zeros((1, N), dtype=image.dtype)], axis=0)
2799
- iS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), image[:-1]], axis=0)
2800
- jW = ops.concatenate([image[:, 1:], ops.zeros((M, 1), dtype=image.dtype)], axis=1)
2801
- jE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), image[:, :-1]], axis=1)
2802
-
2803
- if rect is not None:
2804
- x1, y1, x2, y2 = rect
2805
- imageuniform = image[x1:x2, y1:y2]
2806
- q0_squared = (ops.std(imageuniform) / (ops.mean(imageuniform) + eps)) ** 2
2807
-
2808
- dN = iN - image
2809
- dS = iS - image
2810
- dW = jW - image
2811
- dE = jE - image
2812
-
2813
- G2 = (dN**2 + dS**2 + dW**2 + dE**2) / (image**2 + eps)
2814
- L = (dN + dS + dW + dE) / (image + eps)
2815
- num = (0.5 * G2) - ((1 / 16) * (L**2))
2816
- den = (1 + ((1 / 4) * L)) ** 2
2817
- q_squared = num / (den + eps)
2818
-
2819
- if rect is not None:
2820
- den = (q_squared - q0_squared) / (q0_squared * (1 + q0_squared) + eps)
2821
- c = 1.0 / (1 + den)
2822
- cS = ops.concatenate([ops.zeros((1, N), dtype=image.dtype), c[:-1]], axis=0)
2823
- cE = ops.concatenate([ops.zeros((M, 1), dtype=image.dtype), c[:, :-1]], axis=1)
2824
-
2825
- D = (cS * dS) + (c * dN) + (cE * dE) + (c * dW)
2826
- image = image + (lmbda / 4) * D
2827
-
2828
- result = ops.log(image)
2829
- return result
2830
-
2831
-
2832
- @ops_registry("channels_to_complex")
2833
- class ChannelsToComplex(Operation):
2834
- def call(self, **kwargs):
2835
- data = kwargs[self.key]
2836
- output = channels_to_complex(data)
2837
- return {self.output_key: output}
2838
-
2839
-
2840
- @ops_registry("complex_to_channels")
2841
- class ComplexToChannels(Operation):
2842
- def __init__(self, axis=-1, **kwargs):
2843
- super().__init__(**kwargs)
2844
- self.axis = axis
2845
-
2846
- def call(self, **kwargs):
2847
- data = kwargs[self.key]
2848
- output = complex_to_channels(data, axis=self.axis)
2849
- return {self.output_key: output}
2850
-
2851
-
2852
- def demodulate_not_jitable(
2853
- rf_data,
2854
- sampling_frequency=None,
2855
- center_frequency=None,
2856
- bandwidth=None,
2857
- filter_coeff=None,
2858
- ):
2859
- """Demodulates an RF signal to complex base-band (IQ).
2860
-
2861
- Demodulates the radiofrequency (RF) bandpass signals and returns the
2862
- Inphase/Quadrature (I/Q) components. IQ is a complex whose real (imaginary)
2863
- part contains the in-phase (quadrature) component.
2864
-
2865
- This function operates (i.e. demodulates) on the RF signal over the
2866
- (fast-) time axis which is assumed to be the last axis.
2867
-
2868
- Args:
2869
- rf_data (ndarray): real valued input array of size [..., n_ax, n_el].
2870
- second to last axis is fast-time axis.
2871
- sampling_frequency (float): the sampling frequency of the RF signals (in Hz).
2872
- Only not necessary when filter_coeff is provided.
2873
- center_frequency (float, optional): represents the center frequency (in Hz).
2874
- Defaults to None.
2875
- bandwidth (float, optional): Bandwidth of RF signal in % of center
2876
- frequency. Defaults to None.
2877
- The bandwidth in % is defined by:
2878
- B = Bandwidth_in_% = Bandwidth_in_Hz*(100/center_frequency).
2879
- The cutoff frequency:
2880
- Wn = Bandwidth_in_Hz/sampling_frequency, i.e:
2881
- Wn = B*(center_frequency/100)/sampling_frequency.
2882
- filter_coeff (list, optional): (b, a), numerator and denominator coefficients
2883
- of FIR filter for quadratic band pass filter. All other parameters are ignored
2884
- if filter_coeff are provided. Instead the given filter_coeff is directly used.
2885
- If not provided, a filter is derived from the other params (sampling_frequency,
2886
- center_frequency, bandwidth).
2887
- see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html
2888
-
2889
- Returns:
2890
- iq_data (ndarray): complex valued base-band signal.
2891
-
2892
- """
2893
- rf_data = ops.convert_to_numpy(rf_data)
2894
- assert np.isreal(rf_data).all(), f"RF must contain real RF signals, got {rf_data.dtype}"
2895
-
2896
- input_shape = rf_data.shape
2897
- n_dim = len(input_shape)
2898
- if n_dim > 2:
2899
- *_, n_ax, n_el = input_shape
2900
- else:
2901
- n_ax, n_el = input_shape
2902
-
2903
- if filter_coeff is None:
2904
- assert sampling_frequency is not None, "provide sampling_frequency when no filter is given."
2905
- # Time vector
2906
- t = np.arange(n_ax) / sampling_frequency
2907
- t0 = 0
2908
- t = t + t0
2909
-
2910
- # Estimate center frequency
2911
- if center_frequency is None:
2912
- # Keep a maximum of 100 randomly selected scanlines
2913
- idx = np.arange(n_el)
2914
- if n_el > 100:
2915
- idx = np.random.permutation(idx)[:100]
2916
- # Power Spectrum
2917
- P = np.sum(
2918
- np.abs(np.fft.fft(np.take(rf_data, idx, axis=-1), axis=-2)) ** 2,
2919
- axis=-1,
2920
- )
2921
- P = P[: n_ax // 2]
2922
- # Carrier frequency
2923
- idx = np.sum(np.arange(n_ax // 2) * P) / np.sum(P)
2924
- center_frequency = idx * sampling_frequency / n_ax
2925
-
2926
- # Normalized cut-off frequency
2927
- if bandwidth is None:
2928
- Wn = min(2 * center_frequency / sampling_frequency, 0.5)
2929
- bandwidth = center_frequency * Wn
2930
- else:
2931
- assert np.isscalar(bandwidth), "The signal bandwidth (in %) must be a scalar."
2932
- assert (bandwidth > 0) & (bandwidth <= 200), (
2933
- "The signal bandwidth (in %) must be within the interval of ]0,200]."
2934
- )
2935
- # bandwidth in Hz
2936
- bandwidth = center_frequency * bandwidth / 100
2937
- Wn = bandwidth / sampling_frequency
2938
- assert (Wn > 0) & (Wn <= 1), (
2939
- "The normalized cutoff frequency is not within the interval of (0,1). "
2940
- "Check the input parameters!"
2941
- )
2942
-
2943
- # Down-mixing of the RF signals
2944
- carrier = np.exp(-1j * 2 * np.pi * center_frequency * t)
2945
- # add the singleton dimensions
2946
- carrier = np.reshape(carrier, (*[1] * (n_dim - 2), n_ax, 1))
2947
- iq_data = rf_data * carrier
2948
-
2949
- # Low-pass filter
2950
- N = 5
2951
- b, a = scipy.signal.butter(N, Wn, "low")
2952
-
2953
- # factor 2: to preserve the envelope amplitude
2954
- iq_data = scipy.signal.filtfilt(b, a, iq_data, axis=-2) * 2
2955
-
2956
- # Display a warning message if harmful aliasing is suspected
2957
- # the RF signal is undersampled
2958
- if sampling_frequency < (2 * center_frequency + bandwidth):
2959
- # lower and higher frequencies of the bandpass signal
2960
- fL = center_frequency - bandwidth / 2
2961
- fH = center_frequency + bandwidth / 2
2962
- n = fH // (fH - fL)
2963
- harmless_aliasing = any(
2964
- (2 * fH / np.arange(1, n) <= sampling_frequency)
2965
- & (sampling_frequency <= 2 * fL / np.arange(1, n))
2966
- )
2967
- if not harmless_aliasing:
2968
- log.warning(
2969
- "rf2iq:harmful_aliasing Harmful aliasing is present: the aliases"
2970
- " are not mutually exclusive!"
2971
- )
2972
- else:
2973
- b, a = filter_coeff
2974
- iq_data = scipy.signal.lfilter(b, a, rf_data, axis=-2) * 2
2975
-
2976
- return iq_data
2977
-
2978
-
2979
- def upmix(iq_data, sampling_frequency, center_frequency, upsampling_rate=6):
2980
- """Upsamples and upmixes complex base-band signals (IQ) to RF.
2981
-
2982
- Args:
2983
- iq_data (ndarray): complex valued input array of size [..., n_ax, n_el]. second
2984
- to last axis is fast-time axis.
2985
- sampling_frequency (float): the sampling frequency of the input IQ signal (in Hz).
2986
- resulting sampling_frequency of RF data is upsampling_rate times higher.
2987
- center_frequency (float, optional): represents the center frequency (in Hz).
2988
-
2989
- Returns:
2990
- rf_data (ndarray): output real valued rf data.
2991
- """
2992
- assert iq_data.dtype in [
2993
- "complex64",
2994
- "complex128",
2995
- ], "IQ must contain all complex signals."
2996
-
2997
- input_shape = iq_data.shape
2998
- n_dim = len(input_shape)
2999
- if n_dim > 2:
3000
- *_, n_ax, _ = input_shape
3001
- else:
3002
- n_ax, _ = input_shape
3003
-
3004
- # Time vector
3005
- n_ax_up = n_ax * upsampling_rate
3006
- sampling_frequency_up = sampling_frequency * upsampling_rate
3007
-
3008
- t = ops.arange(n_ax_up, dtype="float32") / sampling_frequency_up
3009
- t0 = 0
3010
- t = t + t0
3011
-
3012
- iq_data_upsampled = resample(
3013
- iq_data,
3014
- n_samples=n_ax_up,
3015
- axis=-2,
3016
- order=1,
3017
- )
3018
-
3019
- # Up-mixing of the IQ signals
3020
- t = ops.cast(t, dtype="complex64")
3021
- center_frequency = ops.cast(center_frequency, dtype="complex64")
3022
- carrier = ops.exp(1j * 2 * np.pi * center_frequency * t)
3023
- carrier = ops.reshape(carrier, (*[1] * (n_dim - 2), n_ax_up, 1))
3024
-
3025
- rf_data = iq_data_upsampled * carrier
3026
- rf_data = ops.real(rf_data) * ops.sqrt(2)
3027
-
3028
- return ops.cast(rf_data, "float32")
3029
-
3030
-
3031
- def get_band_pass_filter(num_taps, sampling_frequency, f1, f2):
3032
- """Band pass filter
3033
-
3034
- Args:
3035
- num_taps (int): number of taps in filter.
3036
- sampling_frequency (float): sample frequency in Hz.
3037
- f1 (float): cutoff frequency in Hz of left band edge.
3038
- f2 (float): cutoff frequency in Hz of right band edge.
3039
-
3040
- Returns:
3041
- ndarray: band pass filter
3042
- """
3043
- bpf = scipy.signal.firwin(num_taps, [f1, f2], pass_zero=False, fs=sampling_frequency)
3044
- return bpf
3045
-
3046
-
3047
- def get_low_pass_iq_filter(num_taps, sampling_frequency, f, bw):
3048
- """Design complex low-pass filter.
3049
-
3050
- The filter is a low-pass FIR filter modulated to the center frequency.
3051
-
3052
- Args:
3053
- num_taps (int): number of taps in filter.
3054
- sampling_frequency (float): sample frequency.
3055
- f (float): center frequency.
3056
- bw (float): bandwidth in Hz.
3057
-
3058
- Raises:
3059
- ValueError: if cutoff frequency (bw / 2) is not within (0, sampling_frequency / 2)
3060
-
3061
- Returns:
3062
- ndarray: Complex-valued low-pass filter
3063
- """
3064
- cutoff = bw / 2
3065
- if not (0 < cutoff < sampling_frequency / 2):
3066
- raise ValueError(
3067
- f"Cutoff frequency must be within (0, sampling_frequency / 2), "
3068
- f"got {cutoff} Hz, must be within (0, {sampling_frequency / 2}) Hz"
3069
- )
3070
- # Design real-valued low-pass filter
3071
- lpf = scipy.signal.firwin(num_taps, cutoff, pass_zero=True, fs=sampling_frequency)
3072
- # Modulate to center frequency to make it complex
3073
- time_points = np.arange(num_taps) / sampling_frequency
3074
- lpf_complex = lpf * np.exp(1j * 2 * np.pi * f * time_points)
3075
- return lpf_complex
3076
-
3077
-
3078
- def complex_to_channels(complex_data, axis=-1):
3079
- """Unroll complex data to separate channels.
3080
-
3081
- Args:
3082
- complex_data (complex ndarray): complex input data.
3083
- axis (int, optional): on which axis to extend. Defaults to -1.
3084
-
3085
- Returns:
3086
- ndarray: real array with real and imaginary components
3087
- unrolled over two channels at axis.
3088
- """
3089
- # assert ops.iscomplex(complex_data).any()
3090
- q_data = ops.imag(complex_data)
3091
- i_data = ops.real(complex_data)
3092
-
3093
- i_data = ops.expand_dims(i_data, axis=axis)
3094
- q_data = ops.expand_dims(q_data, axis=axis)
3095
-
3096
- iq_data = ops.concatenate((i_data, q_data), axis=axis)
3097
- return iq_data
3098
-
3099
-
3100
- def channels_to_complex(data):
3101
- """Convert array with real and imaginary components at
3102
- different channels to complex data array.
3103
-
3104
- Args:
3105
- data (ndarray): input data, with at 0 index of axis
3106
- real component and 1 index of axis the imaginary.
3107
-
3108
- Returns:
3109
- ndarray: complex array with real and imaginary components.
3110
- """
3111
- assert data.shape[-1] == 2, "Data must have two channels."
3112
- data = ops.cast(data, "complex64")
3113
- return data[..., 0] + 1j * data[..., 1]
3114
-
3115
-
3116
- def hilbert(x, N: int = None, axis=-1):
3117
- """Manual implementation of the Hilbert transform function. The function
3118
- returns the analytical signal.
3119
-
3120
- Operated in the Fourier domain.
3121
-
3122
- Note:
3123
- THIS IS NOT THE MATHEMATICAL THE HILBERT TRANSFORM as you will find it on
3124
- wikipedia, but computes the analytical signal. The implementation reproduces
3125
- the behavior of the `scipy.signal.hilbert` function.
3126
-
3127
- Args:
3128
- x (ndarray): input data of any shape.
3129
- N (int, optional): number of points in the FFT. Defaults to None.
3130
- axis (int, optional): axis to operate on. Defaults to -1.
3131
- Returns:
3132
- x (ndarray): complex iq data of any shape.k
3133
-
3134
- """
3135
- input_shape = x.shape
3136
- n_dim = len(input_shape)
3137
-
3138
- n_ax = input_shape[axis]
3139
-
3140
- if axis < 0:
3141
- axis = n_dim + axis
3142
-
3143
- if N is not None:
3144
- if N < n_ax:
3145
- raise ValueError("N must be greater or equal to n_ax.")
3146
- # only pad along the axis, use manual padding
3147
- pad = N - n_ax
3148
- zeros = ops.zeros(
3149
- input_shape[:axis] + (pad,) + input_shape[axis + 1 :],
3150
- )
3151
-
3152
- x = ops.concatenate((x, zeros), axis=axis)
3153
- else:
3154
- N = n_ax
3155
-
3156
- # Create filter to zero out negative frequencies
3157
- h = np.zeros(N)
3158
- if N % 2 == 0:
3159
- h[0] = h[N // 2] = 1
3160
- h[1 : N // 2] = 2
3161
- else:
3162
- h[0] = 1
3163
- h[1 : (N + 1) // 2] = 2
3164
-
3165
- idx = list(range(n_dim))
3166
- # make sure axis gets to the end for fft (operates on last axis)
3167
- idx.remove(axis)
3168
- idx.append(axis)
3169
- x = ops.transpose(x, idx)
3170
-
3171
- if x.ndim > 1:
3172
- ind = [np.newaxis] * x.ndim
3173
- ind[-1] = slice(None)
3174
- h = h[tuple(ind)]
3175
-
3176
- h = ops.convert_to_tensor(h)
3177
- h = ops.cast(h, "complex64")
3178
- h = h + 1j * ops.zeros_like(h)
3179
-
3180
- Xf_r, Xf_i = ops.fft((x, ops.zeros_like(x)))
3181
-
3182
- Xf_r = ops.cast(Xf_r, "complex64")
3183
- Xf_i = ops.cast(Xf_i, "complex64")
3184
-
3185
- Xf = Xf_r + 1j * Xf_i
3186
- Xf = Xf * h
3187
-
3188
- # x = np.fft.ifft(Xf)
3189
- # do manual ifft using fft
3190
- Xf_r = ops.real(Xf)
3191
- Xf_i = ops.imag(Xf)
3192
- Xf_r_inv, Xf_i_inv = ops.fft((Xf_r, -Xf_i))
3193
-
3194
- Xf_i_inv = ops.cast(Xf_i_inv, "complex64")
3195
- Xf_r_inv = ops.cast(Xf_r_inv, "complex64")
3196
-
3197
- x = Xf_r_inv / N
3198
- x = x + 1j * (-Xf_i_inv / N)
3199
-
3200
- # switch back to original shape
3201
- idx = list(range(n_dim))
3202
- idx.insert(axis, idx.pop(-1))
3203
- x = ops.transpose(x, idx)
3204
- return x
3205
-
3206
-
3207
- def demodulate(data, center_frequency, sampling_frequency, axis=-3):
3208
- """Demodulates the input data to baseband. The function computes the analytical
3209
- signal (the signal with negative frequencies removed) and then shifts the spectrum
3210
- of the signal to baseband by multiplying with a complex exponential. Where the
3211
- spectrum was centered around `center_frequency` before, it is now centered around
3212
- 0 Hz. The baseband IQ data are complex-valued. The real and imaginary parts
3213
- are stored in two real-valued channels.
3214
-
3215
- Args:
3216
- data (ops.Tensor): The input data to demodulate of shape `(..., axis, ..., 1)`.
3217
- center_frequency (float): The center frequency of the signal.
3218
- sampling_frequency (float): The sampling frequency of the signal.
3219
- axis (int, optional): The axis along which to demodulate. Defaults to -3.
3220
-
3221
- Returns:
3222
- ops.Tensor: The demodulated IQ data of shape `(..., axis, ..., 2)`.
3223
- """
3224
- # Compute the analytical signal
3225
- analytical_signal = hilbert(data, axis=axis)
3226
-
3227
- # Define frequency indices
3228
- frequency_indices = ops.arange(analytical_signal.shape[axis])
3229
-
3230
- # Expand the frequency indices to match the shape of the RF data
3231
- indexing = [None] * data.ndim
3232
- indexing[axis] = slice(None)
3233
- indexing = tuple(indexing)
3234
- frequency_indices_shaped_like_rf = frequency_indices[indexing]
3235
-
3236
- # Cast to complex64
3237
- center_frequency = ops.cast(center_frequency, dtype="complex64")
3238
- sampling_frequency = ops.cast(sampling_frequency, dtype="complex64")
3239
- frequency_indices_shaped_like_rf = ops.cast(frequency_indices_shaped_like_rf, dtype="complex64")
3240
-
3241
- # Shift to baseband
3242
- phasor_exponent = (
3243
- -1j * 2 * np.pi * center_frequency * frequency_indices_shaped_like_rf / sampling_frequency
3244
- )
3245
- iq_data_signal_complex = analytical_signal * ops.exp(phasor_exponent)
3246
-
3247
- # Split the complex signal into two channels
3248
- iq_data_two_channel = complex_to_channels(ops.squeeze(iq_data_signal_complex, axis=-1))
3249
-
3250
- return iq_data_two_channel
3251
-
3252
-
3253
- def compute_time_to_peak_stack(waveforms, center_frequencies, waveform_sampling_frequency=250e6):
3254
- """Compute the time of the peak of each waveform in a stack of waveforms.
3255
-
3256
- Args:
3257
- waveforms (ndarray): The waveforms of shape (n_waveforms, n_samples).
3258
- center_frequencies (ndarray): The center frequencies of the waveforms in Hz of shape
3259
- (n_waveforms,) or a scalar if all waveforms have the same center frequency.
3260
- waveform_sampling_frequency (float): The sampling frequency of the waveforms in Hz.
3261
-
3262
- Returns:
3263
- ndarray: The time to peak for each waveform in seconds.
3264
- """
3265
- t_peak = []
3266
- center_frequencies = center_frequencies * ops.ones((waveforms.shape[0],))
3267
- for waveform, center_frequency in zip(waveforms, center_frequencies):
3268
- t_peak.append(compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency))
3269
- return ops.stack(t_peak)
3270
-
3271
-
3272
- def compute_time_to_peak(waveform, center_frequency, waveform_sampling_frequency=250e6):
3273
- """Compute the time of the peak of the waveform.
3274
-
3275
- Args:
3276
- waveform (ndarray): The waveform of shape (n_samples).
3277
- center_frequency (float): The center frequency of the waveform in Hz.
3278
- waveform_sampling_frequency (float): The sampling frequency of the waveform in Hz.
3279
-
3280
- Returns:
3281
- float: The time to peak for the waveform in seconds.
3282
- """
3283
- n_samples = waveform.shape[0]
3284
- if n_samples == 0:
3285
- raise ValueError("Waveform has zero samples.")
3286
-
3287
- waveforms_iq_complex_channels = demodulate(
3288
- waveform[..., None], center_frequency, waveform_sampling_frequency, axis=-1
3289
- )
3290
- waveforms_iq_complex = channels_to_complex(waveforms_iq_complex_channels)
3291
- envelope = ops.abs(waveforms_iq_complex)
3292
- peak_idx = ops.argmax(envelope, axis=-1)
3293
- t_peak = ops.cast(peak_idx, dtype="float32") / waveform_sampling_frequency
3294
- return t_peak