geoai-py 0.14.0__py2.py3-none-any.whl → 0.15.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/train.py CHANGED
@@ -35,6 +35,89 @@ except ImportError:
35
35
  SMP_AVAILABLE = False
36
36
 
37
37
 
38
+ def parse_coco_annotations(
39
+ coco_json_path: str, images_dir: str, labels_dir: str
40
+ ) -> Tuple[List[str], List[str]]:
41
+ """
42
+ Parse COCO format annotations and return lists of image and label paths.
43
+
44
+ Args:
45
+ coco_json_path (str): Path to COCO annotations JSON file (instances.json).
46
+ images_dir (str): Directory containing image files.
47
+ labels_dir (str): Directory containing label mask files.
48
+
49
+ Returns:
50
+ Tuple[List[str], List[str]]: Lists of image paths and corresponding label paths.
51
+ """
52
+ import json
53
+
54
+ with open(coco_json_path, "r") as f:
55
+ coco_data = json.load(f)
56
+
57
+ # Create mapping from image_id to filename
58
+ image_files = []
59
+ label_files = []
60
+
61
+ for img_info in coco_data["images"]:
62
+ img_filename = img_info["file_name"]
63
+ img_path = os.path.join(images_dir, img_filename)
64
+
65
+ # Derive label filename (same as image filename)
66
+ label_path = os.path.join(labels_dir, img_filename)
67
+
68
+ if os.path.exists(img_path) and os.path.exists(label_path):
69
+ image_files.append(img_path)
70
+ label_files.append(label_path)
71
+
72
+ return image_files, label_files
73
+
74
+
75
+ def parse_yolo_annotations(
76
+ data_dir: str, images_subdir: str = "images", labels_subdir: str = "labels"
77
+ ) -> Tuple[List[str], List[str]]:
78
+ """
79
+ Parse YOLO format annotations and return lists of image and label paths.
80
+
81
+ YOLO format structure:
82
+ - data_dir/images/: Contains image files (.tif, .png, .jpg)
83
+ - data_dir/labels/: Contains label masks (.tif, .png) and YOLO .txt files
84
+ - data_dir/classes.txt: Class names (one per line)
85
+
86
+ Args:
87
+ data_dir (str): Root directory containing YOLO-format data.
88
+ images_subdir (str): Subdirectory name for images. Defaults to 'images'.
89
+ labels_subdir (str): Subdirectory name for labels. Defaults to 'labels'.
90
+
91
+ Returns:
92
+ Tuple[List[str], List[str]]: Lists of image paths and corresponding label paths.
93
+ """
94
+ images_dir = os.path.join(data_dir, images_subdir)
95
+ labels_dir = os.path.join(data_dir, labels_subdir)
96
+
97
+ if not os.path.exists(images_dir):
98
+ raise FileNotFoundError(f"Images directory not found: {images_dir}")
99
+ if not os.path.exists(labels_dir):
100
+ raise FileNotFoundError(f"Labels directory not found: {labels_dir}")
101
+
102
+ # Get all image files
103
+ image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
104
+ image_files = []
105
+ label_files = []
106
+
107
+ for img_file in os.listdir(images_dir):
108
+ if img_file.lower().endswith(image_extensions):
109
+ img_path = os.path.join(images_dir, img_file)
110
+
111
+ # Find corresponding label mask (same filename)
112
+ label_path = os.path.join(labels_dir, img_file)
113
+
114
+ if os.path.exists(label_path):
115
+ image_files.append(img_path)
116
+ label_files.append(label_path)
117
+
118
+ return sorted(image_files), sorted(label_files)
119
+
120
+
38
121
  def get_instance_segmentation_model(
39
122
  num_classes: int = 2, num_channels: int = 3, pretrained: bool = True
40
123
  ) -> torch.nn.Module:
