zea 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/models/__init__.py CHANGED
@@ -4,8 +4,9 @@
4
4
 
5
5
  Currently, the following models are available (all inherited from :class:`zea.models.BaseModel`):
6
6
 
7
- - :class:`zea.models.echonet.EchoNetDynamic`: A model for echocardiography segmentation.
7
+ - :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
8
8
  - :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
9
+ - :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
9
10
  - :class:`zea.models.unet.UNet`: A simple U-Net implementation.
10
11
  - :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
11
12
  - :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
@@ -16,7 +17,7 @@ To use these models, you can import them directly from the :mod:`zea.models` mod
16
17
 
17
18
  .. code-block:: python
18
19
 
19
- from zea.models import UNet
20
+ from zea.models.unet import UNet
20
21
 
21
22
  model = UNet.from_preset("unet-echonet-inpainter")
22
23
 
@@ -48,7 +49,7 @@ An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is sho
48
49
 
49
50
  .. code-block:: python
50
51
 
51
- from zea.models import DiffusionModel
52
+ from zea.models.diffusion import DiffusionModel
52
53
 
53
54
  model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
54
55
  samples = model.sample(n_samples=4)
@@ -74,9 +75,11 @@ The following steps are recommended when adding a new model:
74
75
 
75
76
  from . import (
76
77
  carotid_segmenter,
78
+ deeplabv3,
77
79
  dense,
78
80
  diffusion,
79
81
  echonet,
82
+ echonetlvh,
80
83
  generative,
81
84
  gmm,
82
85
  layers,
@@ -0,0 +1,131 @@
1
+ """DeepLabV3+ architecture for multi-class segmentation. For more details see https://arxiv.org/abs/1802.02611."""
2
+
3
+ import keras
4
+ from keras import layers, ops
5
+
6
+
7
+ def convolution_block(
8
+ block_input,
9
+ num_filters=256,
10
+ kernel_size=3,
11
+ dilation_rate=1,
12
+ use_bias=False,
13
+ ):
14
+ """
15
+ Create a convolution block with batch normalization and ReLU activation.
16
+
17
+ This is a standard building block used throughout the DeepLabV3+ architecture,
18
+ consisting of Conv2D -> BatchNormalization -> ReLU.
19
+
20
+ Args:
21
+ block_input (Tensor): Input tensor to the convolution block
22
+ num_filters (int): Number of output filters/channels. Defaults to 256.
23
+ kernel_size (int): Size of the convolution kernel. Defaults to 3.
24
+ dilation_rate (int): Dilation rate for dilated convolution. Defaults to 1.
25
+ use_bias (bool): Whether to use bias in the convolution layer. Defaults to False.
26
+
27
+ Returns:
28
+ Tensor: Output tensor after convolution, batch normalization, and ReLU
29
+ """
30
+ x = layers.Conv2D(
31
+ num_filters,
32
+ kernel_size=kernel_size,
33
+ dilation_rate=dilation_rate,
34
+ padding="same",
35
+ use_bias=use_bias,
36
+ kernel_initializer=keras.initializers.HeNormal(),
37
+ )(block_input)
38
+ x = layers.BatchNormalization()(x)
39
+ return ops.nn.relu(x)
40
+
41
+
42
+ def DilatedSpatialPyramidPooling(dspp_input):
43
+ """
44
+ Implement Atrous Spatial Pyramid Pooling (ASPP) module.
45
+
46
+ ASPP captures multi-scale context by applying parallel atrous convolutions
47
+ with different dilation rates. This helps the model understand objects
48
+ at multiple scales.
49
+
50
+ The module consists of:
51
+ - Global average pooling branch
52
+ - 1x1 convolution branch
53
+ - 3x3 convolutions with dilation rates 6, 12, and 18
54
+
55
+ Reference: https://arxiv.org/abs/1706.05587
56
+
57
+ Args:
58
+ dspp_input (Tensor): Input feature tensor from encoder
59
+
60
+ Returns:
61
+ Tensor: Multi-scale feature representation
62
+ """
63
+ dims = dspp_input.shape
64
+ x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
65
+ x = convolution_block(x, kernel_size=1, use_bias=True)
66
+ out_pool = layers.UpSampling2D(
67
+ size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),
68
+ interpolation="bilinear",
69
+ )(x)
70
+
71
+ out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
72
+ out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
73
+ out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
74
+ out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
75
+
76
+ x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
77
+ output = convolution_block(x, kernel_size=1)
78
+ return output
79
+
80
+
81
+ def DeeplabV3Plus(image_shape, num_classes, pretrained_weights=None):
82
+ """
83
+ Build DeepLabV3+ model for semantic segmentation.
84
+
85
+ DeepLabV3+ combines the benefits of spatial pyramid pooling and encoder-decoder
86
+ architecture. It uses a ResNet50 backbone as encoder, ASPP for multi-scale
87
+ feature extraction, and a simple decoder for recovering spatial details.
88
+
89
+ Architecture:
90
+ 1. Encoder: ResNet50 backbone with atrous convolutions
91
+ 2. ASPP: Multi-scale feature extraction
92
+ 3. Decoder: Simple decoder with skip connections
93
+ 4. Output: Final segmentation prediction
94
+
95
+ Reference: https://arxiv.org/abs/1802.02611
96
+
97
+ Args:
98
+ image_shape (tuple): Input image shape as (height, width, channels)
99
+ num_classes (int): Number of output classes for segmentation
100
+ pretrained_weights (str, optional): Pretrained weights for ResNet50 backbone.
101
+ Defaults to None.
102
+
103
+ Returns:
104
+ keras.Model: Complete DeepLabV3+ model
105
+ """
106
+ model_input = keras.Input(shape=image_shape)
107
+ # 3-channel grayscale as repeated single channel for ResNet50
108
+ model_input_3_channel = ops.concatenate([model_input, model_input, model_input], axis=-1)
109
+ preprocessed = keras.applications.resnet50.preprocess_input(model_input_3_channel)
110
+ resnet50 = keras.applications.ResNet50(
111
+ weights=pretrained_weights, include_top=False, input_tensor=preprocessed
112
+ )
113
+ x = resnet50.get_layer("conv4_block6_2_relu").output
114
+ x = DilatedSpatialPyramidPooling(x)
115
+
116
+ input_a = layers.UpSampling2D(
117
+ size=(image_shape[0] // 4 // x.shape[1], image_shape[1] // 4 // x.shape[2]),
118
+ interpolation="bilinear",
119
+ )(x)
120
+ input_b = resnet50.get_layer("conv2_block3_2_relu").output
121
+ input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
122
+
123
+ x = layers.Concatenate(axis=-1)([input_a, input_b])
124
+ x = convolution_block(x)
125
+ x = convolution_block(x)
126
+ x = layers.UpSampling2D(
127
+ size=(image_shape[0] // x.shape[1], image_shape[1] // x.shape[2]),
128
+ interpolation="bilinear",
129
+ )(x)
130
+ model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
131
+ return keras.Model(inputs=model_input, outputs=model_output)
zea/models/diffusion.py CHANGED
@@ -728,15 +728,15 @@ class DiffusionModel(DeepGenerativeModel):
728
728
  )
729
729
  return next_noisy_images
730
730
 
731
- def start_track_progress(self, diffusion_steps):
731
+ def start_track_progress(self, diffusion_steps, initial_step=0):
732
732
  """Initialize the progress tracking for the diffusion process.
733
-
734
733
  For diffusion animation we keep track of the diffusion progress.
735
734
  For large number of steps, we do not store all the images due to memory constraints.
736
735
  """
737
736
  self.track_progress = []
738
- if diffusion_steps > 50:
739
- self.track_progress_interval = diffusion_steps // 50
737
+ remaining = max(1, diffusion_steps - int(initial_step))
738
+ if remaining > 50:
739
+ self.track_progress_interval = remaining // 50
740
740
  else:
741
741
  self.track_progress_interval = 1
742
742
 
@@ -0,0 +1,290 @@
1
+ """EchoNetLVH model for segmentation of PLAX view cardiac ultrasound. For more details see https://echonet.github.io/lvh/index.html."""
2
+
3
+ import numpy as np
4
+ from keras import ops
5
+
6
+ from zea.internal.registry import model_registry
7
+ from zea.models.base import BaseModel
8
+ from zea.models.deeplabv3 import DeeplabV3Plus
9
+ from zea.models.preset_utils import register_presets
10
+ from zea.models.presets import echonet_lvh_presets
11
+ from zea.utils import translate
12
+
13
+
14
+ @model_registry(name="echonetlvh")
15
+ class EchoNetLVH(BaseModel):
16
+ """
17
+ EchoNet Left Ventricular Hypertrophy (LVH) model for echocardiogram analysis.
18
+
19
+ This model performs semantic segmentation on echocardiogram images to identify
20
+ key anatomical landmarks for measuring left ventricular wall thickness:
21
+ - LVPWd_1: Left Ventricular Posterior Wall point 1
22
+ - LVPWd_2: Left Ventricular Posterior Wall point 2
23
+ - IVSd_1: Interventricular Septum point 1
24
+ - IVSd_2: Interventricular Septum point 2
25
+
26
+ The model outputs 4-channel logits corresponding to heatmaps for each landmark.
27
+
28
+ For more information, see the original project page at https://echonet.github.io/lvh/index.html
29
+ """
30
+
31
+ def __init__(self, **kwargs):
32
+ """
33
+ Initialize the EchoNetLVH model.
34
+
35
+ Args:
36
+ **kwargs: Additional keyword arguments passed to BaseModel
37
+ """
38
+ super().__init__(**kwargs)
39
+
40
+ # Scan conversion constants for echonet processing
41
+ self.rho_range = (0, 224) # Radial distance range in pixels
42
+ self.theta_range = (np.deg2rad(-45), np.deg2rad(45)) # Angular range in radians
43
+ self.fill_value = -1.0 # Fill value for scan conversion
44
+ self.resolution = 1.0 # mm per pixel resolution
45
+
46
+ # Network input/output dimensions
47
+ self.n_rho = 224
48
+ self.n_theta = 224
49
+ self.output_shape = (224, 224, 4)
50
+
51
+ # Pre-computed coordinate grid for efficient processing
52
+ self.coordinate_grid = ops.stack(
53
+ ops.cast(ops.convert_to_tensor(np.indices((224, 224))), "float32"), axis=-1
54
+ )
55
+
56
+ # Initialize the underlying segmentation network
57
+ self.network = DeeplabV3Plus(image_shape=(224, 224, 3), num_classes=4)
58
+
59
+ def call(self, inputs):
60
+ """
61
+ Forward pass of the model.
62
+
63
+ Args:
64
+ inputs (Tensor): Input images of shape [B, H, W, C]. They should
65
+ be scan converted, with pixel values in range [0, 255].
66
+
67
+ Returns:
68
+ Tensor: Logits of shape [B, H, W, 4] with 4 channels for each landmark
69
+ """
70
+ assert len(ops.shape(inputs)) == 4
71
+
72
+ # Store original dimensions for output resizing
73
+ original_size = ops.shape(inputs)[1:3]
74
+
75
+ # Resize to network input size
76
+ inputs_resized = ops.image.resize(inputs, size=(224, 224))
77
+
78
+ # Get network predictions
79
+ logits = self.network(inputs_resized)
80
+
81
+ # Resize logits back to original input dimensions
82
+ logits_output = ops.image.resize(logits, original_size)
83
+ return logits_output
84
+
85
+ def extract_key_points_as_indices(self, logits):
86
+ """
87
+ Extract key point coordinates from logits using center-of-mass calculation.
88
+
89
+ Args:
90
+ logits (Tensor): Model output logits of shape [B, H, W, 4]
91
+
92
+ Returns:
93
+ Tensor: Key point coordinates of shape [B, 4, 2] where each point
94
+ is in (x, y) format
95
+ """
96
+ # Create coordinate grid for the current logit dimensions
97
+ input_shape = ops.shape(logits)[1:3]
98
+ input_space_coordinate_grid = ops.stack(
99
+ ops.cast(ops.convert_to_tensor(np.indices(input_shape)), "float32"), axis=-1
100
+ )
101
+
102
+ # Transpose logits to [B, 4, H, W] for vectorized processing
103
+ logits_batchified = ops.transpose(logits, (0, 3, 1, 2))
104
+
105
+ # Extract expected coordinates for each channel
106
+ return ops.flip(
107
+ ops.vectorized_map(
108
+ lambda logit: self.expected_coordinate(logit, input_space_coordinate_grid),
109
+ logits_batchified,
110
+ ),
111
+ axis=-1, # Flip to convert from (y, x) to (x, y)
112
+ )
113
+
114
+ def expected_coordinate(self, mask, coordinate_grid=None):
115
+ """
116
+ Compute the expected coordinate (center-of-mass) of a heatmap.
117
+
118
+ This implements a differentiable version of taking the max of a heatmap
119
+ by computing the weighted average of coordinates.
120
+
121
+ Reference: https://arxiv.org/pdf/1711.08229
122
+
123
+ Args:
124
+ mask (Tensor): Heatmap of shape [B, H, W]
125
+ coordinate_grid (Tensor, optional): Grid of coordinates. If None,
126
+ uses self.coordinate_grid
127
+
128
+ Returns:
129
+ Tensor: Expected coordinates of shape [B, 2] in (x, y) format
130
+ """
131
+ if coordinate_grid is None:
132
+ coordinate_grid = self.coordinate_grid
133
+
134
+ # Ensure mask values are non-negative and normalized
135
+ mask_clipped = ops.clip(mask, 0, None)
136
+ mask_normed = mask_clipped / ops.max(mask_clipped)
137
+
138
+ def safe_normalize(m):
139
+ mask_sum = ops.sum(m)
140
+ return ops.where(mask_sum > 0, m / mask_sum, m)
141
+
142
+ coordinate_probabilities = ops.map(safe_normalize, mask_normed)
143
+
144
+ # Add dimension for broadcasting with coordinate grid
145
+ coordinate_probabilities = ops.expand_dims(coordinate_probabilities, axis=-1)
146
+
147
+ # Compute weighted average of coordinates
148
+ expected_coordinate = ops.sum(
149
+ ops.expand_dims(coordinate_grid, axis=0) * coordinate_probabilities,
150
+ axis=(1, 2),
151
+ )
152
+
153
+ # Flip to convert from (y, x) to (x, y) format for euclidean distance calculation
154
+ return ops.flip(expected_coordinate, axis=-1)
155
+
156
+ def overlay_labels_on_image(self, image, label, alpha=0.5):
157
+ """
158
+ Overlay predicted heatmaps and connecting lines on the input image.
159
+
160
+ Args:
161
+ image (Tensor): Input image of shape [H, W] or [H, W, C]
162
+ label (Tensor): Predicted logits of shape [H, W, 4]
163
+ alpha (float): Blending factor for overlay (0=transparent, 1=opaque)
164
+
165
+ Returns:
166
+ ndarray: Image with overlaid heatmaps and measurements of shape [H, W, 3]
167
+ """
168
+ try:
169
+ import cv2
170
+
171
+ except ImportError as exc:
172
+ raise ImportError(
173
+ "OpenCV is required for `EchoNetLVH.overlay_labels_on_image`. "
174
+ "Please install it with 'pip install opencv-python' or "
175
+ "'pip install opencv-python-headless'."
176
+ ) from exc
177
+
178
+ # Color scheme for each landmark
179
+ overlay_colors = np.array(
180
+ [
181
+ [1, 1, 0], # Yellow (LVPWd_X1)
182
+ [1, 0, 1], # Magenta (LVPWd_X2)
183
+ [0, 1, 1], # Cyan (IVSd_X1)
184
+ [0, 1, 0], # Green (IVSd_X2)
185
+ ],
186
+ )
187
+
188
+ # Convert to numpy and ensure RGB format
189
+ image = ops.convert_to_numpy(image)
190
+ label = ops.convert_to_numpy(label)
191
+
192
+ if image.ndim == 2:
193
+ image = np.stack([image] * 3, axis=-1)
194
+ elif image.shape[-1] == 1:
195
+ image = np.repeat(image, 3, axis=-1)
196
+ else:
197
+ image = image.copy()
198
+
199
+ # Normalize each channel to [0, 1] for proper visualization
200
+ label = np.clip(label, 0, None)
201
+ for ch in range(label.shape[-1]):
202
+ max_val = np.max(label[..., ch])
203
+ if max_val > 0:
204
+ label[..., ch] = label[..., ch] / max_val
205
+
206
+ # Initialize overlay and tracking variables
207
+ overlay = np.zeros_like(image, dtype=np.float32)
208
+ centers = []
209
+
210
+ # Process each landmark channel
211
+ for ch in range(4):
212
+ # Square the mask to enhance peak responses
213
+ mask = label[..., ch] ** 2
214
+ color = overlay_colors[ch]
215
+
216
+ # Find center of mass for this channel
217
+ center_coords = self.expected_coordinate(ops.expand_dims(mask, axis=0))
218
+ center_x = ops.convert_to_numpy(center_coords[0, 0])
219
+ center_y = ops.convert_to_numpy(center_coords[0, 1])
220
+
221
+ # Bounds check before conversion to int
222
+ if 0 <= center_x < image.shape[1] and 0 <= center_y < image.shape[0]:
223
+ center = (int(center_x), int(center_y))
224
+ else:
225
+ center = None
226
+
227
+ if center is not None:
228
+ # Blend heatmap with overlay
229
+ mask_alpha = mask * alpha
230
+ for c in range(3):
231
+ overlay[..., c] += mask_alpha * color[c]
232
+ centers.append(center)
233
+
234
+ # Draw connecting lines between consecutive landmarks
235
+ for i in range(3):
236
+ pt1, pt2 = centers[i], centers[i + 1]
237
+ if pt1 is not None and pt2 is not None:
238
+ color = tuple(int(x) for x in overlay_colors[i])
239
+
240
+ # Create line mask
241
+ line_mask = np.zeros(image.shape[:2], dtype=np.uint8)
242
+ cv2.line(line_mask, pt1, pt2, color=1, thickness=2)
243
+
244
+ # Apply line to overlay
245
+ for c in range(3):
246
+ overlay[..., c][line_mask.astype(bool)] = color[c] * alpha
247
+
248
+ # Blend overlay with original image
249
+ overlay = np.clip(overlay, 0, 1)
250
+ out = image.astype(np.float32)
251
+ blend_mask = np.any(overlay > 0.02, axis=-1)
252
+ out[blend_mask] = (1 - alpha) * out[blend_mask] + overlay[blend_mask]
253
+
254
+ return np.clip(out, 0, 1)
255
+
256
+ def visualize_logits(self, images, logits):
257
+ """
258
+ Create visualization of model predictions overlaid on input images.
259
+
260
+ Args:
261
+ images (Tensor): Input images of shape [B, H, W, C]
262
+ logits (Tensor): Model predictions of shape [B, H, W, 4]
263
+
264
+ Returns:
265
+ Tensor: Images with overlaid predictions of shape [B, H, W, 3]
266
+ """
267
+ # Store original dimensions for final output
268
+ original_size = ops.shape(images)[1:3]
269
+
270
+ # Resize to standard processing size
271
+ images_resized = ops.image.resize(images, size=(224, 224), interpolation="nearest")
272
+ logits_resized = ops.image.resize(logits, size=(224, 224), interpolation="nearest")
273
+
274
+ # Normalize images to [0, 1] range
275
+ images_clipped = ops.clip(images_resized, 0, 255)
276
+ images = translate(images_clipped, range_from=(0, 255), range_to=(0, 1))
277
+
278
+ # Generate overlays for each image in the batch
279
+ images_with_overlay = []
280
+ for img, logit_heatmap in zip(images, logits_resized):
281
+ overlay = self.overlay_labels_on_image(img, logit_heatmap)
282
+ images_with_overlay.append(overlay)
283
+
284
+ # Stack results and resize back to original dimensions
285
+ images_with_overlay = np.stack(images_with_overlay, axis=0)
286
+ return ops.image.resize(images_with_overlay, original_size)
287
+
288
+
289
+ # Register model presets
290
+ register_presets(echonet_lvh_presets, EchoNetLVH)
zea/models/presets.py CHANGED
@@ -47,6 +47,20 @@ echonet_dynamic_presets = {
47
47
  },
48
48
  }
