napari-tmidas 0.1.9__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,14 @@
1
1
  # processing_functions/basic.py
2
2
  """
3
- Basic image processing functions that don't require additional dependencies.
3
+ Basic image processing functions
4
4
  """
5
+ import concurrent.futures
6
+ import os
7
+ import traceback
8
+
9
+ import dask.array as da
5
10
  import numpy as np
11
+ import tifffile
6
12
 
7
13
  from napari_tmidas._registry import BatchProcessingRegistry
8
14
 
@@ -101,24 +107,107 @@ def max_z_projection(image: np.ndarray) -> np.ndarray:
101
107
 
102
108
 
103
109
  @BatchProcessingRegistry.register(
104
- name="Split Channels",
105
- suffix="_split_channels",
110
+ name="Max Z Projection (TZYX)",
111
+ suffix="_maxZ_tzyx",
112
+ description="Maximum intensity projection along the Z-axis for TZYX data",
113
+ parameters={}, # No parameters needed - fully automatic
114
+ )
115
+ def max_z_projection_tzyx(image: np.ndarray) -> np.ndarray:
116
+ """
117
+ Memory-efficient maximum intensity projection along the Z-axis for TZYX data.
118
+
119
+ This function intelligently chooses the most memory-efficient approach
120
+ based on the input data size and available system memory.
121
+
122
+ Parameters:
123
+ -----------
124
+ image : numpy.ndarray
125
+ Input 4D image with TZYX dimensions
126
+
127
+ Returns:
128
+ --------
129
+ numpy.ndarray
130
+ 3D image with TYX dimensions after max projection
131
+ """
132
+ # Validate input dimensions
133
+ if image.ndim != 4:
134
+ raise ValueError(f"Expected 4D image (TZYX), got {image.ndim}D image")
135
+
136
+ # Get dimensions
137
+ t_size, z_size, y_size, x_size = image.shape
138
+
139
+ # For Z projection, we only need one Z plane in memory at a time
140
+ # so we can process this plane by plane to minimize memory usage
141
+
142
+ # Create output array with appropriate dimensions and same dtype
143
+ result = np.zeros((t_size, y_size, x_size), dtype=image.dtype)
144
+
145
+ # Process each time point separately to minimize memory usage
146
+ for t in range(t_size):
147
+ # If data type allows direct max, use it
148
+ if np.issubdtype(image.dtype, np.integer) or np.issubdtype(
149
+ image.dtype, np.floating
150
+ ):
151
+ # Process Z planes efficiently
152
+ # Start with the first Z plane
153
+ z_max = image[t, 0].copy()
154
+
155
+ # Compare with each subsequent Z plane
156
+ for z in range(1, z_size):
157
+ # Use numpy's maximum function to update max values in-place
158
+ np.maximum(z_max, image[t, z], out=z_max)
159
+
160
+ # Store result for this time point
161
+ result[t] = z_max
162
+ else:
163
+ # For unusual data types, fall back to numpy's max function
164
+ result[t] = np.max(image[t], axis=0)
165
+
166
+ return result
167
+
168
+
169
+ @BatchProcessingRegistry.register(
170
+ name="Split Color Channels",
171
+ suffix="_split_color_channels",
106
172
  description="Splits the color channels of the image",
107
173
  parameters={
108
174
  "num_channels": {
109
- "type": "integer",
175
+ "type": int,
110
176
  "default": 3,
177
+ "min": 2,
178
+ "max": 4,
111
179
  "description": "Number of color channels in the image",
112
- }
180
+ },
181
+ "time_steps": {
182
+ "type": int,
183
+ "default": 0,
184
+ "min": 0,
185
+ "max": 1000,
186
+ "description": "Number of time steps (leave 0 if not a time series)",
187
+ },
188
+ "output_format": {
189
+ "type": str,
190
+ "default": "python",
191
+ "options": ["python", "fiji"],
192
+ "description": "Output dimension order: python (standard) or fiji (ImageJ/Fiji compatible)",
193
+ },
113
194
  },
