eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,507 @@
1
+ """Alternative CNN architectures for image classification.
2
+
3
+ This module provides various CNN architectures with different configurations
4
+ of convolutional and dense layers. Includes models with and without batch
5
+ normalization, dropout, and max pooling variations.
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from eoml.torch.cnn.torch_utils import conv_out_sizes
12
+ from torch import nn
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Conv2NormPlanet(nn.Module):
18
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
19
+ self.in_size = in_size
20
+ self.n_bands = n_bands
21
+
22
+ self.conv = [7,5,3]
23
+ self.pad = 0
24
+ self.stride = [2,1,2]
25
+
26
+ self.n_filter = [256, 3*128]
27
+
28
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
29
+
30
+ logger.debug(f"Conv2NormPlanet input sizes: {self.input_sizes}")
31
+
32
+ self.denses = [2*2048, 2*2048, 2048]
33
+
34
+ super().__init__()
35
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
36
+ padding=self.pad, stride= self.stride[0])
37
+ self.conv1_bn = nn.BatchNorm2d(self.n_filter[0])
38
+ self.conv2 = nn.Conv2d(in_channels=self.n_filter[0], out_channels=self.n_filter[1], kernel_size=self.conv[1],
39
+ padding=self.pad, stride= self.stride[1])
40
+ self.conv2_bn = nn.BatchNorm2d(self.n_filter[1])
41
+ self.pool1 = nn.MaxPool2d(self.conv[-1], stride= self.stride[2])
42
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
43
+ self.drop1 = nn.Dropout(p_drop)
44
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
45
+ self.drop2 = nn.Dropout(p_drop)
46
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
47
+ self.drop3 = nn.Dropout(p_drop)
48
+ self.fc4 = nn.Linear(self.denses[2], n_out)
49
+
50
+ def forward(self, x):
51
+ x = F.relu(self.conv1_bn(self.conv1(x)))
52
+ x = F.relu(self.conv2_bn(self.conv2(x)))
53
+ x = self.pool1(x)
54
+ x = torch.flatten(x, 1)
55
+ # flatten all dimensions except batch
56
+ x = F.relu(self.fc1(x))
57
+ x = self.drop1(x)
58
+ x = F.relu(self.fc2(x))
59
+ x = self.drop2(x)
60
+ x = F.relu(self.fc3(x))
61
+ x = self.drop3(x)
62
+
63
+ #F.softmax(
64
+ x = self.fc4(x)
65
+
66
+ return x
67
+
68
+
69
+
70
+
71
+ class AlexNetMod(nn.Module):
72
+ def __init__(self, num_classes: int, bands, dropout: float = 0.5) -> None:
73
+ super().__init__()
74
+ self.features = nn.Sequential(
75
+ nn.Conv2d(bands, 128, kernel_size=5, stride=2, padding=1),
76
+ nn.ReLU(inplace=True),
77
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
78
+ nn.Conv2d(64, 384, kernel_size=5, stride=1, padding=1),
79
+ nn.ReLU(inplace=True),
80
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
81
+ nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
82
+ nn.ReLU(inplace=True),
83
+ #nn.Conv2d(384, 256, kernel_size=3, padding=1),
84
+ #nn.ReLU(inplace=True),
85
+ #nn.Conv2d(256, 256, kernel_size=3, padding=1),
86
+ #nn.ReLU(inplace=True),
87
+ #nn.MaxPool2d(kernel_size=3, stride=2),
88
+ )
89
+ self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
90
+ self.classifier = nn.Sequential(
91
+ nn.Dropout(p=dropout),
92
+ nn.Linear(256 * 6 * 6, 4096),
93
+ nn.ReLU(inplace=True),
94
+ nn.Dropout(p=dropout),
95
+ nn.Linear(4096, 4096),
96
+ nn.ReLU(inplace=True),
97
+ nn.Linear(4096, num_classes),
98
+ )
99
+
100
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
101
+ x = self.features(x)
102
+ x = self.avgpool(x)
103
+ x = torch.flatten(x, 1)
104
+ x = self.classifier(x)
105
+ return x
106
+
107
+
108
+
109
+
110
+ class Conv3Dense3(nn.Module):
111
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
112
+ self.in_size = in_size
113
+ self.n_bands = n_bands
114
+
115
+ self.conv = [5, 5, 3]
116
+ self.pad = 0
117
+ self.stride = 1
118
+
119
+ self.n_filter = [128, 128, 256]
120
+
121
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
122
+
123
+ self.denses = [2048, 2048, 2048]
124
+
125
+ super().__init__()
126
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
127
+ padding=self.pad)
128
+ # self.pool = nn.MaxPool2d(2, 2)
129
+ self.conv2 = nn.Conv2d(in_channels=self.n_filter[0], out_channels=self.n_filter[1], kernel_size=self.conv[1],
130
+ padding=self.pad)
131
+ self.conv3 = nn.Conv2d(in_channels=self.n_filter[1], out_channels=self.n_filter[2], kernel_size=self.conv[2],
132
+ padding=self.pad)
133
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
134
+ self.drop1 = nn.Dropout(p_drop)
135
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
136
+ self.drop2 = nn.Dropout(p_drop)
137
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
138
+ self.drop3 = nn.Dropout(p_drop)
139
+ self.fc4 = nn.Linear(self.denses[2], n_out)
140
+
141
+ class Conv3Dense3(nn.Module):
142
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
143
+ self.in_size = in_size
144
+ self.n_bands = n_bands
145
+
146
+ self.conv = [5, 5, 3]
147
+ self.pad = 0
148
+ self.stride = 1
149
+
150
+ self.n_filter = [128, 128, 256]
151
+
152
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
153
+
154
+ self.denses = [2048, 2048, 2048]
155
+
156
+ super().__init__()
157
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
158
+ padding=self.pad)
159
+ # self.pool = nn.MaxPool2d(2, 2)
160
+ self.conv2 = nn.Conv2d(in_channels=self.n_filter[0], out_channels=self.n_filter[1],
161
+ kernel_size=self.conv[1],
162
+ padding=self.pad)
163
+ self.conv3 = nn.Conv2d(in_channels=self.n_filter[1], out_channels=self.n_filter[2],
164
+ kernel_size=self.conv[2],
165
+ padding=self.pad)
166
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
167
+ self.drop1 = nn.Dropout(p_drop)
168
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
169
+ self.drop2 = nn.Dropout(p_drop)
170
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
171
+ self.drop3 = nn.Dropout(p_drop)
172
+ self.fc4 = nn.Linear(self.denses[2], n_out)
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.conv1(x))
176
+ x = F.relu(self.conv2(x))
177
+ x = F.relu(self.conv3(x))
178
+ x = torch.flatten(x, 1)
179
+ # flatten all dimensions except batch
180
+ x = F.relu(self.fc1(x))
181
+ x = self.drop1(x)
182
+ x = F.relu(self.fc2(x))
183
+ x = self.drop2(x)
184
+ x = F.relu(self.fc3(x))
185
+ x = self.drop3(x)
186
+
187
+ x = self.fc4(x)
188
+
189
+ return x
190
+
191
+ class Conv3Dense3Norm(nn.Module):
192
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
193
+ self.in_size = in_size
194
+ self.n_bands = n_bands
195
+
196
+ self.conv = [5, 5, 3]
197
+ self.pad = 0
198
+ self.stride = 1
199
+
200
+ self.n_filter = [128, 128, 256]
201
+
202
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
203
+
204
+ self.denses = [2048, 2048,2048]
205
+
206
+ super().__init__()
207
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
208
+ padding=self.pad)
209
+ self.conv1_bn = nn.BatchNorm2d(self.n_filter[0])
210
+ # self.pool = nn.MaxPool2d(2, 2)
211
+ self.conv2 = nn.Conv2d(in_channels=self.n_filter[0], out_channels=self.n_filter[1], kernel_size=self.conv[1],
212
+ padding=self.pad)
213
+ self.conv2_bn = nn.BatchNorm2d(self.n_filter[1])
214
+ self.conv3 = nn.Conv2d(in_channels=self.n_filter[1], out_channels=self.n_filter[2], kernel_size=self.conv[2],
215
+ padding=self.pad)
216
+ self.conv3_bn = nn.BatchNorm2d(self.n_filter[2])
217
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
218
+ self.drop1 = nn.Dropout(p_drop)
219
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
220
+ self.drop2 = nn.Dropout(p_drop)
221
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
222
+ self.drop3 = nn.Dropout(p_drop)
223
+ self.fc4 = nn.Linear(self.denses[2], n_out)
224
+
225
+ def forward(self, x):
226
+
227
+ x = F.relu(self.conv1_bn(self.conv1(x)))
228
+ x = F.relu(self.conv2_bn(self.conv2(x)))
229
+ x = F.relu(self.conv3_bn(self.conv3(x)))
230
+ x = torch.flatten(x, 1)
231
+ # flatten all dimensions except batch
232
+ x = F.relu(self.fc1(x))
233
+ x = self.drop1(x)
234
+ x = F.relu(self.fc2(x))
235
+ x = self.drop2(x)
236
+ x = F.relu(self.fc3(x))
237
+ x = self.drop3(x)
238
+
239
+ x = self.fc4(x)
240
+
241
+ return x
242
+
243
+
244
+ class Conv2Dense3(nn.Module):
245
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.2):
246
+ self.in_size = in_size
247
+ self.n_bands = n_bands
248
+
249
+ self.conv = [4, 3]
250
+ self.pad = 0
251
+ self.stride = 1
252
+
253
+ self.n_filter = [64, 128]
254
+
255
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
256
+
257
+ self.denses = [2048, 2048, 2048]
258
+
259
+ super().__init__()
260
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
261
+ padding=self.pad)
262
+ # self.pool = nn.MaxPool2d(2, 2)
263
+ self.conv2 = nn.Conv2d(in_channels=self.n_filter[0], out_channels=self.n_filter[1], kernel_size=self.conv[1],
264
+ padding=self.pad)
265
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
266
+ self.drop1 = nn.Dropout(p_drop)
267
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
268
+ self.drop2 = nn.Dropout(p_drop)
269
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
270
+ self.drop3 = nn.Dropout(p_drop)
271
+ self.fc4 = nn.Linear(self.denses[2], n_out)
272
+
273
+ def forward(self, x):
274
+ x = F.relu(self.conv1(x))
275
+ x = F.relu(self.conv2(x))
276
+ x = torch.flatten(x, 1)
277
+ # flatten all dimensions except batch
278
+ x = F.relu(self.fc1(x))
279
+ x = self.drop1(x)
280
+ x = F.relu(self.fc2(x))
281
+ x = self.drop2(x)
282
+ x = F.relu(self.fc3(x))
283
+ x = self.drop3(x)
284
+
285
+ x = self.fc4(x)
286
+
287
+ return x
288
+
289
+
290
+ class Conv2DropDense3(nn.Module):
291
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.2):
292
+ self.in_size = in_size
293
+ self.n_bands = n_bands
294
+
295
+ self.conv = [4, 3]
296
+ self.pad = 0
297
+ self.stride = 1
298
+
299
+ self.n_filter = [64, 128]
300
+
301
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
302
+
303
+ self.denses = [2048, 2048, 2048]
304
+
305
+ super().__init__()
306
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
307
+ padding=self.pad)
308
+ # self.pool = nn.MaxPool2d(2, 2)
309
+ self.conv2 = nn.Conv2d(in_channels=self.n_filter[0], out_channels=self.n_filter[1], kernel_size=self.conv[1],
310
+ padding=self.pad)
311
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
312
+ self.drop1 = nn.Dropout(p_drop)
313
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
314
+ self.drop2 = nn.Dropout(p_drop)
315
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
316
+ self.drop3 = nn.Dropout(p_drop)
317
+ self.fc4 = nn.Linear(self.denses[2], n_out)
318
+
319
+ def forward(self, x):
320
+ x = F.relu(self.conv1(x))
321
+ x = F.relu(self.conv2(x))
322
+ x = torch.flatten(x, 1)
323
+ # flatten all dimensions except batch
324
+ x = F.relu(self.fc1(x))
325
+ x = self.drop1(x)
326
+ x = F.relu(self.fc2(x))
327
+ x = self.drop2(x)
328
+ x = F.relu(self.fc3(x))
329
+ x = self.drop3(x)
330
+
331
+ x = self.fc4(x)
332
+
333
+ return x
334
+
335
+
336
+ class ConvJavaSmall(nn.Module):
337
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
338
+ self.in_size = in_size
339
+ self.n_bands = n_bands
340
+
341
+ self.conv = [4,2]
342
+ self.pad = 0
343
+ self.stride = [1,2]
344
+
345
+ self.n_filter = [128]
346
+
347
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
348
+
349
+ self.denses = [2048, 2048, 2048]
350
+
351
+ super().__init__()
352
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
353
+ padding=self.pad)
354
+ self.pool1 = nn.MaxPool2d(2)
355
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
356
+ self.drop1 = nn.Dropout(p_drop)
357
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
358
+ self.drop2 = nn.Dropout(p_drop)
359
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
360
+ self.drop3 = nn.Dropout(p_drop)
361
+ self.fc4 = nn.Linear(self.denses[2], n_out)
362
+
363
+ def forward(self, x):
364
+ x = F.relu(self.conv1(x))
365
+ x = self.pool1(x)
366
+ x = torch.flatten(x, 1)
367
+ # flatten all dimensions except batch
368
+ x = F.relu(self.fc1(x))
369
+ x = self.drop1(x)
370
+ x = F.relu(self.fc2(x))
371
+ x = self.drop2(x)
372
+ x = F.relu(self.fc3(x))
373
+ x = self.drop3(x)
374
+
375
+ #F.softmax(
376
+ x = self.fc4(x)
377
+
378
+ return x
379
+
380
+
381
+
382
+ class ConvJavaSmallNorm(nn.Module):
383
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
384
+ self.in_size = in_size
385
+ self.n_bands = n_bands
386
+
387
+ self.conv = [4,2]
388
+ self.pad = 0
389
+ self.stride = [1,2]
390
+
391
+ self.n_filter = [128]
392
+
393
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
394
+
395
+ self.denses = [2048, 2048, 2048]
396
+
397
+ super().__init__()
398
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
399
+ padding=self.pad)
400
+ self.conv1_bn = nn.BatchNorm2d(self.n_filter[0])
401
+ self.pool1 = nn.MaxPool2d(2)
402
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
403
+ self.drop1 = nn.Dropout(p_drop)
404
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
405
+ self.drop2 = nn.Dropout(p_drop)
406
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
407
+ self.drop3 = nn.Dropout(p_drop)
408
+ self.fc4 = nn.Linear(self.denses[2], n_out)
409
+
410
+ def forward(self, x):
411
+ x = F.relu(self.conv1_bn(self.conv1(x)))
412
+ x = self.pool1(x)
413
+ x = torch.flatten(x, 1)
414
+ # flatten all dimensions except batch
415
+ x = F.relu(self.fc1(x))
416
+ x = self.drop1(x)
417
+ x = F.relu(self.fc2(x))
418
+ x = self.drop2(x)
419
+ x = F.relu(self.fc3(x))
420
+ x = self.drop3(x)
421
+
422
+ #F.softmax(
423
+ x = self.fc4(x)
424
+
425
+ return x
426
+
427
+ class Conv2Norm(nn.Module):
428
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
429
+ self.in_size = in_size
430
+ self.n_bands = n_bands
431
+
432
+ self.conv = [5,3,2]
433
+ self.pad = 0
434
+ self.stride = [1,1,2]
435
+
436
+ self.n_filter = [256, 384]
437
+
438
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
439
+
440
+ self.denses = [2*2048, 2*2048, 2*2048]
441
+
442
+ super().__init__()
443
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
444
+ padding=self.pad)
445
+ self.conv1_bn = nn.BatchNorm2d(self.n_filter[0])
446
+ self.conv2 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[1], kernel_size=self.conv[1],
447
+ padding=self.pad)
448
+ self.conv2_bn = nn.BatchNorm2d(self.n_filter[1])
449
+ self.pool1 = nn.MaxPool2d(self.conv[-1])
450
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
451
+ self.drop1 = nn.Dropout(p_drop)
452
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
453
+ self.drop2 = nn.Dropout(p_drop)
454
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
455
+ self.drop3 = nn.Dropout(p_drop)
456
+ self.fc4 = nn.Linear(self.denses[2], n_out)
457
+
458
+ class Conv2NormV2(nn.Module):
459
+ def __init__(self, in_size, n_bands, n_out, p_drop=0.4):
460
+ self.in_size = in_size
461
+ self.n_bands = n_bands
462
+
463
+ self.conv = [5, 3, 2]
464
+ self.pad = 0
465
+ self.stride = [1, 1, 2]
466
+
467
+ self.n_filter = [256, 384]
468
+
469
+ self.input_sizes = conv_out_sizes(in_size, self.conv, self.stride, self.pad)
470
+ logger.debug(f"Input sizes: {self.input_sizes}")
471
+ self.denses = [2 * 2048, 2 * 2048, 2 * 2048]
472
+
473
+ super().__init__()
474
+ self.conv1 = nn.Conv2d(in_channels=n_bands, out_channels=self.n_filter[0], kernel_size=self.conv[0],
475
+ padding=self.pad)
476
+ self.conv1_bn = nn.BatchNorm2d(self.n_filter[0])
477
+ self.conv2 = nn.Conv2d(in_channels=self.n_filter[0], out_channels=self.n_filter[1], kernel_size=self.conv[1],
478
+ padding=self.pad)
479
+ self.conv2_bn = nn.BatchNorm2d(self.n_filter[1])
480
+ self.pool1 = nn.MaxPool2d(self.conv[-1], stride=self.stride[-1])
481
+ self.fc1 = nn.Linear(self.n_filter[-1] * self.input_sizes[-1] * self.input_sizes[-1], self.denses[0])
482
+ self.drop1 = nn.Dropout(p_drop)
483
+ self.fc2 = nn.Linear(self.denses[0], self.denses[1])
484
+ self.drop2 = nn.Dropout(p_drop)
485
+ self.fc3 = nn.Linear(self.denses[1], self.denses[2])
486
+ self.drop3 = nn.Dropout(p_drop)
487
+ self.fc4 = nn.Linear(self.denses[2], n_out)
488
+
489
+
490
+
491
+ def forward(self, x):
492
+ x = F.relu(self.conv1_bn(self.conv1(x)))
493
+ x = F.relu(self.conv2_bn(self.conv2(x)))
494
+ x = self.pool1(x)
495
+ x = torch.flatten(x, 1)
496
+ # flatten all dimensions except batch
497
+ x = F.relu(self.fc1(x))
498
+ x = self.drop1(x)
499
+ x = F.relu(self.fc2(x))
500
+ x = self.drop2(x)
501
+ x = F.relu(self.fc3(x))
502
+ x = self.drop3(x)
503
+
504
+ #F.softmax(
505
+ x = self.fc4(x)
506
+
507
+ return x