@@ -617,6 +700,7 @@ def train_MaskRCNN_model(
617
700
  images_dir: str,
618
701
  labels_dir: str,
619
702
  output_dir: str,
703
+ input_format: str = "directory",
620
704
  num_channels: int = 3,
621
705
  model: Optional[torch.nn.Module] = None,
622
706
  pretrained: bool = True,
@@ -640,9 +724,17 @@ def train_MaskRCNN_model(
640
724
  the backbone or to continue training from a specific checkpoint.
641
725
 
642
726
  Args:
643
- images_dir (str): Directory containing image GeoTIFF files.
644
- labels_dir (str): Directory containing label GeoTIFF files.
727
+ images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
728
+ or root directory containing images/ subdirectory (for 'yolo' format),
729
+ or directory containing images (for 'coco' format).
730
+ labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
731
+ or path to COCO annotations JSON file (for 'coco' format),
732
+ or not used (for 'yolo' format - labels are in images_dir/labels/).
645
733
  output_dir (str): Directory to save model checkpoints and results.
734
+ input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
735
+ - 'directory': Standard directory structure with separate images_dir and labels_dir
736
+ - 'coco': COCO JSON format (labels_dir should be path to instances.json)
737
+ - 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
646
738
  num_channels (int, optional): Number of input channels. If None, auto-detected.
647
739
  Defaults to 3.
648
740
  model (torch.nn.Module, optional): Predefined model. If None, a new model is created.
@@ -688,45 +780,63 @@ def train_MaskRCNN_model(
688
780
  device = get_device()
689
781
  print(f"Using device: {device}")
690
782
 
691
- # Get all image and label files
692
- # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
693
- image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
694
- label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
783
+ # Get all image and label files based on input format
784
+ if input_format.lower() == "coco":
785
+ # Parse COCO format annotations
786
+ if verbose:
787
+ print(f"Loading COCO format annotations from {labels_dir}")
788
+ # For COCO format, labels_dir is path to instances.json
789
+ # Labels are typically in a "labels" directory parallel to "annotations"
790
+ coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
791
+ labels_directory = os.path.join(coco_root, "labels")
792
+ image_files, label_files = parse_coco_annotations(
793
+ labels_dir, images_dir, labels_directory
794
+ )
795
+ elif input_format.lower() == "yolo":
796
+ # Parse YOLO format annotations
797
+ if verbose:
798
+ print(f"Loading YOLO format data from {images_dir}")
799
+ image_files, label_files = parse_yolo_annotations(images_dir)
800
+ else:
801
+ # Default: directory format
802
+ # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
803
+ image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
804
+ label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
805
+
806
+ image_files = sorted(
807
+ [
808
+ os.path.join(images_dir, f)
809
+ for f in os.listdir(images_dir)
810
+ if f.lower().endswith(image_extensions)
811
+ ]
812
+ )
813
+ label_files = sorted(
814
+ [
815
+ os.path.join(labels_dir, f)
816
+ for f in os.listdir(labels_dir)
817
+ if f.lower().endswith(label_extensions)
818
+ ]
819
+ )
695
820
 
696
- image_files = sorted(
697
- [
698
- os.path.join(images_dir, f)
699
- for f in os.listdir(images_dir)
700
- if f.lower().endswith(image_extensions)
701
- ]
702
- )
703
- label_files = sorted(
704
- [
705
- os.path.join(labels_dir, f)
706
- for f in os.listdir(labels_dir)
707
- if f.lower().endswith(label_extensions)
708
- ]
709
- )
821
+ # Ensure matching files
822
+ if len(image_files) != len(label_files):
823
+ print("Warning: Number of image files and label files don't match!")
824
+ # Find matching files by basename
825
+ basenames = [os.path.basename(f) for f in image_files]
826
+ label_files = [
827
+ os.path.join(labels_dir, os.path.basename(f))
828
+ for f in image_files
829
+ if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
830
+ ]
831
+ image_files = [
832
+ f
833
+ for f, b in zip(image_files, basenames)
834
+ if os.path.exists(os.path.join(labels_dir, b))
835
+ ]
836
+ print(f"Using {len(image_files)} matching files")
710
837
 
711
838
  print(f"Found {len(image_files)} image files and {len(label_files)} label files")
712
839
 
713
- # Ensure matching files
714
- if len(image_files) != len(label_files):
715
- print("Warning: Number of image files and label files don't match!")
716
- # Find matching files by basename
717
- basenames = [os.path.basename(f) for f in image_files]
718
- label_files = [
719
- os.path.join(labels_dir, os.path.basename(f))
720
- for f in image_files
721
- if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
722
- ]
723
- image_files = [
724
- f
725
- for f, b in zip(image_files, basenames)
726
- if os.path.exists(os.path.join(labels_dir, b))
727
- ]
728
- print(f"Using {len(image_files)} matching files")
729
-
730
840
  # Split data into train and validation sets
731
841
  train_imgs, val_imgs, train_labels, val_labels = train_test_split(
732
842
  image_files, label_files, test_size=val_split, random_state=seed
@@ -2129,6 +2239,7 @@ def train_segmentation_model(
2129
2239
  images_dir: str,
2130
2240
  labels_dir: str,
2131
2241
  output_dir: str,
2242
+ input_format: str = "directory",
2132
2243
  architecture: str = "unet",
2133
2244
  encoder_name: str = "resnet34",
2134
2245
  encoder_weights: Optional[str] = "imagenet",
@@ -2160,9 +2271,17 @@ def train_segmentation_model(
2160
2271
  this approach treats the task as pixel-level binary classification.
2161
2272
 
2162
2273
  Args:
2163
- images_dir (str): Directory containing image GeoTIFF files.
2164
- labels_dir (str): Directory containing label GeoTIFF files.
2274
+ images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
2275
+ or root directory containing images/ subdirectory (for 'yolo' format),
2276
+ or directory containing images (for 'coco' format).
2277
+ labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
2278
+ or path to COCO annotations JSON file (for 'coco' format),
2279
+ or not used (for 'yolo' format - labels are in images_dir/labels/).
2165
2280
  output_dir (str): Directory to save model checkpoints and results.
2281
+ input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
2282
+ - 'directory': Standard directory structure with separate images_dir and labels_dir
2283
+ - 'coco': COCO JSON format (labels_dir should be path to instances.json)
2284
+ - 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
2166
2285
  architecture (str): Model architecture ('unet', 'deeplabv3', 'deeplabv3plus', 'fpn',
2167
2286
  'pspnet', 'linknet', 'manet'). Defaults to 'unet'.
2168
2287
  encoder_name (str): Encoder backbone name (e.g., 'resnet34', 'resnet50', 'efficientnet-b0').
@@ -2225,45 +2344,63 @@ def train_segmentation_model(
2225
2344
  device = get_device()
2226
2345
  print(f"Using device: {device}")
2227
2346
 
2228
- # Get all image and label files
2229
- # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
2230
- image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2231
- label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2347
+ # Get all image and label files based on input format
2348
+ if input_format.lower() == "coco":
2349
+ # Parse COCO format annotations
2350
+ if verbose:
2351
+ print(f"Loading COCO format annotations from {labels_dir}")
2352
+ # For COCO format, labels_dir is path to instances.json
2353
+ # Labels are typically in a "labels" directory parallel to "annotations"
2354
+ coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
2355
+ labels_directory = os.path.join(coco_root, "labels")
2356
+ image_files, label_files = parse_coco_annotations(
2357
+ labels_dir, images_dir, labels_directory
2358
+ )
2359
+ elif input_format.lower() == "yolo":
2360
+ # Parse YOLO format annotations
2361
+ if verbose:
2362
+ print(f"Loading YOLO format data from {images_dir}")
2363
+ image_files, label_files = parse_yolo_annotations(images_dir)
2364
+ else:
2365
+ # Default: directory format
2366
+ # Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
2367
+ image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2368
+ label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
2369
+
2370
+ image_files = sorted(
2371
+ [
2372
+ os.path.join(images_dir, f)
2373
+ for f in os.listdir(images_dir)
2374
+ if f.lower().endswith(image_extensions)
2375
+ ]
2376
+ )
2377
+ label_files = sorted(
2378
+ [
2379
+ os.path.join(labels_dir, f)
2380
+ for f in os.listdir(labels_dir)
2381
+ if f.lower().endswith(label_extensions)
2382
+ ]
2383
+ )
2232
2384
 
2233
- image_files = sorted(
2234
- [
2235
- os.path.join(images_dir, f)
2236
- for f in os.listdir(images_dir)
2237
- if f.lower().endswith(image_extensions)
2238
- ]
2239
- )
2240
- label_files = sorted(
2241
- [
2242
- os.path.join(labels_dir, f)
2243
- for f in os.listdir(labels_dir)
2244
- if f.lower().endswith(label_extensions)
2245
- ]
2246
- )
2385
+ # Ensure matching files
2386
+ if len(image_files) != len(label_files):
2387
+ print("Warning: Number of image files and label files don't match!")
2388
+ # Find matching files by basename
2389
+ basenames = [os.path.basename(f) for f in image_files]
2390
+ label_files = [
2391
+ os.path.join(labels_dir, os.path.basename(f))
2392
+ for f in image_files
2393
+ if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
2394
+ ]
2395
+ image_files = [
2396
+ f
2397
+ for f, b in zip(image_files, basenames)
2398
+ if os.path.exists(os.path.join(labels_dir, b))
2399
+ ]
2400
+ print(f"Using {len(image_files)} matching files")
2247
2401
 
2248
2402
  print(f"Found {len(image_files)} image files and {len(label_files)} label files")
2249
2403
 
2250
- # Ensure matching files
2251
- if len(image_files) != len(label_files):
2252
- print("Warning: Number of image files and label files don't match!")
2253
- # Find matching files by basename
2254
- basenames = [os.path.basename(f) for f in image_files]
2255
- label_files = [
2256
- os.path.join(labels_dir, os.path.basename(f))
2257
- for f in image_files
2258
- if os.path.exists(os.path.join(labels_dir, os.path.basename(f)))
2259
- ]
2260
- image_files = [
2261
- f
2262
- for f, b in zip(image_files, basenames)
2263
- if os.path.exists(os.path.join(labels_dir, b))
2264
- ]
2265
- print(f"Using {len(image_files)} matching files")
2266
-
2267
2404
  if len(image_files) == 0:
2268
2405
  raise FileNotFoundError("No matching image and label files found")
2269
2406
 
@@ -3527,6 +3664,7 @@ def train_instance_segmentation_model(
3527
3664
  images_dir: str,
3528
3665
  labels_dir: str,
3529
3666
  output_dir: str,
3667
+ input_format: str = "directory",
3530
3668
  num_classes: int = 2,
3531
3669
  num_channels: int = 3,
3532
3670
  batch_size: int = 4,
@@ -3545,9 +3683,17 @@ def train_instance_segmentation_model(
3545
3683
  This is a wrapper function for train_MaskRCNN_model with clearer naming.
3546
3684
 
3547
3685
  Args:
3548
- images_dir (str): Directory containing image GeoTIFF files.
3549
- labels_dir (str): Directory containing label GeoTIFF files.
3686
+ images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
3687
+ or root directory containing images/ subdirectory (for 'yolo' format),
3688
+ or directory containing images (for 'coco' format).
3689
+ labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
3690
+ or path to COCO annotations JSON file (for 'coco' format),
3691
+ or not used (for 'yolo' format - labels are in images_dir/labels/).
3550
3692
  output_dir (str): Directory to save model checkpoints and results.
3693
+ input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
3694
+ - 'directory': Standard directory structure with separate images_dir and labels_dir
3695
+ - 'coco': COCO JSON format (labels_dir should be path to instances.json)
3696
+ - 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
3551
3697
  num_classes (int): Number of classes (including background). Defaults to 2.
3552
3698
  num_channels (int): Number of input channels. Defaults to 3.
3553
3699
  batch_size (int): Batch size for training. Defaults to 4.
@@ -3572,6 +3718,7 @@ def train_instance_segmentation_model(
3572
3718
  images_dir=images_dir,
3573
3719
  labels_dir=labels_dir,
3574
3720
  output_dir=output_dir,
3721
+ input_format=input_format,
3575
3722
  num_channels=num_channels,
3576
3723
  model=model,
3577
3724
  batch_size=batch_size,