114
195
  )
115
- def split_channels(image: np.ndarray, num_channels: int = 3) -> np.ndarray:
196
+ def split_channels(
197
+ image: np.ndarray,
198
+ num_channels: int = 3,
199
+ time_steps: int = 0,
200
+ output_format: str = "python",
201
+ ) -> np.ndarray:
116
202
  """
117
203
  Split the image into separate channels based on the specified number of channels.
204
+ Can handle various dimensional orderings including time series data.
118
205
 
119
206
  Args:
120
207
  image: Input image array (at least 3D: XYC or higher dimensions)
121
208
  num_channels: Number of channels in the image (default: 3)
209
+ time_steps: Number of time steps if time series (default: 0, meaning not a time series)
210
+ output_format: Dimension order format, either "python" (standard) or "fiji" (ImageJ compatible)
122
211
 
123
212
  Returns:
124
213
  Stacked array of channels with shape (num_channels, ...)
@@ -130,30 +219,153 @@ def split_channels(image: np.ndarray, num_channels: int = 3) -> np.ndarray:
130
219
  )
131
220
 
132
221
  print(f"Image shape: {image.shape}")
133
- num_channels = int(num_channels)
134
- # Identify the channel axis
135
- possible_axes = [
136
- axis
137
- for axis, dim_size in enumerate(image.shape)
138
- if dim_size == num_channels
139
- ]
140
- # print(f"Possible axes: {possible_axes}")
141
- if len(possible_axes) != 1:
142
-
222
+ is_timelapse = time_steps > 0
223
+ is_3d = (
224
+ image.ndim > 3
225
+ ) # More than 3 dimensions likely means 3D + channels or time series
226
+
227
+ # Find channel axis based on provided channel count
228
+ channel_axis = None
229
+ for axis, dim_size in enumerate(image.shape):
230
+ if dim_size == num_channels:
231
+ # Found a dimension matching the specified channel count
232
+ channel_axis = axis
233
+ # If we have multiple matching dimensions, prefer the one that's not likely spatial
234
+ if (
235
+ axis < image.ndim - 2
236
+ ): # Not one of the last two dimensions (likely spatial)
237
+ break
238
+
239
+ # If channel axis is not found with exact match, look for other possibilities
240
+ if channel_axis is None:
241
+ # Try to infer channel axis using heuristics
242
+ ndim = image.ndim
243
+
244
+ # Check dimensions for a small value (1-16) that could be channels
245
+ for i, dim_size in enumerate(image.shape):
246
+ # Skip dimensions that are likely spatial (Y,X) - typically the last two
247
+ if i >= ndim - 2:
248
+ continue
249
+ # Skip first dimension if this is a time series
250
+ if is_timelapse and i == 0:
251
+ continue
252
+ # A dimension with size 1-16 is likely channels
253
+ if 1 <= dim_size <= 16:
254
+ channel_axis = i
255
+ break
256
+
257
+ # If still not found, check even the spatial dimensions (for RGB images)
258
+ if channel_axis is None and image.shape[-1] <= 16:
259
+ channel_axis = ndim - 1
260
+
261
+ if channel_axis is None:
143
262
  raise ValueError(
144
- f"Could not uniquely identify a channel axis with {num_channels} channels. "
145
- f"Found {len(possible_axes)} possible axes: {possible_axes}. "
146
- f"Image shape: {image.shape}"
263
+ f"Could not identify a channel axis. Please check if the number of channels ({num_channels}) "
264
+ f"matches any dimension in your image shape {image.shape}"
147
265
  )
148
266
 
149
- channel_axis = possible_axes[0]
150
267
  print(f"Channel axis identified: {channel_axis}")
151
268
 
152
- # Split and process channels
269
+ # Generate dimensional understanding for better handling
270
+ # Create axes string to understand dimension ordering
271
+ axes = [""] * image.ndim
272
+
273
+ # Assign channel axis
274
+ axes[channel_axis] = "C"
275
+
276
+ # Assign time axis if present
277
+ if is_timelapse and 0 not in [
278
+ channel_axis
279
+ ]: # If channel is not at position 0
280
+ axes[0] = "T"
281
+
282
+ # Assign remaining spatial dimensions
283
+ remaining_dims = [i for i in range(image.ndim) if axes[i] == ""]
284
+ spatial_axes = []
285
+ if is_3d and len(remaining_dims) > 2:
286
+ # We have Z dimension
287
+ spatial_axes.append("Z")
288
+
289
+ # Add Y and X
290
+ spatial_axes.extend(["Y", "X"])
291
+
292
+ # Assign remaining dimensions
293
+ for i, dim in enumerate(remaining_dims):
294
+ if i < len(spatial_axes):
295
+ axes[dim] = spatial_axes[i]
296
+ else:
297
+ axes[dim] = "A" # Anonymous dimension
298
+
299
+ axes_str = "".join(axes)
300
+ print(f"Inferred dimension order: {axes_str}")
301
+
302
+ # Split along the channel axis
303
+ actual_channels = image.shape[channel_axis]
304
+ if actual_channels != num_channels:
305
+ print(
306
+ f"Warning: Specified {num_channels} channels but found {actual_channels} in the data. Using {actual_channels}."
307
+ )
308
+ num_channels = actual_channels
309
+
310
+ # Split channels
153
311
  channels = np.split(image, num_channels, axis=channel_axis)
154
- # channels = [np.squeeze(ch, axis=channel_axis) for ch in channels]
155
312
 
156
- return np.stack(channels, axis=0)
313
+ # Process output format
314
+ result_channels = []
315
+ for i, channel_img in enumerate(channels):
316
+ # Get original axes without channel
317
+ axes_without_channel = axes.copy()
318
+ del axes_without_channel[channel_axis]
319
+ axes_without_channel_str = "".join(axes_without_channel)
320
+
321
+ # For fiji format, reorganize dimensions to TZYX order
322
+ if output_format.lower() == "fiji":
323
+ # Map dimensions to positions
324
+ dim_indices = {
325
+ dim: i for i, dim in enumerate(axes_without_channel_str)
326
+ }
327
+
328
+ # Build target order and transpose indices
329
+ target_order = ""
330
+ transpose_indices = []
331
+
332
+ # Add T if exists
333
+ if "T" in dim_indices:
334
+ target_order += "T"
335
+ transpose_indices.append(dim_indices["T"])
336
+
337
+ # Add Z if exists
338
+ if "Z" in dim_indices:
339
+ target_order += "Z"
340
+ transpose_indices.append(dim_indices["Z"])
341
+
342
+ # Add Y and X (should always exist)
343
+ if "Y" in dim_indices and "X" in dim_indices:
344
+ target_order += "YX"
345
+ transpose_indices.append(dim_indices["Y"])
346
+ transpose_indices.append(dim_indices["X"])
347
+
348
+ # Only transpose if order is different and we have enough dimensions
349
+ if (
350
+ axes_without_channel_str != target_order
351
+ and len(transpose_indices) > 1
352
+ and len(transpose_indices) == len(axes_without_channel)
353
+ ):
354
+ print(
355
+ f"Channel {i}: Transposing from {axes_without_channel_str} to {target_order}"
356
+ )
357
+ result_channels.append(
358
+ np.transpose(channel_img, transpose_indices)
359
+ )
360
+ else:
361
+ # Keep as is
362
+ result_channels.append(channel_img)
363
+ else:
364
+ # For python format, keep as is
365
+ result_channels.append(channel_img)
366
+
367
+ # Stack channels along a new first dimension
368
+ return np.stack(result_channels, axis=0)
157
369
 
158
370
 
159
371
  @BatchProcessingRegistry.register(
@@ -227,3 +439,322 @@ def rgb_to_labels(
227
439
  label_image[mask] = label
228
440
  # Return the label image
229
441
  return label_image
442
+
443
+
444
+ @BatchProcessingRegistry.register(
445
+ name="Split TZYX into ZYX TIFs",
446
+ suffix="_split",
447
+ description="Splits a 4D TZYX image stack into separate 3D ZYX TIFs for each time point using parallel processing",
448
+ parameters={
449
+ "output_name_format": {
450
+ "type": str,
451
+ "default": "{basename}_t{timepoint:03d}",
452
+ "description": "Format for output filenames. Use {basename} and {timepoint} as placeholders",
453
+ },
454
+ "preserve_scale": {
455
+ "type": bool,
456
+ "default": True,
457
+ "description": "Preserve scale/resolution metadata when saving",
458
+ },
459
+ "use_compression": {
460
+ "type": bool,
461
+ "default": True,
462
+ "description": "Apply zlib compression to output files",
463
+ },
464
+ "num_workers": {
465
+ "type": int,
466
+ "default": 4,
467
+ "min": 1,
468
+ "max": 16,
469
+ "description": "Number of worker processes for parallel processing",
470
+ },
471
+ },
472
+ )
473
+ def split_tzyx_stack(
474
+ image: np.ndarray,
475
+ output_name_format: str = "{basename}_t{timepoint:03d}",
476
+ preserve_scale: bool = True,
477
+ use_compression: bool = True,
478
+ num_workers: int = 4,
479
+ ) -> np.ndarray:
480
+ """
481
+ Split a 4D TZYX stack into separate 3D ZYX TIF files using parallel processing.
482
+
483
+ This function takes a 4D TZYX image stack and saves each time point as
484
+ a separate 3D ZYX TIF file. Files are processed in parallel for better performance.
485
+ The original 4D stack is returned unchanged.
486
+
487
+ Parameters:
488
+ -----------
489
+ image : numpy.ndarray
490
+ Input 4D TZYX image stack
491
+ output_name_format : str
492
+ Format string for output filenames. Use {basename} and {timepoint} as placeholders.
493
+ Default: "{basename}_t{timepoint:03d}"
494
+ preserve_scale : bool
495
+ Whether to preserve scale/resolution metadata when saving
496
+ use_compression : bool
497
+ Whether to apply zlib compression to output files
498
+ num_workers : int
499
+ Number of worker processes for parallel file saving
500
+
501
+ Returns:
502
+ --------
503
+ numpy.ndarray
504
+ The original image (unchanged)
505
+ """
506
+ # Validate input dimensions
507
+ if image.ndim != 4:
508
+ print(
509
+ f"Warning: Expected 4D TZYX input, got {image.ndim}D. Returning original image."
510
+ )
511
+ return image
512
+
513
+ # Use dask array to optimize memory usage when processing slices
514
+ chunks = (1,) + image.shape[1:] # Each timepoint is a chunk
515
+ dask_image = da.from_array(image, chunks=chunks)
516
+
517
+ # Store processing parameters for post-processing
518
+ split_tzyx_stack.dask_image = dask_image
519
+ split_tzyx_stack.output_name_format = output_name_format
520
+ split_tzyx_stack.preserve_scale = preserve_scale
521
+ split_tzyx_stack.use_compression = use_compression
522
+ split_tzyx_stack.num_workers = min(
523
+ num_workers, image.shape[0]
524
+ ) # Limit workers to number of timepoints
525
+
526
+ # Mark for post-processing with multiple output files
527
+ split_tzyx_stack.requires_post_processing = True
528
+ split_tzyx_stack.produces_multiple_files = True
529
+ # Tell the processing system to skip creating the original output file
530
+ split_tzyx_stack.skip_original_output = True
531
+
532
+ # Get dimensions for informational purposes
533
+ t_size, z_size, y_size, x_size = image.shape
534
+ print(f"TZYX stack dimensions: {image.shape}, dtype: {image.dtype}")
535
+ print(f"Will generate {t_size} separate ZYX files")
536
+ print(f"Parallelization: {split_tzyx_stack.num_workers} workers")
537
+
538
+ # The actual file saving will happen in the post-processing step
539
+ return image
540
+
541
+
542
+ # Monkey patch ProcessingWorker.process_file to handle parallel TZYX splitting
543
+ try:
544
+ # Import tifffile here to ensure it's available for the monkey patch
545
+ import tifffile
546
+
547
+ from napari_tmidas._file_selector import ProcessingWorker
548
+
549
+ # Define function to save a single timepoint
550
+ def save_timepoint(
551
+ t: int,
552
+ data: np.ndarray,
553
+ output_filepath: str,
554
+ resolution=None,
555
+ use_compression=True,
556
+ ) -> str:
557
+ """
558
+ Save a single timepoint to disk.
559
+
560
+ Parameters:
561
+ -----------
562
+ t : int
563
+ Timepoint index for logging
564
+ data : np.ndarray
565
+ 3D ZYX data to save
566
+ output_filepath : str
567
+ Path to save the file
568
+ resolution : tuple, optional
569
+ Resolution metadata to preserve
570
+ use_compression : bool
571
+ Whether to use compression
572
+
573
+ Returns:
574
+ --------
575
+ str
576
+ Path to the saved file
577
+ """
578
+ try:
579
+ # Create output directory if it doesn't exist
580
+ os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
581
+
582
+ # Determine the appropriate compression parameter
583
+ # Note: tifffile uses 'compression', not 'compress'
584
+ compression_arg = "zlib" if use_compression else None
585
+
586
+ # Calculate approximate file size for BigTIFF decision
587
+ size_gb = (data.size * data.itemsize) / (1024**3)
588
+ use_bigtiff = size_gb > 4.0
589
+
590
+ # Save the file with proper parameters
591
+ tifffile.imwrite(
592
+ output_filepath,
593
+ data,
594
+ resolution=resolution,
595
+ compression=compression_arg,
596
+ bigtiff=use_bigtiff,
597
+ )
598
+
599
+ print(f"✓ Saved timepoint {t} to {output_filepath}")
600
+ return output_filepath
601
+ except Exception as e:
602
+ print(f"✘ Error saving timepoint {t}: {str(e)}")
603
+ traceback.print_exc()
604
+ raise
605
+
606
+ # Store the original process_file function
607
+ original_process_file = ProcessingWorker.process_file
608
+
609
+ # Define the custom process_file function
610
+ def process_file_with_tzyx_splitting(self, filepath):
611
+ """Modified process_file function that handles parallel TZYX splitting."""
612
+ # First call the original function to get the initial result
613
+ result = original_process_file(self, filepath)
614
+
615
+ # Skip further processing if there's no result or no processed_file
616
+ if not isinstance(result, dict) or "processed_file" not in result:
617
+ return result
618
+
619
+ # Get the output path from the original processing
620
+ output_path = result["processed_file"]
621
+ processing_func = self.processing_func
622
+
623
+ # Check if our function has the required attributes for TZYX splitting
624
+ if (
625
+ hasattr(processing_func, "requires_post_processing")
626
+ and processing_func.requires_post_processing
627
+ and hasattr(processing_func, "dask_image")
628
+ and hasattr(processing_func, "produces_multiple_files")
629
+ and processing_func.produces_multiple_files
630
+ ):
631
+ try:
632
+ # Get the Dask image and processing parameters
633
+ dask_image = processing_func.dask_image
634
+ output_name_format = processing_func.output_name_format
635
+ preserve_scale = processing_func.preserve_scale
636
+ use_compression = processing_func.use_compression
637
+ num_workers = processing_func.num_workers
638
+
639
+ # Extract base filename without extension
640
+ basename = os.path.splitext(os.path.basename(output_path))[0]
641
+ dirname = os.path.dirname(output_path)
642
+
643
+ # Try to get scale info from original file if needed
644
+ resolution = None
645
+ if preserve_scale:
646
+ try:
647
+ with tifffile.TiffFile(filepath) as tif:
648
+ if hasattr(tif, "pages") and tif.pages:
649
+ page = tif.pages[0]
650
+ if hasattr(page, "resolution"):
651
+ resolution = page.resolution
652
+ except (OSError, AttributeError, KeyError) as e:
653
+
654
+ print(
655
+ f"Warning: Could not read original resolution: {e}"
656
+ )
657
+
658
+ # Get number of timepoints
659
+ t_size = dask_image.shape[0]
660
+ print(f"Processing {t_size} timepoints in parallel...")
661
+
662
+ # Prepare output paths for each timepoint
663
+ output_filepaths = []
664
+ for t in range(t_size):
665
+ # Format the output filename
666
+ output_filename = output_name_format.format(
667
+ basename=basename, timepoint=t
668
+ )
669
+ # Add extension
670
+ output_filepath = os.path.join(
671
+ dirname, f"{output_filename}.tif"
672
+ )
673
+ output_filepaths.append(output_filepath)
674
+
675
+ # Process timepoints in parallel
676
+ processed_files = []
677
+
678
+ # Use ThreadPoolExecutor for parallel file saving
679
+ with concurrent.futures.ThreadPoolExecutor(
680
+ max_workers=num_workers
681
+ ) as executor:
682
+ # Submit tasks for each timepoint
683
+ future_to_timepoint = {}
684
+ for t in range(t_size):
685
+ # Extract this timepoint's data using Dask
686
+ timepoint_array = dask_image[t].compute()
687
+
688
+ # Submit the task to save this timepoint
689
+ future = executor.submit(
690
+ save_timepoint,
691
+ t,
692
+ timepoint_array,
693
+ output_filepaths[t],
694
+ resolution,
695
+ use_compression,
696
+ )
697
+ future_to_timepoint[future] = t
698
+
699
+ total = len(future_to_timepoint)
700
+ for completed, future in enumerate(
701
+ concurrent.futures.as_completed(future_to_timepoint),
702
+ start=1,
703
+ ):
704
+ t = future_to_timepoint[future]
705
+ try:
706
+ output_filepath = future.result()
707
+ processed_files.append(output_filepath)
708
+ except (OSError, concurrent.futures.TimeoutError) as e:
709
+ print(f"Failed to save timepoint {t}: {e}")
710
+
711
+ # Update progress
712
+ if completed % 5 == 0 or completed == total:
713
+ percent = int(completed * 100 / total)
714
+ print(
715
+ f"Progress: {completed}/{total} timepoints ({percent}%)"
716
+ )
717
+
718
+ # Update the result with the list of processed files
719
+ if processed_files:
720
+ print(
721
+ f"Successfully generated {len(processed_files)} ZYX files from TZYX stack"
722
+ )
723
+ result["processed_files"] = processed_files
724
+
725
+ # Skip creating the original consolidated output file if requested
726
+ if (
727
+ hasattr(processing_func, "skip_original_output")
728
+ and processing_func.skip_original_output
729
+ ):
730
+ # Remove the original file if it was already created
731
+ if os.path.exists(output_path):
732
+ try:
733
+ os.remove(output_path)
734
+ print(
735
+ f"Removed unnecessary consolidated file: {output_path}"
736
+ )
737
+ except OSError as e:
738
+ print(
739
+ f"Warning: Could not remove consolidated file: {e}"
740
+ )
741
+
742
+ # Remove the entry from the result to prevent its display
743
+ if "processed_file" in result:
744
+ del result["processed_file"]
745
+
746
+ else:
747
+ print("Warning: No ZYX files were successfully generated")
748
+
749
+ except (OSError, ValueError, RuntimeError) as e:
750
+
751
+ traceback.print_exc()
752
+ print(f"Error in TZYX splitting post-processing: {e}")
753
+
754
+ return result
755
+
756
+ # Apply the monkey patch
757
+ ProcessingWorker.process_file = process_file_with_tzyx_splitting
758
+
759
+ except (NameError, AttributeError) as e:
760
+ print(f"Warning: Could not apply TZYX splitting patch: {e}")