49
49
 
50
+ echonet_lvh_presets = {
51
+ "echonetlvh": {
52
+ "metadata": {
53
+ "description": (
54
+ "EchoNetLVH segmentation model for PLAX-view cardiac ultrasound segmentation. "
55
+ "Trained on images of size (224, 224)."
56
+ ),
57
+ "params": 0,
58
+ "path": "echonetlvh",
59
+ },
60
+ "hf_handle": "hf://zeahub/echonetlvh",
61
+ },
62
+ }
63
+
50
64
  lpips_presets = {
51
65
  "lpips": {
52
66
  "metadata": {
zea/ops.py CHANGED
@@ -1376,20 +1376,6 @@ class Mean(Operation):
1376
1376
  return kwargs
1377
1377
 
1378
1378
 
1379
- @ops_registry("transpose")
1380
- class Transpose(Operation):
1381
- """Transpose the input data along the specified axes."""
1382
-
1383
- def __init__(self, axes, **kwargs):
1384
- super().__init__(**kwargs)
1385
- self.axes = axes
1386
-
1387
- def call(self, **kwargs):
1388
- data = kwargs[self.key]
1389
- transposed_data = ops.transpose(data, axes=self.axes)
1390
- return {self.output_key: transposed_data}
1391
-
1392
-
1393
1379
  @ops_registry("simulate_rf")
1394
1380
  class Simulate(Operation):
1395
1381
  """Simulate RF data."""
@@ -1578,19 +1564,6 @@ class PfieldWeighting(Operation):
1578
1564
  return {self.output_key: weighted_data}
1579
1565
 
1580
1566
 
1581
- @ops_registry("sum")
1582
- class Sum(Operation):
1583
- """Sum data along a specific axis."""
1584
-
1585
- def __init__(self, axis, **kwargs):
1586
- super().__init__(**kwargs)
1587
- self.axis = axis
1588
-
1589
- def call(self, **kwargs):
1590
- data = kwargs[self.key]
1591
- return {self.output_key: ops.sum(data, axis=self.axis)}
1592
-
1593
-
1594
1567
  @ops_registry("delay_and_sum")
1595
1568
  class DelayAndSum(Operation):
1596
1569
  """Sums time-delayed signals along channels and transmits."""
@@ -2124,29 +2097,37 @@ class Demodulate(Operation):
2124
2097
  class Lambda(Operation):
2125
2098
  """Use any function as an operation."""
2126
2099
 
2127
- def __init__(self, func, func_kwargs=None, **kwargs):
2128
- super().__init__(**kwargs)
2129
- func_kwargs = func_kwargs or {}
2130
- self.func = partial(func, **func_kwargs)
2131
-
2132
- def call(self, **kwargs):
2133
- data = kwargs[self.key]
2134
- data = self.func(data)
2135
- return {self.output_key: data}
2136
-
2100
+ def __init__(self, func, **kwargs):
2101
+ # Split kwargs into kwargs for partial and __init__
2102
+ op_kwargs = {k: v for k, v in kwargs.items() if k not in func.__code__.co_varnames}
2103
+ func_kwargs = {k: v for k, v in kwargs.items() if k in func.__code__.co_varnames}
2104
+ Lambda._check_if_unary(func, **func_kwargs)
2137
2105
 
2138
- @ops_registry("clip")
2139
- class Clip(Operation):
2140
- """Clip the input data to a given range."""
2106
+ super().__init__(**op_kwargs)
2107
+ self.func = partial(func, **func_kwargs)
2141
2108
 
2142
- def __init__(self, min_value=None, max_value=None, **kwargs):
2143
- super().__init__(**kwargs)
2144
- self.min_value = min_value
2145
- self.max_value = max_value
2109
+ @staticmethod
2110
+ def _check_if_unary(func, **kwargs):
2111
+ """Checks if the kwargs are sufficient to call the function as a unary operation."""
2112
+ sig = inspect.signature(func)
2113
+ # Remove arguments that are already provided in func_kwargs
2114
+ params = list(sig.parameters.values())
2115
+ remaining = [p for p in params if p.name not in kwargs]
2116
+ # Count required positional arguments (excluding self/cls)
2117
+ required_positional = [
2118
+ p
2119
+ for p in remaining
2120
+ if p.default is p.empty and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
2121
+ ]
2122
+ if len(required_positional) != 1:
2123
+ raise ValueError(
2124
+ f"Partial of {func.__name__} must be callable with exactly one required "
2125
+ f"positional argument, we still need: {required_positional}."
2126
+ )
2146
2127
 
2147
2128
  def call(self, **kwargs):
2148
2129
  data = kwargs[self.key]
2149
- data = ops.clip(data, self.min_value, self.max_value)
2130
+ data = self.func(data)
2150
2131
  return {self.output_key: data}
2151
2132
 
2152
2133
 
@@ -2685,6 +2666,7 @@ class AnisotropicDiffusion(Operation):
2685
2666
  return result
2686
2667
 
2687
2668
 
2669
+ @ops_registry("channels_to_complex")
2688
2670
  class ChannelsToComplex(Operation):
2689
2671
  def call(self, **kwargs):
2690
2672
  data = kwargs[self.key]
@@ -2692,6 +2674,7 @@ class ChannelsToComplex(Operation):
2692
2674
  return {self.output_key: output}
2693
2675
 
2694
2676
 
2677
+ @ops_registry("complex_to_channels")
2695
2678
  class ComplexToChannels(Operation):
2696
2679
  def __init__(self, axis=-1, **kwargs):
2697
2680
  super().__init__(**kwargs)
zea/scan.py CHANGED
@@ -149,12 +149,13 @@ class Scan(Parameters):
149
149
  Defaults to 0.0.
150
150
  attenuation_coef (float, optional): Attenuation coefficient in dB/(MHz*cm).
151
151
  Defaults to 0.0.
152
- selected_transmits (None, str, int, list, or np.ndarray, optional):
152
+ selected_transmits (None, str, int, list, slice, or np.ndarray, optional):
153
153
  Specifies which transmit events to select.
154
154
  - None or "all": Use all transmits.
155
155
  - "center": Use only the center transmit.
156
156
  - int: Select this many evenly spaced transmits.
157
157
  - list/array: Use these specific transmit indices.
158
+ - slice: Use transmits specified by the slice (e.g., slice(0, 10, 2)).
158
159
  grid_type (str, optional): Type of grid to use for beamforming.
159
160
  Can be "cartesian" or "polar". Defaults to "cartesian".
160
161
  dynamic_range (tuple, optional): Dynamic range for image display.
@@ -171,13 +172,14 @@ class Scan(Parameters):
171
172
  "pixels_per_wavelength": {"type": int, "default": 4},
172
173
  "pfield_kwargs": {"type": dict, "default": {}},
173
174
  "apply_lens_correction": {"type": bool, "default": False},
174
- "lens_sound_speed": {"type": (float, int)},
175
+ "lens_sound_speed": {"type": float},
175
176
  "lens_thickness": {"type": float},
176
177
  "grid_type": {"type": str, "default": "cartesian"},
177
178
  "polar_limits": {"type": (tuple, list)},
178
179
  "dynamic_range": {"type": (tuple, list), "default": DEFAULT_DYNAMIC_RANGE},
180
+ "selected_transmits": {"type": (type(None), str, int, list, slice, np.ndarray)},
179
181
  # acquisition parameters
180
- "sound_speed": {"type": (float, int), "default": 1540.0},
182
+ "sound_speed": {"type": float, "default": 1540.0},
181
183
  "sampling_frequency": {"type": float},
182
184
  "center_frequency": {"type": float},
183
185
  "n_el": {"type": int},
@@ -359,6 +361,7 @@ class Scan(Parameters):
359
361
  - "center": Use only the center transmit
360
362
  - int: Select this many evenly spaced transmits
361
363
  - list/array: Use these specific transmit indices
364
+ - slice: Use transmits specified by the slice (e.g., slice(0, 10, 2))
362
365
 
363
366
  Returns:
364
367
  The current instance for method chaining.
@@ -416,6 +419,10 @@ class Scan(Parameters):
416
419
  self._invalidate_dependents("selected_transmits")
417
420
  return self
418
421
 
422
+ # Handle slice - convert to list of indices
423
+ if isinstance(selection, slice):
424
+ selection = list(range(n_tx_total))[selection]
425
+
419
426
  # Handle list of indices
420
427
  if isinstance(selection, list):
421
428
  # Validate indices