geoai-py 0.4.1__py2.py3-none-any.whl → 0.4.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/extract.py CHANGED
@@ -19,6 +19,7 @@ from torchvision.models.detection import (
19
19
  maskrcnn_resnet50_fpn,
20
20
  )
21
21
  from tqdm import tqdm
22
+ import time
22
23
 
23
24
  # Local Imports
24
25
  from .utils import get_raster_stats
@@ -2117,6 +2118,7 @@ class ObjectDetector:
2117
2118
  confidence_threshold=0.5,
2118
2119
  min_object_area=100,
2119
2120
  max_object_area=None,
2121
+ n_workers=None,
2120
2122
  **kwargs,
2121
2123
  ):
2122
2124
  """
@@ -2128,14 +2130,103 @@ class ObjectDetector:
2128
2130
  confidence_threshold: Minimum confidence score (0.0-1.0). Default: 0.5
2129
2131
  min_object_area: Minimum area in pixels to keep an object. Default: 100
2130
2132
  max_object_area: Maximum area in pixels to keep an object. Default: None
2133
+ n_workers: int, default=None
2134
+ The number of worker threads to use.
2135
+ "None" means single-threaded processing.
2136
+ "-1" means using all available CPU processors.
2137
+ Positive integer means using that specific number of threads.
2131
2138
  **kwargs: Additional parameters
2132
2139
 
2133
2140
  Returns:
2134
2141
  GeoDataFrame with car detections and confidence values
2135
2142
  """
2136
2143
 
2144
+ def _process_single_component(
2145
+ component_mask,
2146
+ conf_data,
2147
+ transform,
2148
+ confidence_threshold,
2149
+ min_object_area,
2150
+ max_object_area,
2151
+ ):
2152
+ # Get confidence value
2153
+ conf_region = conf_data[component_mask > 0]
2154
+ if len(conf_region) > 0:
2155
+ confidence = np.mean(conf_region) / 255.0
2156
+ else:
2157
+ confidence = 0.0
2158
+
2159
+ # Skip if confidence is below threshold
2160
+ if confidence < confidence_threshold:
2161
+ return None
2162
+
2163
+ # Find contours
2164
+ contours, _ = cv2.findContours(
2165
+ component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
2166
+ )
2167
+
2168
+ results = []
2169
+
2170
+ for contour in contours:
2171
+ # Filter by size
2172
+ area = cv2.contourArea(contour)
2173
+ if area < min_object_area:
2174
+ continue
2175
+
2176
+ if max_object_area is not None and area > max_object_area:
2177
+ continue
2178
+
2179
+ # Get minimum area rectangle
2180
+ rect = cv2.minAreaRect(contour)
2181
+ box_points = cv2.boxPoints(rect)
2182
+
2183
+ # Convert to geographic coordinates
2184
+ geo_points = []
2185
+ for x, y in box_points:
2186
+ gx, gy = transform * (x, y)
2187
+ geo_points.append((gx, gy))
2188
+
2189
+ # Create polygon
2190
+ poly = Polygon(geo_points)
2191
+ results.append((poly, confidence, area))
2192
+
2193
+ return results
2194
+
2195
+ import concurrent.futures
2196
+ from functools import partial
2197
+
2198
+ def process_component(args):
2199
+ """
2200
+ Helper function to process a single component
2201
+ """
2202
+ (
2203
+ label,
2204
+ labeled_mask,
2205
+ conf_data,
2206
+ transform,
2207
+ confidence_threshold,
2208
+ min_object_area,
2209
+ max_object_area,
2210
+ ) = args
2211
+
2212
+ # Create mask for this component
2213
+ component_mask = (labeled_mask == label).astype(np.uint8)
2214
+
2215
+ return _process_single_component(
2216
+ component_mask,
2217
+ conf_data,
2218
+ transform,
2219
+ confidence_threshold,
2220
+ min_object_area,
2221
+ max_object_area,
2222
+ )
2223
+
2224
+ start_time = time.time()
2137
2225
  print(f"Processing masks from: {masks_path}")
2138
2226
 
2227
+ if n_workers == -1:
2228
+ n_workers = os.cpu_count()
2229
+
2139
2230
  with rasterio.open(masks_path) as src:
2140
2231
  # Read mask and confidence bands
2141
2232
  mask_data = src.read(1)
