geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/extract.py CHANGED
@@ -3,6 +3,7 @@
3
3
  # Standard Library
4
4
  import os
5
5
  import time
6
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Union
6
7
 
7
8
  # Third-Party Libraries
8
9
  import cv2
@@ -64,13 +65,13 @@ class CustomDataset(NonGeoDataset):
64
65
 
65
66
  def __init__(
66
67
  self,
67
- raster_path,
68
- chip_size=(512, 512),
69
- overlap=0.5,
70
- transforms=None,
71
- band_indexes=None,
72
- verbose=False,
73
- ):
68
+ raster_path: str,
69
+ chip_size: Tuple[int, int] = (512, 512),
70
+ overlap: float = 0.5,
71
+ transforms: Optional[Any] = None,
72
+ band_indexes: Optional[List[int]] = None,
73
+ verbose: bool = False,
74
+ ) -> None:
74
75
  """
75
76
  Initialize the dataset with overlapping tiles.
76
77
 
@@ -163,7 +164,7 @@ class CustomDataset(NonGeoDataset):
163
164
  # Get raster stats
164
165
  self.raster_stats = get_raster_stats(raster_path, divide_by=255)
165
166
 
166
- def __getitem__(self, idx):
167
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
167
168
  """
168
169
  Get an image chip from the dataset by index.
169
170
 
@@ -255,7 +256,7 @@ class CustomDataset(NonGeoDataset):
255
256
  ), # Consistent format
256
257
  }
257
258
 
258
- def __len__(self):
259
+ def __len__(self) -> int:
259
260
  """
260
261
  Return the number of samples in the dataset.
261
262
 
@@ -271,8 +272,13 @@ class ObjectDetector:
271
272
  """
272
273
 
273
274
  def __init__(
274
- self, model_path=None, repo_id=None, model=None, num_classes=2, device=None
275
- ):
275
+ self,
276
+ model_path: Optional[str] = None,
277
+ repo_id: Optional[str] = None,
278
+ model: Optional[Any] = None,
279
+ num_classes: int = 2,
280
+ device: Optional[str] = None,
281
+ ) -> None:
276
282
  """
277
283
  Initialize the object extractor.
278
284
 
@@ -312,7 +318,9 @@ class ObjectDetector:
312
318
  # Set model to evaluation mode
313
319
  self.model.eval()
314
320
 
315
- def download_model_from_hf(self, model_path=None, repo_id=None):
321
+ def download_model_from_hf(
322
+ self, model_path: Optional[str] = None, repo_id: Optional[str] = None
323
+ ) -> str:
316
324
  """
317
325
  Download the object detection model from Hugging Face.
318
326
 
@@ -345,7 +353,7 @@ class ObjectDetector:
345
353
  print("Please specify a local model path or ensure internet connectivity.")
346
354
  raise
347
355
 
348
- def initialize_model(self, model, num_classes=2):
356
+ def initialize_model(self, model: Optional[Any], num_classes: int = 2) -> Any:
349
357
  """Initialize a deep learning model for object detection.
350
358
 
351
359
  Args:
@@ -375,7 +383,7 @@ class ObjectDetector:
375
383
  model.to(self.device)
376
384
  return model
377
385
 
378
- def load_weights(self, model_path):
386
+ def load_weights(self, model_path: str) -> None:
379
387
  """
380
388
  Load weights from file with error handling for different formats.
381
389
 
@@ -417,7 +425,7 @@ class ObjectDetector:
417
425
  except Exception as e:
418
426
  raise RuntimeError(f"Failed to load model: {e}")
419
427
 
420
- def mask_to_polygons(self, mask, **kwargs):
428
+ def mask_to_polygons(self, mask: np.ndarray, **kwargs: Any) -> List[Polygon]:
421
429
  """
422
430
  Convert binary mask to polygon contours using OpenCV.
423
431
 
@@ -474,7 +482,9 @@ class ObjectDetector:
474
482
 
475
483
  return polygons
476
484
 
477
- def filter_overlapping_polygons(self, gdf, **kwargs):
485
+ def filter_overlapping_polygons(
486
+ self, gdf: gpd.GeoDataFrame, **kwargs: Any
487
+ ) -> gpd.GeoDataFrame:
478
488
  """
479
489
  Filter overlapping polygons using non-maximum suppression.
480
490
 
@@ -531,7 +541,9 @@ class ObjectDetector:
531
541
 
532
542
  return gdf.iloc[keep_indices]
533
543
 
