geoai-py 0.14.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +21 -1
- geoai/change_detection.py +16 -6
- geoai/geoai.py +3 -0
- geoai/timm_segment.py +1097 -0
- geoai/timm_train.py +658 -0
- geoai/train.py +796 -107
- geoai/utils.py +1427 -245
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/METADATA +9 -1
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/RECORD +13 -11
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/licenses/LICENSE +1 -2
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/top_level.txt +0 -0
geoai/timm_train.py
ADDED
@@ -0,0 +1,658 @@
|
|
1
|
+
"""Module for training and fine-tuning models using timm (PyTorch Image Models) with remote sensing imagery."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
from torch.utils.data import Dataset, DataLoader
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
try:
|
13
|
+
import timm
|
14
|
+
|
15
|
+
TIMM_AVAILABLE = True
|
16
|
+
except ImportError:
|
17
|
+
TIMM_AVAILABLE = False
|
18
|
+
|
19
|
+
try:
|
20
|
+
import lightning.pytorch as pl
|
21
|
+
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
|
22
|
+
from lightning.pytorch.loggers import CSVLogger
|
23
|
+
|
24
|
+
LIGHTNING_AVAILABLE = True
|
25
|
+
except ImportError:
|
26
|
+
LIGHTNING_AVAILABLE = False
|
27
|
+
|
28
|
+
|
29
|
+
def get_timm_model(
|
30
|
+
model_name: str = "resnet50",
|
31
|
+
num_classes: int = 10,
|
32
|
+
in_channels: int = 3,
|
33
|
+
pretrained: bool = True,
|
34
|
+
features_only: bool = False,
|
35
|
+
**kwargs: Any,
|
36
|
+
) -> nn.Module:
|
37
|
+
"""
|
38
|
+
Create a timm model with custom input channels for remote sensing imagery.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
model_name (str): Name of the timm model (e.g., 'resnet50', 'efficientnet_b0',
|
42
|
+
'vit_base_patch16_224', 'convnext_base').
|
43
|
+
num_classes (int): Number of output classes for classification.
|
44
|
+
in_channels (int): Number of input channels (3 for RGB, 4 for RGBN, etc.).
|
45
|
+
pretrained (bool): Whether to use pretrained weights.
|
46
|
+
features_only (bool): If True, return feature extraction model without classifier.
|
47
|
+
**kwargs: Additional arguments to pass to timm.create_model.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
nn.Module: Configured timm model.
|
51
|
+
|
52
|
+
Raises:
|
53
|
+
ImportError: If timm is not installed.
|
54
|
+
ValueError: If model_name is not available in timm.
|
55
|
+
"""
|
56
|
+
if not TIMM_AVAILABLE:
|
57
|
+
raise ImportError("timm is required. Install it with: pip install timm")
|
58
|
+
|
59
|
+
# Check if model exists
|
60
|
+
if model_name not in timm.list_models():
|
61
|
+
available_models = timm.list_models(pretrained=True)[:10]
|
62
|
+
raise ValueError(
|
63
|
+
f"Model '{model_name}' not found in timm. "
|
64
|
+
f"First 10 available models: {available_models}. "
|
65
|
+
f"See all models at: https://github.com/huggingface/pytorch-image-models"
|
66
|
+
)
|
67
|
+
|
68
|
+
# Create base model
|
69
|
+
model = timm.create_model(
|
70
|
+
model_name,
|
71
|
+
pretrained=pretrained,
|
72
|
+
num_classes=num_classes if not features_only else 0,
|
73
|
+
in_chans=in_channels,
|
74
|
+
features_only=features_only,
|
75
|
+
**kwargs,
|
76
|
+
)
|
77
|
+
|
78
|
+
return model
|
79
|
+
|
80
|
+
|
81
|
+
def modify_first_conv_for_channels(
|
82
|
+
model: nn.Module,
|
83
|
+
in_channels: int,
|
84
|
+
pretrained_channels: int = 3,
|
85
|
+
) -> nn.Module:
|
86
|
+
"""
|
87
|
+
Modify the first convolutional layer of a model to accept different number of input channels.
|
88
|
+
|
89
|
+
This is useful when you have a pretrained model with 3 input channels but want to use
|
90
|
+
imagery with more channels (e.g., 4 for RGBN, or more for multispectral).
|
91
|
+
|
92
|
+
Args:
|
93
|
+
model (nn.Module): PyTorch model to modify.
|
94
|
+
in_channels (int): Desired number of input channels.
|
95
|
+
pretrained_channels (int): Number of channels in pretrained weights (usually 3).
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
nn.Module: Modified model with updated first conv layer.
|
99
|
+
"""
|
100
|
+
if in_channels == pretrained_channels:
|
101
|
+
return model
|
102
|
+
|
103
|
+
# Find the first conv layer (different models have different architectures)
|
104
|
+
first_conv_name = None
|
105
|
+
first_conv = None
|
106
|
+
|
107
|
+
# Common patterns for first conv layers
|
108
|
+
possible_names = ["conv1", "conv_stem", "patch_embed.proj", "stem.conv1"]
|
109
|
+
|
110
|
+
for name in possible_names:
|
111
|
+
try:
|
112
|
+
parts = name.split(".")
|
113
|
+
module = model
|
114
|
+
for part in parts:
|
115
|
+
module = getattr(module, part)
|
116
|
+
if isinstance(module, nn.Conv2d):
|
117
|
+
first_conv_name = name
|
118
|
+
first_conv = module
|
119
|
+
break
|
120
|
+
except AttributeError:
|
121
|
+
continue
|
122
|
+
|
123
|
+
if first_conv is None:
|
124
|
+
# Fallback: search recursively
|
125
|
+
for name, module in model.named_modules():
|
126
|
+
if isinstance(module, nn.Conv2d):
|
127
|
+
first_conv_name = name
|
128
|
+
first_conv = module
|
129
|
+
break
|
130
|
+
|
131
|
+
if first_conv is None:
|
132
|
+
raise ValueError("Could not find first convolutional layer in model")
|
133
|
+
|
134
|
+
# Create new conv layer with desired input channels
|
135
|
+
new_conv = nn.Conv2d(
|
136
|
+
in_channels,
|
137
|
+
first_conv.out_channels,
|
138
|
+
kernel_size=first_conv.kernel_size,
|
139
|
+
stride=first_conv.stride,
|
140
|
+
padding=first_conv.padding,
|
141
|
+
bias=first_conv.bias is not None,
|
142
|
+
)
|
143
|
+
|
144
|
+
# Initialize weights
|
145
|
+
with torch.no_grad():
|
146
|
+
if pretrained_channels == 3 and in_channels > 3:
|
147
|
+
# Copy RGB weights
|
148
|
+
new_conv.weight[:, :3, :, :] = first_conv.weight
|
149
|
+
|
150
|
+
# Initialize additional channels with mean of RGB weights
|
151
|
+
mean_weight = first_conv.weight.mean(dim=1, keepdim=True)
|
152
|
+
for i in range(3, in_channels):
|
153
|
+
new_conv.weight[:, i : i + 1, :, :] = mean_weight
|
154
|
+
else:
|
155
|
+
# Generic initialization
|
156
|
+
nn.init.kaiming_normal_(
|
157
|
+
new_conv.weight, mode="fan_out", nonlinearity="relu"
|
158
|
+
)
|
159
|
+
|
160
|
+
if first_conv.bias is not None:
|
161
|
+
new_conv.bias = first_conv.bias
|
162
|
+
|
163
|
+
# Replace the first conv layer
|
164
|
+
parts = first_conv_name.split(".")
|
165
|
+
if len(parts) == 1:
|
166
|
+
setattr(model, first_conv_name, new_conv)
|
167
|
+
else:
|
168
|
+
parent = model
|
169
|
+
for part in parts[:-1]:
|
170
|
+
parent = getattr(parent, part)
|
171
|
+
setattr(parent, parts[-1], new_conv)
|
172
|
+
|
173
|
+
return model
|
174
|
+
|
175
|
+
|
176
|
+
class TimmClassifier(pl.LightningModule):
|
177
|
+
"""
|
178
|
+
PyTorch Lightning module for image classification using timm models.
|
179
|
+
"""
|
180
|
+
|
181
|
+
def __init__(
|
182
|
+
self,
|
183
|
+
model_name: str = "resnet50",
|
184
|
+
num_classes: int = 10,
|
185
|
+
in_channels: int = 3,
|
186
|
+
pretrained: bool = True,
|
187
|
+
learning_rate: float = 1e-3,
|
188
|
+
weight_decay: float = 1e-4,
|
189
|
+
freeze_backbone: bool = False,
|
190
|
+
loss_fn: Optional[nn.Module] = None,
|
191
|
+
class_weights: Optional[torch.Tensor] = None,
|
192
|
+
**model_kwargs: Any,
|
193
|
+
):
|
194
|
+
"""
|
195
|
+
Initialize TimmClassifier.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
model_name (str): Name of timm model.
|
199
|
+
num_classes (int): Number of output classes.
|
200
|
+
in_channels (int): Number of input channels.
|
201
|
+
pretrained (bool): Use pretrained weights.
|
202
|
+
learning_rate (float): Learning rate for optimizer.
|
203
|
+
weight_decay (float): Weight decay for optimizer.
|
204
|
+
freeze_backbone (bool): Freeze backbone weights during training.
|
205
|
+
loss_fn (nn.Module, optional): Custom loss function. Defaults to CrossEntropyLoss.
|
206
|
+
class_weights (torch.Tensor, optional): Class weights for loss function.
|
207
|
+
**model_kwargs: Additional arguments for timm model.
|
208
|
+
"""
|
209
|
+
super().__init__()
|
210
|
+
|
211
|
+
if not TIMM_AVAILABLE:
|
212
|
+
raise ImportError("timm is required. Install it with: pip install timm")
|
213
|
+
|
214
|
+
self.save_hyperparameters()
|
215
|
+
|
216
|
+
self.model = get_timm_model(
|
217
|
+
model_name=model_name,
|
218
|
+
num_classes=num_classes,
|
219
|
+
in_channels=in_channels,
|
220
|
+
pretrained=pretrained,
|
221
|
+
**model_kwargs,
|
222
|
+
)
|
223
|
+
|
224
|
+
if freeze_backbone:
|
225
|
+
self._freeze_backbone()
|
226
|
+
|
227
|
+
# Set up loss function
|
228
|
+
if loss_fn is not None:
|
229
|
+
self.loss_fn = loss_fn
|
230
|
+
elif class_weights is not None:
|
231
|
+
self.loss_fn = nn.CrossEntropyLoss(weight=class_weights)
|
232
|
+
else:
|
233
|
+
self.loss_fn = nn.CrossEntropyLoss()
|
234
|
+
|
235
|
+
self.learning_rate = learning_rate
|
236
|
+
self.weight_decay = weight_decay
|
237
|
+
|
238
|
+
def _freeze_backbone(self):
|
239
|
+
"""Freeze all layers except the classifier head."""
|
240
|
+
for name, param in self.model.named_parameters():
|
241
|
+
if "fc" not in name and "head" not in name and "classifier" not in name:
|
242
|
+
param.requires_grad = False
|
243
|
+
|
244
|
+
def forward(self, x):
|
245
|
+
return self.model(x)
|
246
|
+
|
247
|
+
def training_step(self, batch, batch_idx):
|
248
|
+
x, y = batch
|
249
|
+
logits = self(x)
|
250
|
+
loss = self.loss_fn(logits, y)
|
251
|
+
|
252
|
+
# Calculate accuracy
|
253
|
+
preds = torch.argmax(logits, dim=1)
|
254
|
+
acc = (preds == y).float().mean()
|
255
|
+
|
256
|
+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
257
|
+
self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
|
258
|
+
|
259
|
+
return loss
|
260
|
+
|
261
|
+
def validation_step(self, batch, batch_idx):
|
262
|
+
x, y = batch
|
263
|
+
logits = self(x)
|
264
|
+
loss = self.loss_fn(logits, y)
|
265
|
+
|
266
|
+
# Calculate accuracy
|
267
|
+
preds = torch.argmax(logits, dim=1)
|
268
|
+
acc = (preds == y).float().mean()
|
269
|
+
|
270
|
+
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
|
271
|
+
self.log("val_acc", acc, on_epoch=True, prog_bar=True)
|
272
|
+
|
273
|
+
return loss
|
274
|
+
|
275
|
+
def test_step(self, batch, batch_idx):
|
276
|
+
x, y = batch
|
277
|
+
logits = self(x)
|
278
|
+
loss = self.loss_fn(logits, y)
|
279
|
+
|
280
|
+
# Calculate accuracy
|
281
|
+
preds = torch.argmax(logits, dim=1)
|
282
|
+
acc = (preds == y).float().mean()
|
283
|
+
|
284
|
+
self.log("test_loss", loss, on_epoch=True)
|
285
|
+
self.log("test_acc", acc, on_epoch=True)
|
286
|
+
|
287
|
+
return loss
|
288
|
+
|
289
|
+
def configure_optimizers(self):
|
290
|
+
optimizer = torch.optim.AdamW(
|
291
|
+
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
292
|
+
)
|
293
|
+
|
294
|
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
295
|
+
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
296
|
+
)
|
297
|
+
|
298
|
+
return {
|
299
|
+
"optimizer": optimizer,
|
300
|
+
"lr_scheduler": {
|
301
|
+
"scheduler": scheduler,
|
302
|
+
"monitor": "val_loss",
|
303
|
+
},
|
304
|
+
}
|
305
|
+
|
306
|
+
def predict_step(self, batch, batch_idx):
|
307
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
308
|
+
logits = self(x)
|
309
|
+
probs = torch.softmax(logits, dim=1)
|
310
|
+
preds = torch.argmax(probs, dim=1)
|
311
|
+
return {"predictions": preds, "probabilities": probs}
|
312
|
+
|
313
|
+
|
314
|
+
class RemoteSensingDataset(Dataset):
|
315
|
+
"""
|
316
|
+
Dataset for remote sensing imagery classification.
|
317
|
+
|
318
|
+
This dataset handles loading raster images and their corresponding labels
|
319
|
+
for training classification models.
|
320
|
+
"""
|
321
|
+
|
322
|
+
def __init__(
|
323
|
+
self,
|
324
|
+
image_paths: List[str],
|
325
|
+
labels: List[int],
|
326
|
+
transform: Optional[Callable] = None,
|
327
|
+
num_channels: Optional[int] = None,
|
328
|
+
):
|
329
|
+
"""
|
330
|
+
Initialize RemoteSensingDataset.
|
331
|
+
|
332
|
+
Args:
|
333
|
+
image_paths (List[str]): List of paths to image files.
|
334
|
+
labels (List[int]): List of integer labels corresponding to images.
|
335
|
+
transform (callable, optional): Transform to apply to images.
|
336
|
+
num_channels (int, optional): Number of channels to use. If None, uses all.
|
337
|
+
"""
|
338
|
+
self.image_paths = image_paths
|
339
|
+
self.labels = labels
|
340
|
+
self.transform = transform
|
341
|
+
self.num_channels = num_channels
|
342
|
+
|
343
|
+
if len(image_paths) != len(labels):
|
344
|
+
raise ValueError("Number of images must match number of labels")
|
345
|
+
|
346
|
+
def __len__(self):
|
347
|
+
return len(self.image_paths)
|
348
|
+
|
349
|
+
def __getitem__(self, idx):
|
350
|
+
import rasterio
|
351
|
+
|
352
|
+
# Load image
|
353
|
+
with rasterio.open(self.image_paths[idx]) as src:
|
354
|
+
image = src.read() # Shape: (C, H, W)
|
355
|
+
|
356
|
+
# Handle channel selection
|
357
|
+
if self.num_channels is not None and image.shape[0] != self.num_channels:
|
358
|
+
if image.shape[0] > self.num_channels:
|
359
|
+
image = image[: self.num_channels]
|
360
|
+
else:
|
361
|
+
# Pad with zeros if needed
|
362
|
+
padded = np.zeros(
|
363
|
+
(self.num_channels, image.shape[1], image.shape[2])
|
364
|
+
)
|
365
|
+
padded[: image.shape[0]] = image
|
366
|
+
image = padded
|
367
|
+
|
368
|
+
# Normalize to [0, 1]
|
369
|
+
if image.max() > 1.0:
|
370
|
+
image = image / 255.0
|
371
|
+
|
372
|
+
image = image.astype(np.float32)
|
373
|
+
|
374
|
+
# Get label
|
375
|
+
label = self.labels[idx]
|
376
|
+
|
377
|
+
# Convert to tensor
|
378
|
+
image = torch.from_numpy(image)
|
379
|
+
label = torch.tensor(label, dtype=torch.long)
|
380
|
+
|
381
|
+
# Apply transforms if provided
|
382
|
+
if self.transform is not None:
|
383
|
+
image = self.transform(image)
|
384
|
+
|
385
|
+
return image, label
|
386
|
+
|
387
|
+
|
388
|
+
def train_timm_classifier(
|
389
|
+
train_dataset: Dataset,
|
390
|
+
val_dataset: Optional[Dataset] = None,
|
391
|
+
test_dataset: Optional[Dataset] = None,
|
392
|
+
model_name: str = "resnet50",
|
393
|
+
num_classes: int = 10,
|
394
|
+
in_channels: int = 3,
|
395
|
+
pretrained: bool = True,
|
396
|
+
output_dir: str = "output",
|
397
|
+
batch_size: int = 32,
|
398
|
+
num_epochs: int = 50,
|
399
|
+
learning_rate: float = 1e-3,
|
400
|
+
weight_decay: float = 1e-4,
|
401
|
+
num_workers: int = 4,
|
402
|
+
freeze_backbone: bool = False,
|
403
|
+
class_weights: Optional[List[float]] = None,
|
404
|
+
accelerator: str = "auto",
|
405
|
+
devices: str = "auto",
|
406
|
+
monitor_metric: str = "val_loss",
|
407
|
+
mode: str = "min",
|
408
|
+
patience: int = 10,
|
409
|
+
save_top_k: int = 1,
|
410
|
+
checkpoint_path: Optional[str] = None,
|
411
|
+
**kwargs: Any,
|
412
|
+
) -> TimmClassifier:
|
413
|
+
"""
|
414
|
+
Train a timm-based classifier on remote sensing imagery.
|
415
|
+
|
416
|
+
Args:
|
417
|
+
train_dataset (Dataset): Training dataset.
|
418
|
+
val_dataset (Dataset, optional): Validation dataset.
|
419
|
+
test_dataset (Dataset, optional): Test dataset.
|
420
|
+
model_name (str): Name of timm model to use.
|
421
|
+
num_classes (int): Number of output classes.
|
422
|
+
in_channels (int): Number of input channels.
|
423
|
+
pretrained (bool): Use pretrained weights.
|
424
|
+
output_dir (str): Directory to save outputs.
|
425
|
+
batch_size (int): Batch size for training.
|
426
|
+
num_epochs (int): Number of training epochs.
|
427
|
+
learning_rate (float): Learning rate.
|
428
|
+
weight_decay (float): Weight decay for optimizer.
|
429
|
+
num_workers (int): Number of data loading workers.
|
430
|
+
freeze_backbone (bool): Freeze backbone during training.
|
431
|
+
class_weights (List[float], optional): Class weights for loss.
|
432
|
+
accelerator (str): Accelerator type ('auto', 'gpu', 'cpu').
|
433
|
+
devices (str): Devices to use.
|
434
|
+
monitor_metric (str): Metric to monitor for checkpointing.
|
435
|
+
mode (str): 'min' or 'max' for monitor_metric.
|
436
|
+
patience (int): Early stopping patience.
|
437
|
+
save_top_k (int): Number of best models to save.
|
438
|
+
checkpoint_path (str, optional): Path to checkpoint to resume from.
|
439
|
+
**kwargs: Additional arguments for PyTorch Lightning Trainer.
|
440
|
+
|
441
|
+
Returns:
|
442
|
+
TimmClassifier: Trained model.
|
443
|
+
|
444
|
+
Raises:
|
445
|
+
ImportError: If PyTorch Lightning is not installed.
|
446
|
+
"""
|
447
|
+
if not LIGHTNING_AVAILABLE:
|
448
|
+
raise ImportError(
|
449
|
+
"PyTorch Lightning is required. Install it with: pip install lightning"
|
450
|
+
)
|
451
|
+
|
452
|
+
# Create output directory
|
453
|
+
os.makedirs(output_dir, exist_ok=True)
|
454
|
+
model_dir = os.path.join(output_dir, "models")
|
455
|
+
os.makedirs(model_dir, exist_ok=True)
|
456
|
+
|
457
|
+
# Convert class weights to tensor if provided
|
458
|
+
weight_tensor = None
|
459
|
+
if class_weights is not None:
|
460
|
+
weight_tensor = torch.tensor(class_weights, dtype=torch.float32)
|
461
|
+
|
462
|
+
# Create model
|
463
|
+
model = TimmClassifier(
|
464
|
+
model_name=model_name,
|
465
|
+
num_classes=num_classes,
|
466
|
+
in_channels=in_channels,
|
467
|
+
pretrained=pretrained,
|
468
|
+
learning_rate=learning_rate,
|
469
|
+
weight_decay=weight_decay,
|
470
|
+
freeze_backbone=freeze_backbone,
|
471
|
+
class_weights=weight_tensor,
|
472
|
+
)
|
473
|
+
|
474
|
+
# Create data loaders
|
475
|
+
train_loader = DataLoader(
|
476
|
+
train_dataset,
|
477
|
+
batch_size=batch_size,
|
478
|
+
shuffle=True,
|
479
|
+
num_workers=num_workers,
|
480
|
+
pin_memory=True,
|
481
|
+
)
|
482
|
+
|
483
|
+
val_loader = None
|
484
|
+
if val_dataset is not None:
|
485
|
+
val_loader = DataLoader(
|
486
|
+
val_dataset,
|
487
|
+
batch_size=batch_size,
|
488
|
+
shuffle=False,
|
489
|
+
num_workers=num_workers,
|
490
|
+
pin_memory=True,
|
491
|
+
)
|
492
|
+
|
493
|
+
# Set up callbacks
|
494
|
+
callbacks = []
|
495
|
+
|
496
|
+
# Model checkpoint
|
497
|
+
checkpoint_callback = ModelCheckpoint(
|
498
|
+
dirpath=model_dir,
|
499
|
+
filename=f"{model_name}_{{epoch:02d}}_{{val_loss:.4f}}",
|
500
|
+
monitor=monitor_metric,
|
501
|
+
mode=mode,
|
502
|
+
save_top_k=save_top_k,
|
503
|
+
save_last=True,
|
504
|
+
verbose=True,
|
505
|
+
)
|
506
|
+
callbacks.append(checkpoint_callback)
|
507
|
+
|
508
|
+
# Early stopping
|
509
|
+
early_stop_callback = EarlyStopping(
|
510
|
+
monitor=monitor_metric,
|
511
|
+
patience=patience,
|
512
|
+
mode=mode,
|
513
|
+
verbose=True,
|
514
|
+
)
|
515
|
+
callbacks.append(early_stop_callback)
|
516
|
+
|
517
|
+
# Set up logger
|
518
|
+
logger = CSVLogger(model_dir, name="lightning_logs")
|
519
|
+
|
520
|
+
# Create trainer
|
521
|
+
trainer = pl.Trainer(
|
522
|
+
max_epochs=num_epochs,
|
523
|
+
accelerator=accelerator,
|
524
|
+
devices=devices,
|
525
|
+
callbacks=callbacks,
|
526
|
+
logger=logger,
|
527
|
+
log_every_n_steps=10,
|
528
|
+
**kwargs,
|
529
|
+
)
|
530
|
+
|
531
|
+
# Train model
|
532
|
+
print(f"Training {model_name} for {num_epochs} epochs...")
|
533
|
+
trainer.fit(
|
534
|
+
model,
|
535
|
+
train_dataloaders=train_loader,
|
536
|
+
val_dataloaders=val_loader,
|
537
|
+
ckpt_path=checkpoint_path,
|
538
|
+
)
|
539
|
+
|
540
|
+
# Test if test dataset provided
|
541
|
+
if test_dataset is not None:
|
542
|
+
test_loader = DataLoader(
|
543
|
+
test_dataset,
|
544
|
+
batch_size=batch_size,
|
545
|
+
shuffle=False,
|
546
|
+
num_workers=num_workers,
|
547
|
+
pin_memory=True,
|
548
|
+
)
|
549
|
+
print("\nTesting model on test set...")
|
550
|
+
trainer.test(model, dataloaders=test_loader)
|
551
|
+
|
552
|
+
print(f"\nBest model saved at: {checkpoint_callback.best_model_path}")
|
553
|
+
|
554
|
+
return model
|
555
|
+
|
556
|
+
|
557
|
+
def predict_with_timm(
|
558
|
+
model: Union[TimmClassifier, nn.Module],
|
559
|
+
image_paths: List[str],
|
560
|
+
batch_size: int = 32,
|
561
|
+
num_workers: int = 4,
|
562
|
+
device: Optional[str] = None,
|
563
|
+
return_probabilities: bool = False,
|
564
|
+
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
565
|
+
"""
|
566
|
+
Make predictions on images using a trained timm model.
|
567
|
+
|
568
|
+
Args:
|
569
|
+
model: Trained model (TimmClassifier or nn.Module).
|
570
|
+
image_paths: List of paths to images.
|
571
|
+
batch_size: Batch size for inference.
|
572
|
+
num_workers: Number of data loading workers.
|
573
|
+
device: Device to use ('cuda', 'cpu', etc.). Auto-detected if None.
|
574
|
+
return_probabilities: If True, return both predictions and probabilities.
|
575
|
+
|
576
|
+
Returns:
|
577
|
+
predictions: Array of predicted class indices.
|
578
|
+
probabilities (optional): Array of class probabilities if return_probabilities=True.
|
579
|
+
"""
|
580
|
+
if device is None:
|
581
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
582
|
+
|
583
|
+
# Create dummy labels for dataset
|
584
|
+
dummy_labels = [0] * len(image_paths)
|
585
|
+
dataset = RemoteSensingDataset(image_paths, dummy_labels)
|
586
|
+
|
587
|
+
loader = DataLoader(
|
588
|
+
dataset,
|
589
|
+
batch_size=batch_size,
|
590
|
+
shuffle=False,
|
591
|
+
num_workers=num_workers,
|
592
|
+
pin_memory=True,
|
593
|
+
)
|
594
|
+
|
595
|
+
model.eval()
|
596
|
+
model = model.to(device)
|
597
|
+
|
598
|
+
all_preds = []
|
599
|
+
all_probs = []
|
600
|
+
|
601
|
+
with torch.no_grad():
|
602
|
+
for images, _ in tqdm(loader, desc="Making predictions"):
|
603
|
+
images = images.to(device)
|
604
|
+
|
605
|
+
if isinstance(model, TimmClassifier):
|
606
|
+
logits = model(images)
|
607
|
+
else:
|
608
|
+
logits = model(images)
|
609
|
+
|
610
|
+
probs = torch.softmax(logits, dim=1)
|
611
|
+
preds = torch.argmax(probs, dim=1)
|
612
|
+
|
613
|
+
all_preds.append(preds.cpu().numpy())
|
614
|
+
if return_probabilities:
|
615
|
+
all_probs.append(probs.cpu().numpy())
|
616
|
+
|
617
|
+
predictions = np.concatenate(all_preds)
|
618
|
+
|
619
|
+
if return_probabilities:
|
620
|
+
probabilities = np.concatenate(all_probs)
|
621
|
+
return predictions, probabilities
|
622
|
+
|
623
|
+
return predictions
|
624
|
+
|
625
|
+
|
626
|
+
def list_timm_models(
|
627
|
+
filter: str = "",
|
628
|
+
pretrained: bool = False,
|
629
|
+
limit: Optional[int] = None,
|
630
|
+
) -> List[str]:
|
631
|
+
"""
|
632
|
+
List available timm models.
|
633
|
+
|
634
|
+
Args:
|
635
|
+
filter (str): Filter models by name pattern (e.g., 'resnet', 'efficientnet').
|
636
|
+
The filter supports wildcards. If no wildcards are provided, '*' is added automatically.
|
637
|
+
pretrained (bool): Only show models with pretrained weights.
|
638
|
+
limit (int, optional): Maximum number of models to return.
|
639
|
+
|
640
|
+
Returns:
|
641
|
+
List of model names.
|
642
|
+
|
643
|
+
Raises:
|
644
|
+
ImportError: If timm is not installed.
|
645
|
+
"""
|
646
|
+
if not TIMM_AVAILABLE:
|
647
|
+
raise ImportError("timm is required. Install it with: pip install timm")
|
648
|
+
|
649
|
+
# Add wildcards if not present in filter
|
650
|
+
if filter and "*" not in filter:
|
651
|
+
filter = f"*{filter}*"
|
652
|
+
|
653
|
+
models = timm.list_models(filter=filter, pretrained=pretrained)
|
654
|
+
|
655
|
+
if limit is not None:
|
656
|
+
models = models[:limit]
|
657
|
+
|
658
|
+
return models
|