napari-tmidas 0.1.7.1__py3-none-any.whl → 0.1.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,6 +29,7 @@ from qtpy.QtWidgets import (
29
29
  QWidget,
30
30
  )
31
31
  from skimage.io import imread
32
+ from skimage.transform import resize # Added import for resize function
32
33
  from tifffile import imwrite
33
34
 
34
35
 
@@ -48,6 +49,7 @@ class BatchCropAnything:
48
49
  self.original_image = None
49
50
  self.segmentation_result = None
50
51
  self.current_image_for_segmentation = None
52
+ self.current_scale_factor = 1.0 # Added scale factor tracking
51
53
 
52
54
  # UI references
53
55
  self.image_layer = None
@@ -356,10 +358,41 @@ class BatchCropAnything:
356
358
  # Convert back to uint8
357
359
  image_gamma = (image_gamma * 255).astype(np.uint8)
358
360
 
361
+ # Check if the image is very large and needs downscaling
362
+ orig_shape = image_gamma.shape[:2] # (height, width)
363
+
364
+ # Calculate image size in megapixels
365
+ image_mp = (orig_shape[0] * orig_shape[1]) / 1e6
366
+
367
+ # If image is larger than 2 megapixels, downscale it
368
+ max_mp = 2.0 # Maximum image size in megapixels
369
+ scale_factor = 1.0
370
+
371
+ if image_mp > max_mp:
372
+ scale_factor = np.sqrt(max_mp / image_mp)
373
+ new_height = int(orig_shape[0] * scale_factor)
374
+ new_width = int(orig_shape[1] * scale_factor)
375
+
376
+ self.viewer.status = f"Downscaling image from {orig_shape} to {(new_height, new_width)} for processing (scale: {scale_factor:.2f})"
377
+
378
+ # Resize the image for processing
379
+ image_gamma_resized = resize(
380
+ image_gamma,
381
+ (new_height, new_width),
382
+ anti_aliasing=True,
383
+ preserve_range=True,
384
+ ).astype(np.uint8)
385
+
386
+ # Store scale factor for later use
387
+ self.current_scale_factor = scale_factor
388
+ else:
389
+ image_gamma_resized = image_gamma
390
+ self.current_scale_factor = 1.0
391
+
359
392
  self.viewer.status = f"Generating segmentation with sensitivity {self.sensitivity} (gamma={gamma:.2f})..."
360
393
 
361
- # Generate masks with gamma-corrected image
362
- masks = self.mask_generator.generate(image_gamma)
394
+ # Generate masks with gamma-corrected and potentially resized image
395
+ masks = self.mask_generator.generate(image_gamma_resized)
363
396
  self.viewer.status = f"Generated {len(masks)} masks"
364
397
 
365
398
  if not masks:
@@ -390,9 +423,16 @@ class BatchCropAnything:
390
423
  return
391
424
 
392
425
  # Process segmentation masks
393
- self._process_segmentation_masks(
394
- masks, self.current_image_for_segmentation.shape[:2]
395
- )
426
+ # If image was downscaled, we need to ensure masks are upscaled correctly
427
+ if self.current_scale_factor < 1.0:
428
+ # Upscale the segmentation masks to match the original image dimensions
429
+ self._process_segmentation_masks_with_scaling(
430
+ masks, self.current_image_for_segmentation.shape[:2]
431
+ )
432
+ else:
433
+ self._process_segmentation_masks(
434
+ masks, self.current_image_for_segmentation.shape[:2]
435
+ )
396
436
 
397
437
  # Clear selected labels since segmentation has changed
398
438
  self.selected_labels = set()
@@ -475,6 +515,98 @@ class BatchCropAnything:
475
515
  # image_name = os.path.basename(self.images[self.current_index])
476
516
  self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
477
517
 
