geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/classify.py CHANGED
@@ -1,50 +1,51 @@
1
1
  """The module for training semantic segmentation models for classifying remote sensing imagery."""
2
2
 
3
3
  import os
4
+ from typing import Any, Dict, List, Optional, Union
4
5
 
5
6
  import numpy as np
6
7
 
7
8
 
8
9
  def train_classifier(
9
- image_root,
10
- label_root,
11
- output_dir="output",
12
- in_channels=4,
13
- num_classes=14,
14
- epochs=20,
15
- img_size=256,
16
- batch_size=8,
17
- sample_size=500,
18
- model="unet",
19
- backbone="resnet50",
20
- weights=True,
21
- num_filters=3,
22
- loss="ce",
23
- class_weights=None,
24
- ignore_index=None,
25
- lr=0.001,
26
- patience=10,
27
- freeze_backbone=False,
28
- freeze_decoder=False,
29
- transforms=None,
30
- use_augmentation=False,
31
- seed=42,
32
- train_val_test_split=(0.6, 0.2, 0.2),
33
- accelerator="auto",
34
- devices="auto",
35
- logger=None,
36
- callbacks=None,
37
- log_every_n_steps=10,
38
- use_distributed_sampler=False,
39
- monitor_metric="val_loss",
40
- mode="min",
41
- save_top_k=1,
42
- save_last=True,
43
- checkpoint_filename="best_model",
44
- checkpoint_path=None,
45
- every_n_epochs=1,
46
- **kwargs,
47
- ):
10
+ image_root: str,
11
+ label_root: str,
12
+ output_dir: str = "output",
13
+ in_channels: int = 4,
14
+ num_classes: int = 14,
15
+ epochs: int = 20,
16
+ img_size: int = 256,
17
+ batch_size: int = 8,
18
+ sample_size: int = 500,
19
+ model: str = "unet",
20
+ backbone: str = "resnet50",
21
+ weights: bool = True,
22
+ num_filters: int = 3,
23
+ loss: str = "ce",
24
+ class_weights: Optional[List[float]] = None,
25
+ ignore_index: Optional[int] = None,
26
+ lr: float = 0.001,
27
+ patience: int = 10,
28
+ freeze_backbone: bool = False,
29
+ freeze_decoder: bool = False,
30
+ transforms: Optional[Any] = None,
31
+ use_augmentation: bool = False,
32
+ seed: int = 42,
33
+ train_val_test_split: tuple = (0.6, 0.2, 0.2),
34
+ accelerator: str = "auto",
35
+ devices: str = "auto",
36
+ logger: Optional[Any] = None,
37
+ callbacks: Optional[List[Any]] = None,
38
+ log_every_n_steps: int = 10,
39
+ use_distributed_sampler: bool = False,
40
+ monitor_metric: str = "val_loss",
41
+ mode: str = "min",
42
+ save_top_k: int = 1,
43
+ save_last: bool = True,
44
+ checkpoint_filename: str = "best_model",
45
+ checkpoint_path: Optional[str] = None,
46
+ every_n_epochs: int = 1,
47
+ **kwargs: Any,
48
+ ) -> Any:
48
49
  """Train a semantic segmentation model on geospatial imagery.
49
50
 
50
51
  This function sets up datasets, model, trainer, and executes the training process
@@ -584,15 +585,15 @@ def _classify_image(
584
585
 
585
586
 
586
587
  def classify_image(
587
- image_path,
588
- model_path,
589
- output_path=None,
590
- chip_size=1024,
591
- overlap=256,
592
- batch_size=4,
593
- colormap=None,
594
- **kwargs,
595
- ):
588
+ image_path: str,
589
+ model_path: str,
590
+ output_path: Optional[str] = None,
591
+ chip_size: int = 1024,
592
+ overlap: int = 256,
593
+ batch_size: int = 4,
594
+ colormap: Optional[Dict] = None,
595
+ **kwargs: Any,
596
+ ) -> str:
596
597
  """
597
598
  Classify a geospatial image using a trained semantic segmentation model.
598
599
 