@@ -2155,56 +2246,68 @@ class ObjectDetector:
2155
2246
  confidences = []
2156
2247
  pixels = []
2157
2248
 
2158
- # Add progress bar
2159
- for label in tqdm(range(1, num_features + 1), desc="Processing components"):
2160
- # Create mask for this component
2161
- component_mask = (labeled_mask == label).astype(np.uint8)
2162
-
2163
- # Get confidence value (mean of non-zero values in this region)
2164
- conf_region = conf_data[component_mask > 0]
2165
- if len(conf_region) > 0:
2166
- confidence = (
2167
- np.mean(conf_region) / 255.0
2168
- ) # Convert back to 0-1 range
2169
- else:
2170
- confidence = 0.0
2171
-
2172
- # Skip if confidence is below threshold
2173
- if confidence < confidence_threshold:
2174
- continue
2175
-
2176
- # Find contours
2177
- contours, _ = cv2.findContours(
2178
- component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
2249
+ if n_workers is None or n_workers == 1:
2250
+ print(
2251
+ "Using single-threaded processing, you can speed up processing by setting n_workers > 1"
2179
2252
  )
2253
+ # Add progress bar
2254
+ for label in tqdm(
2255
+ range(1, num_features + 1), desc="Processing components"
2256
+ ):
2257
+ # Create mask for this component
2258
+ component_mask = (labeled_mask == label).astype(np.uint8)
2259
+
2260
+ result = _process_single_component(
2261
+ component_mask,
2262
+ conf_data,
2263
+ transform,
2264
+ confidence_threshold,
2265
+ min_object_area,
2266
+ max_object_area,
2267
+ )
2180
2268
 
2181
- for contour in contours:
2182
- # Filter by size
2183
- area = cv2.contourArea(contour)
2184
- if area < min_object_area:
2185
- continue
2186
-
2187
- if max_object_area is not None:
2188
- if area > max_object_area:
2189
- continue
2190
-
2191
- # Get minimum area rectangle
2192
- rect = cv2.minAreaRect(contour)
2193
- box_points = cv2.boxPoints(rect)
2194
-
2195
- # Convert to geographic coordinates
2196
- geo_points = []
2197
- for x, y in box_points:
2198
- gx, gy = transform * (x, y)
2199
- geo_points.append((gx, gy))
2269
+ if result:
2270
+ for poly, confidence, area in result:
2271
+ # Add to lists
2272
+ polygons.append(poly)
2273
+ confidences.append(confidence)
2274
+ pixels.append(area)
2200
2275
 
2201
- # Create polygon
2202
- poly = Polygon(geo_points)
2276
+ else:
2277
+ # Process components in parallel
2278
+ print(f"Using {n_workers} workers for parallel processing")
2279
+
2280
+ process_args = [
2281
+ (
2282
+ label,
2283
+ labeled_mask,
2284
+ conf_data,
2285
+ transform,
2286
+ confidence_threshold,
2287
+ min_object_area,
2288
+ max_object_area,
2289
+ )
2290
+ for label in range(1, num_features + 1)
2291
+ ]
2292
+
2293
+ with concurrent.futures.ThreadPoolExecutor(
2294
+ max_workers=n_workers
2295
+ ) as executor:
2296
+ results = list(
2297
+ tqdm(
2298
+ executor.map(process_component, process_args),
2299
+ total=num_features,
2300
+ desc="Processing components",
2301
+ )
2302
+ )
2203
2303
 
2204
- # Add to lists
2205
- polygons.append(poly)
2206
- confidences.append(confidence)
2207
- pixels.append(area)
2304
+ for result in results:
2305
+ if result:
2306
+ for poly, confidence, area in result:
2307
+ # Add to lists
2308
+ polygons.append(poly)
2309
+ confidences.append(confidence)
2310
+ pixels.append(area)
2208
2311
 
2209
2312
  # Create GeoDataFrame
2210
2313
  if polygons:
@@ -2223,8 +2326,12 @@ class ObjectDetector:
2223
2326
  gdf.to_file(output_path, driver="GeoJSON")
2224
2327
  print(f"Saved {len(gdf)} objects with confidence to {output_path}")
2225
2328
 
2329
+ end_time = time.time()
2330
+ print(f"Total processing time: {end_time - start_time:.2f} seconds")
2226
2331
  return gdf
2227
2332
  else:
2333
+ end_time = time.time()
2334
+ print(f"Total processing time: {end_time - start_time:.2f} seconds")
2228
2335
  print("No valid polygons found")
2229
2336
  return None
2230
2337
 
@@ -2397,3 +2504,521 @@ class ParkingSplotDetector(ObjectDetector):
2397
2504
  num_classes=num_classes,
2398
2505
  device=device,
2399
2506
  )