534
- def filter_edge_objects(self, gdf, raster_path, edge_buffer=10):
544
+ def filter_edge_objects(
545
+ self, gdf: gpd.GeoDataFrame, raster_path: str, edge_buffer: int = 10
546
+ ) -> gpd.GeoDataFrame:
535
547
  """
536
548
  Filter out object detections that fall in padding/edge areas of the image.
537
549
 
@@ -596,17 +608,17 @@ class ObjectDetector:
596
608
 
597
609
  def masks_to_vector(
598
610
  self,
599
- mask_path,
600
- output_path=None,
601
- simplify_tolerance=None,
602
- mask_threshold=None,
603
- min_object_area=None,
604
- max_object_area=None,
605
- nms_iou_threshold=None,
606
- regularize=True,
607
- angle_threshold=15,
608
- rectangularity_threshold=0.7,
609
- ):
611
+ mask_path: str,
612
+ output_path: Optional[str] = None,
613
+ simplify_tolerance: Optional[float] = None,
614
+ mask_threshold: Optional[float] = None,
615
+ min_object_area: Optional[int] = None,
616
+ max_object_area: Optional[int] = None,
617
+ nms_iou_threshold: Optional[float] = None,
618
+ regularize: bool = True,
619
+ angle_threshold: int = 15,
620
+ rectangularity_threshold: float = 0.7,
621
+ ) -> gpd.GeoDataFrame:
610
622
  """
611
623
  Convert an object mask GeoTIFF to vector polygons and save as GeoJSON.
612
624
 
@@ -808,14 +820,14 @@ class ObjectDetector:
808
820
  @torch.no_grad()
809
821
  def process_raster(
810
822
  self,
811
- raster_path,
812
- output_path=None,
813
- batch_size=4,
814
- filter_edges=True,
815
- edge_buffer=20,
816
- band_indexes=None,
817
- **kwargs,
818
- ):
823
+ raster_path: str,
824
+ output_path: Optional[str] = None,
825
+ batch_size: int = 4,
826
+ filter_edges: bool = True,
827
+ edge_buffer: int = 20,
828
+ band_indexes: Optional[List[int]] = None,
829
+ **kwargs: Any,
830
+ ) -> "gpd.GeoDataFrame":
819
831
  """
820
832
  Process a raster file to extract objects with customizable parameters.
821
833
 
@@ -874,7 +886,7 @@ class ObjectDetector:
874
886
  self.raster_stats = dataset.raster_stats
875
887
 
876
888
  # Custom collate function to handle Shapely objects
877
- def custom_collate(batch):
889
+ def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
878
890
  """
879
891
  Custom collate function that handles Shapely geometries
880
892
  by keeping them as Python objects rather than trying to collate them.
@@ -1056,8 +1068,13 @@ class ObjectDetector:
1056
1068
  return gdf
1057
1069
 
1058
1070
  def save_masks_as_geotiff(
1059
- self, raster_path, output_path=None, batch_size=4, verbose=False, **kwargs
1060
- ):
1071
+ self,
1072
+ raster_path: str,
1073
+ output_path: Optional[str] = None,
1074
+ batch_size: int = 4,
1075
+ verbose: bool = False,
1076
+ **kwargs: Any,
1077
+ ) -> str:
1061
1078
  """
1062
1079
  Process a raster file to extract object masks and save as GeoTIFF.
1063
1080
 
@@ -1125,7 +1142,7 @@ class ObjectDetector:
1125
1142
  mask_array = np.zeros((src.height, src.width), dtype=np.uint8)
1126
1143
 
1127
1144
  # Custom collate function to handle Shapely objects
1128
- def custom_collate(batch):
1145
+ def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
1129
1146
  """Custom collate function for DataLoader"""
1130
1147
  elem = batch[0]
1131
1148
  if isinstance(elem, dict):
@@ -1291,12 +1308,12 @@ class ObjectDetector:
1291
1308
 
1292
1309
  def regularize_objects(
1293
1310
  self,
1294
- gdf,
1295
- min_area=10,
1296
- angle_threshold=15,
1297
- orthogonality_threshold=0.3,
1298
- rectangularity_threshold=0.7,
1299
- ):
1311
+ gdf: gpd.GeoDataFrame,
1312
+ min_area: int = 10,
1313
+ angle_threshold: int = 15,
1314
+ orthogonality_threshold: float = 0.3,
1315
+ rectangularity_threshold: float = 0.7,
1316
+ ) -> gpd.GeoDataFrame:
1300
1317
  """