518
+ # New method for handling scaled segmentation masks
519
+ def _process_segmentation_masks_with_scaling(self, masks, original_shape):
520
+ """Process segmentation masks with scaling to match the original image size."""
521
+ # Create label image from masks
522
+ # First determine the size of the mask predictions (which are at the downscaled resolution)
523
+ if not masks:
524
+ return
525
+
526
+ mask_shape = masks[0]["segmentation"].shape
527
+
528
+ # Create an empty label image at the downscaled resolution
529
+ downscaled_labels = np.zeros(mask_shape, dtype=np.uint32)
530
+ self.label_info = {} # Reset label info
531
+
532
+ # Fill in the downscaled labels
533
+ for i, mask_data in enumerate(masks):
534
+ mask = mask_data["segmentation"]
535
+ label_id = i + 1 # Start label IDs from 1
536
+ downscaled_labels[mask] = label_id
537
+
538
+ # Store basic label info
539
+ area = np.sum(mask)
540
+ y_indices, x_indices = np.where(mask)
541
+ center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
542
+ center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
543
+
544
+ # Scale centers to original image coordinates
545
+ center_y_orig = center_y / self.current_scale_factor
546
+ center_x_orig = center_x / self.current_scale_factor
547
+
548
+ # Store label info at original scale
549
+ self.label_info[label_id] = {
550
+ "area": area
551
+ / (
552
+ self.current_scale_factor**2
553
+ ), # Approximate area in original scale
554
+ "center_y": center_y_orig,
555
+ "center_x": center_x_orig,
556
+ "score": mask_data.get("stability_score", 0),
557
+ }
558
+
559
+ # Upscale the labels to the original image size
560
+ upscaled_labels = resize(
561
+ downscaled_labels,
562
+ original_shape,
563
+ order=0, # Nearest neighbor interpolation
564
+ preserve_range=True,
565
+ anti_aliasing=False,
566
+ ).astype(np.uint32)
567
+
568
+ # Sort labels by area (largest first)
569
+ self.label_info = dict(
570
+ sorted(
571
+ self.label_info.items(),
572
+ key=lambda item: item[1]["area"],
573
+ reverse=True,
574
+ )
575
+ )
576
+
577
+ # Save segmentation result
578
+ self.segmentation_result = upscaled_labels
579
+
580
+ # Remove existing label layer if exists
581
+ for layer in list(self.viewer.layers):
582
+ if isinstance(layer, Labels) and "Segmentation" in layer.name:
583
+ self.viewer.layers.remove(layer)
584
+
585
+ # Add label layer to viewer
586
+ self.label_layer = self.viewer.add_labels(
587
+ upscaled_labels,
588
+ name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
589
+ opacity=0.7,
590
+ )
591
+
592
+ # Make the label layer active by default
593
+ self.viewer.layers.selection.active = self.label_layer
594
+
595
+ # Disconnect existing callbacks if any
596
+ if (
597
+ hasattr(self, "label_layer")
598
+ and self.label_layer is not None
599
+ and hasattr(self.label_layer, "mouse_drag_callbacks")
600
+ ):
601
+ # Remove old callbacks
602
+ for callback in list(self.label_layer.mouse_drag_callbacks):
603
+ self.label_layer.mouse_drag_callbacks.remove(callback)
604
+
605
+ # Connect mouse click event to label selection
606
+ self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
607
+
608
+ self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
609
+
478
610
  # --------------------------------------------------
479
611
  # Label Selection and UI Elements
480
612
  # --------------------------------------------------
@@ -1105,19 +1105,21 @@ class ConversionWorker(QThread):
1105
1105
  )
1106
1106
  file_size_GB = estimated_size_bytes / (1024**3)
1107
1107
 
1108
- # If file is very large (>4GB), force zarr format regardless of setting
1108
+ # Determine format
1109
1109
  use_zarr = self.use_zarr
1110
- if file_size_GB > 4:
1111
- use_zarr = True
1112
- if not self.use_zarr:
1113
- print(
1114
- f"File size ({file_size_GB:.2f}GB) exceeds 4GB limit for TIF, automatically using ZARR format"
1115
- )
1116
- self.file_done.emit(
1117
- filepath,
1118
- True,
1119
- f"File size ({file_size_GB:.2f}GB) exceeds 4GB, using ZARR format",
1120
- )
1110
+ # If file is very large (>4GB) and user didn't explicitly choose TIF,
1111
+ # auto-switch to ZARR format
1112
+ if file_size_GB > 4 and not self.use_zarr:
1113
+ # Recommend ZARR format but respect user's choice by still allowing TIF
1114
+ print(
1115
+ f"File size ({file_size_GB:.2f}GB) exceeds 4GB, ZARR format is recommended but using TIF with BigTIFF format"
1116
+ )
1117
+ self.file_done.emit(
1118
+ filepath,
1119
+ True,
1120
+ f"File size ({file_size_GB:.2f}GB) exceeds 4GB, using TIF with BigTIFF format",
1121
+ )
1122
+
1121
1123
  # Set up the output path
