geoai-py 0.4.3__py2.py3-none-any.whl → 0.5.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/classify.py ADDED
@@ -0,0 +1,933 @@
1
+ """The module for training semantic segmentation models for classifying remote sensing imagery."""
2
+
3
+ import os
4
+ import numpy as np
5
+
6
+
7
+ def train_classifier(
8
+ image_root,
9
+ label_root,
10
+ output_dir="output",
11
+ in_channels=4,
12
+ num_classes=14,
13
+ epochs=20,
14
+ img_size=256,
15
+ batch_size=8,
16
+ sample_size=500,
17
+ model="unet",
18
+ backbone="resnet50",
19
+ weights=True,
20
+ num_filters=3,
21
+ loss="ce",
22
+ class_weights=None,
23
+ ignore_index=None,
24
+ lr=0.001,
25
+ patience=10,
26
+ freeze_backbone=False,
27
+ freeze_decoder=False,
28
+ transforms=None,
29
+ use_augmentation=False,
30
+ seed=42,
31
+ train_val_test_split=(0.6, 0.2, 0.2),
32
+ accelerator="auto",
33
+ devices="auto",
34
+ logger=None,
35
+ callbacks=None,
36
+ log_every_n_steps=10,
37
+ use_distributed_sampler=False,
38
+ monitor_metric="val_loss",
39
+ mode="min",
40
+ save_top_k=1,
41
+ save_last=True,
42
+ checkpoint_filename="best_model",
43
+ checkpoint_path=None,
44
+ every_n_epochs=1,
45
+ **kwargs,
46
+ ):
47
+ """Train a semantic segmentation model on geospatial imagery.
48
+
49
+ This function sets up datasets, model, trainer, and executes the training process
50
+ for semantic segmentation tasks using geospatial data. It supports training
51
+ from scratch or resuming from a checkpoint if available.
52
+
53
+ Args:
54
+ image_root (str): Path to directory containing imagery.
55
+ label_root (str): Path to directory containing land cover labels.
56
+ output_dir (str, optional): Directory to save model outputs and checkpoints.
57
+ Defaults to "output".
58
+ in_channels (int, optional): Number of input channels in the imagery.
59
+ Defaults to 4.
60
+ num_classes (int, optional): Number of classes in the segmentation task.
61
+ Defaults to 14.
62
+ epochs (int, optional): Number of training epochs. Defaults to 20.
63
+ img_size (int, optional): Size of image patches for training. Defaults to 256.
64
+ batch_size (int, optional): Batch size for training. Defaults to 8.
65
+ sample_size (int, optional): Number of samples per epoch. Defaults to 500.
66
+ model (str, optional): Model architecture to use. Defaults to "unet".
67
+ backbone (str, optional): Backbone network for the model. Defaults to "resnet50".
68
+ weights (bool, optional): Whether to use pretrained weights. Defaults to True.
69
+ num_filters (int, optional): Number of filters for the model. Defaults to 3.
70
+ loss (str, optional): Loss function to use ('ce', 'jaccard', or 'focal').
71
+ Defaults to "ce".
72
+ class_weights (list, optional): Class weights for loss function. Defaults to None.
73
+ ignore_index (int, optional): Index to ignore in loss calculation. Defaults to None.
74
+ lr (float, optional): Learning rate. Defaults to 0.001.
75
+ patience (int, optional): Number of epochs with no improvement after which
76
+ training will stop. Defaults to 10.
77
+ freeze_backbone (bool, optional): Whether to freeze backbone. Defaults to False.
78
+ freeze_decoder (bool, optional): Whether to freeze decoder. Defaults to False.
79
+ transforms (callable, optional): Transforms to apply to the data. Defaults to None.
80
+ use_augmentation (bool, optional): Whether to apply data augmentation.
81
+ Defaults to False.
82
+ seed (int, optional): Random seed for reproducibility. Defaults to 42.
83
+ train_val_test_split (list, optional): Proportions for train/val/test split.
84
+ Defaults to [0.6, 0.2, 0.2].
85
+ accelerator (str, optional): Accelerator to use for training ('cpu', 'gpu', etc.).
86
+ Defaults to "auto".
87
+ devices (str, optional): Number of devices to use for training. Defaults to "auto".
88
+ logger (object, optional): Logger for tracking training progress. Defaults to None.
89
+ callbacks (list, optional): List of callbacks for the trainer. Defaults to None.
90
+ log_every_n_steps (int, optional): Frequency of logging training progress.
91
+ Defaults to 10.
92
+ use_distributed_sampler (bool, optional): Whether to use distributed sampling.
93
+ Defaults to False.
94
+ monitor_metric (str, optional): Metric to monitor for saving best model.
95
+ Defaults to "val_loss".
96
+ mode (str, optional): Mode for monitoring metric ('min' or 'max').
97
+ Use 'min' for losses and 'max' for metrics like accuracy.
98
+ Defaults to "min".
99
+ save_top_k (int, optional): Number of best models to save.
100
+ Defaults to 1.
101
+ save_last (bool, optional): Whether to save the model from the last epoch.
102
+ Defaults to True.
103
+ checkpoint_filename (str, optional): Filename pattern for saved checkpoints.
104
+ Defaults to "best_model_{epoch:02d}_{val_loss:.4f}".
105
+ checkpoint_path (str, optional): Path to a checkpoint file to resume training.
106
+ every_n_epochs (int, optional): Save a checkpoint every N epochs.
107
+ Defaults to 1.
108
+ **kwargs: Additional keyword arguments to pass to the datasets.
109
+
110
+ Returns:
111
+ object: Trained SemanticSegmentationTask model.
112
+ """
113
+ import lightning.pytorch as pl
114
+ from torch.utils.data import DataLoader
115
+ from torchgeo.datasets import stack_samples, RasterDataset
116
+ from torchgeo.datasets.splits import random_bbox_assignment
117
+ from torchgeo.samplers import (
118
+ RandomGeoSampler,
119
+ RandomBatchGeoSampler,
120
+ GridGeoSampler,
121
+ )
122
+ import torch
123
+ import multiprocessing as mp
124
+ import timeit
125
+ import albumentations as A
126
+ from torchgeo.datamodules import GeoDataModule
127
+ from torchgeo.trainers import SemanticSegmentationTask
128
+ from lightning.pytorch.callbacks import ModelCheckpoint
129
+ from lightning.pytorch.loggers import CSVLogger
130
+
131
+ # Create a wrapper class for albumentations to work with TorchGeo format
132
+ class AlbumentationsWrapper:
133
+ def __init__(self, transform):
134
+ self.transform = transform
135
+
136
+ def __call__(self, sample):
137
+ # Extract image and mask from TorchGeo sample format
138
+ if "image" not in sample or "mask" not in sample:
139
+ return sample
140
+
141
+ image = sample["image"]
142
+ mask = sample["mask"]
143
+
144
+ # Albumentations expects channels last, but TorchGeo uses channels first
145
+ # Convert (C, H, W) to (H, W, C) for image
146
+ image_np = image.permute(1, 2, 0).numpy()
147
+ mask_np = mask.squeeze(0).numpy() if mask.dim() > 2 else mask.numpy()
148
+
149
+ # Apply transformation with named arguments
150
+ transformed = self.transform(image=image_np, mask=mask_np)
151
+
152
+ # Convert back to PyTorch tensors with channels first
153
+ transformed_image = torch.from_numpy(transformed["image"]).permute(2, 0, 1)
154
+ transformed_mask = torch.from_numpy(transformed["mask"]).unsqueeze(0)
155
+
156
+ # Update the sample dictionary
157
+ result = sample.copy()
158
+ result["image"] = transformed_image
159
+ result["mask"] = transformed_mask
160
+
161
+ return result
162
+
163
+ # Set up data augmentation if requested
164
+ if use_augmentation:
165
+ aug_transforms = A.Compose(
166
+ [
167
+ A.HorizontalFlip(p=0.5),
168
+ A.VerticalFlip(p=0.5),
169
+ A.RandomRotate90(p=0.5),
170
+ A.ShiftScaleRotate(
171
+ p=0.5, shift_limit=0.0625, scale_limit=0.1, rotate_limit=45
172
+ ),
173
+ A.RandomBrightnessContrast(
174
+ p=0.5, brightness_limit=0.2, contrast_limit=0.2
175
+ ),
176
+ A.GaussianBlur(p=0.3),
177
+ A.GaussNoise(p=0.3),
178
+ A.CoarseDropout(p=0.3, max_holes=8, max_height=32, max_width=32),
179
+ ]
180
+ )
181
+ # Wrap the albumentations transforms
182
+ transforms = AlbumentationsWrapper(aug_transforms)
183
+
184
+ # # Set up device configuration
185
+ # device, num_devices = (
186
+ # ("cuda", torch.cuda.device_count())
187
+ # if torch.cuda.is_available()
188
+ # else ("cpu", mp.cpu_count())
189
+ # )
190
+ workers = mp.cpu_count()
191
+ # print(f"Running on {num_devices} {device}(s)")
192
+
193
+ # Define datasets
194
+ class ImageDatasetClass(RasterDataset):
195
+ filename_glob = "*.tif"
196
+ is_image = True
197
+ separate_files = False
198
+
199
+ class LabelDatasetClass(RasterDataset):
200
+ filename_glob = "*.tif"
201
+ is_image = False
202
+ separate_files = False
203
+
204
+ # Prepare output directory
205
+ test_dir = os.path.join(output_dir, "models")
206
+ if not os.path.exists(test_dir):
207
+ os.makedirs(test_dir)
208
+
209
+ # Set up logger and checkpoint callback
210
+ if logger is None:
211
+ logger = CSVLogger(test_dir, name="lightning_logs")
212
+
213
+ if callbacks is None:
214
+ checkpoint_callback = ModelCheckpoint(
215
+ dirpath=test_dir,
216
+ filename=checkpoint_filename,
217
+ save_top_k=save_top_k,
218
+ monitor=monitor_metric,
219
+ mode=mode,
220
+ save_last=save_last,
221
+ every_n_epochs=every_n_epochs,
222
+ verbose=True,
223
+ )
224
+ callbacks = [checkpoint_callback]
225
+
226
+ # Initialize the segmentation task
227
+ task = SemanticSegmentationTask(
228
+ model=model,
229
+ backbone=backbone,
230
+ weights=weights,
231
+ in_channels=in_channels,
232
+ num_classes=num_classes,
233
+ num_filters=num_filters,
234
+ loss=loss,
235
+ class_weights=class_weights,
236
+ ignore_index=ignore_index,
237
+ lr=lr,
238
+ patience=patience,
239
+ freeze_backbone=freeze_backbone,
240
+ freeze_decoder=freeze_decoder,
241
+ )
242
+
243
+ # Set up trainer
244
+ trainer = pl.Trainer(
245
+ accelerator=accelerator,
246
+ devices=devices,
247
+ max_epochs=epochs,
248
+ callbacks=callbacks,
249
+ logger=logger,
250
+ log_every_n_steps=log_every_n_steps,
251
+ use_distributed_sampler=use_distributed_sampler,
252
+ **kwargs, # Pass any additional kwargs to the trainer
253
+ )
254
+
255
+ # Load datasets with transforms if augmentation is enabled
256
+
257
+ if isinstance(image_root, RasterDataset):
258
+ images = image_root
259
+ else:
260
+ images = ImageDatasetClass(paths=image_root, transforms=transforms, **kwargs)
261
+
262
+ if isinstance(label_root, RasterDataset):
263
+ labels = label_root
264
+ else:
265
+ labels = LabelDatasetClass(paths=label_root, **kwargs)
266
+
267
+ # Create intersection dataset
268
+ dataset = images & labels
269
+
270
+ # Define custom datamodule for training
271
+ class CustomGeoDataModule(GeoDataModule):
272
+ def setup(self, stage: str) -> None:
273
+ """Set up datasets.
274
+
275
+ Args:
276
+ stage: Either 'fit', 'validate', 'test', or 'predict'.
277
+ """
278
+ self.dataset = self.dataset_class(**self.kwargs)
279
+
280
+ generator = torch.Generator().manual_seed(seed)
281
+ (
282
+ self.train_dataset,
283
+ self.val_dataset,
284
+ self.test_dataset,
285
+ ) = random_bbox_assignment(dataset, train_val_test_split, generator)
286
+
287
+ if stage in ["fit"]:
288
+ self.train_batch_sampler = RandomBatchGeoSampler(
289
+ self.train_dataset, self.patch_size, self.batch_size, self.length
290
+ )
291
+ if stage in ["fit", "validate"]:
292
+ self.val_sampler = GridGeoSampler(
293
+ self.val_dataset, self.patch_size, self.patch_size
294
+ )
295
+ if stage in ["test"]:
296
+ self.test_sampler = GridGeoSampler(
297
+ self.test_dataset, self.patch_size, self.patch_size
298
+ )
299
+
300
+ # Create datamodule
301
+ datamodule = CustomGeoDataModule(
302
+ dataset_class=type(dataset),
303
+ batch_size=batch_size,
304
+ patch_size=img_size,
305
+ length=sample_size,
306
+ num_workers=workers,
307
+ dataset1=images,
308
+ dataset2=labels,
309
+ collate_fn=stack_samples,
310
+ )
311
+
312
+ # Start training timer
313
+ start = timeit.default_timer()
314
+
315
+ # Check for existing checkpoint
316
+ if checkpoint_path is not None:
317
+ checkpoint_file = os.path.abspath(checkpoint_path)
318
+ else:
319
+ checkpoint_file = os.path.join(test_dir, "last.ckpt")
320
+
321
+ if os.path.isfile(checkpoint_file):
322
+ print("Resuming training from previous checkpoint...")
323
+ trainer.fit(model=task, datamodule=datamodule, ckpt_path=checkpoint_file)
324
+ else:
325
+ print("Starting training from scratch...")
326
+ trainer.fit(
327
+ model=task,
328
+ datamodule=datamodule,
329
+ )
330
+
331
+ training_time = timeit.default_timer() - start
332
+ print(f"The time taken to train was: {training_time:.2f} seconds")
333
+
334
+ best_model_path = checkpoint_callback.best_model_path
335
+ print(f"Best model saved at: {best_model_path}")
336
+
337
+ # Test the model
338
+ trainer.test(model=task, datamodule=datamodule)
339
+
340
+ return task
341
+
342
+
343
+ def _classify_image(
344
+ image_path,
345
+ model_path,
346
+ output_path=None,
347
+ chip_size=1024,
348
+ batch_size=4,
349
+ colormap=None,
350
+ num_workers=2,
351
+ **kwargs,
352
+ ):
353
+ """
354
+ Classify a geospatial image using a trained semantic segmentation model. The version has
355
+ tile edge artifacts.
356
+
357
+ This function handles the full image classification pipeline:
358
+ 1. Loads the image and model
359
+ 2. Cuts the image into tiles/chips
360
+ 3. Makes predictions on each chip
361
+ 4. Georeferences each prediction
362
+ 5. Merges all predictions into a single georeferenced output
363
+
364
+ Parameters:
365
+ image_path (str): Path to the input GeoTIFF image.
366
+ model_path (str): Path to the trained model checkpoint.
367
+ output_path (str, optional): Path to save the output classified image.
368
+ Defaults to "classified_output.tif".
369
+ chip_size (int, optional): Size of chips for processing. Defaults to 1024.
370
+ batch_size (int, optional): Batch size for inference. Defaults to 4.
371
+ colormap (dict, optional): Colormap to apply to the output image.
372
+ Defaults to None.
373
+ num_workers (int, optional): Number of workers for DataLoader. Defaults to 2.
374
+ **kwargs: Additional keyword arguments for DataLoader.
375
+
376
+ Returns:
377
+ str: Path to the saved classified image.
378
+ """
379
+ import numpy as np
380
+ import timeit
381
+ from tqdm import tqdm
382
+
383
+ import torch
384
+ from torch.utils.data import DataLoader
385
+
386
+ from torchgeo.datasets import RasterDataset, stack_samples
387
+ from torchgeo.samplers import GridGeoSampler
388
+ from torchgeo.trainers import SemanticSegmentationTask
389
+
390
+ import rasterio
391
+ from rasterio.transform import from_origin
392
+ from rasterio.io import MemoryFile
393
+ from rasterio.merge import merge
394
+
395
+ # Set default output path if not provided
396
+ if output_path is None:
397
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
398
+ output_path = f"{base_name}_classified.tif"
399
+
400
+ # Make sure output directory exists
401
+ output_dir = os.path.dirname(output_path)
402
+ if output_dir and not os.path.exists(output_dir):
403
+ os.makedirs(output_dir)
404
+
405
+ # Load the model
406
+ print(f"Loading model from {model_path}...")
407
+ task = SemanticSegmentationTask.load_from_checkpoint(model_path)
408
+ task.model.eval()
409
+ task.model.cuda()
410
+
411
+ # Set up dataset and sampler
412
+ print(f"Loading image from {image_path}...")
413
+ dataset = RasterDataset(paths=image_path)
414
+
415
+ # Get the bounds and resolution of the dataset
416
+ original_bounds = dataset.bounds
417
+ pixel_size = dataset.res
418
+ crs = dataset.crs.to_epsg()
419
+
420
+ # Use a GridGeoSampler to sample the image in tiles
421
+ sampler = GridGeoSampler(dataset, chip_size, chip_size)
422
+
423
+ # Create DataLoader
424
+ dataloader = DataLoader(
425
+ dataset,
426
+ batch_size=batch_size,
427
+ sampler=sampler,
428
+ collate_fn=stack_samples,
429
+ num_workers=num_workers,
430
+ **kwargs,
431
+ )
432
+
433
+ print(f"Processing image in {len(dataloader)} batches...")
434
+
435
+ # Helper function to create in-memory geotiffs for chips
436
+ def create_in_memory_geochip(predicted_chip, geotransform, crs):
437
+ """Create in-memory georeferenced chips."""
438
+ photometric = "MINISBLACK"
439
+
440
+ # Ensure predicted_chip has shape (bands, height, width)
441
+ if len(predicted_chip.shape) == 2:
442
+ predicted_chip = predicted_chip[np.newaxis, :, :]
443
+
444
+ memfile = MemoryFile()
445
+ dataset = memfile.open(
446
+ driver="GTiff",
447
+ height=predicted_chip.shape[1],
448
+ width=predicted_chip.shape[2],
449
+ count=predicted_chip.shape[0], # Number of bands
450
+ dtype=np.uint8,
451
+ crs=crs,
452
+ transform=geotransform,
453
+ photometric=photometric,
454
+ )
455
+
456
+ # Write all bands
457
+ for band_idx in range(predicted_chip.shape[0]):
458
+ dataset.write(
459
+ predicted_chip[band_idx], band_idx + 1
460
+ ) # Band indices are 1-based in rasterio
461
+
462
+ return dataset
463
+
464
+ # Helper function to clip to original bounds
465
+ def clip_to_original_bounds(tif_path, original_bounds, colormap=None):
466
+ """Clip a GeoTIFF to match original bounds."""
467
+ with rasterio.open(tif_path) as src:
468
+ # Create a window that matches the original bounds
469
+ window = rasterio.windows.from_bounds(
470
+ original_bounds.minx,
471
+ original_bounds.miny,
472
+ original_bounds.maxx,
473
+ original_bounds.maxy,
474
+ transform=src.transform,
475
+ )
476
+
477
+ # Read data within the window
478
+ data = src.read(window=window)
479
+
480
+ # Update the transform
481
+ transform = rasterio.windows.transform(window, src.transform)
482
+
483
+ # Create new metadata
484
+ meta = src.meta.copy()
485
+ meta.update(
486
+ {
487
+ "height": window.height,
488
+ "width": window.width,
489
+ "transform": transform,
490
+ "compress": "deflate",
491
+ }
492
+ )
493
+
494
+ # Write the clipped data to the same file
495
+ with rasterio.open(tif_path, "w", **meta) as dst:
496
+ dst.write(data)
497
+ if isinstance(colormap, dict):
498
+ dst.write_colormap(1, colormap)
499
+
500
+ # Run inference on all chips
501
+ start_time = timeit.default_timer()
502
+ georref_chips_list = []
503
+
504
+ # Progress bar for processing chips
505
+ progress_bar = tqdm(total=len(dataloader), desc="Processing tiles", unit="batch")
506
+
507
+ for batch in dataloader:
508
+ # Get images and bounds
509
+ images = batch["image"]
510
+ bounds_list = batch["bounds"]
511
+
512
+ # Normalize images
513
+ images = images / 255.0
514
+
515
+ # Make predictions
516
+ with torch.no_grad():
517
+ predictions = task.model.predict(images.cuda())
518
+ predictions = torch.softmax(predictions, dim=1)
519
+ predictions = torch.argmax(predictions, dim=1)
520
+
521
+ # Process each prediction in the batch
522
+ for i in range(len(predictions)):
523
+ # Get the bounds for this chip
524
+ bounds = bounds_list[i]
525
+
526
+ # Create geotransform
527
+ geotransform = from_origin(bounds.minx, bounds.maxy, pixel_size, pixel_size)
528
+
529
+ # Convert prediction to numpy array
530
+ pred = predictions[i].cpu().numpy().astype(np.uint8)
531
+ if len(pred.shape) == 2:
532
+ pred = pred[np.newaxis, :, :]
533
+
534
+ # Create georeferenced chip
535
+ georref_chips_list.append(create_in_memory_geochip(pred, geotransform, crs))
536
+
537
+ # Update progress bar
538
+ progress_bar.update(1)
539
+
540
+ progress_bar.close()
541
+
542
+ prediction_time = timeit.default_timer() - start_time
543
+ print(f"Prediction complete in {prediction_time:.2f} seconds")
544
+ print(f"Produced {len(georref_chips_list)} georeferenced chips")
545
+
546
+ # Merge all georeferenced chips into a single output
547
+ print("Merging predictions...")
548
+ merge_start = timeit.default_timer()
549
+
550
+ # Merge the chips using Rasterio's merge function
551
+ merged, merged_transform = merge(georref_chips_list)
552
+
553
+ # Calculate the number of rows and columns for the merged output
554
+ rows, cols = merged.shape[1], merged.shape[2]
555
+
556
+ # Update the metadata of the merged dataset
557
+ merged_metadata = georref_chips_list[0].meta
558
+ merged_metadata.update(
559
+ {"height": rows, "width": cols, "transform": merged_transform}
560
+ )
561
+
562
+ # Write the merged array to the output file
563
+ with rasterio.open(output_path, "w", **merged_metadata) as dst:
564
+ dst.write(merged)
565
+ # if isinstance(colormap, dict):
566
+ # dst.write_colormap(1, colormap)
567
+
568
+ # Clip to original bounds
569
+ print("Clipping to original image bounds...")
570
+ clip_to_original_bounds(output_path, original_bounds, colormap)
571
+
572
+ # Close all chip datasets
573
+ for chip in tqdm(georref_chips_list, desc="Cleaning up", unit="chip"):
574
+ chip.close()
575
+
576
+ merge_time = timeit.default_timer() - merge_start
577
+ total_time = timeit.default_timer() - start_time
578
+
579
+ print(f"Merge and save complete in {merge_time:.2f} seconds")
580
+ print(f"Total processing time: {total_time:.2f} seconds")
581
+ print(f"Successfully saved classified image to {output_path}")
582
+
583
+ return output_path
584
+
585
+
586
+ def classify_image(
587
+ image_path,
588
+ model_path,
589
+ output_path=None,
590
+ chip_size=1024,
591
+ overlap=256,
592
+ batch_size=4,
593
+ colormap=None,
594
+ **kwargs,
595
+ ):
596
+ """
597
+ Classify a geospatial image using a trained semantic segmentation model.
598
+
599
+ This function handles the full image classification pipeline with special
600
+ attention to edge handling:
601
+ 1. Process the image in a grid pattern with overlapping tiles
602
+ 2. Use central regions of tiles for interior parts
603
+ 3. Special handling for edges to ensure complete coverage
604
+ 4. Merge results into a single georeferenced output
605
+
606
+ Parameters:
607
+ image_path (str): Path to the input GeoTIFF image.
608
+ model_path (str): Path to the trained model checkpoint.
609
+ output_path (str, optional): Path to save the output classified image.
610
+ Defaults to "[input_name]_classified.tif".
611
+ chip_size (int, optional): Size of chips for processing. Defaults to 1024.
612
+ overlap (int, optional): Overlap size between adjacent tiles. Defaults to 256.
613
+ batch_size (int, optional): Batch size for inference. Defaults to 4.
614
+ colormap (dict, optional): Colormap to apply to the output image.
615
+ Defaults to None.
616
+ **kwargs: Additional keyword arguments for DataLoader.
617
+
618
+ Returns:
619
+ str: Path to the saved classified image.
620
+ """
621
+ import timeit
622
+
623
+ import torch
624
+ from torchgeo.trainers import SemanticSegmentationTask
625
+
626
+ import rasterio
627
+
628
+ import warnings
629
+ from rasterio.errors import NotGeoreferencedWarning
630
+
631
+ # Disable specific GDAL/rasterio warnings
632
+ warnings.filterwarnings("ignore", category=UserWarning, module="rasterio._.*")
633
+ warnings.filterwarnings("ignore", category=UserWarning, module="rasterio")
634
+ warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
635
+
636
+ # Also suppress GDAL error reports
637
+ import logging
638
+
639
+ logging.getLogger("rasterio").setLevel(logging.ERROR)
640
+
641
+ # Set default output path if not provided
642
+ if output_path is None:
643
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
644
+ output_path = f"{base_name}_classified.tif"
645
+
646
+ # Make sure output directory exists
647
+ output_dir = os.path.dirname(output_path)
648
+ if output_dir and not os.path.exists(output_dir):
649
+ os.makedirs(output_dir)
650
+
651
+ # Load the model
652
+ print(f"Loading model from {model_path}...")
653
+ task = SemanticSegmentationTask.load_from_checkpoint(model_path)
654
+ task.model.eval()
655
+ task.model.cuda()
656
+
657
+ # Process the image using a modified tiling approach
658
+ with rasterio.open(image_path) as src:
659
+ # Get image dimensions and metadata
660
+ height = src.height
661
+ width = src.width
662
+ profile = src.profile.copy()
663
+
664
+ # Prepare output array for the final result
665
+ output_image = np.zeros((height, width), dtype=np.uint8)
666
+ confidence_map = np.zeros((height, width), dtype=np.float32)
667
+
668
+ # Calculate number of tiles needed with overlap
669
+ # Ensure we have tiles that specifically cover the edges
670
+ effective_stride = chip_size - overlap
671
+
672
+ # Calculate x positions ensuring leftmost and rightmost edges are covered
673
+ x_positions = []
674
+ # Always include the leftmost position
675
+ x_positions.append(0)
676
+ # Add regular grid positions
677
+ for x in range(effective_stride, width - chip_size, effective_stride):
678
+ x_positions.append(x)
679
+ # Always include rightmost position that still fits
680
+ if width > chip_size and x_positions[-1] + chip_size < width:
681
+ x_positions.append(width - chip_size)
682
+
683
+ # Calculate y positions ensuring top and bottom edges are covered
684
+ y_positions = []
685
+ # Always include the topmost position
686
+ y_positions.append(0)
687
+ # Add regular grid positions
688
+ for y in range(effective_stride, height - chip_size, effective_stride):
689
+ y_positions.append(y)
690
+ # Always include bottommost position that still fits
691
+ if height > chip_size and y_positions[-1] + chip_size < height:
692
+ y_positions.append(height - chip_size)
693
+
694
+ # Create list of all tile positions
695
+ tile_positions = []
696
+ for y in y_positions:
697
+ for x in x_positions:
698
+ y_end = min(y + chip_size, height)
699
+ x_end = min(x + chip_size, width)
700
+ tile_positions.append((y, x, y_end, x_end))
701
+
702
+ # Print information about the tiling
703
+ print(
704
+ f"Processing {len(tile_positions)} patches covering an image of size {height}x{width}..."
705
+ )
706
+ start_time = timeit.default_timer()
707
+
708
+ # Process tiles in batches
709
+ for batch_start in range(0, len(tile_positions), batch_size):
710
+ batch_end = min(batch_start + batch_size, len(tile_positions))
711
+ batch_positions = tile_positions[batch_start:batch_end]
712
+ batch_data = []
713
+
714
+ # Load data for current batch
715
+ for y_start, x_start, y_end, x_end in batch_positions:
716
+ # Calculate actual tile size
717
+ actual_height = y_end - y_start
718
+ actual_width = x_end - x_start
719
+
720
+ # Read the tile data
721
+ tile_data = src.read(window=((y_start, y_end), (x_start, x_end)))
722
+
723
+ # Handle different sized tiles by padding if necessary
724
+ if tile_data.shape[1] != chip_size or tile_data.shape[2] != chip_size:
725
+ padded_data = np.zeros(
726
+ (tile_data.shape[0], chip_size, chip_size),
727
+ dtype=tile_data.dtype,
728
+ )
729
+ padded_data[:, : tile_data.shape[1], : tile_data.shape[2]] = (
730
+ tile_data
731
+ )
732
+ tile_data = padded_data
733
+
734
+ # Convert to tensor
735
+
736
+ tile_tensor = torch.from_numpy(tile_data).float() / 255.0
737
+ batch_data.append(tile_tensor)
738
+
739
+ # Convert batch to tensor
740
+ batch_tensor = torch.stack(batch_data)
741
+
742
+ # Run inference
743
+ with torch.no_grad():
744
+ logits = task.model.predict(batch_tensor.cuda())
745
+ probs = torch.softmax(logits, dim=1)
746
+ confidence, predictions = torch.max(probs, dim=1)
747
+ predictions = predictions.cpu().numpy()
748
+ confidence = confidence.cpu().numpy()
749
+
750
+ # Process each prediction
751
+ for idx, (y_start, x_start, y_end, x_end) in enumerate(batch_positions):
752
+ pred = predictions[idx]
753
+ conf = confidence[idx]
754
+
755
+ # Calculate actual tile size
756
+ actual_height = y_end - y_start
757
+ actual_width = x_end - x_start
758
+
759
+ # Get the actual prediction (removing padding if needed)
760
+ valid_pred = pred[:actual_height, :actual_width]
761
+ valid_conf = conf[:actual_height, :actual_width]
762
+
763
+ # Create confidence weights that favor central parts of tiles
764
+ # but still allow edge tiles to contribute fully at the image edges
765
+ is_edge_x = (x_start == 0) or (x_end == width)
766
+ is_edge_y = (y_start == 0) or (y_end == height)
767
+
768
+ # Create a mask that gives higher weight to central regions
769
+ # but ensures proper edge handling for boundary tiles
770
+ weight_mask = np.ones((actual_height, actual_width), dtype=np.float32)
771
+
772
+ # Only apply central weighting if not at an image edge
773
+ border = overlap // 2
774
+ if not is_edge_x and actual_width > 2 * border:
775
+ # Apply horizontal edge falloff (linear)
776
+ for i in range(border):
777
+ # Left edge
778
+ weight_mask[:, i] = (i + 1) / (border + 1)
779
+ # Right edge (if not at image edge)
780
+ if i < actual_width - border:
781
+ weight_mask[:, actual_width - i - 1] = (i + 1) / (
782
+ border + 1
783
+ )
784
+
785
+ if not is_edge_y and actual_height > 2 * border:
786
+ # Apply vertical edge falloff (linear)
787
+ for i in range(border):
788
+ # Top edge
789
+ weight_mask[i, :] = (i + 1) / (border + 1)
790
+ # Bottom edge (if not at image edge)
791
+ if i < actual_height - border:
792
+ weight_mask[actual_height - i - 1, :] = (i + 1) / (
793
+ border + 1
794
+ )
795
+
796
+ # Combine with prediction confidence
797
+ final_weight = weight_mask * valid_conf
798
+
799
+ # Update the output image based on confidence
800
+ current_conf = confidence_map[y_start:y_end, x_start:x_end]
801
+ update_mask = final_weight > current_conf
802
+
803
+ if np.any(update_mask):
804
+ # Update only pixels where this prediction has higher confidence
805
+ output_image[y_start:y_end, x_start:x_end][update_mask] = (
806
+ valid_pred[update_mask]
807
+ )
808
+ confidence_map[y_start:y_end, x_start:x_end][update_mask] = (
809
+ final_weight[update_mask]
810
+ )
811
+
812
+ # Update profile for output
813
+ profile.update({"count": 1, "dtype": "uint8", "nodata": 0})
814
+
815
+ # Save the result
816
+ print(f"Saving classified image to {output_path}...")
817
+ with rasterio.open(output_path, "w", **profile) as dst:
818
+ dst.write(output_image[np.newaxis, :, :])
819
+ if isinstance(colormap, dict):
820
+ dst.write_colormap(1, colormap)
821
+
822
+ # Calculate timing
823
+ total_time = timeit.default_timer() - start_time
824
+ print(f"Total processing time: {total_time:.2f} seconds")
825
+ print(f"Successfully saved classified image to {output_path}")
826
+
827
+ return output_path
828
+
829
+
830
+ def classify_images(
831
+ image_paths,
832
+ model_path,
833
+ output_dir=None,
834
+ chip_size=1024,
835
+ batch_size=4,
836
+ colormap=None,
837
+ file_extension=".tif",
838
+ **kwargs,
839
+ ):
840
+ """
841
+ Classify multiple geospatial images using a trained semantic segmentation model.
842
+
843
+ This function accepts either a list of image paths or a directory containing images
844
+ and applies the classify_image function to each image, saving the results in the
845
+ specified output directory.
846
+
847
+ Parameters:
848
+ image_paths (str or list): Either a directory path containing images or a list
849
+ of paths to input GeoTIFF images.
850
+ model_path (str): Path to the trained model checkpoint.
851
+ output_dir (str, optional): Directory to save the output classified images.
852
+ Defaults to None (same directory as input images for a list, or a new
853
+ "classified" subdirectory for a directory input).
854
+ chip_size (int, optional): Size of chips for processing. Defaults to 1024.
855
+ batch_size (int, optional): Batch size for inference. Defaults to 4.
856
+ colormap (dict, optional): Colormap to apply to the output images.
857
+ Defaults to None.
858
+ file_extension (str, optional): File extension to filter by when image_paths
859
+ is a directory. Defaults to ".tif".
860
+ **kwargs: Additional keyword arguments for the classify_image function.
861
+
862
+ Returns:
863
+ list: List of paths to the saved classified images.
864
+ """
865
+ # Import required libraries
866
+ from tqdm import tqdm
867
+ import glob
868
+
869
+ # Process directory input
870
+ if isinstance(image_paths, str) and os.path.isdir(image_paths):
871
+ # Set default output directory if not provided
872
+ if output_dir is None:
873
+ output_dir = os.path.join(image_paths, "classified")
874
+
875
+ # Get all images with the specified extension
876
+ image_path_list = glob.glob(os.path.join(image_paths, f"*{file_extension}"))
877
+
878
+ # Check if any images were found
879
+ if not image_path_list:
880
+ print(f"No files with extension '{file_extension}' found in {image_paths}")
881
+ return []
882
+
883
+ print(f"Found {len(image_path_list)} images in directory {image_paths}")
884
+
885
+ # Process list input
886
+ elif isinstance(image_paths, list):
887
+ image_path_list = image_paths
888
+
889
+ # Set default output directory if not provided
890
+ if output_dir is None and len(image_path_list) > 0:
891
+ output_dir = os.path.dirname(image_path_list[0])
892
+
893
+ # Invalid input
894
+ else:
895
+ raise ValueError(
896
+ "image_paths must be either a directory path or a list of file paths"
897
+ )
898
+
899
+ # Create output directory if it doesn't exist
900
+ if not os.path.exists(output_dir):
901
+ os.makedirs(output_dir)
902
+
903
+ classified_image_paths = []
904
+
905
+ # Create progress bar
906
+ for image_path in tqdm(image_path_list, desc="Classifying images", unit="image"):
907
+ try:
908
+ # Get just the filename without extension
909
+ base_filename = os.path.splitext(os.path.basename(image_path))[0]
910
+
911
+ # Create output path within output_dir
912
+ output_path = os.path.join(
913
+ output_dir, f"{base_filename}_classified{file_extension}"
914
+ )
915
+
916
+ # Perform classification
917
+ classified_image_path = classify_image(
918
+ image_path,
919
+ model_path,
920
+ output_path=output_path,
921
+ chip_size=chip_size,
922
+ batch_size=batch_size,
923
+ colormap=colormap,
924
+ **kwargs,
925
+ )
926
+ classified_image_paths.append(classified_image_path)
927
+ except Exception as e:
928
+ print(f"Error processing {image_path}: {str(e)}")
929
+
930
+ print(
931
+ f"Classification complete. Processed {len(classified_image_paths)} images successfully."
932
+ )
933
+ return classified_image_paths