geoai-py 0.4.0__py2.py3-none-any.whl → 0.4.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +76 -14
- geoai/download.py +644 -0
- geoai/extract.py +518 -0
- geoai/geoai.py +9 -1
- geoai/train.py +98 -12
- geoai/utils.py +260 -21
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/METADATA +8 -13
- geoai_py-0.4.2.dist-info/RECORD +15 -0
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/WHEEL +1 -1
- geoai_py-0.4.0.dist-info/RECORD +0 -15
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info/licenses}/LICENSE +0 -0
- {geoai_py-0.4.0.dist-info → geoai_py-0.4.2.dist-info}/top_level.txt +0 -0
geoai/extract.py
CHANGED
|
@@ -2397,3 +2397,521 @@ class ParkingSplotDetector(ObjectDetector):
|
|
|
2397
2397
|
num_classes=num_classes,
|
|
2398
2398
|
device=device,
|
|
2399
2399
|
)
|
|
2400
|
+
|
|
2401
|
+
|
|
2402
|
+
class AgricultureFieldDelineator(ObjectDetector):
|
|
2403
|
+
"""
|
|
2404
|
+
Agricultural field boundary delineation using a pre-trained Mask R-CNN model.
|
|
2405
|
+
|
|
2406
|
+
This class extends the ObjectDetector class to specifically handle Sentinel-2
|
|
2407
|
+
imagery with 12 spectral bands for agricultural field boundary detection.
|
|
2408
|
+
|
|
2409
|
+
Attributes:
|
|
2410
|
+
band_selection: List of band indices to use for prediction (default: RGB)
|
|
2411
|
+
sentinel_band_stats: Per-band statistics for Sentinel-2 data
|
|
2412
|
+
use_ndvi: Whether to calculate and include NDVI as an additional channel
|
|
2413
|
+
"""
|
|
2414
|
+
|
|
2415
|
+
def __init__(
|
|
2416
|
+
self,
|
|
2417
|
+
model_path="field_boundary_detector.pth",
|
|
2418
|
+
repo_id=None,
|
|
2419
|
+
model=None,
|
|
2420
|
+
device=None,
|
|
2421
|
+
band_selection=None,
|
|
2422
|
+
use_ndvi=False,
|
|
2423
|
+
):
|
|
2424
|
+
"""
|
|
2425
|
+
Initialize the field boundary delineator.
|
|
2426
|
+
|
|
2427
|
+
Args:
|
|
2428
|
+
model_path: Path to the .pth model file.
|
|
2429
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
2430
|
+
model: Custom model to use for inference.
|
|
2431
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
2432
|
+
band_selection: List of Sentinel-2 band indices to use (None = adapt based on model)
|
|
2433
|
+
use_ndvi: Whether to calculate and include NDVI as an additional channel
|
|
2434
|
+
"""
|
|
2435
|
+
# Save parameters before calling parent constructor
|
|
2436
|
+
self.custom_band_selection = band_selection
|
|
2437
|
+
self.use_ndvi = use_ndvi
|
|
2438
|
+
|
|
2439
|
+
# Set device (copied from parent init to ensure it's set before initialize_model)
|
|
2440
|
+
if device is None:
|
|
2441
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
2442
|
+
else:
|
|
2443
|
+
self.device = torch.device(device)
|
|
2444
|
+
|
|
2445
|
+
# Initialize model differently for multi-spectral input
|
|
2446
|
+
model = self.initialize_sentinel2_model(model)
|
|
2447
|
+
|
|
2448
|
+
# Call parent but with our custom model
|
|
2449
|
+
super().__init__(
|
|
2450
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2451
|
+
)
|
|
2452
|
+
|
|
2453
|
+
# Default Sentinel-2 band statistics (can be overridden with actual stats)
|
|
2454
|
+
# Band order: [B1, B2, B3, B4, B5, B6, B7, B8, B8A, B9, B10, B11, B12]
|
|
2455
|
+
self.sentinel_band_stats = {
|
|
2456
|
+
"means": [
|
|
2457
|
+
0.0975,
|
|
2458
|
+
0.0476,
|
|
2459
|
+
0.0598,
|
|
2460
|
+
0.0697,
|
|
2461
|
+
0.1077,
|
|
2462
|
+
0.1859,
|
|
2463
|
+
0.2378,
|
|
2464
|
+
0.2061,
|
|
2465
|
+
0.2598,
|
|
2466
|
+
0.4120,
|
|
2467
|
+
0.1956,
|
|
2468
|
+
0.1410,
|
|
2469
|
+
],
|
|
2470
|
+
"stds": [
|
|
2471
|
+
0.0551,
|
|
2472
|
+
0.0290,
|
|
2473
|
+
0.0298,
|
|
2474
|
+
0.0479,
|
|
2475
|
+
0.0506,
|
|
2476
|
+
0.0505,
|
|
2477
|
+
0.0747,
|
|
2478
|
+
0.0642,
|
|
2479
|
+
0.0782,
|
|
2480
|
+
0.1187,
|
|
2481
|
+
0.0651,
|
|
2482
|
+
0.0679,
|
|
2483
|
+
],
|
|
2484
|
+
}
|
|
2485
|
+
|
|
2486
|
+
# Set default band selection (RGB - typically B4, B3, B2 for Sentinel-2)
|
|
2487
|
+
self.band_selection = (
|
|
2488
|
+
self.custom_band_selection
|
|
2489
|
+
if self.custom_band_selection is not None
|
|
2490
|
+
else [3, 2, 1]
|
|
2491
|
+
) # R, G, B bands
|
|
2492
|
+
|
|
2493
|
+
# Customize parameters for field delineation
|
|
2494
|
+
self.confidence_threshold = 0.5 # Default confidence threshold
|
|
2495
|
+
self.overlap = 0.5 # Higher overlap for field boundary detection
|
|
2496
|
+
self.min_object_area = 1000 # Minimum area in pixels for field detection
|
|
2497
|
+
self.simplify_tolerance = 2.0 # Higher tolerance for field boundaries
|
|
2498
|
+
|
|
2499
|
+
def initialize_sentinel2_model(self, model=None):
|
|
2500
|
+
"""
|
|
2501
|
+
Initialize a Mask R-CNN model with a modified first layer to accept Sentinel-2 data.
|
|
2502
|
+
|
|
2503
|
+
Args:
|
|
2504
|
+
model: Pre-initialized model (optional)
|
|
2505
|
+
|
|
2506
|
+
Returns:
|
|
2507
|
+
Modified model with appropriate input channels
|
|
2508
|
+
"""
|
|
2509
|
+
import torchvision
|
|
2510
|
+
from torchvision.models.detection import maskrcnn_resnet50_fpn
|
|
2511
|
+
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
|
2512
|
+
|
|
2513
|
+
if model is not None:
|
|
2514
|
+
return model
|
|
2515
|
+
|
|
2516
|
+
# Determine number of input channels based on band selection and NDVI
|
|
2517
|
+
num_input_channels = (
|
|
2518
|
+
len(self.custom_band_selection)
|
|
2519
|
+
if self.custom_band_selection is not None
|
|
2520
|
+
else 3
|
|
2521
|
+
)
|
|
2522
|
+
if self.use_ndvi:
|
|
2523
|
+
num_input_channels += 1
|
|
2524
|
+
|
|
2525
|
+
print(f"Initializing Mask R-CNN model with {num_input_channels} input channels")
|
|
2526
|
+
|
|
2527
|
+
# Create a ResNet50 backbone with modified input channels
|
|
2528
|
+
backbone = resnet_fpn_backbone("resnet50", weights=None)
|
|
2529
|
+
|
|
2530
|
+
# Replace the first conv layer to accept multi-spectral input
|
|
2531
|
+
original_conv = backbone.body.conv1
|
|
2532
|
+
backbone.body.conv1 = torch.nn.Conv2d(
|
|
2533
|
+
num_input_channels,
|
|
2534
|
+
original_conv.out_channels,
|
|
2535
|
+
kernel_size=original_conv.kernel_size,
|
|
2536
|
+
stride=original_conv.stride,
|
|
2537
|
+
padding=original_conv.padding,
|
|
2538
|
+
bias=original_conv.bias is not None,
|
|
2539
|
+
)
|
|
2540
|
+
|
|
2541
|
+
# Create Mask R-CNN with our modified backbone
|
|
2542
|
+
model = maskrcnn_resnet50_fpn(
|
|
2543
|
+
backbone=backbone,
|
|
2544
|
+
num_classes=2, # Background + field
|
|
2545
|
+
image_mean=[0.485] * num_input_channels, # Extend mean to all channels
|
|
2546
|
+
image_std=[0.229] * num_input_channels, # Extend std to all channels
|
|
2547
|
+
)
|
|
2548
|
+
|
|
2549
|
+
model.to(self.device)
|
|
2550
|
+
return model
|
|
2551
|
+
|
|
2552
|
+
def preprocess_sentinel_bands(self, image_data, band_selection=None, use_ndvi=None):
|
|
2553
|
+
"""
|
|
2554
|
+
Preprocess Sentinel-2 band data for model input.
|
|
2555
|
+
|
|
2556
|
+
Args:
|
|
2557
|
+
image_data: Raw Sentinel-2 image data as numpy array [bands, height, width]
|
|
2558
|
+
band_selection: List of band indices to use (overrides instance default if provided)
|
|
2559
|
+
use_ndvi: Whether to include NDVI (overrides instance default if provided)
|
|
2560
|
+
|
|
2561
|
+
Returns:
|
|
2562
|
+
Processed tensor ready for model input
|
|
2563
|
+
"""
|
|
2564
|
+
# Use instance defaults if not specified
|
|
2565
|
+
band_selection = (
|
|
2566
|
+
band_selection if band_selection is not None else self.band_selection
|
|
2567
|
+
)
|
|
2568
|
+
use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
|
|
2569
|
+
|
|
2570
|
+
# Select bands
|
|
2571
|
+
selected_bands = image_data[band_selection]
|
|
2572
|
+
|
|
2573
|
+
# Calculate NDVI if requested (using B8 and B4 which are indices 7 and 3)
|
|
2574
|
+
if (
|
|
2575
|
+
use_ndvi
|
|
2576
|
+
and 7 in range(image_data.shape[0])
|
|
2577
|
+
and 3 in range(image_data.shape[0])
|
|
2578
|
+
):
|
|
2579
|
+
nir = image_data[7].astype(np.float32) # B8 (NIR)
|
|
2580
|
+
red = image_data[3].astype(np.float32) # B4 (Red)
|
|
2581
|
+
|
|
2582
|
+
# Avoid division by zero
|
|
2583
|
+
denominator = nir + red
|
|
2584
|
+
ndvi = np.zeros_like(nir)
|
|
2585
|
+
valid_mask = denominator > 0
|
|
2586
|
+
ndvi[valid_mask] = (nir[valid_mask] - red[valid_mask]) / denominator[
|
|
2587
|
+
valid_mask
|
|
2588
|
+
]
|
|
2589
|
+
|
|
2590
|
+
# Rescale NDVI from [-1, 1] to [0, 1]
|
|
2591
|
+
ndvi = (ndvi + 1) / 2
|
|
2592
|
+
|
|
2593
|
+
# Add NDVI as an additional channel
|
|
2594
|
+
selected_bands = np.vstack([selected_bands, ndvi[np.newaxis, :, :]])
|
|
2595
|
+
|
|
2596
|
+
# Convert to tensor
|
|
2597
|
+
image_tensor = torch.from_numpy(selected_bands).float()
|
|
2598
|
+
|
|
2599
|
+
# Normalize using band statistics
|
|
2600
|
+
for i, band_idx in enumerate(band_selection):
|
|
2601
|
+
# Make sure band_idx is within range of our statistics
|
|
2602
|
+
if band_idx < len(self.sentinel_band_stats["means"]):
|
|
2603
|
+
mean = self.sentinel_band_stats["means"][band_idx]
|
|
2604
|
+
std = self.sentinel_band_stats["stds"][band_idx]
|
|
2605
|
+
image_tensor[i] = (image_tensor[i] - mean) / std
|
|
2606
|
+
|
|
2607
|
+
# If NDVI was added, normalize it too (last channel)
|
|
2608
|
+
if use_ndvi:
|
|
2609
|
+
# NDVI is already roughly in [0,1] range, just standardize it slightly
|
|
2610
|
+
image_tensor[-1] = (image_tensor[-1] - 0.5) / 0.5
|
|
2611
|
+
|
|
2612
|
+
return image_tensor
|
|
2613
|
+
|
|
2614
|
+
def update_band_stats(self, raster_path, band_selection=None, sample_size=1000):
|
|
2615
|
+
"""
|
|
2616
|
+
Update band statistics from the input Sentinel-2 raster.
|
|
2617
|
+
|
|
2618
|
+
Args:
|
|
2619
|
+
raster_path: Path to the Sentinel-2 raster file
|
|
2620
|
+
band_selection: Specific bands to update (None = update all available)
|
|
2621
|
+
sample_size: Number of random pixels to sample for statistics calculation
|
|
2622
|
+
|
|
2623
|
+
Returns:
|
|
2624
|
+
Updated band statistics dictionary
|
|
2625
|
+
"""
|
|
2626
|
+
with rasterio.open(raster_path) as src:
|
|
2627
|
+
# Check if this is likely a Sentinel-2 product
|
|
2628
|
+
band_count = src.count
|
|
2629
|
+
if band_count < 3:
|
|
2630
|
+
print(
|
|
2631
|
+
f"Warning: Raster has only {band_count} bands, may not be Sentinel-2 data"
|
|
2632
|
+
)
|
|
2633
|
+
|
|
2634
|
+
# Get dimensions
|
|
2635
|
+
height, width = src.height, src.width
|
|
2636
|
+
|
|
2637
|
+
# Determine which bands to analyze
|
|
2638
|
+
if band_selection is None:
|
|
2639
|
+
band_selection = list(range(1, band_count + 1)) # 1-indexed
|
|
2640
|
+
|
|
2641
|
+
# Initialize arrays for band statistics
|
|
2642
|
+
means = []
|
|
2643
|
+
stds = []
|
|
2644
|
+
|
|
2645
|
+
# Sample random pixels
|
|
2646
|
+
np.random.seed(42) # For reproducibility
|
|
2647
|
+
sample_rows = np.random.randint(0, height, sample_size)
|
|
2648
|
+
sample_cols = np.random.randint(0, width, sample_size)
|
|
2649
|
+
|
|
2650
|
+
# Calculate statistics for each band
|
|
2651
|
+
for band in band_selection:
|
|
2652
|
+
# Read band data
|
|
2653
|
+
band_data = src.read(band)
|
|
2654
|
+
|
|
2655
|
+
# Sample values
|
|
2656
|
+
sample_values = band_data[sample_rows, sample_cols]
|
|
2657
|
+
|
|
2658
|
+
# Remove invalid values (e.g., nodata)
|
|
2659
|
+
valid_samples = sample_values[np.isfinite(sample_values)]
|
|
2660
|
+
|
|
2661
|
+
# Calculate statistics
|
|
2662
|
+
mean = float(np.mean(valid_samples))
|
|
2663
|
+
std = float(np.std(valid_samples))
|
|
2664
|
+
|
|
2665
|
+
# Store results
|
|
2666
|
+
means.append(mean)
|
|
2667
|
+
stds.append(std)
|
|
2668
|
+
|
|
2669
|
+
print(f"Band {band}: mean={mean:.4f}, std={std:.4f}")
|
|
2670
|
+
|
|
2671
|
+
# Update instance variables
|
|
2672
|
+
self.sentinel_band_stats = {"means": means, "stds": stds}
|
|
2673
|
+
|
|
2674
|
+
return self.sentinel_band_stats
|
|
2675
|
+
|
|
2676
|
+
def process_sentinel_raster(
|
|
2677
|
+
self,
|
|
2678
|
+
raster_path,
|
|
2679
|
+
output_path=None,
|
|
2680
|
+
batch_size=4,
|
|
2681
|
+
band_selection=None,
|
|
2682
|
+
use_ndvi=None,
|
|
2683
|
+
filter_edges=True,
|
|
2684
|
+
edge_buffer=20,
|
|
2685
|
+
**kwargs,
|
|
2686
|
+
):
|
|
2687
|
+
"""
|
|
2688
|
+
Process a Sentinel-2 raster to extract field boundaries.
|
|
2689
|
+
|
|
2690
|
+
Args:
|
|
2691
|
+
raster_path: Path to Sentinel-2 raster file
|
|
2692
|
+
output_path: Path to output GeoJSON or Parquet file (optional)
|
|
2693
|
+
batch_size: Batch size for processing
|
|
2694
|
+
band_selection: List of bands to use (None = use instance default)
|
|
2695
|
+
use_ndvi: Whether to include NDVI (None = use instance default)
|
|
2696
|
+
filter_edges: Whether to filter out objects at the edges of the image
|
|
2697
|
+
edge_buffer: Size of edge buffer in pixels to filter out objects
|
|
2698
|
+
**kwargs: Additional parameters for processing
|
|
2699
|
+
|
|
2700
|
+
Returns:
|
|
2701
|
+
GeoDataFrame with field boundaries
|
|
2702
|
+
"""
|
|
2703
|
+
# Use instance defaults if not specified
|
|
2704
|
+
band_selection = (
|
|
2705
|
+
band_selection if band_selection is not None else self.band_selection
|
|
2706
|
+
)
|
|
2707
|
+
use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
|
|
2708
|
+
|
|
2709
|
+
# Get parameters from kwargs or use instance defaults
|
|
2710
|
+
confidence_threshold = kwargs.get(
|
|
2711
|
+
"confidence_threshold", self.confidence_threshold
|
|
2712
|
+
)
|
|
2713
|
+
overlap = kwargs.get("overlap", self.overlap)
|
|
2714
|
+
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
2715
|
+
nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
|
|
2716
|
+
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
2717
|
+
min_object_area = kwargs.get("min_object_area", self.min_object_area)
|
|
2718
|
+
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
2719
|
+
|
|
2720
|
+
# Update band statistics if not already done
|
|
2721
|
+
if kwargs.get("update_stats", True):
|
|
2722
|
+
self.update_band_stats(raster_path, band_selection)
|
|
2723
|
+
|
|
2724
|
+
print(f"Processing with parameters:")
|
|
2725
|
+
print(f"- Using bands: {band_selection}")
|
|
2726
|
+
print(f"- Include NDVI: {use_ndvi}")
|
|
2727
|
+
print(f"- Confidence threshold: {confidence_threshold}")
|
|
2728
|
+
print(f"- Tile overlap: {overlap}")
|
|
2729
|
+
print(f"- Chip size: {chip_size}")
|
|
2730
|
+
print(f"- Filter edge objects: {filter_edges}")
|
|
2731
|
+
|
|
2732
|
+
# Create a custom Sentinel-2 dataset class
|
|
2733
|
+
class Sentinel2Dataset(torch.utils.data.Dataset):
|
|
2734
|
+
def __init__(
|
|
2735
|
+
self,
|
|
2736
|
+
raster_path,
|
|
2737
|
+
chip_size,
|
|
2738
|
+
stride_x,
|
|
2739
|
+
stride_y,
|
|
2740
|
+
band_selection,
|
|
2741
|
+
use_ndvi,
|
|
2742
|
+
field_delineator,
|
|
2743
|
+
):
|
|
2744
|
+
self.raster_path = raster_path
|
|
2745
|
+
self.chip_size = chip_size
|
|
2746
|
+
self.stride_x = stride_x
|
|
2747
|
+
self.stride_y = stride_y
|
|
2748
|
+
self.band_selection = band_selection
|
|
2749
|
+
self.use_ndvi = use_ndvi
|
|
2750
|
+
self.field_delineator = field_delineator
|
|
2751
|
+
|
|
2752
|
+
with rasterio.open(self.raster_path) as src:
|
|
2753
|
+
self.height = src.height
|
|
2754
|
+
self.width = src.width
|
|
2755
|
+
self.count = src.count
|
|
2756
|
+
self.crs = src.crs
|
|
2757
|
+
self.transform = src.transform
|
|
2758
|
+
|
|
2759
|
+
# Calculate row_starts and col_starts
|
|
2760
|
+
self.row_starts = []
|
|
2761
|
+
self.col_starts = []
|
|
2762
|
+
|
|
2763
|
+
# Normal row starts using stride
|
|
2764
|
+
for r in range((self.height - 1) // self.stride_y):
|
|
2765
|
+
self.row_starts.append(r * self.stride_y)
|
|
2766
|
+
|
|
2767
|
+
# Add a special last row that ensures we reach the bottom edge
|
|
2768
|
+
if self.height > self.chip_size[0]:
|
|
2769
|
+
self.row_starts.append(max(0, self.height - self.chip_size[0]))
|
|
2770
|
+
else:
|
|
2771
|
+
# If the image is smaller than chip size, just start at 0
|
|
2772
|
+
if not self.row_starts:
|
|
2773
|
+
self.row_starts.append(0)
|
|
2774
|
+
|
|
2775
|
+
# Normal column starts using stride
|
|
2776
|
+
for c in range((self.width - 1) // self.stride_x):
|
|
2777
|
+
self.col_starts.append(c * self.stride_x)
|
|
2778
|
+
|
|
2779
|
+
# Add a special last column that ensures we reach the right edge
|
|
2780
|
+
if self.width > self.chip_size[1]:
|
|
2781
|
+
self.col_starts.append(max(0, self.width - self.chip_size[1]))
|
|
2782
|
+
else:
|
|
2783
|
+
# If the image is smaller than chip size, just start at 0
|
|
2784
|
+
if not self.col_starts:
|
|
2785
|
+
self.col_starts.append(0)
|
|
2786
|
+
|
|
2787
|
+
# Calculate number of tiles
|
|
2788
|
+
self.rows = len(self.row_starts)
|
|
2789
|
+
self.cols = len(self.col_starts)
|
|
2790
|
+
|
|
2791
|
+
print(
|
|
2792
|
+
f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
|
|
2793
|
+
)
|
|
2794
|
+
print(f"Image dimensions: {self.width} x {self.height} pixels")
|
|
2795
|
+
print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
|
|
2796
|
+
|
|
2797
|
+
def __len__(self):
|
|
2798
|
+
return self.rows * self.cols
|
|
2799
|
+
|
|
2800
|
+
def __getitem__(self, idx):
|
|
2801
|
+
# Convert flat index to grid position
|
|
2802
|
+
row = idx // self.cols
|
|
2803
|
+
col = idx % self.cols
|
|
2804
|
+
|
|
2805
|
+
# Get pre-calculated starting positions
|
|
2806
|
+
j = self.row_starts[row]
|
|
2807
|
+
i = self.col_starts[col]
|
|
2808
|
+
|
|
2809
|
+
# Read window from raster
|
|
2810
|
+
with rasterio.open(self.raster_path) as src:
|
|
2811
|
+
# Make sure we don't read outside the image
|
|
2812
|
+
width = min(self.chip_size[1], self.width - i)
|
|
2813
|
+
height = min(self.chip_size[0], self.height - j)
|
|
2814
|
+
|
|
2815
|
+
window = Window(i, j, width, height)
|
|
2816
|
+
|
|
2817
|
+
# Read all bands
|
|
2818
|
+
image = src.read(window=window)
|
|
2819
|
+
|
|
2820
|
+
# Handle partial windows at edges by padding
|
|
2821
|
+
if (
|
|
2822
|
+
image.shape[1] != self.chip_size[0]
|
|
2823
|
+
or image.shape[2] != self.chip_size[1]
|
|
2824
|
+
):
|
|
2825
|
+
temp = np.zeros(
|
|
2826
|
+
(image.shape[0], self.chip_size[0], self.chip_size[1]),
|
|
2827
|
+
dtype=image.dtype,
|
|
2828
|
+
)
|
|
2829
|
+
temp[:, : image.shape[1], : image.shape[2]] = image
|
|
2830
|
+
image = temp
|
|
2831
|
+
|
|
2832
|
+
# Preprocess bands for the model
|
|
2833
|
+
image_tensor = self.field_delineator.preprocess_sentinel_bands(
|
|
2834
|
+
image, self.band_selection, self.use_ndvi
|
|
2835
|
+
)
|
|
2836
|
+
|
|
2837
|
+
# Get geographic bounds for the window
|
|
2838
|
+
with rasterio.open(self.raster_path) as src:
|
|
2839
|
+
window_transform = src.window_transform(window)
|
|
2840
|
+
minx, miny = window_transform * (0, height)
|
|
2841
|
+
maxx, maxy = window_transform * (width, 0)
|
|
2842
|
+
bbox = [minx, miny, maxx, maxy]
|
|
2843
|
+
|
|
2844
|
+
return {
|
|
2845
|
+
"image": image_tensor,
|
|
2846
|
+
"bbox": bbox,
|
|
2847
|
+
"coords": torch.tensor([i, j], dtype=torch.long),
|
|
2848
|
+
"window_size": torch.tensor([width, height], dtype=torch.long),
|
|
2849
|
+
}
|
|
2850
|
+
|
|
2851
|
+
# Calculate stride based on overlap
|
|
2852
|
+
stride_x = int(chip_size[1] * (1 - overlap))
|
|
2853
|
+
stride_y = int(chip_size[0] * (1 - overlap))
|
|
2854
|
+
|
|
2855
|
+
# Create dataset
|
|
2856
|
+
dataset = Sentinel2Dataset(
|
|
2857
|
+
raster_path=raster_path,
|
|
2858
|
+
chip_size=chip_size,
|
|
2859
|
+
stride_x=stride_x,
|
|
2860
|
+
stride_y=stride_y,
|
|
2861
|
+
band_selection=band_selection,
|
|
2862
|
+
use_ndvi=use_ndvi,
|
|
2863
|
+
field_delineator=self,
|
|
2864
|
+
)
|
|
2865
|
+
|
|
2866
|
+
# Define custom collate function
|
|
2867
|
+
def custom_collate(batch):
|
|
2868
|
+
elem = batch[0]
|
|
2869
|
+
if isinstance(elem, dict):
|
|
2870
|
+
result = {}
|
|
2871
|
+
for key in elem:
|
|
2872
|
+
if key == "bbox":
|
|
2873
|
+
# Don't collate bbox objects, keep as list
|
|
2874
|
+
result[key] = [d[key] for d in batch]
|
|
2875
|
+
else:
|
|
2876
|
+
# For tensors and other collatable types
|
|
2877
|
+
try:
|
|
2878
|
+
result[key] = (
|
|
2879
|
+
torch.utils.data._utils.collate.default_collate(
|
|
2880
|
+
[d[key] for d in batch]
|
|
2881
|
+
)
|
|
2882
|
+
)
|
|
2883
|
+
except TypeError:
|
|
2884
|
+
# Fall back to list for non-collatable types
|
|
2885
|
+
result[key] = [d[key] for d in batch]
|
|
2886
|
+
return result
|
|
2887
|
+
else:
|
|
2888
|
+
# Default collate for non-dict types
|
|
2889
|
+
return torch.utils.data._utils.collate.default_collate(batch)
|
|
2890
|
+
|
|
2891
|
+
# Create dataloader
|
|
2892
|
+
dataloader = torch.utils.data.DataLoader(
|
|
2893
|
+
dataset,
|
|
2894
|
+
batch_size=batch_size,
|
|
2895
|
+
shuffle=False,
|
|
2896
|
+
num_workers=0,
|
|
2897
|
+
collate_fn=custom_collate,
|
|
2898
|
+
)
|
|
2899
|
+
|
|
2900
|
+
# Process batches (call the parent class's process_raster method)
|
|
2901
|
+
# We'll adapt the process_raster method to work with our Sentinel2Dataset
|
|
2902
|
+
results = super().process_raster(
|
|
2903
|
+
raster_path=raster_path,
|
|
2904
|
+
output_path=output_path,
|
|
2905
|
+
batch_size=batch_size,
|
|
2906
|
+
filter_edges=filter_edges,
|
|
2907
|
+
edge_buffer=edge_buffer,
|
|
2908
|
+
confidence_threshold=confidence_threshold,
|
|
2909
|
+
overlap=overlap,
|
|
2910
|
+
chip_size=chip_size,
|
|
2911
|
+
nms_iou_threshold=nms_iou_threshold,
|
|
2912
|
+
mask_threshold=mask_threshold,
|
|
2913
|
+
min_object_area=min_object_area,
|
|
2914
|
+
simplify_tolerance=simplify_tolerance,
|
|
2915
|
+
)
|
|
2916
|
+
|
|
2917
|
+
return results
|
geoai/geoai.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
"""Main module."""
|
|
2
2
|
|
|
3
|
+
from .download import (
|
|
4
|
+
download_naip,
|
|
5
|
+
download_overture_buildings,
|
|
6
|
+
download_pc_stac_item,
|
|
7
|
+
pc_collection_list,
|
|
8
|
+
pc_stac_search,
|
|
9
|
+
pc_stac_download,
|
|
10
|
+
)
|
|
3
11
|
from .extract import *
|
|
4
12
|
from .hf import *
|
|
5
13
|
from .segment import *
|
|
14
|
+
from .train import object_detection, train_MaskRCNN_model
|
|
6
15
|
from .utils import *
|
|
7
|
-
from .train import train_MaskRCNN_model, object_detection
|