1122
1124
  if use_zarr:
1123
1125
  output_path = os.path.join(
@@ -1171,12 +1173,22 @@ class ConversionWorker(QThread):
1171
1173
  def _save_tif(
1172
1174
  self, image_data: np.ndarray, output_path: str, metadata: dict = None
1173
1175
  ):
1174
- """Enhanced TIF saving with proper dimension handling"""
1176
+ """Enhanced TIF saving with proper dimension handling and BigTIFF support"""
1175
1177
  import tifffile
1176
1178
 
1177
1179
  print(f"Saving TIF file: {output_path}")
1178
1180
  print(f"Image data shape: {image_data.shape}")
1179
1181
 
1182
+ # Check if this is a large file that needs BigTIFF
1183
+ estimated_size_bytes = np.prod(image_data.shape) * image_data.itemsize
1184
+ file_size_GB = estimated_size_bytes / (1024**3)
1185
+ use_bigtiff = file_size_GB > 4
1186
+
1187
+ if use_bigtiff:
1188
+ print(
1189
+ f"File size ({file_size_GB:.2f}GB) exceeds 4GB, using BigTIFF format"
1190
+ )
1191
+
1180
1192
  if metadata:
1181
1193
  print(f"Metadata keys: {list(metadata.keys())}")
1182
1194
  if "axes" in metadata:
@@ -1198,7 +1210,12 @@ class ConversionWorker(QThread):
1198
1210
  # Basic save if no metadata
1199
1211
  if metadata is None:
1200
1212
  print("No metadata provided, using basic save")
1201
- tifffile.imwrite(output_path, image_data, compression="zstd")
1213
+ tifffile.imwrite(
1214
+ output_path,
1215
+ image_data,
1216
+ compression="zlib",
1217
+ bigtiff=use_bigtiff,
1218
+ )
1202
1219
  return
1203
1220
 
1204
1221
  # Get image dimensions and axis order
@@ -1261,7 +1278,10 @@ class ConversionWorker(QThread):
1261
1278
  print(f"Error reordering dimensions: {e}")
1262
1279
  # Fall back to simple save without reordering
1263
1280
  tifffile.imwrite(
1264
- output_path, image_data, compression="zstd"
1281
+ output_path,
1282
+ image_data,
1283
+ compression="zlib",
1284
+ bigtiff=use_bigtiff,
1265
1285
  )
1266
1286
  return
1267
1287
 
@@ -1287,7 +1307,8 @@ class ConversionWorker(QThread):
1287
1307
  output_path,
1288
1308
  image_data,
1289
1309
  resolution=resolution,
1290
- compression="zstd",
1310
+ compression="zlib",
1311
+ bigtiff=use_bigtiff,
1291
1312
  )
1292
1313
  else:
1293
1314
  # Hyperstack case
@@ -1306,14 +1327,15 @@ class ConversionWorker(QThread):
1306
1327
  imagej=True,
1307
1328
  resolution=resolution,
1308
1329
  metadata=imagej_metadata,
1309
- compression="zstd",
1330
+ compression="zlib",
1331
+ bigtiff=use_bigtiff,
1310
1332
  )
1311
1333
 
1312
1334
  print(f"Successfully saved TIF file: {output_path}")
1313
1335
  except (ValueError, FileNotFoundError) as e:
1314
1336
  print(f"Error saving TIF file: {e}")
1315
1337
  # Try simple save as fallback
1316
- tifffile.imwrite(output_path, image_data)
1338
+ tifffile.imwrite(output_path, image_data, bigtiff=use_bigtiff)
1317
1339
 