1301
1318
  Regularize objects to enforce right angles and rectangular shapes.
1302
1319
 
@@ -1319,7 +1336,9 @@ class ObjectDetector:
1319
1336
  from shapely.geometry import MultiPolygon, Polygon, box
1320
1337
  from tqdm import tqdm
1321
1338
 
1322
- def get_angle(p1, p2, p3):
1339
+ def get_angle(
1340
+ p1: Tuple[float, float], p2: Tuple[float, float], p3: Tuple[float, float]
1341
+ ) -> float:
1323
1342
  """Calculate angle between three points in degrees (0-180)"""
1324
1343
  a = np.array(p1)
1325
1344
  b = np.array(p2)
@@ -1335,11 +1354,11 @@ class ObjectDetector:
1335
1354
 
1336
1355
  return angle
1337
1356
 
1338
- def is_orthogonal(angle, threshold=angle_threshold):
1357
+ def is_orthogonal(angle: float, threshold: int = angle_threshold) -> bool:
1339
1358
  """Check if angle is close to 90 degrees"""
1340
1359
  return abs(angle - 90) <= threshold
1341
1360
 
1342
- def calculate_dominant_direction(polygon):
1361
+ def calculate_dominant_direction(polygon: Polygon) -> float:
1343
1362
  """Find the dominant direction of a polygon using PCA"""
1344
1363
  # Extract coordinates
1345
1364
  coords = np.array(polygon.exterior.coords)
@@ -1368,7 +1387,7 @@ class ObjectDetector:
1368
1387
 
1369
1388
  return angle_deg
1370
1389
 
1371
- def create_oriented_envelope(polygon, angle_deg):
1390
+ def create_oriented_envelope(polygon: Polygon, angle_deg: float) -> Polygon:
1372
1391
  """Create an oriented minimum area rectangle for the polygon"""
1373
1392
  # Create a rotated rectangle using OpenCV method (more robust than Shapely methods)
1374
1393
  coords = np.array(polygon.exterior.coords)[:-1].astype(
@@ -1384,13 +1403,13 @@ class ObjectDetector:
1384
1403
 
1385
1404
  return oriented_box
1386
1405
 
1387
- def get_rectangularity(polygon, oriented_box):
1406
+ def get_rectangularity(polygon: Polygon, oriented_box: Polygon) -> float:
1388
1407
  """Calculate the rectangularity (area ratio to its oriented bounding box)"""
1389
1408
  if oriented_box.area == 0:
1390
1409
  return 0
1391
1410
  return polygon.area / oriented_box.area
1392
1411
 
1393
- def check_orthogonality(polygon):
1412
+ def check_orthogonality(polygon: Polygon) -> float:
1394
1413
  """Check what percentage of angles in the polygon are orthogonal"""
1395
1414
  coords = list(polygon.exterior.coords)
1396
1415
  if len(coords) <= 4: # Triangle or point
@@ -1413,7 +1432,7 @@ class ObjectDetector:
1413
1432
 
1414
1433
  return orthogonal_count / total_angles
1415
1434
 
1416
- def simplify_to_rectangle(polygon):
1435
+ def simplify_to_rectangle(polygon: Polygon) -> Polygon:
1417
1436
  """Simplify a polygon to a rectangle using its oriented bounding box"""
1418
1437
  # Get dominant direction
1419
1438
  angle = calculate_dominant_direction(polygon)
@@ -1506,8 +1525,12 @@ class ObjectDetector:
1506
1525
  return result_gdf
1507
1526
 
1508
1527
  def visualize_results(
1509
- self, raster_path, gdf=None, output_path=None, figsize=(12, 12)
1510
- ):
1528
+ self,
1529
+ raster_path: str,
1530
+ gdf: Optional[gpd.GeoDataFrame] = None,
1531
+ output_path: Optional[str] = None,
1532
+ figsize: Tuple[int, int] = (12, 12),
1533
+ ) -> bool:
1511
1534
  """
1512
1535
  Visualize object detection results with proper coordinate transformation.
1513
1536
 
@@ -1653,7 +1676,9 @@ class ObjectDetector:
1653
1676
  has_confidence = False
1654
1677
 
1655
1678
  # Function to convert coordinates
1656
- def geo_to_pixel(geometry, transform):
1679
+ def geo_to_pixel(
1680
+ geometry: Any, transform: Any
1681
+ ) -> Optional[Tuple[List[float], List[float]]]:
1657
1682
  """Convert geometry to pixel coordinates using the provided transform."""
1658
1683
  if geometry.is_empty:
1659
1684
  return None
@@ -1913,18 +1938,18 @@ class ObjectDetector:
1913
1938
 
1914
1939
  def generate_masks(
1915
1940
  self,
1916
- raster_path,
1917
- output_path=None,
1918
- confidence_threshold=None,
1919
- mask_threshold=None,
1920
- min_object_area=10,
1921
- max_object_area=float("inf"),
1922
- overlap=0.25,
1923
- batch_size=4,
1924
- band_indexes=None,
1925
- verbose=False,
1926
- **kwargs,
1927
- ):
1941
+ raster_path: str,
1942
+ output_path: Optional[str] = None,
1943
+ confidence_threshold: Optional[float] = None,
1944
+ mask_threshold: Optional[float] = None,
1945
+ min_object_area: int = 10,
1946
+ max_object_area: float = float("inf"),
1947
+ overlap: float = 0.25,
1948
+ batch_size: int = 4,
1949
+ band_indexes: Optional[List[int]] = None,
1950
+ verbose: bool = False,
1951
+ **kwargs: Any,
1952
+ ) -> str:
1928
1953
  """
