geoai-py 0.4.1__py2.py3-none-any.whl → 0.4.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/download.py +751 -3
- geoai/extract.py +671 -46
- geoai/geoai.py +22 -1
- geoai/train.py +98 -12
- geoai/utils.py +240 -8
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/METADATA +6 -6
- geoai_py-0.4.3.dist-info/RECORD +15 -0
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/WHEEL +1 -1
- geoai_py-0.4.1.dist-info/RECORD +0 -15
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info/licenses}/LICENSE +0 -0
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/top_level.txt +0 -0
geoai/extract.py
CHANGED
|
@@ -19,6 +19,7 @@ from torchvision.models.detection import (
|
|
|
19
19
|
maskrcnn_resnet50_fpn,
|
|
20
20
|
)
|
|
21
21
|
from tqdm import tqdm
|
|
22
|
+
import time
|
|
22
23
|
|
|
23
24
|
# Local Imports
|
|
24
25
|
from .utils import get_raster_stats
|
|
@@ -2117,6 +2118,7 @@ class ObjectDetector:
|
|
|
2117
2118
|
confidence_threshold=0.5,
|
|
2118
2119
|
min_object_area=100,
|
|
2119
2120
|
max_object_area=None,
|
|
2121
|
+
n_workers=None,
|
|
2120
2122
|
**kwargs,
|
|
2121
2123
|
):
|
|
2122
2124
|
"""
|
|
@@ -2128,14 +2130,103 @@ class ObjectDetector:
|
|
|
2128
2130
|
confidence_threshold: Minimum confidence score (0.0-1.0). Default: 0.5
|
|
2129
2131
|
min_object_area: Minimum area in pixels to keep an object. Default: 100
|
|
2130
2132
|
max_object_area: Maximum area in pixels to keep an object. Default: None
|
|
2133
|
+
n_workers: int, default=None
|
|
2134
|
+
The number of worker threads to use.
|
|
2135
|
+
"None" means single-threaded processing.
|
|
2136
|
+
"-1" means using all available CPU processors.
|
|
2137
|
+
Positive integer means using that specific number of threads.
|
|
2131
2138
|
**kwargs: Additional parameters
|
|
2132
2139
|
|
|
2133
2140
|
Returns:
|
|
2134
2141
|
GeoDataFrame with car detections and confidence values
|
|
2135
2142
|
"""
|
|
2136
2143
|
|
|
2144
|
+
def _process_single_component(
|
|
2145
|
+
component_mask,
|
|
2146
|
+
conf_data,
|
|
2147
|
+
transform,
|
|
2148
|
+
confidence_threshold,
|
|
2149
|
+
min_object_area,
|
|
2150
|
+
max_object_area,
|
|
2151
|
+
):
|
|
2152
|
+
# Get confidence value
|
|
2153
|
+
conf_region = conf_data[component_mask > 0]
|
|
2154
|
+
if len(conf_region) > 0:
|
|
2155
|
+
confidence = np.mean(conf_region) / 255.0
|
|
2156
|
+
else:
|
|
2157
|
+
confidence = 0.0
|
|
2158
|
+
|
|
2159
|
+
# Skip if confidence is below threshold
|
|
2160
|
+
if confidence < confidence_threshold:
|
|
2161
|
+
return None
|
|
2162
|
+
|
|
2163
|
+
# Find contours
|
|
2164
|
+
contours, _ = cv2.findContours(
|
|
2165
|
+
component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
2166
|
+
)
|
|
2167
|
+
|
|
2168
|
+
results = []
|
|
2169
|
+
|
|
2170
|
+
for contour in contours:
|
|
2171
|
+
# Filter by size
|
|
2172
|
+
area = cv2.contourArea(contour)
|
|
2173
|
+
if area < min_object_area:
|
|
2174
|
+
continue
|
|
2175
|
+
|
|
2176
|
+
if max_object_area is not None and area > max_object_area:
|
|
2177
|
+
continue
|
|
2178
|
+
|
|
2179
|
+
# Get minimum area rectangle
|
|
2180
|
+
rect = cv2.minAreaRect(contour)
|
|
2181
|
+
box_points = cv2.boxPoints(rect)
|
|
2182
|
+
|
|
2183
|
+
# Convert to geographic coordinates
|
|
2184
|
+
geo_points = []
|
|
2185
|
+
for x, y in box_points:
|
|
2186
|
+
gx, gy = transform * (x, y)
|
|
2187
|
+
geo_points.append((gx, gy))
|
|
2188
|
+
|
|
2189
|
+
# Create polygon
|
|
2190
|
+
poly = Polygon(geo_points)
|
|
2191
|
+
results.append((poly, confidence, area))
|
|
2192
|
+
|
|
2193
|
+
return results
|
|
2194
|
+
|
|
2195
|
+
import concurrent.futures
|
|
2196
|
+
from functools import partial
|
|
2197
|
+
|
|
2198
|
+
def process_component(args):
|
|
2199
|
+
"""
|
|
2200
|
+
Helper function to process a single component
|
|
2201
|
+
"""
|
|
2202
|
+
(
|
|
2203
|
+
label,
|
|
2204
|
+
labeled_mask,
|
|
2205
|
+
conf_data,
|
|
2206
|
+
transform,
|
|
2207
|
+
confidence_threshold,
|
|
2208
|
+
min_object_area,
|
|
2209
|
+
max_object_area,
|
|
2210
|
+
) = args
|
|
2211
|
+
|
|
2212
|
+
# Create mask for this component
|
|
2213
|
+
component_mask = (labeled_mask == label).astype(np.uint8)
|
|
2214
|
+
|
|
2215
|
+
return _process_single_component(
|
|
2216
|
+
component_mask,
|
|
2217
|
+
conf_data,
|
|
2218
|
+
transform,
|
|
2219
|
+
confidence_threshold,
|
|
2220
|
+
min_object_area,
|
|
2221
|
+
max_object_area,
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
start_time = time.time()
|
|
2137
2225
|
print(f"Processing masks from: {masks_path}")
|
|
2138
2226
|
|
|
2227
|
+
if n_workers == -1:
|
|
2228
|
+
n_workers = os.cpu_count()
|
|
2229
|
+
|
|
2139
2230
|
with rasterio.open(masks_path) as src:
|
|
2140
2231
|
# Read mask and confidence bands
|
|
2141
2232
|
mask_data = src.read(1)
|
|
@@ -2155,56 +2246,68 @@ class ObjectDetector:
|
|
|
2155
2246
|
confidences = []
|
|
2156
2247
|
pixels = []
|
|
2157
2248
|
|
|
2158
|
-
|
|
2159
|
-
|
|
2160
|
-
|
|
2161
|
-
component_mask = (labeled_mask == label).astype(np.uint8)
|
|
2162
|
-
|
|
2163
|
-
# Get confidence value (mean of non-zero values in this region)
|
|
2164
|
-
conf_region = conf_data[component_mask > 0]
|
|
2165
|
-
if len(conf_region) > 0:
|
|
2166
|
-
confidence = (
|
|
2167
|
-
np.mean(conf_region) / 255.0
|
|
2168
|
-
) # Convert back to 0-1 range
|
|
2169
|
-
else:
|
|
2170
|
-
confidence = 0.0
|
|
2171
|
-
|
|
2172
|
-
# Skip if confidence is below threshold
|
|
2173
|
-
if confidence < confidence_threshold:
|
|
2174
|
-
continue
|
|
2175
|
-
|
|
2176
|
-
# Find contours
|
|
2177
|
-
contours, _ = cv2.findContours(
|
|
2178
|
-
component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
2249
|
+
if n_workers is None or n_workers == 1:
|
|
2250
|
+
print(
|
|
2251
|
+
"Using single-threaded processing, you can speed up processing by setting n_workers > 1"
|
|
2179
2252
|
)
|
|
2253
|
+
# Add progress bar
|
|
2254
|
+
for label in tqdm(
|
|
2255
|
+
range(1, num_features + 1), desc="Processing components"
|
|
2256
|
+
):
|
|
2257
|
+
# Create mask for this component
|
|
2258
|
+
component_mask = (labeled_mask == label).astype(np.uint8)
|
|
2259
|
+
|
|
2260
|
+
result = _process_single_component(
|
|
2261
|
+
component_mask,
|
|
2262
|
+
conf_data,
|
|
2263
|
+
transform,
|
|
2264
|
+
confidence_threshold,
|
|
2265
|
+
min_object_area,
|
|
2266
|
+
max_object_area,
|
|
2267
|
+
)
|
|
2180
2268
|
|
|
2181
|
-
|
|
2182
|
-
|
|
2183
|
-
|
|
2184
|
-
|
|
2185
|
-
|
|
2186
|
-
|
|
2187
|
-
if max_object_area is not None:
|
|
2188
|
-
if area > max_object_area:
|
|
2189
|
-
continue
|
|
2190
|
-
|
|
2191
|
-
# Get minimum area rectangle
|
|
2192
|
-
rect = cv2.minAreaRect(contour)
|
|
2193
|
-
box_points = cv2.boxPoints(rect)
|
|
2194
|
-
|
|
2195
|
-
# Convert to geographic coordinates
|
|
2196
|
-
geo_points = []
|
|
2197
|
-
for x, y in box_points:
|
|
2198
|
-
gx, gy = transform * (x, y)
|
|
2199
|
-
geo_points.append((gx, gy))
|
|
2269
|
+
if result:
|
|
2270
|
+
for poly, confidence, area in result:
|
|
2271
|
+
# Add to lists
|
|
2272
|
+
polygons.append(poly)
|
|
2273
|
+
confidences.append(confidence)
|
|
2274
|
+
pixels.append(area)
|
|
2200
2275
|
|
|
2201
|
-
|
|
2202
|
-
|
|
2276
|
+
else:
|
|
2277
|
+
# Process components in parallel
|
|
2278
|
+
print(f"Using {n_workers} workers for parallel processing")
|
|
2279
|
+
|
|
2280
|
+
process_args = [
|
|
2281
|
+
(
|
|
2282
|
+
label,
|
|
2283
|
+
labeled_mask,
|
|
2284
|
+
conf_data,
|
|
2285
|
+
transform,
|
|
2286
|
+
confidence_threshold,
|
|
2287
|
+
min_object_area,
|
|
2288
|
+
max_object_area,
|
|
2289
|
+
)
|
|
2290
|
+
for label in range(1, num_features + 1)
|
|
2291
|
+
]
|
|
2292
|
+
|
|
2293
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
2294
|
+
max_workers=n_workers
|
|
2295
|
+
) as executor:
|
|
2296
|
+
results = list(
|
|
2297
|
+
tqdm(
|
|
2298
|
+
executor.map(process_component, process_args),
|
|
2299
|
+
total=num_features,
|
|
2300
|
+
desc="Processing components",
|
|
2301
|
+
)
|
|
2302
|
+
)
|
|
2203
2303
|
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2304
|
+
for result in results:
|
|
2305
|
+
if result:
|
|
2306
|
+
for poly, confidence, area in result:
|
|
2307
|
+
# Add to lists
|
|
2308
|
+
polygons.append(poly)
|
|
2309
|
+
confidences.append(confidence)
|
|
2310
|
+
pixels.append(area)
|
|
2208
2311
|
|
|
2209
2312
|
# Create GeoDataFrame
|
|
2210
2313
|
if polygons:
|
|
@@ -2223,8 +2326,12 @@ class ObjectDetector:
|
|
|
2223
2326
|
gdf.to_file(output_path, driver="GeoJSON")
|
|
2224
2327
|
print(f"Saved {len(gdf)} objects with confidence to {output_path}")
|
|
2225
2328
|
|
|
2329
|
+
end_time = time.time()
|
|
2330
|
+
print(f"Total processing time: {end_time - start_time:.2f} seconds")
|
|
2226
2331
|
return gdf
|
|
2227
2332
|
else:
|
|
2333
|
+
end_time = time.time()
|
|
2334
|
+
print(f"Total processing time: {end_time - start_time:.2f} seconds")
|
|
2228
2335
|
print("No valid polygons found")
|
|
2229
2336
|
return None
|
|
2230
2337
|
|
|
@@ -2397,3 +2504,521 @@ class ParkingSplotDetector(ObjectDetector):
|
|
|
2397
2504
|
num_classes=num_classes,
|
|
2398
2505
|
device=device,
|
|
2399
2506
|
)
|
|
2507
|
+
|
|
2508
|
+
|
|
2509
|
+
class AgricultureFieldDelineator(ObjectDetector):
|
|
2510
|
+
"""
|
|
2511
|
+
Agricultural field boundary delineation using a pre-trained Mask R-CNN model.
|
|
2512
|
+
|
|
2513
|
+
This class extends the ObjectDetector class to specifically handle Sentinel-2
|
|
2514
|
+
imagery with 12 spectral bands for agricultural field boundary detection.
|
|
2515
|
+
|
|
2516
|
+
Attributes:
|
|
2517
|
+
band_selection: List of band indices to use for prediction (default: RGB)
|
|
2518
|
+
sentinel_band_stats: Per-band statistics for Sentinel-2 data
|
|
2519
|
+
use_ndvi: Whether to calculate and include NDVI as an additional channel
|
|
2520
|
+
"""
|
|
2521
|
+
|
|
2522
|
+
def __init__(
|
|
2523
|
+
self,
|
|
2524
|
+
model_path="field_boundary_detector.pth",
|
|
2525
|
+
repo_id=None,
|
|
2526
|
+
model=None,
|
|
2527
|
+
device=None,
|
|
2528
|
+
band_selection=None,
|
|
2529
|
+
use_ndvi=False,
|
|
2530
|
+
):
|
|
2531
|
+
"""
|
|
2532
|
+
Initialize the field boundary delineator.
|
|
2533
|
+
|
|
2534
|
+
Args:
|
|
2535
|
+
model_path: Path to the .pth model file.
|
|
2536
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
2537
|
+
model: Custom model to use for inference.
|
|
2538
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
2539
|
+
band_selection: List of Sentinel-2 band indices to use (None = adapt based on model)
|
|
2540
|
+
use_ndvi: Whether to calculate and include NDVI as an additional channel
|
|
2541
|
+
"""
|
|
2542
|
+
# Save parameters before calling parent constructor
|
|
2543
|
+
self.custom_band_selection = band_selection
|
|
2544
|
+
self.use_ndvi = use_ndvi
|
|
2545
|
+
|
|
2546
|
+
# Set device (copied from parent init to ensure it's set before initialize_model)
|
|
2547
|
+
if device is None:
|
|
2548
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
2549
|
+
else:
|
|
2550
|
+
self.device = torch.device(device)
|
|
2551
|
+
|
|
2552
|
+
# Initialize model differently for multi-spectral input
|
|
2553
|
+
model = self.initialize_sentinel2_model(model)
|
|
2554
|
+
|
|
2555
|
+
# Call parent but with our custom model
|
|
2556
|
+
super().__init__(
|
|
2557
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2558
|
+
)
|
|
2559
|
+
|
|
2560
|
+
# Default Sentinel-2 band statistics (can be overridden with actual stats)
|
|
2561
|
+
# Band order: [B1, B2, B3, B4, B5, B6, B7, B8, B8A, B9, B10, B11, B12]
|
|
2562
|
+
self.sentinel_band_stats = {
|
|
2563
|
+
"means": [
|
|
2564
|
+
0.0975,
|
|
2565
|
+
0.0476,
|
|
2566
|
+
0.0598,
|
|
2567
|
+
0.0697,
|
|
2568
|
+
0.1077,
|
|
2569
|
+
0.1859,
|
|
2570
|
+
0.2378,
|
|
2571
|
+
0.2061,
|
|
2572
|
+
0.2598,
|
|
2573
|
+
0.4120,
|
|
2574
|
+
0.1956,
|
|
2575
|
+
0.1410,
|
|
2576
|
+
],
|
|
2577
|
+
"stds": [
|
|
2578
|
+
0.0551,
|
|
2579
|
+
0.0290,
|
|
2580
|
+
0.0298,
|
|
2581
|
+
0.0479,
|
|
2582
|
+
0.0506,
|
|
2583
|
+
0.0505,
|
|
2584
|
+
0.0747,
|
|
2585
|
+
0.0642,
|
|
2586
|
+
0.0782,
|
|
2587
|
+
0.1187,
|
|
2588
|
+
0.0651,
|
|
2589
|
+
0.0679,
|
|
2590
|
+
],
|
|
2591
|
+
}
|
|
2592
|
+
|
|
2593
|
+
# Set default band selection (RGB - typically B4, B3, B2 for Sentinel-2)
|
|
2594
|
+
self.band_selection = (
|
|
2595
|
+
self.custom_band_selection
|
|
2596
|
+
if self.custom_band_selection is not None
|
|
2597
|
+
else [3, 2, 1]
|
|
2598
|
+
) # R, G, B bands
|
|
2599
|
+
|
|
2600
|
+
# Customize parameters for field delineation
|
|
2601
|
+
self.confidence_threshold = 0.5 # Default confidence threshold
|
|
2602
|
+
self.overlap = 0.5 # Higher overlap for field boundary detection
|
|
2603
|
+
self.min_object_area = 1000 # Minimum area in pixels for field detection
|
|
2604
|
+
self.simplify_tolerance = 2.0 # Higher tolerance for field boundaries
|
|
2605
|
+
|
|
2606
|
+
def initialize_sentinel2_model(self, model=None):
|
|
2607
|
+
"""
|
|
2608
|
+
Initialize a Mask R-CNN model with a modified first layer to accept Sentinel-2 data.
|
|
2609
|
+
|
|
2610
|
+
Args:
|
|
2611
|
+
model: Pre-initialized model (optional)
|
|
2612
|
+
|
|
2613
|
+
Returns:
|
|
2614
|
+
Modified model with appropriate input channels
|
|
2615
|
+
"""
|
|
2616
|
+
import torchvision
|
|
2617
|
+
from torchvision.models.detection import maskrcnn_resnet50_fpn
|
|
2618
|
+
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
|
2619
|
+
|
|
2620
|
+
if model is not None:
|
|
2621
|
+
return model
|
|
2622
|
+
|
|
2623
|
+
# Determine number of input channels based on band selection and NDVI
|
|
2624
|
+
num_input_channels = (
|
|
2625
|
+
len(self.custom_band_selection)
|
|
2626
|
+
if self.custom_band_selection is not None
|
|
2627
|
+
else 3
|
|
2628
|
+
)
|
|
2629
|
+
if self.use_ndvi:
|
|
2630
|
+
num_input_channels += 1
|
|
2631
|
+
|
|
2632
|
+
print(f"Initializing Mask R-CNN model with {num_input_channels} input channels")
|
|
2633
|
+
|
|
2634
|
+
# Create a ResNet50 backbone with modified input channels
|
|
2635
|
+
backbone = resnet_fpn_backbone("resnet50", weights=None)
|
|
2636
|
+
|
|
2637
|
+
# Replace the first conv layer to accept multi-spectral input
|
|
2638
|
+
original_conv = backbone.body.conv1
|
|
2639
|
+
backbone.body.conv1 = torch.nn.Conv2d(
|
|
2640
|
+
num_input_channels,
|
|
2641
|
+
original_conv.out_channels,
|
|
2642
|
+
kernel_size=original_conv.kernel_size,
|
|
2643
|
+
stride=original_conv.stride,
|
|
2644
|
+
padding=original_conv.padding,
|
|
2645
|
+
bias=original_conv.bias is not None,
|
|
2646
|
+
)
|
|
2647
|
+
|
|
2648
|
+
# Create Mask R-CNN with our modified backbone
|
|
2649
|
+
model = maskrcnn_resnet50_fpn(
|
|
2650
|
+
backbone=backbone,
|
|
2651
|
+
num_classes=2, # Background + field
|
|
2652
|
+
image_mean=[0.485] * num_input_channels, # Extend mean to all channels
|
|
2653
|
+
image_std=[0.229] * num_input_channels, # Extend std to all channels
|
|
2654
|
+
)
|
|
2655
|
+
|
|
2656
|
+
model.to(self.device)
|
|
2657
|
+
return model
|
|
2658
|
+
|
|
2659
|
+
def preprocess_sentinel_bands(self, image_data, band_selection=None, use_ndvi=None):
|
|
2660
|
+
"""
|
|
2661
|
+
Preprocess Sentinel-2 band data for model input.
|
|
2662
|
+
|
|
2663
|
+
Args:
|
|
2664
|
+
image_data: Raw Sentinel-2 image data as numpy array [bands, height, width]
|
|
2665
|
+
band_selection: List of band indices to use (overrides instance default if provided)
|
|
2666
|
+
use_ndvi: Whether to include NDVI (overrides instance default if provided)
|
|
2667
|
+
|
|
2668
|
+
Returns:
|
|
2669
|
+
Processed tensor ready for model input
|
|
2670
|
+
"""
|
|
2671
|
+
# Use instance defaults if not specified
|
|
2672
|
+
band_selection = (
|
|
2673
|
+
band_selection if band_selection is not None else self.band_selection
|
|
2674
|
+
)
|
|
2675
|
+
use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
|
|
2676
|
+
|
|
2677
|
+
# Select bands
|
|
2678
|
+
selected_bands = image_data[band_selection]
|
|
2679
|
+
|
|
2680
|
+
# Calculate NDVI if requested (using B8 and B4 which are indices 7 and 3)
|
|
2681
|
+
if (
|
|
2682
|
+
use_ndvi
|
|
2683
|
+
and 7 in range(image_data.shape[0])
|
|
2684
|
+
and 3 in range(image_data.shape[0])
|
|
2685
|
+
):
|
|
2686
|
+
nir = image_data[7].astype(np.float32) # B8 (NIR)
|
|
2687
|
+
red = image_data[3].astype(np.float32) # B4 (Red)
|
|
2688
|
+
|
|
2689
|
+
# Avoid division by zero
|
|
2690
|
+
denominator = nir + red
|
|
2691
|
+
ndvi = np.zeros_like(nir)
|
|
2692
|
+
valid_mask = denominator > 0
|
|
2693
|
+
ndvi[valid_mask] = (nir[valid_mask] - red[valid_mask]) / denominator[
|
|
2694
|
+
valid_mask
|
|
2695
|
+
]
|
|
2696
|
+
|
|
2697
|
+
# Rescale NDVI from [-1, 1] to [0, 1]
|
|
2698
|
+
ndvi = (ndvi + 1) / 2
|
|
2699
|
+
|
|
2700
|
+
# Add NDVI as an additional channel
|
|
2701
|
+
selected_bands = np.vstack([selected_bands, ndvi[np.newaxis, :, :]])
|
|
2702
|
+
|
|
2703
|
+
# Convert to tensor
|
|
2704
|
+
image_tensor = torch.from_numpy(selected_bands).float()
|
|
2705
|
+
|
|
2706
|
+
# Normalize using band statistics
|
|
2707
|
+
for i, band_idx in enumerate(band_selection):
|
|
2708
|
+
# Make sure band_idx is within range of our statistics
|
|
2709
|
+
if band_idx < len(self.sentinel_band_stats["means"]):
|
|
2710
|
+
mean = self.sentinel_band_stats["means"][band_idx]
|
|
2711
|
+
std = self.sentinel_band_stats["stds"][band_idx]
|
|
2712
|
+
image_tensor[i] = (image_tensor[i] - mean) / std
|
|
2713
|
+
|
|
2714
|
+
# If NDVI was added, normalize it too (last channel)
|
|
2715
|
+
if use_ndvi:
|
|
2716
|
+
# NDVI is already roughly in [0,1] range, just standardize it slightly
|
|
2717
|
+
image_tensor[-1] = (image_tensor[-1] - 0.5) / 0.5
|
|
2718
|
+
|
|
2719
|
+
return image_tensor
|
|
2720
|
+
|
|
2721
|
+
def update_band_stats(self, raster_path, band_selection=None, sample_size=1000):
|
|
2722
|
+
"""
|
|
2723
|
+
Update band statistics from the input Sentinel-2 raster.
|
|
2724
|
+
|
|
2725
|
+
Args:
|
|
2726
|
+
raster_path: Path to the Sentinel-2 raster file
|
|
2727
|
+
band_selection: Specific bands to update (None = update all available)
|
|
2728
|
+
sample_size: Number of random pixels to sample for statistics calculation
|
|
2729
|
+
|
|
2730
|
+
Returns:
|
|
2731
|
+
Updated band statistics dictionary
|
|
2732
|
+
"""
|
|
2733
|
+
with rasterio.open(raster_path) as src:
|
|
2734
|
+
# Check if this is likely a Sentinel-2 product
|
|
2735
|
+
band_count = src.count
|
|
2736
|
+
if band_count < 3:
|
|
2737
|
+
print(
|
|
2738
|
+
f"Warning: Raster has only {band_count} bands, may not be Sentinel-2 data"
|
|
2739
|
+
)
|
|
2740
|
+
|
|
2741
|
+
# Get dimensions
|
|
2742
|
+
height, width = src.height, src.width
|
|
2743
|
+
|
|
2744
|
+
# Determine which bands to analyze
|
|
2745
|
+
if band_selection is None:
|
|
2746
|
+
band_selection = list(range(1, band_count + 1)) # 1-indexed
|
|
2747
|
+
|
|
2748
|
+
# Initialize arrays for band statistics
|
|
2749
|
+
means = []
|
|
2750
|
+
stds = []
|
|
2751
|
+
|
|
2752
|
+
# Sample random pixels
|
|
2753
|
+
np.random.seed(42) # For reproducibility
|
|
2754
|
+
sample_rows = np.random.randint(0, height, sample_size)
|
|
2755
|
+
sample_cols = np.random.randint(0, width, sample_size)
|
|
2756
|
+
|
|
2757
|
+
# Calculate statistics for each band
|
|
2758
|
+
for band in band_selection:
|
|
2759
|
+
# Read band data
|
|
2760
|
+
band_data = src.read(band)
|
|
2761
|
+
|
|
2762
|
+
# Sample values
|
|
2763
|
+
sample_values = band_data[sample_rows, sample_cols]
|
|
2764
|
+
|
|
2765
|
+
# Remove invalid values (e.g., nodata)
|
|
2766
|
+
valid_samples = sample_values[np.isfinite(sample_values)]
|
|
2767
|
+
|
|
2768
|
+
# Calculate statistics
|
|
2769
|
+
mean = float(np.mean(valid_samples))
|
|
2770
|
+
std = float(np.std(valid_samples))
|
|
2771
|
+
|
|
2772
|
+
# Store results
|
|
2773
|
+
means.append(mean)
|
|
2774
|
+
stds.append(std)
|
|
2775
|
+
|
|
2776
|
+
print(f"Band {band}: mean={mean:.4f}, std={std:.4f}")
|
|
2777
|
+
|
|
2778
|
+
# Update instance variables
|
|
2779
|
+
self.sentinel_band_stats = {"means": means, "stds": stds}
|
|
2780
|
+
|
|
2781
|
+
return self.sentinel_band_stats
|
|
2782
|
+
|
|
2783
|
+
def process_sentinel_raster(
|
|
2784
|
+
self,
|
|
2785
|
+
raster_path,
|
|
2786
|
+
output_path=None,
|
|
2787
|
+
batch_size=4,
|
|
2788
|
+
band_selection=None,
|
|
2789
|
+
use_ndvi=None,
|
|
2790
|
+
filter_edges=True,
|
|
2791
|
+
edge_buffer=20,
|
|
2792
|
+
**kwargs,
|
|
2793
|
+
):
|
|
2794
|
+
"""
|
|
2795
|
+
Process a Sentinel-2 raster to extract field boundaries.
|
|
2796
|
+
|
|
2797
|
+
Args:
|
|
2798
|
+
raster_path: Path to Sentinel-2 raster file
|
|
2799
|
+
output_path: Path to output GeoJSON or Parquet file (optional)
|
|
2800
|
+
batch_size: Batch size for processing
|
|
2801
|
+
band_selection: List of bands to use (None = use instance default)
|
|
2802
|
+
use_ndvi: Whether to include NDVI (None = use instance default)
|
|
2803
|
+
filter_edges: Whether to filter out objects at the edges of the image
|
|
2804
|
+
edge_buffer: Size of edge buffer in pixels to filter out objects
|
|
2805
|
+
**kwargs: Additional parameters for processing
|
|
2806
|
+
|
|
2807
|
+
Returns:
|
|
2808
|
+
GeoDataFrame with field boundaries
|
|
2809
|
+
"""
|
|
2810
|
+
# Use instance defaults if not specified
|
|
2811
|
+
band_selection = (
|
|
2812
|
+
band_selection if band_selection is not None else self.band_selection
|
|
2813
|
+
)
|
|
2814
|
+
use_ndvi = use_ndvi if use_ndvi is not None else self.use_ndvi
|
|
2815
|
+
|
|
2816
|
+
# Get parameters from kwargs or use instance defaults
|
|
2817
|
+
confidence_threshold = kwargs.get(
|
|
2818
|
+
"confidence_threshold", self.confidence_threshold
|
|
2819
|
+
)
|
|
2820
|
+
overlap = kwargs.get("overlap", self.overlap)
|
|
2821
|
+
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
2822
|
+
nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
|
|
2823
|
+
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
2824
|
+
min_object_area = kwargs.get("min_object_area", self.min_object_area)
|
|
2825
|
+
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
2826
|
+
|
|
2827
|
+
# Update band statistics if not already done
|
|
2828
|
+
if kwargs.get("update_stats", True):
|
|
2829
|
+
self.update_band_stats(raster_path, band_selection)
|
|
2830
|
+
|
|
2831
|
+
print(f"Processing with parameters:")
|
|
2832
|
+
print(f"- Using bands: {band_selection}")
|
|
2833
|
+
print(f"- Include NDVI: {use_ndvi}")
|
|
2834
|
+
print(f"- Confidence threshold: {confidence_threshold}")
|
|
2835
|
+
print(f"- Tile overlap: {overlap}")
|
|
2836
|
+
print(f"- Chip size: {chip_size}")
|
|
2837
|
+
print(f"- Filter edge objects: {filter_edges}")
|
|
2838
|
+
|
|
2839
|
+
# Create a custom Sentinel-2 dataset class
|
|
2840
|
+
class Sentinel2Dataset(torch.utils.data.Dataset):
|
|
2841
|
+
def __init__(
|
|
2842
|
+
self,
|
|
2843
|
+
raster_path,
|
|
2844
|
+
chip_size,
|
|
2845
|
+
stride_x,
|
|
2846
|
+
stride_y,
|
|
2847
|
+
band_selection,
|
|
2848
|
+
use_ndvi,
|
|
2849
|
+
field_delineator,
|
|
2850
|
+
):
|
|
2851
|
+
self.raster_path = raster_path
|
|
2852
|
+
self.chip_size = chip_size
|
|
2853
|
+
self.stride_x = stride_x
|
|
2854
|
+
self.stride_y = stride_y
|
|
2855
|
+
self.band_selection = band_selection
|
|
2856
|
+
self.use_ndvi = use_ndvi
|
|
2857
|
+
self.field_delineator = field_delineator
|
|
2858
|
+
|
|
2859
|
+
with rasterio.open(self.raster_path) as src:
|
|
2860
|
+
self.height = src.height
|
|
2861
|
+
self.width = src.width
|
|
2862
|
+
self.count = src.count
|
|
2863
|
+
self.crs = src.crs
|
|
2864
|
+
self.transform = src.transform
|
|
2865
|
+
|
|
2866
|
+
# Calculate row_starts and col_starts
|
|
2867
|
+
self.row_starts = []
|
|
2868
|
+
self.col_starts = []
|
|
2869
|
+
|
|
2870
|
+
# Normal row starts using stride
|
|
2871
|
+
for r in range((self.height - 1) // self.stride_y):
|
|
2872
|
+
self.row_starts.append(r * self.stride_y)
|
|
2873
|
+
|
|
2874
|
+
# Add a special last row that ensures we reach the bottom edge
|
|
2875
|
+
if self.height > self.chip_size[0]:
|
|
2876
|
+
self.row_starts.append(max(0, self.height - self.chip_size[0]))
|
|
2877
|
+
else:
|
|
2878
|
+
# If the image is smaller than chip size, just start at 0
|
|
2879
|
+
if not self.row_starts:
|
|
2880
|
+
self.row_starts.append(0)
|
|
2881
|
+
|
|
2882
|
+
# Normal column starts using stride
|
|
2883
|
+
for c in range((self.width - 1) // self.stride_x):
|
|
2884
|
+
self.col_starts.append(c * self.stride_x)
|
|
2885
|
+
|
|
2886
|
+
# Add a special last column that ensures we reach the right edge
|
|
2887
|
+
if self.width > self.chip_size[1]:
|
|
2888
|
+
self.col_starts.append(max(0, self.width - self.chip_size[1]))
|
|
2889
|
+
else:
|
|
2890
|
+
# If the image is smaller than chip size, just start at 0
|
|
2891
|
+
if not self.col_starts:
|
|
2892
|
+
self.col_starts.append(0)
|
|
2893
|
+
|
|
2894
|
+
# Calculate number of tiles
|
|
2895
|
+
self.rows = len(self.row_starts)
|
|
2896
|
+
self.cols = len(self.col_starts)
|
|
2897
|
+
|
|
2898
|
+
print(
|
|
2899
|
+
f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
|
|
2900
|
+
)
|
|
2901
|
+
print(f"Image dimensions: {self.width} x {self.height} pixels")
|
|
2902
|
+
print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
|
|
2903
|
+
|
|
2904
|
+
def __len__(self):
|
|
2905
|
+
return self.rows * self.cols
|
|
2906
|
+
|
|
2907
|
+
def __getitem__(self, idx):
|
|
2908
|
+
# Convert flat index to grid position
|
|
2909
|
+
row = idx // self.cols
|
|
2910
|
+
col = idx % self.cols
|
|
2911
|
+
|
|
2912
|
+
# Get pre-calculated starting positions
|
|
2913
|
+
j = self.row_starts[row]
|
|
2914
|
+
i = self.col_starts[col]
|
|
2915
|
+
|
|
2916
|
+
# Read window from raster
|
|
2917
|
+
with rasterio.open(self.raster_path) as src:
|
|
2918
|
+
# Make sure we don't read outside the image
|
|
2919
|
+
width = min(self.chip_size[1], self.width - i)
|
|
2920
|
+
height = min(self.chip_size[0], self.height - j)
|
|
2921
|
+
|
|
2922
|
+
window = Window(i, j, width, height)
|
|
2923
|
+
|
|
2924
|
+
# Read all bands
|
|
2925
|
+
image = src.read(window=window)
|
|
2926
|
+
|
|
2927
|
+
# Handle partial windows at edges by padding
|
|
2928
|
+
if (
|
|
2929
|
+
image.shape[1] != self.chip_size[0]
|
|
2930
|
+
or image.shape[2] != self.chip_size[1]
|
|
2931
|
+
):
|
|
2932
|
+
temp = np.zeros(
|
|
2933
|
+
(image.shape[0], self.chip_size[0], self.chip_size[1]),
|
|
2934
|
+
dtype=image.dtype,
|
|
2935
|
+
)
|
|
2936
|
+
temp[:, : image.shape[1], : image.shape[2]] = image
|
|
2937
|
+
image = temp
|
|
2938
|
+
|
|
2939
|
+
# Preprocess bands for the model
|
|
2940
|
+
image_tensor = self.field_delineator.preprocess_sentinel_bands(
|
|
2941
|
+
image, self.band_selection, self.use_ndvi
|
|
2942
|
+
)
|
|
2943
|
+
|
|
2944
|
+
# Get geographic bounds for the window
|
|
2945
|
+
with rasterio.open(self.raster_path) as src:
|
|
2946
|
+
window_transform = src.window_transform(window)
|
|
2947
|
+
minx, miny = window_transform * (0, height)
|
|
2948
|
+
maxx, maxy = window_transform * (width, 0)
|
|
2949
|
+
bbox = [minx, miny, maxx, maxy]
|
|
2950
|
+
|
|
2951
|
+
return {
|
|
2952
|
+
"image": image_tensor,
|
|
2953
|
+
"bbox": bbox,
|
|
2954
|
+
"coords": torch.tensor([i, j], dtype=torch.long),
|
|
2955
|
+
"window_size": torch.tensor([width, height], dtype=torch.long),
|
|
2956
|
+
}
|
|
2957
|
+
|
|
2958
|
+
# Calculate stride based on overlap
|
|
2959
|
+
stride_x = int(chip_size[1] * (1 - overlap))
|
|
2960
|
+
stride_y = int(chip_size[0] * (1 - overlap))
|
|
2961
|
+
|
|
2962
|
+
# Create dataset
|
|
2963
|
+
dataset = Sentinel2Dataset(
|
|
2964
|
+
raster_path=raster_path,
|
|
2965
|
+
chip_size=chip_size,
|
|
2966
|
+
stride_x=stride_x,
|
|
2967
|
+
stride_y=stride_y,
|
|
2968
|
+
band_selection=band_selection,
|
|
2969
|
+
use_ndvi=use_ndvi,
|
|
2970
|
+
field_delineator=self,
|
|
2971
|
+
)
|
|
2972
|
+
|
|
2973
|
+
# Define custom collate function
|
|
2974
|
+
def custom_collate(batch):
|
|
2975
|
+
elem = batch[0]
|
|
2976
|
+
if isinstance(elem, dict):
|
|
2977
|
+
result = {}
|
|
2978
|
+
for key in elem:
|
|
2979
|
+
if key == "bbox":
|
|
2980
|
+
# Don't collate bbox objects, keep as list
|
|
2981
|
+
result[key] = [d[key] for d in batch]
|
|
2982
|
+
else:
|
|
2983
|
+
# For tensors and other collatable types
|
|
2984
|
+
try:
|
|
2985
|
+
result[key] = (
|
|
2986
|
+
torch.utils.data._utils.collate.default_collate(
|
|
2987
|
+
[d[key] for d in batch]
|
|
2988
|
+
)
|
|
2989
|
+
)
|
|
2990
|
+
except TypeError:
|
|
2991
|
+
# Fall back to list for non-collatable types
|
|
2992
|
+
result[key] = [d[key] for d in batch]
|
|
2993
|
+
return result
|
|
2994
|
+
else:
|
|
2995
|
+
# Default collate for non-dict types
|
|
2996
|
+
return torch.utils.data._utils.collate.default_collate(batch)
|
|
2997
|
+
|
|
2998
|
+
# Create dataloader
|
|
2999
|
+
dataloader = torch.utils.data.DataLoader(
|
|
3000
|
+
dataset,
|
|
3001
|
+
batch_size=batch_size,
|
|
3002
|
+
shuffle=False,
|
|
3003
|
+
num_workers=0,
|
|
3004
|
+
collate_fn=custom_collate,
|
|
3005
|
+
)
|
|
3006
|
+
|
|
3007
|
+
# Process batches (call the parent class's process_raster method)
|
|
3008
|
+
# We'll adapt the process_raster method to work with our Sentinel2Dataset
|
|
3009
|
+
results = super().process_raster(
|
|
3010
|
+
raster_path=raster_path,
|
|
3011
|
+
output_path=output_path,
|
|
3012
|
+
batch_size=batch_size,
|
|
3013
|
+
filter_edges=filter_edges,
|
|
3014
|
+
edge_buffer=edge_buffer,
|
|
3015
|
+
confidence_threshold=confidence_threshold,
|
|
3016
|
+
overlap=overlap,
|
|
3017
|
+
chip_size=chip_size,
|
|
3018
|
+
nms_iou_threshold=nms_iou_threshold,
|
|
3019
|
+
mask_threshold=mask_threshold,
|
|
3020
|
+
min_object_area=min_object_area,
|
|
3021
|
+
simplify_tolerance=simplify_tolerance,
|
|
3022
|
+
)
|
|
3023
|
+
|
|
3024
|
+
return results
|