1318
1340
  def _save_zarr(
1319
1341
  self, image_data: np.ndarray, output_path: str, metadata: dict = None
@@ -603,15 +603,29 @@ class ProcessingWorker(QThread):
603
603
  self.processing_finished.emit()
604
604
 
605
605
  def process_file(self, filepath):
606
- """Process a single file"""
606
+ """Process a single file with support for large TIFF files and removal of all singleton dimensions"""
607
607
  try:
608
608
  # Load the image
609
609
  image = imread(filepath)
610
610
  image_dtype = image.dtype
611
611
 
612
+ print(f"Original image shape: {image.shape}, dtype: {image_dtype}")
613
+
612
614
  # Apply processing with parameters
613
615
  processed_image = self.processing_func(image, **self.param_values)
614
616
 
617
+ print(
618
+ f"Processed image shape before removing singletons: {processed_image.shape}, dtype: {processed_image.dtype}"
619
+ )
620
+
621
+ # Remove ALL singleton dimensions from the processed image
622
+ # This will keep only dimensions with size > 1
623
+ processed_image = np.squeeze(processed_image)
624
+
625
+ print(
626
+ f"Processed image shape after removing singletons: {processed_image.shape}"
627
+ )
628
+
615
629
  # Generate new filename base
616
630
  filename = os.path.basename(filepath)
617
631
  name, ext = os.path.splitext(filename)
@@ -619,33 +633,91 @@ class ProcessingWorker(QThread):
619
633
  name.replace(self.input_suffix, "") + self.output_suffix
620
634
  )
621
635
 
622
- # Check if the processed image is a stacked array
623
- if processed_image.ndim > image.ndim:
636
+ # Check if the first dimension should be treated as channels
637
+ # If processed_image has more dimensions than the original image,
638
+ # assume the first dimension represents channels
639
+ is_multi_channel = (processed_image.ndim > image.ndim - 1) or (
640
+ processed_image.ndim == image.ndim
641
+ and processed_image.shape[0] <= 10
642
+ )
643
+
644
+ if (
645
+ is_multi_channel and processed_image.shape[0] <= 10
646
+ ): # Reasonable number of channels
624
647
  # Save each channel as a separate image
625
648
  processed_files = []
626
- for i in range(processed_image.shape[0]):
649
+
650
+ num_channels = processed_image.shape[0]
651
+ print(
652
+ f"Treating first dimension as channels. Saving {num_channels} separate channel files"
653
+ )
654
+
655
+ for i in range(num_channels):
627
656
  channel_filename = f"{new_filename_base}_channel_{i}{ext}"
628
657
  channel_filepath = os.path.join(
629
658
  self.output_folder, channel_filename
630
659
  )
631
660
 
661
+ # Extract channel data and remove any remaining singleton dimensions
662
+ channel_image = np.squeeze(processed_image[i])
663
+
664
+ print(f"Channel {i} shape: {channel_image.shape}")
665
+
666
+ # Calculate approx file size in GB
667
+ size_gb = (
668
+ channel_image.size * channel_image.itemsize / (1024**3)
669
+ )
670
+ print(f"Estimated file size: {size_gb:.2f} GB")
671
+
672
+ # Check data range
673
+ data_min = (
674
+ np.min(channel_image) if channel_image.size > 0 else 0
675
+ )
676
+ data_max = (
677
+ np.max(channel_image) if channel_image.size > 0 else 0
678
+ )
679
+ print(f"Channel {i} data range: {data_min} to {data_max}")
680
+
681
+ # For very large files, we need to use BigTIFF format
682
+ use_bigtiff = (
683
+ size_gb > 2.0
684
+ ) # Use BigTIFF for files over 2GB
685
+
632
686
  if (
633
687
  "labels" in channel_filename
634
688
  or "semantic" in channel_filename
635
689
  ):
690
+ # Choose appropriate integer type based on data range
691
+ if data_max <= 255:
692
+ save_dtype = np.uint8
693
+ elif data_max <= 65535:
694
+ save_dtype = np.uint16
695
+ else:
696
+ save_dtype = np.uint32
697
+
698
+ print(
699
+ f"Label image detected, saving as {save_dtype.__name__} with bigtiff={use_bigtiff}"
700
+ )
636
701
  tifffile.imwrite(
637
702
  channel_filepath,
638
- processed_image[i].astype(np.uint32),
703
+ channel_image.astype(save_dtype),
639
704
  compression="zlib",
705
+ bigtiff=use_bigtiff,
640
706
  )
