napari-tmidas 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. napari_tmidas/__init__.py +35 -5
  2. napari_tmidas/_crop_anything.py +1520 -609
  3. napari_tmidas/_env_manager.py +76 -0
  4. napari_tmidas/_file_conversion.py +1646 -1131
  5. napari_tmidas/_file_selector.py +1455 -216
  6. napari_tmidas/_label_inspection.py +83 -8
  7. napari_tmidas/_processing_worker.py +309 -0
  8. napari_tmidas/_reader.py +6 -10
  9. napari_tmidas/_registry.py +2 -2
  10. napari_tmidas/_roi_colocalization.py +1221 -84
  11. napari_tmidas/_tests/test_crop_anything.py +123 -0
  12. napari_tmidas/_tests/test_env_manager.py +89 -0
  13. napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
  14. napari_tmidas/_tests/test_init.py +98 -0
  15. napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
  16. napari_tmidas/_tests/test_label_inspection.py +86 -0
  17. napari_tmidas/_tests/test_processing_basic.py +500 -0
  18. napari_tmidas/_tests/test_processing_worker.py +142 -0
  19. napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
  20. napari_tmidas/_tests/test_registry.py +70 -2
  21. napari_tmidas/_tests/test_scipy_filters.py +168 -0
  22. napari_tmidas/_tests/test_skimage_filters.py +259 -0
  23. napari_tmidas/_tests/test_split_channels.py +217 -0
  24. napari_tmidas/_tests/test_spotiflow.py +87 -0
  25. napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
  26. napari_tmidas/_tests/test_ui_utils.py +68 -0
  27. napari_tmidas/_tests/test_widget.py +30 -0
  28. napari_tmidas/_tests/test_windows_basic.py +66 -0
  29. napari_tmidas/_ui_utils.py +57 -0
  30. napari_tmidas/_version.py +16 -3
  31. napari_tmidas/_widget.py +41 -4
  32. napari_tmidas/processing_functions/basic.py +557 -20
  33. napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
  34. napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
  35. napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
  36. napari_tmidas/processing_functions/colocalization.py +513 -56
  37. napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
  38. napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
  39. napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
  40. napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
  41. napari_tmidas/processing_functions/sam2_mp4.py +274 -195
  42. napari_tmidas/processing_functions/scipy_filters.py +403 -8
  43. napari_tmidas/processing_functions/skimage_filters.py +424 -212
  44. napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
  45. napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
  46. napari_tmidas/processing_functions/timepoint_merger.py +334 -86
  47. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +70 -30
  48. napari_tmidas-0.2.4.dist-info/RECORD +63 -0
  49. napari_tmidas/_tests/__init__.py +0 -0
  50. napari_tmidas-0.2.2.dist-info/RECORD +0 -40
  51. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
  52. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
  53. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
  54. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,168 @@