1929
1954
  Save masks with confidence values as a multi-band GeoTIFF.
1930
1955
 
@@ -1983,7 +2008,7 @@ class ObjectDetector:
1983
2008
  conf_array = np.zeros((src.height, src.width), dtype=np.uint8)
1984
2009
 
1985
2010
  # Define custom collate function to handle Shapely objects
1986
- def custom_collate(batch):
2011
+ def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
1987
2012
  """
1988
2013
  Custom collate function that handles Shapely geometries
1989
2014
  by keeping them as Python objects rather than trying to collate them.
@@ -2113,14 +2138,14 @@ class ObjectDetector:
2113
2138
 
2114
2139
  def vectorize_masks(
2115
2140
  self,
2116
- masks_path,
2117
- output_path=None,
2118
- confidence_threshold=0.5,
2119
- min_object_area=100,
2120
- max_object_area=None,
2121
- n_workers=None,
2122
- **kwargs,
2123
- ):
2141
+ masks_path: str,
2142
+ output_path: Optional[str] = None,
2143
+ confidence_threshold: float = 0.5,
2144
+ min_object_area: int = 100,
2145
+ max_object_area: Optional[int] = None,
2146
+ n_workers: Optional[int] = None,
2147
+ **kwargs: Any,
2148
+ ) -> gpd.GeoDataFrame:
2124
2149
  """
2125
2150
  Convert masks with confidence to vector polygons.
2126
2151
 
@@ -2142,13 +2167,13 @@ class ObjectDetector:
2142
2167
  """
2143
2168
 
2144
2169
  def _process_single_component(
2145
- component_mask,
2146
- conf_data,
2147
- transform,
2148
- confidence_threshold,
2149
- min_object_area,
2150
- max_object_area,
2151
- ):
2170
+ component_mask: np.ndarray,
2171
+ conf_data: np.ndarray,
2172
+ transform: Any,
2173
+ confidence_threshold: float,
2174
+ min_object_area: int,
2175
+ max_object_area: Optional[int],
2176
+ ) -> Optional[Dict[str, Any]]:
2152
2177
  # Get confidence value
2153
2178
  conf_region = conf_data[component_mask > 0]
2154
2179
  if len(conf_region) > 0:
@@ -2195,7 +2220,9 @@ class ObjectDetector:
2195
2220
  import concurrent.futures
2196
2221
  from functools import partial
2197
2222
 
2198
- def process_component(args):
2223
+ def process_component(
2224
+ args: Tuple[int, np.ndarray, np.ndarray, Any, float, int, Optional[int]],
2225
+ ) -> Optional[Dict[str, Any]]:
2199
2226
  """
2200
2227
  Helper function to process a single component
2201
2228
  """
@@ -2346,11 +2373,11 @@ class BuildingFootprintExtractor(ObjectDetector):
2346
2373
 
2347
2374
  def __init__(
2348
2375
  self,
2349
- model_path="building_footprints_usa.pth",
2350
- repo_id=None,
2351
- model=None,
2352
- device=None,
2353
- ):
2376
+ model_path: str = "building_footprints_usa.pth",
2377
+ repo_id: Optional[str] = None,
2378
+ model: Optional[Any] = None,
2379
+ device: Optional[str] = None,
2380
+ ) -> None:
2354
2381
  """
