napari-tmidas 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,14 @@
1
1
  # processing_functions/basic.py
2
2
  """
3
- Basic image processing functions that don't require additional dependencies.
3
+ Basic image processing functions
4
4
  """
5
+ import concurrent.futures
6
+ import os
7
+ import traceback
8
+
9
+ import dask.array as da
5
10
  import numpy as np
11
+ import tifffile
6
12
 
7
13
  from napari_tmidas._registry import BatchProcessingRegistry
8
14
 
@@ -161,24 +167,47 @@ def max_z_projection_tzyx(image: np.ndarray) -> np.ndarray:
161
167
 
162
168
 
163
169
  @BatchProcessingRegistry.register(
164
- name="Split Channels",
165
- suffix="_split_channels",
170
+ name="Split Color Channels",
171
+ suffix="_split_color_channels",
166
172
  description="Splits the color channels of the image",
167
173
  parameters={
168
174
  "num_channels": {
169
- "type": "integer",
175
+ "type": int,
170
176
  "default": 3,
177
+ "min": 2,
178
+ "max": 4,
171
179
  "description": "Number of color channels in the image",
172
- }
180
+ },
181
+ "time_steps": {
182
+ "type": int,
183
+ "default": 0,
184
+ "min": 0,
185
+ "max": 1000,
186
+ "description": "Number of time steps (leave 0 if not a time series)",
187
+ },
188
+ "output_format": {
189
+ "type": str,
190
+ "default": "python",
191
+ "options": ["python", "fiji"],
192
+ "description": "Output dimension order: python (standard) or fiji (ImageJ/Fiji compatible)",
193
+ },
173
194
  },
174
195
  )