@@ -826,15 +827,15 @@ def classify_image(
826
827
 
827
828
 
828
829
  def classify_images(
829
- image_paths,
830
- model_path,
831
- output_dir=None,
832
- chip_size=1024,
833
- batch_size=4,
834
- colormap=None,
835
- file_extension=".tif",
836
- **kwargs,
837
- ):
830
+ image_paths: Union[str, List[str]],
831
+ model_path: str,
832
+ output_dir: Optional[str] = None,
833
+ chip_size: int = 1024,
834
+ batch_size: int = 4,
835
+ colormap: Optional[Dict] = None,
836
+ file_extension: str = ".tif",
837
+ **kwargs: Any,
838
+ ) -> List[str]:
838
839
  """
839
840
  Classify multiple geospatial images using a trained semantic segmentation model.
840
841
 
geoai/detectron2.py ADDED
@@ -0,0 +1,466 @@
1
+ """Detectron2 integration for remote sensing image segmentation.
2
+ See https://github.com/facebookresearch/detectron2 for more details.
3
+ """
4
+
5
+ import os
6
+ import warnings
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import rasterio
12
+ import torch
13
+ from rasterio.crs import CRS
14
+ from rasterio.transform import from_bounds
15
+
16
+ try:
17
+ import detectron2
18
+ from detectron2 import model_zoo
19
+ from detectron2.config import LazyConfig, get_cfg
20
+ from detectron2.data import MetadataCatalog
21
+ from detectron2.engine import DefaultPredictor
22
+ from detectron2.utils.visualizer import Visualizer
23
+
24
+ HAS_DETECTRON2 = True
25
+ except ImportError:
26
+ HAS_DETECTRON2 = False
27
+ warnings.warn("Detectron2 not found. Please install detectron2 to use this module.")
28
+
29
+ try:
30
+ from .utils import get_device
31
+ except ImportError:
32
+ # Fallback device detection if utils is not available
33
+ def get_device():
34
+ try:
35
+ import torch
36
+
37
+ return "cuda" if torch.cuda.is_available() else "cpu"
38
+ except ImportError:
39
+ return "cpu"
40
+
41
+
42
+ def check_detectron2():
43
+ """Check if detectron2 is available."""
44
+ if not HAS_DETECTRON2:
45
+ raise ImportError(
46
+ "Detectron2 is required. Please install it with: pip install detectron2"
47
+ )
48
+
49
+
50
+ def load_detectron2_model(
51
+ model_config: str = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
52
+ model_weights: Optional[str] = None,
53
+ score_threshold: float = 0.5,
54
+ device: Optional[str] = None,
55
+ num_classes: Optional[int] = None,
56
+ ) -> DefaultPredictor:
57
+ """
58
+ Load a Detectron2 model for instance segmentation.
59
+
60
+ Args:
61
+ model_config: Model configuration file path or name from model zoo
62
+ model_weights: Path to model weights file. If None, uses model zoo weights
63
+ score_threshold: Confidence threshold for predictions
64
+ device: Device to use ('cpu', 'cuda', or None for auto-detection)
65
+ num_classes: Number of classes for custom models
66
+
67
+ Returns:
68
+ DefaultPredictor: Configured Detectron2 predictor
69
+ """
70
+ check_detectron2()
71
+
72
+ cfg = get_cfg()
73
+
74
+ # Load model configuration
75
+ if model_config.endswith(".yaml"):
76
+ cfg.merge_from_file(model_zoo.get_config_file(model_config))
77
+ else:
78
+ cfg.merge_from_file(model_config)
79
+
80
+ # Set model weights
81
+ if model_weights is None:
82
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_config)
83
+ else:
84
+ cfg.MODEL.WEIGHTS = model_weights
85
+
86
+ # Set score threshold
87
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_threshold
88
+
89
+ # Set device
90
+ if device is None:
91
+ device = get_device()
92
+
93
+ # Ensure device is a string (detectron2 expects string, not torch.device)
94
+ if hasattr(device, "type"):
95
+ device = device.type
96
+ elif not isinstance(device, str):
97
+ device = str(device)
98
+
99
+ cfg.MODEL.DEVICE = device
100
+
101
+ # Set number of classes if specified
102
+ if num_classes is not None:
103
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes
104
+
105
+ return DefaultPredictor(cfg)
106
+
107
+
108
+ def detectron2_segment(
109
+ image_path: str,
110
+ output_dir: str = ".",
111
+ model_config: str = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
112
+ model_weights: Optional[str] = None,
113
+ score_threshold: float = 0.5,
114
+ device: Optional[str] = None,
115
+ save_masks: bool = True,
116
+ save_probability: bool = True,
117
+ mask_prefix: str = "instance_masks",
118
+ prob_prefix: str = "probability_mask",
119
+ ) -> Dict:
120
+ """
121
+ Perform instance segmentation on a remote sensing image using Detectron2.
122
+
123
+ Args:
124
+ image_path: Path to input image
125
+ output_dir: Directory to save output files
126
+ model_config: Model configuration file path or name from model zoo
127
+ model_weights: Path to model weights file. If None, uses model zoo weights
128
+ score_threshold: Confidence threshold for predictions
129
+ device: Device to use ('cpu', 'cuda', or None for auto-detection)
130
+ save_masks: Whether to save instance masks as GeoTIFF
131
+ save_probability: Whether to save probability masks as GeoTIFF
132
+ mask_prefix: Prefix for instance mask output file
133
+ prob_prefix: Prefix for probability mask output file
134
+
135
+ Returns:
136
+ Dict containing segmentation results and output file paths
137
+ """
138
+ check_detectron2()
139
+
140
+ # Load the model
141
+ predictor = load_detectron2_model(
142
+ model_config=model_config,
143
+ model_weights=model_weights,
144
+ score_threshold=score_threshold,
145
+ device=device,
146
+ )
147
+
148
+ # Read the image
149
+ image = cv2.imread(image_path)
150
+ if image is None:
151
+ raise ValueError(f"Could not read image from {image_path}")
152
+
153
+ # Run inference
154
+ outputs = predictor(image)
155
+
156
+ # Extract results
157
+ instances = outputs["instances"].to("cpu")
158
+ masks = instances.pred_masks.numpy()
159
+ scores = instances.scores.numpy()
160
+ classes = instances.pred_classes.numpy()
161
+ boxes = instances.pred_boxes.tensor.numpy()
162
+
163
+ results = {
164
+ "masks": masks,
165
+ "scores": scores,
166
+ "classes": classes,
167
+ "boxes": boxes,
168
+ "num_instances": len(masks),
169
+ }
170
+
171
+ # Get image geospatial information
172
+ try:
173
+ with rasterio.open(image_path) as src:
174
+ transform = src.transform
175
+ crs = src.crs
176
+ height, width = src.height, src.width
177
+ except Exception:
178
+ # If not a GeoTIFF, create a simple transform
179
+ height, width = image.shape[:2]
180
+ transform = from_bounds(0, 0, width, height, width, height)
181
+ crs = CRS.from_epsg(4326)
182
+
183
+ # Save instance masks as GeoTIFF
184
+ if save_masks and len(masks) > 0:
185
+ instance_mask_path = os.path.join(output_dir, f"{mask_prefix}.tif")
186
+ instance_mask = create_instance_mask(masks)
187
+ save_geotiff_mask(
188
+ instance_mask, instance_mask_path, transform, crs, dtype="uint16"
189
+ )
190
+ results["instance_mask_path"] = instance_mask_path
191
+
192
+ # Save probability masks as GeoTIFF
193
+ if save_probability and len(masks) > 0:
194
+ prob_mask_path = os.path.join(output_dir, f"{prob_prefix}.tif")
195
+ probability_mask = create_probability_mask(masks, scores)
196
+ save_geotiff_mask(
197
+ probability_mask, prob_mask_path, transform, crs, dtype="float32"
198
+ )
199
+ results["probability_mask_path"] = prob_mask_path
200
+
201
+ return results
202
+
203
+
204
+ def create_instance_mask(masks: np.ndarray) -> np.ndarray:
205
+ """
206
+ Create an instance mask from individual binary masks.
207
+
208
+ Args:
209
+ masks: Array of binary masks with shape (num_instances, height, width)
210
+
211
+ Returns:
212
+ Instance mask with unique ID for each instance
213
+ """
214
+ if len(masks) == 0:
215
+ return np.zeros((masks.shape[1], masks.shape[2]), dtype=np.uint16)
216
+
217
+ instance_mask = np.zeros((masks.shape[1], masks.shape[2]), dtype=np.uint16)
218
+
219
+ for i, mask in enumerate(masks):
220
+ # Assign unique instance ID (starting from 1)
221
+ instance_mask[mask] = i + 1
222
+
223
+ return instance_mask
224
+
225
+
226
+ def create_probability_mask(masks: np.ndarray, scores: np.ndarray) -> np.ndarray:
227
+ """
228
+ Create a probability mask from individual binary masks and their confidence scores.
229
+
230
+ Args:
231
+ masks: Array of binary masks with shape (num_instances, height, width)
232
+ scores: Array of confidence scores for each mask
233
+
234
+ Returns:
235
+ Probability mask with maximum confidence score for each pixel
236
+ """
237
+ if len(masks) == 0:
238
+ return np.zeros((masks.shape[1], masks.shape[2]), dtype=np.float32)
239
+
240
+ probability_mask = np.zeros((masks.shape[1], masks.shape[2]), dtype=np.float32)
241
+
242
+ for i, (mask, score) in enumerate(zip(masks, scores)):
243
+ # Update probability mask with higher confidence scores
244
+ probability_mask = np.where(
245
+ mask & (score > probability_mask), score, probability_mask
246
+ )
247
+
248
+ return probability_mask
249
+
250
+
251
+ def save_geotiff_mask(
252
+ mask: np.ndarray,
253
+ output_path: str,
254
+ transform: rasterio.transform.Affine,
255
+ crs: CRS,
256
+ dtype: str = "uint16",
257
+ ) -> None:
258
+ """
259
+ Save a mask as a GeoTIFF file.
260
+
261
+ Args:
262
+ mask: 2D numpy array representing the mask
263
+ output_path: Path to save the GeoTIFF file
264
+ transform: Rasterio transform for georeferencing
265
+ crs: Coordinate reference system
266
+ dtype: Data type for the output file
267
+ """
268
+ # Create output directory if it doesn't exist
269
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
270
+
271
+ # Determine numpy dtype
272
+ if dtype == "uint16":
273
+ np_dtype = np.uint16
274
+ elif dtype == "float32":
275
+ np_dtype = np.float32
276
+ else:
277
+ np_dtype = np.uint16
278
+
279
+ # Convert mask to appropriate dtype
280
+ mask = mask.astype(np_dtype)
281
+
282
+ # Save as GeoTIFF
283
+ with rasterio.open(
284
+ output_path,
285
+ "w",
286
+ driver="GTiff",
287
+ height=mask.shape[0],
288
+ width=mask.shape[1],
289
+ count=1,
290
+ dtype=np_dtype,
291
+ crs=crs,
292
+ transform=transform,
293
+ compress="lzw",
294
+ ) as dst:
295
+ dst.write(mask, 1)
296
+
297
+
298
+ def visualize_detectron2_results(
299
+ image_path: str,
300
+ results: Dict,
301
+ output_path: Optional[str] = None,
302
+ show_scores: bool = True,
303
+ show_classes: bool = True,
304
+ ) -> np.ndarray:
305
+ """
306
+ Visualize Detectron2 segmentation results on the original image.
307
+
308
+ Args:
309
+ image_path: Path to the original image
310
+ results: Results dictionary from detectron2_segment
311
+ output_path: Path to save the visualization (optional)
312
+ show_scores: Whether to show confidence scores
313
+ show_classes: Whether to show class labels
314
+
315
+ Returns:
316
+ Visualization image as numpy array
317
+ """
318
+ check_detectron2()
319
+
320
+ # Load the image
321
+ image = cv2.imread(image_path)
322
+ if image is None:
323
+ raise ValueError(f"Could not read image from {image_path}")
324
+
325
+ # Convert BGR to RGB
326
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
327
+
328
+ # Create visualizer
329
+ v = Visualizer(image_rgb, scale=1.0)
330
+
331
+ # Create instances object for visualization
332
+ from detectron2.structures import Boxes, Instances
333
+
334
+ instances = Instances((image.shape[0], image.shape[1]))
335
+ instances.pred_masks = torch.from_numpy(results["masks"])
336
+ instances.pred_boxes = Boxes(torch.from_numpy(results["boxes"]))
337
+ instances.scores = torch.from_numpy(results["scores"])
338
+ instances.pred_classes = torch.from_numpy(results["classes"])
339
+
340
+ # Draw predictions
341
+ out = v.draw_instance_predictions(instances)
342
+ vis_image = out.get_image()
343
+
344
+ # Save visualization if path provided
345
+ if output_path is not None:
346
+ cv2.imwrite(output_path, cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR))
347
+
348
+ return vis_image
349
+
350
+
351
+ def get_detectron2_models() -> List[str]:
352
+ """
353
+ Get a list of available Detectron2 models for instance segmentation.
354
+
355
+ Returns:
356
+ List of model configuration names
357
+ """
358
+ from detectron2.model_zoo.model_zoo import _ModelZooUrls
359
+
360
+ configs = list(_ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX.keys())
361
+ models = [f"{config}.yaml" for config in configs]
362
+ return models
363
+
364
+
365
+ def batch_detectron2_segment(
366
+ image_paths: List[str],
367
+ output_dir: str = ".",
368
+ model_config: str = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
369
+ model_weights: Optional[str] = None,
370
+ score_threshold: float = 0.5,
371
+ device: Optional[str] = None,
372
+ save_masks: bool = True,
373
+ save_probability: bool = True,
374
+ ) -> List[Dict]:
375
+ """
376
+ Perform batch instance segmentation on multiple images.
377
+
378
+ Args:
379
+ image_paths: List of paths to input images
380
+ output_dir: Directory to save output files
381
+ model_config: Model configuration file path or name from model zoo
382
+ model_weights: Path to model weights file. If None, uses model zoo weights
383
+ score_threshold: Confidence threshold for predictions
384
+ device: Device to use ('cpu', 'cuda', or None for auto-detection)
385
+ save_masks: Whether to save instance masks as GeoTIFF
386
+ save_probability: Whether to save probability masks as GeoTIFF
387
+
388
+ Returns:
389
+ List of results dictionaries for each image
390
+ """
391
+ check_detectron2()
392
+
393
+ # Load the model once for batch processing
394
+ predictor = load_detectron2_model(
395
+ model_config=model_config,
396
+ model_weights=model_weights,
397
+ score_threshold=score_threshold,
398
+ device=device,
399
+ )
400
+
401
+ results = []
402
+
403
+ for i, image_path in enumerate(image_paths):
404
+ try:
405
+ # Generate unique output prefixes
406
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
407
+ mask_prefix = f"{base_name}_instance_masks"
408
+ prob_prefix = f"{base_name}_probability_mask"
409
+
410
+ # Process image
411
+ result = detectron2_segment(
412
+ image_path=image_path,
413
+ output_dir=output_dir,
414
+ model_config=model_config,
415
+ model_weights=model_weights,
416
+ score_threshold=score_threshold,
417
+ device=device,
418
+ save_masks=save_masks,
419
+ save_probability=save_probability,
420
+ mask_prefix=mask_prefix,
421
+ prob_prefix=prob_prefix,
422
+ )
423
+
424
+ result["image_path"] = image_path
425
+ results.append(result)
426
+
427
+ print(f"Processed {i+1}/{len(image_paths)}: {image_path}")
428
+
429
+ except Exception as e:
430
+ print(f"Error processing {image_path}: {str(e)}")
431
+ results.append({"image_path": image_path, "error": str(e)})
432
+
433
+ return results
434
+
435
+
436
+ def get_class_id_name_mapping(config_path: str, lazy: bool = False) -> Dict[int, str]:
437
+ """
438
+ Get class ID to name mapping from a Detectron2 model config.
439
+
440
+ Args:
441
+ config_path (str): Path to the config file or model_zoo config name.
442
+ lazy (bool): Whether the config is a LazyConfig (i.e., .py).
443
+
444
+ Returns:
445
+ dict: Mapping from class ID (int) to class name (str).
446
+ """
447
+ if lazy or config_path.endswith(".py"):
448
+ cfg = LazyConfig.load(
449
+ model_zoo.get_config_file(config_path)
450
+ if not os.path.exists(config_path)
451
+ else config_path
452
+ )
453
+ dataset_name = cfg.dataloader.train.mapper.dataset.names[0]
454
+ else:
455
+ cfg = get_cfg()
456
+ cfg.merge_from_file(
457
+ model_zoo.get_config_file(config_path)
458
+ if not os.path.exists(config_path)
459
+ else config_path
460
+ )
461
+ dataset_name = cfg.DATASETS.TRAIN[0]
462
+
463
+ metadata = MetadataCatalog.get(dataset_name)
464
+
465
+ classes = metadata.get("thing_classes", []) or metadata.get("stuff_classes", [])
466
+ return {i: name for i, name in enumerate(classes)}