geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/change_detection.py +1568 -0
- geoai/classify.py +58 -57
- geoai/detectron2.py +466 -0
- geoai/download.py +74 -68
- geoai/extract.py +186 -141
- geoai/geoai.py +13 -11
- geoai/hf.py +14 -12
- geoai/segment.py +44 -39
- geoai/segmentation.py +10 -9
- geoai/train.py +372 -241
- geoai/utils.py +198 -123
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/METADATA +5 -1
- geoai_py-0.9.1.dist-info/RECORD +19 -0
- geoai_py-0.8.3.dist-info/RECORD +0 -17
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/top_level.txt +0 -0
geoai/classify.py
CHANGED
@@ -1,50 +1,51 @@
|
|
1
1
|
"""The module for training semantic segmentation models for classifying remote sensing imagery."""
|
2
2
|
|
3
3
|
import os
|
4
|
+
from typing import Any, Dict, List, Optional, Union
|
4
5
|
|
5
6
|
import numpy as np
|
6
7
|
|
7
8
|
|
8
9
|
def train_classifier(
|
9
|
-
image_root,
|
10
|
-
label_root,
|
11
|
-
output_dir="output",
|
12
|
-
in_channels=4,
|
13
|
-
num_classes=14,
|
14
|
-
epochs=20,
|
15
|
-
img_size=256,
|
16
|
-
batch_size=8,
|
17
|
-
sample_size=500,
|
18
|
-
model="unet",
|
19
|
-
backbone="resnet50",
|
20
|
-
weights=True,
|
21
|
-
num_filters=3,
|
22
|
-
loss="ce",
|
23
|
-
class_weights=None,
|
24
|
-
ignore_index=None,
|
25
|
-
lr=0.001,
|
26
|
-
patience=10,
|
27
|
-
freeze_backbone=False,
|
28
|
-
freeze_decoder=False,
|
29
|
-
transforms=None,
|
30
|
-
use_augmentation=False,
|
31
|
-
seed=42,
|
32
|
-
train_val_test_split=(0.6, 0.2, 0.2),
|
33
|
-
accelerator="auto",
|
34
|
-
devices="auto",
|
35
|
-
logger=None,
|
36
|
-
callbacks=None,
|
37
|
-
log_every_n_steps=10,
|
38
|
-
use_distributed_sampler=False,
|
39
|
-
monitor_metric="val_loss",
|
40
|
-
mode="min",
|
41
|
-
save_top_k=1,
|
42
|
-
save_last=True,
|
43
|
-
checkpoint_filename="best_model",
|
44
|
-
checkpoint_path=None,
|
45
|
-
every_n_epochs=1,
|
46
|
-
**kwargs,
|
47
|
-
):
|
10
|
+
image_root: str,
|
11
|
+
label_root: str,
|
12
|
+
output_dir: str = "output",
|
13
|
+
in_channels: int = 4,
|
14
|
+
num_classes: int = 14,
|
15
|
+
epochs: int = 20,
|
16
|
+
img_size: int = 256,
|
17
|
+
batch_size: int = 8,
|
18
|
+
sample_size: int = 500,
|
19
|
+
model: str = "unet",
|
20
|
+
backbone: str = "resnet50",
|
21
|
+
weights: bool = True,
|
22
|
+
num_filters: int = 3,
|
23
|
+
loss: str = "ce",
|
24
|
+
class_weights: Optional[List[float]] = None,
|
25
|
+
ignore_index: Optional[int] = None,
|
26
|
+
lr: float = 0.001,
|
27
|
+
patience: int = 10,
|
28
|
+
freeze_backbone: bool = False,
|
29
|
+
freeze_decoder: bool = False,
|
30
|
+
transforms: Optional[Any] = None,
|
31
|
+
use_augmentation: bool = False,
|
32
|
+
seed: int = 42,
|
33
|
+
train_val_test_split: tuple = (0.6, 0.2, 0.2),
|
34
|
+
accelerator: str = "auto",
|
35
|
+
devices: str = "auto",
|
36
|
+
logger: Optional[Any] = None,
|
37
|
+
callbacks: Optional[List[Any]] = None,
|
38
|
+
log_every_n_steps: int = 10,
|
39
|
+
use_distributed_sampler: bool = False,
|
40
|
+
monitor_metric: str = "val_loss",
|
41
|
+
mode: str = "min",
|
42
|
+
save_top_k: int = 1,
|
43
|
+
save_last: bool = True,
|
44
|
+
checkpoint_filename: str = "best_model",
|
45
|
+
checkpoint_path: Optional[str] = None,
|
46
|
+
every_n_epochs: int = 1,
|
47
|
+
**kwargs: Any,
|
48
|
+
) -> Any:
|
48
49
|
"""Train a semantic segmentation model on geospatial imagery.
|
49
50
|
|
50
51
|
This function sets up datasets, model, trainer, and executes the training process
|
@@ -584,15 +585,15 @@ def _classify_image(
|
|
584
585
|
|
585
586
|
|
586
587
|
def classify_image(
|
587
|
-
image_path,
|
588
|
-
model_path,
|
589
|
-
output_path=None,
|
590
|
-
chip_size=1024,
|
591
|
-
overlap=256,
|
592
|
-
batch_size=4,
|
593
|
-
colormap=None,
|
594
|
-
**kwargs,
|
595
|
-
):
|
588
|
+
image_path: str,
|
589
|
+
model_path: str,
|
590
|
+
output_path: Optional[str] = None,
|
591
|
+
chip_size: int = 1024,
|
592
|
+
overlap: int = 256,
|
593
|
+
batch_size: int = 4,
|
594
|
+
colormap: Optional[Dict] = None,
|
595
|
+
**kwargs: Any,
|
596
|
+
) -> str:
|
596
597
|
"""
|
597
598
|
Classify a geospatial image using a trained semantic segmentation model.
|
598
599
|
|
@@ -826,15 +827,15 @@ def classify_image(
|
|
826
827
|
|
827
828
|
|
828
829
|
def classify_images(
|
829
|
-
image_paths,
|
830
|
-
model_path,
|
831
|
-
output_dir=None,
|
832
|
-
chip_size=1024,
|
833
|
-
batch_size=4,
|
834
|
-
colormap=None,
|
835
|
-
file_extension=".tif",
|
836
|
-
**kwargs,
|
837
|
-
):
|
830
|
+
image_paths: Union[str, List[str]],
|
831
|
+
model_path: str,
|
832
|
+
output_dir: Optional[str] = None,
|
833
|
+
chip_size: int = 1024,
|
834
|
+
batch_size: int = 4,
|
835
|
+
colormap: Optional[Dict] = None,
|
836
|
+
file_extension: str = ".tif",
|
837
|
+
**kwargs: Any,
|
838
|
+
) -> List[str]:
|
838
839
|
"""
|
839
840
|
Classify multiple geospatial images using a trained semantic segmentation model.
|
840
841
|
|
geoai/detectron2.py
ADDED
@@ -0,0 +1,466 @@
|
|
1
|
+
"""Detectron2 integration for remote sensing image segmentation.
|
2
|
+
See https://github.com/facebookresearch/detectron2 for more details.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import warnings
|
7
|
+
from typing import Dict, List, Optional, Tuple, Union
|
8
|
+
|
9
|
+
import cv2
|
10
|
+
import numpy as np
|
11
|
+
import rasterio
|
12
|
+
import torch
|
13
|
+
from rasterio.crs import CRS
|
14
|
+
from rasterio.transform import from_bounds
|
15
|
+
|
16
|
+
try:
|
17
|
+
import detectron2
|
18
|
+
from detectron2 import model_zoo
|
19
|
+
from detectron2.config import LazyConfig, get_cfg
|
20
|
+
from detectron2.data import MetadataCatalog
|
21
|
+
from detectron2.engine import DefaultPredictor
|
22
|
+
from detectron2.utils.visualizer import Visualizer
|
23
|
+
|
24
|
+
HAS_DETECTRON2 = True
|
25
|
+
except ImportError:
|
26
|
+
HAS_DETECTRON2 = False
|
27
|
+
warnings.warn("Detectron2 not found. Please install detectron2 to use this module.")
|
28
|
+
|
29
|
+
try:
|
30
|
+
from .utils import get_device
|
31
|
+
except ImportError:
|
32
|
+
# Fallback device detection if utils is not available
|
33
|
+
def get_device():
|
34
|
+
try:
|
35
|
+
import torch
|
36
|
+
|
37
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
38
|
+
except ImportError:
|
39
|
+
return "cpu"
|
40
|
+
|
41
|
+
|
42
|
+
def check_detectron2():
|
43
|
+
"""Check if detectron2 is available."""
|
44
|
+
if not HAS_DETECTRON2:
|
45
|
+
raise ImportError(
|
46
|
+
"Detectron2 is required. Please install it with: pip install detectron2"
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
def load_detectron2_model(
|
51
|
+
model_config: str = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
|
52
|
+
model_weights: Optional[str] = None,
|
53
|
+
score_threshold: float = 0.5,
|
54
|
+
device: Optional[str] = None,
|
55
|
+
num_classes: Optional[int] = None,
|
56
|
+
) -> DefaultPredictor:
|
57
|
+
"""
|
58
|
+
Load a Detectron2 model for instance segmentation.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
model_config: Model configuration file path or name from model zoo
|
62
|
+
model_weights: Path to model weights file. If None, uses model zoo weights
|
63
|
+
score_threshold: Confidence threshold for predictions
|
64
|
+
device: Device to use ('cpu', 'cuda', or None for auto-detection)
|
65
|
+
num_classes: Number of classes for custom models
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
DefaultPredictor: Configured Detectron2 predictor
|
69
|
+
"""
|
70
|
+
check_detectron2()
|
71
|
+
|
72
|
+
cfg = get_cfg()
|
73
|
+
|
74
|
+
# Load model configuration
|
75
|
+
if model_config.endswith(".yaml"):
|
76
|
+
cfg.merge_from_file(model_zoo.get_config_file(model_config))
|
77
|
+
else:
|
78
|
+
cfg.merge_from_file(model_config)
|
79
|
+
|
80
|
+
# Set model weights
|
81
|
+
if model_weights is None:
|
82
|
+
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_config)
|
83
|
+
else:
|
84
|
+
cfg.MODEL.WEIGHTS = model_weights
|
85
|
+
|
86
|
+
# Set score threshold
|
87
|
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_threshold
|
88
|
+
|
89
|
+
# Set device
|
90
|
+
if device is None:
|
91
|
+
device = get_device()
|
92
|
+
|
93
|
+
# Ensure device is a string (detectron2 expects string, not torch.device)
|
94
|
+
if hasattr(device, "type"):
|
95
|
+
device = device.type
|
96
|
+
elif not isinstance(device, str):
|
97
|
+
device = str(device)
|
98
|
+
|
99
|
+
cfg.MODEL.DEVICE = device
|
100
|
+
|
101
|
+
# Set number of classes if specified
|
102
|
+
if num_classes is not None:
|
103
|
+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes
|
104
|
+
|
105
|
+
return DefaultPredictor(cfg)
|
106
|
+
|
107
|
+
|
108
|
+
def detectron2_segment(
|
109
|
+
image_path: str,
|
110
|
+
output_dir: str = ".",
|
111
|
+
model_config: str = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
|
112
|
+
model_weights: Optional[str] = None,
|
113
|
+
score_threshold: float = 0.5,
|
114
|
+
device: Optional[str] = None,
|
115
|
+
save_masks: bool = True,
|
116
|
+
save_probability: bool = True,
|
117
|
+
mask_prefix: str = "instance_masks",
|
118
|
+
prob_prefix: str = "probability_mask",
|
119
|
+
) -> Dict:
|
120
|
+
"""
|
121
|
+
Perform instance segmentation on a remote sensing image using Detectron2.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
image_path: Path to input image
|
125
|
+
output_dir: Directory to save output files
|
126
|
+
model_config: Model configuration file path or name from model zoo
|
127
|
+
model_weights: Path to model weights file. If None, uses model zoo weights
|
128
|
+
score_threshold: Confidence threshold for predictions
|
129
|
+
device: Device to use ('cpu', 'cuda', or None for auto-detection)
|
130
|
+
save_masks: Whether to save instance masks as GeoTIFF
|
131
|
+
save_probability: Whether to save probability masks as GeoTIFF
|
132
|
+
mask_prefix: Prefix for instance mask output file
|
133
|
+
prob_prefix: Prefix for probability mask output file
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
Dict containing segmentation results and output file paths
|
137
|
+
"""
|
138
|
+
check_detectron2()
|
139
|
+
|
140
|
+
# Load the model
|
141
|
+
predictor = load_detectron2_model(
|
142
|
+
model_config=model_config,
|
143
|
+
model_weights=model_weights,
|
144
|
+
score_threshold=score_threshold,
|
145
|
+
device=device,
|
146
|
+
)
|
147
|
+
|
148
|
+
# Read the image
|
149
|
+
image = cv2.imread(image_path)
|
150
|
+
if image is None:
|
151
|
+
raise ValueError(f"Could not read image from {image_path}")
|
152
|
+
|
153
|
+
# Run inference
|
154
|
+
outputs = predictor(image)
|
155
|
+
|
156
|
+
# Extract results
|
157
|
+
instances = outputs["instances"].to("cpu")
|
158
|
+
masks = instances.pred_masks.numpy()
|
159
|
+
scores = instances.scores.numpy()
|
160
|
+
classes = instances.pred_classes.numpy()
|
161
|
+
boxes = instances.pred_boxes.tensor.numpy()
|
162
|
+
|
163
|
+
results = {
|
164
|
+
"masks": masks,
|
165
|
+
"scores": scores,
|
166
|
+
"classes": classes,
|
167
|
+
"boxes": boxes,
|
168
|
+
"num_instances": len(masks),
|
169
|
+
}
|
170
|
+
|
171
|
+
# Get image geospatial information
|
172
|
+
try:
|
173
|
+
with rasterio.open(image_path) as src:
|
174
|
+
transform = src.transform
|
175
|
+
crs = src.crs
|
176
|
+
height, width = src.height, src.width
|
177
|
+
except Exception:
|
178
|
+
# If not a GeoTIFF, create a simple transform
|
179
|
+
height, width = image.shape[:2]
|
180
|
+
transform = from_bounds(0, 0, width, height, width, height)
|
181
|
+
crs = CRS.from_epsg(4326)
|
182
|
+
|
183
|
+
# Save instance masks as GeoTIFF
|
184
|
+
if save_masks and len(masks) > 0:
|
185
|
+
instance_mask_path = os.path.join(output_dir, f"{mask_prefix}.tif")
|
186
|
+
instance_mask = create_instance_mask(masks)
|
187
|
+
save_geotiff_mask(
|
188
|
+
instance_mask, instance_mask_path, transform, crs, dtype="uint16"
|
189
|
+
)
|
190
|
+
results["instance_mask_path"] = instance_mask_path
|
191
|
+
|
192
|
+
# Save probability masks as GeoTIFF
|
193
|
+
if save_probability and len(masks) > 0:
|
194
|
+
prob_mask_path = os.path.join(output_dir, f"{prob_prefix}.tif")
|
195
|
+
probability_mask = create_probability_mask(masks, scores)
|
196
|
+
save_geotiff_mask(
|
197
|
+
probability_mask, prob_mask_path, transform, crs, dtype="float32"
|
198
|
+
)
|
199
|
+
results["probability_mask_path"] = prob_mask_path
|
200
|
+
|
201
|
+
return results
|
202
|
+
|
203
|
+
|
204
|
+
def create_instance_mask(masks: np.ndarray) -> np.ndarray:
|
205
|
+
"""
|
206
|
+
Create an instance mask from individual binary masks.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
masks: Array of binary masks with shape (num_instances, height, width)
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Instance mask with unique ID for each instance
|
213
|
+
"""
|
214
|
+
if len(masks) == 0:
|
215
|
+
return np.zeros((masks.shape[1], masks.shape[2]), dtype=np.uint16)
|
216
|
+
|
217
|
+
instance_mask = np.zeros((masks.shape[1], masks.shape[2]), dtype=np.uint16)
|
218
|
+
|
219
|
+
for i, mask in enumerate(masks):
|
220
|
+
# Assign unique instance ID (starting from 1)
|
221
|
+
instance_mask[mask] = i + 1
|
222
|
+
|
223
|
+
return instance_mask
|
224
|
+
|
225
|
+
|
226
|
+
def create_probability_mask(masks: np.ndarray, scores: np.ndarray) -> np.ndarray:
|
227
|
+
"""
|
228
|
+
Create a probability mask from individual binary masks and their confidence scores.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
masks: Array of binary masks with shape (num_instances, height, width)
|
232
|
+
scores: Array of confidence scores for each mask
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Probability mask with maximum confidence score for each pixel
|
236
|
+
"""
|
237
|
+
if len(masks) == 0:
|
238
|
+
return np.zeros((masks.shape[1], masks.shape[2]), dtype=np.float32)
|
239
|
+
|
240
|
+
probability_mask = np.zeros((masks.shape[1], masks.shape[2]), dtype=np.float32)
|
241
|
+
|
242
|
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
243
|
+
# Update probability mask with higher confidence scores
|
244
|
+
probability_mask = np.where(
|
245
|
+
mask & (score > probability_mask), score, probability_mask
|
246
|
+
)
|
247
|
+
|
248
|
+
return probability_mask
|
249
|
+
|
250
|
+
|
251
|
+
def save_geotiff_mask(
|
252
|
+
mask: np.ndarray,
|
253
|
+
output_path: str,
|
254
|
+
transform: rasterio.transform.Affine,
|
255
|
+
crs: CRS,
|
256
|
+
dtype: str = "uint16",
|
257
|
+
) -> None:
|
258
|
+
"""
|
259
|
+
Save a mask as a GeoTIFF file.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
mask: 2D numpy array representing the mask
|
263
|
+
output_path: Path to save the GeoTIFF file
|
264
|
+
transform: Rasterio transform for georeferencing
|
265
|
+
crs: Coordinate reference system
|
266
|
+
dtype: Data type for the output file
|
267
|
+
"""
|
268
|
+
# Create output directory if it doesn't exist
|
269
|
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
270
|
+
|
271
|
+
# Determine numpy dtype
|
272
|
+
if dtype == "uint16":
|
273
|
+
np_dtype = np.uint16
|
274
|
+
elif dtype == "float32":
|
275
|
+
np_dtype = np.float32
|
276
|
+
else:
|
277
|
+
np_dtype = np.uint16
|
278
|
+
|
279
|
+
# Convert mask to appropriate dtype
|
280
|
+
mask = mask.astype(np_dtype)
|
281
|
+
|
282
|
+
# Save as GeoTIFF
|
283
|
+
with rasterio.open(
|
284
|
+
output_path,
|
285
|
+
"w",
|
286
|
+
driver="GTiff",
|
287
|
+
height=mask.shape[0],
|
288
|
+
width=mask.shape[1],
|
289
|
+
count=1,
|
290
|
+
dtype=np_dtype,
|
291
|
+
crs=crs,
|
292
|
+
transform=transform,
|
293
|
+
compress="lzw",
|
294
|
+
) as dst:
|
295
|
+
dst.write(mask, 1)
|
296
|
+
|
297
|
+
|
298
|
+
def visualize_detectron2_results(
|
299
|
+
image_path: str,
|
300
|
+
results: Dict,
|
301
|
+
output_path: Optional[str] = None,
|
302
|
+
show_scores: bool = True,
|
303
|
+
show_classes: bool = True,
|
304
|
+
) -> np.ndarray:
|
305
|
+
"""
|
306
|
+
Visualize Detectron2 segmentation results on the original image.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
image_path: Path to the original image
|
310
|
+
results: Results dictionary from detectron2_segment
|
311
|
+
output_path: Path to save the visualization (optional)
|
312
|
+
show_scores: Whether to show confidence scores
|
313
|
+
show_classes: Whether to show class labels
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
Visualization image as numpy array
|
317
|
+
"""
|
318
|
+
check_detectron2()
|
319
|
+
|
320
|
+
# Load the image
|
321
|
+
image = cv2.imread(image_path)
|
322
|
+
if image is None:
|
323
|
+
raise ValueError(f"Could not read image from {image_path}")
|
324
|
+
|
325
|
+
# Convert BGR to RGB
|
326
|
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
327
|
+
|
328
|
+
# Create visualizer
|
329
|
+
v = Visualizer(image_rgb, scale=1.0)
|
330
|
+
|
331
|
+
# Create instances object for visualization
|
332
|
+
from detectron2.structures import Boxes, Instances
|
333
|
+
|
334
|
+
instances = Instances((image.shape[0], image.shape[1]))
|
335
|
+
instances.pred_masks = torch.from_numpy(results["masks"])
|
336
|
+
instances.pred_boxes = Boxes(torch.from_numpy(results["boxes"]))
|
337
|
+
instances.scores = torch.from_numpy(results["scores"])
|
338
|
+
instances.pred_classes = torch.from_numpy(results["classes"])
|
339
|
+
|
340
|
+
# Draw predictions
|
341
|
+
out = v.draw_instance_predictions(instances)
|
342
|
+
vis_image = out.get_image()
|
343
|
+
|
344
|
+
# Save visualization if path provided
|
345
|
+
if output_path is not None:
|
346
|
+
cv2.imwrite(output_path, cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR))
|
347
|
+
|
348
|
+
return vis_image
|
349
|
+
|
350
|
+
|
351
|
+
def get_detectron2_models() -> List[str]:
|
352
|
+
"""
|
353
|
+
Get a list of available Detectron2 models for instance segmentation.
|
354
|
+
|
355
|
+
Returns:
|
356
|
+
List of model configuration names
|
357
|
+
"""
|
358
|
+
from detectron2.model_zoo.model_zoo import _ModelZooUrls
|
359
|
+
|
360
|
+
configs = list(_ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX.keys())
|
361
|
+
models = [f"{config}.yaml" for config in configs]
|
362
|
+
return models
|
363
|
+
|
364
|
+
|
365
|
+
def batch_detectron2_segment(
|
366
|
+
image_paths: List[str],
|
367
|
+
output_dir: str = ".",
|
368
|
+
model_config: str = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
|
369
|
+
model_weights: Optional[str] = None,
|
370
|
+
score_threshold: float = 0.5,
|
371
|
+
device: Optional[str] = None,
|
372
|
+
save_masks: bool = True,
|
373
|
+
save_probability: bool = True,
|
374
|
+
) -> List[Dict]:
|
375
|
+
"""
|
376
|
+
Perform batch instance segmentation on multiple images.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
image_paths: List of paths to input images
|
380
|
+
output_dir: Directory to save output files
|
381
|
+
model_config: Model configuration file path or name from model zoo
|
382
|
+
model_weights: Path to model weights file. If None, uses model zoo weights
|
383
|
+
score_threshold: Confidence threshold for predictions
|
384
|
+
device: Device to use ('cpu', 'cuda', or None for auto-detection)
|
385
|
+
save_masks: Whether to save instance masks as GeoTIFF
|
386
|
+
save_probability: Whether to save probability masks as GeoTIFF
|
387
|
+
|
388
|
+
Returns:
|
389
|
+
List of results dictionaries for each image
|
390
|
+
"""
|
391
|
+
check_detectron2()
|
392
|
+
|
393
|
+
# Load the model once for batch processing
|
394
|
+
predictor = load_detectron2_model(
|
395
|
+
model_config=model_config,
|
396
|
+
model_weights=model_weights,
|
397
|
+
score_threshold=score_threshold,
|
398
|
+
device=device,
|
399
|
+
)
|
400
|
+
|
401
|
+
results = []
|
402
|
+
|
403
|
+
for i, image_path in enumerate(image_paths):
|
404
|
+
try:
|
405
|
+
# Generate unique output prefixes
|
406
|
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
407
|
+
mask_prefix = f"{base_name}_instance_masks"
|
408
|
+
prob_prefix = f"{base_name}_probability_mask"
|
409
|
+
|
410
|
+
# Process image
|
411
|
+
result = detectron2_segment(
|
412
|
+
image_path=image_path,
|
413
|
+
output_dir=output_dir,
|
414
|
+
model_config=model_config,
|
415
|
+
model_weights=model_weights,
|
416
|
+
score_threshold=score_threshold,
|
417
|
+
device=device,
|
418
|
+
save_masks=save_masks,
|
419
|
+
save_probability=save_probability,
|
420
|
+
mask_prefix=mask_prefix,
|
421
|
+
prob_prefix=prob_prefix,
|
422
|
+
)
|
423
|
+
|
424
|
+
result["image_path"] = image_path
|
425
|
+
results.append(result)
|
426
|
+
|
427
|
+
print(f"Processed {i+1}/{len(image_paths)}: {image_path}")
|
428
|
+
|
429
|
+
except Exception as e:
|
430
|
+
print(f"Error processing {image_path}: {str(e)}")
|
431
|
+
results.append({"image_path": image_path, "error": str(e)})
|
432
|
+
|
433
|
+
return results
|
434
|
+
|
435
|
+
|
436
|
+
def get_class_id_name_mapping(config_path: str, lazy: bool = False) -> Dict[int, str]:
|
437
|
+
"""
|
438
|
+
Get class ID to name mapping from a Detectron2 model config.
|
439
|
+
|
440
|
+
Args:
|
441
|
+
config_path (str): Path to the config file or model_zoo config name.
|
442
|
+
lazy (bool): Whether the config is a LazyConfig (i.e., .py).
|
443
|
+
|
444
|
+
Returns:
|
445
|
+
dict: Mapping from class ID (int) to class name (str).
|
446
|
+
"""
|
447
|
+
if lazy or config_path.endswith(".py"):
|
448
|
+
cfg = LazyConfig.load(
|
449
|
+
model_zoo.get_config_file(config_path)
|
450
|
+
if not os.path.exists(config_path)
|
451
|
+
else config_path
|
452
|
+
)
|
453
|
+
dataset_name = cfg.dataloader.train.mapper.dataset.names[0]
|
454
|
+
else:
|
455
|
+
cfg = get_cfg()
|
456
|
+
cfg.merge_from_file(
|
457
|
+
model_zoo.get_config_file(config_path)
|
458
|
+
if not os.path.exists(config_path)
|
459
|
+
else config_path
|
460
|
+
)
|
461
|
+
dataset_name = cfg.DATASETS.TRAIN[0]
|
462
|
+
|
463
|
+
metadata = MetadataCatalog.get(dataset_name)
|
464
|
+
|
465
|
+
classes = metadata.get("thing_classes", []) or metadata.get("stuff_classes", [])
|
466
|
+
return {i: name for i, name in enumerate(classes)}
|