geoai-py 0.14.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/timm_train.py ADDED
@@ -0,0 +1,658 @@
1
+ """Module for training and fine-tuning models using timm (PyTorch Image Models) with remote sensing imagery."""
2
+
3
+ import os
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from tqdm import tqdm
11
+
12
+ try:
13
+ import timm
14
+
15
+ TIMM_AVAILABLE = True
16
+ except ImportError:
17
+ TIMM_AVAILABLE = False
18
+
19
+ try:
20
+ import lightning.pytorch as pl
21
+ from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
22
+ from lightning.pytorch.loggers import CSVLogger
23
+
24
+ LIGHTNING_AVAILABLE = True
25
+ except ImportError:
26
+ LIGHTNING_AVAILABLE = False
27
+
28
+
29
+ def get_timm_model(
30
+ model_name: str = "resnet50",
31
+ num_classes: int = 10,
32
+ in_channels: int = 3,
33
+ pretrained: bool = True,
34
+ features_only: bool = False,
35
+ **kwargs: Any,
36
+ ) -> nn.Module:
37
+ """
38
+ Create a timm model with custom input channels for remote sensing imagery.
39
+
40
+ Args:
41
+ model_name (str): Name of the timm model (e.g., 'resnet50', 'efficientnet_b0',
42
+ 'vit_base_patch16_224', 'convnext_base').
43
+ num_classes (int): Number of output classes for classification.
44
+ in_channels (int): Number of input channels (3 for RGB, 4 for RGBN, etc.).
45
+ pretrained (bool): Whether to use pretrained weights.
46
+ features_only (bool): If True, return feature extraction model without classifier.
47
+ **kwargs: Additional arguments to pass to timm.create_model.
48
+
49
+ Returns:
50
+ nn.Module: Configured timm model.
51
+
52
+ Raises:
53
+ ImportError: If timm is not installed.
54
+ ValueError: If model_name is not available in timm.
55
+ """
56
+ if not TIMM_AVAILABLE:
57
+ raise ImportError("timm is required. Install it with: pip install timm")
58
+
59
+ # Check if model exists
60
+ if model_name not in timm.list_models():
61
+ available_models = timm.list_models(pretrained=True)[:10]
62
+ raise ValueError(
63
+ f"Model '{model_name}' not found in timm. "
64
+ f"First 10 available models: {available_models}. "
65
+ f"See all models at: https://github.com/huggingface/pytorch-image-models"
66
+ )
67
+
68
+ # Create base model
69
+ model = timm.create_model(
70
+ model_name,
71
+ pretrained=pretrained,
72
+ num_classes=num_classes if not features_only else 0,
73
+ in_chans=in_channels,
74
+ features_only=features_only,
75
+ **kwargs,
76
+ )
77
+
78
+ return model
79
+
80
+
81
+ def modify_first_conv_for_channels(
82
+ model: nn.Module,
83
+ in_channels: int,
84
+ pretrained_channels: int = 3,
85
+ ) -> nn.Module:
86
+ """
87
+ Modify the first convolutional layer of a model to accept different number of input channels.
88
+
89
+ This is useful when you have a pretrained model with 3 input channels but want to use
90
+ imagery with more channels (e.g., 4 for RGBN, or more for multispectral).
91
+
92
+ Args:
93
+ model (nn.Module): PyTorch model to modify.
94
+ in_channels (int): Desired number of input channels.
95
+ pretrained_channels (int): Number of channels in pretrained weights (usually 3).
96
+
97
+ Returns:
98
+ nn.Module: Modified model with updated first conv layer.
99
+ """
100
+ if in_channels == pretrained_channels:
101
+ return model
102
+
103
+ # Find the first conv layer (different models have different architectures)
104
+ first_conv_name = None
105
+ first_conv = None
106
+
107
+ # Common patterns for first conv layers
108
+ possible_names = ["conv1", "conv_stem", "patch_embed.proj", "stem.conv1"]
109
+
110
+ for name in possible_names:
111
+ try:
112
+ parts = name.split(".")
113
+ module = model
114
+ for part in parts:
115
+ module = getattr(module, part)
116
+ if isinstance(module, nn.Conv2d):
117
+ first_conv_name = name
118
+ first_conv = module
119
+ break
120
+ except AttributeError:
121
+ continue
122
+
123
+ if first_conv is None:
124
+ # Fallback: search recursively
125
+ for name, module in model.named_modules():
126
+ if isinstance(module, nn.Conv2d):
127
+ first_conv_name = name
128
+ first_conv = module
129
+ break
130
+
131
+ if first_conv is None:
132
+ raise ValueError("Could not find first convolutional layer in model")
133
+
134
+ # Create new conv layer with desired input channels
135
+ new_conv = nn.Conv2d(
136
+ in_channels,
137
+ first_conv.out_channels,
138
+ kernel_size=first_conv.kernel_size,
139
+ stride=first_conv.stride,
140
+ padding=first_conv.padding,
141
+ bias=first_conv.bias is not None,
142
+ )
143
+
144
+ # Initialize weights
145
+ with torch.no_grad():
146
+ if pretrained_channels == 3 and in_channels > 3:
147
+ # Copy RGB weights
148
+ new_conv.weight[:, :3, :, :] = first_conv.weight
149
+
150
+ # Initialize additional channels with mean of RGB weights
151
+ mean_weight = first_conv.weight.mean(dim=1, keepdim=True)
152
+ for i in range(3, in_channels):
153
+ new_conv.weight[:, i : i + 1, :, :] = mean_weight
154
+ else:
155
+ # Generic initialization
156
+ nn.init.kaiming_normal_(
157
+ new_conv.weight, mode="fan_out", nonlinearity="relu"
158
+ )
159
+
160
+ if first_conv.bias is not None:
161
+ new_conv.bias = first_conv.bias
162
+
163
+ # Replace the first conv layer
164
+ parts = first_conv_name.split(".")
165
+ if len(parts) == 1:
166
+ setattr(model, first_conv_name, new_conv)
167
+ else:
168
+ parent = model
169
+ for part in parts[:-1]:
170
+ parent = getattr(parent, part)
171
+ setattr(parent, parts[-1], new_conv)
172
+
173
+ return model
174
+
175
+
176
+ class TimmClassifier(pl.LightningModule):
177
+ """
178
+ PyTorch Lightning module for image classification using timm models.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ model_name: str = "resnet50",
184
+ num_classes: int = 10,
185
+ in_channels: int = 3,
186
+ pretrained: bool = True,
187
+ learning_rate: float = 1e-3,
188
+ weight_decay: float = 1e-4,
189
+ freeze_backbone: bool = False,
190
+ loss_fn: Optional[nn.Module] = None,
191
+ class_weights: Optional[torch.Tensor] = None,
192
+ **model_kwargs: Any,
193
+ ):
194
+ """
195
+ Initialize TimmClassifier.
196
+
197
+ Args:
198
+ model_name (str): Name of timm model.
199
+ num_classes (int): Number of output classes.
200
+ in_channels (int): Number of input channels.
201
+ pretrained (bool): Use pretrained weights.
202
+ learning_rate (float): Learning rate for optimizer.
203
+ weight_decay (float): Weight decay for optimizer.
204
+ freeze_backbone (bool): Freeze backbone weights during training.
205
+ loss_fn (nn.Module, optional): Custom loss function. Defaults to CrossEntropyLoss.
206
+ class_weights (torch.Tensor, optional): Class weights for loss function.
207
+ **model_kwargs: Additional arguments for timm model.
208
+ """
209
+ super().__init__()
210
+
211
+ if not TIMM_AVAILABLE:
212
+ raise ImportError("timm is required. Install it with: pip install timm")
213
+
214
+ self.save_hyperparameters()
215
+
216
+ self.model = get_timm_model(
217
+ model_name=model_name,
218
+ num_classes=num_classes,
219
+ in_channels=in_channels,
220
+ pretrained=pretrained,
221
+ **model_kwargs,
222
+ )
223
+
224
+ if freeze_backbone:
225
+ self._freeze_backbone()
226
+
227
+ # Set up loss function
228
+ if loss_fn is not None:
229
+ self.loss_fn = loss_fn
230
+ elif class_weights is not None:
231
+ self.loss_fn = nn.CrossEntropyLoss(weight=class_weights)
232
+ else:
233
+ self.loss_fn = nn.CrossEntropyLoss()
234
+
235
+ self.learning_rate = learning_rate
236
+ self.weight_decay = weight_decay
237
+
238
+ def _freeze_backbone(self):
239
+ """Freeze all layers except the classifier head."""
240
+ for name, param in self.model.named_parameters():
241
+ if "fc" not in name and "head" not in name and "classifier" not in name:
242
+ param.requires_grad = False
243
+
244
+ def forward(self, x):
245
+ return self.model(x)
246
+
247
+ def training_step(self, batch, batch_idx):
248
+ x, y = batch
249
+ logits = self(x)
250
+ loss = self.loss_fn(logits, y)
251
+
252
+ # Calculate accuracy
253
+ preds = torch.argmax(logits, dim=1)
254
+ acc = (preds == y).float().mean()
255
+
256
+ self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
257
+ self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
258
+
259
+ return loss
260
+
261
+ def validation_step(self, batch, batch_idx):
262
+ x, y = batch
263
+ logits = self(x)
264
+ loss = self.loss_fn(logits, y)
265
+
266
+ # Calculate accuracy
267
+ preds = torch.argmax(logits, dim=1)
268
+ acc = (preds == y).float().mean()
269
+
270
+ self.log("val_loss", loss, on_epoch=True, prog_bar=True)
271
+ self.log("val_acc", acc, on_epoch=True, prog_bar=True)
272
+
273
+ return loss
274
+
275
+ def test_step(self, batch, batch_idx):
276
+ x, y = batch
277
+ logits = self(x)
278
+ loss = self.loss_fn(logits, y)
279
+
280
+ # Calculate accuracy
281
+ preds = torch.argmax(logits, dim=1)
282
+ acc = (preds == y).float().mean()
283
+
284
+ self.log("test_loss", loss, on_epoch=True)
285
+ self.log("test_acc", acc, on_epoch=True)
286
+
287
+ return loss
288
+
289
+ def configure_optimizers(self):
290
+ optimizer = torch.optim.AdamW(
291
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
292
+ )
293
+
294
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
295
+ optimizer, mode="min", factor=0.5, patience=5, verbose=True
296
+ )
297
+
298
+ return {
299
+ "optimizer": optimizer,
300
+ "lr_scheduler": {
301
+ "scheduler": scheduler,
302
+ "monitor": "val_loss",
303
+ },
304
+ }
305
+
306
+ def predict_step(self, batch, batch_idx):
307
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
308
+ logits = self(x)
309
+ probs = torch.softmax(logits, dim=1)
310
+ preds = torch.argmax(probs, dim=1)
311
+ return {"predictions": preds, "probabilities": probs}
312
+
313
+
314
+ class RemoteSensingDataset(Dataset):
315
+ """
316
+ Dataset for remote sensing imagery classification.
317
+
318
+ This dataset handles loading raster images and their corresponding labels
319
+ for training classification models.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ image_paths: List[str],
325
+ labels: List[int],
326
+ transform: Optional[Callable] = None,
327
+ num_channels: Optional[int] = None,
328
+ ):
329
+ """
330
+ Initialize RemoteSensingDataset.
331
+
332
+ Args:
333
+ image_paths (List[str]): List of paths to image files.
334
+ labels (List[int]): List of integer labels corresponding to images.
335
+ transform (callable, optional): Transform to apply to images.
336
+ num_channels (int, optional): Number of channels to use. If None, uses all.
337
+ """
338
+ self.image_paths = image_paths
339
+ self.labels = labels
340
+ self.transform = transform
341
+ self.num_channels = num_channels
342
+
343
+ if len(image_paths) != len(labels):
344
+ raise ValueError("Number of images must match number of labels")
345
+
346
+ def __len__(self):
347
+ return len(self.image_paths)
348
+
349
+ def __getitem__(self, idx):
350
+ import rasterio
351
+
352
+ # Load image
353
+ with rasterio.open(self.image_paths[idx]) as src:
354
+ image = src.read() # Shape: (C, H, W)
355
+
356
+ # Handle channel selection
357
+ if self.num_channels is not None and image.shape[0] != self.num_channels:
358
+ if image.shape[0] > self.num_channels:
359
+ image = image[: self.num_channels]
360
+ else:
361
+ # Pad with zeros if needed
362
+ padded = np.zeros(
363
+ (self.num_channels, image.shape[1], image.shape[2])
364
+ )
365
+ padded[: image.shape[0]] = image
366
+ image = padded
367
+
368
+ # Normalize to [0, 1]
369
+ if image.max() > 1.0:
370
+ image = image / 255.0
371
+
372
+ image = image.astype(np.float32)
373
+
374
+ # Get label
375
+ label = self.labels[idx]
376
+
377
+ # Convert to tensor
378
+ image = torch.from_numpy(image)
379
+ label = torch.tensor(label, dtype=torch.long)
380
+
381
+ # Apply transforms if provided
382
+ if self.transform is not None:
383
+ image = self.transform(image)
384
+
385
+ return image, label
386
+
387
+
388
+ def train_timm_classifier(
389
+ train_dataset: Dataset,
390
+ val_dataset: Optional[Dataset] = None,
391
+ test_dataset: Optional[Dataset] = None,
392
+ model_name: str = "resnet50",
393
+ num_classes: int = 10,
394
+ in_channels: int = 3,
395
+ pretrained: bool = True,
396
+ output_dir: str = "output",
397
+ batch_size: int = 32,
398
+ num_epochs: int = 50,
399
+ learning_rate: float = 1e-3,
400
+ weight_decay: float = 1e-4,
401
+ num_workers: int = 4,
402
+ freeze_backbone: bool = False,
403
+ class_weights: Optional[List[float]] = None,
404
+ accelerator: str = "auto",
405
+ devices: str = "auto",
406
+ monitor_metric: str = "val_loss",
407
+ mode: str = "min",
408
+ patience: int = 10,
409
+ save_top_k: int = 1,
410
+ checkpoint_path: Optional[str] = None,
411
+ **kwargs: Any,
412
+ ) -> TimmClassifier:
413
+ """
414
+ Train a timm-based classifier on remote sensing imagery.
415
+
416
+ Args:
417
+ train_dataset (Dataset): Training dataset.
418
+ val_dataset (Dataset, optional): Validation dataset.
419
+ test_dataset (Dataset, optional): Test dataset.
420
+ model_name (str): Name of timm model to use.
421
+ num_classes (int): Number of output classes.
422
+ in_channels (int): Number of input channels.
423
+ pretrained (bool): Use pretrained weights.
424
+ output_dir (str): Directory to save outputs.
425
+ batch_size (int): Batch size for training.
426
+ num_epochs (int): Number of training epochs.
427
+ learning_rate (float): Learning rate.
428
+ weight_decay (float): Weight decay for optimizer.
429
+ num_workers (int): Number of data loading workers.
430
+ freeze_backbone (bool): Freeze backbone during training.
431
+ class_weights (List[float], optional): Class weights for loss.
432
+ accelerator (str): Accelerator type ('auto', 'gpu', 'cpu').
433
+ devices (str): Devices to use.
434
+ monitor_metric (str): Metric to monitor for checkpointing.
435
+ mode (str): 'min' or 'max' for monitor_metric.
436
+ patience (int): Early stopping patience.
437
+ save_top_k (int): Number of best models to save.
438
+ checkpoint_path (str, optional): Path to checkpoint to resume from.
439
+ **kwargs: Additional arguments for PyTorch Lightning Trainer.
440
+
441
+ Returns:
442
+ TimmClassifier: Trained model.
443
+
444
+ Raises:
445
+ ImportError: If PyTorch Lightning is not installed.
446
+ """
447
+ if not LIGHTNING_AVAILABLE:
448
+ raise ImportError(
449
+ "PyTorch Lightning is required. Install it with: pip install lightning"
450
+ )
451
+
452
+ # Create output directory
453
+ os.makedirs(output_dir, exist_ok=True)
454
+ model_dir = os.path.join(output_dir, "models")
455
+ os.makedirs(model_dir, exist_ok=True)
456
+
457
+ # Convert class weights to tensor if provided
458
+ weight_tensor = None
459
+ if class_weights is not None:
460
+ weight_tensor = torch.tensor(class_weights, dtype=torch.float32)
461
+
462
+ # Create model
463
+ model = TimmClassifier(
464
+ model_name=model_name,
465
+ num_classes=num_classes,
466
+ in_channels=in_channels,
467
+ pretrained=pretrained,
468
+ learning_rate=learning_rate,
469
+ weight_decay=weight_decay,
470
+ freeze_backbone=freeze_backbone,
471
+ class_weights=weight_tensor,
472
+ )
473
+
474
+ # Create data loaders
475
+ train_loader = DataLoader(
476
+ train_dataset,
477
+ batch_size=batch_size,
478
+ shuffle=True,
479
+ num_workers=num_workers,
480
+ pin_memory=True,
481
+ )
482
+
483
+ val_loader = None
484
+ if val_dataset is not None:
485
+ val_loader = DataLoader(
486
+ val_dataset,
487
+ batch_size=batch_size,
488
+ shuffle=False,
489
+ num_workers=num_workers,
490
+ pin_memory=True,
491
+ )
492
+
493
+ # Set up callbacks
494
+ callbacks = []
495
+
496
+ # Model checkpoint
497
+ checkpoint_callback = ModelCheckpoint(
498
+ dirpath=model_dir,
499
+ filename=f"{model_name}_{{epoch:02d}}_{{val_loss:.4f}}",
500
+ monitor=monitor_metric,
501
+ mode=mode,
502
+ save_top_k=save_top_k,
503
+ save_last=True,
504
+ verbose=True,
505
+ )
506
+ callbacks.append(checkpoint_callback)
507
+
508
+ # Early stopping
509
+ early_stop_callback = EarlyStopping(
510
+ monitor=monitor_metric,
511
+ patience=patience,
512
+ mode=mode,
513
+ verbose=True,
514
+ )
515
+ callbacks.append(early_stop_callback)
516
+
517
+ # Set up logger
518
+ logger = CSVLogger(model_dir, name="lightning_logs")
519
+
520
+ # Create trainer
521
+ trainer = pl.Trainer(
522
+ max_epochs=num_epochs,
523
+ accelerator=accelerator,
524
+ devices=devices,
525
+ callbacks=callbacks,
526
+ logger=logger,
527
+ log_every_n_steps=10,
528
+ **kwargs,
529
+ )
530
+
531
+ # Train model
532
+ print(f"Training {model_name} for {num_epochs} epochs...")
533
+ trainer.fit(
534
+ model,
535
+ train_dataloaders=train_loader,
536
+ val_dataloaders=val_loader,
537
+ ckpt_path=checkpoint_path,
538
+ )
539
+
540
+ # Test if test dataset provided
541
+ if test_dataset is not None:
542
+ test_loader = DataLoader(
543
+ test_dataset,
544
+ batch_size=batch_size,
545
+ shuffle=False,
546
+ num_workers=num_workers,
547
+ pin_memory=True,
548
+ )
549
+ print("\nTesting model on test set...")
550
+ trainer.test(model, dataloaders=test_loader)
551
+
552
+ print(f"\nBest model saved at: {checkpoint_callback.best_model_path}")
553
+
554
+ return model
555
+
556
+
557
+ def predict_with_timm(
558
+ model: Union[TimmClassifier, nn.Module],
559
+ image_paths: List[str],
560
+ batch_size: int = 32,
561
+ num_workers: int = 4,
562
+ device: Optional[str] = None,
563
+ return_probabilities: bool = False,
564
+ ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
565
+ """
566
+ Make predictions on images using a trained timm model.
567
+
568
+ Args:
569
+ model: Trained model (TimmClassifier or nn.Module).
570
+ image_paths: List of paths to images.
571
+ batch_size: Batch size for inference.
572
+ num_workers: Number of data loading workers.
573
+ device: Device to use ('cuda', 'cpu', etc.). Auto-detected if None.
574
+ return_probabilities: If True, return both predictions and probabilities.
575
+
576
+ Returns:
577
+ predictions: Array of predicted class indices.
578
+ probabilities (optional): Array of class probabilities if return_probabilities=True.
579
+ """
580
+ if device is None:
581
+ device = "cuda" if torch.cuda.is_available() else "cpu"
582
+
583
+ # Create dummy labels for dataset
584
+ dummy_labels = [0] * len(image_paths)
585
+ dataset = RemoteSensingDataset(image_paths, dummy_labels)
586
+
587
+ loader = DataLoader(
588
+ dataset,
589
+ batch_size=batch_size,
590
+ shuffle=False,
591
+ num_workers=num_workers,
592
+ pin_memory=True,
593
+ )
594
+
595
+ model.eval()
596
+ model = model.to(device)
597
+
598
+ all_preds = []
599
+ all_probs = []
600
+
601
+ with torch.no_grad():
602
+ for images, _ in tqdm(loader, desc="Making predictions"):
603
+ images = images.to(device)
604
+
605
+ if isinstance(model, TimmClassifier):
606
+ logits = model(images)
607
+ else:
608
+ logits = model(images)
609
+
610
+ probs = torch.softmax(logits, dim=1)
611
+ preds = torch.argmax(probs, dim=1)
612
+
613
+ all_preds.append(preds.cpu().numpy())
614
+ if return_probabilities:
615
+ all_probs.append(probs.cpu().numpy())
616
+
617
+ predictions = np.concatenate(all_preds)
618
+
619
+ if return_probabilities:
620
+ probabilities = np.concatenate(all_probs)
621
+ return predictions, probabilities
622
+
623
+ return predictions
624
+
625
+
626
+ def list_timm_models(
627
+ filter: str = "",
628
+ pretrained: bool = False,
629
+ limit: Optional[int] = None,
630
+ ) -> List[str]:
631
+ """
632
+ List available timm models.
633
+
634
+ Args:
635
+ filter (str): Filter models by name pattern (e.g., 'resnet', 'efficientnet').
636
+ The filter supports wildcards. If no wildcards are provided, '*' is added automatically.
637
+ pretrained (bool): Only show models with pretrained weights.
638
+ limit (int, optional): Maximum number of models to return.
639
+
640
+ Returns:
641
+ List of model names.
642
+
643
+ Raises:
644
+ ImportError: If timm is not installed.
645
+ """
646
+ if not TIMM_AVAILABLE:
647
+ raise ImportError("timm is required. Install it with: pip install timm")
648
+
649
+ # Add wildcards if not present in filter
650
+ if filter and "*" not in filter:
651
+ filter = f"*{filter}*"
652
+
653
+ models = timm.list_models(filter=filter, pretrained=pretrained)
654
+
655
+ if limit is not None:
656
+ models = models[:limit]
657
+
658
+ return models