geoai-py 0.14.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +21 -1
- geoai/change_detection.py +16 -6
- geoai/geoai.py +3 -0
- geoai/timm_segment.py +1097 -0
- geoai/timm_train.py +658 -0
- geoai/train.py +796 -107
- geoai/utils.py +1427 -245
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/METADATA +9 -1
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/RECORD +13 -11
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/licenses/LICENSE +1 -2
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/top_level.txt +0 -0
geoai/timm_segment.py
ADDED
@@ -0,0 +1,1097 @@
|
|
1
|
+
"""Module for training semantic segmentation models using timm encoders with PyTorch Lightning."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
from torch.utils.data import Dataset, DataLoader
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
try:
|
13
|
+
import timm
|
14
|
+
|
15
|
+
TIMM_AVAILABLE = True
|
16
|
+
except ImportError:
|
17
|
+
TIMM_AVAILABLE = False
|
18
|
+
|
19
|
+
try:
|
20
|
+
import segmentation_models_pytorch as smp
|
21
|
+
|
22
|
+
SMP_AVAILABLE = True
|
23
|
+
except ImportError:
|
24
|
+
SMP_AVAILABLE = False
|
25
|
+
|
26
|
+
try:
|
27
|
+
import lightning.pytorch as pl
|
28
|
+
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
|
29
|
+
from lightning.pytorch.loggers import CSVLogger
|
30
|
+
|
31
|
+
LIGHTNING_AVAILABLE = True
|
32
|
+
except ImportError:
|
33
|
+
LIGHTNING_AVAILABLE = False
|
34
|
+
|
35
|
+
|
36
|
+
class TimmSegmentationModel(pl.LightningModule):
|
37
|
+
"""
|
38
|
+
PyTorch Lightning module for semantic segmentation using timm encoders with SMP decoders,
|
39
|
+
or pure timm models from Hugging Face Hub.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
encoder_name: str = "resnet50",
|
45
|
+
architecture: str = "unet",
|
46
|
+
num_classes: int = 2,
|
47
|
+
in_channels: int = 3,
|
48
|
+
encoder_weights: str = "imagenet",
|
49
|
+
learning_rate: float = 1e-3,
|
50
|
+
weight_decay: float = 1e-4,
|
51
|
+
freeze_encoder: bool = False,
|
52
|
+
loss_fn: Optional[nn.Module] = None,
|
53
|
+
class_weights: Optional[torch.Tensor] = None,
|
54
|
+
use_timm_model: bool = False,
|
55
|
+
timm_model_name: Optional[str] = None,
|
56
|
+
**decoder_kwargs: Any,
|
57
|
+
):
|
58
|
+
"""
|
59
|
+
Initialize TimmSegmentationModel.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
encoder_name (str): Name of encoder (e.g., 'resnet50', 'efficientnet_b0').
|
63
|
+
architecture (str): Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3',
|
64
|
+
'deeplabv3plus', 'fpn', 'pspnet', 'linknet', 'manet', 'pan').
|
65
|
+
Ignored if use_timm_model=True.
|
66
|
+
num_classes (int): Number of output classes.
|
67
|
+
in_channels (int): Number of input channels.
|
68
|
+
encoder_weights (str): Pretrained weights for encoder ('imagenet', 'ssl', 'swsl', None).
|
69
|
+
learning_rate (float): Learning rate for optimizer.
|
70
|
+
weight_decay (float): Weight decay for optimizer.
|
71
|
+
freeze_encoder (bool): Freeze encoder weights during training.
|
72
|
+
loss_fn (nn.Module, optional): Custom loss function. Defaults to CrossEntropyLoss.
|
73
|
+
class_weights (torch.Tensor, optional): Class weights for loss function.
|
74
|
+
use_timm_model (bool): If True, load a complete segmentation model from timm/HF Hub
|
75
|
+
instead of using SMP architecture. Defaults to False.
|
76
|
+
timm_model_name (str, optional): Name or path of timm model from HF Hub
|
77
|
+
(e.g., 'hf-hub:timm/segformer_b0.ade_512x512' or 'nvidia/mit-b0').
|
78
|
+
Only used if use_timm_model=True.
|
79
|
+
**decoder_kwargs: Additional arguments for decoder (only used with SMP).
|
80
|
+
"""
|
81
|
+
super().__init__()
|
82
|
+
|
83
|
+
if not TIMM_AVAILABLE:
|
84
|
+
raise ImportError("timm is required. Install it with: pip install timm")
|
85
|
+
|
86
|
+
self.save_hyperparameters()
|
87
|
+
|
88
|
+
# Check if using a pure timm model from HF Hub
|
89
|
+
if use_timm_model:
|
90
|
+
if timm_model_name is None:
|
91
|
+
timm_model_name = encoder_name
|
92
|
+
|
93
|
+
# Load model from timm (supports HF Hub with 'hf-hub:' prefix)
|
94
|
+
try:
|
95
|
+
self.model = timm.create_model(
|
96
|
+
timm_model_name,
|
97
|
+
pretrained=True if encoder_weights else False,
|
98
|
+
num_classes=num_classes,
|
99
|
+
in_chans=in_channels,
|
100
|
+
)
|
101
|
+
print(f"Loaded timm model: {timm_model_name}")
|
102
|
+
except Exception as e:
|
103
|
+
raise ValueError(
|
104
|
+
f"Failed to load timm model '{timm_model_name}'. "
|
105
|
+
f"Error: {str(e)}. "
|
106
|
+
f"For HF Hub models, use format 'hf-hub:username/model-name' or 'hf_hub:username/model-name'."
|
107
|
+
)
|
108
|
+
else:
|
109
|
+
# Use SMP architecture with timm encoder
|
110
|
+
if not SMP_AVAILABLE:
|
111
|
+
raise ImportError(
|
112
|
+
"segmentation-models-pytorch is required. "
|
113
|
+
"Install it with: pip install segmentation-models-pytorch"
|
114
|
+
)
|
115
|
+
|
116
|
+
# Create segmentation model with timm encoder using smp.create_model
|
117
|
+
try:
|
118
|
+
self.model = smp.create_model(
|
119
|
+
arch=architecture,
|
120
|
+
encoder_name=encoder_name,
|
121
|
+
encoder_weights=encoder_weights,
|
122
|
+
in_channels=in_channels,
|
123
|
+
classes=num_classes,
|
124
|
+
**decoder_kwargs,
|
125
|
+
)
|
126
|
+
except Exception as e:
|
127
|
+
# Provide helpful error message
|
128
|
+
available_archs = [
|
129
|
+
"unet",
|
130
|
+
"unetplusplus",
|
131
|
+
"manet",
|
132
|
+
"linknet",
|
133
|
+
"fpn",
|
134
|
+
"pspnet",
|
135
|
+
"deeplabv3",
|
136
|
+
"deeplabv3plus",
|
137
|
+
"pan",
|
138
|
+
"upernet",
|
139
|
+
]
|
140
|
+
raise ValueError(
|
141
|
+
f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
|
142
|
+
f"Error: {str(e)}. "
|
143
|
+
f"Available architectures include: {', '.join(available_archs)}. "
|
144
|
+
f"Please check the segmentation-models-pytorch documentation for supported combinations."
|
145
|
+
)
|
146
|
+
|
147
|
+
if freeze_encoder:
|
148
|
+
self._freeze_encoder()
|
149
|
+
|
150
|
+
# Set up loss function
|
151
|
+
if loss_fn is not None:
|
152
|
+
self.loss_fn = loss_fn
|
153
|
+
elif class_weights is not None:
|
154
|
+
self.loss_fn = nn.CrossEntropyLoss(weight=class_weights)
|
155
|
+
else:
|
156
|
+
self.loss_fn = nn.CrossEntropyLoss()
|
157
|
+
|
158
|
+
self.learning_rate = learning_rate
|
159
|
+
self.weight_decay = weight_decay
|
160
|
+
|
161
|
+
def _freeze_encoder(self):
|
162
|
+
"""Freeze encoder weights."""
|
163
|
+
if hasattr(self.model, "encoder"):
|
164
|
+
for param in self.model.encoder.parameters():
|
165
|
+
param.requires_grad = False
|
166
|
+
else:
|
167
|
+
# For pure timm models without separate encoder
|
168
|
+
if not self.hparams.use_timm_model:
|
169
|
+
raise ValueError("Model does not have an encoder attribute to freeze")
|
170
|
+
|
171
|
+
def forward(self, x):
|
172
|
+
return self.model(x)
|
173
|
+
|
174
|
+
def training_step(self, batch, batch_idx):
|
175
|
+
x, y = batch
|
176
|
+
logits = self(x)
|
177
|
+
loss = self.loss_fn(logits, y)
|
178
|
+
|
179
|
+
# Calculate IoU
|
180
|
+
pred = torch.argmax(logits, dim=1)
|
181
|
+
iou = self._compute_iou(pred, y)
|
182
|
+
|
183
|
+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
|
184
|
+
self.log("train_iou", iou, on_step=True, on_epoch=True, prog_bar=True)
|
185
|
+
|
186
|
+
return loss
|
187
|
+
|
188
|
+
def validation_step(self, batch, batch_idx):
|
189
|
+
x, y = batch
|
190
|
+
logits = self(x)
|
191
|
+
loss = self.loss_fn(logits, y)
|
192
|
+
|
193
|
+
# Calculate IoU
|
194
|
+
pred = torch.argmax(logits, dim=1)
|
195
|
+
iou = self._compute_iou(pred, y)
|
196
|
+
|
197
|
+
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
|
198
|
+
self.log("val_iou", iou, on_epoch=True, prog_bar=True)
|
199
|
+
|
200
|
+
return loss
|
201
|
+
|
202
|
+
def test_step(self, batch, batch_idx):
|
203
|
+
x, y = batch
|
204
|
+
logits = self(x)
|
205
|
+
loss = self.loss_fn(logits, y)
|
206
|
+
|
207
|
+
# Calculate IoU
|
208
|
+
pred = torch.argmax(logits, dim=1)
|
209
|
+
iou = self._compute_iou(pred, y)
|
210
|
+
|
211
|
+
self.log("test_loss", loss, on_epoch=True)
|
212
|
+
self.log("test_iou", iou, on_epoch=True)
|
213
|
+
|
214
|
+
return loss
|
215
|
+
|
216
|
+
def _compute_iou(self, pred, target, smooth=1e-6):
|
217
|
+
"""Compute mean IoU across all classes."""
|
218
|
+
num_classes = self.hparams.num_classes
|
219
|
+
ious = []
|
220
|
+
|
221
|
+
for cls in range(num_classes):
|
222
|
+
pred_cls = pred == cls
|
223
|
+
target_cls = target == cls
|
224
|
+
|
225
|
+
intersection = (pred_cls & target_cls).float().sum()
|
226
|
+
union = (pred_cls | target_cls).float().sum()
|
227
|
+
|
228
|
+
if union == 0:
|
229
|
+
continue
|
230
|
+
|
231
|
+
iou = (intersection + smooth) / (union + smooth)
|
232
|
+
ious.append(iou)
|
233
|
+
|
234
|
+
return (
|
235
|
+
torch.stack(ious).mean() if ious else torch.tensor(0.0, device=pred.device)
|
236
|
+
)
|
237
|
+
|
238
|
+
def configure_optimizers(self):
|
239
|
+
optimizer = torch.optim.AdamW(
|
240
|
+
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
241
|
+
)
|
242
|
+
|
243
|
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
244
|
+
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
245
|
+
)
|
246
|
+
|
247
|
+
return {
|
248
|
+
"optimizer": optimizer,
|
249
|
+
"lr_scheduler": {
|
250
|
+
"scheduler": scheduler,
|
251
|
+
"monitor": "val_loss",
|
252
|
+
},
|
253
|
+
}
|
254
|
+
|
255
|
+
def predict_step(self, batch, batch_idx):
|
256
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
257
|
+
logits = self(x)
|
258
|
+
probs = torch.softmax(logits, dim=1)
|
259
|
+
preds = torch.argmax(probs, dim=1)
|
260
|
+
return {"predictions": preds, "probabilities": probs}
|
261
|
+
|
262
|
+
|
263
|
+
class SegmentationDataset(Dataset):
|
264
|
+
"""
|
265
|
+
Dataset for semantic segmentation with remote sensing imagery.
|
266
|
+
"""
|
267
|
+
|
268
|
+
def __init__(
|
269
|
+
self,
|
270
|
+
image_paths: List[str],
|
271
|
+
mask_paths: List[str],
|
272
|
+
transform: Optional[Callable] = None,
|
273
|
+
num_channels: Optional[int] = None,
|
274
|
+
):
|
275
|
+
"""
|
276
|
+
Initialize SegmentationDataset.
|
277
|
+
|
278
|
+
Args:
|
279
|
+
image_paths (List[str]): List of paths to image files.
|
280
|
+
mask_paths (List[str]): List of paths to mask files.
|
281
|
+
transform (callable, optional): Transform to apply to images and masks.
|
282
|
+
num_channels (int, optional): Number of channels to use. If None, uses all.
|
283
|
+
"""
|
284
|
+
self.image_paths = image_paths
|
285
|
+
self.mask_paths = mask_paths
|
286
|
+
self.transform = transform
|
287
|
+
self.num_channels = num_channels
|
288
|
+
|
289
|
+
if len(image_paths) != len(mask_paths):
|
290
|
+
raise ValueError("Number of images must match number of masks")
|
291
|
+
|
292
|
+
def __len__(self):
|
293
|
+
return len(self.image_paths)
|
294
|
+
|
295
|
+
def __getitem__(self, idx):
|
296
|
+
import rasterio
|
297
|
+
|
298
|
+
# Load image
|
299
|
+
with rasterio.open(self.image_paths[idx]) as src:
|
300
|
+
image = src.read() # Shape: (C, H, W)
|
301
|
+
|
302
|
+
# Handle channel selection
|
303
|
+
if self.num_channels is not None and image.shape[0] != self.num_channels:
|
304
|
+
if image.shape[0] > self.num_channels:
|
305
|
+
image = image[: self.num_channels]
|
306
|
+
else:
|
307
|
+
# Pad with zeros if needed
|
308
|
+
padded = np.zeros(
|
309
|
+
(self.num_channels, image.shape[1], image.shape[2])
|
310
|
+
)
|
311
|
+
padded[: image.shape[0]] = image
|
312
|
+
image = padded
|
313
|
+
|
314
|
+
# Normalize to [0, 1]
|
315
|
+
if image.max() > 1.0:
|
316
|
+
image = image / 255.0
|
317
|
+
|
318
|
+
image = image.astype(np.float32)
|
319
|
+
|
320
|
+
# Load mask
|
321
|
+
with rasterio.open(self.mask_paths[idx]) as src:
|
322
|
+
mask = src.read(1) # Shape: (H, W)
|
323
|
+
mask = mask.astype(np.int64)
|
324
|
+
|
325
|
+
# Convert to tensors
|
326
|
+
image = torch.from_numpy(image)
|
327
|
+
mask = torch.from_numpy(mask)
|
328
|
+
|
329
|
+
# Apply transforms if provided
|
330
|
+
if self.transform is not None:
|
331
|
+
image, mask = self.transform(image, mask)
|
332
|
+
|
333
|
+
return image, mask
|
334
|
+
|
335
|
+
|
336
|
+
def train_timm_segmentation(
|
337
|
+
train_dataset: Dataset,
|
338
|
+
val_dataset: Optional[Dataset] = None,
|
339
|
+
test_dataset: Optional[Dataset] = None,
|
340
|
+
encoder_name: str = "resnet50",
|
341
|
+
architecture: str = "unet",
|
342
|
+
num_classes: int = 2,
|
343
|
+
in_channels: int = 3,
|
344
|
+
encoder_weights: str = "imagenet",
|
345
|
+
output_dir: str = "output",
|
346
|
+
batch_size: int = 8,
|
347
|
+
num_epochs: int = 50,
|
348
|
+
learning_rate: float = 1e-3,
|
349
|
+
weight_decay: float = 1e-4,
|
350
|
+
num_workers: int = 4,
|
351
|
+
freeze_encoder: bool = False,
|
352
|
+
class_weights: Optional[List[float]] = None,
|
353
|
+
accelerator: str = "auto",
|
354
|
+
devices: str = "auto",
|
355
|
+
monitor_metric: str = "val_loss",
|
356
|
+
mode: str = "min",
|
357
|
+
patience: int = 10,
|
358
|
+
save_top_k: int = 1,
|
359
|
+
checkpoint_path: Optional[str] = None,
|
360
|
+
use_timm_model: bool = False,
|
361
|
+
timm_model_name: Optional[str] = None,
|
362
|
+
**kwargs: Any,
|
363
|
+
) -> TimmSegmentationModel:
|
364
|
+
"""
|
365
|
+
Train a semantic segmentation model using timm encoder.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
train_dataset (Dataset): Training dataset.
|
369
|
+
val_dataset (Dataset, optional): Validation dataset.
|
370
|
+
test_dataset (Dataset, optional): Test dataset.
|
371
|
+
encoder_name (str): Name of timm encoder.
|
372
|
+
architecture (str): Segmentation architecture.
|
373
|
+
num_classes (int): Number of output classes.
|
374
|
+
in_channels (int): Number of input channels.
|
375
|
+
encoder_weights (str): Pretrained weights for encoder.
|
376
|
+
output_dir (str): Directory to save outputs.
|
377
|
+
batch_size (int): Batch size for training.
|
378
|
+
num_epochs (int): Number of training epochs.
|
379
|
+
learning_rate (float): Learning rate.
|
380
|
+
weight_decay (float): Weight decay for optimizer.
|
381
|
+
num_workers (int): Number of data loading workers.
|
382
|
+
freeze_encoder (bool): Freeze encoder during training.
|
383
|
+
class_weights (List[float], optional): Class weights for loss.
|
384
|
+
accelerator (str): Accelerator type ('auto', 'gpu', 'cpu').
|
385
|
+
devices (str): Devices to use.
|
386
|
+
monitor_metric (str): Metric to monitor for checkpointing.
|
387
|
+
mode (str): 'min' or 'max' for monitor_metric.
|
388
|
+
patience (int): Early stopping patience.
|
389
|
+
save_top_k (int): Number of best models to save.
|
390
|
+
checkpoint_path (str, optional): Path to checkpoint to resume from.
|
391
|
+
use_timm_model (bool): Load complete segmentation model from timm/HF Hub.
|
392
|
+
timm_model_name (str, optional): Model name from HF Hub (e.g., 'hf-hub:nvidia/mit-b0').
|
393
|
+
**kwargs: Additional arguments for PyTorch Lightning Trainer.
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
TimmSegmentationModel: Trained model.
|
397
|
+
"""
|
398
|
+
if not LIGHTNING_AVAILABLE:
|
399
|
+
raise ImportError(
|
400
|
+
"PyTorch Lightning is required. Install it with: pip install lightning"
|
401
|
+
)
|
402
|
+
|
403
|
+
# Create output directory
|
404
|
+
os.makedirs(output_dir, exist_ok=True)
|
405
|
+
model_dir = os.path.join(output_dir, "models")
|
406
|
+
os.makedirs(model_dir, exist_ok=True)
|
407
|
+
|
408
|
+
# Convert class weights to tensor if provided
|
409
|
+
weight_tensor = None
|
410
|
+
if class_weights is not None:
|
411
|
+
weight_tensor = torch.tensor(class_weights, dtype=torch.float32)
|
412
|
+
|
413
|
+
# Create model
|
414
|
+
model = TimmSegmentationModel(
|
415
|
+
encoder_name=encoder_name,
|
416
|
+
architecture=architecture,
|
417
|
+
num_classes=num_classes,
|
418
|
+
in_channels=in_channels,
|
419
|
+
encoder_weights=encoder_weights,
|
420
|
+
learning_rate=learning_rate,
|
421
|
+
weight_decay=weight_decay,
|
422
|
+
freeze_encoder=freeze_encoder,
|
423
|
+
class_weights=weight_tensor,
|
424
|
+
use_timm_model=use_timm_model,
|
425
|
+
timm_model_name=timm_model_name,
|
426
|
+
)
|
427
|
+
|
428
|
+
# Create data loaders
|
429
|
+
train_loader = DataLoader(
|
430
|
+
train_dataset,
|
431
|
+
batch_size=batch_size,
|
432
|
+
shuffle=True,
|
433
|
+
num_workers=num_workers,
|
434
|
+
pin_memory=True,
|
435
|
+
)
|
436
|
+
|
437
|
+
val_loader = None
|
438
|
+
if val_dataset is not None:
|
439
|
+
val_loader = DataLoader(
|
440
|
+
val_dataset,
|
441
|
+
batch_size=batch_size,
|
442
|
+
shuffle=False,
|
443
|
+
num_workers=num_workers,
|
444
|
+
pin_memory=True,
|
445
|
+
)
|
446
|
+
|
447
|
+
# Set up callbacks
|
448
|
+
callbacks = []
|
449
|
+
|
450
|
+
# Model checkpoint
|
451
|
+
checkpoint_callback = ModelCheckpoint(
|
452
|
+
dirpath=model_dir,
|
453
|
+
filename=f"{encoder_name}_{architecture}_{{epoch:02d}}_{{val_loss:.4f}}",
|
454
|
+
monitor=monitor_metric,
|
455
|
+
mode=mode,
|
456
|
+
save_top_k=save_top_k,
|
457
|
+
save_last=True,
|
458
|
+
verbose=True,
|
459
|
+
)
|
460
|
+
callbacks.append(checkpoint_callback)
|
461
|
+
|
462
|
+
# Early stopping
|
463
|
+
early_stop_callback = EarlyStopping(
|
464
|
+
monitor=monitor_metric,
|
465
|
+
patience=patience,
|
466
|
+
mode=mode,
|
467
|
+
verbose=True,
|
468
|
+
)
|
469
|
+
callbacks.append(early_stop_callback)
|
470
|
+
|
471
|
+
# Set up logger
|
472
|
+
logger = CSVLogger(model_dir, name="lightning_logs")
|
473
|
+
|
474
|
+
# Create trainer
|
475
|
+
trainer = pl.Trainer(
|
476
|
+
max_epochs=num_epochs,
|
477
|
+
accelerator=accelerator,
|
478
|
+
devices=devices,
|
479
|
+
callbacks=callbacks,
|
480
|
+
logger=logger,
|
481
|
+
log_every_n_steps=10,
|
482
|
+
**kwargs,
|
483
|
+
)
|
484
|
+
|
485
|
+
# Train model
|
486
|
+
print(f"Training {encoder_name} {architecture} for {num_epochs} epochs...")
|
487
|
+
trainer.fit(
|
488
|
+
model,
|
489
|
+
train_dataloaders=train_loader,
|
490
|
+
val_dataloaders=val_loader,
|
491
|
+
ckpt_path=checkpoint_path,
|
492
|
+
)
|
493
|
+
|
494
|
+
# Test if test dataset provided
|
495
|
+
if test_dataset is not None:
|
496
|
+
test_loader = DataLoader(
|
497
|
+
test_dataset,
|
498
|
+
batch_size=batch_size,
|
499
|
+
shuffle=False,
|
500
|
+
num_workers=num_workers,
|
501
|
+
pin_memory=True,
|
502
|
+
)
|
503
|
+
print("\nTesting model on test set...")
|
504
|
+
trainer.test(model, dataloaders=test_loader)
|
505
|
+
|
506
|
+
print(f"\nBest model saved at: {checkpoint_callback.best_model_path}")
|
507
|
+
|
508
|
+
# Save training history in compatible format
|
509
|
+
metrics = trainer.logged_metrics
|
510
|
+
history = {
|
511
|
+
"train_loss": [],
|
512
|
+
"val_loss": [],
|
513
|
+
"val_iou": [],
|
514
|
+
"epochs": [],
|
515
|
+
}
|
516
|
+
|
517
|
+
# Extract metrics from logger
|
518
|
+
import pandas as pd
|
519
|
+
import glob
|
520
|
+
|
521
|
+
csv_files = glob.glob(
|
522
|
+
os.path.join(model_dir, "lightning_logs", "**", "metrics.csv"), recursive=True
|
523
|
+
)
|
524
|
+
if csv_files:
|
525
|
+
df = pd.read_csv(csv_files[0])
|
526
|
+
|
527
|
+
# Group by epoch to get epoch-level metrics
|
528
|
+
epoch_data = df.groupby("epoch").last().reset_index()
|
529
|
+
|
530
|
+
if "train_loss_epoch" in epoch_data.columns:
|
531
|
+
history["train_loss"] = epoch_data["train_loss_epoch"].dropna().tolist()
|
532
|
+
if "val_loss" in epoch_data.columns:
|
533
|
+
history["val_loss"] = epoch_data["val_loss"].dropna().tolist()
|
534
|
+
if "val_iou" in epoch_data.columns:
|
535
|
+
history["val_iou"] = epoch_data["val_iou"].dropna().tolist()
|
536
|
+
if "epoch" in epoch_data.columns:
|
537
|
+
history["epochs"] = epoch_data["epoch"].dropna().tolist()
|
538
|
+
|
539
|
+
# Save history
|
540
|
+
history_path = os.path.join(model_dir, "training_history.pth")
|
541
|
+
torch.save(history, history_path)
|
542
|
+
print(f"Training history saved to: {history_path}")
|
543
|
+
|
544
|
+
return model
|
545
|
+
|
546
|
+
|
547
|
+
def predict_segmentation(
|
548
|
+
model: Union[TimmSegmentationModel, nn.Module],
|
549
|
+
image_paths: List[str],
|
550
|
+
batch_size: int = 8,
|
551
|
+
num_workers: int = 4,
|
552
|
+
device: Optional[str] = None,
|
553
|
+
return_probabilities: bool = False,
|
554
|
+
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
555
|
+
"""
|
556
|
+
Make predictions on images using a trained segmentation model.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
model: Trained model.
|
560
|
+
image_paths: List of paths to images.
|
561
|
+
batch_size: Batch size for inference.
|
562
|
+
num_workers: Number of data loading workers.
|
563
|
+
device: Device to use ('cuda', 'cpu', etc.). Auto-detected if None.
|
564
|
+
return_probabilities: If True, return both predictions and probabilities.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
predictions: Array of predicted segmentation masks.
|
568
|
+
probabilities (optional): Array of class probabilities if return_probabilities=True.
|
569
|
+
"""
|
570
|
+
if device is None:
|
571
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
572
|
+
|
573
|
+
# Create dummy masks for dataset
|
574
|
+
dummy_masks = image_paths # Use image paths as placeholders
|
575
|
+
dataset = SegmentationDataset(image_paths, dummy_masks)
|
576
|
+
|
577
|
+
loader = DataLoader(
|
578
|
+
dataset,
|
579
|
+
batch_size=batch_size,
|
580
|
+
shuffle=False,
|
581
|
+
num_workers=num_workers,
|
582
|
+
pin_memory=True,
|
583
|
+
)
|
584
|
+
|
585
|
+
model.eval()
|
586
|
+
model = model.to(device)
|
587
|
+
|
588
|
+
all_preds = []
|
589
|
+
all_probs = []
|
590
|
+
|
591
|
+
with torch.no_grad():
|
592
|
+
for images, _ in tqdm(loader, desc="Making predictions"):
|
593
|
+
images = images.to(device)
|
594
|
+
|
595
|
+
if isinstance(model, TimmSegmentationModel):
|
596
|
+
logits = model(images)
|
597
|
+
else:
|
598
|
+
logits = model(images)
|
599
|
+
|
600
|
+
probs = torch.softmax(logits, dim=1)
|
601
|
+
preds = torch.argmax(probs, dim=1)
|
602
|
+
|
603
|
+
all_preds.append(preds.cpu().numpy())
|
604
|
+
if return_probabilities:
|
605
|
+
all_probs.append(probs.cpu().numpy())
|
606
|
+
|
607
|
+
predictions = np.concatenate(all_preds)
|
608
|
+
|
609
|
+
if return_probabilities:
|
610
|
+
probabilities = np.concatenate(all_probs)
|
611
|
+
return predictions, probabilities
|
612
|
+
|
613
|
+
return predictions
|
614
|
+
|
615
|
+
|
616
|
+
def train_timm_segmentation_model(
|
617
|
+
images_dir: str,
|
618
|
+
labels_dir: str,
|
619
|
+
output_dir: str,
|
620
|
+
input_format: str = "directory",
|
621
|
+
encoder_name: str = "resnet50",
|
622
|
+
architecture: str = "unet",
|
623
|
+
encoder_weights: str = "imagenet",
|
624
|
+
num_channels: int = 3,
|
625
|
+
num_classes: int = 2,
|
626
|
+
batch_size: int = 8,
|
627
|
+
num_epochs: int = 50,
|
628
|
+
learning_rate: float = 0.001,
|
629
|
+
weight_decay: float = 1e-4,
|
630
|
+
val_split: float = 0.2,
|
631
|
+
seed: int = 42,
|
632
|
+
num_workers: int = 4,
|
633
|
+
freeze_encoder: bool = False,
|
634
|
+
monitor_metric: str = "val_iou",
|
635
|
+
mode: str = "max",
|
636
|
+
patience: int = 10,
|
637
|
+
save_top_k: int = 1,
|
638
|
+
verbose: bool = True,
|
639
|
+
device: Optional[str] = None,
|
640
|
+
use_timm_model: bool = False,
|
641
|
+
timm_model_name: Optional[str] = None,
|
642
|
+
**kwargs: Any,
|
643
|
+
) -> torch.nn.Module:
|
644
|
+
"""
|
645
|
+
Train a semantic segmentation model using timm encoder (simplified interface).
|
646
|
+
|
647
|
+
This is a simplified function that takes image and label directories and handles
|
648
|
+
the dataset creation automatically, similar to train_segmentation_model.
|
649
|
+
|
650
|
+
Args:
|
651
|
+
images_dir (str): Directory containing image GeoTIFF files (for 'directory' format),
|
652
|
+
or root directory containing images/ subdirectory (for 'yolo' format),
|
653
|
+
or directory containing images (for 'coco' format).
|
654
|
+
labels_dir (str): Directory containing label GeoTIFF files (for 'directory' format),
|
655
|
+
or path to COCO annotations JSON file (for 'coco' format),
|
656
|
+
or not used (for 'yolo' format - labels are in images_dir/labels/).
|
657
|
+
output_dir (str): Directory to save model checkpoints and results.
|
658
|
+
input_format (str): Input data format - 'directory' (default), 'coco', or 'yolo'.
|
659
|
+
- 'directory': Standard directory structure with separate images_dir and labels_dir
|
660
|
+
- 'coco': COCO JSON format (labels_dir should be path to instances.json)
|
661
|
+
- 'yolo': YOLO format (images_dir is root with images/ and labels/ subdirectories)
|
662
|
+
encoder_name (str): Name of timm encoder (e.g., 'resnet50', 'efficientnet_b3').
|
663
|
+
architecture (str): Segmentation architecture ('unet', 'unetplusplus', 'deeplabv3',
|
664
|
+
'deeplabv3plus', 'fpn', 'pspnet', 'linknet', 'manet', 'pan').
|
665
|
+
encoder_weights (str): Pretrained weights ('imagenet', 'ssl', 'swsl', None).
|
666
|
+
num_channels (int): Number of input channels.
|
667
|
+
num_classes (int): Number of output classes.
|
668
|
+
batch_size (int): Batch size for training.
|
669
|
+
num_epochs (int): Number of training epochs.
|
670
|
+
learning_rate (float): Learning rate.
|
671
|
+
weight_decay (float): Weight decay for optimizer.
|
672
|
+
val_split (float): Validation split ratio (0-1).
|
673
|
+
seed (int): Random seed for reproducibility.
|
674
|
+
num_workers (int): Number of data loading workers.
|
675
|
+
freeze_encoder (bool): Freeze encoder during training.
|
676
|
+
monitor_metric (str): Metric to monitor ('val_loss' or 'val_iou').
|
677
|
+
mode (str): 'min' for loss, 'max' for metrics.
|
678
|
+
patience (int): Early stopping patience.
|
679
|
+
save_top_k (int): Number of best models to save.
|
680
|
+
verbose (bool): Print training progress.
|
681
|
+
device (str, optional): Device to use. Auto-detected if None.
|
682
|
+
use_timm_model (bool): Load complete segmentation model from timm/HF Hub.
|
683
|
+
timm_model_name (str, optional): Model name from HF Hub (e.g., 'hf-hub:nvidia/mit-b0').
|
684
|
+
**kwargs: Additional arguments for training.
|
685
|
+
|
686
|
+
Returns:
|
687
|
+
torch.nn.Module: Trained model.
|
688
|
+
"""
|
689
|
+
import glob
|
690
|
+
from sklearn.model_selection import train_test_split
|
691
|
+
from .train import parse_coco_annotations, parse_yolo_annotations
|
692
|
+
|
693
|
+
if not LIGHTNING_AVAILABLE:
|
694
|
+
raise ImportError(
|
695
|
+
"PyTorch Lightning is required. Install it with: pip install lightning"
|
696
|
+
)
|
697
|
+
|
698
|
+
# Set random seed
|
699
|
+
torch.manual_seed(seed)
|
700
|
+
np.random.seed(seed)
|
701
|
+
|
702
|
+
# Get image and label paths based on input format
|
703
|
+
if input_format.lower() == "coco":
|
704
|
+
# Parse COCO format annotations
|
705
|
+
if verbose:
|
706
|
+
print(f"Loading COCO format annotations from {labels_dir}")
|
707
|
+
# For COCO format, labels_dir is path to instances.json
|
708
|
+
# Labels are typically in a "labels" directory parallel to "annotations"
|
709
|
+
coco_root = os.path.dirname(os.path.dirname(labels_dir)) # Go up two levels
|
710
|
+
labels_directory = os.path.join(coco_root, "labels")
|
711
|
+
image_paths, label_paths = parse_coco_annotations(
|
712
|
+
labels_dir, images_dir, labels_directory
|
713
|
+
)
|
714
|
+
elif input_format.lower() == "yolo":
|
715
|
+
# Parse YOLO format annotations
|
716
|
+
if verbose:
|
717
|
+
print(f"Loading YOLO format data from {images_dir}")
|
718
|
+
image_paths, label_paths = parse_yolo_annotations(images_dir)
|
719
|
+
else:
|
720
|
+
# Default: directory format
|
721
|
+
image_paths = sorted(
|
722
|
+
glob.glob(os.path.join(images_dir, "*.tif"))
|
723
|
+
+ glob.glob(os.path.join(images_dir, "*.tiff"))
|
724
|
+
)
|
725
|
+
label_paths = sorted(
|
726
|
+
glob.glob(os.path.join(labels_dir, "*.tif"))
|
727
|
+
+ glob.glob(os.path.join(labels_dir, "*.tiff"))
|
728
|
+
)
|
729
|
+
|
730
|
+
if len(image_paths) == 0:
|
731
|
+
raise ValueError(f"No images found")
|
732
|
+
if len(label_paths) == 0:
|
733
|
+
raise ValueError(f"No labels found")
|
734
|
+
if len(image_paths) != len(label_paths):
|
735
|
+
raise ValueError(
|
736
|
+
f"Number of images ({len(image_paths)}) doesn't match "
|
737
|
+
f"number of labels ({len(label_paths)})"
|
738
|
+
)
|
739
|
+
|
740
|
+
if verbose:
|
741
|
+
print(f"Found {len(image_paths)} image-label pairs")
|
742
|
+
|
743
|
+
# Split into train and validation
|
744
|
+
train_images, val_images, train_labels, val_labels = train_test_split(
|
745
|
+
image_paths, label_paths, test_size=val_split, random_state=seed
|
746
|
+
)
|
747
|
+
|
748
|
+
if verbose:
|
749
|
+
print(f"Training samples: {len(train_images)}")
|
750
|
+
print(f"Validation samples: {len(val_images)}")
|
751
|
+
|
752
|
+
# Create datasets
|
753
|
+
train_dataset = SegmentationDataset(
|
754
|
+
image_paths=train_images,
|
755
|
+
mask_paths=train_labels,
|
756
|
+
num_channels=num_channels,
|
757
|
+
)
|
758
|
+
|
759
|
+
val_dataset = SegmentationDataset(
|
760
|
+
image_paths=val_images,
|
761
|
+
mask_paths=val_labels,
|
762
|
+
num_channels=num_channels,
|
763
|
+
)
|
764
|
+
|
765
|
+
# Train model
|
766
|
+
model = train_timm_segmentation(
|
767
|
+
train_dataset=train_dataset,
|
768
|
+
val_dataset=val_dataset,
|
769
|
+
test_dataset=None,
|
770
|
+
encoder_name=encoder_name,
|
771
|
+
architecture=architecture,
|
772
|
+
num_classes=num_classes,
|
773
|
+
in_channels=num_channels,
|
774
|
+
encoder_weights=encoder_weights,
|
775
|
+
output_dir=output_dir,
|
776
|
+
batch_size=batch_size,
|
777
|
+
num_epochs=num_epochs,
|
778
|
+
learning_rate=learning_rate,
|
779
|
+
weight_decay=weight_decay,
|
780
|
+
num_workers=num_workers,
|
781
|
+
freeze_encoder=freeze_encoder,
|
782
|
+
accelerator="auto" if device is None else device,
|
783
|
+
monitor_metric=monitor_metric,
|
784
|
+
mode=mode,
|
785
|
+
patience=patience,
|
786
|
+
save_top_k=save_top_k,
|
787
|
+
use_timm_model=use_timm_model,
|
788
|
+
timm_model_name=timm_model_name,
|
789
|
+
**kwargs,
|
790
|
+
)
|
791
|
+
|
792
|
+
if verbose:
|
793
|
+
print(f"\nTraining completed. Model saved to {output_dir}")
|
794
|
+
|
795
|
+
return model.model # Return the underlying model
|
796
|
+
|
797
|
+
|
798
|
+
def timm_semantic_segmentation(
|
799
|
+
input_path: str,
|
800
|
+
output_path: str,
|
801
|
+
model_path: str,
|
802
|
+
encoder_name: str = "resnet50",
|
803
|
+
architecture: str = "unet",
|
804
|
+
num_channels: int = 3,
|
805
|
+
num_classes: int = 2,
|
806
|
+
window_size: int = 512,
|
807
|
+
overlap: int = 256,
|
808
|
+
batch_size: int = 4,
|
809
|
+
device: Optional[str] = None,
|
810
|
+
quiet: bool = False,
|
811
|
+
use_timm_model: bool = False,
|
812
|
+
timm_model_name: Optional[str] = None,
|
813
|
+
**kwargs: Any,
|
814
|
+
) -> None:
|
815
|
+
"""
|
816
|
+
Perform semantic segmentation on a raster using a trained timm model.
|
817
|
+
|
818
|
+
This function performs inference on a GeoTIFF using a sliding window approach
|
819
|
+
and saves the result as a georeferenced raster.
|
820
|
+
|
821
|
+
Args:
|
822
|
+
input_path (str): Path to input GeoTIFF file.
|
823
|
+
output_path (str): Path to save output mask.
|
824
|
+
model_path (str): Path to trained model checkpoint (.ckpt or .pth).
|
825
|
+
encoder_name (str): Name of timm encoder used in training.
|
826
|
+
architecture (str): Segmentation architecture used in training.
|
827
|
+
num_channels (int): Number of input channels.
|
828
|
+
num_classes (int): Number of output classes.
|
829
|
+
window_size (int): Size of sliding window for inference.
|
830
|
+
overlap (int): Overlap between adjacent windows.
|
831
|
+
batch_size (int): Batch size for inference.
|
832
|
+
device (str, optional): Device to use. Auto-detected if None.
|
833
|
+
quiet (bool): If True, suppress progress messages.
|
834
|
+
use_timm_model (bool): If True, model was trained with timm model from HF Hub.
|
835
|
+
timm_model_name (str, optional): Model name from HF Hub used during training.
|
836
|
+
**kwargs: Additional arguments.
|
837
|
+
"""
|
838
|
+
import rasterio
|
839
|
+
from rasterio.windows import Window
|
840
|
+
|
841
|
+
if device is None:
|
842
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
843
|
+
|
844
|
+
# Load model
|
845
|
+
if model_path.endswith(".ckpt"):
|
846
|
+
model = TimmSegmentationModel.load_from_checkpoint(
|
847
|
+
model_path,
|
848
|
+
encoder_name=encoder_name,
|
849
|
+
architecture=architecture,
|
850
|
+
num_classes=num_classes,
|
851
|
+
in_channels=num_channels,
|
852
|
+
use_timm_model=use_timm_model,
|
853
|
+
timm_model_name=timm_model_name,
|
854
|
+
)
|
855
|
+
model = model.model # Get underlying model
|
856
|
+
else:
|
857
|
+
# Load state dict
|
858
|
+
if use_timm_model:
|
859
|
+
# Load pure timm model
|
860
|
+
if timm_model_name is None:
|
861
|
+
timm_model_name = encoder_name
|
862
|
+
|
863
|
+
model = timm.create_model(
|
864
|
+
timm_model_name,
|
865
|
+
pretrained=False,
|
866
|
+
num_classes=num_classes,
|
867
|
+
in_chans=num_channels,
|
868
|
+
)
|
869
|
+
else:
|
870
|
+
# Load SMP model
|
871
|
+
import segmentation_models_pytorch as smp
|
872
|
+
|
873
|
+
try:
|
874
|
+
model = smp.create_model(
|
875
|
+
arch=architecture,
|
876
|
+
encoder_name=encoder_name,
|
877
|
+
encoder_weights=None,
|
878
|
+
in_channels=num_channels,
|
879
|
+
classes=num_classes,
|
880
|
+
)
|
881
|
+
except Exception as e:
|
882
|
+
raise ValueError(
|
883
|
+
f"Failed to create model with architecture '{architecture}' and encoder '{encoder_name}'. "
|
884
|
+
f"Error: {str(e)}"
|
885
|
+
)
|
886
|
+
|
887
|
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
888
|
+
|
889
|
+
model.eval()
|
890
|
+
model = model.to(device)
|
891
|
+
|
892
|
+
# Read input raster
|
893
|
+
with rasterio.open(input_path) as src:
|
894
|
+
meta = src.meta.copy()
|
895
|
+
height, width = src.shape
|
896
|
+
|
897
|
+
# Calculate number of windows
|
898
|
+
stride = window_size - overlap
|
899
|
+
n_rows = int(np.ceil((height - overlap) / stride))
|
900
|
+
n_cols = int(np.ceil((width - overlap) / stride))
|
901
|
+
|
902
|
+
if not quiet:
|
903
|
+
print(f"Processing {n_rows} x {n_cols} = {n_rows * n_cols} windows")
|
904
|
+
|
905
|
+
# Initialize output array (use int32 to avoid overflow during accumulation)
|
906
|
+
output = np.zeros((height, width), dtype=np.int32)
|
907
|
+
count = np.zeros((height, width), dtype=np.int32)
|
908
|
+
|
909
|
+
# Process windows
|
910
|
+
with torch.no_grad():
|
911
|
+
for i in tqdm(range(n_rows), disable=quiet, desc="Processing rows"):
|
912
|
+
for j in range(n_cols):
|
913
|
+
# Calculate window bounds
|
914
|
+
row_start = i * stride
|
915
|
+
col_start = j * stride
|
916
|
+
row_end = min(row_start + window_size, height)
|
917
|
+
col_end = min(col_start + window_size, width)
|
918
|
+
|
919
|
+
# Read window
|
920
|
+
window = Window(
|
921
|
+
col_start, row_start, col_end - col_start, row_end - row_start
|
922
|
+
)
|
923
|
+
img = src.read(window=window)
|
924
|
+
|
925
|
+
# Handle channel selection
|
926
|
+
if img.shape[0] > num_channels:
|
927
|
+
img = img[:num_channels]
|
928
|
+
elif img.shape[0] < num_channels:
|
929
|
+
padded = np.zeros((num_channels, img.shape[1], img.shape[2]))
|
930
|
+
padded[: img.shape[0]] = img
|
931
|
+
img = padded
|
932
|
+
|
933
|
+
# Normalize
|
934
|
+
if img.max() > 1.0:
|
935
|
+
img = img / 255.0
|
936
|
+
img = img.astype(np.float32)
|
937
|
+
|
938
|
+
# Pad if necessary
|
939
|
+
h, w = img.shape[1], img.shape[2]
|
940
|
+
if h < window_size or w < window_size:
|
941
|
+
padded = np.zeros(
|
942
|
+
(num_channels, window_size, window_size), dtype=np.float32
|
943
|
+
)
|
944
|
+
padded[:, :h, :w] = img
|
945
|
+
img = padded
|
946
|
+
|
947
|
+
# Predict
|
948
|
+
img_tensor = torch.from_numpy(img).unsqueeze(0).to(device)
|
949
|
+
logits = model(img_tensor)
|
950
|
+
pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
|
951
|
+
|
952
|
+
# Crop to actual size
|
953
|
+
pred = pred[:h, :w]
|
954
|
+
|
955
|
+
# Add to output
|
956
|
+
output[row_start:row_end, col_start:col_end] += pred
|
957
|
+
count[row_start:row_end, col_start:col_end] += 1
|
958
|
+
|
959
|
+
# Average overlapping predictions
|
960
|
+
output = (output / np.maximum(count, 1)).astype(np.uint8)
|
961
|
+
|
962
|
+
# Save output
|
963
|
+
meta.update({"count": 1, "dtype": "uint8", "compress": "lzw"})
|
964
|
+
|
965
|
+
with rasterio.open(output_path, "w", **meta) as dst:
|
966
|
+
dst.write(output, 1)
|
967
|
+
|
968
|
+
if not quiet:
|
969
|
+
print(f"Segmentation saved to {output_path}")
|
970
|
+
|
971
|
+
|
972
|
+
def push_timm_model_to_hub(
|
973
|
+
model_path: str,
|
974
|
+
repo_id: str,
|
975
|
+
encoder_name: str = "resnet50",
|
976
|
+
architecture: str = "unet",
|
977
|
+
num_channels: int = 3,
|
978
|
+
num_classes: int = 2,
|
979
|
+
use_timm_model: bool = False,
|
980
|
+
timm_model_name: Optional[str] = None,
|
981
|
+
commit_message: Optional[str] = None,
|
982
|
+
private: bool = False,
|
983
|
+
token: Optional[str] = None,
|
984
|
+
**kwargs: Any,
|
985
|
+
) -> str:
|
986
|
+
"""
|
987
|
+
Push a trained timm segmentation model to Hugging Face Hub.
|
988
|
+
|
989
|
+
Args:
|
990
|
+
model_path (str): Path to trained model checkpoint (.ckpt or .pth).
|
991
|
+
repo_id (str): Repository ID on HF Hub (e.g., 'username/model-name').
|
992
|
+
encoder_name (str): Name of timm encoder used in training.
|
993
|
+
architecture (str): Segmentation architecture used in training.
|
994
|
+
num_channels (int): Number of input channels.
|
995
|
+
num_classes (int): Number of output classes.
|
996
|
+
use_timm_model (bool): If True, model was trained with pure timm model.
|
997
|
+
timm_model_name (str, optional): Model name from HF Hub used during training.
|
998
|
+
commit_message (str, optional): Commit message for the upload.
|
999
|
+
private (bool): Whether to make the repository private.
|
1000
|
+
token (str, optional): HuggingFace API token. If None, uses logged-in token.
|
1001
|
+
**kwargs: Additional arguments for push_to_hub.
|
1002
|
+
|
1003
|
+
Returns:
|
1004
|
+
str: URL of the uploaded model on HF Hub.
|
1005
|
+
"""
|
1006
|
+
try:
|
1007
|
+
from huggingface_hub import HfApi, create_repo
|
1008
|
+
except ImportError:
|
1009
|
+
raise ImportError(
|
1010
|
+
"huggingface_hub is required to push models. "
|
1011
|
+
"Install it with: pip install huggingface-hub"
|
1012
|
+
)
|
1013
|
+
|
1014
|
+
# Load model
|
1015
|
+
if model_path.endswith(".ckpt"):
|
1016
|
+
lightning_model = TimmSegmentationModel.load_from_checkpoint(
|
1017
|
+
model_path,
|
1018
|
+
encoder_name=encoder_name,
|
1019
|
+
architecture=architecture,
|
1020
|
+
num_classes=num_classes,
|
1021
|
+
in_channels=num_channels,
|
1022
|
+
use_timm_model=use_timm_model,
|
1023
|
+
timm_model_name=timm_model_name,
|
1024
|
+
)
|
1025
|
+
model = lightning_model.model
|
1026
|
+
else:
|
1027
|
+
# Load state dict
|
1028
|
+
if use_timm_model:
|
1029
|
+
if timm_model_name is None:
|
1030
|
+
timm_model_name = encoder_name
|
1031
|
+
|
1032
|
+
model = timm.create_model(
|
1033
|
+
timm_model_name,
|
1034
|
+
pretrained=False,
|
1035
|
+
num_classes=num_classes,
|
1036
|
+
in_chans=num_channels,
|
1037
|
+
)
|
1038
|
+
else:
|
1039
|
+
import segmentation_models_pytorch as smp
|
1040
|
+
|
1041
|
+
model = smp.create_model(
|
1042
|
+
arch=architecture,
|
1043
|
+
encoder_name=encoder_name,
|
1044
|
+
encoder_weights=None,
|
1045
|
+
in_channels=num_channels,
|
1046
|
+
classes=num_classes,
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
1050
|
+
|
1051
|
+
# Create repository if it doesn't exist
|
1052
|
+
api = HfApi(token=token)
|
1053
|
+
try:
|
1054
|
+
create_repo(repo_id, private=private, token=token, exist_ok=True)
|
1055
|
+
except Exception as e:
|
1056
|
+
print(f"Repository creation note: {e}")
|
1057
|
+
|
1058
|
+
# Save model configuration
|
1059
|
+
config = {
|
1060
|
+
"encoder_name": encoder_name,
|
1061
|
+
"architecture": architecture,
|
1062
|
+
"num_channels": num_channels,
|
1063
|
+
"num_classes": num_classes,
|
1064
|
+
"use_timm_model": use_timm_model,
|
1065
|
+
"timm_model_name": timm_model_name,
|
1066
|
+
"model_type": "timm_segmentation",
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
# Save model state dict to temporary file
|
1070
|
+
import tempfile
|
1071
|
+
import json
|
1072
|
+
|
1073
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
1074
|
+
# Save model state dict
|
1075
|
+
model_save_path = os.path.join(tmpdir, "model.pth")
|
1076
|
+
torch.save(model.state_dict(), model_save_path)
|
1077
|
+
|
1078
|
+
# Save config
|
1079
|
+
config_path = os.path.join(tmpdir, "config.json")
|
1080
|
+
with open(config_path, "w") as f:
|
1081
|
+
json.dump(config, f, indent=2)
|
1082
|
+
|
1083
|
+
# Upload files
|
1084
|
+
if commit_message is None:
|
1085
|
+
commit_message = f"Upload {architecture} with {encoder_name} encoder"
|
1086
|
+
|
1087
|
+
api.upload_folder(
|
1088
|
+
folder_path=tmpdir,
|
1089
|
+
repo_id=repo_id,
|
1090
|
+
commit_message=commit_message,
|
1091
|
+
token=token,
|
1092
|
+
**kwargs,
|
1093
|
+
)
|
1094
|
+
|
1095
|
+
url = f"https://huggingface.co/{repo_id}"
|
1096
|
+
print(f"Model successfully pushed to: {url}")
|
1097
|
+
return url
|