zea 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +3 -3
- zea/agent/masks.py +2 -2
- zea/agent/selection.py +3 -3
- zea/backend/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +1 -1
- zea/beamform/beamformer.py +4 -2
- zea/beamform/pfield.py +2 -2
- zea/data/augmentations.py +1 -1
- zea/data/convert/__main__.py +93 -52
- zea/data/convert/camus.py +8 -2
- zea/data/convert/echonet.py +1 -1
- zea/data/convert/echonetlvh/__init__.py +1 -1
- zea/data/convert/verasonics.py +810 -772
- zea/data/data_format.py +0 -2
- zea/data/file.py +28 -0
- zea/data/preset_utils.py +1 -1
- zea/display.py +1 -1
- zea/doppler.py +5 -5
- zea/func/__init__.py +109 -0
- zea/{tensor_ops.py → func/tensor.py} +32 -8
- zea/func/ultrasound.py +500 -0
- zea/internal/_generate_keras_ops.py +5 -5
- zea/metrics.py +6 -5
- zea/models/diffusion.py +1 -1
- zea/models/echonetlvh.py +1 -1
- zea/models/gmm.py +1 -1
- zea/ops/__init__.py +188 -0
- zea/ops/base.py +442 -0
- zea/{keras_ops.py → ops/keras_ops.py} +2 -2
- zea/ops/pipeline.py +1472 -0
- zea/ops/tensor.py +356 -0
- zea/ops/ultrasound.py +890 -0
- zea/probes.py +2 -10
- zea/scan.py +17 -20
- zea/tools/fit_scan_cone.py +1 -1
- zea/tools/selection_tool.py +1 -1
- zea/tracking/lucas_kanade.py +1 -1
- zea/tracking/segmentation.py +1 -1
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/METADATA +3 -1
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/RECORD +43 -37
- zea/ops.py +0 -3534
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
- {zea-0.0.8.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops/__init__.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Operations and Pipelines for ultrasound data processing.
|
|
2
|
+
|
|
3
|
+
This module contains two important classes, :class:`Operation` and :class:`Pipeline`,
|
|
4
|
+
which are used to process ultrasound data. A pipeline is a sequence of operations
|
|
5
|
+
that are applied to the data in a specific order.
|
|
6
|
+
|
|
7
|
+
We implement a range of common
|
|
8
|
+
operations for ultrasound data processing (:mod:`zea.ops.ultrasound`), but also support
|
|
9
|
+
a variety of basic tensor operations (:mod:`zea.ops.tensor`). Lastly, all existing Keras
|
|
10
|
+
operations (see `Keras Ops API <https://keras.io/api/ops/>`_) are available as `zea`
|
|
11
|
+
operations as well (see :mod:`zea.ops.keras_ops`).
|
|
12
|
+
|
|
13
|
+
Stand-alone manual usage
|
|
14
|
+
------------------------
|
|
15
|
+
|
|
16
|
+
Operations can be run on their own:
|
|
17
|
+
|
|
18
|
+
Examples
|
|
19
|
+
^^^^^^^^
|
|
20
|
+
.. doctest::
|
|
21
|
+
|
|
22
|
+
>>> import numpy as np
|
|
23
|
+
>>> from zea.ops import EnvelopeDetect
|
|
24
|
+
>>> data = np.random.randn(2000, 128, 1)
|
|
25
|
+
>>> # static arguments are passed in the constructor
|
|
26
|
+
>>> envelope_detect = EnvelopeDetect(axis=-1)
|
|
27
|
+
>>> # other parameters can be passed here along with the data
|
|
28
|
+
>>> envelope_data = envelope_detect(data=data)
|
|
29
|
+
|
|
30
|
+
Using a pipeline
|
|
31
|
+
----------------
|
|
32
|
+
|
|
33
|
+
You can initialize with a default pipeline or create your own custom pipeline.
|
|
34
|
+
|
|
35
|
+
.. doctest::
|
|
36
|
+
|
|
37
|
+
>>> from zea.ops import Pipeline, EnvelopeDetect, Normalize, LogCompress
|
|
38
|
+
>>> pipeline = Pipeline.from_default()
|
|
39
|
+
|
|
40
|
+
>>> operations = [
|
|
41
|
+
... EnvelopeDetect(),
|
|
42
|
+
... Normalize(),
|
|
43
|
+
... LogCompress(),
|
|
44
|
+
... ]
|
|
45
|
+
>>> pipeline_custom = Pipeline(operations)
|
|
46
|
+
|
|
47
|
+
One can also load a pipeline from a config or yaml/json file:
|
|
48
|
+
|
|
49
|
+
.. doctest::
|
|
50
|
+
|
|
51
|
+
>>> from zea import Pipeline
|
|
52
|
+
|
|
53
|
+
>>> # From JSON string
|
|
54
|
+
>>> json_string = '{"operations": ["identity"]}'
|
|
55
|
+
>>> pipeline = Pipeline.from_json(json_string)
|
|
56
|
+
|
|
57
|
+
>>> # from yaml file
|
|
58
|
+
>>> import yaml
|
|
59
|
+
>>> from zea import Config
|
|
60
|
+
>>> # Create a sample pipeline YAML file
|
|
61
|
+
>>> pipeline_dict = {
|
|
62
|
+
... "operations": [
|
|
63
|
+
... {"name": "identity"},
|
|
64
|
+
... ]
|
|
65
|
+
... }
|
|
66
|
+
>>> with open("pipeline.yaml", "w") as f:
|
|
67
|
+
... yaml.dump(pipeline_dict, f)
|
|
68
|
+
>>> yaml_file = "pipeline.yaml"
|
|
69
|
+
>>> pipeline = Pipeline.from_yaml(yaml_file)
|
|
70
|
+
|
|
71
|
+
.. testcleanup::
|
|
72
|
+
|
|
73
|
+
import os
|
|
74
|
+
|
|
75
|
+
os.remove("pipeline.yaml")
|
|
76
|
+
|
|
77
|
+
Example of a yaml file:
|
|
78
|
+
|
|
79
|
+
.. code-block:: yaml
|
|
80
|
+
|
|
81
|
+
pipeline:
|
|
82
|
+
operations:
|
|
83
|
+
- name: demodulate
|
|
84
|
+
- name: beamform
|
|
85
|
+
params:
|
|
86
|
+
type: das
|
|
87
|
+
pfield: false
|
|
88
|
+
num_patches: 100
|
|
89
|
+
- name: envelope_detect
|
|
90
|
+
- name: normalize
|
|
91
|
+
- name: log_compress
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
from zea.internal.registry import ops_registry
|
|
96
|
+
from zea.ops import keras_ops
|
|
97
|
+
|
|
98
|
+
from .base import (
|
|
99
|
+
Identity,
|
|
100
|
+
ImageOperation,
|
|
101
|
+
Lambda,
|
|
102
|
+
Mean,
|
|
103
|
+
Merge,
|
|
104
|
+
Operation,
|
|
105
|
+
Stack,
|
|
106
|
+
get_ops,
|
|
107
|
+
)
|
|
108
|
+
from .pipeline import (
|
|
109
|
+
Beamform,
|
|
110
|
+
BranchedPipeline,
|
|
111
|
+
DelayAndSum,
|
|
112
|
+
DelayMultiplyAndSum,
|
|
113
|
+
Map,
|
|
114
|
+
PatchedGrid,
|
|
115
|
+
Pipeline,
|
|
116
|
+
)
|
|
117
|
+
from .tensor import (
|
|
118
|
+
GaussianBlur,
|
|
119
|
+
Normalize,
|
|
120
|
+
Pad,
|
|
121
|
+
Threshold,
|
|
122
|
+
)
|
|
123
|
+
from .ultrasound import (
|
|
124
|
+
AnisotropicDiffusion,
|
|
125
|
+
ChannelsToComplex,
|
|
126
|
+
Companding,
|
|
127
|
+
ComplexToChannels,
|
|
128
|
+
Demodulate,
|
|
129
|
+
Downsample,
|
|
130
|
+
EnvelopeDetect,
|
|
131
|
+
FirFilter,
|
|
132
|
+
LeeFilter,
|
|
133
|
+
LogCompress,
|
|
134
|
+
LowPassFilter,
|
|
135
|
+
PfieldWeighting,
|
|
136
|
+
ReshapeGrid,
|
|
137
|
+
ScanConvert,
|
|
138
|
+
Simulate,
|
|
139
|
+
TOFCorrection,
|
|
140
|
+
UpMix,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
__all__ = [
|
|
144
|
+
# Registry
|
|
145
|
+
"ops_registry",
|
|
146
|
+
# Base operations
|
|
147
|
+
"Identity",
|
|
148
|
+
"ImageOperation",
|
|
149
|
+
"Lambda",
|
|
150
|
+
"Mean",
|
|
151
|
+
"Merge",
|
|
152
|
+
"Operation",
|
|
153
|
+
"Stack",
|
|
154
|
+
"get_ops",
|
|
155
|
+
# Pipeline
|
|
156
|
+
"DelayAndSum",
|
|
157
|
+
"DelayMultiplyAndSum",
|
|
158
|
+
"Beamform",
|
|
159
|
+
"BranchedPipeline",
|
|
160
|
+
"Map",
|
|
161
|
+
"PatchedGrid",
|
|
162
|
+
"Pipeline",
|
|
163
|
+
# Tensor operations
|
|
164
|
+
"GaussianBlur",
|
|
165
|
+
"Normalize",
|
|
166
|
+
"Pad",
|
|
167
|
+
"Threshold",
|
|
168
|
+
# Ultrasound operations
|
|
169
|
+
"AnisotropicDiffusion",
|
|
170
|
+
"ChannelsToComplex",
|
|
171
|
+
"Companding",
|
|
172
|
+
"ComplexToChannels",
|
|
173
|
+
"Demodulate",
|
|
174
|
+
"Downsample",
|
|
175
|
+
"EnvelopeDetect",
|
|
176
|
+
"FirFilter",
|
|
177
|
+
"LeeFilter",
|
|
178
|
+
"LogCompress",
|
|
179
|
+
"LowPassFilter",
|
|
180
|
+
"PfieldWeighting",
|
|
181
|
+
"ReshapeGrid",
|
|
182
|
+
"ScanConvert",
|
|
183
|
+
"Simulate",
|
|
184
|
+
"TOFCorrection",
|
|
185
|
+
"UpMix",
|
|
186
|
+
# Keras operations
|
|
187
|
+
"keras_ops",
|
|
188
|
+
]
|
zea/ops/base.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import inspect
|
|
3
|
+
import json
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Any, Dict, List, Union
|
|
6
|
+
|
|
7
|
+
import keras
|
|
8
|
+
from keras import ops
|
|
9
|
+
|
|
10
|
+
from zea import log
|
|
11
|
+
from zea.backend import jit
|
|
12
|
+
from zea.internal.checks import _assert_keys_and_axes
|
|
13
|
+
from zea.internal.core import (
|
|
14
|
+
DataTypes,
|
|
15
|
+
)
|
|
16
|
+
from zea.internal.registry import ops_registry
|
|
17
|
+
from zea.utils import (
|
|
18
|
+
deep_compare,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_ops(ops_name):
|
|
23
|
+
"""Get the operation from the registry."""
|
|
24
|
+
return ops_registry[ops_name]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Operation(keras.Operation):
|
|
28
|
+
"""
|
|
29
|
+
A base abstract class for operations in the pipeline with caching functionality.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
input_data_type: Union[DataTypes, None] = None,
|
|
35
|
+
output_data_type: Union[DataTypes, None] = None,
|
|
36
|
+
key: Union[str, None] = "data",
|
|
37
|
+
output_key: Union[str, None] = None,
|
|
38
|
+
cache_inputs: Union[bool, List[str]] = False,
|
|
39
|
+
cache_outputs: bool = False,
|
|
40
|
+
jit_compile: bool = True,
|
|
41
|
+
with_batch_dim: bool = True,
|
|
42
|
+
jit_kwargs: dict | None = None,
|
|
43
|
+
jittable: bool = True,
|
|
44
|
+
additional_output_keys: List[str] = None,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Args:
|
|
49
|
+
input_data_type (DataTypes): The data type of the input data
|
|
50
|
+
output_data_type (DataTypes): The data type of the output data
|
|
51
|
+
key: The key for the input data (operation will operate on this key)
|
|
52
|
+
Defaults to "data".
|
|
53
|
+
output_key: The key for the output data (operation will output to this key)
|
|
54
|
+
Defaults to the same as the input key. If you want to store intermediate
|
|
55
|
+
results, you can set this to a different key. But make sure to update the
|
|
56
|
+
input key of the next operation to match the output key of this operation.
|
|
57
|
+
cache_inputs: A list of input keys to cache or True to cache all inputs
|
|
58
|
+
cache_outputs: A list of output keys to cache or True to cache all outputs
|
|
59
|
+
jit_compile: Whether to JIT compile the 'call' method for faster execution
|
|
60
|
+
with_batch_dim: Whether operations should expect a batch dimension in the input
|
|
61
|
+
jit_kwargs: Additional keyword arguments for the JIT compiler
|
|
62
|
+
jittable: Whether the operation can be JIT compiled
|
|
63
|
+
additional_output_keys: A list of additional output keys produced by the operation.
|
|
64
|
+
These are used to track if all keys are available for downstream operations.
|
|
65
|
+
If the operation has a conditional output, it is best to add all possible
|
|
66
|
+
output keys here.
|
|
67
|
+
"""
|
|
68
|
+
super().__init__(**kwargs)
|
|
69
|
+
|
|
70
|
+
self.input_data_type = input_data_type
|
|
71
|
+
self.output_data_type = output_data_type
|
|
72
|
+
|
|
73
|
+
self.key = key # Key for input data
|
|
74
|
+
self.output_key = output_key # Key for output data
|
|
75
|
+
if self.output_key is None:
|
|
76
|
+
self.output_key = self.key
|
|
77
|
+
self.additional_output_keys = additional_output_keys or []
|
|
78
|
+
|
|
79
|
+
self.inputs = [] # Source(s) of input data (name of a previous operation)
|
|
80
|
+
self.allow_multiple_inputs = False # Only single input allowed by default
|
|
81
|
+
|
|
82
|
+
self.cache_inputs = cache_inputs
|
|
83
|
+
self.cache_outputs = cache_outputs
|
|
84
|
+
|
|
85
|
+
# Initialize input and output caches
|
|
86
|
+
self._input_cache = {}
|
|
87
|
+
self._output_cache = {}
|
|
88
|
+
|
|
89
|
+
# Obtain the input signature of the `call` method
|
|
90
|
+
self._trace_signatures()
|
|
91
|
+
|
|
92
|
+
if jit_kwargs is None:
|
|
93
|
+
jit_kwargs = {}
|
|
94
|
+
|
|
95
|
+
if keras.backend.backend() == "jax" and self.static_params:
|
|
96
|
+
jit_kwargs |= {"static_argnames": self.static_params}
|
|
97
|
+
|
|
98
|
+
self.jit_kwargs = jit_kwargs
|
|
99
|
+
|
|
100
|
+
self.with_batch_dim = with_batch_dim
|
|
101
|
+
self._jittable = jittable
|
|
102
|
+
|
|
103
|
+
# Set the jit compilation flag and compile the `call` method
|
|
104
|
+
# Set zea logger level to suppress warnings regarding
|
|
105
|
+
# torch not being able to compile the function
|
|
106
|
+
with log.set_level("ERROR"):
|
|
107
|
+
self.set_jit(jit_compile)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def output_keys(self) -> List[str]:
|
|
111
|
+
"""Get the output keys of the operation."""
|
|
112
|
+
return [self.output_key] + self.additional_output_keys
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def static_params(self):
|
|
116
|
+
"""Get the static parameters of the operation."""
|
|
117
|
+
return getattr(self.__class__, "STATIC_PARAMS", [])
|
|
118
|
+
|
|
119
|
+
def set_jit(self, jit_compile: bool):
|
|
120
|
+
"""Set the JIT compilation flag and set the `_call` method accordingly."""
|
|
121
|
+
self._jit_compile = jit_compile
|
|
122
|
+
if self._jit_compile and self.jittable:
|
|
123
|
+
self._call = jit(self.call, **self.jit_kwargs)
|
|
124
|
+
else:
|
|
125
|
+
self._call = self.call
|
|
126
|
+
|
|
127
|
+
def _trace_signatures(self):
|
|
128
|
+
"""
|
|
129
|
+
Analyze and store the input/output signatures of the `call` method.
|
|
130
|
+
"""
|
|
131
|
+
self._input_signature = inspect.signature(self.call)
|
|
132
|
+
self._valid_keys = set(self._input_signature.parameters.keys()) | {self.key}
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def valid_keys(self) -> set:
|
|
136
|
+
"""Get the valid keys for the `call` method."""
|
|
137
|
+
return self._valid_keys
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def needs_keys(self) -> set:
|
|
141
|
+
"""Get a set of all input keys needed by the operation."""
|
|
142
|
+
return self.valid_keys
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def jittable(self):
|
|
146
|
+
"""Check if the operation can be JIT compiled."""
|
|
147
|
+
return self._jittable
|
|
148
|
+
|
|
149
|
+
def call(self, **kwargs):
|
|
150
|
+
"""
|
|
151
|
+
Abstract method that defines the processing logic for the operation.
|
|
152
|
+
Subclasses must implement this method.
|
|
153
|
+
"""
|
|
154
|
+
raise NotImplementedError
|
|
155
|
+
|
|
156
|
+
def set_input_cache(self, input_cache: Dict[str, Any]):
|
|
157
|
+
"""
|
|
158
|
+
Set a cache for inputs, then retrace the function if necessary.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
input_cache: A dictionary containing cached inputs.
|
|
162
|
+
"""
|
|
163
|
+
self._input_cache.update(input_cache)
|
|
164
|
+
self._trace_signatures() # Retrace after updating cache to ensure correctness.
|
|
165
|
+
|
|
166
|
+
def set_output_cache(self, output_cache: Dict[str, Any]):
|
|
167
|
+
"""
|
|
168
|
+
Set a cache for outputs, then retrace the function if necessary.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
output_cache: A dictionary containing cached outputs.
|
|
172
|
+
"""
|
|
173
|
+
self._output_cache.update(output_cache)
|
|
174
|
+
self._trace_signatures() # Retrace after updating cache to ensure correctness.
|
|
175
|
+
|
|
176
|
+
def clear_cache(self):
|
|
177
|
+
"""
|
|
178
|
+
Clear the input and output caches.
|
|
179
|
+
"""
|
|
180
|
+
self._input_cache.clear()
|
|
181
|
+
self._output_cache.clear()
|
|
182
|
+
|
|
183
|
+
def _hash_inputs(self, kwargs: Dict) -> str:
|
|
184
|
+
"""
|
|
185
|
+
Generate a hash for the given inputs to use as a cache key.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
kwargs: Keyword arguments.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
A unique hash representing the inputs.
|
|
192
|
+
"""
|
|
193
|
+
input_json = json.dumps(kwargs, sort_keys=True, default=str)
|
|
194
|
+
return hashlib.md5(input_json.encode()).hexdigest()
|
|
195
|
+
|
|
196
|
+
def __call__(self, *args, **kwargs) -> Dict:
|
|
197
|
+
"""
|
|
198
|
+
Process the input keyword arguments and return the processed results.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
kwargs: Keyword arguments to be processed.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Combined input and output as kwargs.
|
|
205
|
+
"""
|
|
206
|
+
if args:
|
|
207
|
+
example_usage = f" result = {ops_registry.get_name(self)}({self.key}=my_data"
|
|
208
|
+
valid_keys_no_kwargs = self.valid_keys - {"kwargs"}
|
|
209
|
+
if valid_keys_no_kwargs:
|
|
210
|
+
example_usage += f", {list(valid_keys_no_kwargs)[0]}=param1, ..., **kwargs)"
|
|
211
|
+
else:
|
|
212
|
+
example_usage += ", **kwargs)"
|
|
213
|
+
raise TypeError(
|
|
214
|
+
f"{self.__class__.__name__}.__call__() only accepts keyword arguments. "
|
|
215
|
+
"Positional arguments are not allowed.\n"
|
|
216
|
+
f"Received positional arguments: {args}\n"
|
|
217
|
+
"Example usage:\n"
|
|
218
|
+
f"{example_usage}"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Merge cached inputs with provided ones
|
|
222
|
+
merged_kwargs = {**self._input_cache, **kwargs}
|
|
223
|
+
|
|
224
|
+
# Return cached output if available
|
|
225
|
+
if self.cache_outputs:
|
|
226
|
+
cache_key = self._hash_inputs(merged_kwargs)
|
|
227
|
+
if cache_key in self._output_cache:
|
|
228
|
+
return {**merged_kwargs, **self._output_cache[cache_key]}
|
|
229
|
+
|
|
230
|
+
# Filter kwargs to match the valid keys of the `call` method
|
|
231
|
+
if "kwargs" not in self.valid_keys:
|
|
232
|
+
filtered_kwargs = {k: v for k, v in merged_kwargs.items() if k in self.valid_keys}
|
|
233
|
+
else:
|
|
234
|
+
filtered_kwargs = merged_kwargs
|
|
235
|
+
|
|
236
|
+
# Call the processing function
|
|
237
|
+
# If you want to jump in with debugger please set `jit_compile=False`
|
|
238
|
+
# when initializing the pipeline.
|
|
239
|
+
processed_output = self._call(**filtered_kwargs)
|
|
240
|
+
|
|
241
|
+
# Ensure the output is always a dictionary
|
|
242
|
+
if not isinstance(processed_output, dict):
|
|
243
|
+
raise TypeError(
|
|
244
|
+
f"The `call` method must return a dictionary. Got {type(processed_output)}."
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Merge outputs with inputs
|
|
248
|
+
combined_kwargs = {**merged_kwargs, **processed_output}
|
|
249
|
+
|
|
250
|
+
# Cache the result if caching is enabled
|
|
251
|
+
if self.cache_outputs:
|
|
252
|
+
if isinstance(self.cache_outputs, list):
|
|
253
|
+
cached_output = {
|
|
254
|
+
k: v for k, v in processed_output.items() if k in self.cache_outputs
|
|
255
|
+
}
|
|
256
|
+
else:
|
|
257
|
+
cached_output = processed_output
|
|
258
|
+
self._output_cache[cache_key] = cached_output
|
|
259
|
+
|
|
260
|
+
return combined_kwargs
|
|
261
|
+
|
|
262
|
+
def get_dict(self):
|
|
263
|
+
"""Get the configuration of the operation. Inherit from keras.Operation."""
|
|
264
|
+
config = {}
|
|
265
|
+
config.update({"name": ops_registry.get_name(self)})
|
|
266
|
+
config["params"] = {
|
|
267
|
+
"key": self.key,
|
|
268
|
+
"output_key": self.output_key,
|
|
269
|
+
"cache_inputs": self.cache_inputs,
|
|
270
|
+
"cache_outputs": self.cache_outputs,
|
|
271
|
+
"jit_compile": self._jit_compile,
|
|
272
|
+
"with_batch_dim": self.with_batch_dim,
|
|
273
|
+
"jit_kwargs": self.jit_kwargs,
|
|
274
|
+
}
|
|
275
|
+
return config
|
|
276
|
+
|
|
277
|
+
def __eq__(self, other):
|
|
278
|
+
"""Check equality of two operations based on type and configuration."""
|
|
279
|
+
if not isinstance(other, Operation):
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
# Compare the class name and parameters
|
|
283
|
+
if self.__class__.__name__ != other.__class__.__name__:
|
|
284
|
+
return False
|
|
285
|
+
|
|
286
|
+
# Compare the name assigned to the operation
|
|
287
|
+
name = ops_registry.get_name(self)
|
|
288
|
+
other_name = ops_registry.get_name(other)
|
|
289
|
+
if name != other_name:
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
# Compare the parameters of the operations
|
|
293
|
+
if not deep_compare(self.get_dict(), other.get_dict()):
|
|
294
|
+
return False
|
|
295
|
+
|
|
296
|
+
return True
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@ops_registry("identity")
|
|
300
|
+
class Identity(Operation):
|
|
301
|
+
"""Identity operation."""
|
|
302
|
+
|
|
303
|
+
def call(self, **kwargs) -> Dict:
|
|
304
|
+
"""Returns the input as is."""
|
|
305
|
+
return {}
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class ImageOperation(Operation):
|
|
309
|
+
"""
|
|
310
|
+
Base class for image processing operations.
|
|
311
|
+
|
|
312
|
+
This class extends the Operation class to provide a common interface
|
|
313
|
+
for operations that process image data, with shape (batch, height, width, channels)
|
|
314
|
+
or (height, width, channels) if batch dimension is not present.
|
|
315
|
+
|
|
316
|
+
Subclasses should implement the `call` method to define the image processing logic, and call
|
|
317
|
+
``super().call(**kwargs)`` to validate the input data shape.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def call(self, **kwargs):
|
|
321
|
+
"""
|
|
322
|
+
Validate input data shape for image operations.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
**kwargs: Keyword arguments containing input data.
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
AssertionError: If input data does not have the expected number of dimensions.
|
|
329
|
+
"""
|
|
330
|
+
data = kwargs[self.key]
|
|
331
|
+
|
|
332
|
+
if self.with_batch_dim:
|
|
333
|
+
assert ops.ndim(data) == 4, "Input data must have 4 dimensions (b, h, w, c)."
|
|
334
|
+
else:
|
|
335
|
+
assert ops.ndim(data) == 3, "Input data must have 3 dimensions (h, w, c)."
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@ops_registry("lambda")
|
|
339
|
+
class Lambda(Operation):
|
|
340
|
+
"""Use any function as an operation."""
|
|
341
|
+
|
|
342
|
+
def __init__(self, func, **kwargs):
|
|
343
|
+
# Split kwargs into kwargs for partial and __init__
|
|
344
|
+
sig = inspect.signature(func)
|
|
345
|
+
func_params = set(sig.parameters.keys())
|
|
346
|
+
|
|
347
|
+
func_kwargs = {k: v for k, v in kwargs.items() if k in func_params}
|
|
348
|
+
op_kwargs = {k: v for k, v in kwargs.items() if k not in func_params}
|
|
349
|
+
|
|
350
|
+
Lambda._check_if_unary(func, **func_kwargs)
|
|
351
|
+
|
|
352
|
+
super().__init__(**op_kwargs)
|
|
353
|
+
self.func = partial(func, **func_kwargs)
|
|
354
|
+
|
|
355
|
+
@staticmethod
|
|
356
|
+
def _check_if_unary(func, **kwargs):
|
|
357
|
+
"""Checks if the kwargs are sufficient to call the function as a unary operation."""
|
|
358
|
+
sig = inspect.signature(func)
|
|
359
|
+
# Remove arguments that are already provided in func_kwargs
|
|
360
|
+
params = list(sig.parameters.values())
|
|
361
|
+
remaining = [p for p in params if p.name not in kwargs]
|
|
362
|
+
# Count required positional arguments (excluding self/cls)
|
|
363
|
+
required_positional = [
|
|
364
|
+
p
|
|
365
|
+
for p in remaining
|
|
366
|
+
if p.default is p.empty and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
367
|
+
]
|
|
368
|
+
if len(required_positional) != 1:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"Partial of {func.__name__} must be callable with exactly one required "
|
|
371
|
+
f"positional argument, we still need: {required_positional}."
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def call(self, **kwargs):
|
|
375
|
+
data = kwargs[self.key]
|
|
376
|
+
if self.with_batch_dim:
|
|
377
|
+
data = ops.map(self.func, data)
|
|
378
|
+
else:
|
|
379
|
+
data = self.func(data)
|
|
380
|
+
return {self.output_key: data}
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@ops_registry("mean")
|
|
384
|
+
class Mean(Operation):
|
|
385
|
+
"""Take the mean of the input data along a specific axis."""
|
|
386
|
+
|
|
387
|
+
def __init__(self, keys, axes, **kwargs):
|
|
388
|
+
super().__init__(**kwargs)
|
|
389
|
+
|
|
390
|
+
self.keys, self.axes = _assert_keys_and_axes(keys, axes)
|
|
391
|
+
|
|
392
|
+
def call(self, **kwargs):
|
|
393
|
+
for key, axis in zip(self.keys, self.axes):
|
|
394
|
+
kwargs[key] = ops.mean(kwargs[key], axis=axis)
|
|
395
|
+
|
|
396
|
+
return kwargs
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@ops_registry("merge")
|
|
400
|
+
class Merge(Operation):
|
|
401
|
+
"""Operation that merges sets of input dictionaries."""
|
|
402
|
+
|
|
403
|
+
def __init__(self, **kwargs):
|
|
404
|
+
super().__init__(**kwargs)
|
|
405
|
+
self.allow_multiple_inputs = True
|
|
406
|
+
|
|
407
|
+
def call(self, *args, **kwargs) -> Dict:
|
|
408
|
+
"""
|
|
409
|
+
Merges the input dictionaries. Priority is given to the last input.
|
|
410
|
+
"""
|
|
411
|
+
merged = {}
|
|
412
|
+
for arg in args:
|
|
413
|
+
if not isinstance(arg, dict):
|
|
414
|
+
raise TypeError("All inputs must be dictionaries.")
|
|
415
|
+
merged.update(arg)
|
|
416
|
+
return merged
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
@ops_registry("stack")
|
|
420
|
+
class Stack(Operation):
|
|
421
|
+
"""Stack multiple data arrays along a new axis.
|
|
422
|
+
Useful to merge data from parallel pipelines.
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
def __init__(
|
|
426
|
+
self,
|
|
427
|
+
keys: Union[str, List[str], None],
|
|
428
|
+
axes: Union[int, List[int], None],
|
|
429
|
+
**kwargs,
|
|
430
|
+
):
|
|
431
|
+
super().__init__(**kwargs)
|
|
432
|
+
|
|
433
|
+
self.keys, self.axes = _assert_keys_and_axes(keys, axes)
|
|
434
|
+
|
|
435
|
+
def call(self, **kwargs) -> Dict:
|
|
436
|
+
"""
|
|
437
|
+
Stacks the inputs corresponding to the specified keys along the specified axis.
|
|
438
|
+
If a list of axes is provided, the length must match the number of keys.
|
|
439
|
+
"""
|
|
440
|
+
for key, axis in zip(self.keys, self.axes):
|
|
441
|
+
kwargs[key] = keras.ops.stack([kwargs[key] for key in self.keys], axis=axis)
|
|
442
|
+
return kwargs
|
|
@@ -5,7 +5,7 @@ They can be used in zea pipelines like any other :class:`zea.Operation`, for exa
|
|
|
5
5
|
|
|
6
6
|
.. doctest::
|
|
7
7
|
|
|
8
|
-
>>> from zea.keras_ops import Squeeze
|
|
8
|
+
>>> from zea.ops.keras_ops import Squeeze
|
|
9
9
|
|
|
10
10
|
>>> op = Squeeze(axis=1)
|
|
11
11
|
|
|
@@ -16,7 +16,7 @@ Generated with Keras 3.12.0
|
|
|
16
16
|
import keras
|
|
17
17
|
|
|
18
18
|
from zea.internal.registry import ops_registry
|
|
19
|
-
from zea.ops import Lambda
|
|
19
|
+
from zea.ops.base import Lambda
|
|
20
20
|
|
|
21
21
|
class MissingKerasOps(ValueError):
|
|
22
22
|
def __init__(self, class_name: str, func: str):
|