2507
+
2508
+
2509
+ class AgricultureFieldDelineator(ObjectDetector):
2510
+ """
2511
+ Agricultural field boundary delineation using a pre-trained Mask R-CNN model.
2512
+
2513
+ This class extends the ObjectDetector class to specifically handle Sentinel-2
2514
+ imagery with 12 spectral bands for agricultural field boundary detection.
2515
+
2516
+ Attributes:
2517
+ band_selection: List of band indices to use for prediction (default: RGB)
2518
+ sentinel_band_stats: Per-band statistics for Sentinel-2 data
2519
+ use_ndvi: Whether to calculate and include NDVI as an additional channel
2520
+ """
2521
+
2522
+ def __init__(
2523
+ self,
2524
+ model_path="field_boundary_detector.pth",
2525
+ repo_id=None,
2526
+ model=None,
2527
+ device=None,
2528
+ band_selection=None,
2529
+ use_ndvi=False,
2530
+ ):
2531
+ """
2532
+ Initialize the field boundary delineator.
2533
+
2534
+ Args:
2535
+ model_path: Path to the .pth model file.
2536
+ repo_id: Repo ID for loading models from the Hub.
2537
+ model: Custom model to use for inference.
2538
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2539
+ band_selection: List of Sentinel-2 band indices to use (None = adapt based on model)
2540
+ use_ndvi: Whether to calculate and include NDVI as an additional channel
2541
+ """
2542
+ # Save parameters before calling parent constructor
2543
+ self.custom_band_selection = band_selection
2544
+ self.use_ndvi = use_ndvi
2545
+
2546
+ # Set device (copied from parent init to ensure it's set before initialize_model)
2547
+ if device is None:
2548
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2549
+ else:
2550
+ self.device = torch.device(device)
2551
+
2552
+ # Initialize model differently for multi-spectral input
2553
+ model = self.initialize_sentinel2_model(model)
2554
+
2555
+ # Call parent but with our custom model
2556
+ super().__init__(
2557
+ model_path=model_path, repo_id=repo_id, model=model, device=device
2558
+ )
2559
+
2560
+ # Default Sentinel-2 band statistics (can be overridden with actual stats)
2561
+ # Band order: [B1, B2, B3, B4, B5, B6, B7, B8, B8A, B9, B10, B11, B12]
2562
+ self.sentinel_band_stats = {
2563
+ "means": [
2564
+ 0.0975,
2565
+ 0.0476,
2566
+ 0.0598,
2567
+ 0.0697,
2568
+ 0.1077,
2569
+ 0.1859,
2570
+ 0.2378,
2571
+ 0.2061,
2572
+ 0.2598,
2573
+ 0.4120,
2574
+ 0.1956,
2575
+ 0.1410,
2576
+ ],
2577
+ "stds": [
2578
+ 0.0551,
2579
+ 0.0290,
2580
+ 0.0298,
2581
+ 0.0479,
2582
+ 0.0506,
2583
+ 0.0505,
2584
+ 0.0747,
2585
+ 0.0642,
2586
+ 0.0782,
2587
+ 0.1187,
2588
+ 0.0651,
2589
+ 0.0679,
2590
+ ],
2591
+ }
2592
+
2593
+ # Set default band selection (RGB - typically B4, B3, B2 for Sentinel-2)
2594
+ self.band_selection = (
2595
+ self.custom_band_selection
2596
+ if self.custom_band_selection is not None
2597
+ else [3, 2, 1]
2598
+ ) # R, G, B bands
2599
+
2600
+ # Customize parameters for field delineation
2601
+ self.confidence_threshold = 0.5 # Default confidence threshold
2602
+ self.overlap = 0.5 # Higher overlap for field boundary detection
2603
+ self.min_object_area = 1000 # Minimum area in pixels for field detection
2604
+ self.simplify_tolerance = 2.0 # Higher tolerance for field boundaries
2605
+
2606
+ def initialize_sentinel2_model(self, model=None):
2607
+ """
2608
+ Initialize a Mask R-CNN model with a modified first layer to accept Sentinel-2 data.
2609
+
2610
+ Args:
2611
+ model: Pre-initialized model (optional)
2612
+
2613
+ Returns:
2614
+ Modified model with appropriate input channels
2615
+ """
2616
+ import torchvision
2617
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
2618
+ from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
2619
+
2620
+ if model is not None:
2621
+ return model
2622
+
2623
+ # Determine number of input channels based on band selection and NDVI
2624
+ num_input_channels = (
2625
+ len(self.custom_band_selection)
2626
+ if self.custom_band_selection is not None
2627
+ else 3
2628
+ )
2629
+ if self.use_ndvi:
2630
+ num_input_channels += 1
2631
+
2632
+ print(f"Initializing Mask R-CNN model with {num_input_channels} input channels")
2633
+
2634
+ # Create a ResNet50 backbone with modified input channels
2635
+ backbone = resnet_fpn_backbone("resnet50", weights=None)
2636
+
2637
+ # Replace the first conv layer to accept multi-spectral input
2638
+ original_conv = backbone.body.conv1
2639
+ backbone.body.conv1 = torch.nn.Conv2d(
2640
+ num_input_channels,
2641
+ original_conv.out_channels,
2642
+ kernel_size=original_conv.kernel_size,
2643
+ stride=original_conv.stride,
2644
+ padding=original_conv.padding,
2645
+ bias=original_conv.bias is not None,
2646
+ )
2647
+
2648
+ # Create Mask R-CNN with our modified backbone
2649
+ model = maskrcnn_resnet50_fpn(
2650
+ backbone=backbone,
2651
+ num_classes=2, # Background + field
2652
+ image_mean=[0.485] * num_input_channels, # Extend mean to all channels
2653
+ image_std=[0.229] * num_input_channels, # Extend std to all channels
2654
+ )
2655
+
2656
+ model.to(self.device)
2657
+ return model
2658
+
2659
+ def preprocess_sentinel_bands(self, image_data, band_selection=None, use_ndvi=None):
2660
+ """
2661
+ Preprocess Sentinel-2 band data for model input.
2662
+
2663
+ Args:
2664
+ image_data: Raw Sentinel-2 image data as numpy array [bands, height, width]
2665
+ band_selection: List of band indices to use (overrides instance default if provided)
2666
+ use_ndvi: Whether to include NDVI (overrides instance default if provided)
2667
+
2668
+ Returns:
2669
+ Processed tensor ready for model input
2670
+ """
2671
+ # Use instance defaults if not specified
2672
+ band_selection = (
2673
+ band_selection if band_selection is not None else self.band_selection
2674
+ )
2675
+ use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
2676
+
2677
+ # Select bands
2678
+ selected_bands = image_data[band_selection]
2679
+
2680
+ # Calculate NDVI if requested (using B8 and B4 which are indices 7 and 3)
2681
+ if (
2682
+ use_ndvi
2683
+ and 7 in range(image_data.shape[0])
2684
+ and 3 in range(image_data.shape[0])
2685
+ ):
2686
+ nir = image_data[7].astype(np.float32) # B8 (NIR)
2687
+ red = image_data[3].astype(np.float32) # B4 (Red)
2688
+
2689
+ # Avoid division by zero
2690
+ denominator = nir + red
2691
+ ndvi = np.zeros_like(nir)
2692
+ valid_mask = denominator > 0
2693
+ ndvi[valid_mask] = (nir[valid_mask] - red[valid_mask]) / denominator[
2694
+ valid_mask
2695
+ ]
2696
+
2697
+ # Rescale NDVI from [-1, 1] to [0, 1]
2698
+ ndvi = (ndvi + 1) / 2
2699
+
2700
+ # Add NDVI as an additional channel
2701
+ selected_bands = np.vstack([selected_bands, ndvi[np.newaxis, :, :]])
2702
+
2703
+ # Convert to tensor
2704
+ image_tensor = torch.from_numpy(selected_bands).float()
2705
+
2706
+ # Normalize using band statistics
2707
+ for i, band_idx in enumerate(band_selection):
2708
+ # Make sure band_idx is within range of our statistics
2709
+ if band_idx < len(self.sentinel_band_stats["means"]):
2710
+ mean = self.sentinel_band_stats["means"][band_idx]
2711
+ std = self.sentinel_band_stats["stds"][band_idx]
2712
+ image_tensor[i] = (image_tensor[i] - mean) / std
2713
+
2714
+ # If NDVI was added, normalize it too (last channel)
2715
+ if use_ndvi:
2716
+ # NDVI is already roughly in [0,1] range, just standardize it slightly
2717
+ image_tensor[-1] = (image_tensor[-1] - 0.5) / 0.5
2718
+
2719
+ return image_tensor
2720
+
2721
+ def update_band_stats(self, raster_path, band_selection=None, sample_size=1000):
2722
+ """
2723
+ Update band statistics from the input Sentinel-2 raster.
2724
+
2725
+ Args:
2726
+ raster_path: Path to the Sentinel-2 raster file
2727
+ band_selection: Specific bands to update (None = update all available)
2728
+ sample_size: Number of random pixels to sample for statistics calculation
2729
+
2730
+ Returns:
2731
+ Updated band statistics dictionary
2732
+ """
2733
+ with rasterio.open(raster_path) as src:
2734
+ # Check if this is likely a Sentinel-2 product
2735
+ band_count = src.count
2736
+ if band_count < 3:
2737
+ print(
2738
+ f"Warning: Raster has only {band_count} bands, may not be Sentinel-2 data"
2739
+ )
2740
+
2741
+ # Get dimensions
2742
+ height, width = src.height, src.width
2743
+
2744
+ # Determine which bands to analyze
2745
+ if band_selection is None:
2746
+ band_selection = list(range(1, band_count + 1)) # 1-indexed
2747
+
2748
+ # Initialize arrays for band statistics
2749
+ means = []
2750
+ stds = []
2751
+
2752
+ # Sample random pixels
2753
+ np.random.seed(42) # For reproducibility
2754
+ sample_rows = np.random.randint(0, height, sample_size)
2755
+ sample_cols = np.random.randint(0, width, sample_size)
2756
+
2757
+ # Calculate statistics for each band
2758
+ for band in band_selection:
2759
+ # Read band data
2760
+ band_data = src.read(band)
2761
+
2762
+ # Sample values
2763
+ sample_values = band_data[sample_rows, sample_cols]
2764
+
2765
+ # Remove invalid values (e.g., nodata)
2766
+ valid_samples = sample_values[np.isfinite(sample_values)]
2767
+
2768
+ # Calculate statistics
2769
+ mean = float(np.mean(valid_samples))
2770
+ std = float(np.std(valid_samples))
2771
+
2772
+ # Store results
2773
+ means.append(mean)
2774
+ stds.append(std)
2775
+
2776
+ print(f"Band {band}: mean={mean:.4f}, std={std:.4f}")
2777
+
2778
+ # Update instance variables
2779
+ self.sentinel_band_stats = {"means": means, "stds": stds}
2780
+
2781
+ return self.sentinel_band_stats
2782
+
2783
+ def process_sentinel_raster(
2784
+ self,
2785
+ raster_path,
2786
+ output_path=None,
2787
+ batch_size=4,
2788
+ band_selection=None,
2789
+ use_ndvi=None,
2790
+ filter_edges=True,
2791
+ edge_buffer=20,
2792
+ **kwargs,
2793
+ ):
2794
+ """
2795
+ Process a Sentinel-2 raster to extract field boundaries.
2796
+
2797
+ Args:
2798
+ raster_path: Path to Sentinel-2 raster file
2799
+ output_path: Path to output GeoJSON or Parquet file (optional)
2800
+ batch_size: Batch size for processing
2801
+ band_selection: List of bands to use (None = use instance default)
2802
+ use_ndvi: Whether to include NDVI (None = use instance default)
2803
+ filter_edges: Whether to filter out objects at the edges of the image
2804
+ edge_buffer: Size of edge buffer in pixels to filter out objects
2805
+ **kwargs: Additional parameters for processing
2806
+
2807
+ Returns:
2808
+ GeoDataFrame with field boundaries
2809
+ """
2810
+ # Use instance defaults if not specified
2811
+ band_selection = (
2812
+ band_selection if band_selection is not None else self.band_selection
2813
+ )
2814
+ use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
2815
+
2816
+ # Get parameters from kwargs or use instance defaults
2817
+ confidence_threshold = kwargs.get(
2818
+ "confidence_threshold", self.confidence_threshold
2819
+ )
2820
+ overlap = kwargs.get("overlap", self.overlap)
2821
+ chip_size = kwargs.get("chip_size", self.chip_size)
2822
+ nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
2823
+ mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
2824
+ min_object_area = kwargs.get("min_object_area", self.min_object_area)
2825
+ simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
2826
+
2827
+ # Update band statistics if not already done
2828
+ if kwargs.get("update_stats", True):
2829
+ self.update_band_stats(raster_path, band_selection)
2830
+
2831
+ print(f"Processing with parameters:")
2832
+ print(f"- Using bands: {band_selection}")
2833
+ print(f"- Include NDVI: {use_ndvi}")
2834
+ print(f"- Confidence threshold: {confidence_threshold}")
2835
+ print(f"- Tile overlap: {overlap}")
2836
+ print(f"- Chip size: {chip_size}")
2837
+ print(f"- Filter edge objects: {filter_edges}")
2838
+
2839
+ # Create a custom Sentinel-2 dataset class
2840
+ class Sentinel2Dataset(torch.utils.data.Dataset):
2841
+ def __init__(
2842
+ self,
2843
+ raster_path,
2844
+ chip_size,
2845
+ stride_x,
2846
+ stride_y,
2847
+ band_selection,
2848
+ use_ndvi,
2849
+ field_delineator,
2850
+ ):
2851
+ self.raster_path = raster_path
2852
+ self.chip_size = chip_size
2853
+ self.stride_x = stride_x
2854
+ self.stride_y = stride_y
2855
+ self.band_selection = band_selection
2856
+ self.use_ndvi = use_ndvi
2857
+ self.field_delineator = field_delineator
2858
+
2859
+ with rasterio.open(self.raster_path) as src:
2860
+ self.height = src.height
2861
+ self.width = src.width
2862
+ self.count = src.count
2863
+ self.crs = src.crs
2864
+ self.transform = src.transform
2865
+
2866
+ # Calculate row_starts and col_starts
2867
+ self.row_starts = []
2868
+ self.col_starts = []
2869
+
2870
+ # Normal row starts using stride
2871
+ for r in range((self.height - 1) // self.stride_y):
2872
+ self.row_starts.append(r * self.stride_y)
2873
+
2874
+ # Add a special last row that ensures we reach the bottom edge
2875
+ if self.height > self.chip_size[0]:
2876
+ self.row_starts.append(max(0, self.height - self.chip_size[0]))
2877
+ else:
2878
+ # If the image is smaller than chip size, just start at 0
2879
+ if not self.row_starts:
2880
+ self.row_starts.append(0)
2881
+
2882
+ # Normal column starts using stride
2883
+ for c in range((self.width - 1) // self.stride_x):
2884
+ self.col_starts.append(c * self.stride_x)
2885
+
2886
+ # Add a special last column that ensures we reach the right edge
2887
+ if self.width > self.chip_size[1]:
2888
+ self.col_starts.append(max(0, self.width - self.chip_size[1]))
2889
+ else:
2890
+ # If the image is smaller than chip size, just start at 0
2891
+ if not self.col_starts:
2892
+ self.col_starts.append(0)
2893
+
2894
+ # Calculate number of tiles
2895
+ self.rows = len(self.row_starts)
2896
+ self.cols = len(self.col_starts)
2897
+
2898
+ print(
2899
+ f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
2900
+ )
2901
+ print(f"Image dimensions: {self.width} x {self.height} pixels")
2902
+ print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
2903
+
2904
+ def __len__(self):
2905
+ return self.rows * self.cols
2906
+
2907
+ def __getitem__(self, idx):
2908
+ # Convert flat index to grid position
2909
+ row = idx // self.cols
2910
+ col = idx % self.cols
2911
+
2912
+ # Get pre-calculated starting positions
2913
+ j = self.row_starts[row]
2914
+ i = self.col_starts[col]
2915
+
2916
+ # Read window from raster
2917
+ with rasterio.open(self.raster_path) as src:
2918
+ # Make sure we don't read outside the image
2919
+ width = min(self.chip_size[1], self.width - i)
2920
+ height = min(self.chip_size[0], self.height - j)
2921
+
2922
+ window = Window(i, j, width, height)
2923
+
2924
+ # Read all bands
2925
+ image = src.read(window=window)
2926
+
2927
+ # Handle partial windows at edges by padding
2928
+ if (
2929
+ image.shape[1] != self.chip_size[0]
2930
+ or image.shape[2] != self.chip_size[1]
2931
+ ):
2932
+ temp = np.zeros(
2933
+ (image.shape[0], self.chip_size[0], self.chip_size[1]),
2934
+ dtype=image.dtype,
2935
+ )
2936
+ temp[:, : image.shape[1], : image.shape[2]] = image
2937
+ image = temp
2938
+
2939
+ # Preprocess bands for the model
2940
+ image_tensor = self.field_delineator.preprocess_sentinel_bands(
2941
+ image, self.band_selection, self.use_ndvi
2942
+ )
2943
+
2944
+ # Get geographic bounds for the window
2945
+ with rasterio.open(self.raster_path) as src:
2946
+ window_transform = src.window_transform(window)
2947
+ minx, miny = window_transform * (0, height)
2948
+ maxx, maxy = window_transform * (width, 0)
2949
+ bbox = [minx, miny, maxx, maxy]
2950
+
2951
+ return {
2952
+ "image": image_tensor,
2953
+ "bbox": bbox,
2954
+ "coords": torch.tensor([i, j], dtype=torch.long),
2955
+ "window_size": torch.tensor([width, height], dtype=torch.long),
2956
+ }
2957
+
2958
+ # Calculate stride based on overlap
2959
+ stride_x = int(chip_size[1] * (1 - overlap))
2960
+ stride_y = int(chip_size[0] * (1 - overlap))
2961
+
2962
+ # Create dataset
2963
+ dataset = Sentinel2Dataset(
2964
+ raster_path=raster_path,
2965
+ chip_size=chip_size,
2966
+ stride_x=stride_x,
2967
+ stride_y=stride_y,
2968
+ band_selection=band_selection,
2969
+ use_ndvi=use_ndvi,
2970
+ field_delineator=self,
2971
+ )
2972
+
2973
+ # Define custom collate function
2974
+ def custom_collate(batch):
2975
+ elem = batch[0]
2976
+ if isinstance(elem, dict):
2977
+ result = {}
2978
+ for key in elem:
2979
+ if key == "bbox":
2980
+ # Don't collate bbox objects, keep as list
2981
+ result[key] = [d[key] for d in batch]
2982
+ else:
2983
+ # For tensors and other collatable types
2984
+ try:
2985
+ result[key] = (
2986
+ torch.utils.data._utils.collate.default_collate(
2987
+ [d[key] for d in batch]
2988
+ )
2989
+ )
2990
+ except TypeError:
2991
+ # Fall back to list for non-collatable types
2992
+ result[key] = [d[key] for d in batch]
2993
+ return result
2994
+ else:
2995
+ # Default collate for non-dict types
2996
+ return torch.utils.data._utils.collate.default_collate(batch)
2997
+
2998
+ # Create dataloader
2999
+ dataloader = torch.utils.data.DataLoader(
3000
+ dataset,
3001
+ batch_size=batch_size,
3002
+ shuffle=False,
3003
+ num_workers=0,
3004
+ collate_fn=custom_collate,
3005
+ )
3006
+
3007
+ # Process batches (call the parent class's process_raster method)
3008
+ # We'll adapt the process_raster method to work with our Sentinel2Dataset
3009
+ results = super().process_raster(
3010
+ raster_path=raster_path,
3011
+ output_path=output_path,
3012
+ batch_size=batch_size,
3013
+ filter_edges=filter_edges,
3014
+ edge_buffer=edge_buffer,
3015
+ confidence_threshold=confidence_threshold,
3016
+ overlap=overlap,
3017
+ chip_size=chip_size,
3018
+ nms_iou_threshold=nms_iou_threshold,
3019
+ mask_threshold=mask_threshold,
3020
+ min_object_area=min_object_area,
3021
+ simplify_tolerance=simplify_tolerance,
3022
+ )
3023
+
3024
+ return results