eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
eoml/torch/resnet.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
"""ResNet architecture implementations for PyTorch.
|
|
2
|
+
|
|
3
|
+
This module provides ResNet (Residual Network) implementations adapted from
|
|
4
|
+
https://colab.research.google.com/github/seyrankhademi/ResNet_CIFAR10/blob/master/CIFAR10_ResNet.ipynb
|
|
5
|
+
|
|
6
|
+
Includes ResNet variants (ResNet-20, ResNet-32, ResNet-56) and a year-aware variant
|
|
7
|
+
that incorporates temporal information as an additional input feature.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from eoml.torch.cnn.torch_utils import conv_out_sizes
|
|
14
|
+
from torch import nn
|
|
15
|
+
from torch.nn import init
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# taken from https://colab.research.google.com/github/seyrankhademi/ResNet_CIFAR10/blob/master/CIFAR10_ResNet.ipynb#scrollTo=V9Y2hYRwB-qg
|
|
19
|
+
|
|
20
|
+
# We define all the classes and function regarding the ResNet architecture in this code cell
|
|
21
|
+
#__all__ = ['resnet20']
|
|
22
|
+
#'ResNet','resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _weights_init(m):
|
|
26
|
+
"""Initialize CNN weights using Kaiming normal initialization.
|
|
27
|
+
|
|
28
|
+
Applies to Linear and Conv2d layers.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
m: PyTorch module/layer to initialize.
|
|
32
|
+
"""
|
|
33
|
+
classname = m.__class__.__name__
|
|
34
|
+
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
|
|
35
|
+
init.kaiming_normal_(m.weight)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class LambdaLayer(nn.Module):
|
|
39
|
+
"""Lambda layer for identity mapping between ResNet blocks with different feature map sizes.
|
|
40
|
+
|
|
41
|
+
Used for handling dimension changes in skip connections when feature map sizes differ.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
lambd: Lambda function to apply to input.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, lambd):
|
|
48
|
+
"""Initialize LambdaLayer.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
lambd: Lambda function for transforming input.
|
|
52
|
+
"""
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.lambd = lambd
|
|
55
|
+
|
|
56
|
+
def forward(self, x):
|
|
57
|
+
"""Forward pass applying the lambda function.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
x (torch.Tensor): Input tensor.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
torch.Tensor: Transformed tensor.
|
|
64
|
+
"""
|
|
65
|
+
return self.lambd(x)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# A basic block as shown in Fig.3 (right) in the paper consists of two convolutional blocks, each followed by a Bach-Norm layer.
|
|
69
|
+
# Every basic block is shortcuted in ResNet architecture to construct f(x)+x module.
|
|
70
|
+
# Expansion for option 'A' in the paper is equal to identity with extra zero entries padded
|
|
71
|
+
# for increasing dimensions between layers with different feature map size. This option introduces no extra parameter.
|
|
72
|
+
class BasicBlock(nn.Module):
|
|
73
|
+
"""Basic residual block for ResNet architecture.
|
|
74
|
+
|
|
75
|
+
Consists of two convolutional layers with batch normalization and a shortcut connection.
|
|
76
|
+
Implements the f(x) + x residual mapping.
|
|
77
|
+
|
|
78
|
+
Attributes:
|
|
79
|
+
expansion (int): Expansion factor for output channels (always 1 for BasicBlock).
|
|
80
|
+
conv1 (nn.Conv2d): First convolutional layer.
|
|
81
|
+
bn1 (nn.BatchNorm2d): Batch normalization after first conv.
|
|
82
|
+
conv2 (nn.Conv2d): Second convolutional layer.
|
|
83
|
+
bn2 (nn.BatchNorm2d): Batch normalization after second conv.
|
|
84
|
+
shortcut (nn.Sequential): Shortcut connection for identity mapping.
|
|
85
|
+
"""
|
|
86
|
+
# the output of a block keep the same size
|
|
87
|
+
expansion = 1
|
|
88
|
+
|
|
89
|
+
def __init__(self, in_planes, planes, stride=1, option='A'):
|
|
90
|
+
"""Initialize BasicBlock.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
in_planes (int): Number of input channels.
|
|
94
|
+
planes (int): Number of output channels.
|
|
95
|
+
stride (int, optional): Stride for first convolution. Defaults to 1.
|
|
96
|
+
option (str, optional): Shortcut option - 'A' for padding (CIFAR10) or 'B' for
|
|
97
|
+
projection. Defaults to 'A'.
|
|
98
|
+
"""
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
101
|
+
self.bn1 = nn.BatchNorm2d(planes)
|
|
102
|
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
|
103
|
+
self.bn2 = nn.BatchNorm2d(planes)
|
|
104
|
+
self.shortcut = nn.Sequential()
|
|
105
|
+
if stride != 1 or in_planes != planes:
|
|
106
|
+
if option == 'A':
|
|
107
|
+
"""
|
|
108
|
+
For CIFAR10 experiment, ResNet paper uses option A.
|
|
109
|
+
"""
|
|
110
|
+
## we colapase the side by the strid of 2
|
|
111
|
+
##then then we take the output size (2 time input(B) in initial conf) =>initial +padding = B+(2B/4)*2=2B/
|
|
112
|
+
self.shortcut = LambdaLayer(lambda x:
|
|
113
|
+
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant",
|
|
114
|
+
0))
|
|
115
|
+
elif option == 'B':
|
|
116
|
+
self.shortcut = nn.Sequential(
|
|
117
|
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
|
118
|
+
nn.BatchNorm2d(self.expansion * planes)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def forward(self, x):
|
|
122
|
+
"""Forward pass through the basic block.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
x (torch.Tensor): Input tensor.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
torch.Tensor: Output tensor after residual connection and activation.
|
|
129
|
+
"""
|
|
130
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
|
131
|
+
out = self.bn2(self.conv2(out))
|
|
132
|
+
out += self.shortcut(x)
|
|
133
|
+
out = F.relu(out)
|
|
134
|
+
return out
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# Stack of 3 times 2*n (n is the number of basic blocks) layers are used for making the ResNet model,
|
|
138
|
+
# where each 2n layers have feature maps of size {16,32,64}, respectively.
|
|
139
|
+
# The subsampling is performed by convolutions with a stride of 2.
|
|
140
|
+
class ResNet(nn.Module):
|
|
141
|
+
"""ResNet architecture for image classification.
|
|
142
|
+
|
|
143
|
+
Implements ResNet with 3 stages of residual blocks. The number of blocks in each stage
|
|
144
|
+
determines the depth (e.g., ResNet-20, ResNet-32, ResNet-56).
|
|
145
|
+
|
|
146
|
+
Attributes:
|
|
147
|
+
in_planes (int): Current number of input planes, updated as layers are built.
|
|
148
|
+
conv1 (nn.Conv2d): Initial convolutional layer.
|
|
149
|
+
bn1 (nn.BatchNorm2d): Batch normalization after initial conv.
|
|
150
|
+
layer1 (nn.Sequential): First stage of residual blocks.
|
|
151
|
+
layer2 (nn.Sequential): Second stage of residual blocks.
|
|
152
|
+
layer3 (nn.Sequential): Third stage of residual blocks.
|
|
153
|
+
linear (nn.Linear): Final fully connected layer for classification.
|
|
154
|
+
"""
|
|
155
|
+
# TODO check size before dense and size
|
|
156
|
+
#(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
|
|
157
|
+
def __init__(self, size, in_band, n_out, block, num_blocks):
|
|
158
|
+
"""Initialize ResNet model.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
size (int): Input image size (not currently used in implementation).
|
|
162
|
+
in_band (int): Number of input channels/bands.
|
|
163
|
+
n_out (int): Number of output classes.
|
|
164
|
+
block: Block class to use (typically BasicBlock).
|
|
165
|
+
num_blocks (list): List of integers specifying number of blocks in each stage.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
super().__init__()
|
|
169
|
+
# in plane is updated as we _make_layer (multiplied by block expansion)
|
|
170
|
+
#self.in_planes = 16
|
|
171
|
+
self.in_planes = 2*32
|
|
172
|
+
# go from in band to in_planes == input of first block
|
|
173
|
+
self.conv1 = nn.Conv2d(in_band, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
|
174
|
+
self.bn1 = nn.BatchNorm2d(self.in_planes)
|
|
175
|
+
# 3 layer of n block of 2 conv N layer = 3*n*2=6n +1 input conv +1 output dense=6n+2 layer
|
|
176
|
+
self.layer1 = self._make_layer(block, self.in_planes, num_blocks[0], stride=1)
|
|
177
|
+
self.layer2 = self._make_layer(block, 2*64, num_blocks[1], stride=2)
|
|
178
|
+
self.layer3 = self._make_layer(block, 2*128, num_blocks[2], stride=2)
|
|
179
|
+
self.linear = nn.Linear(2*128, n_out)
|
|
180
|
+
self.apply(_weights_init)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _make_layer(self, block, planes, num_blocks, stride):
|
|
184
|
+
"""Create a stage of residual blocks.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
block: Block class to instantiate.
|
|
188
|
+
planes (int): Number of output channels for blocks in this stage.
|
|
189
|
+
num_blocks (int): Number of blocks in this stage.
|
|
190
|
+
stride (int): Stride for first block (subsequent blocks use stride=1).
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
nn.Sequential: Sequential container of residual blocks.
|
|
194
|
+
"""
|
|
195
|
+
strides = [stride] + [1] * (num_blocks - 1)
|
|
196
|
+
layers = []
|
|
197
|
+
for stride in strides:
|
|
198
|
+
layers.append(block(self.in_planes, planes, stride))
|
|
199
|
+
self.in_planes = planes * block.expansion # new in plane = plane * expansion factor(1 in this case)
|
|
200
|
+
|
|
201
|
+
return nn.Sequential(*layers)
|
|
202
|
+
|
|
203
|
+
def forward(self, x):
|
|
204
|
+
"""Forward pass through ResNet.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
x (torch.Tensor): Input tensor of shape (batch_size, in_band, height, width).
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
torch.Tensor: Output logits of shape (batch_size, n_out).
|
|
211
|
+
"""
|
|
212
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
|
213
|
+
out = self.layer1(out)
|
|
214
|
+
out = self.layer2(out)
|
|
215
|
+
out = self.layer3(out)
|
|
216
|
+
# todo check this one collapse 1X1Xnband fixme
|
|
217
|
+
# colaps to nband
|
|
218
|
+
out = F.avg_pool2d(out, (out.size(2), out.size(3)))
|
|
219
|
+
## colapse to batch time all element size
|
|
220
|
+
out = out.view(out.size(0), -1)
|
|
221
|
+
out = self.linear(out)
|
|
222
|
+
return out
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def resnet20(in_size, n_bands, n_out):
|
|
226
|
+
"""Create ResNet-20 model (20 layers: 6*3 + 2).
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
in_size (int): Input image size.
|
|
230
|
+
n_bands (int): Number of input channels.
|
|
231
|
+
n_out (int): Number of output classes.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
ResNet: ResNet-20 model.
|
|
235
|
+
"""
|
|
236
|
+
return ResNet(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def resnet32(in_size, n_bands, n_out):
|
|
240
|
+
"""Create ResNet-32 model (32 layers: 6*5 + 2).
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
in_size (int): Input image size.
|
|
244
|
+
n_bands (int): Number of input channels.
|
|
245
|
+
n_out (int): Number of output classes.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
ResNet: ResNet-32 model.
|
|
249
|
+
"""
|
|
250
|
+
return ResNet(in_size, n_bands, n_out, BasicBlock, [5, 5, 5])
|
|
251
|
+
|
|
252
|
+
def resnet44():
|
|
253
|
+
"""Create ResNet-44 model (44 layers: 6*7 + 2).
|
|
254
|
+
|
|
255
|
+
Note: This function has incomplete signature and may not work correctly.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
ResNet: ResNet-44 model.
|
|
259
|
+
"""
|
|
260
|
+
return ResNet(BasicBlock, BasicBlock, [7, 7, 7])
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def resnet56(in_size, n_bands, n_out):
|
|
264
|
+
"""Create ResNet-56 model (56 layers: 6*9 + 2).
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
in_size (int): Input image size.
|
|
268
|
+
n_bands (int): Number of input channels.
|
|
269
|
+
n_out (int): Number of output classes.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
ResNet: ResNet-56 model.
|
|
273
|
+
"""
|
|
274
|
+
return ResNet(in_size, n_bands, n_out, BasicBlock, [9, 9, 9])
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
#def resnet110():
|
|
278
|
+
# return ResNet(BasicBlock, [18, 18, 18])
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
#def resnet1202():
|
|
282
|
+
# return ResNet(BasicBlock, [200, 200, 200])
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
#def test(net):
|
|
286
|
+
# total_params = 0
|
|
287
|
+
#
|
|
288
|
+
# for x in filter(lambda p: p.requires_grad, net.parameters()):
|
|
289
|
+
# total_params += np.prod(x.data.numpy().shape)
|
|
290
|
+
# print("Total number of params", total_params)
|
|
291
|
+
# print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size()) > 1, net.parameters()))))
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
#if __name__ == "__main__":
|
|
295
|
+
# for net_name in __all__:
|
|
296
|
+
# if net_name.startswith('resnet'):
|
|
297
|
+
# print(net_name)
|
|
298
|
+
# test(globals()[net_name]())
|
|
299
|
+
# print()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
class ResNetYear(nn.Module):
|
|
303
|
+
"""ResNet architecture with year as additional input feature.
|
|
304
|
+
|
|
305
|
+
Similar to ResNet but accepts a year value as additional input, which is concatenated
|
|
306
|
+
with the conv features before the final linear layer. Useful for temporal classification tasks.
|
|
307
|
+
|
|
308
|
+
Attributes:
|
|
309
|
+
in_planes (int): Current number of input planes, updated as layers are built.
|
|
310
|
+
conv1 (nn.Conv2d): Initial convolutional layer.
|
|
311
|
+
bn1 (nn.BatchNorm2d): Batch normalization after initial conv.
|
|
312
|
+
layer1 (nn.Sequential): First stage of residual blocks.
|
|
313
|
+
layer2 (nn.Sequential): Second stage of residual blocks.
|
|
314
|
+
layer3 (nn.Sequential): Third stage of residual blocks.
|
|
315
|
+
linear (nn.Linear): Final fully connected layer (takes 2*128+1 inputs for year feature).
|
|
316
|
+
"""
|
|
317
|
+
# TODO check size before dense and size
|
|
318
|
+
#(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
|
|
319
|
+
def __init__(self, size, in_band, n_out, block, num_blocks):
|
|
320
|
+
"""Initialize ResNetYear model.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
size (int): Input image size (not currently used in implementation).
|
|
324
|
+
in_band (int): Number of input channels/bands.
|
|
325
|
+
n_out (int): Number of output classes.
|
|
326
|
+
block: Block class to use (typically BasicBlock).
|
|
327
|
+
num_blocks (list): List of integers specifying number of blocks in each stage.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
super().__init__()
|
|
331
|
+
# in plane is updated as we _make_layer (multiplied by block expansion)
|
|
332
|
+
#self.in_planes = 16
|
|
333
|
+
self.in_planes = 2*32
|
|
334
|
+
# go from in band to in_planes == input of first block
|
|
335
|
+
self.conv1 = nn.Conv2d(in_band, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
|
336
|
+
self.bn1 = nn.BatchNorm2d(self.in_planes)
|
|
337
|
+
# 3 layer of n block of 2 conv N layer = 3*n*2=6n +1 input conv +1 output dense=6n+2 layer
|
|
338
|
+
self.layer1 = self._make_layer(block, self.in_planes, num_blocks[0], stride=1)
|
|
339
|
+
self.layer2 = self._make_layer(block, 2*64, num_blocks[1], stride=2)
|
|
340
|
+
self.layer3 = self._make_layer(block, 2*128, num_blocks[2], stride=2)
|
|
341
|
+
self.linear = nn.Linear(2*128+1, n_out)
|
|
342
|
+
self.apply(_weights_init)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _make_layer(self, block, planes, num_blocks, stride):
|
|
346
|
+
"""Create a stage of residual blocks.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
block: Block class to instantiate.
|
|
350
|
+
planes (int): Number of output channels for blocks in this stage.
|
|
351
|
+
num_blocks (int): Number of blocks in this stage.
|
|
352
|
+
stride (int): Stride for first block (subsequent blocks use stride=1).
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
nn.Sequential: Sequential container of residual blocks.
|
|
356
|
+
"""
|
|
357
|
+
strides = [stride] + [1] * (num_blocks - 1)
|
|
358
|
+
layers = []
|
|
359
|
+
for stride in strides:
|
|
360
|
+
layers.append(block(self.in_planes, planes, stride))
|
|
361
|
+
self.in_planes = planes * block.expansion # new in plane = plane * expansion factor(1 in this case)
|
|
362
|
+
|
|
363
|
+
return nn.Sequential(*layers)
|
|
364
|
+
|
|
365
|
+
def forward(self, x, year):
|
|
366
|
+
"""Forward pass through ResNetYear with year input.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
x (torch.Tensor): Input tensor of shape (batch_size, in_band, height, width).
|
|
370
|
+
year (torch.Tensor): Year tensor of shape (batch_size, 1) representing temporal information.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
torch.Tensor: Output logits of shape (batch_size, n_out).
|
|
374
|
+
"""
|
|
375
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
|
376
|
+
out = self.layer1(out)
|
|
377
|
+
out = self.layer2(out)
|
|
378
|
+
out = self.layer3(out)
|
|
379
|
+
# todo check this one collapse 1X1Xnband fixme
|
|
380
|
+
# colaps to nband
|
|
381
|
+
out = F.avg_pool2d(out, (out.size(2), out.size(3)))
|
|
382
|
+
## colapse to batch time all element size
|
|
383
|
+
out = out.view(out.size(0), -1)
|
|
384
|
+
out = torch.cat((out, year), dim=1)
|
|
385
|
+
out = self.linear(out)
|
|
386
|
+
return out
|
|
387
|
+
|
|
388
|
+
def resnet_year_20(in_size, n_bands, n_out):
|
|
389
|
+
"""Create ResNetYear-20 model with year input.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
in_size (int): Input image size.
|
|
393
|
+
n_bands (int): Number of input channels.
|
|
394
|
+
n_out (int): Number of output classes.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
ResNetYear: ResNetYear-20 model.
|
|
398
|
+
"""
|
|
399
|
+
return ResNetYear(in_size, n_bands, n_out, BasicBlock, [3, 3, 3])
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def resnet_year_32(in_size, n_bands, n_out):
|
|
403
|
+
"""Create ResNetYear-32 model with year input.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
in_size (int): Input image size.
|
|
407
|
+
n_bands (int): Number of input channels.
|
|
408
|
+
n_out (int): Number of output classes.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
ResNetYear: ResNetYear-32 model.
|
|
412
|
+
"""
|
|
413
|
+
return ResNetYear(in_size, n_bands, n_out, BasicBlock, [5, 5, 5])
|
|
414
|
+
|
|
415
|
+
def resnet_year_44():
|
|
416
|
+
"""Create ResNetYear-44 model with year input.
|
|
417
|
+
|
|
418
|
+
Note: This function has incomplete signature and may not work correctly.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
ResNetYear: ResNetYear-44 model.
|
|
422
|
+
"""
|
|
423
|
+
return ResNetYear(BasicBlock, [7, 7, 7])
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def resnet_year_56(in_size, n_bands, n_out):
|
|
427
|
+
"""Create ResNetYear-56 model with year input.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
in_size (int): Input image size.
|
|
431
|
+
n_bands (int): Number of input channels.
|
|
432
|
+
n_out (int): Number of output classes.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
ResNetYear: ResNetYear-56 model.
|
|
436
|
+
"""
|
|
437
|
+
return ResNetYear(in_size, n_bands, n_out, BasicBlock, [9, 9, 9])
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""Classification statistics and evaluation utilities for PyTorch models.
|
|
2
|
+
|
|
3
|
+
This module provides classes for computing classification metrics and exporting
|
|
4
|
+
misclassified samples to geospatial formats for analysis.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import math
|
|
9
|
+
import os
|
|
10
|
+
from collections import OrderedDict
|
|
11
|
+
from typing import Optional, Literal, List
|
|
12
|
+
|
|
13
|
+
import fiona
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from fiona.crs import from_epsg
|
|
17
|
+
from shapely.geometry import mapping
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
from torchmetrics import F1Score, Accuracy, Recall, Precision
|
|
20
|
+
from torchmetrics.classification import MulticlassConfusionMatrix
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ClassificationStats:
|
|
27
|
+
"""Compute and display classification metrics for multi-class problems.
|
|
28
|
+
|
|
29
|
+
Calculates F1, accuracy, precision, recall, and confusion matrix using torchmetrics.
|
|
30
|
+
Supports multiple averaging modes (micro, macro, weighted, none).
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
f1_type (List): List of averaging modes for F1 score.
|
|
34
|
+
accuracy_type (List): List of averaging modes for accuracy.
|
|
35
|
+
precision_type (List): List of averaging modes for precision.
|
|
36
|
+
recall_type (List): List of averaging modes for recall.
|
|
37
|
+
f1_f (dict): Dictionary of F1Score metric objects.
|
|
38
|
+
confusion_f (MulticlassConfusionMatrix): Confusion matrix metric.
|
|
39
|
+
accuracy_f (dict): Dictionary of Accuracy metric objects.
|
|
40
|
+
recall_f (dict): Dictionary of Recall metric objects.
|
|
41
|
+
precision_f (dict): Dictionary of Precision metric objects.
|
|
42
|
+
computed (bool): Flag indicating if metrics have been computed.
|
|
43
|
+
num_class (int): Number of classes.
|
|
44
|
+
category_name (list, optional): Names of categories for display.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def __init__(self, num_class, device, category_name=None):
|
|
50
|
+
"""Initialize ClassificationStats.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
num_class (int): Number of classes in the classification problem.
|
|
54
|
+
device (str): Device to run computations on ('cpu' or 'cuda').
|
|
55
|
+
category_name (list, optional): List of category names for display. Defaults to None.
|
|
56
|
+
"""
|
|
57
|
+
#, "micro", "macro", "weighted"
|
|
58
|
+
self.f1_type : List[Optional[Literal["micro", "macro", "weighted", "none"]]]= ["none", "micro"]
|
|
59
|
+
self.accuracy_type: List[Optional[Literal["micro", "macro", "weighted", "none"]]] = ["micro"]
|
|
60
|
+
self.precision_type: List[Optional[Literal["micro", "macro", "weighted", "none"]]] = ["none"]
|
|
61
|
+
self.recall_type: List[Optional[Literal["micro", "macro", "weighted", "none"]]] = ["none"]
|
|
62
|
+
|
|
63
|
+
self.f1_f = {t: F1Score(task="multiclass", average=t, num_classes=num_class).to(device) for t in self.f1_type}
|
|
64
|
+
self.confusion_f = MulticlassConfusionMatrix(num_classes=num_class).to(device)
|
|
65
|
+
|
|
66
|
+
self.accuracy_f = {t: Accuracy(task="multiclass", average=t, num_classes=num_class).to(device) for t in self.accuracy_type}
|
|
67
|
+
self.recall_f = {t: Recall(task="multiclass", average=t, num_classes=num_class) .to(device)for t in self.recall_type}
|
|
68
|
+
self.precision_f = {t: Precision(task="multiclass", average=t, num_classes=num_class).to(device) for t in self.precision_type}
|
|
69
|
+
|
|
70
|
+
self.computed = False
|
|
71
|
+
|
|
72
|
+
self.num_class = num_class
|
|
73
|
+
|
|
74
|
+
self.category_name = category_name
|
|
75
|
+
|
|
76
|
+
def compute(self, model, dataloader: DataLoader, device="cpu"):
|
|
77
|
+
"""Compute classification metrics on a dataset.
|
|
78
|
+
|
|
79
|
+
Runs model inference on dataloader and accumulates metrics. Resets all metrics
|
|
80
|
+
before computation.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model (torch.nn.Module): Model to evaluate.
|
|
84
|
+
dataloader (DataLoader): DataLoader providing (inputs, labels, meta) batches.
|
|
85
|
+
device (str, optional): Device to run on. Defaults to "cpu".
|
|
86
|
+
"""
|
|
87
|
+
model = model.to(device)
|
|
88
|
+
model.eval()
|
|
89
|
+
|
|
90
|
+
[self.f1_f[t].reset() for t in self.f1_type]
|
|
91
|
+
self.confusion_f.reset()
|
|
92
|
+
[self.accuracy_f[t].reset() for t in self.accuracy_type]
|
|
93
|
+
[self.precision_f[t].reset() for t in self.precision_type]
|
|
94
|
+
[self.recall_f[t].reset() for t in self.recall_type]
|
|
95
|
+
|
|
96
|
+
with torch.inference_mode():
|
|
97
|
+
with tqdm(total=len(dataloader), desc="Batch") as pbar:
|
|
98
|
+
for i, data in enumerate(dataloader):
|
|
99
|
+
# Every data instance is an input + label pair
|
|
100
|
+
inputs, labels, meta = data
|
|
101
|
+
|
|
102
|
+
if device is not None:
|
|
103
|
+
if isinstance(inputs, (list, tuple)):
|
|
104
|
+
inputs = tuple(map(lambda x: x.to(device), inputs)) # trace need tuple for input
|
|
105
|
+
else:
|
|
106
|
+
inputs = inputs.to(device)
|
|
107
|
+
|
|
108
|
+
labels = labels.to(device, non_blocking=True)
|
|
109
|
+
|
|
110
|
+
output = model(*inputs)
|
|
111
|
+
|
|
112
|
+
[self.f1_f[t](output, labels) for t in self.f1_type]
|
|
113
|
+
self.confusion_f(output, labels)
|
|
114
|
+
[self.accuracy_f[t](output, labels) for t in self.accuracy_type]
|
|
115
|
+
[self.precision_f[t](output, labels) for t in self.precision_type]
|
|
116
|
+
[self.recall_f[t](output, labels).detach() for t in self.recall_type]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
self.computed = True
|
|
120
|
+
|
|
121
|
+
def display(self):
|
|
122
|
+
"""Display computed metrics to console.
|
|
123
|
+
|
|
124
|
+
Prints category names (if provided), F1 scores, accuracy, precision, recall,
|
|
125
|
+
and confusion matrix.
|
|
126
|
+
"""
|
|
127
|
+
if self.category_name is not None:
|
|
128
|
+
for i, name in enumerate(self.category_name):
|
|
129
|
+
logger.info(f'{i}: {name}')
|
|
130
|
+
|
|
131
|
+
for t, val in self.f1_f.items():
|
|
132
|
+
logger.info(f'f1 {t}: {val.compute().detach().cpu().numpy()}')
|
|
133
|
+
|
|
134
|
+
for t, val in self.accuracy_f.items():
|
|
135
|
+
logger.info(f'accuracy {t}: {val.compute().detach().cpu().numpy()}')
|
|
136
|
+
|
|
137
|
+
for t, val in self.precision_f.items():
|
|
138
|
+
logger.info(f'precision {t}: {val.compute().detach().cpu().numpy()}')
|
|
139
|
+
|
|
140
|
+
for t, val in self.recall_f.items():
|
|
141
|
+
logger.info(f'recall {t}: {val.compute().detach().cpu().numpy()}')
|
|
142
|
+
|
|
143
|
+
logger.info(f'Confusion matrix:\n{self.confusion_f.compute()}')
|
|
144
|
+
|
|
145
|
+
def to_file(self, path):
|
|
146
|
+
"""Write computed metrics to a file.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
path (str): Output file path.
|
|
150
|
+
"""
|
|
151
|
+
with open(path, 'w') as f:
|
|
152
|
+
|
|
153
|
+
if self.category_name is not None:
|
|
154
|
+
for i, name in enumerate(self.category_name):
|
|
155
|
+
f.write(f'{i}: {name}\n')
|
|
156
|
+
|
|
157
|
+
for t, val in self.f1_f.items():
|
|
158
|
+
f.write(f'f1 {t}: {val.compute().detach().cpu().numpy()}\n')
|
|
159
|
+
|
|
160
|
+
for t, val in self.accuracy_f.items():
|
|
161
|
+
f.write(f'accuracy {t}: {val.compute().detach().cpu().numpy()}\n')
|
|
162
|
+
|
|
163
|
+
for t, val in self.precision_f.items():
|
|
164
|
+
f.write(f'precision {t}: {val.compute().detach().cpu().numpy()}\n')
|
|
165
|
+
|
|
166
|
+
for t, val in self.recall_f.items():
|
|
167
|
+
f.write(f'recall {t}: {val.compute().detach().cpu().numpy()}\n')
|
|
168
|
+
|
|
169
|
+
confusion = self.confusion_f.compute().detach().cpu().numpy()
|
|
170
|
+
n_digit = math.ceil(math.log10(confusion.max())) + 1
|
|
171
|
+
np.savetxt(f, confusion, fmt=f'%{n_digit}.0d', delimiter=' ', newline=os.linesep)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class BadlyClassifyToGPKG:
|
|
177
|
+
"""Export misclassified samples to GeoPackage format for spatial analysis.
|
|
178
|
+
|
|
179
|
+
Identifies samples where model predictions don't match reference labels and exports
|
|
180
|
+
them as point geometries with prediction and reference label attributes.
|
|
181
|
+
|
|
182
|
+
Attributes:
|
|
183
|
+
results (list): List of misclassified sample records with geometry and properties.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(self):
|
|
187
|
+
"""Initialize BadlyClassifyToGPKG with empty results list."""
|
|
188
|
+
self.results = []
|
|
189
|
+
|
|
190
|
+
def compute(self, model, dataloader: DataLoader, device="cpu"):
|
|
191
|
+
"""Identify misclassified samples from model predictions.
|
|
192
|
+
|
|
193
|
+
Runs inference on dataloader and stores records for samples where prediction
|
|
194
|
+
differs from reference label.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
model (torch.nn.Module): Model to evaluate.
|
|
198
|
+
dataloader (DataLoader): DataLoader providing (inputs, labels, meta) batches.
|
|
199
|
+
Meta must contain geometry information.
|
|
200
|
+
device (str, optional): Device to run on. Defaults to "cpu".
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
self.results = []
|
|
204
|
+
|
|
205
|
+
model = model.to(device)
|
|
206
|
+
model.eval()
|
|
207
|
+
|
|
208
|
+
with torch.inference_mode():
|
|
209
|
+
with tqdm(total=len(dataloader), desc="Batch") as pbar:
|
|
210
|
+
for i, data in enumerate(dataloader):
|
|
211
|
+
# Every data instance is an input + label pair
|
|
212
|
+
inputs, labels, meta = data
|
|
213
|
+
|
|
214
|
+
if device is not None:
|
|
215
|
+
if isinstance(inputs, (list, tuple)):
|
|
216
|
+
inputs = tuple(map(lambda x: x.to(device), inputs)) # trace need tuple for input
|
|
217
|
+
else:
|
|
218
|
+
inputs = inputs.to(device)
|
|
219
|
+
|
|
220
|
+
labels = labels.to(device, non_blocking=True)
|
|
221
|
+
|
|
222
|
+
output = model(*inputs)
|
|
223
|
+
|
|
224
|
+
output = torch.argmax(output, dim=1)
|
|
225
|
+
output = output.detach().cpu().numpy()
|
|
226
|
+
labels = labels.detach().cpu().numpy()
|
|
227
|
+
|
|
228
|
+
for o,l, m in zip(output, labels, meta):
|
|
229
|
+
if o != l:
|
|
230
|
+
rec ={'geometry': mapping(m.geometry),
|
|
231
|
+
'properties': OrderedDict([
|
|
232
|
+
('Ref_label', int(l)),
|
|
233
|
+
('Pred_label', int(o)),
|
|
234
|
+
])
|
|
235
|
+
}
|
|
236
|
+
self.results.append(rec)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def to_file(self, path, crs="4326"):
|
|
240
|
+
"""Write misclassified samples to GeoPackage file.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
path (str): Output GeoPackage file path.
|
|
244
|
+
crs (str, optional): EPSG code for coordinate reference system. Defaults to "4326".
|
|
245
|
+
"""
|
|
246
|
+
#('Class', 'float:16')
|
|
247
|
+
|
|
248
|
+
schema= {'geometry': 'Point',
|
|
249
|
+
'properties': OrderedDict([('Ref_label', 'int'),
|
|
250
|
+
('Pred_label', 'int')])
|
|
251
|
+
}
|
|
252
|
+
crs = from_epsg(crs)
|
|
253
|
+
|
|
254
|
+
with fiona.open(path, 'w',
|
|
255
|
+
driver='GPKG',
|
|
256
|
+
schema=schema,
|
|
257
|
+
crs=crs) as src:
|
|
258
|
+
|
|
259
|
+
for record in self.results:
|
|
260
|
+
src.write(record)
|