1
+ # src/napari_tmidas/_tests/test_scipy_filters.py
2
+ import numpy as np
3
+ import pytest
4
+
5
+ from napari_tmidas.processing_functions import scipy_filters
6
+ from napari_tmidas.processing_functions.scipy_filters import gaussian_blur
7
+
8
+
9
+ class TestScipyFilters:
10
+ def test_resize_labels(self):
11
+ """Test resizing label objects while maintaining original array dimensions."""
12
+ from napari_tmidas.processing_functions.scipy_filters import (
13
+ resize_labels,
14
+ )
15
+
16
+ label_image = np.zeros((10, 10), dtype=np.uint8)
17
+ label_image[2:8, 2:8] = 3
18
+
19
+ # Test with float - dimensions should stay the same
20
+ scale_factor = 0.5
21
+ scaled = resize_labels(label_image, scale_factor=scale_factor)
22
+ # Function maintains original dimensions
23
+ assert scaled.shape == label_image.shape
24
+ assert set(np.unique(scaled)).issubset({0, 3})
25
+ # Objects should be smaller (fewer pixels with label 3)
26
+ assert np.sum(scaled == 3) > 0
27
+ assert np.sum(scaled == 3) < np.sum(label_image == 3)
28
+
29
+ # Test with string
30
+ scale_factor_str = "0.5"
31
+ scaled_str = resize_labels(label_image, scale_factor=scale_factor_str)
32
+ assert scaled_str.shape == label_image.shape
33
+ assert set(np.unique(scaled_str)).issubset({0, 3})
34
+ assert np.sum(scaled_str == 3) > 0
35
+
36
+ def test_gaussian_blur_basic(self):
37
+ """Test basic gaussian blur functionality"""
38
+ image = np.random.rand(100, 100)
39
+
40
+ # Test with default parameters
41
+ result = gaussian_blur(image)
42
+ assert result.shape == image.shape
43
+ assert result.dtype == image.dtype
44
+
45
+ def test_gaussian_blur_with_sigma(self):
46
+ """Test gaussian blur with custom sigma"""
47
+ image = np.random.rand(50, 50)
48
+
49
+ # Test with sigma parameter
50
+ result = gaussian_blur(image, sigma=2.0)
51
+ assert result.shape == image.shape
52
+ assert result.dtype == image.dtype
53
+
54
+ def test_gaussian_blur_3d(self):
55
+ """Test gaussian blur on 3D image"""
56
+ image = np.random.rand(20, 20, 20)
57
+
58
+ result = gaussian_blur(image, sigma=1.0)
59
+ assert result.shape == image.shape
60
+ assert result.dtype == image.dtype
61
+
62
+ @pytest.mark.skipif(
63
+ not scipy_filters.SCIPY_AVAILABLE, reason="SciPy is required"
64
+ )
65
+ def test_subdivide_labels_3layers_combined_output(self):
66
+ from napari_tmidas.processing_functions.scipy_filters import (
67
+ subdivide_labels_3layers,
68
+ )
69
+
70
+ label_image = np.zeros((9, 9), dtype=np.uint16)
71
+ label_image[2:7, 2:7] = 1
72
+
73
+ result = subdivide_labels_3layers(label_image)
74
+
75
+ assert result.shape == label_image.shape
76
+ unique_ids = set(np.unique(result))
77
+ assert unique_ids.issubset({0, 1, 2, 3})
78
+ assert {1, 2, 3}.issubset(unique_ids)
79
+
80
+ @pytest.mark.skipif(
81
+ not scipy_filters.SCIPY_AVAILABLE, reason="SciPy is required"
82
+ )
83
+ def test_subdivide_labels_3layers_dtype_promotion(self):
84
+ from napari_tmidas.processing_functions.scipy_filters import (
85
+ subdivide_labels_3layers,
86
+ )
87
+
88
+ label_image = np.zeros((9, 9), dtype=np.uint8)
89
+ label_image[2:7, 2:7] = 200
90
+
91
+ result = subdivide_labels_3layers(label_image)
92
+
93
+ assert result.dtype in (np.uint32, np.uint64)
94
+ assert result.max() == 200 + 2 * 200
95
+
96
+ @pytest.mark.skipif(
97
+ not scipy_filters.SCIPY_AVAILABLE, reason="SciPy is required"
98
+ )
99
+ def test_subdivide_labels_3layers_empty(self):
100
+ from napari_tmidas.processing_functions.scipy_filters import (
101
+ subdivide_labels_3layers,
102
+ )
103
+
104
+ label_image = np.zeros((5, 5, 5), dtype=np.uint16)
105
+
106
+ result = subdivide_labels_3layers(label_image)
107
+
108
+ np.testing.assert_array_equal(result, label_image)
109
+
110
+ @pytest.mark.skipif(
111
+ not scipy_filters.SCIPY_AVAILABLE, reason="SciPy is required"
112
+ )
113
+ def test_subdivide_labels_3layers_half_body(self):
114
+ from napari_tmidas.processing_functions.scipy_filters import (
115
+ subdivide_labels_3layers,
116
+ )
117
+
118
+ # Create a simple half-spheroid-like object
119
+ label_image = np.zeros((20, 20, 20), dtype=np.uint16)
120
+ # Fill upper half with a sphere-like object
121
+ for z in range(10, 20):
122
+ for y in range(20):
123
+ for x in range(20):
124
+ if (z - 15) ** 2 + (y - 10) ** 2 + (x - 10) ** 2 <= 25:
125
+ label_image[z, y, x] = 1
126
+
127
+ # Test with half-body mode disabled (default)
128
+ result_normal = subdivide_labels_3layers(
129
+ label_image, is_half_body=False
130
+ )
131
+ assert result_normal.shape == label_image.shape
132
+ unique_normal = np.unique(result_normal)
133
+ assert len(unique_normal) > 1 # Should have background + layers
134
+
135
+ # Test with half-body mode enabled (cut along Z-axis = 0)
136
+ result_half_body = subdivide_labels_3layers(
137
+ label_image, is_half_body=True, cut_axis=0
138
+ )
139
+ assert result_half_body.shape == label_image.shape
140
+ unique_half = np.unique(result_half_body)
141
+ assert len(unique_half) > 1 # Should have background + layers
142
+
143
+ # Both modes should produce layered results
144
+ # In normal mode, the cut surface (z=10 plane) may show fewer layers
145
+ # because it's treating it as a partial object
146
+ cut_surface_normal = result_normal[10, :, :]
147
+ cut_surface_half = result_half_body[10, :, :]
148
+
149
+ # Both should have some non-zero values at the cut surface
150
+ assert np.sum(cut_surface_normal > 0) > 0
151
+ assert np.sum(cut_surface_half > 0) > 0
152
+
153
+ # Half-body mode should ideally show more variety, but the exact
154
+ # behavior depends on the implementation details. Just verify both work.
155
+ unique_layers_normal = len(
156
+ np.unique(cut_surface_normal[cut_surface_normal > 0])
157
+ )
158
+ unique_layers_half = len(
159
+ np.unique(cut_surface_half[cut_surface_half > 0])
160
+ )
161
+ assert unique_layers_normal >= 1
162
+ assert unique_layers_half >= 1
163
+
164
+ # Test invalid cut_axis
165
+ with pytest.raises(ValueError):
166
+ subdivide_labels_3layers(
167
+ label_image, is_half_body=True, cut_axis=5
168
+ )
@@ -0,0 +1,259 @@
1
+ # src/napari_tmidas/_tests/test_skimage_filters.py
2
+ import numpy as np
3
+
4
+ from napari_tmidas.processing_functions.skimage_filters import (
5
+ adaptive_threshold_bright,
6
+ equalize_histogram,
7
+ invert_image,
8
+ percentile_threshold,
9
+ rolling_ball_background,
10
+ simple_thresholding,
11
+ )
12
+
13
+
14
+ class TestSkimageFilters:
15
+
16
+ def test_invert_image_basic(self):
17
+ """Test basic image inversion functionality"""
18
+ image = np.random.rand(100, 100)
19
+
20
+ # Test with default parameters
21
+ result = invert_image(image)
22
+ assert result.shape == image.shape
23
+ assert result.dtype == image.dtype
24
+
25
+ def test_invert_image_binary(self):
26
+ """Test image inversion on binary image"""
27
+ image = np.array([[0, 1], [1, 0]], dtype=np.uint8)
28
+
29
+ result = invert_image(image)
30
+ # skimage.util.invert inverts all bits, so 0->255, 1->254 for uint8
31
+ expected = np.array([[255, 254], [254, 255]], dtype=np.uint8)
32
+ np.testing.assert_array_equal(result, expected)
33
+
34
+ def test_invert_image_3d(self):
35
+ """Test image inversion on 3D image"""
36
+ image = np.random.rand(20, 20, 20)
37
+
38
+ result = invert_image(image)
39
+ assert result.shape == image.shape
40
+ assert result.dtype == image.dtype
41
+
42
+ def test_simple_thresholding_returns_uint32(self):
43
+ """Test that manual thresholding returns uint8 with value 255 for proper display"""
44
+ image = np.array([[0, 100, 200], [50, 150, 255]], dtype=np.uint8)
45
+
46
+ result = simple_thresholding(image, threshold=128)
47
+
48
+ # Check dtype is uint8
49
+ assert result.dtype == np.uint8
50
+
51
+ # Check values are binary (0 or 255)
52
+ assert set(np.unique(result)).issubset({0, 255})
53
+
54
+ # Check correct thresholding
55
+ expected = np.array([[0, 0, 255], [0, 255, 255]], dtype=np.uint8)
56
+ np.testing.assert_array_equal(result, expected)
57
+
58
+ def test_simple_thresholding_different_thresholds(self):
59
+ """Test manual thresholding with different threshold values"""
60
+ image = np.arange(0, 256, dtype=np.uint8).reshape(16, 16)
61
+
62
+ # Test with low threshold
63
+ result_low = simple_thresholding(image, threshold=50)
64
+ assert result_low.dtype == np.uint8
65
+ assert (
66
+ np.sum(result_low == 255) > np.prod(result_low.shape) * 0.8
67
+ ) # Most pixels above 50
68
+
69
+ # Test with high threshold
70
+ result_high = simple_thresholding(image, threshold=200)
71
+ assert result_high.dtype == np.uint8
72
+ assert (
73
+ np.sum(result_high == 255) < np.prod(result_high.shape) * 0.3
74
+ ) # Most pixels below 200
75
+
76
+
77
+ class TestBrightRegionExtraction:
78
+ """Test suite for bright region extraction functions"""
79
+
80
+ def test_percentile_threshold_original(self):
81
+ """Test percentile thresholding with original values"""
82
+ # Create image with gradient
83
+ image = np.arange(0, 256, dtype=np.uint8).reshape(16, 16)
84
+
85
+ result = percentile_threshold(
86
+ image, percentile=90, output_type="original"
87
+ )
88
+
89
+ # Only top 10% should remain
90
+ assert result.shape == image.shape
91
+ assert np.sum(result > 0) < image.size * 0.15 # Allow some margin
92
+ assert result.max() == image.max() # Original max value preserved
93
+
94
+ def test_percentile_threshold_binary(self):
95
+ """Test percentile thresholding with binary output"""
96
+ image = np.random.randint(0, 256, size=(50, 50), dtype=np.uint8)
97
+
98
+ result = percentile_threshold(
99
+ image, percentile=80, output_type="binary"
100
+ )
101
+
102
+ # Should be binary
103
+ assert result.dtype == np.uint8
104
+ assert set(np.unique(result)).issubset({0, 255})
105
+
106
+ def test_rolling_ball_background_subtraction(self):
107
+ """Test rolling ball background subtraction"""
108
+ # Create image with uneven background and bright spot
109
+ x, y = np.meshgrid(np.arange(100), np.arange(100))
110
+ background = (50 + 30 * np.sin(x / 20) + 30 * np.sin(y / 20)).astype(
111
+ np.uint8
112
+ )
113
+ image = background.copy()
114
+ image[40:60, 40:60] += 150 # Add bright feature
115
+
116
+ result = rolling_ball_background(image, radius=30)
117
+
118
+ # Background should be reduced
119
+ assert result.shape == image.shape
120
+ # Center of bright spot should be brighter in result than in corners
121
+ assert result[50, 50] > result[10, 10]
122
+
123
+ def test_adaptive_threshold_bright(self):
124
+ """Test adaptive thresholding with bright bias"""
125
+ # Create image with varying brightness
126
+ image = np.random.randint(0, 256, size=(100, 100), dtype=np.uint8)
127
+
128
+ result = adaptive_threshold_bright(image, block_size=35, offset=-10.0)
129
+
130
+ # Should be binary
131
+ assert result.dtype == np.uint8
132
+ assert set(np.unique(result)).issubset({0, 255})
133
+ assert result.shape == image.shape
134
+
135
+ def test_adaptive_threshold_even_blocksize(self):
136
+ """Test that even block size is handled correctly"""
137
+ image = np.random.randint(0, 256, size=(50, 50), dtype=np.uint8)
138
+
139
+ # Should handle even block size by making it odd
140
+ result = adaptive_threshold_bright(image, block_size=34, offset=0)
141
+
142
+ assert result.shape == image.shape
143
+ assert result.dtype == np.uint8
144
+
145
+
146
+ class TestCLAHE:
147
+ """Test suite for CLAHE (Contrast Limited Adaptive Histogram Equalization)"""
148
+
149
+ def test_clahe_basic(self):
150
+ """Test basic CLAHE functionality"""
151
+ # Create a dark image with weak bright features
152
+ image = np.zeros((100, 100), dtype=np.float32)
153
+ image[40:60, 40:60] = 0.1 # Weak bright region
154
+
155
+ result = equalize_histogram(image)
156
+
157
+ # Output should be same shape
158
+ assert result.shape == image.shape
159
+ # Output should be normalized to [0, 1] range
160
+ assert result.min() >= 0
161
+ assert result.max() <= 1
162
+ # Contrast should be enhanced (std deviation should increase)
163
+ assert result.std() > image.std()
164
+
165
+ def test_clahe_dark_with_membranes(self):
166
+ """Test CLAHE on dark images with weak bright membranes (the use case that failed)"""
167
+ # Create a realistic dark image with weak membrane-like structures
168
+ np.random.seed(42)
169
+ image = np.random.normal(0.05, 0.01, (200, 200)) # Dark background
170
+ image = np.clip(image, 0, 1)
171
+
172
+ # Add weak membrane-like structures
173
+ image[50:55, :] += 0.1 # Horizontal membrane
174
+ image[:, 100:105] += 0.1 # Vertical membrane
175
+ image = np.clip(image, 0, 1)
176
+
177
+ result = equalize_histogram(image, clip_limit=0.01)
178
+
179
+ # Should not produce black image
180
+ assert result.max() > 0.1, "CLAHE should not produce near-black images"
181
+ # Should enhance contrast
182
+ assert result.std() > image.std()
183
+ # Membranes should be more visible (higher values)
184
+ membrane_region = result[50:55, :]
185
+ background_region = result[10:20, 10:20]
186
+ assert membrane_region.mean() > background_region.mean()
187
+
188
+ def test_clahe_custom_kernel_size(self):
189
+ """Test CLAHE with custom kernel size"""
190
+ image = np.random.rand(256, 256)
191
+
192
+ result = equalize_histogram(image, kernel_size=64)
193
+
194
+ assert result.shape == image.shape
195
+ assert result.min() >= 0
196
+ assert result.max() <= 1
197
+
198
+ def test_clahe_auto_kernel_size(self):
199
+ """Test CLAHE with automatic kernel size calculation"""
200
+ # Small image
201
+ small_image = np.random.rand(128, 128)
202
+ result_small = equalize_histogram(small_image, kernel_size=0)
203
+ assert result_small.shape == small_image.shape
204
+
205
+ # Large image
206
+ large_image = np.random.rand(1024, 1024)
207
+ result_large = equalize_histogram(large_image, kernel_size=0)
208
+ assert result_large.shape == large_image.shape
209
+
210
+ def test_clahe_different_clip_limits(self):
211
+ """Test CLAHE with different clip limit values"""
212
+ image = np.random.rand(100, 100) * 0.2 # Dark image
213
+
214
+ # Low clip limit (less contrast enhancement)
215
+ result_low = equalize_histogram(image, clip_limit=0.005)
216
+
217
+ # High clip limit (more contrast enhancement)
218
+ result_high = equalize_histogram(image, clip_limit=0.05)
219
+
220
+ # Both should enhance contrast compared to original
221
+ assert result_low.std() > image.std()
222
+ assert result_high.std() > image.std()
223
+ # Higher clip limit typically gives more contrast (but not always guaranteed)
224
+ assert (
225
+ result_high.max() >= result_low.max() * 0.8
226
+ ) # Allow some tolerance
227
+
228
+ def test_clahe_3d_image(self):
229
+ """Test CLAHE on 3D image (should work on last 2 dimensions)"""
230
+ # Create 3D image (e.g., time series or z-stack)
231
+ image_3d = np.random.rand(10, 100, 100) * 0.3
232
+
233
+ result = equalize_histogram(image_3d)
234
+
235
+ assert result.shape == image_3d.shape
236
+ # Each slice should be enhanced independently
237
+ assert result.std() > image_3d.std()
238
+
239
+ def test_clahe_preserves_dtype(self):
240
+ """Test that CLAHE preserves the original dtype"""
241
+ # Test uint8
242
+ img_uint8 = np.random.randint(0, 256, (100, 100), dtype=np.uint8)
243
+ result_uint8 = equalize_histogram(img_uint8)
244
+ assert result_uint8.dtype == np.uint8
245
+ assert result_uint8.max() <= 255
246
+ assert result_uint8.min() >= 0
247
+
248
+ # Test uint16
249
+ img_uint16 = np.random.randint(0, 65536, (100, 100), dtype=np.uint16)
250
+ result_uint16 = equalize_histogram(img_uint16)
251
+ assert result_uint16.dtype == np.uint16
252
+ assert result_uint16.max() <= 65535
253
+
254
+ # Test float32
255
+ img_float32 = np.random.rand(100, 100).astype(np.float32)
256
+ result_float32 = equalize_histogram(img_float32)
257
+ assert result_float32.dtype == np.float32
258
+ assert result_float32.max() <= 1.0
259
+ assert result_float32.min() >= 0.0
@@ -0,0 +1,217 @@
1
+ """Test for split_channels function with various image formats"""
2
+
3
+ import os
4
+ import sys
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ # Add the source directory to the path
10
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
11
+
12
+ from napari_tmidas.processing_functions.basic import split_channels
13
+
14
+
15
+ class TestSplitChannels:
16
+ """Test the split_channels function with various input formats"""
17
+
18
+ def test_split_tcyx_python_format(self):
19
+ """Test splitting TCYX image (Time, Channel, Y, X) with python format"""
20
+ # Create a TCYX image: 5 timepoints, 3 channels, 100x100 pixels
21
+ tcyx_image = np.random.rand(5, 3, 100, 100)
22
+
23
+ result = split_channels(
24
+ tcyx_image, num_channels=3, time_steps=5, output_format="python"
25
+ )
26
+
27
+ # Result should be (3, 5, 100, 100): 3 channels, each with shape (5, 100, 100)
28
+ assert result.shape == (
29
+ 3,
30
+ 5,
31
+ 100,
32
+ 100,
33
+ ), f"Expected shape (3, 5, 100, 100), got {result.shape}"
34
+
35
+ # Each channel should have shape (5, 100, 100)
36
+ for i in range(3):
37
+ assert result[i].shape == (
38
+ 5,
39
+ 100,
40
+ 100,
41
+ ), f"Channel {i} has incorrect shape {result[i].shape}"
42
+
43
+ def test_split_tcyx_fiji_format(self):
44
+ """Test splitting TCYX image with Fiji format"""
45
+ # Create a TCYX image: 5 timepoints, 3 channels, 100x100 pixels
46
+ tcyx_image = np.random.rand(5, 3, 100, 100)
47
+
48
+ result = split_channels(
49
+ tcyx_image, num_channels=3, time_steps=5, output_format="fiji"
50
+ )
51
+
52
+ # Result should be (3, 5, 100, 100): 3 channels, each with shape (5, 100, 100)
53
+ assert result.shape == (
54
+ 3,
55
+ 5,
56
+ 100,
57
+ 100,
58
+ ), f"Expected shape (3, 5, 100, 100), got {result.shape}"
59
+
60
+ # Each channel should have shape (5, 100, 100)
61
+ for i in range(3):
62
+ assert result[i].shape == (
63
+ 5,
64
+ 100,
65
+ 100,
66
+ ), f"Channel {i} has incorrect shape {result[i].shape}"
67
+
68
+ def test_split_yxc_image(self):
69
+ """Test splitting standard RGB image (YXC)"""
70
+ # Create a YXC image: 100x100 pixels, 3 channels
71
+ yxc_image = np.random.rand(100, 100, 3)
72
+
73
+ result = split_channels(
74
+ yxc_image, num_channels=3, time_steps=0, output_format="python"
75
+ )
76
+
77
+ # Result should be (3, 100, 100): 3 channels, each with shape (100, 100)
78
+ assert result.shape == (
79
+ 3,
80
+ 100,
81
+ 100,
82
+ ), f"Expected shape (3, 100, 100), got {result.shape}"
83
+
84
+ # Each channel should have shape (100, 100)
85
+ for i in range(3):
86
+ assert result[i].shape == (
87
+ 100,
88
+ 100,
89
+ ), f"Channel {i} has incorrect shape {result[i].shape}"
90
+
91
+ def test_split_zyxc_image(self):
92
+ """Test splitting 3D color image (ZYXC)"""
93
+ # Create a ZYXC image: 10 z-slices, 100x100 pixels, 3 channels
94
+ zyxc_image = np.random.rand(10, 100, 100, 3)
95
+
96
+ result = split_channels(
97
+ zyxc_image, num_channels=3, time_steps=0, output_format="python"
98
+ )
99
+
100
+ # Result should be (3, 10, 100, 100): 3 channels, each with shape (10, 100, 100)
101
+ assert result.shape == (
102
+ 3,
103
+ 10,
104
+ 100,
105
+ 100,
106
+ ), f"Expected shape (3, 10, 100, 100), got {result.shape}"
107
+
108
+ # Each channel should have shape (10, 100, 100)
109
+ for i in range(3):
110
+ assert result[i].shape == (
111
+ 10,
112
+ 100,
113
+ 100,
114
+ ), f"Channel {i} has incorrect shape {result[i].shape}"
115
+
116
+ def test_split_tzyxc_image(self):
117
+ """Test splitting 4D time-series color Z-stack (TZYXC)"""
118
+ # Create a TZYXC image: 5 timepoints, 10 z-slices, 100x100 pixels, 3 channels
119
+ tzyxc_image = np.random.rand(5, 10, 100, 100, 3)
120
+
121
+ result = split_channels(
122
+ tzyxc_image, num_channels=3, time_steps=5, output_format="python"
123
+ )
124
+
125
+ # Result should be (3, 5, 10, 100, 100): 3 channels, each with shape (5, 10, 100, 100)
126
+ assert result.shape == (
127
+ 3,
128
+ 5,
129
+ 10,
130
+ 100,
131
+ 100,
132
+ ), f"Expected shape (3, 5, 10, 100, 100), got {result.shape}"
133
+
134
+ # Each channel should have shape (5, 10, 100, 100)
135
+ for i in range(3):
136
+ assert result[i].shape == (
137
+ 5,
138
+ 10,
139
+ 100,
140
+ 100,
141
+ ), f"Channel {i} has incorrect shape {result[i].shape}"
142
+
143
+ def test_split_channels_with_4_channels(self):
144
+ """Test splitting image with 4 channels (RGBA)"""
145
+ # Create a YXC image: 100x100 pixels, 4 channels
146
+ yxc_image = np.random.rand(100, 100, 4)
147
+
148
+ result = split_channels(
149
+ yxc_image, num_channels=4, time_steps=0, output_format="python"
150
+ )
151
+
152
+ # Result should be (4, 100, 100): 4 channels, each with shape (100, 100)
153
+ assert result.shape == (
154
+ 4,
155
+ 100,
156
+ 100,
157
+ ), f"Expected shape (4, 100, 100), got {result.shape}"
158
+
159
+ # Each channel should have shape (100, 100)
160
+ for i in range(4):
161
+ assert result[i].shape == (
162
+ 100,
163
+ 100,
164
+ ), f"Channel {i} has incorrect shape {result[i].shape}"
165
+
166
+ def test_split_channels_verifies_data_integrity(self):
167
+ """Test that split channels contain the correct data"""
168
+ # Create a simple test image where we can verify the data
169
+ tcyx_image = np.zeros((2, 3, 10, 10)) # 2 timepoints, 3 channels
170
+
171
+ # Set distinct values for each channel
172
+ tcyx_image[:, 0, :, :] = 1.0 # Channel 0
173
+ tcyx_image[:, 1, :, :] = 2.0 # Channel 1
174
+ tcyx_image[:, 2, :, :] = 3.0 # Channel 2
175
+
176
+ result = split_channels(
177
+ tcyx_image, num_channels=3, time_steps=2, output_format="python"
178
+ )
179
+
180
+ # Verify shape
181
+ assert result.shape == (3, 2, 10, 10)
182
+
183
+ # Verify data integrity
184
+ assert np.allclose(result[0], 1.0), "Channel 0 data incorrect"
185
+ assert np.allclose(result[1], 2.0), "Channel 1 data incorrect"
186
+ assert np.allclose(result[2], 3.0), "Channel 2 data incorrect"
187
+
188
+ def test_split_channels_auto_detect_mismatch(self):
189
+ """Test that function handles mismatch between specified and actual channel count"""
190
+ # Create a TCYX image: 5 timepoints, 4 channels, 100x100 pixels
191
+ tcyx_image = np.random.rand(5, 4, 100, 100)
192
+
193
+ # Specify 3 channels when there are actually 4
194
+ result = split_channels(
195
+ tcyx_image, num_channels=3, time_steps=5, output_format="python"
196
+ )
197
+
198
+ # Should auto-detect and use 4 channels
199
+ assert result.shape == (
200
+ 4,
201
+ 5,
202
+ 100,
203
+ 100,
204
+ ), f"Expected shape (4, 5, 100, 100), got {result.shape}"
205
+
206
+ def test_split_channels_dimension_error(self):
207
+ """Test that function raises error for invalid input"""
208
+ # Create a 2D image (should fail)
209
+ image_2d = np.random.rand(100, 100)
210
+
211
+ with pytest.raises(ValueError, match="at least 3 dimensions"):
212
+ split_channels(image_2d, num_channels=3, time_steps=0)
213
+
214
+
215
+ if __name__ == "__main__":
216
+ # Run tests with pytest
217
+ pytest.main([__file__, "-v"])