geoai-py 0.3.6__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.3.6"
5
+ __version__ = "0.4.0"
6
6
 
7
7
 
8
8
  import os
geoai/download.py CHANGED
@@ -1,18 +1,19 @@
1
1
  """This module provides functions to download data, including NAIP imagery and building data from Overture Maps."""
2
2
 
3
+ import logging
3
4
  import os
4
- from typing import List, Tuple, Optional, Dict, Any
5
- import rioxarray
6
- import numpy as np
5
+ import subprocess
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import geopandas as gpd
7
9
  import matplotlib.pyplot as plt
8
- from pystac_client import Client
10
+ import numpy as np
9
11
  import planetary_computer as pc
10
- import geopandas as gpd
12
+ import requests
13
+ import rioxarray
14
+ from pystac_client import Client
11
15
  from shapely.geometry import box
12
16
  from tqdm import tqdm
13
- import requests
14
- import subprocess
15
- import logging
16
17
 
17
18
  # Configure logging
18
19
  logging.basicConfig(
geoai/extract.py CHANGED
@@ -14,16 +14,15 @@ import torch
14
14
  from huggingface_hub import hf_hub_download
15
15
  from rasterio.windows import Window
16
16
  from shapely.geometry import Polygon, box
17
- from tqdm import tqdm
18
17
  from torchvision.models.detection import (
19
- maskrcnn_resnet50_fpn,
20
18
  fasterrcnn_resnet50_fpn_v2,
19
+ maskrcnn_resnet50_fpn,
21
20
  )
21
+ from tqdm import tqdm
22
22
 
23
23
  # Local Imports
24
24
  from .utils import get_raster_stats
25
25
 
26
-
27
26
  try:
28
27
  from torchgeo.datasets import NonGeoDataset
29
28
  except ImportError as e:
@@ -270,7 +269,9 @@ class ObjectDetector:
270
269
  Object extraction using Mask R-CNN with TorchGeo.
271
270
  """
272
271
 
273
- def __init__(self, model_path=None, repo_id=None, model=None, device=None):
272
+ def __init__(
273
+ self, model_path=None, repo_id=None, model=None, num_classes=2, device=None
274
+ ):
274
275
  """
275
276
  Initialize the object extractor.
276
277
 
@@ -278,6 +279,7 @@ class ObjectDetector:
278
279
  model_path: Path to the .pth model file.
279
280
  repo_id: Hugging Face repository ID for model download.
280
281
  model: Pre-initialized model object (optional).
282
+ num_classes: Number of classes for detection (default: 2).
281
283
  device: Device to use for inference ('cuda:0', 'cpu', etc.).
282
284
  """
283
285
  # Set device
@@ -297,7 +299,7 @@ class ObjectDetector:
297
299
  self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
298
300
 
299
301
  # Initialize model
300
- self.model = self.initialize_model(model)
302
+ self.model = self.initialize_model(model, num_classes=num_classes)
301
303
 
302
304
  # Download model if needed
303
305
  if model_path is None or (not os.path.exists(model_path)):
@@ -342,11 +344,12 @@ class ObjectDetector:
342
344
  print("Please specify a local model path or ensure internet connectivity.")
343
345
  raise
344
346
 
345
- def initialize_model(self, model):
347
+ def initialize_model(self, model, num_classes=2):
346
348
  """Initialize a deep learning model for object detection.
347
349
 
348
350
  Args:
349
351
  model (torch.nn.Module): A pre-initialized model object.
352
+ num_classes (int): Number of classes for detection.
350
353
 
351
354
  Returns:
352
355
  torch.nn.Module: A deep learning model for object detection.
@@ -361,7 +364,7 @@ class ObjectDetector:
361
364
  model = maskrcnn_resnet50_fpn(
362
365
  weights=None,
363
366
  progress=False,
364
- num_classes=2, # Background + object
367
+ num_classes=num_classes, # Background + object
365
368
  weights_backbone=None,
366
369
  # These parameters ensure consistent normalization
367
370
  image_mean=image_mean,
@@ -1306,13 +1309,14 @@ class ObjectDetector:
1306
1309
  Returns:
1307
1310
  GeoDataFrame with regularized objects
1308
1311
  """
1312
+ import math
1313
+
1314
+ import cv2
1315
+ import geopandas as gpd
1309
1316
  import numpy as np
1310
- from shapely.geometry import Polygon, MultiPolygon, box
1311
1317
  from shapely.affinity import rotate, translate
1312
- import geopandas as gpd
1313
- import math
1318
+ from shapely.geometry import MultiPolygon, Polygon, box
1314
1319
  from tqdm import tqdm
1315
- import cv2
1316
1320
 
1317
1321
  def get_angle(p1, p2, p3):
1318
1322
  """Calculate angle between three points in degrees (0-180)"""
@@ -2112,7 +2116,7 @@ class ObjectDetector:
2112
2116
  output_path=None,
2113
2117
  confidence_threshold=0.5,
2114
2118
  min_object_area=100,
2115
- max_object_size=None,
2119
+ max_object_area=None,
2116
2120
  **kwargs,
2117
2121
  ):
2118
2122
  """
@@ -2123,7 +2127,7 @@ class ObjectDetector:
2123
2127
  output_path: Path for output GeoJSON.
2124
2128
  confidence_threshold: Minimum confidence score (0.0-1.0). Default: 0.5
2125
2129
  min_object_area: Minimum area in pixels to keep an object. Default: 100
2126
- max_object_size: Maximum area in pixels to keep an object. Default: None
2130
+ max_object_area: Maximum area in pixels to keep an object. Default: None
2127
2131
  **kwargs: Additional parameters
2128
2132
 
2129
2133
  Returns:
@@ -2147,8 +2151,9 @@ class ObjectDetector:
2147
2151
  print(f"Found {num_features} connected components")
2148
2152
 
2149
2153
  # Process each component
2150
- car_polygons = []
2151
- car_confidences = []
2154
+ polygons = []
2155
+ confidences = []
2156
+ pixels = []
2152
2157
 
2153
2158
  # Add progress bar
2154
2159
  for label in tqdm(range(1, num_features + 1), desc="Processing components"):
@@ -2179,8 +2184,8 @@ class ObjectDetector:
2179
2184
  if area < min_object_area:
2180
2185
  continue
2181
2186
 
2182
- if max_object_size is not None:
2183
- if area > max_object_size:
2187
+ if max_object_area is not None:
2188
+ if area > max_object_area:
2184
2189
  continue
2185
2190
 
2186
2191
  # Get minimum area rectangle
@@ -2197,16 +2202,18 @@ class ObjectDetector:
2197
2202
  poly = Polygon(geo_points)
2198
2203
 
2199
2204
  # Add to lists
2200
- car_polygons.append(poly)
2201
- car_confidences.append(confidence)
2205
+ polygons.append(poly)
2206
+ confidences.append(confidence)
2207
+ pixels.append(area)
2202
2208
 
2203
2209
  # Create GeoDataFrame
2204
- if car_polygons:
2210
+ if polygons:
2205
2211
  gdf = gpd.GeoDataFrame(
2206
2212
  {
2207
- "geometry": car_polygons,
2208
- "confidence": car_confidences,
2209
- "class": [1] * len(car_polygons),
2213
+ "geometry": polygons,
2214
+ "confidence": confidences,
2215
+ "class": [1] * len(polygons),
2216
+ "pixels": pixels,
2210
2217
  },
2211
2218
  crs=crs,
2212
2219
  )
@@ -2218,7 +2225,7 @@ class ObjectDetector:
2218
2225
 
2219
2226
  return gdf
2220
2227
  else:
2221
- print("No valid car polygons found")
2228
+ print("No valid polygons found")
2222
2229
  return None
2223
2230
 
2224
2231
 
@@ -2356,3 +2363,37 @@ class SolarPanelDetector(ObjectDetector):
2356
2363
  super().__init__(
2357
2364
  model_path=model_path, repo_id=repo_id, model=model, device=device
2358
2365
  )
2366
+
2367
+
2368
+ class ParkingSplotDetector(ObjectDetector):
2369
+ """
2370
+ Car detection using a pre-trained Mask R-CNN model.
2371
+
2372
+ This class extends the `ObjectDetector` class with additional methods for car detection.
2373
+ """
2374
+
2375
+ def __init__(
2376
+ self,
2377
+ model_path="parking_spot_detection.pth",
2378
+ repo_id=None,
2379
+ model=None,
2380
+ num_classes=3,
2381
+ device=None,
2382
+ ):
2383
+ """
2384
+ Initialize the object extractor.
2385
+
2386
+ Args:
2387
+ model_path: Path to the .pth model file.
2388
+ repo_id: Repo ID for loading models from the Hub.
2389
+ model: Custom model to use for inference.
2390
+ num_classes: Number of classes for the model. Default: 3
2391
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2392
+ """
2393
+ super().__init__(
2394
+ model_path=model_path,
2395
+ repo_id=repo_id,
2396
+ model=model,
2397
+ num_classes=num_classes,
2398
+ device=device,
2399
+ )
geoai/geoai.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Main module."""
2
2
 
3
- from .utils import *
4
3
  from .extract import *
4
+ from .hf import *
5
5
  from .segment import *
6
+ from .utils import *
7
+ from .train import train_MaskRCNN_model, object_detection
geoai/hf.py ADDED
@@ -0,0 +1,447 @@
1
+ """This module contains utility functions for working with Hugging Face models."""
2
+
3
+ import csv
4
+ import os
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import rasterio
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+ from transformers import AutoConfig, AutoModelForMaskedImageModeling, pipeline
13
+
14
+
15
+ def get_model_config(model_id):
16
+ """
17
+ Get the model configuration for a Hugging Face model.
18
+
19
+ Args:
20
+ model_id (str): The Hugging Face model ID.
21
+
22
+ Returns:
23
+ transformers.configuration_utils.PretrainedConfig: The model configuration.
24
+ """
25
+ return AutoConfig.from_pretrained(model_id)
26
+
27
+
28
+ def get_model_input_channels(model_id):
29
+ """
30
+ Check the number of input channels supported by a Hugging Face model.
31
+
32
+ Args:
33
+ model_id (str): The Hugging Face model ID.
34
+
35
+ Returns:
36
+ int: The number of input channels the model accepts.
37
+
38
+ Raises:
39
+ ValueError: If unable to determine the number of input channels.
40
+ """
41
+ # Load the model configuration
42
+ config = AutoConfig.from_pretrained(model_id)
43
+
44
+ # For Mask2Former models
45
+ if hasattr(config, "backbone_config"):
46
+ if hasattr(config.backbone_config, "num_channels"):
47
+ return config.backbone_config.num_channels
48
+
49
+ # Try to load the model and inspect its architecture
50
+ try:
51
+ model = AutoModelForMaskedImageModeling.from_pretrained(model_id)
52
+
53
+ # For Swin Transformer-based models like Mask2Former
54
+ if hasattr(model, "backbone") and hasattr(model.backbone, "embeddings"):
55
+ if hasattr(model.backbone.embeddings, "patch_embeddings"):
56
+ # Swin models typically have patch embeddings that indicate channel count
57
+ return model.backbone.embeddings.patch_embeddings.in_channels
58
+ except Exception as e:
59
+ print(f"Couldn't inspect model architecture: {e}")
60
+
61
+ # Default for most vision models
62
+ return 3
63
+
64
+
65
+ def image_segmentation(
66
+ tif_path,
67
+ output_path,
68
+ labels_to_extract=None,
69
+ dtype="uint8",
70
+ model_name=None,
71
+ segmenter_args=None,
72
+ **kwargs,
73
+ ):
74
+ """
75
+ Segments an image with a Hugging Face segmentation model and saves the results
76
+ as a single georeferenced image where each class has a unique integer value.
77
+
78
+ Args:
79
+ tif_path (str): Path to the input georeferenced TIF file.
80
+ output_path (str): Path where the output georeferenced segmentation will be saved.
81
+ labels_to_extract (list, optional): List of labels to extract. If None, extracts all labels.
82
+ dtype (str, optional): Data type to use for the output mask. Defaults to "uint8".
83
+ model_name (str, optional): Name of the Hugging Face model to use for segmentation,
84
+ such as "facebook/mask2former-swin-large-cityscapes-semantic". Defaults to None.
85
+ See https://huggingface.co/models?pipeline_tag=image-segmentation&sort=trending for options.
86
+ segmenter_args (dict, optional): Additional arguments to pass to the segmenter.
87
+ Defaults to None.
88
+ **kwargs: Additional keyword arguments to pass to the segmentation pipeline
89
+
90
+ Returns:
91
+ tuple: (Path to saved image, dictionary mapping label names to their assigned values,
92
+ dictionary mapping label names to confidence scores)
93
+ """
94
+ # Load the original georeferenced image to extract metadata
95
+ with rasterio.open(tif_path) as src:
96
+ # Save the metadata for later use
97
+ meta = src.meta.copy()
98
+ # Get the dimensions
99
+ height = src.height
100
+ width = src.width
101
+ # Get the transform and CRS for georeferencing
102
+ # transform = src.transform
103
+ # crs = src.crs
104
+
105
+ # Initialize the segmentation pipeline
106
+ if model_name is None:
107
+ model_name = "facebook/mask2former-swin-large-cityscapes-semantic"
108
+
109
+ kwargs["task"] = "image-segmentation"
110
+
111
+ segmenter = pipeline(model=model_name, **kwargs)
112
+
113
+ # Run the segmentation on the GeoTIFF
114
+ if segmenter_args is None:
115
+ segmenter_args = {}
116
+
117
+ segments = segmenter(tif_path, **segmenter_args)
118
+
119
+ # If no specific labels are requested, extract all available ones
120
+ if labels_to_extract is None:
121
+ labels_to_extract = [segment["label"] for segment in segments]
122
+
123
+ # Create an empty mask to hold all the labels
124
+ # Using uint8 for up to 255 classes, switch to uint16 for more
125
+ combined_mask = np.zeros((height, width), dtype=np.uint8)
126
+
127
+ # Create a dictionary to map labels to values and store scores
128
+ label_to_value = {}
129
+ label_to_score = {}
130
+
131
+ # Process each segment we want to keep
132
+ for i, segment in enumerate(
133
+ [s for s in segments if s["label"] in labels_to_extract]
134
+ ):
135
+ # Assign a unique value to each label (starting from 1)
136
+ value = i + 1
137
+ label = segment["label"]
138
+ score = segment["score"]
139
+
140
+ label_to_value[label] = value
141
+ label_to_score[label] = score
142
+
143
+ # Convert PIL image to numpy array
144
+ mask = np.array(segment["mask"])
145
+
146
+ # Apply a threshold if it's a probability mask (not binary)
147
+ if mask.dtype == float:
148
+ mask = (mask > 0.5).astype(np.uint8)
149
+
150
+ # Resize if needed to match original dimensions
151
+ if mask.shape != (height, width):
152
+ mask_img = Image.fromarray(mask)
153
+ mask_img = mask_img.resize((width, height))
154
+ mask = np.array(mask_img)
155
+
156
+ # Add this class to the combined mask
157
+ # Only overwrite if the pixel isn't already assigned to another class
158
+ # This handles overlapping segments by giving priority to earlier segments
159
+ combined_mask = np.where(
160
+ (mask > 0) & (combined_mask == 0), value, combined_mask
161
+ )
162
+
163
+ # Update metadata for the output raster
164
+ meta.update(
165
+ {
166
+ "count": 1, # One band for the mask
167
+ "dtype": dtype, # Use uint8 for up to 255 classes
168
+ "nodata": 0, # 0 represents no class
169
+ }
170
+ )
171
+
172
+ # Save the mask as a new georeferenced GeoTIFF
173
+ with rasterio.open(output_path, "w", **meta) as dst:
174
+ dst.write(combined_mask[np.newaxis, :, :]) # Add channel dimension
175
+
176
+ # Create a CSV colormap file with scores included
177
+ csv_path = os.path.splitext(output_path)[0] + "_colormap.csv"
178
+ with open(csv_path, "w", newline="") as csvfile:
179
+ fieldnames = ["ClassValue", "ClassName", "ConfidenceScore"]
180
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
181
+
182
+ writer.writeheader()
183
+ for label, value in label_to_value.items():
184
+ writer.writerow(
185
+ {
186
+ "ClassValue": value,
187
+ "ClassName": label,
188
+ "ConfidenceScore": f"{label_to_score[label]:.4f}",
189
+ }
190
+ )
191
+
192
+ return output_path, label_to_value, label_to_score
193
+
194
+
195
+ def mask_generation(
196
+ input_path: str,
197
+ output_mask_path: str,
198
+ output_csv_path: str,
199
+ model: str = "facebook/sam-vit-base",
200
+ confidence_threshold: float = 0.5,
201
+ points_per_side: int = 32,
202
+ crop_size: Optional[int] = None,
203
+ batch_size: int = 1,
204
+ band_indices: Optional[List[int]] = None,
205
+ min_object_size: int = 0,
206
+ generator_kwargs: Optional[Dict] = None,
207
+ **kwargs,
208
+ ) -> Tuple[str, str]:
209
+ """
210
+ Process a GeoTIFF using SAM mask generation and save results as a GeoTIFF and CSV.
211
+
212
+ The function reads a GeoTIFF image, applies the SAM mask generator from the
213
+ Hugging Face transformers pipeline, rasterizes the resulting masks to create
214
+ a labeled mask GeoTIFF, and saves mask scores and geometries to a CSV file.
215
+
216
+ Args:
217
+ input_path: Path to the input GeoTIFF image.
218
+ output_mask_path: Path where the output mask GeoTIFF will be saved.
219
+ output_csv_path: Path where the mask scores CSV will be saved.
220
+ model: HuggingFace model checkpoint for the SAM model.
221
+ confidence_threshold: Minimum confidence score for masks to be included.
222
+ points_per_side: Number of points to sample along each side of the image.
223
+ crop_size: Size of image crops for processing. If None, process the full image.
224
+ band_indices: List of band indices to use. If None, use all bands.
225
+ batch_size: Batch size for inference.
226
+ min_object_size: Minimum size in pixels for objects to be included. Smaller masks will be filtered out.
227
+ generator_kwargs: Additional keyword arguments to pass to the mask generator.
228
+
229
+ Returns:
230
+ Tuple containing the paths to the saved mask GeoTIFF and CSV file.
231
+
232
+ Raises:
233
+ ValueError: If the input file cannot be opened or processed.
234
+ RuntimeError: If mask generation fails.
235
+ """
236
+ # Set up the mask generator
237
+ print("Setting up mask generator...")
238
+ mask_generator = pipeline(model=model, task="mask-generation", **kwargs)
239
+
240
+ # Open the GeoTIFF file
241
+ try:
242
+ print(f"Reading input GeoTIFF: {input_path}")
243
+ with rasterio.open(input_path) as src:
244
+ # Read metadata
245
+ profile = src.profile
246
+ # transform = src.transform
247
+ # crs = src.crs
248
+
249
+ # Read the image data
250
+ if band_indices is not None:
251
+ print(f"Using specified bands: {band_indices}")
252
+ image_data = np.stack([src.read(i + 1) for i in band_indices])
253
+ else:
254
+ print("Using all bands")
255
+ image_data = src.read()
256
+
257
+ # Handle image with more than 3 bands (convert to RGB for visualization)
258
+ if image_data.shape[0] > 3:
259
+ print(
260
+ f"Converting {image_data.shape[0]} bands to RGB (using first 3 bands)"
261
+ )
262
+ # Select first three bands or perform other band combination
263
+ image_data = image_data[:3]
264
+ elif image_data.shape[0] == 1:
265
+ print("Duplicating single band to create 3-band image")
266
+ # Duplicate single band to create a 3-band image
267
+ image_data = np.vstack([image_data] * 3)
268
+
269
+ # Transpose to HWC format for the model
270
+ image_data = np.transpose(image_data, (1, 2, 0))
271
+
272
+ # Normalize the image if needed
273
+ if image_data.dtype != np.uint8:
274
+ print(f"Normalizing image from {image_data.dtype} to uint8")
275
+ image_data = (image_data / image_data.max() * 255).astype(np.uint8)
276
+ except Exception as e:
277
+ raise ValueError(f"Failed to open or process input GeoTIFF: {e}")
278
+
279
+ # Process the image with the mask generator
280
+ try:
281
+ # Convert numpy array to PIL Image for the pipeline
282
+ # Ensure the array is in the right format (HWC and uint8)
283
+ if image_data.dtype != np.uint8:
284
+ image_data = (image_data / image_data.max() * 255).astype(np.uint8)
285
+
286
+ # Create a PIL Image from the numpy array
287
+ print("Converting to PIL Image for mask generation")
288
+ pil_image = Image.fromarray(image_data)
289
+
290
+ # Use the SAM pipeline for mask generation
291
+ if generator_kwargs is None:
292
+ generator_kwargs = {}
293
+
294
+ print("Running mask generation...")
295
+ mask_results = mask_generator(
296
+ pil_image,
297
+ points_per_side=points_per_side,
298
+ crop_n_points_downscale_factor=1 if crop_size is None else 2,
299
+ point_grids=None,
300
+ pred_iou_thresh=confidence_threshold,
301
+ stability_score_thresh=confidence_threshold,
302
+ crops_n_layers=0 if crop_size is None else 1,
303
+ crop_overlap_ratio=0.5,
304
+ batch_size=batch_size,
305
+ **generator_kwargs,
306
+ )
307
+
308
+ print(
309
+ f"Number of initial masks: {len(mask_results['masks']) if isinstance(mask_results, dict) and 'masks' in mask_results else len(mask_results)}"
310
+ )
311
+
312
+ except Exception as e:
313
+ raise RuntimeError(f"Mask generation failed: {e}")
314
+
315
+ # Create a mask raster with unique IDs for each mask
316
+ mask_raster = np.zeros((image_data.shape[0], image_data.shape[1]), dtype=np.uint32)
317
+ mask_records = []
318
+
319
+ # Process each mask based on the structure of mask_results
320
+ if (
321
+ isinstance(mask_results, dict)
322
+ and "masks" in mask_results
323
+ and "scores" in mask_results
324
+ ):
325
+ # Handle dictionary with 'masks' and 'scores' lists
326
+ print("Processing masks...")
327
+ total_masks = len(mask_results["masks"])
328
+
329
+ # Create progress bar
330
+ for i, (mask_data, score) in enumerate(
331
+ tqdm(
332
+ zip(mask_results["masks"], mask_results["scores"]),
333
+ total=total_masks,
334
+ desc="Processing masks",
335
+ )
336
+ ):
337
+ mask_id = i + 1 # Start IDs at 1
338
+
339
+ # Convert to numpy if not already
340
+ if not isinstance(mask_data, np.ndarray):
341
+ # Try to convert from tensor or other format if needed
342
+ try:
343
+ mask_data = np.array(mask_data)
344
+ except:
345
+ print(f"Could not convert mask at index {i} to numpy array")
346
+ continue
347
+
348
+ mask_binary = mask_data.astype(bool)
349
+ area_pixels = np.sum(mask_binary)
350
+
351
+ # Skip if mask is smaller than the minimum size
352
+ if area_pixels < min_object_size:
353
+ continue
354
+
355
+ # Add the mask to the raster with a unique ID
356
+ mask_raster[mask_binary] = mask_id
357
+
358
+ # Create a record for the CSV - without geometry calculation
359
+ mask_records.append(
360
+ {"mask_id": mask_id, "score": float(score), "area_pixels": area_pixels}
361
+ )
362
+ elif isinstance(mask_results, list):
363
+ # Handle list of dictionaries format (SAM original format)
364
+ print("Processing masks...")
365
+ total_masks = len(mask_results)
366
+
367
+ # Create progress bar
368
+ for i, mask_result in enumerate(tqdm(mask_results, desc="Processing masks")):
369
+ mask_id = i + 1 # Start IDs at 1
370
+
371
+ # Try different possible key names for masks and scores
372
+ mask_data = None
373
+ score = None
374
+
375
+ if isinstance(mask_result, dict):
376
+ # Try to find mask data
377
+ if "segmentation" in mask_result:
378
+ mask_data = mask_result["segmentation"]
379
+ elif "mask" in mask_result:
380
+ mask_data = mask_result["mask"]
381
+
382
+ # Try to find score
383
+ if "score" in mask_result:
384
+ score = mask_result["score"]
385
+ elif "predicted_iou" in mask_result:
386
+ score = mask_result["predicted_iou"]
387
+ elif "stability_score" in mask_result:
388
+ score = mask_result["stability_score"]
389
+ else:
390
+ score = 1.0 # Default score if none found
391
+ else:
392
+ # If mask_result is not a dict, it might be the mask directly
393
+ try:
394
+ mask_data = np.array(mask_result)
395
+ score = 1.0 # Default score
396
+ except:
397
+ print(f"Could not process mask at index {i}")
398
+ continue
399
+
400
+ if mask_data is not None:
401
+ # Convert to numpy if not already
402
+ if not isinstance(mask_data, np.ndarray):
403
+ try:
404
+ mask_data = np.array(mask_data)
405
+ except:
406
+ print(f"Could not convert mask at index {i} to numpy array")
407
+ continue
408
+
409
+ mask_binary = mask_data.astype(bool)
410
+ area_pixels = np.sum(mask_binary)
411
+
412
+ # Skip if mask is smaller than the minimum size
413
+ if area_pixels < min_object_size:
414
+ continue
415
+
416
+ # Add the mask to the raster with a unique ID
417
+ mask_raster[mask_binary] = mask_id
418
+
419
+ # Create a record for the CSV - without geometry calculation
420
+ mask_records.append(
421
+ {
422
+ "mask_id": mask_id,
423
+ "score": float(score),
424
+ "area_pixels": area_pixels,
425
+ }
426
+ )
427
+ else:
428
+ # If we couldn't figure out the format, raise an error
429
+ raise ValueError(f"Unexpected format for mask_results: {type(mask_results)}")
430
+
431
+ print(f"Number of final masks (after size filtering): {len(mask_records)}")
432
+
433
+ # Save the mask raster as a GeoTIFF
434
+ print(f"Saving mask GeoTIFF to {output_mask_path}")
435
+ output_profile = profile.copy()
436
+ output_profile.update(dtype=rasterio.uint32, count=1, compress="lzw", nodata=0)
437
+
438
+ with rasterio.open(output_mask_path, "w", **output_profile) as dst:
439
+ dst.write(mask_raster.astype(rasterio.uint32), 1)
440
+
441
+ # Save the mask data as a CSV
442
+ print(f"Saving mask metadata to {output_csv_path}")
443
+ mask_df = pd.DataFrame(mask_records)
444
+ mask_df.to_csv(output_csv_path, index=False)
445
+
446
+ print("Processing complete!")
447
+ return output_mask_path, output_csv_path