2355
2382
  Initialize the object extractor.
2356
2383
 
@@ -2366,12 +2393,12 @@ class BuildingFootprintExtractor(ObjectDetector):
2366
2393
 
2367
2394
  def regularize_buildings(
2368
2395
  self,
2369
- gdf,
2370
- min_area=10,
2371
- angle_threshold=15,
2372
- orthogonality_threshold=0.3,
2373
- rectangularity_threshold=0.7,
2374
- ):
2396
+ gdf: gpd.GeoDataFrame,
2397
+ min_area: int = 10,
2398
+ angle_threshold: int = 15,
2399
+ orthogonality_threshold: float = 0.3,
2400
+ rectangularity_threshold: float = 0.7,
2401
+ ) -> gpd.GeoDataFrame:
2375
2402
  """
2376
2403
  Regularize building footprints to enforce right angles and rectangular shapes.
2377
2404
 
@@ -2402,8 +2429,12 @@ class CarDetector(ObjectDetector):
2402
2429
  """
2403
2430
 
2404
2431
  def __init__(
2405
- self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
2406
- ):
2432
+ self,
2433
+ model_path: str = "car_detection_usa.pth",
2434
+ repo_id: Optional[str] = None,
2435
+ model: Optional[Any] = None,
2436
+ device: Optional[str] = None,
2437
+ ) -> None:
2407
2438
  """
2408
2439
  Initialize the object extractor.
2409
2440
 
@@ -2427,8 +2458,12 @@ class ShipDetector(ObjectDetector):
2427
2458
  """
2428
2459
 
2429
2460
  def __init__(
2430
- self, model_path="ship_detection.pth", repo_id=None, model=None, device=None
2431
- ):
2461
+ self,
2462
+ model_path: str = "ship_detection.pth",
2463
+ repo_id: Optional[str] = None,
2464
+ model: Optional[Any] = None,
2465
+ device: Optional[str] = None,
2466
+ ) -> None:
2432
2467
  """
2433
2468
  Initialize the object extractor.
2434
2469
 
@@ -2453,11 +2488,11 @@ class SolarPanelDetector(ObjectDetector):
2453
2488
 
2454
2489
  def __init__(
2455
2490
  self,
2456
- model_path="solar_panel_detection.pth",
2457
- repo_id=None,
2458
- model=None,
2459
- device=None,
2460
- ):
2491
+ model_path: str = "solar_panel_detection.pth",
2492
+ repo_id: Optional[str] = None,
2493
+ model: Optional[Any] = None,
2494
+ device: Optional[str] = None,
2495
+ ) -> None:
2461
2496
  """
2462
2497
  Initialize the object extractor.
2463
2498
 
@@ -2481,12 +2516,12 @@ class ParkingSplotDetector(ObjectDetector):
2481
2516
 
2482
2517
  def __init__(
2483
2518
  self,
2484
- model_path="parking_spot_detection.pth",
2485
- repo_id=None,
2486
- model=None,
2487
- num_classes=3,
2488
- device=None,
2489
- ):
2519
+ model_path: str = "parking_spot_detection.pth",
2520
+ repo_id: Optional[str] = None,
2521
+ model: Optional[Any] = None,
2522
+ num_classes: int = 3,
2523
+ device: Optional[str] = None,
2524
+ ) -> None:
2490
2525
  """
2491
2526
  Initialize the object extractor.
2492
2527
 
@@ -2521,13 +2556,13 @@ class AgricultureFieldDelineator(ObjectDetector):
2521
2556
 
2522
2557
  def __init__(
2523
2558
  self,
2524
- model_path="field_boundary_detector.pth",
2525
- repo_id=None,
2526
- model=None,
2527
- device=None,
2528
- band_selection=None,
2529
- use_ndvi=False,
2530
- ):
2559
+ model_path: str = "field_boundary_detector.pth",
2560
+ repo_id: Optional[str] = None,
2561
+ model: Optional[Any] = None,
2562
+ device: Optional[str] = None,
2563
+ band_selection: Optional[List[int]] = None,
2564
+ use_ndvi: bool = False,
2565
+ ) -> None:
2531
2566
  """
2532
2567
  Initialize the field boundary delineator.
2533
2568
 
@@ -2603,7 +2638,7 @@ class AgricultureFieldDelineator(ObjectDetector):
2603
2638
  self.min_object_area = 1000 # Minimum area in pixels for field detection
2604
2639
  self.simplify_tolerance = 2.0 # Higher tolerance for field boundaries
