geoai-py 0.4.1__py2.py3-none-any.whl → 0.4.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/extract.py CHANGED
@@ -2397,3 +2397,521 @@ class ParkingSplotDetector(ObjectDetector):
2397
2397
  num_classes=num_classes,
2398
2398
  device=device,
2399
2399
  )
2400
+
2401
+
2402
+ class AgricultureFieldDelineator(ObjectDetector):
2403
+ """
2404
+ Agricultural field boundary delineation using a pre-trained Mask R-CNN model.
2405
+
2406
+ This class extends the ObjectDetector class to specifically handle Sentinel-2
2407
+ imagery with 12 spectral bands for agricultural field boundary detection.
2408
+
2409
+ Attributes:
2410
+ band_selection: List of band indices to use for prediction (default: RGB)
2411
+ sentinel_band_stats: Per-band statistics for Sentinel-2 data
2412
+ use_ndvi: Whether to calculate and include NDVI as an additional channel
2413
+ """
2414
+
2415
+ def __init__(
2416
+ self,
2417
+ model_path="field_boundary_detector.pth",
2418
+ repo_id=None,
2419
+ model=None,
2420
+ device=None,
2421
+ band_selection=None,
2422
+ use_ndvi=False,
2423
+ ):
2424
+ """
2425
+ Initialize the field boundary delineator.
2426
+
2427
+ Args:
2428
+ model_path: Path to the .pth model file.
2429
+ repo_id: Repo ID for loading models from the Hub.
2430
+ model: Custom model to use for inference.
2431
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2432
+ band_selection: List of Sentinel-2 band indices to use (None = adapt based on model)
2433
+ use_ndvi: Whether to calculate and include NDVI as an additional channel
2434
+ """
2435
+ # Save parameters before calling parent constructor
2436
+ self.custom_band_selection = band_selection
2437
+ self.use_ndvi = use_ndvi
2438
+
2439
+ # Set device (copied from parent init to ensure it's set before initialize_model)
2440
+ if device is None:
2441
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2442
+ else:
2443
+ self.device = torch.device(device)
2444
+
2445
+ # Initialize model differently for multi-spectral input
2446
+ model = self.initialize_sentinel2_model(model)
2447
+
2448
+ # Call parent but with our custom model
2449
+ super().__init__(
2450
+ model_path=model_path, repo_id=repo_id, model=model, device=device
2451
+ )
2452
+
2453
+ # Default Sentinel-2 band statistics (can be overridden with actual stats)
2454
+ # Band order: [B1, B2, B3, B4, B5, B6, B7, B8, B8A, B9, B10, B11, B12]
2455
+ self.sentinel_band_stats = {
2456
+ "means": [
2457
+ 0.0975,
2458
+ 0.0476,
2459
+ 0.0598,
2460
+ 0.0697,
2461
+ 0.1077,
2462
+ 0.1859,
2463
+ 0.2378,
2464
+ 0.2061,
2465
+ 0.2598,
2466
+ 0.4120,
2467
+ 0.1956,
2468
+ 0.1410,
2469
+ ],
2470
+ "stds": [
2471
+ 0.0551,
2472
+ 0.0290,
2473
+ 0.0298,
2474
+ 0.0479,
2475
+ 0.0506,
2476
+ 0.0505,
2477
+ 0.0747,
2478
+ 0.0642,
2479
+ 0.0782,
2480
+ 0.1187,
2481
+ 0.0651,
2482
+ 0.0679,
2483
+ ],
2484
+ }
2485
+
2486
+ # Set default band selection (RGB - typically B4, B3, B2 for Sentinel-2)
2487
+ self.band_selection = (
2488
+ self.custom_band_selection
2489
+ if self.custom_band_selection is not None
2490
+ else [3, 2, 1]
2491
+ ) # R, G, B bands
2492
+
2493
+ # Customize parameters for field delineation
2494
+ self.confidence_threshold = 0.5 # Default confidence threshold
2495
+ self.overlap = 0.5 # Higher overlap for field boundary detection
2496
+ self.min_object_area = 1000 # Minimum area in pixels for field detection
2497
+ self.simplify_tolerance = 2.0 # Higher tolerance for field boundaries
2498
+
2499
+ def initialize_sentinel2_model(self, model=None):
2500
+ """
2501
+ Initialize a Mask R-CNN model with a modified first layer to accept Sentinel-2 data.
2502
+
2503
+ Args:
2504
+ model: Pre-initialized model (optional)
2505
+
2506
+ Returns:
2507
+ Modified model with appropriate input channels
2508
+ """
2509
+ import torchvision
2510
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
2511
+ from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
2512
+
2513
+ if model is not None:
2514
+ return model
2515
+
2516
+ # Determine number of input channels based on band selection and NDVI
2517
+ num_input_channels = (
2518
+ len(self.custom_band_selection)
2519
+ if self.custom_band_selection is not None
2520
+ else 3
2521
+ )
2522
+ if self.use_ndvi:
2523
+ num_input_channels += 1
2524
+
2525
+ print(f"Initializing Mask R-CNN model with {num_input_channels} input channels")
2526
+
2527
+ # Create a ResNet50 backbone with modified input channels
2528
+ backbone = resnet_fpn_backbone("resnet50", weights=None)
2529
+
2530
+ # Replace the first conv layer to accept multi-spectral input
2531
+ original_conv = backbone.body.conv1
2532
+ backbone.body.conv1 = torch.nn.Conv2d(
2533
+ num_input_channels,
2534
+ original_conv.out_channels,
2535
+ kernel_size=original_conv.kernel_size,
2536
+ stride=original_conv.stride,
2537
+ padding=original_conv.padding,
2538
+ bias=original_conv.bias is not None,
2539
+ )
2540
+
2541
+ # Create Mask R-CNN with our modified backbone
2542
+ model = maskrcnn_resnet50_fpn(
2543
+ backbone=backbone,
2544
+ num_classes=2, # Background + field
2545
+ image_mean=[0.485] * num_input_channels, # Extend mean to all channels
2546
+ image_std=[0.229] * num_input_channels, # Extend std to all channels
2547
+ )
2548
+
2549
+ model.to(self.device)
2550
+ return model
2551
+
2552
+ def preprocess_sentinel_bands(self, image_data, band_selection=None, use_ndvi=None):
2553
+ """
2554
+ Preprocess Sentinel-2 band data for model input.
2555
+
2556
+ Args:
2557
+ image_data: Raw Sentinel-2 image data as numpy array [bands, height, width]
2558
+ band_selection: List of band indices to use (overrides instance default if provided)
2559
+ use_ndvi: Whether to include NDVI (overrides instance default if provided)
2560
+
2561
+ Returns:
2562
+ Processed tensor ready for model input
2563
+ """
2564
+ # Use instance defaults if not specified
2565
+ band_selection = (
2566
+ band_selection if band_selection is not None else self.band_selection
2567
+ )
2568
+ use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
2569
+
2570
+ # Select bands
2571
+ selected_bands = image_data[band_selection]
2572
+
2573
+ # Calculate NDVI if requested (using B8 and B4 which are indices 7 and 3)
2574
+ if (
2575
+ use_ndvi
2576
+ and 7 in range(image_data.shape[0])
2577
+ and 3 in range(image_data.shape[0])
2578
+ ):
2579
+ nir = image_data[7].astype(np.float32) # B8 (NIR)
2580
+ red = image_data[3].astype(np.float32) # B4 (Red)
2581
+
2582
+ # Avoid division by zero
2583
+ denominator = nir + red
2584
+ ndvi = np.zeros_like(nir)
2585
+ valid_mask = denominator > 0
2586
+ ndvi[valid_mask] = (nir[valid_mask] - red[valid_mask]) / denominator[
2587
+ valid_mask
2588
+ ]
2589
+
2590
+ # Rescale NDVI from [-1, 1] to [0, 1]
2591
+ ndvi = (ndvi + 1) / 2
2592
+
2593
+ # Add NDVI as an additional channel
2594
+ selected_bands = np.vstack([selected_bands, ndvi[np.newaxis, :, :]])
2595
+
2596
+ # Convert to tensor
2597
+ image_tensor = torch.from_numpy(selected_bands).float()
2598
+
2599
+ # Normalize using band statistics
2600
+ for i, band_idx in enumerate(band_selection):
2601
+ # Make sure band_idx is within range of our statistics
2602
+ if band_idx < len(self.sentinel_band_stats["means"]):
2603
+ mean = self.sentinel_band_stats["means"][band_idx]
2604
+ std = self.sentinel_band_stats["stds"][band_idx]
2605
+ image_tensor[i] = (image_tensor[i] - mean) / std
2606
+
2607
+ # If NDVI was added, normalize it too (last channel)
2608
+ if use_ndvi:
2609
+ # NDVI is already roughly in [0,1] range, just standardize it slightly
2610
+ image_tensor[-1] = (image_tensor[-1] - 0.5) / 0.5
2611
+
2612
+ return image_tensor
2613
+
2614
+ def update_band_stats(self, raster_path, band_selection=None, sample_size=1000):
2615
+ """
2616
+ Update band statistics from the input Sentinel-2 raster.
2617
+
2618
+ Args:
2619
+ raster_path: Path to the Sentinel-2 raster file
2620
+ band_selection: Specific bands to update (None = update all available)
2621
+ sample_size: Number of random pixels to sample for statistics calculation
2622
+
2623
+ Returns:
2624
+ Updated band statistics dictionary
2625
+ """
2626
+ with rasterio.open(raster_path) as src:
2627
+ # Check if this is likely a Sentinel-2 product
2628
+ band_count = src.count
2629
+ if band_count < 3:
2630
+ print(
2631
+ f"Warning: Raster has only {band_count} bands, may not be Sentinel-2 data"
2632
+ )
2633
+
2634
+ # Get dimensions
2635
+ height, width = src.height, src.width
2636
+
2637
+ # Determine which bands to analyze
2638
+ if band_selection is None:
2639
+ band_selection = list(range(1, band_count + 1)) # 1-indexed
2640
+
2641
+ # Initialize arrays for band statistics
2642
+ means = []
2643
+ stds = []
2644
+
2645
+ # Sample random pixels
2646
+ np.random.seed(42) # For reproducibility
2647
+ sample_rows = np.random.randint(0, height, sample_size)
2648
+ sample_cols = np.random.randint(0, width, sample_size)
2649
+
2650
+ # Calculate statistics for each band
2651
+ for band in band_selection:
2652
+ # Read band data
2653
+ band_data = src.read(band)
2654
+
2655
+ # Sample values
2656
+ sample_values = band_data[sample_rows, sample_cols]
2657
+
2658
+ # Remove invalid values (e.g., nodata)
2659
+ valid_samples = sample_values[np.isfinite(sample_values)]
2660
+
2661
+ # Calculate statistics
2662
+ mean = float(np.mean(valid_samples))
2663
+ std = float(np.std(valid_samples))
2664
+
2665
+ # Store results
2666
+ means.append(mean)
2667
+ stds.append(std)
2668
+
2669
+ print(f"Band {band}: mean={mean:.4f}, std={std:.4f}")
2670
+
2671
+ # Update instance variables
2672
+ self.sentinel_band_stats = {"means": means, "stds": stds}
2673
+
2674
+ return self.sentinel_band_stats
2675
+
2676
+ def process_sentinel_raster(
2677
+ self,
2678
+ raster_path,
2679
+ output_path=None,
2680
+ batch_size=4,
2681
+ band_selection=None,
2682
+ use_ndvi=None,
2683
+ filter_edges=True,
2684
+ edge_buffer=20,
2685
+ **kwargs,
2686
+ ):
2687
+ """
2688
+ Process a Sentinel-2 raster to extract field boundaries.
2689
+
2690
+ Args:
2691
+ raster_path: Path to Sentinel-2 raster file
2692
+ output_path: Path to output GeoJSON or Parquet file (optional)
2693
+ batch_size: Batch size for processing
2694
+ band_selection: List of bands to use (None = use instance default)
2695
+ use_ndvi: Whether to include NDVI (None = use instance default)
2696
+ filter_edges: Whether to filter out objects at the edges of the image
2697
+ edge_buffer: Size of edge buffer in pixels to filter out objects
2698
+ **kwargs: Additional parameters for processing
2699
+
2700
+ Returns:
2701
+ GeoDataFrame with field boundaries
2702
+ """
2703
+ # Use instance defaults if not specified
2704
+ band_selection = (
2705
+ band_selection if band_selection is not None else self.band_selection
2706
+ )
2707
+ use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
2708
+
2709
+ # Get parameters from kwargs or use instance defaults
2710
+ confidence_threshold = kwargs.get(
2711
+ "confidence_threshold", self.confidence_threshold
2712
+ )
2713
+ overlap = kwargs.get("overlap", self.overlap)
2714
+ chip_size = kwargs.get("chip_size", self.chip_size)
2715
+ nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
2716
+ mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
2717
+ min_object_area = kwargs.get("min_object_area", self.min_object_area)
2718
+ simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
2719
+
2720
+ # Update band statistics if not already done
2721
+ if kwargs.get("update_stats", True):
2722
+ self.update_band_stats(raster_path, band_selection)
2723
+
2724
+ print(f"Processing with parameters:")
2725
+ print(f"- Using bands: {band_selection}")
2726
+ print(f"- Include NDVI: {use_ndvi}")
2727
+ print(f"- Confidence threshold: {confidence_threshold}")
2728
+ print(f"- Tile overlap: {overlap}")
2729
+ print(f"- Chip size: {chip_size}")
2730
+ print(f"- Filter edge objects: {filter_edges}")
2731
+
2732
+ # Create a custom Sentinel-2 dataset class
2733
+ class Sentinel2Dataset(torch.utils.data.Dataset):
2734
+ def __init__(
2735
+ self,
2736
+ raster_path,
2737
+ chip_size,
2738
+ stride_x,
2739
+ stride_y,
2740
+ band_selection,
2741
+ use_ndvi,
2742
+ field_delineator,
2743
+ ):
2744
+ self.raster_path = raster_path
2745
+ self.chip_size = chip_size
2746
+ self.stride_x = stride_x
2747
+ self.stride_y = stride_y
2748
+ self.band_selection = band_selection
2749
+ self.use_ndvi = use_ndvi
2750
+ self.field_delineator = field_delineator
2751
+
2752
+ with rasterio.open(self.raster_path) as src:
2753
+ self.height = src.height
2754
+ self.width = src.width
2755
+ self.count = src.count
2756
+ self.crs = src.crs
2757
+ self.transform = src.transform
2758
+
2759
+ # Calculate row_starts and col_starts
2760
+ self.row_starts = []
2761
+ self.col_starts = []
2762
+
2763
+ # Normal row starts using stride
2764
+ for r in range((self.height - 1) // self.stride_y):
2765
+ self.row_starts.append(r * self.stride_y)
2766
+
2767
+ # Add a special last row that ensures we reach the bottom edge
2768
+ if self.height > self.chip_size[0]:
2769
+ self.row_starts.append(max(0, self.height - self.chip_size[0]))
2770
+ else:
2771
+ # If the image is smaller than chip size, just start at 0
2772
+ if not self.row_starts:
2773
+ self.row_starts.append(0)
2774
+
2775
+ # Normal column starts using stride
2776
+ for c in range((self.width - 1) // self.stride_x):
2777
+ self.col_starts.append(c * self.stride_x)
2778
+
2779
+ # Add a special last column that ensures we reach the right edge
2780
+ if self.width > self.chip_size[1]:
2781
+ self.col_starts.append(max(0, self.width - self.chip_size[1]))
2782
+ else:
2783
+ # If the image is smaller than chip size, just start at 0
2784
+ if not self.col_starts:
2785
+ self.col_starts.append(0)
2786
+
2787
+ # Calculate number of tiles
2788
+ self.rows = len(self.row_starts)
2789
+ self.cols = len(self.col_starts)
2790
+
2791
+ print(
2792
+ f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
2793
+ )
2794
+ print(f"Image dimensions: {self.width} x {self.height} pixels")
2795
+ print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
2796
+
2797
+ def __len__(self):
2798
+ return self.rows * self.cols
2799
+
2800
+ def __getitem__(self, idx):
2801
+ # Convert flat index to grid position
2802
+ row = idx // self.cols
2803
+ col = idx % self.cols
2804
+
2805
+ # Get pre-calculated starting positions
2806
+ j = self.row_starts[row]
2807
+ i = self.col_starts[col]
2808
+
2809
+ # Read window from raster
2810
+ with rasterio.open(self.raster_path) as src:
2811
+ # Make sure we don't read outside the image
2812
+ width = min(self.chip_size[1], self.width - i)
2813
+ height = min(self.chip_size[0], self.height - j)
2814
+
2815
+ window = Window(i, j, width, height)
2816
+
2817
+ # Read all bands
2818
+ image = src.read(window=window)
2819
+
2820
+ # Handle partial windows at edges by padding
2821
+ if (
2822
+ image.shape[1] != self.chip_size[0]
2823
+ or image.shape[2] != self.chip_size[1]
2824
+ ):
2825
+ temp = np.zeros(
2826
+ (image.shape[0], self.chip_size[0], self.chip_size[1]),
2827
+ dtype=image.dtype,
2828
+ )
2829
+ temp[:, : image.shape[1], : image.shape[2]] = image
2830
+ image = temp
2831
+
2832
+ # Preprocess bands for the model
2833
+ image_tensor = self.field_delineator.preprocess_sentinel_bands(
2834
+ image, self.band_selection, self.use_ndvi
2835
+ )
2836
+
2837
+ # Get geographic bounds for the window
2838
+ with rasterio.open(self.raster_path) as src:
2839
+ window_transform = src.window_transform(window)
2840
+ minx, miny = window_transform * (0, height)
2841
+ maxx, maxy = window_transform * (width, 0)
2842
+ bbox = [minx, miny, maxx, maxy]
2843
+
2844
+ return {
2845
+ "image": image_tensor,
2846
+ "bbox": bbox,
2847
+ "coords": torch.tensor([i, j], dtype=torch.long),
2848
+ "window_size": torch.tensor([width, height], dtype=torch.long),
2849
+ }
2850
+
2851
+ # Calculate stride based on overlap
2852
+ stride_x = int(chip_size[1] * (1 - overlap))
2853
+ stride_y = int(chip_size[0] * (1 - overlap))
2854
+
2855
+ # Create dataset
2856
+ dataset = Sentinel2Dataset(
2857
+ raster_path=raster_path,
2858
+ chip_size=chip_size,
2859
+ stride_x=stride_x,
2860
+ stride_y=stride_y,
2861
+ band_selection=band_selection,
2862
+ use_ndvi=use_ndvi,
2863
+ field_delineator=self,
2864
+ )
2865
+
2866
+ # Define custom collate function
2867
+ def custom_collate(batch):
2868
+ elem = batch[0]
2869
+ if isinstance(elem, dict):
2870
+ result = {}
2871
+ for key in elem:
2872
+ if key == "bbox":
2873
+ # Don't collate bbox objects, keep as list
2874
+ result[key] = [d[key] for d in batch]
2875
+ else:
2876
+ # For tensors and other collatable types
2877
+ try:
2878
+ result[key] = (
2879
+ torch.utils.data._utils.collate.default_collate(
2880
+ [d[key] for d in batch]
2881
+ )
2882
+ )
2883
+ except TypeError:
2884
+ # Fall back to list for non-collatable types
2885
+ result[key] = [d[key] for d in batch]
2886
+ return result
2887
+ else:
2888
+ # Default collate for non-dict types
2889
+ return torch.utils.data._utils.collate.default_collate(batch)
2890
+
2891
+ # Create dataloader
2892
+ dataloader = torch.utils.data.DataLoader(
2893
+ dataset,
2894
+ batch_size=batch_size,
2895
+ shuffle=False,
2896
+ num_workers=0,
2897
+ collate_fn=custom_collate,
2898
+ )
2899
+
2900
+ # Process batches (call the parent class's process_raster method)
2901
+ # We'll adapt the process_raster method to work with our Sentinel2Dataset
2902
+ results = super().process_raster(
2903
+ raster_path=raster_path,
2904
+ output_path=output_path,
2905
+ batch_size=batch_size,
2906
+ filter_edges=filter_edges,
2907
+ edge_buffer=edge_buffer,
2908
+ confidence_threshold=confidence_threshold,
2909
+ overlap=overlap,
2910
+ chip_size=chip_size,
2911
+ nms_iou_threshold=nms_iou_threshold,
2912
+ mask_threshold=mask_threshold,
2913
+ min_object_area=min_object_area,
2914
+ simplify_tolerance=simplify_tolerance,
2915
+ )
2916
+
2917
+ return results
geoai/geoai.py CHANGED
@@ -1,7 +1,15 @@
1
1
  """Main module."""
2
2
 
3
+ from .download import (
4
+ download_naip,
5
+ download_overture_buildings,
6
+ download_pc_stac_item,
7
+ pc_collection_list,
8
+ pc_stac_search,
9
+ pc_stac_download,
10
+ )
3
11
  from .extract import *
4
12
  from .hf import *
5
13
  from .segment import *
14
+ from .train import object_detection, train_MaskRCNN_model
6
15
  from .utils import *
7
- from .train import train_MaskRCNN_model, object_detection