geoai-py 0.3.2__py2.py3-none-any.whl → 0.3.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.3.2"
5
+ __version__ = "0.3.4"
6
6
 
7
7
 
8
8
  import os
geoai/extract.py CHANGED
@@ -13,7 +13,8 @@ import rasterio
13
13
  from rasterio.windows import Window
14
14
  from rasterio.features import shapes
15
15
  from huggingface_hub import hf_hub_download
16
- from .preprocess import get_raster_stats
16
+ import scipy.ndimage as ndimage
17
+ from .utils import get_raster_stats
17
18
 
18
19
  try:
19
20
  from torchgeo.datasets import NonGeoDataset
@@ -25,34 +26,74 @@ except ImportError as e:
25
26
 
26
27
  class CustomDataset(NonGeoDataset):
27
28
  """
28
- A TorchGeo dataset for object extraction.
29
- Using NonGeoDataset to avoid spatial indexing issues.
29
+ A TorchGeo dataset for object extraction with overlapping tiles support.
30
+
31
+ This dataset class creates overlapping image tiles for object detection,
32
+ ensuring complete coverage of the input raster including right and bottom edges.
33
+ It inherits from NonGeoDataset to avoid spatial indexing issues.
34
+
35
+ Attributes:
36
+ raster_path: Path to the input raster file.
37
+ chip_size: Size of image chips to extract (height, width).
38
+ overlap: Amount of overlap between adjacent tiles (0.0-1.0).
39
+ transforms: Transforms to apply to the image.
40
+ verbose: Whether to print detailed processing information.
41
+ stride_x: Horizontal stride between tiles based on overlap.
42
+ stride_y: Vertical stride between tiles based on overlap.
43
+ row_starts: Starting Y positions for each row of tiles.
44
+ col_starts: Starting X positions for each column of tiles.
45
+ crs: Coordinate reference system of the raster.
46
+ transform: Affine transform of the raster.
47
+ height: Height of the raster in pixels.
48
+ width: Width of the raster in pixels.
49
+ count: Number of bands in the raster.
50
+ bounds: Geographic bounds of the raster (west, south, east, north).
51
+ roi: Shapely box representing the region of interest.
52
+ rows: Number of rows of tiles.
53
+ cols: Number of columns of tiles.
54
+ raster_stats: Statistics of the raster.
30
55
  """
31
56
 
32
57
  def __init__(
33
- self, raster_path, chip_size=(512, 512), transforms=None, verbose=False
58
+ self,
59
+ raster_path,
60
+ chip_size=(512, 512),
61
+ overlap=0.5,
62
+ transforms=None,
63
+ verbose=False,
34
64
  ):
35
65
  """
36
- Initialize the dataset.
66
+ Initialize the dataset with overlapping tiles.
37
67
 
38
68
  Args:
39
- raster_path: Path to the input raster file
40
- chip_size: Size of image chips to extract (height, width)
41
- transforms: Transforms to apply to the image
42
- verbose: Whether to print detailed processing information
69
+ raster_path: Path to the input raster file.
70
+ chip_size: Size of image chips to extract (height, width). Default is (512, 512).
71
+ overlap: Amount of overlap between adjacent tiles (0.0-1.0). Default is 0.5 (50%).
72
+ transforms: Transforms to apply to the image. Default is None.
73
+ verbose: Whether to print detailed processing information. Default is False.
74
+
75
+ Raises:
76
+ ValueError: If overlap is too high resulting in non-positive stride.
43
77
  """
44
78
  super().__init__()
45
79
 
46
80
  # Initialize parameters
47
81
  self.raster_path = raster_path
48
82
  self.chip_size = chip_size
83
+ self.overlap = overlap
49
84
  self.transforms = transforms
50
85
  self.verbose = verbose
51
-
52
- # For tracking warnings about multi-band images
53
86
  self.warned_about_bands = False
54
87
 
