geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1568 @@
1
+ """Change detection module for remote sensing imagery using torchange."""
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import rasterio
10
+ from rasterio.windows import from_bounds
11
+ from skimage.transform import resize
12
+ from torchange.models.segment_any_change import AnyChange, show_change_masks
13
+
14
+ from .utils import download_file
15
+
16
+
17
+ class ChangeDetection:
18
+ """A class for change detection on geospatial imagery using torchange and SAM."""
19
+
20
+ def __init__(self, sam_model_type="vit_h", sam_checkpoint=None):
21
+ """
22
+ Initialize the ChangeDetection class.
23
+
24
+ Args:
25
+ sam_model_type (str): SAM model type ('vit_h', 'vit_l', 'vit_b')
26
+ sam_checkpoint (str): Path to SAM checkpoint file
27
+ """
28
+ self.sam_model_type = sam_model_type
29
+ self.sam_checkpoint = sam_checkpoint
30
+ self.model = None
31
+ self._init_model()
32
+
33
+ def _init_model(self):
34
+ """Initialize the AnyChange model."""
35
+ if self.sam_checkpoint is None:
36
+ self.sam_checkpoint = download_checkpoint(self.sam_model_type)
37
+
38
+ self.model = AnyChange(self.sam_model_type, sam_checkpoint=self.sam_checkpoint)
39
+
40
+ # Set default hyperparameters
41
+ self.model.make_mask_generator(
42
+ points_per_side=32,
43
+ stability_score_thresh=0.95,
44
+ )
45
+ self.model.set_hyperparameters(
46
+ change_confidence_threshold=145,
47
+ use_normalized_feature=True,
48
+ bitemporal_match=True,
49
+ )
50
+
51
+ def set_hyperparameters(
52
+ self,
53
+ change_confidence_threshold: int = 155,
54
+ auto_threshold: bool = False,
55
+ use_normalized_feature: bool = True,
56
+ area_thresh: float = 0.8,
57
+ match_hist: bool = False,
58
+ object_sim_thresh: int = 60,
59
+ bitemporal_match: bool = True,
60
+ **kwargs: Any,
61
+ ) -> None:
62
+ """
63
+ Set hyperparameters for the change detection model.
64
+
65
+ Args:
66
+ change_confidence_threshold (int): Change confidence threshold for SAM
67
+ auto_threshold (bool): Whether to use auto threshold for SAM
68
+ use_normalized_feature (bool): Whether to use normalized feature for SAM
69
+ area_thresh (float): Area threshold for SAM
70
+ match_hist (bool): Whether to use match hist for SAM
71
+ object_sim_thresh (int): Object similarity threshold for SAM
72
+ bitemporal_match (bool): Whether to use bitemporal match for SAM
73
+ **kwargs: Keyword arguments for model hyperparameters
74
+ """
75
+ if self.model:
76
+ self.model.set_hyperparameters(
77
+ change_confidence_threshold=change_confidence_threshold,
78
+ auto_threshold=auto_threshold,
79
+ use_normalized_feature=use_normalized_feature,
80
+ area_thresh=area_thresh,
81
+ match_hist=match_hist,
82
+ object_sim_thresh=object_sim_thresh,
83
+ bitemporal_match=bitemporal_match,
84
+ **kwargs,
85
+ )
86
+
87
+ def set_mask_generator_params(
88
+ self,
89
+ points_per_side: int = 32,
90
+ points_per_batch: int = 64,
91
+ pred_iou_thresh: float = 0.5,
92
+ stability_score_thresh: float = 0.95,
93
+ stability_score_offset: float = 1.0,
94
+ box_nms_thresh: float = 0.7,
95
+ point_grids: Optional[List] = None,
96
+ min_mask_region_area: int = 0,
97
+ **kwargs: Any,
98
+ ) -> None:
99
+ """
100
+ Set mask generator parameters.
101
+
102
+ Args:
103
+ points_per_side (int): Number of points per side for SAM
104
+ points_per_batch (int): Number of points per batch for SAM
105
+ pred_iou_thresh (float): IoU threshold for SAM
106
+ stability_score_thresh (float): Stability score threshold for SAM
107
+ stability_score_offset (float): Stability score offset for SAM
108
+ box_nms_thresh (float): NMS threshold for SAM
109
+ point_grids (list): Point grids for SAM
110
+ min_mask_region_area (int): Minimum mask region area for SAM
111
+ **kwargs: Keyword arguments for mask generator
112
+ """
113
+ if self.model:
114
+ self.model.make_mask_generator(
115
+ points_per_side=points_per_side,
116
+ points_per_batch=points_per_batch,
117
+ pred_iou_thresh=pred_iou_thresh,
118
+ stability_score_thresh=stability_score_thresh,
119
+ stability_score_offset=stability_score_offset,
120
+ box_nms_thresh=box_nms_thresh,
121
+ point_grids=point_grids,
122
+ min_mask_region_area=min_mask_region_area,
123
+ **kwargs,
124
+ )
125
+
126
+ def _read_and_align_images(self, image1_path, image2_path, target_size=1024):
127
+ """
128
+ Read and align two GeoTIFF images, handling different extents and projections.
129
+
130
+ Args:
131
+ image1_path (str): Path to first image
132
+ image2_path (str): Path to second image
133
+ target_size (int): Target size for processing (default 1024 for torchange)
134
+
135
+ Returns:
136
+ tuple: (aligned_img1, aligned_img2, transform, crs, bounds)
137
+ """
138
+ with rasterio.open(image1_path) as src1, rasterio.open(image2_path) as src2:
139
+ # Get the intersection of bounds
140
+ bounds1 = src1.bounds
141
+ bounds2 = src2.bounds
142
+
143
+ # Calculate intersection bounds
144
+ left = max(bounds1.left, bounds2.left)
145
+ bottom = max(bounds1.bottom, bounds2.bottom)
146
+ right = min(bounds1.right, bounds2.right)
147
+ top = min(bounds1.top, bounds2.top)
148
+
149
+ if left >= right or bottom >= top:
150
+ raise ValueError("Images do not overlap")
151
+
152
+ intersection_bounds = (left, bottom, right, top)
153
+
154
+ # Read the intersecting area from both images
155
+ window1 = from_bounds(*intersection_bounds, src1.transform)
156
+ window2 = from_bounds(*intersection_bounds, src2.transform)
157
+
158
+ # Read data
159
+ img1_data = src1.read(window=window1)
160
+ img2_data = src2.read(window=window2)
161
+
162
+ # Get transform for the intersecting area
163
+ transform = src1.window_transform(window1)
164
+ crs = src1.crs
165
+
166
+ # Convert from (bands, height, width) to (height, width, bands)
167
+ img1_data = np.transpose(img1_data, (1, 2, 0))
168
+ img2_data = np.transpose(img2_data, (1, 2, 0))
169
+
170
+ # Use only RGB bands (first 3 channels) for torchange
171
+ if img1_data.shape[2] >= 3:
172
+ img1_data = img1_data[:, :, :3]
173
+ if img2_data.shape[2] >= 3:
174
+ img2_data = img2_data[:, :, :3]
175
+
176
+ # Normalize to 0-255 range if needed
177
+ if img1_data.dtype != np.uint8:
178
+ img1_data = (
179
+ (img1_data - img1_data.min())
180
+ / (img1_data.max() - img1_data.min())
181
+ * 255
182
+ ).astype(np.uint8)
183
+ if img2_data.dtype != np.uint8:
184
+ img2_data = (
185
+ (img2_data - img2_data.min())
186
+ / (img2_data.max() - img2_data.min())
187
+ * 255
188
+ ).astype(np.uint8)
189
+
190
+ # Store original size for later use
191
+ original_shape = img1_data.shape[:2]
192
+
193
+ # Resize to target size for torchange processing
194
+ if img1_data.shape[0] != target_size or img1_data.shape[1] != target_size:
195
+ img1_resized = resize(
196
+ img1_data, (target_size, target_size), preserve_range=True
197
+ ).astype(np.uint8)
198
+ img2_resized = resize(
199
+ img2_data, (target_size, target_size), preserve_range=True
200
+ ).astype(np.uint8)
201
+ else:
202
+ img1_resized = img1_data
203
+ img2_resized = img2_data
204
+
205
+ return (img1_resized, img2_resized, transform, crs, original_shape)
206
+
207
+ def detect_changes(
208
+ self,
209
+ image1_path: str,
210
+ image2_path: str,
211
+ output_path: Optional[str] = None,
212
+ target_size: int = 1024,
213
+ return_results: bool = True,
214
+ export_probability: bool = False,
215
+ probability_output_path: Optional[str] = None,
216
+ export_instance_masks: bool = False,
217
+ instance_masks_output_path: Optional[str] = None,
218
+ return_detailed_results: bool = False,
219
+ ) -> Union[Tuple[Any, np.ndarray, np.ndarray], Dict[str, Any], None]:
220
+ """
221
+ Detect changes between two GeoTIFF images with instance segmentation.
222
+
223
+ Args:
224
+ image1_path (str): Path to first image
225
+ image2_path (str): Path to second image
226
+ output_path (str): Optional path to save binary change mask as GeoTIFF
227
+ target_size (int): Target size for processing
228
+ return_results (bool): Whether to return results
229
+ export_probability (bool): Whether to export probability mask
230
+ probability_output_path (str): Path to save probability mask (required if export_probability=True)
231
+ export_instance_masks (bool): Whether to export instance segmentation masks
232
+ instance_masks_output_path (str): Path to save instance masks (required if export_instance_masks=True)
233
+ return_detailed_results (bool): Whether to return detailed mask information
234
+
235
+ Returns:
236
+ tuple: (change_masks, img1, img2) if return_results=True
237
+ dict: Detailed results if return_detailed_results=True
238
+ """
239
+ # Read and align images
240
+ (img1, img2, transform, crs, original_shape) = self._read_and_align_images(
241
+ image1_path, image2_path, target_size
242
+ )
243
+
244
+ # Detect changes
245
+ change_masks, _, _ = self.model.forward(img1, img2)
246
+
247
+ # If output path specified, save binary mask as GeoTIFF
248
+ if output_path:
249
+ self._save_change_mask(
250
+ change_masks, output_path, transform, crs, original_shape, target_size
251
+ )
252
+
253
+ # If probability export requested, save probability mask
254
+ if export_probability:
255
+ if probability_output_path is None:
256
+ raise ValueError(
257
+ "probability_output_path must be specified when export_probability=True"
258
+ )
259
+ self._save_probability_mask(
260
+ change_masks,
261
+ probability_output_path,
262
+ transform,
263
+ crs,
264
+ original_shape,
265
+ target_size,
266
+ )
267
+
268
+ # If instance masks export requested, save instance segmentation masks
269
+ if export_instance_masks:
270
+ if instance_masks_output_path is None:
271
+ raise ValueError(
272
+ "instance_masks_output_path must be specified when export_instance_masks=True"
273
+ )
274
+ num_instances = self._save_instance_segmentation_masks(
275
+ change_masks,
276
+ instance_masks_output_path,
277
+ transform,
278
+ crs,
279
+ original_shape,
280
+ target_size,
281
+ )
282
+
283
+ # Also save instance scores if requested
284
+ scores_path = instance_masks_output_path.replace(".tif", "_scores.tif")
285
+ self._save_instance_scores_mask(
286
+ change_masks,
287
+ scores_path,
288
+ transform,
289
+ crs,
290
+ original_shape,
291
+ target_size,
292
+ )
293
+
294
+ # Return detailed results if requested
295
+ if return_detailed_results:
296
+ return self._extract_detailed_results(
297
+ change_masks, transform, crs, original_shape, target_size
298
+ )
299
+
300
+ if return_results:
301
+ return change_masks, img1, img2
302
+
303
+ def _save_change_mask(
304
+ self, change_masks, output_path, transform, crs, original_shape, target_size
305
+ ):
306
+ """
307
+ Save change masks as a GeoTIFF with proper georeference.
308
+
309
+ Args:
310
+ change_masks: Change detection masks (MaskData object)
311
+ output_path (str): Output file path
312
+ transform: Rasterio transform
313
+ crs: Coordinate reference system
314
+ original_shape (tuple): Original image shape
315
+ target_size (int): Processing target size
316
+ """
317
+ # Convert MaskData to binary mask by decoding RLE masks
318
+ combined_mask = np.zeros((target_size, target_size), dtype=bool)
319
+
320
+ # Extract RLE masks from MaskData object
321
+ mask_items = dict(change_masks.items())
322
+ if "rles" in mask_items:
323
+ rles = mask_items["rles"]
324
+ for rle in rles:
325
+ if isinstance(rle, dict) and "size" in rle and "counts" in rle:
326
+ try:
327
+ # Decode RLE to binary mask
328
+ size = rle["size"]
329
+ counts = rle["counts"]
330
+
331
+ # Create binary mask from RLE counts
332
+ mask = np.zeros(size[0] * size[1], dtype=np.uint8)
333
+ pos = 0
334
+ value = 0
335
+
336
+ for count in counts:
337
+ if pos + count <= len(mask):
338
+ if value == 1:
339
+ mask[pos : pos + count] = 1
340
+ pos += count
341
+ value = 1 - value # Toggle between 0 and 1
342
+ else:
343
+ break
344
+
345
+ # RLE is column-major, reshape and transpose
346
+ mask = mask.reshape(size).T
347
+ if mask.shape == (target_size, target_size):
348
+ combined_mask = np.logical_or(
349
+ combined_mask, mask.astype(bool)
350
+ )
351
+
352
+ except Exception as e:
353
+ print(f"Warning: Failed to decode RLE mask: {e}")
354
+ continue
355
+
356
+ # Convert to uint8 first, then resize if needed
357
+ combined_mask_uint8 = combined_mask.astype(np.uint8) * 255
358
+
359
+ # Resize back to original shape if needed
360
+ if original_shape != (target_size, target_size):
361
+ # Use precise resize
362
+ combined_mask_resized = resize(
363
+ combined_mask_uint8.astype(np.float32),
364
+ original_shape,
365
+ preserve_range=True,
366
+ anti_aliasing=False,
367
+ order=0,
368
+ )
369
+ combined_mask = (combined_mask_resized > 127).astype(np.uint8) * 255
370
+ else:
371
+ combined_mask = combined_mask_uint8
372
+
373
+ # Save as GeoTIFF
374
+ with rasterio.open(
375
+ output_path,
376
+ "w",
377
+ driver="GTiff",
378
+ height=combined_mask.shape[0],
379
+ width=combined_mask.shape[1],
380
+ count=1,
381
+ dtype=combined_mask.dtype,
382
+ crs=crs,
383
+ transform=transform,
384
+ compress="lzw",
385
+ ) as dst:
386
+ dst.write(combined_mask, 1)
387
+
388
+ def _save_probability_mask(
389
+ self, change_masks, output_path, transform, crs, original_shape, target_size
390
+ ):
391
+ """
392
+ Save probability masks as a GeoTIFF with proper georeference.
393
+
394
+ Args:
395
+ change_masks: Change detection masks (MaskData object)
396
+ output_path (str): Output file path
397
+ transform: Rasterio transform
398
+ crs: Coordinate reference system
399
+ original_shape (tuple): Original image shape
400
+ target_size (int): Processing target size
401
+ """
402
+ # Extract mask components for probability calculation
403
+ mask_items = dict(change_masks.items())
404
+ rles = mask_items.get("rles", [])
405
+ iou_preds = mask_items.get("iou_preds", None)
406
+ stability_scores = mask_items.get("stability_score", None)
407
+ change_confidence = mask_items.get("change_confidence", None)
408
+ areas = mask_items.get("areas", None)
409
+
410
+ # Convert tensors to numpy if needed
411
+ if iou_preds is not None:
412
+ iou_preds = iou_preds.detach().cpu().numpy()
413
+ if stability_scores is not None:
414
+ stability_scores = stability_scores.detach().cpu().numpy()
415
+ if change_confidence is not None:
416
+ change_confidence = change_confidence.detach().cpu().numpy()
417
+ if areas is not None:
418
+ areas = areas.detach().cpu().numpy()
419
+
420
+ # Create probability mask
421
+ probability_mask = np.zeros((target_size, target_size), dtype=np.float32)
422
+
423
+ # Process each mask with probability weighting
424
+ for i, rle in enumerate(rles):
425
+ if isinstance(rle, dict) and "size" in rle and "counts" in rle:
426
+ try:
427
+ # Decode RLE to binary mask
428
+ size = rle["size"]
429
+ counts = rle["counts"]
430
+
431
+ mask = np.zeros(size[0] * size[1], dtype=np.uint8)
432
+ pos = 0
433
+ value = 0
434
+
435
+ for count in counts:
436
+ if pos + count <= len(mask):
437
+ if value == 1:
438
+ mask[pos : pos + count] = 1
439
+ pos += count
440
+ value = 1 - value
441
+ else:
442
+ break
443
+
444
+ mask = mask.reshape(size).T
445
+ if mask.shape != (target_size, target_size):
446
+ continue
447
+
448
+ mask_bool = mask.astype(bool)
449
+
450
+ # Calculate probability using multiple factors
451
+ prob_components = []
452
+
453
+ # IoU prediction (0-1, higher is better)
454
+ if iou_preds is not None and i < len(iou_preds):
455
+ iou_score = float(iou_preds[i])
456
+ prob_components.append(("iou", iou_score))
457
+ else:
458
+ prob_components.append(("iou", 0.8))
459
+
460
+ # Stability score (0-1, higher is better)
461
+ if stability_scores is not None and i < len(stability_scores):
462
+ stability = float(stability_scores[i])
463
+ prob_components.append(("stability", stability))
464
+ else:
465
+ prob_components.append(("stability", 0.8))
466
+
467
+ # Change confidence (normalize based on threshold)
468
+ if change_confidence is not None and i < len(change_confidence):
469
+ conf = float(change_confidence[i])
470
+ # Normalize confidence: threshold is 145, values above indicate higher confidence
471
+ if conf >= 145:
472
+ conf_normalized = 0.5 + min(0.5, (conf - 145) / 145)
473
+ else:
474
+ conf_normalized = max(0.0, conf / 145 * 0.5)
475
+ prob_components.append(("confidence", conf_normalized))
476
+ else:
477
+ prob_components.append(("confidence", 0.5))
478
+
479
+ # Area-based weight (normalize using log scale)
480
+ if areas is not None and i < len(areas):
481
+ area = float(areas[i])
482
+ area_normalized = 0.2 + 0.8 * min(1.0, np.log(area + 1) / 15.0)
483
+ prob_components.append(("area", area_normalized))
484
+ else:
485
+ prob_components.append(("area", 0.6))
486
+
487
+ # Calculate weighted probability
488
+ weights = {
489
+ "iou": 0.3,
490
+ "stability": 0.3,
491
+ "confidence": 0.35,
492
+ "area": 0.05,
493
+ }
494
+ prob_weight = sum(
495
+ weights[name] * value for name, value in prob_components
496
+ )
497
+ prob_weight = np.clip(prob_weight, 0.0, 1.0)
498
+
499
+ # Add to probability mask (take maximum where masks overlap)
500
+ current_prob = probability_mask[mask_bool]
501
+ new_prob = np.maximum(current_prob, prob_weight)
502
+ probability_mask[mask_bool] = new_prob
503
+
504
+ except Exception as e:
505
+ print(f"Warning: Failed to process probability mask {i}: {e}")
506
+ continue
507
+
508
+ # Resize back to original shape if needed
509
+ if original_shape != (target_size, target_size):
510
+ prob_resized = resize(
511
+ probability_mask,
512
+ original_shape,
513
+ preserve_range=True,
514
+ anti_aliasing=True,
515
+ order=1,
516
+ )
517
+ prob_final = np.clip(prob_resized, 0.0, 1.0)
518
+ else:
519
+ prob_final = probability_mask
520
+
521
+ # Save as float32 GeoTIFF
522
+ with rasterio.open(
523
+ output_path,
524
+ "w",
525
+ driver="GTiff",
526
+ height=prob_final.shape[0],
527
+ width=prob_final.shape[1],
528
+ count=1,
529
+ dtype=rasterio.float32,
530
+ crs=crs,
531
+ transform=transform,
532
+ compress="lzw",
533
+ ) as dst:
534
+ dst.write(prob_final.astype(np.float32), 1)
535
+
536
+ def visualize_changes(
537
+ self, image1_path: str, image2_path: str, figsize: Tuple[int, int] = (15, 5)
538
+ ) -> plt.Figure:
539
+ """
540
+ Visualize change detection results.
541
+
542
+ Args:
543
+ image1_path (str): Path to first image
544
+ image2_path (str): Path to second image
545
+ figsize (tuple): Figure size
546
+
547
+ Returns:
548
+ matplotlib.figure.Figure: The figure object
549
+ """
550
+ change_masks, img1, img2 = self.detect_changes(
551
+ image1_path, image2_path, return_results=True
552
+ )
553
+
554
+ # Use torchange's visualization function
555
+ fig, _ = show_change_masks(img1, img2, change_masks)
556
+ fig.set_size_inches(figsize)
557
+
558
+ return fig
559
+
560
+ def visualize_results(self, image1_path, image2_path, binary_path, prob_path):
561
+ """Create enhanced visualization with probability analysis."""
562
+
563
+ # Load data
564
+ with rasterio.open(image1_path) as src:
565
+ img1 = src.read([1, 2, 3])
566
+ img1 = np.transpose(img1, (1, 2, 0))
567
+
568
+ with rasterio.open(image2_path) as src:
569
+ img2 = src.read([1, 2, 3])
570
+ img2 = np.transpose(img2, (1, 2, 0))
571
+
572
+ with rasterio.open(binary_path) as src:
573
+ binary_mask = src.read(1)
574
+
575
+ with rasterio.open(prob_path) as src:
576
+ prob_mask = src.read(1)
577
+
578
+ # Create comprehensive visualization
579
+ fig, axes = plt.subplots(2, 4, figsize=(24, 12))
580
+
581
+ # Crop for better visualization
582
+ h, w = img1.shape[:2]
583
+ y1, y2 = h // 4, 3 * h // 4
584
+ x1, x2 = w // 4, 3 * w // 4
585
+
586
+ img1_crop = img1[y1:y2, x1:x2]
587
+ img2_crop = img2[y1:y2, x1:x2]
588
+ binary_crop = binary_mask[y1:y2, x1:x2]
589
+ prob_crop = prob_mask[y1:y2, x1:x2]
590
+
591
+ # Row 1: Original and overlays
592
+ axes[0, 0].imshow(img1_crop)
593
+ axes[0, 0].set_title("2019 Image", fontweight="bold")
594
+ axes[0, 0].axis("off")
595
+
596
+ axes[0, 1].imshow(img2_crop)
597
+ axes[0, 1].set_title("2022 Image", fontweight="bold")
598
+ axes[0, 1].axis("off")
599
+
600
+ # Binary overlay
601
+ overlay_binary = img2_crop.copy()
602
+ overlay_binary[binary_crop > 0] = [255, 0, 0]
603
+ axes[0, 2].imshow(overlay_binary)
604
+ axes[0, 2].set_title("Binary Changes\n(Red = Change)", fontweight="bold")
605
+ axes[0, 2].axis("off")
606
+
607
+ # Probability heatmap
608
+ im1 = axes[0, 3].imshow(prob_crop, cmap="hot", vmin=0, vmax=1)
609
+ axes[0, 3].set_title(
610
+ "Probability Heatmap\n(White = High Confidence)", fontweight="bold"
611
+ )
612
+ axes[0, 3].axis("off")
613
+ plt.colorbar(im1, ax=axes[0, 3], shrink=0.8)
614
+
615
+ # Row 2: Detailed probability analysis
616
+ # Confidence levels overlay
617
+ overlay_conf = img2_crop.copy()
618
+ high_conf = prob_crop > 0.7
619
+ med_conf = (prob_crop > 0.4) & (prob_crop <= 0.7)
620
+ low_conf = (prob_crop > 0.1) & (prob_crop <= 0.4)
621
+
622
+ overlay_conf[high_conf] = [255, 0, 0] # Red for high
623
+ overlay_conf[med_conf] = [255, 165, 0] # Orange for medium
624
+ overlay_conf[low_conf] = [255, 255, 0] # Yellow for low
625
+
626
+ axes[1, 0].imshow(overlay_conf)
627
+ axes[1, 0].set_title(
628
+ "Confidence Levels\n(Red>0.7, Orange>0.4, Yellow>0.1)", fontweight="bold"
629
+ )
630
+ axes[1, 0].axis("off")
631
+
632
+ # Thresholded probability (>0.5)
633
+ overlay_thresh = img2_crop.copy()
634
+ high_prob = prob_crop > 0.5
635
+ overlay_thresh[high_prob] = [255, 0, 0]
636
+ axes[1, 1].imshow(overlay_thresh)
637
+ axes[1, 1].set_title(
638
+ "High Confidence Only\n(Probability > 0.5)", fontweight="bold"
639
+ )
640
+ axes[1, 1].axis("off")
641
+
642
+ # Probability histogram
643
+ prob_values = prob_crop[prob_crop > 0]
644
+ if len(prob_values) > 0:
645
+ axes[1, 2].hist(
646
+ prob_values, bins=50, alpha=0.7, color="red", edgecolor="black"
647
+ )
648
+ axes[1, 2].axvline(
649
+ x=0.5, color="blue", linestyle="--", label="0.5 threshold"
650
+ )
651
+ axes[1, 2].axvline(
652
+ x=0.7, color="green", linestyle="--", label="0.7 threshold"
653
+ )
654
+ axes[1, 2].set_xlabel("Change Probability")
655
+ axes[1, 2].set_ylabel("Pixel Count")
656
+ axes[1, 2].set_title(
657
+ f"Probability Distribution\n({len(prob_values):,} pixels)"
658
+ )
659
+ axes[1, 2].legend()
660
+ axes[1, 2].grid(True, alpha=0.3)
661
+
662
+ # Statistics text
663
+ stats_text = f"""Probability Statistics:
664
+ Min: {np.min(prob_values):.3f}
665
+ Max: {np.max(prob_values):.3f}
666
+ Mean: {np.mean(prob_values):.3f}
667
+ Median: {np.median(prob_values):.3f}
668
+
669
+ Confidence Levels:
670
+ High (>0.7): {np.sum(prob_crop > 0.7):,}
671
+ Med (0.4-0.7): {np.sum((prob_crop > 0.4) & (prob_crop <= 0.7)):,}
672
+ Low (0.1-0.4): {np.sum((prob_crop > 0.1) & (prob_crop <= 0.4)):,}"""
673
+
674
+ axes[1, 3].text(
675
+ 0.05,
676
+ 0.95,
677
+ stats_text,
678
+ transform=axes[1, 3].transAxes,
679
+ fontsize=11,
680
+ verticalalignment="top",
681
+ fontfamily="monospace",
682
+ )
683
+ axes[1, 3].set_xlim(0, 1)
684
+ axes[1, 3].set_ylim(0, 1)
685
+ axes[1, 3].axis("off")
686
+ axes[1, 3].set_title("Statistics Summary", fontweight="bold")
687
+
688
+ plt.tight_layout()
689
+ plt.suptitle(
690
+ "Enhanced Probability-Based Change Detection",
691
+ fontsize=16,
692
+ fontweight="bold",
693
+ y=0.98,
694
+ )
695
+
696
+ plt.savefig("enhanced_probability_results.png", dpi=150, bbox_inches="tight")
697
+ plt.show()
698
+
699
+ print("💾 Enhanced visualization saved as 'enhanced_probability_results.png'")
700
+
701
+ def create_split_comparison(
702
+ self,
703
+ image1_path,
704
+ image2_path,
705
+ binary_path,
706
+ prob_path,
707
+ output_path="split_comparison.png",
708
+ ):
709
+ """Create a split comparison visualization showing before/after with change overlay."""
710
+
711
+ # Load data
712
+ with rasterio.open(image1_path) as src:
713
+ img1 = src.read([1, 2, 3])
714
+ img1 = np.transpose(img1, (1, 2, 0))
715
+ if img1.dtype != np.uint8:
716
+ img1 = ((img1 - img1.min()) / (img1.max() - img1.min()) * 255).astype(
717
+ np.uint8
718
+ )
719
+
720
+ with rasterio.open(image2_path) as src:
721
+ img2 = src.read([1, 2, 3])
722
+ img2 = np.transpose(img2, (1, 2, 0))
723
+ if img2.dtype != np.uint8:
724
+ img2 = ((img2 - img2.min()) / (img2.max() - img2.min()) * 255).astype(
725
+ np.uint8
726
+ )
727
+
728
+ with rasterio.open(prob_path) as src:
729
+ prob_mask = src.read(1)
730
+
731
+ # Ensure all arrays have the same shape
732
+ h, w = img1.shape[:2]
733
+ if prob_mask.shape != (h, w):
734
+ prob_mask = resize(
735
+ prob_mask, (h, w), preserve_range=True, anti_aliasing=True, order=1
736
+ )
737
+
738
+ # Create split comparison
739
+ fig, ax = plt.subplots(1, 1, figsize=(15, 10))
740
+
741
+ # Create combined image - left half is 2019, right half is 2022
742
+ combined_img = np.zeros_like(img1)
743
+ combined_img[:, : w // 2] = img1[:, : w // 2]
744
+ combined_img[:, w // 2 :] = img2[:, w // 2 :]
745
+
746
+ # Create overlay with changes - ensure prob_mask is 2D and matches image dimensions
747
+ overlay = combined_img.copy()
748
+ high_conf_changes = prob_mask > 0.5
749
+
750
+ # Apply overlay only where changes are detected
751
+ if len(overlay.shape) == 3: # RGB image
752
+ overlay[high_conf_changes] = [255, 0, 0] # Red for high confidence changes
753
+
754
+ # Blend overlay with original
755
+ blended = cv2.addWeighted(combined_img, 0.7, overlay, 0.3, 0)
756
+
757
+ ax.imshow(blended)
758
+ ax.axvline(x=w // 2, color="white", linewidth=3, linestyle="--", alpha=0.8)
759
+ ax.text(
760
+ w // 4,
761
+ 50,
762
+ "2019",
763
+ fontsize=20,
764
+ color="white",
765
+ ha="center",
766
+ bbox={"boxstyle": "round,pad=0.3", "facecolor": "black", "alpha": 0.8},
767
+ )
768
+ ax.text(
769
+ 3 * w // 4,
770
+ 50,
771
+ "2022",
772
+ fontsize=20,
773
+ color="white",
774
+ ha="center",
775
+ bbox={"boxstyle": "round,pad=0.3", "facecolor": "black", "alpha": 0.8},
776
+ )
777
+
778
+ ax.set_title(
779
+ "Split Comparison with Change Detection\n(Red = High Confidence Changes)",
780
+ fontsize=16,
781
+ fontweight="bold",
782
+ pad=20,
783
+ )
784
+ ax.axis("off")
785
+
786
+ plt.tight_layout()
787
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
788
+ plt.show()
789
+
790
+ print(f"💾 Split comparison saved as '{output_path}'")
791
+
792
+ def analyze_instances(
793
+ self, instance_mask_path, scores_path, output_path="instance_analysis.png"
794
+ ):
795
+ """Analyze and visualize instance segmentation results."""
796
+
797
+ # Load instance mask and scores
798
+ with rasterio.open(instance_mask_path) as src:
799
+ instance_mask = src.read(1)
800
+
801
+ with rasterio.open(scores_path) as src:
802
+ scores_mask = src.read(1)
803
+
804
+ # Get unique instances (excluding background)
805
+ unique_instances = np.unique(instance_mask)
806
+ unique_instances = unique_instances[unique_instances > 0]
807
+
808
+ # Calculate statistics for each instance
809
+ instance_stats = []
810
+ for instance_id in unique_instances:
811
+ mask = instance_mask == instance_id
812
+ area = np.sum(mask)
813
+ score = np.mean(scores_mask[mask])
814
+ instance_stats.append({"id": instance_id, "area": area, "score": score})
815
+
816
+ # Sort by score
817
+ instance_stats.sort(key=lambda x: x["score"], reverse=True)
818
+
819
+ # Create visualization
820
+ fig, axes = plt.subplots(2, 2, figsize=(16, 12))
821
+
822
+ # 1. Instance segmentation visualization
823
+ colored_mask = np.zeros((*instance_mask.shape, 3), dtype=np.uint8)
824
+ colors = plt.cm.Set3(np.linspace(0, 1, len(unique_instances)))
825
+
826
+ for i, instance_id in enumerate(unique_instances):
827
+ mask = instance_mask == instance_id
828
+ colored_mask[mask] = (colors[i][:3] * 255).astype(np.uint8)
829
+
830
+ axes[0, 0].imshow(colored_mask)
831
+ axes[0, 0].set_title(
832
+ f"Instance Segmentation\n({len(unique_instances)} instances)",
833
+ fontweight="bold",
834
+ )
835
+ axes[0, 0].axis("off")
836
+
837
+ # 2. Scores heatmap
838
+ im = axes[0, 1].imshow(scores_mask, cmap="viridis", vmin=0, vmax=1)
839
+ axes[0, 1].set_title("Instance Confidence Scores", fontweight="bold")
840
+ axes[0, 1].axis("off")
841
+ plt.colorbar(im, ax=axes[0, 1], shrink=0.8)
842
+
843
+ # 3. Score distribution
844
+ all_scores = [stat["score"] for stat in instance_stats]
845
+ axes[1, 0].hist(
846
+ all_scores, bins=20, alpha=0.7, color="skyblue", edgecolor="black"
847
+ )
848
+ axes[1, 0].axvline(
849
+ x=np.mean(all_scores),
850
+ color="red",
851
+ linestyle="--",
852
+ label=f"Mean: {np.mean(all_scores):.3f}",
853
+ )
854
+ axes[1, 0].set_xlabel("Confidence Score")
855
+ axes[1, 0].set_ylabel("Instance Count")
856
+ axes[1, 0].set_title("Score Distribution", fontweight="bold")
857
+ axes[1, 0].legend()
858
+ axes[1, 0].grid(True, alpha=0.3)
859
+
860
+ # 4. Top instances by score
861
+ top_instances = instance_stats[:10]
862
+ instance_ids = [stat["id"] for stat in top_instances]
863
+ scores = [stat["score"] for stat in top_instances]
864
+ areas = [stat["area"] for stat in top_instances]
865
+
866
+ bars = axes[1, 1].bar(
867
+ range(len(top_instances)), scores, color="coral", alpha=0.7
868
+ )
869
+ axes[1, 1].set_xlabel("Top 10 Instances")
870
+ axes[1, 1].set_ylabel("Confidence Score")
871
+ axes[1, 1].set_title("Top Instances by Confidence", fontweight="bold")
872
+ axes[1, 1].set_xticks(range(len(top_instances)))
873
+ axes[1, 1].set_xticklabels([f"#{id}" for id in instance_ids], rotation=45)
874
+
875
+ # Add area info as text on bars
876
+ for i, (bar, area) in enumerate(zip(bars, areas)):
877
+ height = bar.get_height()
878
+ axes[1, 1].text(
879
+ bar.get_x() + bar.get_width() / 2.0,
880
+ height,
881
+ f"{area}px",
882
+ ha="center",
883
+ va="bottom",
884
+ fontsize=8,
885
+ )
886
+
887
+ plt.tight_layout()
888
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
889
+ plt.show()
890
+
891
+ # Print summary statistics
892
+ print(f"\n📊 Instance Analysis Summary:")
893
+ print(f" Total instances: {len(unique_instances)}")
894
+ print(f" Average confidence: {np.mean(all_scores):.3f}")
895
+ print(f" Score range: {np.min(all_scores):.3f} - {np.max(all_scores):.3f}")
896
+ print(f" Total change area: {sum(areas):,} pixels")
897
+
898
+ print(f"\n💾 Instance analysis saved as '{output_path}'")
899
+
900
+ return instance_stats
901
+
902
+ def create_comprehensive_report(
903
+ self, results_dict, output_path="comprehensive_report.png"
904
+ ):
905
+ """Create a comprehensive visualization report from detailed results."""
906
+
907
+ if not results_dict or "masks" not in results_dict:
908
+ print("❌ No detailed results provided")
909
+ return
910
+
911
+ masks = results_dict["masks"]
912
+ stats = results_dict["statistics"]
913
+
914
+ # Create comprehensive visualization
915
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
916
+
917
+ # 1. Score distributions
918
+ if "iou_predictions" in stats:
919
+ iou_scores = [
920
+ mask["iou_pred"] for mask in masks if mask["iou_pred"] is not None
921
+ ]
922
+ axes[0, 0].hist(
923
+ iou_scores, bins=20, alpha=0.7, color="lightblue", edgecolor="black"
924
+ )
925
+ axes[0, 0].axvline(
926
+ x=stats["iou_predictions"]["mean"],
927
+ color="red",
928
+ linestyle="--",
929
+ label=f"Mean: {stats['iou_predictions']['mean']:.3f}",
930
+ )
931
+ axes[0, 0].set_xlabel("IoU Score")
932
+ axes[0, 0].set_ylabel("Count")
933
+ axes[0, 0].set_title("IoU Predictions Distribution", fontweight="bold")
934
+ axes[0, 0].legend()
935
+ axes[0, 0].grid(True, alpha=0.3)
936
+
937
+ # 2. Stability scores
938
+ if "stability_scores" in stats:
939
+ stability_scores = [
940
+ mask["stability_score"]
941
+ for mask in masks
942
+ if mask["stability_score"] is not None
943
+ ]
944
+ axes[0, 1].hist(
945
+ stability_scores,
946
+ bins=20,
947
+ alpha=0.7,
948
+ color="lightgreen",
949
+ edgecolor="black",
950
+ )
951
+ axes[0, 1].axvline(
952
+ x=stats["stability_scores"]["mean"],
953
+ color="red",
954
+ linestyle="--",
955
+ label=f"Mean: {stats['stability_scores']['mean']:.3f}",
956
+ )
957
+ axes[0, 1].set_xlabel("Stability Score")
958
+ axes[0, 1].set_ylabel("Count")
959
+ axes[0, 1].set_title("Stability Scores Distribution", fontweight="bold")
960
+ axes[0, 1].legend()
961
+ axes[0, 1].grid(True, alpha=0.3)
962
+
963
+ # 3. Change confidence
964
+ if "change_confidence" in stats:
965
+ change_conf = [
966
+ mask["change_confidence"]
967
+ for mask in masks
968
+ if mask["change_confidence"] is not None
969
+ ]
970
+ axes[0, 2].hist(
971
+ change_conf, bins=20, alpha=0.7, color="lightyellow", edgecolor="black"
972
+ )
973
+ axes[0, 2].axvline(
974
+ x=stats["change_confidence"]["mean"],
975
+ color="red",
976
+ linestyle="--",
977
+ label=f"Mean: {stats['change_confidence']['mean']:.1f}",
978
+ )
979
+ axes[0, 2].set_xlabel("Change Confidence")
980
+ axes[0, 2].set_ylabel("Count")
981
+ axes[0, 2].set_title("Change Confidence Distribution", fontweight="bold")
982
+ axes[0, 2].legend()
983
+ axes[0, 2].grid(True, alpha=0.3)
984
+
985
+ # 4. Area distribution
986
+ if "areas" in stats:
987
+ areas = [mask["area"] for mask in masks if mask["area"] is not None]
988
+ axes[1, 0].hist(
989
+ areas, bins=20, alpha=0.7, color="lightcoral", edgecolor="black"
990
+ )
991
+ axes[1, 0].axvline(
992
+ x=stats["areas"]["mean"],
993
+ color="red",
994
+ linestyle="--",
995
+ label=f"Mean: {stats['areas']['mean']:.1f}",
996
+ )
997
+ axes[1, 0].set_xlabel("Area (pixels)")
998
+ axes[1, 0].set_ylabel("Count")
999
+ axes[1, 0].set_title("Area Distribution", fontweight="bold")
1000
+ axes[1, 0].legend()
1001
+ axes[1, 0].grid(True, alpha=0.3)
1002
+
1003
+ # 5. Combined confidence vs area scatter
1004
+ combined_conf = [
1005
+ mask["combined_confidence"]
1006
+ for mask in masks
1007
+ if "combined_confidence" in mask
1008
+ ]
1009
+ areas_for_scatter = [
1010
+ mask["area"]
1011
+ for mask in masks
1012
+ if "combined_confidence" in mask and mask["area"] is not None
1013
+ ]
1014
+
1015
+ if combined_conf and areas_for_scatter:
1016
+ scatter = axes[1, 1].scatter(
1017
+ areas_for_scatter,
1018
+ combined_conf,
1019
+ alpha=0.6,
1020
+ c=combined_conf,
1021
+ cmap="viridis",
1022
+ s=50,
1023
+ )
1024
+ axes[1, 1].set_xlabel("Area (pixels)")
1025
+ axes[1, 1].set_ylabel("Combined Confidence")
1026
+ axes[1, 1].set_title("Confidence vs Area", fontweight="bold")
1027
+ axes[1, 1].grid(True, alpha=0.3)
1028
+ plt.colorbar(scatter, ax=axes[1, 1], shrink=0.8)
1029
+
1030
+ # 6. Summary statistics text
1031
+ summary_text = f"""Detection Summary:
1032
+ Total Instances: {len(masks)}
1033
+ Processing Size: {results_dict['summary']['target_size']}
1034
+ Original Shape: {results_dict['summary']['original_shape']}
1035
+
1036
+ Quality Metrics:"""
1037
+
1038
+ if "iou_predictions" in stats:
1039
+ summary_text += f"""
1040
+ IoU Predictions:
1041
+ Mean: {stats['iou_predictions']['mean']:.3f}
1042
+ Range: {stats['iou_predictions']['min']:.3f} - {stats['iou_predictions']['max']:.3f}"""
1043
+
1044
+ if "stability_scores" in stats:
1045
+ summary_text += f"""
1046
+ Stability Scores:
1047
+ Mean: {stats['stability_scores']['mean']:.3f}
1048
+ Range: {stats['stability_scores']['min']:.3f} - {stats['stability_scores']['max']:.3f}"""
1049
+
1050
+ if "change_confidence" in stats:
1051
+ summary_text += f"""
1052
+ Change Confidence:
1053
+ Mean: {stats['change_confidence']['mean']:.1f}
1054
+ Range: {stats['change_confidence']['min']:.1f} - {stats['change_confidence']['max']:.1f}"""
1055
+
1056
+ if "areas" in stats:
1057
+ summary_text += f"""
1058
+ Areas:
1059
+ Mean: {stats['areas']['mean']:.1f}
1060
+ Total: {stats['areas']['total']:,.0f} pixels"""
1061
+
1062
+ axes[1, 2].text(
1063
+ 0.05,
1064
+ 0.95,
1065
+ summary_text,
1066
+ transform=axes[1, 2].transAxes,
1067
+ fontsize=10,
1068
+ verticalalignment="top",
1069
+ fontfamily="monospace",
1070
+ )
1071
+ axes[1, 2].set_xlim(0, 1)
1072
+ axes[1, 2].set_ylim(0, 1)
1073
+ axes[1, 2].axis("off")
1074
+ axes[1, 2].set_title("Summary Statistics", fontweight="bold")
1075
+
1076
+ plt.tight_layout()
1077
+ plt.suptitle(
1078
+ "Comprehensive Change Detection Report",
1079
+ fontsize=16,
1080
+ fontweight="bold",
1081
+ y=0.98,
1082
+ )
1083
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
1084
+ plt.show()
1085
+
1086
+ print(f"💾 Comprehensive report saved as '{output_path}'")
1087
+
1088
+ def run_complete_analysis(
1089
+ self, image1_path, image2_path, output_dir="change_detection_results"
1090
+ ):
1091
+ """Run complete change detection analysis with all outputs and visualizations."""
1092
+
1093
+ # Create output directory
1094
+ os.makedirs(output_dir, exist_ok=True)
1095
+
1096
+ # Define output paths
1097
+ binary_path = os.path.join(output_dir, "binary_mask.tif")
1098
+ prob_path = os.path.join(output_dir, "probability_mask.tif")
1099
+ instance_path = os.path.join(output_dir, "instance_masks.tif")
1100
+
1101
+ print("🔍 Running complete change detection analysis...")
1102
+
1103
+ # Run detection with all outputs
1104
+ results = self.detect_changes(
1105
+ image1_path,
1106
+ image2_path,
1107
+ output_path=binary_path,
1108
+ export_probability=True,
1109
+ probability_output_path=prob_path,
1110
+ export_instance_masks=True,
1111
+ instance_masks_output_path=instance_path,
1112
+ return_detailed_results=True,
1113
+ return_results=False,
1114
+ )
1115
+
1116
+ print("📊 Creating visualizations...")
1117
+
1118
+ # Create all visualizations
1119
+ self.visualize_results(image1_path, image2_path, binary_path, prob_path)
1120
+
1121
+ self.create_split_comparison(
1122
+ image1_path,
1123
+ image2_path,
1124
+ binary_path,
1125
+ prob_path,
1126
+ os.path.join(output_dir, "split_comparison.png"),
1127
+ )
1128
+
1129
+ scores_path = instance_path.replace(".tif", "_scores.tif")
1130
+ self.analyze_instances(
1131
+ instance_path,
1132
+ scores_path,
1133
+ os.path.join(output_dir, "instance_analysis.png"),
1134
+ )
1135
+
1136
+ self.create_comprehensive_report(
1137
+ results, os.path.join(output_dir, "comprehensive_report.png")
1138
+ )
1139
+
1140
+ print(f"✅ Complete analysis finished! Results saved to: {output_dir}")
1141
+ return results
1142
+
1143
+ def _save_instance_segmentation_masks(
1144
+ self, change_masks, output_path, transform, crs, original_shape, target_size
1145
+ ):
1146
+ """
1147
+ Save instance segmentation masks as a single GeoTIFF where each instance has a unique ID.
1148
+
1149
+ Args:
1150
+ change_masks: Change detection masks (MaskData object)
1151
+ output_path (str): Output path for instance segmentation GeoTIFF
1152
+ transform: Rasterio transform
1153
+ crs: Coordinate reference system
1154
+ original_shape (tuple): Original image shape
1155
+ target_size (int): Processing target size
1156
+ """
1157
+ # Extract mask components
1158
+ mask_items = dict(change_masks.items())
1159
+ rles = mask_items.get("rles", [])
1160
+
1161
+ # Create instance segmentation mask (each instance gets unique ID)
1162
+ instance_mask = np.zeros((target_size, target_size), dtype=np.uint16)
1163
+
1164
+ # Process each mask and assign unique instance ID
1165
+ for instance_id, rle in enumerate(rles, start=1):
1166
+ if isinstance(rle, dict) and "size" in rle and "counts" in rle:
1167
+ try:
1168
+ # Decode RLE to binary mask
1169
+ size = rle["size"]
1170
+ counts = rle["counts"]
1171
+
1172
+ mask = np.zeros(size[0] * size[1], dtype=np.uint8)
1173
+ pos = 0
1174
+ value = 0
1175
+
1176
+ for count in counts:
1177
+ if pos + count <= len(mask):
1178
+ if value == 1:
1179
+ mask[pos : pos + count] = 1
1180
+ pos += count
1181
+ value = 1 - value
1182
+ else:
1183
+ break
1184
+
1185
+ # RLE is column-major, reshape and transpose
1186
+ mask = mask.reshape(size).T
1187
+ if mask.shape != (target_size, target_size):
1188
+ continue
1189
+
1190
+ # Assign instance ID to this mask
1191
+ instance_mask[mask.astype(bool)] = instance_id
1192
+
1193
+ except Exception as e:
1194
+ print(f"Warning: Failed to process mask {instance_id}: {e}")
1195
+ continue
1196
+
1197
+ # Resize back to original shape if needed
1198
+ if original_shape != (target_size, target_size):
1199
+ instance_mask_resized = resize(
1200
+ instance_mask.astype(np.float32),
1201
+ original_shape,
1202
+ preserve_range=True,
1203
+ anti_aliasing=False,
1204
+ order=0,
1205
+ )
1206
+ instance_mask_final = np.round(instance_mask_resized).astype(np.uint16)
1207
+ else:
1208
+ instance_mask_final = instance_mask
1209
+
1210
+ # Save as GeoTIFF
1211
+ with rasterio.open(
1212
+ output_path,
1213
+ "w",
1214
+ driver="GTiff",
1215
+ height=instance_mask_final.shape[0],
1216
+ width=instance_mask_final.shape[1],
1217
+ count=1,
1218
+ dtype=instance_mask_final.dtype,
1219
+ crs=crs,
1220
+ transform=transform,
1221
+ compress="lzw",
1222
+ ) as dst:
1223
+ dst.write(instance_mask_final, 1)
1224
+
1225
+ # Add metadata
1226
+ dst.update_tags(
1227
+ description="Instance segmentation mask with unique IDs for each change object",
1228
+ total_instances=str(len(rles)),
1229
+ background_value="0",
1230
+ instance_range=f"1-{len(rles)}",
1231
+ )
1232
+
1233
+ print(
1234
+ f"Saved instance segmentation mask with {len(rles)} instances to {output_path}"
1235
+ )
1236
+ return len(rles)
1237
+
1238
+ def _save_instance_scores_mask(
1239
+ self, change_masks, output_path, transform, crs, original_shape, target_size
1240
+ ):
1241
+ """
1242
+ Save instance scores/probability mask as a GeoTIFF where each instance has its confidence score.
1243
+
1244
+ Args:
1245
+ change_masks: Change detection masks (MaskData object)
1246
+ output_path (str): Output path for instance scores GeoTIFF
1247
+ transform: Rasterio transform
1248
+ crs: Coordinate reference system
1249
+ original_shape (tuple): Original image shape
1250
+ target_size (int): Processing target size
1251
+ """
1252
+ # Extract mask components
1253
+ mask_items = dict(change_masks.items())
1254
+ rles = mask_items.get("rles", [])
1255
+ iou_preds = mask_items.get("iou_preds", None)
1256
+ stability_scores = mask_items.get("stability_score", None)
1257
+ change_confidence = mask_items.get("change_confidence", None)
1258
+
1259
+ # Convert tensors to numpy if needed
1260
+ if iou_preds is not None:
1261
+ iou_preds = iou_preds.detach().cpu().numpy()
1262
+ if stability_scores is not None:
1263
+ stability_scores = stability_scores.detach().cpu().numpy()
1264
+ if change_confidence is not None:
1265
+ change_confidence = change_confidence.detach().cpu().numpy()
1266
+
1267
+ # Create instance scores mask
1268
+ scores_mask = np.zeros((target_size, target_size), dtype=np.float32)
1269
+
1270
+ # Process each mask and assign confidence score
1271
+ for instance_id, rle in enumerate(rles):
1272
+ if isinstance(rle, dict) and "size" in rle and "counts" in rle:
1273
+ try:
1274
+ # Decode RLE to binary mask
1275
+ size = rle["size"]
1276
+ counts = rle["counts"]
1277
+
1278
+ mask = np.zeros(size[0] * size[1], dtype=np.uint8)
1279
+ pos = 0
1280
+ value = 0
1281
+
1282
+ for count in counts:
1283
+ if pos + count <= len(mask):
1284
+ if value == 1:
1285
+ mask[pos : pos + count] = 1
1286
+ pos += count
1287
+ value = 1 - value
1288
+ else:
1289
+ break
1290
+
1291
+ # RLE is column-major, reshape and transpose
1292
+ mask = mask.reshape(size).T
1293
+ if mask.shape != (target_size, target_size):
1294
+ continue
1295
+
1296
+ # Calculate combined confidence score
1297
+ confidence_score = 0.5 # Default
1298
+ if iou_preds is not None and instance_id < len(iou_preds):
1299
+ iou_score = float(iou_preds[instance_id])
1300
+
1301
+ if stability_scores is not None and instance_id < len(
1302
+ stability_scores
1303
+ ):
1304
+ stability_score = float(stability_scores[instance_id])
1305
+
1306
+ if change_confidence is not None and instance_id < len(
1307
+ change_confidence
1308
+ ):
1309
+ change_conf = float(change_confidence[instance_id])
1310
+ # Normalize change confidence (typically around 145 threshold)
1311
+ change_conf_norm = max(
1312
+ 0.0, min(1.0, abs(change_conf) / 200.0)
1313
+ )
1314
+
1315
+ # Weighted combination of scores
1316
+ confidence_score = (
1317
+ 0.35 * iou_score
1318
+ + 0.35 * stability_score
1319
+ + 0.3 * change_conf_norm
1320
+ )
1321
+ else:
1322
+ confidence_score = 0.5 * (iou_score + stability_score)
1323
+ else:
1324
+ confidence_score = iou_score
1325
+
1326
+ # Assign confidence score to this mask
1327
+ scores_mask[mask.astype(bool)] = confidence_score
1328
+
1329
+ except Exception as e:
1330
+ print(
1331
+ f"Warning: Failed to process scores for mask {instance_id}: {e}"
1332
+ )
1333
+ continue
1334
+
1335
+ # Resize back to original shape if needed
1336
+ if original_shape != (target_size, target_size):
1337
+ scores_mask_resized = resize(
1338
+ scores_mask,
1339
+ original_shape,
1340
+ preserve_range=True,
1341
+ anti_aliasing=True,
1342
+ order=1,
1343
+ )
1344
+ scores_mask_final = np.clip(scores_mask_resized, 0.0, 1.0).astype(
1345
+ np.float32
1346
+ )
1347
+ else:
1348
+ scores_mask_final = scores_mask
1349
+
1350
+ # Save as GeoTIFF
1351
+ with rasterio.open(
1352
+ output_path,
1353
+ "w",
1354
+ driver="GTiff",
1355
+ height=scores_mask_final.shape[0],
1356
+ width=scores_mask_final.shape[1],
1357
+ count=1,
1358
+ dtype=scores_mask_final.dtype,
1359
+ crs=crs,
1360
+ transform=transform,
1361
+ compress="lzw",
1362
+ ) as dst:
1363
+ dst.write(scores_mask_final, 1)
1364
+
1365
+ # Add metadata
1366
+ dst.update_tags(
1367
+ description="Instance scores mask with confidence values for each change object",
1368
+ total_instances=str(len(rles)),
1369
+ background_value="0.0",
1370
+ score_range="0.0-1.0",
1371
+ )
1372
+
1373
+ print(f"Saved instance scores mask with {len(rles)} instances to {output_path}")
1374
+ return len(rles)
1375
+
1376
+ def _extract_detailed_results(
1377
+ self, change_masks, transform, crs, original_shape, target_size
1378
+ ):
1379
+ """
1380
+ Extract detailed results from change masks.
1381
+
1382
+ Args:
1383
+ change_masks: Change detection masks (MaskData object)
1384
+ transform: Rasterio transform
1385
+ crs: Coordinate reference system
1386
+ original_shape (tuple): Original image shape
1387
+ target_size (int): Processing target size
1388
+
1389
+ Returns:
1390
+ dict: Detailed results with mask information and statistics
1391
+ """
1392
+ # Extract mask components
1393
+ mask_items = dict(change_masks.items())
1394
+ rles = mask_items.get("rles", [])
1395
+ iou_preds = mask_items.get("iou_preds", None)
1396
+ stability_scores = mask_items.get("stability_score", None)
1397
+ change_confidence = mask_items.get("change_confidence", None)
1398
+ areas = mask_items.get("areas", None)
1399
+ boxes = mask_items.get("boxes", None)
1400
+ points = mask_items.get("points", None)
1401
+
1402
+ # Convert tensors to numpy if needed
1403
+ if iou_preds is not None:
1404
+ iou_preds = iou_preds.detach().cpu().numpy()
1405
+ if stability_scores is not None:
1406
+ stability_scores = stability_scores.detach().cpu().numpy()
1407
+ if change_confidence is not None:
1408
+ change_confidence = change_confidence.detach().cpu().numpy()
1409
+ if areas is not None:
1410
+ areas = areas.detach().cpu().numpy()
1411
+ if boxes is not None:
1412
+ boxes = boxes.detach().cpu().numpy()
1413
+ if points is not None:
1414
+ points = points.detach().cpu().numpy()
1415
+
1416
+ # Calculate statistics
1417
+ results = {
1418
+ "summary": {
1419
+ "total_masks": len(rles),
1420
+ "target_size": target_size,
1421
+ "original_shape": original_shape,
1422
+ "crs": str(crs),
1423
+ "transform": transform.to_gdal(),
1424
+ },
1425
+ "statistics": {},
1426
+ "masks": [],
1427
+ }
1428
+
1429
+ # Calculate statistics for each metric
1430
+ if iou_preds is not None and len(iou_preds) > 0:
1431
+ results["statistics"]["iou_predictions"] = {
1432
+ "mean": float(np.mean(iou_preds)),
1433
+ "std": float(np.std(iou_preds)),
1434
+ "min": float(np.min(iou_preds)),
1435
+ "max": float(np.max(iou_preds)),
1436
+ "median": float(np.median(iou_preds)),
1437
+ }
1438
+
1439
+ if stability_scores is not None and len(stability_scores) > 0:
1440
+ results["statistics"]["stability_scores"] = {
1441
+ "mean": float(np.mean(stability_scores)),
1442
+ "std": float(np.std(stability_scores)),
1443
+ "min": float(np.min(stability_scores)),
1444
+ "max": float(np.max(stability_scores)),
1445
+ "median": float(np.median(stability_scores)),
1446
+ }
1447
+
1448
+ if change_confidence is not None and len(change_confidence) > 0:
1449
+ results["statistics"]["change_confidence"] = {
1450
+ "mean": float(np.mean(change_confidence)),
1451
+ "std": float(np.std(change_confidence)),
1452
+ "min": float(np.min(change_confidence)),
1453
+ "max": float(np.max(change_confidence)),
1454
+ "median": float(np.median(change_confidence)),
1455
+ }
1456
+
1457
+ if areas is not None and len(areas) > 0:
1458
+ results["statistics"]["areas"] = {
1459
+ "mean": float(np.mean(areas)),
1460
+ "std": float(np.std(areas)),
1461
+ "min": float(np.min(areas)),
1462
+ "max": float(np.max(areas)),
1463
+ "median": float(np.median(areas)),
1464
+ "total": float(np.sum(areas)),
1465
+ }
1466
+
1467
+ # Extract individual mask details
1468
+ for i in range(len(rles)):
1469
+ mask_info = {
1470
+ "mask_id": i,
1471
+ "iou_pred": (
1472
+ float(iou_preds[i])
1473
+ if iou_preds is not None and i < len(iou_preds)
1474
+ else None
1475
+ ),
1476
+ "stability_score": (
1477
+ float(stability_scores[i])
1478
+ if stability_scores is not None and i < len(stability_scores)
1479
+ else None
1480
+ ),
1481
+ "change_confidence": (
1482
+ float(change_confidence[i])
1483
+ if change_confidence is not None and i < len(change_confidence)
1484
+ else None
1485
+ ),
1486
+ "area": int(areas[i]) if areas is not None and i < len(areas) else None,
1487
+ "bbox": (
1488
+ boxes[i].tolist() if boxes is not None and i < len(boxes) else None
1489
+ ),
1490
+ "center_point": (
1491
+ points[i].tolist()
1492
+ if points is not None and i < len(points)
1493
+ else None
1494
+ ),
1495
+ }
1496
+
1497
+ # Calculate combined confidence score
1498
+ if all(
1499
+ v is not None
1500
+ for v in [
1501
+ mask_info["iou_pred"],
1502
+ mask_info["stability_score"],
1503
+ mask_info["change_confidence"],
1504
+ ]
1505
+ ):
1506
+ # Normalize change confidence (145 is typical threshold)
1507
+ conf_norm = max(0.0, min(1.0, mask_info["change_confidence"] / 145.0))
1508
+ combined_score = (
1509
+ 0.3 * mask_info["iou_pred"]
1510
+ + 0.3 * mask_info["stability_score"]
1511
+ + 0.4 * conf_norm
1512
+ )
1513
+ mask_info["combined_confidence"] = float(combined_score)
1514
+
1515
+ results["masks"].append(mask_info)
1516
+
1517
+ # Sort masks by combined confidence if available
1518
+ if results["masks"] and "combined_confidence" in results["masks"][0]:
1519
+ results["masks"].sort(key=lambda x: x["combined_confidence"], reverse=True)
1520
+
1521
+ return results
1522
+
1523
+
1524
+ def download_checkpoint(
1525
+ model_type: str = "vit_h", checkpoint_dir: Optional[str] = None
1526
+ ) -> str:
1527
+ """Download the SAM model checkpoint.
1528
+
1529
+ Args:
1530
+ model_type (str, optional): The model type. Can be one of ['vit_h', 'vit_l', 'vit_b'].
1531
+ Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
1532
+ checkpoint_dir (str, optional): The checkpoint_dir directory. Defaults to None,
1533
+ which uses "~/.cache/torch/hub/checkpoints".
1534
+ """
1535
+
1536
+ model_types = {
1537
+ "vit_h": {
1538
+ "name": "sam_vit_h_4b8939.pth",
1539
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
1540
+ },
1541
+ "vit_l": {
1542
+ "name": "sam_vit_l_0b3195.pth",
1543
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
1544
+ },
1545
+ "vit_b": {
1546
+ "name": "sam_vit_b_01ec64.pth",
1547
+ "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
1548
+ },
1549
+ }
1550
+
1551
+ if model_type not in model_types:
1552
+ raise ValueError(
1553
+ f"Invalid model_type: {model_type}. It must be one of {', '.join(model_types)}"
1554
+ )
1555
+
1556
+ if checkpoint_dir is None:
1557
+ checkpoint_dir = os.environ.get(
1558
+ "TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints")
1559
+ )
1560
+
1561
+ checkpoint = os.path.join(checkpoint_dir, model_types[model_type]["name"])
1562
+ if not os.path.exists(checkpoint):
1563
+ print(f"Model checkpoint for {model_type} not found.")
1564
+ url = model_types[model_type]["url"]
1565
+ if isinstance(url, str):
1566
+ download_file(url, checkpoint)
1567
+
1568
+ return checkpoint