geoai-py 0.4.3__py2.py3-none-any.whl → 0.5.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +6 -1
- geoai/classify.py +933 -0
- geoai/download.py +119 -80
- geoai/geoai.py +93 -1
- geoai/train.py +115 -6
- geoai/utils.py +196 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/METADATA +9 -1
- geoai_py-0.5.1.dist-info/RECORD +16 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/WHEEL +1 -1
- geoai_py-0.4.3.dist-info/RECORD +0 -15
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/top_level.txt +0 -0
geoai/classify.py
ADDED
|
@@ -0,0 +1,933 @@
|
|
|
1
|
+
"""The module for training semantic segmentation models for classifying remote sensing imagery."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def train_classifier(
|
|
8
|
+
image_root,
|
|
9
|
+
label_root,
|
|
10
|
+
output_dir="output",
|
|
11
|
+
in_channels=4,
|
|
12
|
+
num_classes=14,
|
|
13
|
+
epochs=20,
|
|
14
|
+
img_size=256,
|
|
15
|
+
batch_size=8,
|
|
16
|
+
sample_size=500,
|
|
17
|
+
model="unet",
|
|
18
|
+
backbone="resnet50",
|
|
19
|
+
weights=True,
|
|
20
|
+
num_filters=3,
|
|
21
|
+
loss="ce",
|
|
22
|
+
class_weights=None,
|
|
23
|
+
ignore_index=None,
|
|
24
|
+
lr=0.001,
|
|
25
|
+
patience=10,
|
|
26
|
+
freeze_backbone=False,
|
|
27
|
+
freeze_decoder=False,
|
|
28
|
+
transforms=None,
|
|
29
|
+
use_augmentation=False,
|
|
30
|
+
seed=42,
|
|
31
|
+
train_val_test_split=(0.6, 0.2, 0.2),
|
|
32
|
+
accelerator="auto",
|
|
33
|
+
devices="auto",
|
|
34
|
+
logger=None,
|
|
35
|
+
callbacks=None,
|
|
36
|
+
log_every_n_steps=10,
|
|
37
|
+
use_distributed_sampler=False,
|
|
38
|
+
monitor_metric="val_loss",
|
|
39
|
+
mode="min",
|
|
40
|
+
save_top_k=1,
|
|
41
|
+
save_last=True,
|
|
42
|
+
checkpoint_filename="best_model",
|
|
43
|
+
checkpoint_path=None,
|
|
44
|
+
every_n_epochs=1,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
"""Train a semantic segmentation model on geospatial imagery.
|
|
48
|
+
|
|
49
|
+
This function sets up datasets, model, trainer, and executes the training process
|
|
50
|
+
for semantic segmentation tasks using geospatial data. It supports training
|
|
51
|
+
from scratch or resuming from a checkpoint if available.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
image_root (str): Path to directory containing imagery.
|
|
55
|
+
label_root (str): Path to directory containing land cover labels.
|
|
56
|
+
output_dir (str, optional): Directory to save model outputs and checkpoints.
|
|
57
|
+
Defaults to "output".
|
|
58
|
+
in_channels (int, optional): Number of input channels in the imagery.
|
|
59
|
+
Defaults to 4.
|
|
60
|
+
num_classes (int, optional): Number of classes in the segmentation task.
|
|
61
|
+
Defaults to 14.
|
|
62
|
+
epochs (int, optional): Number of training epochs. Defaults to 20.
|
|
63
|
+
img_size (int, optional): Size of image patches for training. Defaults to 256.
|
|
64
|
+
batch_size (int, optional): Batch size for training. Defaults to 8.
|
|
65
|
+
sample_size (int, optional): Number of samples per epoch. Defaults to 500.
|
|
66
|
+
model (str, optional): Model architecture to use. Defaults to "unet".
|
|
67
|
+
backbone (str, optional): Backbone network for the model. Defaults to "resnet50".
|
|
68
|
+
weights (bool, optional): Whether to use pretrained weights. Defaults to True.
|
|
69
|
+
num_filters (int, optional): Number of filters for the model. Defaults to 3.
|
|
70
|
+
loss (str, optional): Loss function to use ('ce', 'jaccard', or 'focal').
|
|
71
|
+
Defaults to "ce".
|
|
72
|
+
class_weights (list, optional): Class weights for loss function. Defaults to None.
|
|
73
|
+
ignore_index (int, optional): Index to ignore in loss calculation. Defaults to None.
|
|
74
|
+
lr (float, optional): Learning rate. Defaults to 0.001.
|
|
75
|
+
patience (int, optional): Number of epochs with no improvement after which
|
|
76
|
+
training will stop. Defaults to 10.
|
|
77
|
+
freeze_backbone (bool, optional): Whether to freeze backbone. Defaults to False.
|
|
78
|
+
freeze_decoder (bool, optional): Whether to freeze decoder. Defaults to False.
|
|
79
|
+
transforms (callable, optional): Transforms to apply to the data. Defaults to None.
|
|
80
|
+
use_augmentation (bool, optional): Whether to apply data augmentation.
|
|
81
|
+
Defaults to False.
|
|
82
|
+
seed (int, optional): Random seed for reproducibility. Defaults to 42.
|
|
83
|
+
train_val_test_split (list, optional): Proportions for train/val/test split.
|
|
84
|
+
Defaults to [0.6, 0.2, 0.2].
|
|
85
|
+
accelerator (str, optional): Accelerator to use for training ('cpu', 'gpu', etc.).
|
|
86
|
+
Defaults to "auto".
|
|
87
|
+
devices (str, optional): Number of devices to use for training. Defaults to "auto".
|
|
88
|
+
logger (object, optional): Logger for tracking training progress. Defaults to None.
|
|
89
|
+
callbacks (list, optional): List of callbacks for the trainer. Defaults to None.
|
|
90
|
+
log_every_n_steps (int, optional): Frequency of logging training progress.
|
|
91
|
+
Defaults to 10.
|
|
92
|
+
use_distributed_sampler (bool, optional): Whether to use distributed sampling.
|
|
93
|
+
Defaults to False.
|
|
94
|
+
monitor_metric (str, optional): Metric to monitor for saving best model.
|
|
95
|
+
Defaults to "val_loss".
|
|
96
|
+
mode (str, optional): Mode for monitoring metric ('min' or 'max').
|
|
97
|
+
Use 'min' for losses and 'max' for metrics like accuracy.
|
|
98
|
+
Defaults to "min".
|
|
99
|
+
save_top_k (int, optional): Number of best models to save.
|
|
100
|
+
Defaults to 1.
|
|
101
|
+
save_last (bool, optional): Whether to save the model from the last epoch.
|
|
102
|
+
Defaults to True.
|
|
103
|
+
checkpoint_filename (str, optional): Filename pattern for saved checkpoints.
|
|
104
|
+
Defaults to "best_model_{epoch:02d}_{val_loss:.4f}".
|
|
105
|
+
checkpoint_path (str, optional): Path to a checkpoint file to resume training.
|
|
106
|
+
every_n_epochs (int, optional): Save a checkpoint every N epochs.
|
|
107
|
+
Defaults to 1.
|
|
108
|
+
**kwargs: Additional keyword arguments to pass to the datasets.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
object: Trained SemanticSegmentationTask model.
|
|
112
|
+
"""
|
|
113
|
+
import lightning.pytorch as pl
|
|
114
|
+
from torch.utils.data import DataLoader
|
|
115
|
+
from torchgeo.datasets import stack_samples, RasterDataset
|
|
116
|
+
from torchgeo.datasets.splits import random_bbox_assignment
|
|
117
|
+
from torchgeo.samplers import (
|
|
118
|
+
RandomGeoSampler,
|
|
119
|
+
RandomBatchGeoSampler,
|
|
120
|
+
GridGeoSampler,
|
|
121
|
+
)
|
|
122
|
+
import torch
|
|
123
|
+
import multiprocessing as mp
|
|
124
|
+
import timeit
|
|
125
|
+
import albumentations as A
|
|
126
|
+
from torchgeo.datamodules import GeoDataModule
|
|
127
|
+
from torchgeo.trainers import SemanticSegmentationTask
|
|
128
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
129
|
+
from lightning.pytorch.loggers import CSVLogger
|
|
130
|
+
|
|
131
|
+
# Create a wrapper class for albumentations to work with TorchGeo format
|
|
132
|
+
class AlbumentationsWrapper:
|
|
133
|
+
def __init__(self, transform):
|
|
134
|
+
self.transform = transform
|
|
135
|
+
|
|
136
|
+
def __call__(self, sample):
|
|
137
|
+
# Extract image and mask from TorchGeo sample format
|
|
138
|
+
if "image" not in sample or "mask" not in sample:
|
|
139
|
+
return sample
|
|
140
|
+
|
|
141
|
+
image = sample["image"]
|
|
142
|
+
mask = sample["mask"]
|
|
143
|
+
|
|
144
|
+
# Albumentations expects channels last, but TorchGeo uses channels first
|
|
145
|
+
# Convert (C, H, W) to (H, W, C) for image
|
|
146
|
+
image_np = image.permute(1, 2, 0).numpy()
|
|
147
|
+
mask_np = mask.squeeze(0).numpy() if mask.dim() > 2 else mask.numpy()
|
|
148
|
+
|
|
149
|
+
# Apply transformation with named arguments
|
|
150
|
+
transformed = self.transform(image=image_np, mask=mask_np)
|
|
151
|
+
|
|
152
|
+
# Convert back to PyTorch tensors with channels first
|
|
153
|
+
transformed_image = torch.from_numpy(transformed["image"]).permute(2, 0, 1)
|
|
154
|
+
transformed_mask = torch.from_numpy(transformed["mask"]).unsqueeze(0)
|
|
155
|
+
|
|
156
|
+
# Update the sample dictionary
|
|
157
|
+
result = sample.copy()
|
|
158
|
+
result["image"] = transformed_image
|
|
159
|
+
result["mask"] = transformed_mask
|
|
160
|
+
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
# Set up data augmentation if requested
|
|
164
|
+
if use_augmentation:
|
|
165
|
+
aug_transforms = A.Compose(
|
|
166
|
+
[
|
|
167
|
+
A.HorizontalFlip(p=0.5),
|
|
168
|
+
A.VerticalFlip(p=0.5),
|
|
169
|
+
A.RandomRotate90(p=0.5),
|
|
170
|
+
A.ShiftScaleRotate(
|
|
171
|
+
p=0.5, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45
|
|
172
|
+
),
|
|
173
|
+
A.RandomBrightnessContrast(
|
|
174
|
+
p=0.5, brightness_limit=0.2, contrast_limit=0.2
|
|
175
|
+
),
|
|
176
|
+
A.GaussianBlur(p=0.3),
|
|
177
|
+
A.GaussNoise(p=0.3),
|
|
178
|
+
A.CoarseDropout(p=0.3, max_holes=8, max_height=32, max_width=32),
|
|
179
|
+
]
|
|
180
|
+
)
|
|
181
|
+
# Wrap the albumentations transforms
|
|
182
|
+
transforms = AlbumentationsWrapper(aug_transforms)
|
|
183
|
+
|
|
184
|
+
# # Set up device configuration
|
|
185
|
+
# device, num_devices = (
|
|
186
|
+
# ("cuda", torch.cuda.device_count())
|
|
187
|
+
# if torch.cuda.is_available()
|
|
188
|
+
# else ("cpu", mp.cpu_count())
|
|
189
|
+
# )
|
|
190
|
+
workers = mp.cpu_count()
|
|
191
|
+
# print(f"Running on {num_devices} {device}(s)")
|
|
192
|
+
|
|
193
|
+
# Define datasets
|
|
194
|
+
class ImageDatasetClass(RasterDataset):
|
|
195
|
+
filename_glob = "*.tif"
|
|
196
|
+
is_image = True
|
|
197
|
+
separate_files = False
|
|
198
|
+
|
|
199
|
+
class LabelDatasetClass(RasterDataset):
|
|
200
|
+
filename_glob = "*.tif"
|
|
201
|
+
is_image = False
|
|
202
|
+
separate_files = False
|
|
203
|
+
|
|
204
|
+
# Prepare output directory
|
|
205
|
+
test_dir = os.path.join(output_dir, "models")
|
|
206
|
+
if not os.path.exists(test_dir):
|
|
207
|
+
os.makedirs(test_dir)
|
|
208
|
+
|
|
209
|
+
# Set up logger and checkpoint callback
|
|
210
|
+
if logger is None:
|
|
211
|
+
logger = CSVLogger(test_dir, name="lightning_logs")
|
|
212
|
+
|
|
213
|
+
if callbacks is None:
|
|
214
|
+
checkpoint_callback = ModelCheckpoint(
|
|
215
|
+
dirpath=test_dir,
|
|
216
|
+
filename=checkpoint_filename,
|
|
217
|
+
save_top_k=save_top_k,
|
|
218
|
+
monitor=monitor_metric,
|
|
219
|
+
mode=mode,
|
|
220
|
+
save_last=save_last,
|
|
221
|
+
every_n_epochs=every_n_epochs,
|
|
222
|
+
verbose=True,
|
|
223
|
+
)
|
|
224
|
+
callbacks = [checkpoint_callback]
|
|
225
|
+
|
|
226
|
+
# Initialize the segmentation task
|
|
227
|
+
task = SemanticSegmentationTask(
|
|
228
|
+
model=model,
|
|
229
|
+
backbone=backbone,
|
|
230
|
+
weights=weights,
|
|
231
|
+
in_channels=in_channels,
|
|
232
|
+
num_classes=num_classes,
|
|
233
|
+
num_filters=num_filters,
|
|
234
|
+
loss=loss,
|
|
235
|
+
class_weights=class_weights,
|
|
236
|
+
ignore_index=ignore_index,
|
|
237
|
+
lr=lr,
|
|
238
|
+
patience=patience,
|
|
239
|
+
freeze_backbone=freeze_backbone,
|
|
240
|
+
freeze_decoder=freeze_decoder,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Set up trainer
|
|
244
|
+
trainer = pl.Trainer(
|
|
245
|
+
accelerator=accelerator,
|
|
246
|
+
devices=devices,
|
|
247
|
+
max_epochs=epochs,
|
|
248
|
+
callbacks=callbacks,
|
|
249
|
+
logger=logger,
|
|
250
|
+
log_every_n_steps=log_every_n_steps,
|
|
251
|
+
use_distributed_sampler=use_distributed_sampler,
|
|
252
|
+
**kwargs, # Pass any additional kwargs to the trainer
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Load datasets with transforms if augmentation is enabled
|
|
256
|
+
|
|
257
|
+
if isinstance(image_root, RasterDataset):
|
|
258
|
+
images = image_root
|
|
259
|
+
else:
|
|
260
|
+
images = ImageDatasetClass(paths=image_root, transforms=transforms, **kwargs)
|
|
261
|
+
|
|
262
|
+
if isinstance(label_root, RasterDataset):
|
|
263
|
+
labels = label_root
|
|
264
|
+
else:
|
|
265
|
+
labels = LabelDatasetClass(paths=label_root, **kwargs)
|
|
266
|
+
|
|
267
|
+
# Create intersection dataset
|
|
268
|
+
dataset = images & labels
|
|
269
|
+
|
|
270
|
+
# Define custom datamodule for training
|
|
271
|
+
class CustomGeoDataModule(GeoDataModule):
|
|
272
|
+
def setup(self, stage: str) -> None:
|
|
273
|
+
"""Set up datasets.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
|
277
|
+
"""
|
|
278
|
+
self.dataset = self.dataset_class(**self.kwargs)
|
|
279
|
+
|
|
280
|
+
generator = torch.Generator().manual_seed(seed)
|
|
281
|
+
(
|
|
282
|
+
self.train_dataset,
|
|
283
|
+
self.val_dataset,
|
|
284
|
+
self.test_dataset,
|
|
285
|
+
) = random_bbox_assignment(dataset, train_val_test_split, generator)
|
|
286
|
+
|
|
287
|
+
if stage in ["fit"]:
|
|
288
|
+
self.train_batch_sampler = RandomBatchGeoSampler(
|
|
289
|
+
self.train_dataset, self.patch_size, self.batch_size, self.length
|
|
290
|
+
)
|
|
291
|
+
if stage in ["fit", "validate"]:
|
|
292
|
+
self.val_sampler = GridGeoSampler(
|
|
293
|
+
self.val_dataset, self.patch_size, self.patch_size
|
|
294
|
+
)
|
|
295
|
+
if stage in ["test"]:
|
|
296
|
+
self.test_sampler = GridGeoSampler(
|
|
297
|
+
self.test_dataset, self.patch_size, self.patch_size
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Create datamodule
|
|
301
|
+
datamodule = CustomGeoDataModule(
|
|
302
|
+
dataset_class=type(dataset),
|
|
303
|
+
batch_size=batch_size,
|
|
304
|
+
patch_size=img_size,
|
|
305
|
+
length=sample_size,
|
|
306
|
+
num_workers=workers,
|
|
307
|
+
dataset1=images,
|
|
308
|
+
dataset2=labels,
|
|
309
|
+
collate_fn=stack_samples,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Start training timer
|
|
313
|
+
start = timeit.default_timer()
|
|
314
|
+
|
|
315
|
+
# Check for existing checkpoint
|
|
316
|
+
if checkpoint_path is not None:
|
|
317
|
+
checkpoint_file = os.path.abspath(checkpoint_path)
|
|
318
|
+
else:
|
|
319
|
+
checkpoint_file = os.path.join(test_dir, "last.ckpt")
|
|
320
|
+
|
|
321
|
+
if os.path.isfile(checkpoint_file):
|
|
322
|
+
print("Resuming training from previous checkpoint...")
|
|
323
|
+
trainer.fit(model=task, datamodule=datamodule, ckpt_path=checkpoint_file)
|
|
324
|
+
else:
|
|
325
|
+
print("Starting training from scratch...")
|
|
326
|
+
trainer.fit(
|
|
327
|
+
model=task,
|
|
328
|
+
datamodule=datamodule,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
training_time = timeit.default_timer() - start
|
|
332
|
+
print(f"The time taken to train was: {training_time:.2f} seconds")
|
|
333
|
+
|
|
334
|
+
best_model_path = checkpoint_callback.best_model_path
|
|
335
|
+
print(f"Best model saved at: {best_model_path}")
|
|
336
|
+
|
|
337
|
+
# Test the model
|
|
338
|
+
trainer.test(model=task, datamodule=datamodule)
|
|
339
|
+
|
|
340
|
+
return task
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _classify_image(
|
|
344
|
+
image_path,
|
|
345
|
+
model_path,
|
|
346
|
+
output_path=None,
|
|
347
|
+
chip_size=1024,
|
|
348
|
+
batch_size=4,
|
|
349
|
+
colormap=None,
|
|
350
|
+
num_workers=2,
|
|
351
|
+
**kwargs,
|
|
352
|
+
):
|
|
353
|
+
"""
|
|
354
|
+
Classify a geospatial image using a trained semantic segmentation model. The version has
|
|
355
|
+
tile edge artifacts.
|
|
356
|
+
|
|
357
|
+
This function handles the full image classification pipeline:
|
|
358
|
+
1. Loads the image and model
|
|
359
|
+
2. Cuts the image into tiles/chips
|
|
360
|
+
3. Makes predictions on each chip
|
|
361
|
+
4. Georeferences each prediction
|
|
362
|
+
5. Merges all predictions into a single georeferenced output
|
|
363
|
+
|
|
364
|
+
Parameters:
|
|
365
|
+
image_path (str): Path to the input GeoTIFF image.
|
|
366
|
+
model_path (str): Path to the trained model checkpoint.
|
|
367
|
+
output_path (str, optional): Path to save the output classified image.
|
|
368
|
+
Defaults to "classified_output.tif".
|
|
369
|
+
chip_size (int, optional): Size of chips for processing. Defaults to 1024.
|
|
370
|
+
batch_size (int, optional): Batch size for inference. Defaults to 4.
|
|
371
|
+
colormap (dict, optional): Colormap to apply to the output image.
|
|
372
|
+
Defaults to None.
|
|
373
|
+
num_workers (int, optional): Number of workers for DataLoader. Defaults to 2.
|
|
374
|
+
**kwargs: Additional keyword arguments for DataLoader.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
str: Path to the saved classified image.
|
|
378
|
+
"""
|
|
379
|
+
import numpy as np
|
|
380
|
+
import timeit
|
|
381
|
+
from tqdm import tqdm
|
|
382
|
+
|
|
383
|
+
import torch
|
|
384
|
+
from torch.utils.data import DataLoader
|
|
385
|
+
|
|
386
|
+
from torchgeo.datasets import RasterDataset, stack_samples
|
|
387
|
+
from torchgeo.samplers import GridGeoSampler
|
|
388
|
+
from torchgeo.trainers import SemanticSegmentationTask
|
|
389
|
+
|
|
390
|
+
import rasterio
|
|
391
|
+
from rasterio.transform import from_origin
|
|
392
|
+
from rasterio.io import MemoryFile
|
|
393
|
+
from rasterio.merge import merge
|
|
394
|
+
|
|
395
|
+
# Set default output path if not provided
|
|
396
|
+
if output_path is None:
|
|
397
|
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
|
398
|
+
output_path = f"{base_name}_classified.tif"
|
|
399
|
+
|
|
400
|
+
# Make sure output directory exists
|
|
401
|
+
output_dir = os.path.dirname(output_path)
|
|
402
|
+
if output_dir and not os.path.exists(output_dir):
|
|
403
|
+
os.makedirs(output_dir)
|
|
404
|
+
|
|
405
|
+
# Load the model
|
|
406
|
+
print(f"Loading model from {model_path}...")
|
|
407
|
+
task = SemanticSegmentationTask.load_from_checkpoint(model_path)
|
|
408
|
+
task.model.eval()
|
|
409
|
+
task.model.cuda()
|
|
410
|
+
|
|
411
|
+
# Set up dataset and sampler
|
|
412
|
+
print(f"Loading image from {image_path}...")
|
|
413
|
+
dataset = RasterDataset(paths=image_path)
|
|
414
|
+
|
|
415
|
+
# Get the bounds and resolution of the dataset
|
|
416
|
+
original_bounds = dataset.bounds
|
|
417
|
+
pixel_size = dataset.res
|
|
418
|
+
crs = dataset.crs.to_epsg()
|
|
419
|
+
|
|
420
|
+
# Use a GridGeoSampler to sample the image in tiles
|
|
421
|
+
sampler = GridGeoSampler(dataset, chip_size, chip_size)
|
|
422
|
+
|
|
423
|
+
# Create DataLoader
|
|
424
|
+
dataloader = DataLoader(
|
|
425
|
+
dataset,
|
|
426
|
+
batch_size=batch_size,
|
|
427
|
+
sampler=sampler,
|
|
428
|
+
collate_fn=stack_samples,
|
|
429
|
+
num_workers=num_workers,
|
|
430
|
+
**kwargs,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
print(f"Processing image in {len(dataloader)} batches...")
|
|
434
|
+
|
|
435
|
+
# Helper function to create in-memory geotiffs for chips
|
|
436
|
+
def create_in_memory_geochip(predicted_chip, geotransform, crs):
|
|
437
|
+
"""Create in-memory georeferenced chips."""
|
|
438
|
+
photometric = "MINISBLACK"
|
|
439
|
+
|
|
440
|
+
# Ensure predicted_chip has shape (bands, height, width)
|
|
441
|
+
if len(predicted_chip.shape) == 2:
|
|
442
|
+
predicted_chip = predicted_chip[np.newaxis, :, :]
|
|
443
|
+
|
|
444
|
+
memfile = MemoryFile()
|
|
445
|
+
dataset = memfile.open(
|
|
446
|
+
driver="GTiff",
|
|
447
|
+
height=predicted_chip.shape[1],
|
|
448
|
+
width=predicted_chip.shape[2],
|
|
449
|
+
count=predicted_chip.shape[0], # Number of bands
|
|
450
|
+
dtype=np.uint8,
|
|
451
|
+
crs=crs,
|
|
452
|
+
transform=geotransform,
|
|
453
|
+
photometric=photometric,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Write all bands
|
|
457
|
+
for band_idx in range(predicted_chip.shape[0]):
|
|
458
|
+
dataset.write(
|
|
459
|
+
predicted_chip[band_idx], band_idx + 1
|
|
460
|
+
) # Band indices are 1-based in rasterio
|
|
461
|
+
|
|
462
|
+
return dataset
|
|
463
|
+
|
|
464
|
+
# Helper function to clip to original bounds
|
|
465
|
+
def clip_to_original_bounds(tif_path, original_bounds, colormap=None):
|
|
466
|
+
"""Clip a GeoTIFF to match original bounds."""
|
|
467
|
+
with rasterio.open(tif_path) as src:
|
|
468
|
+
# Create a window that matches the original bounds
|
|
469
|
+
window = rasterio.windows.from_bounds(
|
|
470
|
+
original_bounds.minx,
|
|
471
|
+
original_bounds.miny,
|
|
472
|
+
original_bounds.maxx,
|
|
473
|
+
original_bounds.maxy,
|
|
474
|
+
transform=src.transform,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Read data within the window
|
|
478
|
+
data = src.read(window=window)
|
|
479
|
+
|
|
480
|
+
# Update the transform
|
|
481
|
+
transform = rasterio.windows.transform(window, src.transform)
|
|
482
|
+
|
|
483
|
+
# Create new metadata
|
|
484
|
+
meta = src.meta.copy()
|
|
485
|
+
meta.update(
|
|
486
|
+
{
|
|
487
|
+
"height": window.height,
|
|
488
|
+
"width": window.width,
|
|
489
|
+
"transform": transform,
|
|
490
|
+
"compress": "deflate",
|
|
491
|
+
}
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Write the clipped data to the same file
|
|
495
|
+
with rasterio.open(tif_path, "w", **meta) as dst:
|
|
496
|
+
dst.write(data)
|
|
497
|
+
if isinstance(colormap, dict):
|
|
498
|
+
dst.write_colormap(1, colormap)
|
|
499
|
+
|
|
500
|
+
# Run inference on all chips
|
|
501
|
+
start_time = timeit.default_timer()
|
|
502
|
+
georref_chips_list = []
|
|
503
|
+
|
|
504
|
+
# Progress bar for processing chips
|
|
505
|
+
progress_bar = tqdm(total=len(dataloader), desc="Processing tiles", unit="batch")
|
|
506
|
+
|
|
507
|
+
for batch in dataloader:
|
|
508
|
+
# Get images and bounds
|
|
509
|
+
images = batch["image"]
|
|
510
|
+
bounds_list = batch["bounds"]
|
|
511
|
+
|
|
512
|
+
# Normalize images
|
|
513
|
+
images = images / 255.0
|
|
514
|
+
|
|
515
|
+
# Make predictions
|
|
516
|
+
with torch.no_grad():
|
|
517
|
+
predictions = task.model.predict(images.cuda())
|
|
518
|
+
predictions = torch.softmax(predictions, dim=1)
|
|
519
|
+
predictions = torch.argmax(predictions, dim=1)
|
|
520
|
+
|
|
521
|
+
# Process each prediction in the batch
|
|
522
|
+
for i in range(len(predictions)):
|
|
523
|
+
# Get the bounds for this chip
|
|
524
|
+
bounds = bounds_list[i]
|
|
525
|
+
|
|
526
|
+
# Create geotransform
|
|
527
|
+
geotransform = from_origin(bounds.minx, bounds.maxy, pixel_size, pixel_size)
|
|
528
|
+
|
|
529
|
+
# Convert prediction to numpy array
|
|
530
|
+
pred = predictions[i].cpu().numpy().astype(np.uint8)
|
|
531
|
+
if len(pred.shape) == 2:
|
|
532
|
+
pred = pred[np.newaxis, :, :]
|
|
533
|
+
|
|
534
|
+
# Create georeferenced chip
|
|
535
|
+
georref_chips_list.append(create_in_memory_geochip(pred, geotransform, crs))
|
|
536
|
+
|
|
537
|
+
# Update progress bar
|
|
538
|
+
progress_bar.update(1)
|
|
539
|
+
|
|
540
|
+
progress_bar.close()
|
|
541
|
+
|
|
542
|
+
prediction_time = timeit.default_timer() - start_time
|
|
543
|
+
print(f"Prediction complete in {prediction_time:.2f} seconds")
|
|
544
|
+
print(f"Produced {len(georref_chips_list)} georeferenced chips")
|
|
545
|
+
|
|
546
|
+
# Merge all georeferenced chips into a single output
|
|
547
|
+
print("Merging predictions...")
|
|
548
|
+
merge_start = timeit.default_timer()
|
|
549
|
+
|
|
550
|
+
# Merge the chips using Rasterio's merge function
|
|
551
|
+
merged, merged_transform = merge(georref_chips_list)
|
|
552
|
+
|
|
553
|
+
# Calculate the number of rows and columns for the merged output
|
|
554
|
+
rows, cols = merged.shape[1], merged.shape[2]
|
|
555
|
+
|
|
556
|
+
# Update the metadata of the merged dataset
|
|
557
|
+
merged_metadata = georref_chips_list[0].meta
|
|
558
|
+
merged_metadata.update(
|
|
559
|
+
{"height": rows, "width": cols, "transform": merged_transform}
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
# Write the merged array to the output file
|
|
563
|
+
with rasterio.open(output_path, "w", **merged_metadata) as dst:
|
|
564
|
+
dst.write(merged)
|
|
565
|
+
# if isinstance(colormap, dict):
|
|
566
|
+
# dst.write_colormap(1, colormap)
|
|
567
|
+
|
|
568
|
+
# Clip to original bounds
|
|
569
|
+
print("Clipping to original image bounds...")
|
|
570
|
+
clip_to_original_bounds(output_path, original_bounds, colormap)
|
|
571
|
+
|
|
572
|
+
# Close all chip datasets
|
|
573
|
+
for chip in tqdm(georref_chips_list, desc="Cleaning up", unit="chip"):
|
|
574
|
+
chip.close()
|
|
575
|
+
|
|
576
|
+
merge_time = timeit.default_timer() - merge_start
|
|
577
|
+
total_time = timeit.default_timer() - start_time
|
|
578
|
+
|
|
579
|
+
print(f"Merge and save complete in {merge_time:.2f} seconds")
|
|
580
|
+
print(f"Total processing time: {total_time:.2f} seconds")
|
|
581
|
+
print(f"Successfully saved classified image to {output_path}")
|
|
582
|
+
|
|
583
|
+
return output_path
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def classify_image(
|
|
587
|
+
image_path,
|
|
588
|
+
model_path,
|
|
589
|
+
output_path=None,
|
|
590
|
+
chip_size=1024,
|
|
591
|
+
overlap=256,
|
|
592
|
+
batch_size=4,
|
|
593
|
+
colormap=None,
|
|
594
|
+
**kwargs,
|
|
595
|
+
):
|
|
596
|
+
"""
|
|
597
|
+
Classify a geospatial image using a trained semantic segmentation model.
|
|
598
|
+
|
|
599
|
+
This function handles the full image classification pipeline with special
|
|
600
|
+
attention to edge handling:
|
|
601
|
+
1. Process the image in a grid pattern with overlapping tiles
|
|
602
|
+
2. Use central regions of tiles for interior parts
|
|
603
|
+
3. Special handling for edges to ensure complete coverage
|
|
604
|
+
4. Merge results into a single georeferenced output
|
|
605
|
+
|
|
606
|
+
Parameters:
|
|
607
|
+
image_path (str): Path to the input GeoTIFF image.
|
|
608
|
+
model_path (str): Path to the trained model checkpoint.
|
|
609
|
+
output_path (str, optional): Path to save the output classified image.
|
|
610
|
+
Defaults to "[input_name]_classified.tif".
|
|
611
|
+
chip_size (int, optional): Size of chips for processing. Defaults to 1024.
|
|
612
|
+
overlap (int, optional): Overlap size between adjacent tiles. Defaults to 256.
|
|
613
|
+
batch_size (int, optional): Batch size for inference. Defaults to 4.
|
|
614
|
+
colormap (dict, optional): Colormap to apply to the output image.
|
|
615
|
+
Defaults to None.
|
|
616
|
+
**kwargs: Additional keyword arguments for DataLoader.
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
str: Path to the saved classified image.
|
|
620
|
+
"""
|
|
621
|
+
import timeit
|
|
622
|
+
|
|
623
|
+
import torch
|
|
624
|
+
from torchgeo.trainers import SemanticSegmentationTask
|
|
625
|
+
|
|
626
|
+
import rasterio
|
|
627
|
+
|
|
628
|
+
import warnings
|
|
629
|
+
from rasterio.errors import NotGeoreferencedWarning
|
|
630
|
+
|
|
631
|
+
# Disable specific GDAL/rasterio warnings
|
|
632
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="rasterio._.*")
|
|
633
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="rasterio")
|
|
634
|
+
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
|
|
635
|
+
|
|
636
|
+
# Also suppress GDAL error reports
|
|
637
|
+
import logging
|
|
638
|
+
|
|
639
|
+
logging.getLogger("rasterio").setLevel(logging.ERROR)
|
|
640
|
+
|
|
641
|
+
# Set default output path if not provided
|
|
642
|
+
if output_path is None:
|
|
643
|
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
|
644
|
+
output_path = f"{base_name}_classified.tif"
|
|
645
|
+
|
|
646
|
+
# Make sure output directory exists
|
|
647
|
+
output_dir = os.path.dirname(output_path)
|
|
648
|
+
if output_dir and not os.path.exists(output_dir):
|
|
649
|
+
os.makedirs(output_dir)
|
|
650
|
+
|
|
651
|
+
# Load the model
|
|
652
|
+
print(f"Loading model from {model_path}...")
|
|
653
|
+
task = SemanticSegmentationTask.load_from_checkpoint(model_path)
|
|
654
|
+
task.model.eval()
|
|
655
|
+
task.model.cuda()
|
|
656
|
+
|
|
657
|
+
# Process the image using a modified tiling approach
|
|
658
|
+
with rasterio.open(image_path) as src:
|
|
659
|
+
# Get image dimensions and metadata
|
|
660
|
+
height = src.height
|
|
661
|
+
width = src.width
|
|
662
|
+
profile = src.profile.copy()
|
|
663
|
+
|
|
664
|
+
# Prepare output array for the final result
|
|
665
|
+
output_image = np.zeros((height, width), dtype=np.uint8)
|
|
666
|
+
confidence_map = np.zeros((height, width), dtype=np.float32)
|
|
667
|
+
|
|
668
|
+
# Calculate number of tiles needed with overlap
|
|
669
|
+
# Ensure we have tiles that specifically cover the edges
|
|
670
|
+
effective_stride = chip_size - overlap
|
|
671
|
+
|
|
672
|
+
# Calculate x positions ensuring leftmost and rightmost edges are covered
|
|
673
|
+
x_positions = []
|
|
674
|
+
# Always include the leftmost position
|
|
675
|
+
x_positions.append(0)
|
|
676
|
+
# Add regular grid positions
|
|
677
|
+
for x in range(effective_stride, width - chip_size, effective_stride):
|
|
678
|
+
x_positions.append(x)
|
|
679
|
+
# Always include rightmost position that still fits
|
|
680
|
+
if width > chip_size and x_positions[-1] + chip_size < width:
|
|
681
|
+
x_positions.append(width - chip_size)
|
|
682
|
+
|
|
683
|
+
# Calculate y positions ensuring top and bottom edges are covered
|
|
684
|
+
y_positions = []
|
|
685
|
+
# Always include the topmost position
|
|
686
|
+
y_positions.append(0)
|
|
687
|
+
# Add regular grid positions
|
|
688
|
+
for y in range(effective_stride, height - chip_size, effective_stride):
|
|
689
|
+
y_positions.append(y)
|
|
690
|
+
# Always include bottommost position that still fits
|
|
691
|
+
if height > chip_size and y_positions[-1] + chip_size < height:
|
|
692
|
+
y_positions.append(height - chip_size)
|
|
693
|
+
|
|
694
|
+
# Create list of all tile positions
|
|
695
|
+
tile_positions = []
|
|
696
|
+
for y in y_positions:
|
|
697
|
+
for x in x_positions:
|
|
698
|
+
y_end = min(y + chip_size, height)
|
|
699
|
+
x_end = min(x + chip_size, width)
|
|
700
|
+
tile_positions.append((y, x, y_end, x_end))
|
|
701
|
+
|
|
702
|
+
# Print information about the tiling
|
|
703
|
+
print(
|
|
704
|
+
f"Processing {len(tile_positions)} patches covering an image of size {height}x{width}..."
|
|
705
|
+
)
|
|
706
|
+
start_time = timeit.default_timer()
|
|
707
|
+
|
|
708
|
+
# Process tiles in batches
|
|
709
|
+
for batch_start in range(0, len(tile_positions), batch_size):
|
|
710
|
+
batch_end = min(batch_start + batch_size, len(tile_positions))
|
|
711
|
+
batch_positions = tile_positions[batch_start:batch_end]
|
|
712
|
+
batch_data = []
|
|
713
|
+
|
|
714
|
+
# Load data for current batch
|
|
715
|
+
for y_start, x_start, y_end, x_end in batch_positions:
|
|
716
|
+
# Calculate actual tile size
|
|
717
|
+
actual_height = y_end - y_start
|
|
718
|
+
actual_width = x_end - x_start
|
|
719
|
+
|
|
720
|
+
# Read the tile data
|
|
721
|
+
tile_data = src.read(window=((y_start, y_end), (x_start, x_end)))
|
|
722
|
+
|
|
723
|
+
# Handle different sized tiles by padding if necessary
|
|
724
|
+
if tile_data.shape[1] != chip_size or tile_data.shape[2] != chip_size:
|
|
725
|
+
padded_data = np.zeros(
|
|
726
|
+
(tile_data.shape[0], chip_size, chip_size),
|
|
727
|
+
dtype=tile_data.dtype,
|
|
728
|
+
)
|
|
729
|
+
padded_data[:, : tile_data.shape[1], : tile_data.shape[2]] = (
|
|
730
|
+
tile_data
|
|
731
|
+
)
|
|
732
|
+
tile_data = padded_data
|
|
733
|
+
|
|
734
|
+
# Convert to tensor
|
|
735
|
+
|
|
736
|
+
tile_tensor = torch.from_numpy(tile_data).float() / 255.0
|
|
737
|
+
batch_data.append(tile_tensor)
|
|
738
|
+
|
|
739
|
+
# Convert batch to tensor
|
|
740
|
+
batch_tensor = torch.stack(batch_data)
|
|
741
|
+
|
|
742
|
+
# Run inference
|
|
743
|
+
with torch.no_grad():
|
|
744
|
+
logits = task.model.predict(batch_tensor.cuda())
|
|
745
|
+
probs = torch.softmax(logits, dim=1)
|
|
746
|
+
confidence, predictions = torch.max(probs, dim=1)
|
|
747
|
+
predictions = predictions.cpu().numpy()
|
|
748
|
+
confidence = confidence.cpu().numpy()
|
|
749
|
+
|
|
750
|
+
# Process each prediction
|
|
751
|
+
for idx, (y_start, x_start, y_end, x_end) in enumerate(batch_positions):
|
|
752
|
+
pred = predictions[idx]
|
|
753
|
+
conf = confidence[idx]
|
|
754
|
+
|
|
755
|
+
# Calculate actual tile size
|
|
756
|
+
actual_height = y_end - y_start
|
|
757
|
+
actual_width = x_end - x_start
|
|
758
|
+
|
|
759
|
+
# Get the actual prediction (removing padding if needed)
|
|
760
|
+
valid_pred = pred[:actual_height, :actual_width]
|
|
761
|
+
valid_conf = conf[:actual_height, :actual_width]
|
|
762
|
+
|
|
763
|
+
# Create confidence weights that favor central parts of tiles
|
|
764
|
+
# but still allow edge tiles to contribute fully at the image edges
|
|
765
|
+
is_edge_x = (x_start == 0) or (x_end == width)
|
|
766
|
+
is_edge_y = (y_start == 0) or (y_end == height)
|
|
767
|
+
|
|
768
|
+
# Create a mask that gives higher weight to central regions
|
|
769
|
+
# but ensures proper edge handling for boundary tiles
|
|
770
|
+
weight_mask = np.ones((actual_height, actual_width), dtype=np.float32)
|
|
771
|
+
|
|
772
|
+
# Only apply central weighting if not at an image edge
|
|
773
|
+
border = overlap // 2
|
|
774
|
+
if not is_edge_x and actual_width > 2 * border:
|
|
775
|
+
# Apply horizontal edge falloff (linear)
|
|
776
|
+
for i in range(border):
|
|
777
|
+
# Left edge
|
|
778
|
+
weight_mask[:, i] = (i + 1) / (border + 1)
|
|
779
|
+
# Right edge (if not at image edge)
|
|
780
|
+
if i < actual_width - border:
|
|
781
|
+
weight_mask[:, actual_width - i - 1] = (i + 1) / (
|
|
782
|
+
border + 1
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
if not is_edge_y and actual_height > 2 * border:
|
|
786
|
+
# Apply vertical edge falloff (linear)
|
|
787
|
+
for i in range(border):
|
|
788
|
+
# Top edge
|
|
789
|
+
weight_mask[i, :] = (i + 1) / (border + 1)
|
|
790
|
+
# Bottom edge (if not at image edge)
|
|
791
|
+
if i < actual_height - border:
|
|
792
|
+
weight_mask[actual_height - i - 1, :] = (i + 1) / (
|
|
793
|
+
border + 1
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
# Combine with prediction confidence
|
|
797
|
+
final_weight = weight_mask * valid_conf
|
|
798
|
+
|
|
799
|
+
# Update the output image based on confidence
|
|
800
|
+
current_conf = confidence_map[y_start:y_end, x_start:x_end]
|
|
801
|
+
update_mask = final_weight > current_conf
|
|
802
|
+
|
|
803
|
+
if np.any(update_mask):
|
|
804
|
+
# Update only pixels where this prediction has higher confidence
|
|
805
|
+
output_image[y_start:y_end, x_start:x_end][update_mask] = (
|
|
806
|
+
valid_pred[update_mask]
|
|
807
|
+
)
|
|
808
|
+
confidence_map[y_start:y_end, x_start:x_end][update_mask] = (
|
|
809
|
+
final_weight[update_mask]
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Update profile for output
|
|
813
|
+
profile.update({"count": 1, "dtype": "uint8", "nodata": 0})
|
|
814
|
+
|
|
815
|
+
# Save the result
|
|
816
|
+
print(f"Saving classified image to {output_path}...")
|
|
817
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
|
818
|
+
dst.write(output_image[np.newaxis, :, :])
|
|
819
|
+
if isinstance(colormap, dict):
|
|
820
|
+
dst.write_colormap(1, colormap)
|
|
821
|
+
|
|
822
|
+
# Calculate timing
|
|
823
|
+
total_time = timeit.default_timer() - start_time
|
|
824
|
+
print(f"Total processing time: {total_time:.2f} seconds")
|
|
825
|
+
print(f"Successfully saved classified image to {output_path}")
|
|
826
|
+
|
|
827
|
+
return output_path
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def classify_images(
|
|
831
|
+
image_paths,
|
|
832
|
+
model_path,
|
|
833
|
+
output_dir=None,
|
|
834
|
+
chip_size=1024,
|
|
835
|
+
batch_size=4,
|
|
836
|
+
colormap=None,
|
|
837
|
+
file_extension=".tif",
|
|
838
|
+
**kwargs,
|
|
839
|
+
):
|
|
840
|
+
"""
|
|
841
|
+
Classify multiple geospatial images using a trained semantic segmentation model.
|
|
842
|
+
|
|
843
|
+
This function accepts either a list of image paths or a directory containing images
|
|
844
|
+
and applies the classify_image function to each image, saving the results in the
|
|
845
|
+
specified output directory.
|
|
846
|
+
|
|
847
|
+
Parameters:
|
|
848
|
+
image_paths (str or list): Either a directory path containing images or a list
|
|
849
|
+
of paths to input GeoTIFF images.
|
|
850
|
+
model_path (str): Path to the trained model checkpoint.
|
|
851
|
+
output_dir (str, optional): Directory to save the output classified images.
|
|
852
|
+
Defaults to None (same directory as input images for a list, or a new
|
|
853
|
+
"classified" subdirectory for a directory input).
|
|
854
|
+
chip_size (int, optional): Size of chips for processing. Defaults to 1024.
|
|
855
|
+
batch_size (int, optional): Batch size for inference. Defaults to 4.
|
|
856
|
+
colormap (dict, optional): Colormap to apply to the output images.
|
|
857
|
+
Defaults to None.
|
|
858
|
+
file_extension (str, optional): File extension to filter by when image_paths
|
|
859
|
+
is a directory. Defaults to ".tif".
|
|
860
|
+
**kwargs: Additional keyword arguments for the classify_image function.
|
|
861
|
+
|
|
862
|
+
Returns:
|
|
863
|
+
list: List of paths to the saved classified images.
|
|
864
|
+
"""
|
|
865
|
+
# Import required libraries
|
|
866
|
+
from tqdm import tqdm
|
|
867
|
+
import glob
|
|
868
|
+
|
|
869
|
+
# Process directory input
|
|
870
|
+
if isinstance(image_paths, str) and os.path.isdir(image_paths):
|
|
871
|
+
# Set default output directory if not provided
|
|
872
|
+
if output_dir is None:
|
|
873
|
+
output_dir = os.path.join(image_paths, "classified")
|
|
874
|
+
|
|
875
|
+
# Get all images with the specified extension
|
|
876
|
+
image_path_list = glob.glob(os.path.join(image_paths, f"*{file_extension}"))
|
|
877
|
+
|
|
878
|
+
# Check if any images were found
|
|
879
|
+
if not image_path_list:
|
|
880
|
+
print(f"No files with extension '{file_extension}' found in {image_paths}")
|
|
881
|
+
return []
|
|
882
|
+
|
|
883
|
+
print(f"Found {len(image_path_list)} images in directory {image_paths}")
|
|
884
|
+
|
|
885
|
+
# Process list input
|
|
886
|
+
elif isinstance(image_paths, list):
|
|
887
|
+
image_path_list = image_paths
|
|
888
|
+
|
|
889
|
+
# Set default output directory if not provided
|
|
890
|
+
if output_dir is None and len(image_path_list) > 0:
|
|
891
|
+
output_dir = os.path.dirname(image_path_list[0])
|
|
892
|
+
|
|
893
|
+
# Invalid input
|
|
894
|
+
else:
|
|
895
|
+
raise ValueError(
|
|
896
|
+
"image_paths must be either a directory path or a list of file paths"
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
# Create output directory if it doesn't exist
|
|
900
|
+
if not os.path.exists(output_dir):
|
|
901
|
+
os.makedirs(output_dir)
|
|
902
|
+
|
|
903
|
+
classified_image_paths = []
|
|
904
|
+
|
|
905
|
+
# Create progress bar
|
|
906
|
+
for image_path in tqdm(image_path_list, desc="Classifying images", unit="image"):
|
|
907
|
+
try:
|
|
908
|
+
# Get just the filename without extension
|
|
909
|
+
base_filename = os.path.splitext(os.path.basename(image_path))[0]
|
|
910
|
+
|
|
911
|
+
# Create output path within output_dir
|
|
912
|
+
output_path = os.path.join(
|
|
913
|
+
output_dir, f"{base_filename}_classified{file_extension}"
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
# Perform classification
|
|
917
|
+
classified_image_path = classify_image(
|
|
918
|
+
image_path,
|
|
919
|
+
model_path,
|
|
920
|
+
output_path=output_path,
|
|
921
|
+
chip_size=chip_size,
|
|
922
|
+
batch_size=batch_size,
|
|
923
|
+
colormap=colormap,
|
|
924
|
+
**kwargs,
|
|
925
|
+
)
|
|
926
|
+
classified_image_paths.append(classified_image_path)
|
|
927
|
+
except Exception as e:
|
|
928
|
+
print(f"Error processing {image_path}: {str(e)}")
|
|
929
|
+
|
|
930
|
+
print(
|
|
931
|
+
f"Classification complete. Processed {len(classified_image_paths)} images successfully."
|
|
932
|
+
)
|
|
933
|
+
return classified_image_paths
|