641
707
  else:
642
- # First remove singletons
643
- channel_image = np.squeeze(processed_image[i])
708
+ # Handle large images with bigtiff format
709
+ print(
710
+ f"Regular image channel, saving with dtype {image_dtype} and bigtiff={use_bigtiff}"
711
+ )
712
+
713
+ # Save with original dtype and bigtiff format if needed
644
714
  tifffile.imwrite(
645
715
  channel_filepath,
646
716
  channel_image.astype(image_dtype),
647
717
  compression="zlib",
718
+ bigtiff=use_bigtiff,
648
719
  )
720
+
649
721
  processed_files.append(channel_filepath)
650
722
 
651
723
  # Return processing info
@@ -654,25 +726,61 @@ class ProcessingWorker(QThread):
654
726
  "processed_files": processed_files,
655
727
  }
656
728
  else:
657
- # Save as a single image (original behavior)
729
+ # Save as a single image
658
730
  new_filepath = os.path.join(
659
731
  self.output_folder, new_filename_base + ext
660
732
  )
661
733
 
734
+ print(f"Single output image shape: {processed_image.shape}")
735
+
736
+ # Calculate approx file size in GB
737
+ size_gb = (
738
+ processed_image.size * processed_image.itemsize / (1024**3)
739
+ )
740
+ print(f"Estimated file size: {size_gb:.2f} GB")
741
+
742
+ # For very large files, we need to use BigTIFF format
743
+ use_bigtiff = size_gb > 2.0 # Use BigTIFF for files over 2GB
744
+
745
+ # Check data range
746
+ data_min = (
747
+ np.min(processed_image) if processed_image.size > 0 else 0
748
+ )
749
+ data_max = (
750
+ np.max(processed_image) if processed_image.size > 0 else 0
751
+ )
752
+ print(f"Data range: {data_min} to {data_max}")
753
+
662
754
  if (
663
755
  "labels" in new_filename_base
664
756
  or "semantic" in new_filename_base
665
757
  ):
758
+ # Choose appropriate integer type based on data range
759
+ if data_max <= 255:
760
+ save_dtype = np.uint8
761
+ elif data_max <= 65535:
762
+ save_dtype = np.uint16
763
+ else:
764
+ save_dtype = np.uint32
765
+
766
+ print(
767
+ f"Saving label image as {save_dtype.__name__} with bigtiff={use_bigtiff}"
768
+ )
666
769
  tifffile.imwrite(
667
770
  new_filepath,
668
- processed_image.astype(np.uint32),
771
+ processed_image.astype(save_dtype),
669
772
  compression="zlib",
773
+ bigtiff=use_bigtiff,
670
774
  )
671
775
  else:
776
+ print(
777
+ f"Saving image with dtype {image_dtype} and bigtiff={use_bigtiff}"
778
+ )
672
779
  tifffile.imwrite(
673
780
  new_filepath,
674
781
  processed_image.astype(image_dtype),
675
782
  compression="zlib",
783
+ bigtiff=use_bigtiff,
676
784
  )
677
785
 
678
786
  # Return processing info
@@ -684,6 +792,9 @@ class ProcessingWorker(QThread):
684
792
  except Exception as e:
685
793
  # Log the error and re-raise to be caught by the executor
686
794
  print(f"Error processing {filepath}: {e}")
795
+ import traceback
796
+
797
+ traceback.print_exc()
687
798
  raise
688
799
  finally:
689
800
  # Explicit cleanup to help with memory management
@@ -692,10 +803,6 @@ class ProcessingWorker(QThread):
692
803
  if "processed_image" in locals():
693
804
  del processed_image
694
805
 
695
- def stop(self):
696
- """Request worker to stop processing"""
697
- self.stop_requested = True
698
-
699
806
 
700
807
  class FileResultsWidget(QWidget):
