eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
eoml/torch/resnet.py ADDED
@@ -0,0 +1,437 @@
1
+ """ResNet architecture implementations for PyTorch.
2
+
3
+ This module provides ResNet (Residual Network) implementations adapted from
4
+ https://colab.research.google.com/github/seyrankhademi/ResNet_CIFAR10/blob/master/CIFAR10_ResNet.ipynb
5
+
6
+ Includes ResNet variants (ResNet-20, ResNet-32, ResNet-56) and a year-aware variant
7
+ that incorporates temporal information as an additional input feature.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from eoml.torch.cnn.torch_utils import conv_out_sizes
14
+ from torch import nn
15
+ from torch.nn import init
16
+
17
+
18
+ # taken from https://colab.research.google.com/github/seyrankhademi/ResNet_CIFAR10/blob/master/CIFAR10_ResNet.ipynb#scrollTo=V9Y2hYRwB-qg
19
+
20
+ # We define all the classes and function regarding the ResNet architecture in this code cell
21
+ #__all__ = ['resnet20']
22
+ #'ResNet','resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'
23
+
24
+
25
+ def _weights_init(m):
26
+ """Initialize CNN weights using Kaiming normal initialization.
27
+
28
+ Applies to Linear and Conv2d layers.
29
+
30
+ Args:
31
+ m: PyTorch module/layer to initialize.
32
+ """
33
+ classname = m.__class__.__name__
34
+ if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
35
+ init.kaiming_normal_(m.weight)
36
+
37
+
38
+ class LambdaLayer(nn.Module):
39
+ """Lambda layer for identity mapping between ResNet blocks with different feature map sizes.
40
+
41
+ Used for handling dimension changes in skip connections when feature map sizes differ.
42
+
43
+ Attributes:
44
+ lambd: Lambda function to apply to input.
45
+ """
46
+
47
+ def __init__(self, lambd):
48
+ """Initialize LambdaLayer.
49
+
50
+ Args:
51
+ lambd: Lambda function for transforming input.
52
+ """
53
+ super().__init__()
54
+ self.lambd = lambd
55
+
56
+ def forward(self, x):
57
+ """Forward pass applying the lambda function.
58
+
59
+ Args:
60
+ x (torch.Tensor): Input tensor.
61
+
62
+ Returns:
63
+ torch.Tensor: Transformed tensor.
64
+ """
65
+ return self.lambd(x)
66
+
67
+
68
+ # A basic block as shown in Fig.3 (right) in the paper consists of two convolutional blocks, each followed by a Bach-Norm layer.
69
+ # Every basic block is shortcuted in ResNet architecture to construct f(x)+x module.
70
+ # Expansion for option 'A' in the paper is equal to identity with extra zero entries padded
71
+ # for increasing dimensions between layers with different feature map size. This option introduces no extra parameter.
72
+ class BasicBlock(nn.Module):
73
+ """Basic residual block for ResNet architecture.
74
+
75
+ Consists of two convolutional layers with batch normalization and a shortcut connection.
76
+ Implements the f(x) + x residual mapping.
77
+
78
+ Attributes:
79
+ expansion (int): Expansion factor for output channels (always 1 for BasicBlock).
80
+ conv1 (nn.Conv2d): First convolutional layer.
81
+ bn1 (nn.BatchNorm2d): Batch normalization after first conv.
82
+ conv2 (nn.Conv2d): Second convolutional layer.
83
+ bn2 (nn.BatchNorm2d): Batch normalization after second conv.
84
+ shortcut (nn.Sequential): Shortcut connection for identity mapping.
85
+ """
86
+ # the output of a block keep the same size
87
+ expansion = 1
88
+
89
+ def __init__(self, in_planes, planes, stride=1, option='A'):
90
+ """Initialize BasicBlock.
91
+
92
+ Args:
93
+ in_planes (int): Number of input channels.
94
+ planes (int): Number of output channels.
95
+ stride (int, optional): Stride for first convolution. Defaults to 1.
96
+ option (str, optional): Shortcut option - 'A' for padding (CIFAR10) or 'B' for
97
+ projection. Defaults to 'A'.
98
+ """
99
+ super().__init__()
100
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
101
+ self.bn1 = nn.BatchNorm2d(planes)
102
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
103
+ self.bn2 = nn.BatchNorm2d(planes)
104
+ self.shortcut = nn.Sequential()
105
+ if stride != 1 or in_planes != planes:
106
+ if option == 'A':
107
+ """
108
+ For CIFAR10 experiment, ResNet paper uses option A.
109
+ """
110
+ ## we colapase the side by the strid of 2
111
+ ##then then we take the output size (2 time input(B) in initial conf) =>initial +padding = B+(2B/4)*2=2B/
112
+ self.shortcut = LambdaLayer(lambda x:
113
+ F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant",
114
+ 0))
115
+ elif option == 'B':
116
+ self.shortcut = nn.Sequential(
117
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
118
+ nn.BatchNorm2d(self.expansion * planes)
119
+ )
120
+
121
+ def forward(self, x):
122
+ """Forward pass through the basic block.
123
+
124
+ Args:
125
+ x (torch.Tensor): Input tensor.
126
+
127
+ Returns:
128
+ torch.Tensor: Output tensor after residual connection and activation.
129
+ """
130
+ out = F.relu(self.bn1(self.conv1(x)))
131
+ out = self.bn2(self.conv2(out))
132
+ out += self.shortcut(x)
133
+ out = F.relu(out)
134
+ return out
135
+
136
+
137
+ # Stack of 3 times 2*n (n is the number of basic blocks) layers are used for making the ResNet model,
138
+ # where each 2n layers have feature maps of size {16,32,64}, respectively.
139
+ # The subsampling is performed by convolutions with a stride of 2.
140
+ class ResNet(nn.Module):
141
+ """ResNet architecture for image classification.
142
+
143
+ Implements ResNet with 3 stages of residual blocks. The number of blocks in each stage
144
+ determines the depth (e.g., ResNet-20, ResNet-32, ResNet-56).
145
+
146
+ Attributes:
147
+ in_planes (int): Current number of input planes, updated as layers are built.
148
+ conv1 (nn.Conv2d): Initial convolutional layer.
149
+ bn1 (nn.BatchNorm2d): Batch normalization after initial conv.
150
+ layer1 (nn.Sequential): First stage of residual blocks.
151
+ layer2 (nn.Sequential): Second stage of residual blocks.
152
+ layer3 (nn.Sequential): Third stage of residual blocks.
153
+ linear (nn.Linear): Final fully connected layer for classification.
154
+ """
155
+ # TODO check size before dense and size
156
+ #(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
157
+ def __init__(self, size, in_band, n_out, block, num_blocks):
158
+ """Initialize ResNet model.
159
+
160
+ Args:
161
+ size (int): Input image size (not currently used in implementation).
162
+ in_band (int): Number of input channels/bands.
163
+ n_out (int): Number of output classes.
164
+ block: Block class to use (typically BasicBlock).
165
+ num_blocks (list): List of integers specifying number of blocks in each stage.
166
+ """
167
+
168
+ super().__init__()
169
+ # in plane is updated as we _make_layer (multiplied by block expansion)
170
+ #self.in_planes = 16
171
+ self.in_planes = 2*32
172
+ # go from in band to in_planes == input of first block
173
+ self.conv1 = nn.Conv2d(in_band, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
174
+ self.bn1 = nn.BatchNorm2d(self.in_planes)
175
+ # 3 layer of n block of 2 conv N layer = 3*n*2=6n +1 input conv +1 output dense=6n+2 layer
176
+ self.layer1 = self._make_layer(block, self.in_planes, num_blocks[0], stride=1)
177
+ self.layer2 = self._make_layer(block, 2*64, num_blocks[1], stride=2)
178
+ self.layer3 = self._make_layer(block, 2*128, num_blocks[2], stride=2)
179
+ self.linear = nn.Linear(2*128, n_out)
180
+ self.apply(_weights_init)
181
+
182
+
183
+ def _make_layer(self, block, planes, num_blocks, stride):
184
+ """Create a stage of residual blocks.
185
+
186
+ Args:
187
+ block: Block class to instantiate.
188
+ planes (int): Number of output channels for blocks in this stage.
189
+ num_blocks (int): Number of blocks in this stage.
190
+ stride (int): Stride for first block (subsequent blocks use stride=1).
191
+
192
+ Returns:
193
+ nn.Sequential: Sequential container of residual blocks.
194
+ """
195
+ strides = [stride] + [1] * (num_blocks - 1)
196
+ layers = []
197
+ for stride in strides:
198
+ layers.append(block(self.in_planes, planes, stride))
199
+ self.in_planes = planes * block.expansion # new in plane = plane * expansion factor(1 in this case)
200
+
201
+ return nn.Sequential(*layers)
202
+
203
+ def forward(self, x):
204
+ """Forward pass through ResNet.
205
+
206
+ Args:
207
+ x (torch.Tensor): Input tensor of shape (batch_size, in_band, height, width).
208
+
209
+ Returns:
210
+ torch.Tensor: Output logits of shape (batch_size, n_out).
211
+ """
212
+ out = F.relu(self.bn1(self.conv1(x)))
213
+ out = self.layer1(out)
214
+ out = self.layer2(out)
215
+ out = self.layer3(out)
216
+ # todo check this one collapse 1X1Xnband fixme
217
+ # colaps to nband
218
+ out = F.avg_pool2d(out, (out.size(2), out.size(3)))
219
+ ## colapse to batch time all element size
220
+ out = out.view(out.size(0), -1)
221
+ out = self.linear(out)
222
+ return out
223
+
224
+
225
+ def resnet20(in_size, n_bands, n_out):
226
+ """Create ResNet-20 model (20 layers: 6*3 + 2).
227
+
228
+ Args:
229
+ in_size (int): Input image size.
230
+ n_bands (int): Number of input channels.
231
+ n_out (int): Number of output classes.
232
+
233
+ Returns:
234
+ ResNet: ResNet-20 model.
235
+ """
236
+ return ResNet(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
237
+
238
+
239
+ def resnet32(in_size, n_bands, n_out):
240
+ """Create ResNet-32 model (32 layers: 6*5 + 2).
241
+
242
+ Args:
243
+ in_size (int): Input image size.
244
+ n_bands (int): Number of input channels.
245
+ n_out (int): Number of output classes.
246
+
247
+ Returns:
248
+ ResNet: ResNet-32 model.
249
+ """
250
+ return ResNet(in_size, n_bands, n_out, BasicBlock, [5, 5, 5])
251
+
252
+ def resnet44():
253
+ """Create ResNet-44 model (44 layers: 6*7 + 2).
254
+
255
+ Note: This function has incomplete signature and may not work correctly.
256
+
257
+ Returns:
258
+ ResNet: ResNet-44 model.
259
+ """
260
+ return ResNet(BasicBlock, BasicBlock, [7, 7, 7])
261
+
262
+
263
+ def resnet56(in_size, n_bands, n_out):
264
+ """Create ResNet-56 model (56 layers: 6*9 + 2).
265
+
266
+ Args:
267
+ in_size (int): Input image size.
268
+ n_bands (int): Number of input channels.
269
+ n_out (int): Number of output classes.
270
+
271
+ Returns:
272
+ ResNet: ResNet-56 model.
273
+ """
274
+ return ResNet(in_size, n_bands, n_out, BasicBlock, [9, 9, 9])
275
+
276
+
277
+ #def resnet110():
278
+ # return ResNet(BasicBlock, [18, 18, 18])
279
+
280
+
281
+ #def resnet1202():
282
+ # return ResNet(BasicBlock, [200, 200, 200])
283
+
284
+
285
+ #def test(net):
286
+ # total_params = 0
287
+ #
288
+ # for x in filter(lambda p: p.requires_grad, net.parameters()):
289
+ # total_params += np.prod(x.data.numpy().shape)
290
+ # print("Total number of params", total_params)
291
+ # print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size()) > 1, net.parameters()))))
292
+
293
+
294
+ #if __name__ == "__main__":
295
+ # for net_name in __all__:
296
+ # if net_name.startswith('resnet'):
297
+ # print(net_name)
298
+ # test(globals()[net_name]())
299
+ # print()
300
+
301
+
302
+ class ResNetYear(nn.Module):
303
+ """ResNet architecture with year as additional input feature.
304
+
305
+ Similar to ResNet but accepts a year value as additional input, which is concatenated
306
+ with the conv features before the final linear layer. Useful for temporal classification tasks.
307
+
308
+ Attributes:
309
+ in_planes (int): Current number of input planes, updated as layers are built.
310
+ conv1 (nn.Conv2d): Initial convolutional layer.
311
+ bn1 (nn.BatchNorm2d): Batch normalization after initial conv.
312
+ layer1 (nn.Sequential): First stage of residual blocks.
313
+ layer2 (nn.Sequential): Second stage of residual blocks.
314
+ layer3 (nn.Sequential): Third stage of residual blocks.
315
+ linear (nn.Linear): Final fully connected layer (takes 2*128+1 inputs for year feature).
316
+ """
317
+ # TODO check size before dense and size
318
+ #(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
319
+ def __init__(self, size, in_band, n_out, block, num_blocks):
320
+ """Initialize ResNetYear model.
321
+
322
+ Args:
323
+ size (int): Input image size (not currently used in implementation).
324
+ in_band (int): Number of input channels/bands.
325
+ n_out (int): Number of output classes.
326
+ block: Block class to use (typically BasicBlock).
327
+ num_blocks (list): List of integers specifying number of blocks in each stage.
328
+ """
329
+
330
+ super().__init__()
331
+ # in plane is updated as we _make_layer (multiplied by block expansion)
332
+ #self.in_planes = 16
333
+ self.in_planes = 2*32
334
+ # go from in band to in_planes == input of first block
335
+ self.conv1 = nn.Conv2d(in_band, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
336
+ self.bn1 = nn.BatchNorm2d(self.in_planes)
337
+ # 3 layer of n block of 2 conv N layer = 3*n*2=6n +1 input conv +1 output dense=6n+2 layer
338
+ self.layer1 = self._make_layer(block, self.in_planes, num_blocks[0], stride=1)
339
+ self.layer2 = self._make_layer(block, 2*64, num_blocks[1], stride=2)
340
+ self.layer3 = self._make_layer(block, 2*128, num_blocks[2], stride=2)
341
+ self.linear = nn.Linear(2*128+1, n_out)
342
+ self.apply(_weights_init)
343
+
344
+
345
+ def _make_layer(self, block, planes, num_blocks, stride):
346
+ """Create a stage of residual blocks.
347
+
348
+ Args:
349
+ block: Block class to instantiate.
350
+ planes (int): Number of output channels for blocks in this stage.
351
+ num_blocks (int): Number of blocks in this stage.
352
+ stride (int): Stride for first block (subsequent blocks use stride=1).
353
+
354
+ Returns:
355
+ nn.Sequential: Sequential container of residual blocks.
356
+ """
357
+ strides = [stride] + [1] * (num_blocks - 1)
358
+ layers = []
359
+ for stride in strides:
360
+ layers.append(block(self.in_planes, planes, stride))
361
+ self.in_planes = planes * block.expansion # new in plane = plane * expansion factor(1 in this case)
362
+
363
+ return nn.Sequential(*layers)
364
+
365
+ def forward(self, x, year):
366
+ """Forward pass through ResNetYear with year input.
367
+
368
+ Args:
369
+ x (torch.Tensor): Input tensor of shape (batch_size, in_band, height, width).
370
+ year (torch.Tensor): Year tensor of shape (batch_size, 1) representing temporal information.
371
+
372
+ Returns:
373
+ torch.Tensor: Output logits of shape (batch_size, n_out).
374
+ """
375
+ out = F.relu(self.bn1(self.conv1(x)))
376
+ out = self.layer1(out)
377
+ out = self.layer2(out)
378
+ out = self.layer3(out)
379
+ # todo check this one collapse 1X1Xnband fixme
380
+ # colaps to nband
381
+ out = F.avg_pool2d(out, (out.size(2), out.size(3)))
382
+ ## colapse to batch time all element size
383
+ out = out.view(out.size(0), -1)
384
+ out = torch.cat((out, year), dim=1)
385
+ out = self.linear(out)
386
+ return out
387
+
388
+ def resnet_year_20(in_size, n_bands, n_out):
389
+ """Create ResNetYear-20 model with year input.
390
+
391
+ Args:
392
+ in_size (int): Input image size.
393
+ n_bands (int): Number of input channels.
394
+ n_out (int): Number of output classes.
395
+
396
+ Returns:
397
+ ResNetYear: ResNetYear-20 model.
398
+ """
399
+ return ResNetYear(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
400
+
401
+
402
+ def resnet_year_32(in_size, n_bands, n_out):
403
+ """Create ResNetYear-32 model with year input.
404
+
405
+ Args:
406
+ in_size (int): Input image size.
407
+ n_bands (int): Number of input channels.
408
+ n_out (int): Number of output classes.
409
+
410
+ Returns:
411
+ ResNetYear: ResNetYear-32 model.
412
+ """
413
+ return ResNetYear(in_size, n_bands, n_out, BasicBlock, [5, 5, 5])
414
+
415
+ def resnet_year_44():
416
+ """Create ResNetYear-44 model with year input.
417
+
418
+ Note: This function has incomplete signature and may not work correctly.
419
+
420
+ Returns:
421
+ ResNetYear: ResNetYear-44 model.
422
+ """
423
+ return ResNetYear(BasicBlock, [7, 7, 7])
424
+
425
+
426
+ def resnet_year_56(in_size, n_bands, n_out):
427
+ """Create ResNetYear-56 model with year input.
428
+
429
+ Args:
430
+ in_size (int): Input image size.
431
+ n_bands (int): Number of input channels.
432
+ n_out (int): Number of output classes.
433
+
434
+ Returns:
435
+ ResNetYear: ResNetYear-56 model.
436
+ """
437
+ return ResNetYear(in_size, n_bands, n_out, BasicBlock, [9, 9, 9])
@@ -0,0 +1,260 @@
1
+ """Classification statistics and evaluation utilities for PyTorch models.
2
+
3
+ This module provides classes for computing classification metrics and exporting
4
+ misclassified samples to geospatial formats for analysis.
5
+ """
6
+
7
+ import logging
8
+ import math
9
+ import os
10
+ from collections import OrderedDict
11
+ from typing import Optional, Literal, List
12
+
13
+ import fiona
14
+ import numpy as np
15
+ import torch
16
+ from fiona.crs import from_epsg
17
+ from shapely.geometry import mapping
18
+ from torch.utils.data import DataLoader
19
+ from torchmetrics import F1Score, Accuracy, Recall, Precision
20
+ from torchmetrics.classification import MulticlassConfusionMatrix
21
+ from tqdm import tqdm
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class ClassificationStats:
27
+ """Compute and display classification metrics for multi-class problems.
28
+
29
+ Calculates F1, accuracy, precision, recall, and confusion matrix using torchmetrics.
30
+ Supports multiple averaging modes (micro, macro, weighted, none).
31
+
32
+ Attributes:
33
+ f1_type (List): List of averaging modes for F1 score.
34
+ accuracy_type (List): List of averaging modes for accuracy.
35
+ precision_type (List): List of averaging modes for precision.
36
+ recall_type (List): List of averaging modes for recall.
37
+ f1_f (dict): Dictionary of F1Score metric objects.
38
+ confusion_f (MulticlassConfusionMatrix): Confusion matrix metric.
39
+ accuracy_f (dict): Dictionary of Accuracy metric objects.
40
+ recall_f (dict): Dictionary of Recall metric objects.
41
+ precision_f (dict): Dictionary of Precision metric objects.
42
+ computed (bool): Flag indicating if metrics have been computed.
43
+ num_class (int): Number of classes.
44
+ category_name (list, optional): Names of categories for display.
45
+ """
46
+
47
+
48
+
49
+ def __init__(self, num_class, device, category_name=None):
50
+ """Initialize ClassificationStats.
51
+
52
+ Args:
53
+ num_class (int): Number of classes in the classification problem.
54
+ device (str): Device to run computations on ('cpu' or 'cuda').
55
+ category_name (list, optional): List of category names for display. Defaults to None.
56
+ """
57
+ #, "micro", "macro", "weighted"
58
+ self.f1_type : List[Optional[Literal["micro", "macro", "weighted", "none"]]]= ["none", "micro"]
59
+ self.accuracy_type: List[Optional[Literal["micro", "macro", "weighted", "none"]]] = ["micro"]
60
+ self.precision_type: List[Optional[Literal["micro", "macro", "weighted", "none"]]] = ["none"]
61
+ self.recall_type: List[Optional[Literal["micro", "macro", "weighted", "none"]]] = ["none"]
62
+
63
+ self.f1_f = {t: F1Score(task="multiclass", average=t, num_classes=num_class).to(device) for t in self.f1_type}
64
+ self.confusion_f = MulticlassConfusionMatrix(num_classes=num_class).to(device)
65
+
66
+ self.accuracy_f = {t: Accuracy(task="multiclass", average=t, num_classes=num_class).to(device) for t in self.accuracy_type}
67
+ self.recall_f = {t: Recall(task="multiclass", average=t, num_classes=num_class) .to(device)for t in self.recall_type}
68
+ self.precision_f = {t: Precision(task="multiclass", average=t, num_classes=num_class).to(device) for t in self.precision_type}
69
+
70
+ self.computed = False
71
+
72
+ self.num_class = num_class
73
+
74
+ self.category_name = category_name
75
+
76
+ def compute(self, model, dataloader: DataLoader, device="cpu"):
77
+ """Compute classification metrics on a dataset.
78
+
79
+ Runs model inference on dataloader and accumulates metrics. Resets all metrics
80
+ before computation.
81
+
82
+ Args:
83
+ model (torch.nn.Module): Model to evaluate.
84
+ dataloader (DataLoader): DataLoader providing (inputs, labels, meta) batches.
85
+ device (str, optional): Device to run on. Defaults to "cpu".
86
+ """
87
+ model = model.to(device)
88
+ model.eval()
89
+
90
+ [self.f1_f[t].reset() for t in self.f1_type]
91
+ self.confusion_f.reset()
92
+ [self.accuracy_f[t].reset() for t in self.accuracy_type]
93
+ [self.precision_f[t].reset() for t in self.precision_type]
94
+ [self.recall_f[t].reset() for t in self.recall_type]
95
+
96
+ with torch.inference_mode():
97
+ with tqdm(total=len(dataloader), desc="Batch") as pbar:
98
+ for i, data in enumerate(dataloader):
99
+ # Every data instance is an input + label pair
100
+ inputs, labels, meta = data
101
+
102
+ if device is not None:
103
+ if isinstance(inputs, (list, tuple)):
104
+ inputs = tuple(map(lambda x: x.to(device), inputs)) # trace need tuple for input
105
+ else:
106
+ inputs = inputs.to(device)
107
+
108
+ labels = labels.to(device, non_blocking=True)
109
+
110
+ output = model(*inputs)
111
+
112
+ [self.f1_f[t](output, labels) for t in self.f1_type]
113
+ self.confusion_f(output, labels)
114
+ [self.accuracy_f[t](output, labels) for t in self.accuracy_type]
115
+ [self.precision_f[t](output, labels) for t in self.precision_type]
116
+ [self.recall_f[t](output, labels).detach() for t in self.recall_type]
117
+
118
+
119
+ self.computed = True
120
+
121
+ def display(self):
122
+ """Display computed metrics to console.
123
+
124
+ Prints category names (if provided), F1 scores, accuracy, precision, recall,
125
+ and confusion matrix.
126
+ """
127
+ if self.category_name is not None:
128
+ for i, name in enumerate(self.category_name):
129
+ logger.info(f'{i}: {name}')
130
+
131
+ for t, val in self.f1_f.items():
132
+ logger.info(f'f1 {t}: {val.compute().detach().cpu().numpy()}')
133
+
134
+ for t, val in self.accuracy_f.items():
135
+ logger.info(f'accuracy {t}: {val.compute().detach().cpu().numpy()}')
136
+
137
+ for t, val in self.precision_f.items():
138
+ logger.info(f'precision {t}: {val.compute().detach().cpu().numpy()}')
139
+
140
+ for t, val in self.recall_f.items():
141
+ logger.info(f'recall {t}: {val.compute().detach().cpu().numpy()}')
142
+
143
+ logger.info(f'Confusion matrix:\n{self.confusion_f.compute()}')
144
+
145
+ def to_file(self, path):
146
+ """Write computed metrics to a file.
147
+
148
+ Args:
149
+ path (str): Output file path.
150
+ """
151
+ with open(path, 'w') as f:
152
+
153
+ if self.category_name is not None:
154
+ for i, name in enumerate(self.category_name):
155
+ f.write(f'{i}: {name}\n')
156
+
157
+ for t, val in self.f1_f.items():
158
+ f.write(f'f1 {t}: {val.compute().detach().cpu().numpy()}\n')
159
+
160
+ for t, val in self.accuracy_f.items():
161
+ f.write(f'accuracy {t}: {val.compute().detach().cpu().numpy()}\n')
162
+
163
+ for t, val in self.precision_f.items():
164
+ f.write(f'precision {t}: {val.compute().detach().cpu().numpy()}\n')
165
+
166
+ for t, val in self.recall_f.items():
167
+ f.write(f'recall {t}: {val.compute().detach().cpu().numpy()}\n')
168
+
169
+ confusion = self.confusion_f.compute().detach().cpu().numpy()
170
+ n_digit = math.ceil(math.log10(confusion.max())) + 1
171
+ np.savetxt(f, confusion, fmt=f'%{n_digit}.0d', delimiter=' ', newline=os.linesep)
172
+
173
+
174
+
175
+
176
+ class BadlyClassifyToGPKG:
177
+ """Export misclassified samples to GeoPackage format for spatial analysis.
178
+
179
+ Identifies samples where model predictions don't match reference labels and exports
180
+ them as point geometries with prediction and reference label attributes.
181
+
182
+ Attributes:
183
+ results (list): List of misclassified sample records with geometry and properties.
184
+ """
185
+
186
+ def __init__(self):
187
+ """Initialize BadlyClassifyToGPKG with empty results list."""
188
+ self.results = []
189
+
190
+ def compute(self, model, dataloader: DataLoader, device="cpu"):
191
+ """Identify misclassified samples from model predictions.
192
+
193
+ Runs inference on dataloader and stores records for samples where prediction
194
+ differs from reference label.
195
+
196
+ Args:
197
+ model (torch.nn.Module): Model to evaluate.
198
+ dataloader (DataLoader): DataLoader providing (inputs, labels, meta) batches.
199
+ Meta must contain geometry information.
200
+ device (str, optional): Device to run on. Defaults to "cpu".
201
+ """
202
+
203
+ self.results = []
204
+
205
+ model = model.to(device)
206
+ model.eval()
207
+
208
+ with torch.inference_mode():
209
+ with tqdm(total=len(dataloader), desc="Batch") as pbar:
210
+ for i, data in enumerate(dataloader):
211
+ # Every data instance is an input + label pair
212
+ inputs, labels, meta = data
213
+
214
+ if device is not None:
215
+ if isinstance(inputs, (list, tuple)):
216
+ inputs = tuple(map(lambda x: x.to(device), inputs)) # trace need tuple for input
217
+ else:
218
+ inputs = inputs.to(device)
219
+
220
+ labels = labels.to(device, non_blocking=True)
221
+
222
+ output = model(*inputs)
223
+
224
+ output = torch.argmax(output, dim=1)
225
+ output = output.detach().cpu().numpy()
226
+ labels = labels.detach().cpu().numpy()
227
+
228
+ for o,l, m in zip(output, labels, meta):
229
+ if o != l:
230
+ rec ={'geometry': mapping(m.geometry),
231
+ 'properties': OrderedDict([
232
+ ('Ref_label', int(l)),
233
+ ('Pred_label', int(o)),
234
+ ])
235
+ }
236
+ self.results.append(rec)
237
+
238
+
239
+ def to_file(self, path, crs="4326"):
240
+ """Write misclassified samples to GeoPackage file.
241
+
242
+ Args:
243
+ path (str): Output GeoPackage file path.
244
+ crs (str, optional): EPSG code for coordinate reference system. Defaults to "4326".
245
+ """
246
+ #('Class', 'float:16')
247
+
248
+ schema= {'geometry': 'Point',
249
+ 'properties': OrderedDict([('Ref_label', 'int'),
250
+ ('Pred_label', 'int')])
251
+ }
252
+ crs = from_epsg(crs)
253
+
254
+ with fiona.open(path, 'w',
255
+ driver='GPKG',
256
+ schema=schema,
257
+ crs=crs) as src:
258
+
259
+ for record in self.results:
260
+ src.write(record)