2605
2640
 
2606
- def initialize_sentinel2_model(self, model=None):
2641
+ def initialize_sentinel2_model(self, model: Optional[Any] = None) -> Any:
2607
2642
  """
2608
2643
  Initialize a Mask R-CNN model with a modified first layer to accept Sentinel-2 data.
2609
2644
 
@@ -2656,7 +2691,12 @@ class AgricultureFieldDelineator(ObjectDetector):
2656
2691
  model.to(self.device)
2657
2692
  return model
2658
2693
 
2659
- def preprocess_sentinel_bands(self, image_data, band_selection=None, use_ndvi=None):
2694
+ def preprocess_sentinel_bands(
2695
+ self,
2696
+ image_data: np.ndarray,
2697
+ band_selection: Optional[List[int]] = None,
2698
+ use_ndvi: Optional[bool] = None,
2699
+ ) -> torch.Tensor:
2660
2700
  """
2661
2701
  Preprocess Sentinel-2 band data for model input.
2662
2702
 
@@ -2718,7 +2758,12 @@ class AgricultureFieldDelineator(ObjectDetector):
2718
2758
 
2719
2759
  return image_tensor
2720
2760
 
2721
- def update_band_stats(self, raster_path, band_selection=None, sample_size=1000):
2761
+ def update_band_stats(
2762
+ self,
2763
+ raster_path: str,
2764
+ band_selection: Optional[List[int]] = None,
2765
+ sample_size: int = 1000,
2766
+ ) -> Dict[str, List[float]]:
2722
2767
  """
2723
2768
  Update band statistics from the input Sentinel-2 raster.
2724
2769
 
@@ -2782,15 +2827,15 @@ class AgricultureFieldDelineator(ObjectDetector):
2782
2827
 
2783
2828
  def process_sentinel_raster(
2784
2829
  self,
2785
- raster_path,
2786
- output_path=None,
2787
- batch_size=4,
2788
- band_selection=None,
2789
- use_ndvi=None,
2790
- filter_edges=True,
2791
- edge_buffer=20,
2792
- **kwargs,
2793
- ):
2830
+ raster_path: str,
2831
+ output_path: Optional[str] = None,
2832
+ batch_size: int = 4,
2833
+ band_selection: Optional[List[int]] = None,
2834
+ use_ndvi: Optional[bool] = None,
2835
+ filter_edges: bool = True,
2836
+ edge_buffer: int = 20,
2837
+ **kwargs: Any,
2838
+ ) -> gpd.GeoDataFrame:
2794
2839
  """
2795
2840
  Process a Sentinel-2 raster to extract field boundaries.
2796
2841
 
@@ -2840,14 +2885,14 @@ class AgricultureFieldDelineator(ObjectDetector):
2840
2885
  class Sentinel2Dataset(torch.utils.data.Dataset):
2841
2886
  def __init__(
2842
2887
  self,
2843
- raster_path,
2844
- chip_size,
2845
- stride_x,
2846
- stride_y,
2847
- band_selection,
2848
- use_ndvi,
2849
- field_delineator,
2850
- ):
2888
+ raster_path: str,
2889
+ chip_size: Tuple[int, int],
2890
+ stride_x: int,
2891
+ stride_y: int,
2892
+ band_selection: List[int],
2893
+ use_ndvi: bool,
2894
+ field_delineator: Any,
2895
+ ) -> None:
2851
2896
  self.raster_path = raster_path
2852
2897
  self.chip_size = chip_size
2853
2898
  self.stride_x = stride_x
@@ -2901,10 +2946,10 @@ class AgricultureFieldDelineator(ObjectDetector):
2901
2946
  print(f"Image dimensions: {self.width} x {self.height} pixels")
2902
2947
  print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
2903
2948
 
2904
- def __len__(self):
2949
+ def __len__(self) -> int:
2905
2950
  return self.rows * self.cols
2906
2951
 
2907
- def __getitem__(self, idx):
2952
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
2908
2953
  # Convert flat index to grid position
2909
2954
  row = idx // self.cols
2910
2955
  col = idx % self.cols
@@ -2971,7 +3016,7 @@ class AgricultureFieldDelineator(ObjectDetector):
2971
3016
  )
2972
3017
 
2973
3018
  # Define custom collate function
2974
- def custom_collate(batch):
3019
+ def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
2975
3020
  elem = batch[0]
2976
3021
  if isinstance(elem, dict):
2977
3022
  result = {}