175
- def split_channels(image: np.ndarray, num_channels: int = 3) -> np.ndarray:
196
+ def split_channels(
197
+ image: np.ndarray,
198
+ num_channels: int = 3,
199
+ time_steps: int = 0,
200
+ output_format: str = "python",
201
+ ) -> np.ndarray:
176
202
  """
177
203
  Split the image into separate channels based on the specified number of channels.
204
+ Can handle various dimensional orderings including time series data.
178
205
 
179
206
  Args:
180
207
  image: Input image array (at least 3D: XYC or higher dimensions)
181
208
  num_channels: Number of channels in the image (default: 3)
209
+ time_steps: Number of time steps if time series (default: 0, meaning not a time series)
210
+ output_format: Dimension order format, either "python" (standard) or "fiji" (ImageJ compatible)
182
211
 
183
212
  Returns:
184
213
  Stacked array of channels with shape (num_channels, ...)
@@ -190,30 +219,153 @@ def split_channels(image: np.ndarray, num_channels: int = 3) -> np.ndarray:
190
219
  )
191
220
 
192
221
  print(f"Image shape: {image.shape}")
193
- num_channels = int(num_channels)
194
- # Identify the channel axis
195
- possible_axes = [
196
- axis
197
- for axis, dim_size in enumerate(image.shape)
198
- if dim_size == num_channels
199
- ]
200
- # print(f"Possible axes: {possible_axes}")
201
- if len(possible_axes) != 1:
202
-
222
+ is_timelapse = time_steps > 0
223
+ is_3d = (
224
+ image.ndim > 3
225
+ ) # More than 3 dimensions likely means 3D + channels or time series
226
+
227
+ # Find channel axis based on provided channel count
228
+ channel_axis = None
229
+ for axis, dim_size in enumerate(image.shape):
230
+ if dim_size == num_channels:
231
+ # Found a dimension matching the specified channel count
232
+ channel_axis = axis
233
+ # If we have multiple matching dimensions, prefer the one that's not likely spatial
234
+ if (
235
+ axis < image.ndim - 2
236
+ ): # Not one of the last two dimensions (likely spatial)
237
+ break
238
+
239
+ # If channel axis is not found with exact match, look for other possibilities
240
+ if channel_axis is None:
241
+ # Try to infer channel axis using heuristics
242
+ ndim = image.ndim
243
+
244
+ # Check dimensions for a small value (1-16) that could be channels
245
+ for i, dim_size in enumerate(image.shape):
246
+ # Skip dimensions that are likely spatial (Y,X) - typically the last two
247
+ if i >= ndim - 2:
248
+ continue
249
+ # Skip first dimension if this is a time series
250
+ if is_timelapse and i == 0:
251
+ continue
252
+ # A dimension with size 1-16 is likely channels
253
+ if 1 <= dim_size <= 16:
254
+ channel_axis = i
255
+ break
256
+
257
+ # If still not found, check even the spatial dimensions (for RGB images)
258
+ if channel_axis is None and image.shape[-1] <= 16:
259
+ channel_axis = ndim - 1
260
+
261
+ if channel_axis is None:
203
262
  raise ValueError(
204
- f"Could not uniquely identify a channel axis with {num_channels} channels. "
205
- f"Found {len(possible_axes)} possible axes: {possible_axes}. "
206
- f"Image shape: {image.shape}"
263
+ f"Could not identify a channel axis. Please check if the number of channels ({num_channels}) "
264
+ f"matches any dimension in your image shape {image.shape}"
207
265
  )
208
266
 
209
- channel_axis = possible_axes[0]
210
267
  print(f"Channel axis identified: {channel_axis}")
211
268
 
212
- # Split and process channels
269
+ # Generate dimensional understanding for better handling
270
+ # Create axes string to understand dimension ordering
271
+ axes = [""] * image.ndim
272
+
273
+ # Assign channel axis
274
+ axes[channel_axis] = "C"
275
+
276
+ # Assign time axis if present
277
+ if is_timelapse and 0 not in [
278
+ channel_axis
279
+ ]: # If channel is not at position 0
280
+ axes[0] = "T"
281
+
282
+ # Assign remaining spatial dimensions
283
+ remaining_dims = [i for i in range(image.ndim) if axes[i] == ""]
284
+ spatial_axes = []
285
+ if is_3d and len(remaining_dims) > 2:
286
+ # We have Z dimension
287
+ spatial_axes.append("Z")
288
+
289
+ # Add Y and X
290
+ spatial_axes.extend(["Y", "X"])
291
+
292
+ # Assign remaining dimensions
293
+ for i, dim in enumerate(remaining_dims):
294
+ if i < len(spatial_axes):
295
+ axes[dim] = spatial_axes[i]
296
+ else:
297
+ axes[dim] = "A" # Anonymous dimension
298
+
299
+ axes_str = "".join(axes)
300
+ print(f"Inferred dimension order: {axes_str}")
301
+
302
+ # Split along the channel axis
303
+ actual_channels = image.shape[channel_axis]
304
+ if actual_channels != num_channels:
305
+ print(
306
+ f"Warning: Specified {num_channels} channels but found {actual_channels} in the data. Using {actual_channels}."
307
+ )
308
+ num_channels = actual_channels
309
+
310
+ # Split channels
213
311
  channels = np.split(image, num_channels, axis=channel_axis)
214
- # channels = [np.squeeze(ch, axis=channel_axis) for ch in channels]
215
312
 
216
- return np.stack(channels, axis=0)
313
+ # Process output format
314
+ result_channels = []
315
+ for i, channel_img in enumerate(channels):
316
+ # Get original axes without channel
317
+ axes_without_channel = axes.copy()
318
+ del axes_without_channel[channel_axis]
319
+ axes_without_channel_str = "".join(axes_without_channel)
320
+
321
+ # For fiji format, reorganize dimensions to TZYX order
322
+ if output_format.lower() == "fiji":
323
+ # Map dimensions to positions
324
+ dim_indices = {
325
+ dim: i for i, dim in enumerate(axes_without_channel_str)
326
+ }
327
+
328
+ # Build target order and transpose indices
329
+ target_order = ""
330
+ transpose_indices = []
331
+
332
+ # Add T if exists
333
+ if "T" in dim_indices:
334
+ target_order += "T"
335
+ transpose_indices.append(dim_indices["T"])
336
+
337
+ # Add Z if exists
338
+ if "Z" in dim_indices:
339
+ target_order += "Z"
340
+ transpose_indices.append(dim_indices["Z"])
341
+
342
+ # Add Y and X (should always exist)
343
+ if "Y" in dim_indices and "X" in dim_indices:
344
+ target_order += "YX"
345
+ transpose_indices.append(dim_indices["Y"])
346
+ transpose_indices.append(dim_indices["X"])
347
+
348
+ # Only transpose if order is different and we have enough dimensions
349
+ if (
350
+ axes_without_channel_str != target_order
351
+ and len(transpose_indices) > 1
352
+ and len(transpose_indices) == len(axes_without_channel)
353
+ ):
354
+ print(
355
+ f"Channel {i}: Transposing from {axes_without_channel_str} to {target_order}"
356
+ )
357
+ result_channels.append(
358
+ np.transpose(channel_img, transpose_indices)
359
+ )
360
+ else:
361
+ # Keep as is
362
+ result_channels.append(channel_img)
363
+ else:
364
+ # For python format, keep as is
365
+ result_channels.append(channel_img)
366
+
367
+ # Stack channels along a new first dimension
368
+ return np.stack(result_channels, axis=0)
217
369
 
218
370
 
219
371
  @BatchProcessingRegistry.register(
@@ -287,3 +439,322 @@ def rgb_to_labels(
287
439
  label_image[mask] = label
288
440
  # Return the label image
289
441
  return label_image
442
+
443
+
444
+ @BatchProcessingRegistry.register(
445
+ name="Split TZYX into ZYX TIFs",
446
+ suffix="_split",
447
+ description="Splits a 4D TZYX image stack into separate 3D ZYX TIFs for each time point using parallel processing",
448
+ parameters={
449
+ "output_name_format": {
450
+ "type": str,
451
+ "default": "{basename}_t{timepoint:03d}",
452
+ "description": "Format for output filenames. Use {basename} and {timepoint} as placeholders",
453
+ },
454
+ "preserve_scale": {
455
+ "type": bool,
456
+ "default": True,
457
+ "description": "Preserve scale/resolution metadata when saving",
458
+ },
459
+ "use_compression": {
460
+ "type": bool,
461
+ "default": True,
462
+ "description": "Apply zlib compression to output files",
463
+ },
464
+ "num_workers": {
465
+ "type": int,
466
+ "default": 4,
467
+ "min": 1,
468
+ "max": 16,
469
+ "description": "Number of worker processes for parallel processing",
470
+ },
471
+ },
472
+ )
473
+ def split_tzyx_stack(
474
+ image: np.ndarray,
475
+ output_name_format: str = "{basename}_t{timepoint:03d}",
476
+ preserve_scale: bool = True,
477
+ use_compression: bool = True,
478
+ num_workers: int = 4,
479
+ ) -> np.ndarray:
480
+ """
481
+ Split a 4D TZYX stack into separate 3D ZYX TIF files using parallel processing.
482
+
483
+ This function takes a 4D TZYX image stack and saves each time point as
484
+ a separate 3D ZYX TIF file. Files are processed in parallel for better performance.
485
+ The original 4D stack is returned unchanged.
486
+
487
+ Parameters:
488
+ -----------
489
+ image : numpy.ndarray
490
+ Input 4D TZYX image stack
491
+ output_name_format : str
492
+ Format string for output filenames. Use {basename} and {timepoint} as placeholders.
493
+ Default: "{basename}_t{timepoint:03d}"
494
+ preserve_scale : bool
495
+ Whether to preserve scale/resolution metadata when saving
496
+ use_compression : bool
497
+ Whether to apply zlib compression to output files
498
+ num_workers : int
499
+ Number of worker processes for parallel file saving
500
+
501
+ Returns:
502
+ --------
503
+ numpy.ndarray
504
+ The original image (unchanged)
505
+ """
506
+ # Validate input dimensions
507
+ if image.ndim != 4:
508
+ print(
509
+ f"Warning: Expected 4D TZYX input, got {image.ndim}D. Returning original image."
510
+ )
511
+ return image
512
+
513
+ # Use dask array to optimize memory usage when processing slices
514
+ chunks = (1,) + image.shape[1:] # Each timepoint is a chunk
515
+ dask_image = da.from_array(image, chunks=chunks)
516
+
517
+ # Store processing parameters for post-processing
518
+ split_tzyx_stack.dask_image = dask_image
519
+ split_tzyx_stack.output_name_format = output_name_format
520
+ split_tzyx_stack.preserve_scale = preserve_scale
521
+ split_tzyx_stack.use_compression = use_compression
522
+ split_tzyx_stack.num_workers = min(
523
+ num_workers, image.shape[0]
524
+ ) # Limit workers to number of timepoints
525
+
526
+ # Mark for post-processing with multiple output files
527
+ split_tzyx_stack.requires_post_processing = True
528
+ split_tzyx_stack.produces_multiple_files = True
529
+ # Tell the processing system to skip creating the original output file
530
+ split_tzyx_stack.skip_original_output = True
531
+
532
+ # Get dimensions for informational purposes
533
+ t_size, z_size, y_size, x_size = image.shape
534
+ print(f"TZYX stack dimensions: {image.shape}, dtype: {image.dtype}")
535
+ print(f"Will generate {t_size} separate ZYX files")
536
+ print(f"Parallelization: {split_tzyx_stack.num_workers} workers")
537
+
538
+ # The actual file saving will happen in the post-processing step
539
+ return image
540
+
541
+
542
+ # Monkey patch ProcessingWorker.process_file to handle parallel TZYX splitting
543
+ try:
544
+ # Import tifffile here to ensure it's available for the monkey patch
545
+ import tifffile
546
+
547
+ from napari_tmidas._file_selector import ProcessingWorker
548
+
549
+ # Define function to save a single timepoint
550
+ def save_timepoint(
551
+ t: int,
552
+ data: np.ndarray,
553
+ output_filepath: str,
554
+ resolution=None,
555
+ use_compression=True,
556
+ ) -> str:
557
+ """
558
+ Save a single timepoint to disk.
559
+
560
+ Parameters:
561
+ -----------
562
+ t : int
563
+ Timepoint index for logging
564
+ data : np.ndarray
565
+ 3D ZYX data to save
566
+ output_filepath : str
567
+ Path to save the file
568
+ resolution : tuple, optional
569
+ Resolution metadata to preserve
570
+ use_compression : bool
571
+ Whether to use compression
572
+
573
+ Returns:
574
+ --------
575
+ str
576
+ Path to the saved file
577
+ """
578
+ try:
579
+ # Create output directory if it doesn't exist
580
+ os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
581
+
582
+ # Determine the appropriate compression parameter
583
+ # Note: tifffile uses 'compression', not 'compress'
584
+ compression_arg = "zlib" if use_compression else None
585
+
586
+ # Calculate approximate file size for BigTIFF decision
587
+ size_gb = (data.size * data.itemsize) / (1024**3)
588
+ use_bigtiff = size_gb > 4.0
589
+
590
+ # Save the file with proper parameters
591
+ tifffile.imwrite(
592
+ output_filepath,
593
+ data,
594
+ resolution=resolution,
595
+ compression=compression_arg,
596
+ bigtiff=use_bigtiff,
597
+ )
598
+
599
+ print(f"✓ Saved timepoint {t} to {output_filepath}")
600
+ return output_filepath
601
+ except Exception as e:
602
+ print(f"✘ Error saving timepoint {t}: {str(e)}")
603
+ traceback.print_exc()
604
+ raise
605
+
606
+ # Store the original process_file function
607
+ original_process_file = ProcessingWorker.process_file
608
+
609
+ # Define the custom process_file function
610
+ def process_file_with_tzyx_splitting(self, filepath):
611
+ """Modified process_file function that handles parallel TZYX splitting."""
612
+ # First call the original function to get the initial result
613
+ result = original_process_file(self, filepath)
614
+
615
+ # Skip further processing if there's no result or no processed_file
616
+ if not isinstance(result, dict) or "processed_file" not in result:
617
+ return result
618
+
619
+ # Get the output path from the original processing
620
+ output_path = result["processed_file"]
621
+ processing_func = self.processing_func
622
+
623
+ # Check if our function has the required attributes for TZYX splitting
624
+ if (
625
+ hasattr(processing_func, "requires_post_processing")
626
+ and processing_func.requires_post_processing
627
+ and hasattr(processing_func, "dask_image")
628
+ and hasattr(processing_func, "produces_multiple_files")
629
+ and processing_func.produces_multiple_files
630
+ ):
631
+ try:
632
+ # Get the Dask image and processing parameters
633
+ dask_image = processing_func.dask_image
634
+ output_name_format = processing_func.output_name_format
635
+ preserve_scale = processing_func.preserve_scale
636
+ use_compression = processing_func.use_compression
637
+ num_workers = processing_func.num_workers
638
+
639
+ # Extract base filename without extension
640
+ basename = os.path.splitext(os.path.basename(output_path))[0]
641
+ dirname = os.path.dirname(output_path)
642
+
643
+ # Try to get scale info from original file if needed
644
+ resolution = None
645
+ if preserve_scale:
646
+ try:
647
+ with tifffile.TiffFile(filepath) as tif:
648
+ if hasattr(tif, "pages") and tif.pages:
649
+ page = tif.pages[0]
650
+ if hasattr(page, "resolution"):
651
+ resolution = page.resolution
652
+ except (OSError, AttributeError, KeyError) as e:
653
+
654
+ print(
655
+ f"Warning: Could not read original resolution: {e}"
656
+ )
657
+
658
+ # Get number of timepoints
659
+ t_size = dask_image.shape[0]
660
+ print(f"Processing {t_size} timepoints in parallel...")
661
+
662
+ # Prepare output paths for each timepoint
663
+ output_filepaths = []
664
+ for t in range(t_size):
665
+ # Format the output filename
666
+ output_filename = output_name_format.format(
667
+ basename=basename, timepoint=t
668
+ )
669
+ # Add extension
670
+ output_filepath = os.path.join(
671
+ dirname, f"{output_filename}.tif"
672
+ )
673
+ output_filepaths.append(output_filepath)
674
+
675
+ # Process timepoints in parallel
676
+ processed_files = []
677
+
678
+ # Use ThreadPoolExecutor for parallel file saving
679
+ with concurrent.futures.ThreadPoolExecutor(
680
+ max_workers=num_workers
681
+ ) as executor:
682
+ # Submit tasks for each timepoint
683
+ future_to_timepoint = {}
684
+ for t in range(t_size):
685
+ # Extract this timepoint's data using Dask
686
+ timepoint_array = dask_image[t].compute()
687
+
688
+ # Submit the task to save this timepoint
689
+ future = executor.submit(
690
+ save_timepoint,
691
+ t,
692
+ timepoint_array,
693
+ output_filepaths[t],
694
+ resolution,
695
+ use_compression,
696
+ )
697
+ future_to_timepoint[future] = t
698
+
699
+ total = len(future_to_timepoint)
700
+ for completed, future in enumerate(
701
+ concurrent.futures.as_completed(future_to_timepoint),
702
+ start=1,
703
+ ):
704
+ t = future_to_timepoint[future]
705
+ try:
706
+ output_filepath = future.result()
707
+ processed_files.append(output_filepath)
708
+ except (OSError, concurrent.futures.TimeoutError) as e:
709
+ print(f"Failed to save timepoint {t}: {e}")
710
+
711
+ # Update progress
712
+ if completed % 5 == 0 or completed == total:
713
+ percent = int(completed * 100 / total)
714
+ print(
715
+ f"Progress: {completed}/{total} timepoints ({percent}%)"
716
+ )
717
+
718
+ # Update the result with the list of processed files
719
+ if processed_files:
720
+ print(
721
+ f"Successfully generated {len(processed_files)} ZYX files from TZYX stack"
722
+ )
723
+ result["processed_files"] = processed_files
724
+
725
+ # Skip creating the original consolidated output file if requested
726
+ if (
727
+ hasattr(processing_func, "skip_original_output")
728
+ and processing_func.skip_original_output
729
+ ):
730
+ # Remove the original file if it was already created
731
+ if os.path.exists(output_path):
732
+ try:
733
+ os.remove(output_path)
734
+ print(
735
+ f"Removed unnecessary consolidated file: {output_path}"
736
+ )
737
+ except OSError as e:
738
+ print(
739
+ f"Warning: Could not remove consolidated file: {e}"
740
+ )
741
+
742
+ # Remove the entry from the result to prevent its display
743
+ if "processed_file" in result:
744
+ del result["processed_file"]
745
+
746
+ else:
747
+ print("Warning: No ZYX files were successfully generated")
748
+
749
+ except (OSError, ValueError, RuntimeError) as e:
750
+
751
+ traceback.print_exc()
752
+ print(f"Error in TZYX splitting post-processing: {e}")
753
+
754
+ return result
755
+
756
+ # Apply the monkey patch
757
+ ProcessingWorker.process_file = process_file_with_tzyx_splitting
758
+
759
+ except (NameError, AttributeError) as e:
760
+ print(f"Warning: Could not apply TZYX splitting patch: {e}")