55
- # Open raster and get metadata
88
+ # Calculate stride based on overlap
89
+ self.stride_x = int(chip_size[1] * (1 - overlap))
90
+ self.stride_y = int(chip_size[0] * (1 - overlap))
91
+
92
+ if self.stride_x <= 0 or self.stride_y <= 0:
93
+ raise ValueError(
94
+ f"Overlap {overlap} is too high, resulting in non-positive stride"
95
+ )
96
+
56
97
  with rasterio.open(self.raster_path) as src:
57
98
  self.crs = src.crs
58
99
  self.transform = src.transform
@@ -63,43 +104,78 @@ class CustomDataset(NonGeoDataset):
63
104
  # Define the bounds of the dataset
64
105
  west, south, east, north = src.bounds
65
106
  self.bounds = (west, south, east, north)
66
-
67
- # Define the ROI for the dataset
68
107
  self.roi = box(*self.bounds)
69
108
 
70
- # Calculate number of chips in each dimension
71
- # Use ceil division to ensure we cover the entire image
72
- self.rows = (self.height + self.chip_size[0] - 1) // self.chip_size[0]
73
- self.cols = (self.width + self.chip_size[1] - 1) // self.chip_size[1]
109
+ # Calculate starting positions for each tile
110
+ self.row_starts = []
111
+ self.col_starts = []
112
+
113
+ # Normal row starts using stride
114
+ for r in range((self.height - 1) // self.stride_y):
115
+ self.row_starts.append(r * self.stride_y)
116
+
117
+ # Add a special last row that ensures we reach the bottom edge
118
+ if self.height > self.chip_size[0]:
119
+ self.row_starts.append(max(0, self.height - self.chip_size[0]))
120
+ else:
121
+ # If the image is smaller than chip size, just start at 0
122
+ if not self.row_starts:
123
+ self.row_starts.append(0)
124
+
125
+ # Normal column starts using stride
126
+ for c in range((self.width - 1) // self.stride_x):
127
+ self.col_starts.append(c * self.stride_x)
128
+
129
+ # Add a special last column that ensures we reach the right edge
130
+ if self.width > self.chip_size[1]:
131
+ self.col_starts.append(max(0, self.width - self.chip_size[1]))
132
+ else:
133
+ # If the image is smaller than chip size, just start at 0
134
+ if not self.col_starts:
135
+ self.col_starts.append(0)
136
+
137
+ # Update rows and cols based on actual starting positions
138
+ self.rows = len(self.row_starts)
139
+ self.cols = len(self.col_starts)
74
140
 
75
141
  print(
76
142
  f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
77
143
  )
78
144
  print(f"Image dimensions: {self.width} x {self.height} pixels")
79
145
  print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
146
+ print(
147
+ f"Overlap: {overlap*100}% (stride_x={self.stride_x}, stride_y={self.stride_y})"
148
+ )
80
149
  if src.crs:
81
150
  print(f"CRS: {src.crs}")
82
151
 
83
- # get raster stats
152
+ # Get raster stats
84
153
  self.raster_stats = get_raster_stats(raster_path, divide_by=255)
85
154
 
86
155
  def __getitem__(self, idx):
87
156
  """
88
157
  Get an image chip from the dataset by index.
89
158
 
159
+ Retrieves an image tile with the specified overlap pattern, ensuring
160
+ proper coverage of the entire raster including edges.
161
+
90
162
  Args:
91
- idx: Index of the chip
163
+ idx: Index of the chip to retrieve.
92
164
 
93
165
  Returns:
94
- Dict containing image tensor
166
+ dict: Dictionary containing:
167
+ - image: Image tensor.
168
+ - bbox: Geographic bounding box for the window.
169
+ - coords: Pixel coordinates as tensor [i, j].
170
+ - window_size: Window size as tensor [width, height].
95
171
  """
96
172
  # Convert flat index to grid position
97
173
  row = idx // self.cols
98
174
  col = idx % self.cols
99
175
 
100
- # Calculate pixel coordinates
101
- i = col * self.chip_size[1]
102
- j = row * self.chip_size[0]
176
+ # Get pre-calculated starting positions
177
+ j = self.row_starts[row]
178
+ i = self.col_starts[col]
103
179
 
104
180
  # Read window from raster
105
181
  with rasterio.open(self.raster_path) as src:
@@ -166,7 +242,12 @@ class CustomDataset(NonGeoDataset):
166
242
  }
167
243
 
168
244
  def __len__(self):
169
- """Return the number of samples in the dataset."""
245
+ """
246
+ Return the number of samples in the dataset.
247
+
248
+ Returns:
249
+ int: Total number of tiles in the dataset.
250
+ """
170
251
  return self.rows * self.cols
171
252
 
172
253
 
@@ -196,7 +277,8 @@ class ObjectDetector:
196
277
  self.overlap = 0.25 # Default overlap between tiles
197
278
  self.confidence_threshold = 0.5 # Default confidence threshold
198
279
  self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
199
- self.small_object_area = 100 # Minimum area in pixels to keep an object
280
+ self.min_object_area = 100 # Minimum area in pixels to keep an object
281
+ self.max_object_area = None # Maximum area in pixels to keep an object
200
282
  self.mask_threshold = 0.5 # Threshold for mask binarization
201
283
  self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
202
284
 
@@ -326,7 +408,8 @@ class ObjectDetector:
326
408
  **kwargs: Optional parameters:
327
409
  simplify_tolerance: Tolerance for polygon simplification
328
410
  mask_threshold: Threshold for mask binarization
329
- small_object_area: Minimum area in pixels to keep an object
411
+ min_object_area: Minimum area in pixels to keep an object
412
+ max_object_area: Maximum area in pixels to keep an object
330
413
 
331
414
  Returns:
332
415
  List of polygons as lists of (x, y) coordinates
@@ -335,7 +418,8 @@ class ObjectDetector:
335
418
  # Get parameters from kwargs or use instance defaults
336
419
  simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
337
420
  mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
338
- small_object_area = kwargs.get("small_object_area", self.small_object_area)
421
+ min_object_area = kwargs.get("min_object_area", self.min_object_area)
422
+ max_object_area = kwargs.get("max_object_area", self.max_object_area)
339
423
 
340
424
  # Ensure binary mask
341
425
  mask = (mask > mask_threshold).astype(np.uint8)
@@ -351,7 +435,14 @@ class ObjectDetector:
351
435
  polygons = []
352
436
  for contour in contours:
353
437
  # Filter out too small contours
354
- if contour.shape[0] < 3 or cv2.contourArea(contour) < small_object_area:
438
+ if contour.shape[0] < 3 or cv2.contourArea(contour) < min_object_area:
439
+ continue
440
+
441
+ # Filter out too large contours
442
+ if (
443
+ max_object_area is not None
444
+ and cv2.contourArea(contour) > max_object_area
445
+ ):
355
446
  continue
356
447
 
357
448
  # Simplify contour if it has many points
@@ -491,7 +582,8 @@ class ObjectDetector:
491
582
  output_path=None,
492
583
  simplify_tolerance=None,
493
584
  mask_threshold=None,
494
- small_object_area=None,
585
+ min_object_area=None,
586
+ max_object_area=None,
495
587
  nms_iou_threshold=None,
496
588
  regularize=True,
497
589
  angle_threshold=15,
@@ -505,7 +597,8 @@ class ObjectDetector:
505
597
  output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
506
598
  simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
507
599
  mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
508
- small_object_area: Minimum area in pixels to keep an object (default: self.small_object_area)
600
+ min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
601
+ max_object_area: Minimum area in pixels to keep an object (default: self.max_object_area)
509
602
  nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
510
603
  regularize: Whether to regularize objects to right angles (default: True)
511
604
  angle_threshold: Maximum deviation from 90 degrees for regularization (default: 15)
@@ -523,10 +616,11 @@ class ObjectDetector:
523
616
  mask_threshold = (
524
617
  mask_threshold if mask_threshold is not None else self.mask_threshold
525
618
  )
526
- small_object_area = (
527
- small_object_area
528
- if small_object_area is not None
529
- else self.small_object_area
619
+ min_object_area = (
620
+ min_object_area if min_object_area is not None else self.min_object_area
621
+ )
622
+ max_object_area = (
623
+ max_object_area if max_object_area is not None else self.max_object_area
530
624
  )
531
625
  nms_iou_threshold = (
532
626
  nms_iou_threshold
@@ -540,7 +634,8 @@ class ObjectDetector:
540
634
 
541
635
  print(f"Converting mask to GeoJSON with parameters:")
542
636
  print(f"- Mask threshold: {mask_threshold}")
543
- print(f"- Min object area: {small_object_area}")
637
+ print(f"- Min object area: {min_object_area}")
638
+ print(f"- Max object area: {max_object_area}")
544
639
  print(f"- Simplify tolerance: {simplify_tolerance}")
545
640
  print(f"- NMS IoU threshold: {nms_iou_threshold}")
546
641
  print(f"- Regularize objects: {regularize}")
@@ -586,7 +681,11 @@ class ObjectDetector:
586
681
  area = stats[i, cv2.CC_STAT_AREA]
587
682
 
588
683
  # Skip if too small
589
- if area < small_object_area:
684
+ if area < min_object_area:
685
+ continue
686
+
687
+ # Skip if too large
688
+ if max_object_area is not None and area > max_object_area:
590
689
  continue
591
690
 
592
691
  # Create a mask for this object
@@ -710,7 +809,7 @@ class ObjectDetector:
710
809
  chip_size: Size of image chips for processing (height, width)
711
810
  nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
712
811
  mask_threshold: Threshold for mask binarization (0.0-1.0)
713
- small_object_area: Minimum area in pixels to keep an object
812
+ min_object_area: Minimum area in pixels to keep an object
714
813
  simplify_tolerance: Tolerance for polygon simplification
715
814
 
716
815
  Returns:
@@ -724,7 +823,8 @@ class ObjectDetector:
724
823
  chip_size = kwargs.get("chip_size", self.chip_size)
725
824
  nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
726
825
  mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
727
- small_object_area = kwargs.get("small_object_area", self.small_object_area)
826
+ min_object_area = kwargs.get("min_object_area", self.min_object_area)
827
+ max_object_area = kwargs.get("max_object_area", self.max_object_area)
728
828
  simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
729
829
 
730
830
  # Print parameters being used
@@ -734,14 +834,17 @@ class ObjectDetector:
734
834
  print(f"- Chip size: {chip_size}")
735
835
  print(f"- NMS IoU threshold: {nms_iou_threshold}")
736
836
  print(f"- Mask threshold: {mask_threshold}")
737
- print(f"- Min object area: {small_object_area}")
837
+ print(f"- Min object area: {min_object_area}")
838
+ print(f"- Max object area: {max_object_area}")
738
839
  print(f"- Simplify tolerance: {simplify_tolerance}")
739
840
  print(f"- Filter edge objects: {filter_edges}")
740
841
  if filter_edges:
741
842
  print(f"- Edge buffer size: {edge_buffer} pixels")
742
843
 
743
844
  # Create dataset
744
- dataset = CustomDataset(raster_path=raster_path, chip_size=chip_size)
845
+ dataset = CustomDataset(
846
+ raster_path=raster_path, chip_size=chip_size, overlap=overlap
847
+ )
745
848
  self.raster_stats = dataset.raster_stats
746
849
 
747
850
  # Custom collate function to handle Shapely objects
@@ -865,7 +968,8 @@ class ObjectDetector:
865
968
  binary_mask,
866
969
  simplify_tolerance=simplify_tolerance,
867
970
  mask_threshold=mask_threshold,
868
- small_object_area=small_object_area,
971
+ min_object_area=min_object_area,
972
+ max_object_area=max_object_area,
869
973
  )
870
974
 
871
975
  # Skip if no valid polygons
@@ -948,6 +1052,7 @@ class ObjectDetector:
948
1052
  )
949
1053
  chip_size = kwargs.get("chip_size", self.chip_size)
950
1054
  mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
1055
+ overlap = kwargs.get("overlap", self.overlap)
951
1056
 
952
1057
  # Set default output path if not provided
953
1058
  if output_path is None:
@@ -961,7 +1066,10 @@ class ObjectDetector:
961
1066
 
962
1067
  # Create dataset
963
1068
  dataset = CustomDataset(
964
- raster_path=raster_path, chip_size=chip_size, verbose=verbose
1069
+ raster_path=raster_path,
1070
+ chip_size=chip_size,
1071
+ overlap=overlap,
1072
+ verbose=verbose,
965
1073
  )
966
1074
 
967
1075
  # Store a flag to avoid repetitive messages
@@ -1836,8 +1944,7 @@ class CarDetector(ObjectDetector):
1836
1944
  """
1837
1945
  Car detection using a pre-trained Mask R-CNN model.
1838
1946
 
1839
- This class extends the
1840
- `ObjectDetector` class with additional methods for car detection."
1947
+ This class extends the `ObjectDetector` class with additional methods for car detection.
1841
1948
  """
1842
1949
 
1843
1950
  def __init__(
@@ -1856,6 +1963,274 @@ class CarDetector(ObjectDetector):
1856
1963
  model_path=model_path, repo_id=repo_id, model=model, device=device
1857
1964
  )
1858
1965
 
1966
+ def generate_masks(
1967
+ self,
1968
+ raster_path,
1969
+ output_path=None,
1970
+ confidence_threshold=None,
1971
+ mask_threshold=None,
1972
+ overlap=0.25,
1973
+ batch_size=4,
1974
+ verbose=False,
1975
+ ):
1976
+ """
1977
+ Save masks with confidence values as a multi-band GeoTIFF.
1978
+
1979
+ Args:
1980
+ raster_path: Path to input raster
1981
+ output_path: Path for output GeoTIFF
1982
+ confidence_threshold: Minimum confidence score (0.0-1.0)
1983
+ mask_threshold: Threshold for mask binarization (0.0-1.0)
1984
+ overlap: Overlap between tiles (0.0-1.0)
1985
+ batch_size: Batch size for processing
1986
+ verbose: Whether to print detailed processing information
1987
+
1988
+ Returns:
1989
+ Path to the saved GeoTIFF
1990
+ """
1991
+ # Use provided thresholds or fall back to instance defaults
1992
+ if confidence_threshold is None:
1993
+ confidence_threshold = self.confidence_threshold
1994
+ if mask_threshold is None:
1995
+ mask_threshold = self.mask_threshold
1996
+
1997
+ # Default output path
1998
+ if output_path is None:
1999
+ output_path = os.path.splitext(raster_path)[0] + "_masks_conf.tif"
2000
+
2001
+ # Process the raster to get individual masks with confidence
2002
+ with rasterio.open(raster_path) as src:
2003
+ # Create dataset with the specified overlap
2004
+ dataset = CustomDataset(
2005
+ raster_path=raster_path,
2006
+ chip_size=self.chip_size,
2007
+ overlap=overlap,
2008
+ verbose=verbose,
2009
+ )
2010
+
2011
+ # Create output profile
2012
+ output_profile = src.profile.copy()
2013
+ output_profile.update(
2014
+ dtype=rasterio.uint8,
2015
+ count=2, # Two bands: mask and confidence
2016
+ compress="lzw",
2017
+ nodata=0,
2018
+ )
2019
+
2020
+ # Initialize mask and confidence arrays
2021
+ mask_array = np.zeros((src.height, src.width), dtype=np.uint8)
2022
+ conf_array = np.zeros((src.height, src.width), dtype=np.uint8)
2023
+
2024
+ # Define custom collate function to handle Shapely objects
2025
+ def custom_collate(batch):
2026
+ """
2027
+ Custom collate function that handles Shapely geometries
2028
+ by keeping them as Python objects rather than trying to collate them.
2029
+ """
2030
+ elem = batch[0]
2031
+ if isinstance(elem, dict):
2032
+ result = {}
2033
+ for key in elem:
2034
+ if key == "bbox":
2035
+ # Don't collate shapely objects, keep as list
2036
+ result[key] = [d[key] for d in batch]
2037
+ else:
2038
+ # For tensors and other collatable types
2039
+ try:
2040
+ result[key] = (
2041
+ torch.utils.data._utils.collate.default_collate(
2042
+ [d[key] for d in batch]
2043
+ )
2044
+ )
2045
+ except TypeError:
2046
+ # Fall back to list for non-collatable types
2047
+ result[key] = [d[key] for d in batch]
2048
+ return result
2049
+ else:
2050
+ # Default collate for non-dict types
2051
+ return torch.utils.data._utils.collate.default_collate(batch)
2052
+
2053
+ # Create dataloader with custom collate function
2054
+ dataloader = torch.utils.data.DataLoader(
2055
+ dataset,
2056
+ batch_size=batch_size,
2057
+ shuffle=False,
2058
+ num_workers=0,
2059
+ collate_fn=custom_collate,
2060
+ )
2061
+
2062
+ # Process batches
2063
+ print(f"Processing raster with {len(dataloader)} batches")
2064
+ for batch in tqdm(dataloader):
2065
+ # Move images to device
2066
+ images = batch["image"].to(self.device)
2067
+ coords = batch["coords"] # Tensor of shape [batch_size, 2]
2068
+
2069
+ # Run inference
2070
+ with torch.no_grad():
2071
+ predictions = self.model(images)
2072
+
2073
+ # Process predictions
2074
+ for idx, prediction in enumerate(predictions):
2075
+ masks = prediction["masks"].cpu().numpy()
2076
+ scores = prediction["scores"].cpu().numpy()
2077
+
2078
+ # Filter by confidence threshold
2079
+ valid_indices = scores >= confidence_threshold
2080
+ masks = masks[valid_indices]
2081
+ scores = scores[valid_indices]
2082
+
2083
+ # Skip if no valid predictions
2084
+ if len(masks) == 0:
2085
+ continue
2086
+
2087
+ # Get window coordinates
2088
+ i, j = coords[idx].cpu().numpy()
2089
+
2090
+ # Process each mask
2091
+ for mask_idx, mask in enumerate(masks):
2092
+ # Convert to binary mask
2093
+ binary_mask = (mask[0] > mask_threshold).astype(np.uint8) * 255
2094
+ conf_value = int(scores[mask_idx] * 255) # Scale to 0-255
2095
+
2096
+ # Update the mask and confidence arrays
2097
+ h, w = binary_mask.shape
2098
+ valid_h = min(h, src.height - j)
2099
+ valid_w = min(w, src.width - i)
2100
+
2101
+ if valid_h > 0 and valid_w > 0:
2102
+ # Use maximum for overlapping regions in the mask
2103
+ mask_array[j : j + valid_h, i : i + valid_w] = np.maximum(
2104
+ mask_array[j : j + valid_h, i : i + valid_w],
2105
+ binary_mask[:valid_h, :valid_w],
2106
+ )
2107
+
2108
+ # For confidence, only update where mask is positive
2109
+ # and confidence is higher than existing
2110
+ mask_region = binary_mask[:valid_h, :valid_w] > 0
2111
+ if np.any(mask_region):
2112
+ # Only update where mask is positive and new confidence is higher
2113
+ current_conf = conf_array[
2114
+ j : j + valid_h, i : i + valid_w
2115
+ ]
2116
+
2117
+ # Where to update confidence (mask positive & higher confidence)
2118
+ update_mask = np.logical_and(
2119
+ mask_region,
2120
+ np.logical_or(
2121
+ current_conf == 0, current_conf < conf_value
2122
+ ),
2123
+ )
2124
+
2125
+ if np.any(update_mask):
2126
+ conf_array[j : j + valid_h, i : i + valid_w][
2127
+ update_mask
2128
+ ] = conf_value
2129
+
2130
+ # Write to GeoTIFF
2131
+ with rasterio.open(output_path, "w", **output_profile) as dst:
2132
+ dst.write(mask_array, 1)
2133
+ dst.write(conf_array, 2)
2134
+
2135
+ print(f"Masks with confidence values saved to {output_path}")
2136
+ return output_path
2137
+
2138
+ def vectorize_masks(self, masks_path, output_path=None, **kwargs):
2139
+ """
2140
+ Convert masks with confidence to vector polygons.
2141
+
2142
+ Args:
2143
+ masks_path: Path to masks GeoTIFF with confidence band
2144
+ output_path: Path for output GeoJSON
2145
+ **kwargs: Additional parameters
2146
+
2147
+ Returns:
2148
+ GeoDataFrame with car detections and confidence values
2149
+ """
2150
+
2151
+ print(f"Processing masks from: {masks_path}")
2152
+
2153
+ with rasterio.open(masks_path) as src:
2154
+ # Read mask and confidence bands
2155
+ mask_data = src.read(1)
2156
+ conf_data = src.read(2)
2157
+ transform = src.transform
2158
+ crs = src.crs
2159
+
2160
+ # Convert to binary mask
2161
+ binary_mask = mask_data > 0
2162
+
2163
+ # Find connected components
2164
+ labeled_mask, num_features = ndimage.label(binary_mask)
2165
+ print(f"Found {num_features} connected components")
2166
+
2167
+ # Process each component
2168
+ car_polygons = []
2169
+ car_confidences = []
2170
+
2171
+ # Add progress bar
2172
+ for label in tqdm(range(1, num_features + 1), desc="Processing components"):
2173
+ # Create mask for this component
2174
+ component_mask = (labeled_mask == label).astype(np.uint8)
2175
+
2176
+ # Get confidence value (mean of non-zero values in this region)
2177
+ conf_region = conf_data[component_mask > 0]
2178
+ if len(conf_region) > 0:
2179
+ confidence = (
2180
+ np.mean(conf_region) / 255.0
2181
+ ) # Convert back to 0-1 range
2182
+ else:
2183
+ confidence = 0.0
2184
+
2185
+ # Find contours
2186
+ contours, _ = cv2.findContours(
2187
+ component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
2188
+ )
2189
+
2190
+ for contour in contours:
2191
+ # Filter by size
2192
+ area = cv2.contourArea(contour)
2193
+ if area < kwargs.get("min_object_area", 100):
2194
+ continue
2195
+
2196
+ # Get minimum area rectangle
2197
+ rect = cv2.minAreaRect(contour)
2198
+ box_points = cv2.boxPoints(rect)
2199
+
2200
+ # Convert to geographic coordinates
2201
+ geo_points = []
2202
+ for x, y in box_points:
2203
+ gx, gy = transform * (x, y)
2204
+ geo_points.append((gx, gy))
2205
+
2206
+ # Create polygon
2207
+ poly = Polygon(geo_points)
2208
+
2209
+ # Add to lists
2210
+ car_polygons.append(poly)
2211
+ car_confidences.append(confidence)
2212
+
2213
+ # Create GeoDataFrame
2214
+ if car_polygons:
2215
+ gdf = gpd.GeoDataFrame(
2216
+ {
2217
+ "geometry": car_polygons,
2218
+ "confidence": car_confidences,
2219
+ "class": [1] * len(car_polygons),
2220
+ },
2221
+ crs=crs,
2222
+ )
2223
+
2224
+ # Save to file if requested
2225
+ if output_path:
2226
+ gdf.to_file(output_path, driver="GeoJSON")
2227
+ print(f"Saved {len(gdf)} cars with confidence to {output_path}")
2228
+
2229
+ return gdf
2230
+ else:
2231
+ print("No valid car polygons found")
2232
+ return None
2233
+
1859
2234
 
1860
2235
  class ShipDetector(ObjectDetector):
1861
2236
  """
geoai/geoai.py CHANGED
@@ -1,5 +1,4 @@
1
1
  """Main module."""
2
2
 
3
3
  from .utils import *
4
- from .preprocess import *
5
4
  from .extract import *