701
808
  """
napari_tmidas/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.7.1'
21
- __version_tuple__ = version_tuple = (0, 1, 7, 1)
20
+ __version__ = version = '0.1.8.5'
21
+ __version_tuple__ = version_tuple = (0, 1, 8, 5)
@@ -7,6 +7,37 @@ import numpy as np
7
7
  from napari_tmidas._registry import BatchProcessingRegistry
8
8
 
9
9
 
10
+ @BatchProcessingRegistry.register(
11
+ name="Labels to Binary",
12
+ suffix="_binary",
13
+ description="Convert multi-label images to binary masks (all non-zero labels become 1)",
14
+ )
15
+ def labels_to_binary(image: np.ndarray) -> np.ndarray:
16
+ """
17
+ Convert multi-label images to binary masks.
18
+
19
+ This function takes a label image (where different regions have different label values)
20
+ and converts it to a binary mask (where all labeled regions have a value of 1 and
21
+ background has a value of 0).
22
+
23
+ Parameters:
24
+ -----------
25
+ image : numpy.ndarray
26
+ Input label image array
27
+
28
+ Returns:
29
+ --------
30
+ numpy.ndarray
31
+ Binary mask with 1 for all non-zero labels and 0 for background
32
+ """
33
+ # Make a copy of the input image to avoid modifying the original
34
+ binary_mask = image.copy()
35
+
36
+ binary_mask = (binary_mask > 0).astype(np.uint32)
37
+
38
+ return binary_mask
39
+
40
+
10
41
  @BatchProcessingRegistry.register(
11
42
  name="Gamma Correction",
12
43
  suffix="_gamma",
@@ -123,3 +154,76 @@ def split_channels(image: np.ndarray, num_channels: int = 3) -> np.ndarray:
123
154
  # channels = [np.squeeze(ch, axis=channel_axis) for ch in channels]
124
155
 
125
156
  return np.stack(channels, axis=0)
157
+
158
+
159
+ @BatchProcessingRegistry.register(
160
+ name="RGB to Labels",
161
+ suffix="_labels",
162
+ description="Convert RGB images to label images using a color map",
163
+ parameters={
164
+ "blue_label": {
165
+ "type": int,
166
+ "default": 1,
167
+ "min": 0,
168
+ "max": 255,
169
+ "description": "Label value for blue objects",
170
+ },
171
+ "green_label": {
172
+ "type": int,
173
+ "default": 2,
174
+ "min": 0,
175
+ "max": 255,
176
+ "description": "Label value for green objects",
177
+ },
178
+ "red_label": {
179
+ "type": int,
180
+ "default": 3,
181
+ "min": 0,
182
+ "max": 255,
183
+ "description": "Label value for red objects",
184
+ },
185
+ },
186
+ )
187
+ def rgb_to_labels(
188
+ image: np.ndarray,
189
+ blue_label: int = 1,
190
+ green_label: int = 2,
191
+ red_label: int = 3,
192
+ ) -> np.ndarray:
193
+ """
194
+ Convert RGB images to label images where each color is mapped to a specific label value.
195
+
196
+ Parameters:
197
+ -----------
198
+ image : numpy.ndarray
199
+ Input RGB image array
200
+ blue_label : int
201
+ Label value for blue objects (default: 1)
202
+ green_label : int
203
+ Label value for green objects (default: 2)
204
+ red_label : int
205
+ Label value for red objects (default: 3)
206
+
207
+ Returns:
208
+ --------
209
+ numpy.ndarray
210
+ Label image where each color is mapped to the specified label value
211
+ """
212
+ # Ensure the image is a proper RGB image
213
+ if image.ndim < 3 or image.shape[-1] != 3:
214
+ raise ValueError("Input must be an RGB image with 3 channels")
215
+
216
+ # Define the color mapping
217
+ color_mapping = {
218
+ (0, 0, 255): blue_label, # Blue
219
+ (0, 255, 0): green_label, # Green
220
+ (255, 0, 0): red_label, # Red
221
+ }
222
+ # Create an empty label image
223
+ label_image = np.zeros(image.shape[:-1], dtype=np.uint32)
224
+ # Iterate through the color mapping and assign labels
225
+ for color, label in color_mapping.items():
226
+ mask = np.all(image == color, axis=-1)
227
+ label_image[mask] = label
228
+ # Return the label image
229
+ return label_image