eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from eoml.torch.cnn.map_dataset import MapResultAggregator, BatchMeta, IterableMapDataset
|
|
3
|
+
from rasterio.windows import Window
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConstMemBatchMeta(BatchMeta):
|
|
7
|
+
|
|
8
|
+
def __init__(self, window, is_finished, count, worker):
|
|
9
|
+
super().__init__(window, is_finished, worker)
|
|
10
|
+
self.window = window
|
|
11
|
+
self.is_finished = is_finished
|
|
12
|
+
self.count = count
|
|
13
|
+
self.worker = worker
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Buffer:
|
|
17
|
+
|
|
18
|
+
def __init__(self, bands, height, width, device):
|
|
19
|
+
self.buffer = None
|
|
20
|
+
self.device = device
|
|
21
|
+
self.stored_height = 0
|
|
22
|
+
self.stored_width = 0
|
|
23
|
+
self.bands = bands
|
|
24
|
+
|
|
25
|
+
def store(self, data):
|
|
26
|
+
channel, height, width = data.shape
|
|
27
|
+
# we need a new buffer to avoid writing in the buffer being used for computation on gpu
|
|
28
|
+
# we cound have prefectch number of buffer to solve this issue
|
|
29
|
+
self.buffer = torch.empty((channel, height, width), device=self.device)
|
|
30
|
+
self.buffer[:, 0:height, 0:width] = data
|
|
31
|
+
self.stored_height = height
|
|
32
|
+
self.stored_width = width
|
|
33
|
+
|
|
34
|
+
def __getitem__(self, item):
|
|
35
|
+
self.buffer.__getitem__(item)
|
|
36
|
+
|
|
37
|
+
def __setitem__(self, key, value):
|
|
38
|
+
self.buffer.__setitem__(key, value)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def shape(self):
|
|
42
|
+
return self.bands, self.stored_height, self.stored_width
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class IterableMapDatasetConstMem(IterableMapDataset):
|
|
46
|
+
# Create an aligned raster with cropped border to take the convolution into account.
|
|
47
|
+
# If stride is >1, the widows starting at the top left corner and size stride X stride
|
|
48
|
+
# will be filled with the value returned by the NN.
|
|
49
|
+
|
|
50
|
+
def __init__(self, raster_reader, kernel_size, target_windows, off_x, off_y, stride=1, batch_size=1024,
|
|
51
|
+
device="cpu"):
|
|
52
|
+
|
|
53
|
+
super().__init__(raster_reader, kernel_size, target_windows, off_x, off_y, stride, batch_size, device)
|
|
54
|
+
|
|
55
|
+
self.max_width, self.max_height = self._max_win_size(target_windows, kernel_size)
|
|
56
|
+
self.buffer = Buffer(raster_reader.n_band, self.max_width, self.max_height, device)
|
|
57
|
+
|
|
58
|
+
def _max_win_size(self, windows, size):
|
|
59
|
+
# windows is a tuble ((i,j), windows)
|
|
60
|
+
width = max(windows, key=lambda w: w[1].width)[1].width + size
|
|
61
|
+
height = max(windows, key=lambda w: w[1].height)[1].height + size
|
|
62
|
+
return width, height
|
|
63
|
+
|
|
64
|
+
def __iter__(self):
|
|
65
|
+
"""
|
|
66
|
+
iteratro over the dataset. return at most batch_size data or the number of data needed to finish the current
|
|
67
|
+
block of data.
|
|
68
|
+
:return: data, (target_windows, is_block_finished, worker_id)
|
|
69
|
+
"""
|
|
70
|
+
#flush = 0
|
|
71
|
+
for ji, window in self.target_windows:
|
|
72
|
+
|
|
73
|
+
#flush +=1
|
|
74
|
+
(col_off, row_off, w_width, w_height) = window.flatten()
|
|
75
|
+
# compute the source windows
|
|
76
|
+
window_source = Window(col_off + self.off_x - self.half_size, row_off + self.off_y - self.half_size,
|
|
77
|
+
w_width + self.size - 1, w_height + self.size - 1)
|
|
78
|
+
|
|
79
|
+
np_buff = self.read_windows(window_source)
|
|
80
|
+
|
|
81
|
+
self.buffer.store(torch.from_numpy(np_buff))
|
|
82
|
+
|
|
83
|
+
for sample, meta in self.extract_tensor_iter(self.buffer, self.batch_size):
|
|
84
|
+
meta.window = window
|
|
85
|
+
yield sample, meta
|
|
86
|
+
|
|
87
|
+
#if flush == self.flush_threshold:
|
|
88
|
+
# flush = 0
|
|
89
|
+
# del sample
|
|
90
|
+
# gc.collect()
|
|
91
|
+
# torch.cuda.empty_cache()
|
|
92
|
+
|
|
93
|
+
def extract_tensor_iter(self, data, batch_size):
|
|
94
|
+
channel, height, width = data.shape
|
|
95
|
+
|
|
96
|
+
height = height - self.size + 1
|
|
97
|
+
width = width - self.size + 1
|
|
98
|
+
|
|
99
|
+
samples = []
|
|
100
|
+
|
|
101
|
+
count = 0
|
|
102
|
+
|
|
103
|
+
for i in range(0, height):
|
|
104
|
+
for j in range(0, width):
|
|
105
|
+
if count == batch_size:
|
|
106
|
+
yield torch.stack(samples, dim=0), ConstMemBatchMeta(None, False, count, self.worker_id)
|
|
107
|
+
samples = []
|
|
108
|
+
count = 0
|
|
109
|
+
source_w = self.buffer.buffer.narrow(1, i, self.size).narrow(2, j, self.size)
|
|
110
|
+
samples.append(source_w)
|
|
111
|
+
count += 1
|
|
112
|
+
|
|
113
|
+
# file the batch with empty sample
|
|
114
|
+
valid_count = count
|
|
115
|
+
while count < batch_size:
|
|
116
|
+
# seems to be a bit faster
|
|
117
|
+
self.buffer.buffer.narrow(1, 0, self.size).narrow(2, 0, self.size)
|
|
118
|
+
#samples.append(torch.empty(channel, self.size, self.size,device=self.device))
|
|
119
|
+
count += 1
|
|
120
|
+
|
|
121
|
+
yield torch.stack(samples, dim=0), ConstMemBatchMeta(None, True, valid_count, self.worker_id)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class MapResultAggregatorConstMem(MapResultAggregator):
|
|
125
|
+
|
|
126
|
+
def __init__(self, path_out, transform_result_f, n_windows, write_profile):
|
|
127
|
+
super().__init__(path_out, transform_result_f, n_windows, write_profile)
|
|
128
|
+
|
|
129
|
+
def submit_result(self, values, meta: ConstMemBatchMeta):
|
|
130
|
+
values = values[:meta.count]
|
|
131
|
+
super().submit_result(values,meta)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Output transformation classes for neural network predictions.
|
|
2
|
+
|
|
3
|
+
This module provides classes for transforming raw neural network outputs into
|
|
4
|
+
usable formats for geospatial mapping, including classification, regression,
|
|
5
|
+
and probability outputs.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OutputTransformer:
|
|
14
|
+
"""Abstract base class for transforming neural network outputs.
|
|
15
|
+
|
|
16
|
+
Defines interface for converting raw NN outputs to map-ready values.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
_shape: Shape of the output data.
|
|
20
|
+
_dtype: Data type for output values.
|
|
21
|
+
_nodata: No-data value for invalid outputs.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, shape, dtype, nodata):
|
|
24
|
+
self._shape = shape
|
|
25
|
+
self._dtype = dtype
|
|
26
|
+
self._nodata = nodata
|
|
27
|
+
|
|
28
|
+
def __call__(self, v):
|
|
29
|
+
...
|
|
30
|
+
@property
|
|
31
|
+
def shape(self):
|
|
32
|
+
"""
|
|
33
|
+
shape of the output
|
|
34
|
+
:return:
|
|
35
|
+
"""
|
|
36
|
+
return self._shape
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def bands(self):
|
|
40
|
+
"""
|
|
41
|
+
shape of the output
|
|
42
|
+
:return:
|
|
43
|
+
"""
|
|
44
|
+
return self.shape[0]
|
|
45
|
+
@property
|
|
46
|
+
def dtype(self):
|
|
47
|
+
"""
|
|
48
|
+
shape of the input
|
|
49
|
+
:return:
|
|
50
|
+
"""
|
|
51
|
+
return self._dtype
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def nodata(self):
|
|
55
|
+
"""
|
|
56
|
+
shape of the input
|
|
57
|
+
:return:
|
|
58
|
+
"""
|
|
59
|
+
return self._nodata
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ArgMax(OutputTransformer):
|
|
63
|
+
"""Return the index of the highest neural network output.
|
|
64
|
+
|
|
65
|
+
Performs argmax operation for classification tasks.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
dtype: Data type for output indices. Defaults to "int16".
|
|
69
|
+
nodata: Value for invalid outputs. Defaults to -1.
|
|
70
|
+
"""
|
|
71
|
+
def __init__(self, dtype="int16", nodata=-1):
|
|
72
|
+
super().__init__([1], dtype, nodata)
|
|
73
|
+
|
|
74
|
+
def __call__(self, vec):
|
|
75
|
+
return np.argmax(vec, axis=1).astype(self.dtype)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class ArgMaxToCategory(ArgMax):
|
|
79
|
+
"""Transform neural network categories to map category values.
|
|
80
|
+
|
|
81
|
+
Performs argmax to find the highest output, then maps the index to a
|
|
82
|
+
specific category value from the provided mapping.
|
|
83
|
+
|
|
84
|
+
Attributes:
|
|
85
|
+
category_map: List mapping NN output indices to category values.
|
|
86
|
+
dtype: Data type for output values. Defaults to "int16".
|
|
87
|
+
nodata: Value for invalid outputs. Defaults to -1.
|
|
88
|
+
"""
|
|
89
|
+
def __init__(self, category_map: List, dtype="int16", nodata=-1):
|
|
90
|
+
super().__init__(dtype, nodata)
|
|
91
|
+
self.category_map = category_map
|
|
92
|
+
|
|
93
|
+
def __call__(self, vec):
|
|
94
|
+
return np.array([self.category_map[x] for x in super().__call__(vec)], dtype=self.dtype)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Identity(OutputTransformer):
|
|
98
|
+
"""Return neural network output as-is with type casting.
|
|
99
|
+
|
|
100
|
+
Passes through NN output but casts to specified map format. Output shape
|
|
101
|
+
must be specified in constructor.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
shape: Shape of the output data.
|
|
105
|
+
dtype: Data type for output values.
|
|
106
|
+
nodata: Value for invalid outputs.
|
|
107
|
+
"""
|
|
108
|
+
def __init__(self, shape, dtype, nodata):
|
|
109
|
+
super().__init__(shape, dtype, nodata)
|
|
110
|
+
|
|
111
|
+
def __call__(self, vec):
|
|
112
|
+
return vec.astype(self.dtype)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ToPercentage(OutputTransformer):
|
|
116
|
+
"""Convert neural network output to percentage values.
|
|
117
|
+
|
|
118
|
+
Multiplies output by 100 and casts to specified type, useful for
|
|
119
|
+
probability or confidence outputs.
|
|
120
|
+
|
|
121
|
+
Attributes:
|
|
122
|
+
shape: Shape of the output data. Defaults to [1].
|
|
123
|
+
dtype: Data type for output values. Defaults to "int16".
|
|
124
|
+
nodata: Value for invalid outputs. Defaults to -255.
|
|
125
|
+
"""
|
|
126
|
+
def __init__(self, shape=[1], dtype="int16", nodata=-255):
|
|
127
|
+
super().__init__(shape, dtype, nodata)
|
|
128
|
+
|
|
129
|
+
def __call__(self, vec):
|
|
130
|
+
return (100*vec).astype(self.dtype)
|
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
"""PyTorch utility functions for neural network operations.
|
|
2
|
+
|
|
3
|
+
This module provides helper functions for PyTorch-based deep learning, including
|
|
4
|
+
convolution size calculations, custom collation functions for data loaders, pixel
|
|
5
|
+
extraction utilities, and grid alignment functions for geospatial raster data.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from rasterio.transform import xy, rowcol, guard_transform
|
|
14
|
+
from rasterio.warp import Affine
|
|
15
|
+
from rasterio.windows import Window, transform
|
|
16
|
+
from torch.utils.data import default_collate
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
def int_to_list(var, size):
|
|
21
|
+
"""Convert integer to list or validate list size.
|
|
22
|
+
|
|
23
|
+
Used for managing convolution size inputs. Repeats an integer value into a
|
|
24
|
+
list of specified size, or validates that an existing list has the correct size.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
var: Integer value to repeat or list to validate.
|
|
28
|
+
size: Target list size.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
list: List of specified size containing the value(s).
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
Exception: If var is a list of wrong size.
|
|
35
|
+
"""
|
|
36
|
+
if isinstance(var, int):
|
|
37
|
+
list_var = [var for i in range(size)]
|
|
38
|
+
else:
|
|
39
|
+
if len(var) != size:
|
|
40
|
+
raise Exception(" Input should have size n")
|
|
41
|
+
list_var = var
|
|
42
|
+
|
|
43
|
+
return list_var
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def conv_out_size(in_size, conv, stride, padding):
|
|
47
|
+
"""Calculate output size of a convolution along one dimension.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
in_size: Input size in pixels.
|
|
51
|
+
conv: Convolution kernel size.
|
|
52
|
+
stride: Convolution stride.
|
|
53
|
+
padding: Padding size.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
int: Output size after convolution.
|
|
57
|
+
"""
|
|
58
|
+
return math.floor((in_size - conv + 2 * padding) / stride + 1)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def conv_out_sizes(in_size, convs, strides, paddings):
|
|
62
|
+
"""Calculate output sizes of a series of convolutions.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
in_size: Initial input size in pixels.
|
|
66
|
+
convs: List of convolution kernel sizes (or single value).
|
|
67
|
+
strides: List of strides (or single value).
|
|
68
|
+
paddings: List of padding sizes (or single value).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
list: List of output sizes after each convolution, including initial size.
|
|
72
|
+
"""
|
|
73
|
+
n_layer = len(convs) if hasattr(convs, '__len__') else 1
|
|
74
|
+
n_layer = len(strides) if hasattr(strides, '__len__') else n_layer
|
|
75
|
+
n_layer = len(paddings) if hasattr(paddings, '__len__') else n_layer
|
|
76
|
+
|
|
77
|
+
convs = int_to_list(convs, n_layer)
|
|
78
|
+
strides = int_to_list(strides, n_layer)
|
|
79
|
+
paddings = int_to_list(paddings, n_layer)
|
|
80
|
+
|
|
81
|
+
sizes = [in_size]
|
|
82
|
+
for conv, stride, padding in zip(convs, strides, paddings):
|
|
83
|
+
sizes.append(conv_out_size(sizes[-1], conv, stride, padding))
|
|
84
|
+
return sizes
|
|
85
|
+
|
|
86
|
+
class PixelAt:
|
|
87
|
+
"""Extract specific pixel values from an array.
|
|
88
|
+
|
|
89
|
+
Can use lists to get multiple values at once.
|
|
90
|
+
|
|
91
|
+
Attributes:
|
|
92
|
+
c: Channel index(es).
|
|
93
|
+
h: Height/row index(es).
|
|
94
|
+
w: Width/column index(es).
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(self, c, h, w):
|
|
98
|
+
self.c = c
|
|
99
|
+
self.h = h
|
|
100
|
+
self.w = w
|
|
101
|
+
|
|
102
|
+
def __call__(self, array):
|
|
103
|
+
return pixel_at(array, self.c, self.h, self.w)
|
|
104
|
+
|
|
105
|
+
class PixelAtBand:
|
|
106
|
+
def __init__(self, h, w):
|
|
107
|
+
self.h = h
|
|
108
|
+
self.w = w
|
|
109
|
+
|
|
110
|
+
def __call__(self, array):
|
|
111
|
+
return pixel_at_band(array, self.h, self.w)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class PixelAtBandSkipValue:
|
|
115
|
+
|
|
116
|
+
def __init__(self, h, w, skip):
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
:param h:
|
|
120
|
+
:param w:
|
|
121
|
+
:param skip: if the value is in any of the return array, we skip it
|
|
122
|
+
"""
|
|
123
|
+
self.h = h
|
|
124
|
+
self.w = w
|
|
125
|
+
self.skip = skip
|
|
126
|
+
|
|
127
|
+
def __call__(self, array):
|
|
128
|
+
out = pixel_at_band(array, self.h, self.w)
|
|
129
|
+
|
|
130
|
+
if (out == self.skip).any():
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
return out
|
|
134
|
+
|
|
135
|
+
def center_pixel(array):
|
|
136
|
+
"""
|
|
137
|
+
look for the central pixel of an array and return it. Pixel at is error-prone but more efficient
|
|
138
|
+
:param array:
|
|
139
|
+
:return the centrer pixels values (all the band):
|
|
140
|
+
"""
|
|
141
|
+
h, w = array.shape
|
|
142
|
+
|
|
143
|
+
if h % 2 != 1 or w % 2 != 1:
|
|
144
|
+
raise "h,w has no clear center, size should be odd"
|
|
145
|
+
|
|
146
|
+
center_h = math.ceil(h / 2)
|
|
147
|
+
center_w = math.ceil(w / 2)
|
|
148
|
+
|
|
149
|
+
return array[:, center_h, center_w]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def pixel_at(array, c, h, w):
|
|
153
|
+
"""
|
|
154
|
+
Return the pixel at a given position and return an array (scalar value are transformed to a vector of size 1
|
|
155
|
+
:param array: to extract value from
|
|
156
|
+
:param c: chanel to extract value from (use a list to extract multiple value
|
|
157
|
+
:param h: height index (use a list to extract multiple value
|
|
158
|
+
:param w: width index (use a list to extract multiple value
|
|
159
|
+
:return:
|
|
160
|
+
"""
|
|
161
|
+
# can use a list to get multiple value
|
|
162
|
+
v = array[c, h, w]
|
|
163
|
+
if not isinstance(v, np.ndarray):
|
|
164
|
+
v = np.array([v])
|
|
165
|
+
return v
|
|
166
|
+
|
|
167
|
+
def pixel_at_band(array, h, w):
|
|
168
|
+
"""
|
|
169
|
+
Return the pixel at a given position and return an array (scalar value are transformed to a vector of size 1
|
|
170
|
+
:param array: to extract value from
|
|
171
|
+
:param c: chanel to extract value from (use a list to extract multiple value
|
|
172
|
+
:param h: height index (use a list to extract multiple value
|
|
173
|
+
:param w: width index (use a list to extract multiple value
|
|
174
|
+
:return:
|
|
175
|
+
"""
|
|
176
|
+
# can use a list to get multiple value
|
|
177
|
+
v = array[:, h, w]
|
|
178
|
+
if not isinstance(v, np.ndarray):
|
|
179
|
+
v = np.array([v])
|
|
180
|
+
return v
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def multi_input_training_collate(batch):
|
|
185
|
+
"""
|
|
186
|
+
custom collate function, used then there is multiple input.
|
|
187
|
+
:param batch:
|
|
188
|
+
:return: collated data separated by batch
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
data_entries = [[] for _ in range(len(batch[0]))]
|
|
192
|
+
|
|
193
|
+
for b in batch:
|
|
194
|
+
for i, entry in enumerate(b):
|
|
195
|
+
data_entries[i].append(entry)
|
|
196
|
+
# need index 0 as collate keep the outside list
|
|
197
|
+
out = list(map(lambda x: default_collate(x), data_entries))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# return the input out output separated
|
|
201
|
+
return out[:-1], out[-1]
|
|
202
|
+
|
|
203
|
+
def batch_collate(batch):
|
|
204
|
+
"""
|
|
205
|
+
custom collate function, used when we use a batch sampler.
|
|
206
|
+
:param batch:
|
|
207
|
+
:return: collated data and metadata separated
|
|
208
|
+
"""
|
|
209
|
+
return batch[0]
|
|
210
|
+
|
|
211
|
+
def meta_data_collate(batch):
|
|
212
|
+
"""
|
|
213
|
+
custom collate function, used when the __get_item__/the iterator return data, metadata. Applly default collate to
|
|
214
|
+
the data and left the meta-data untransformed.
|
|
215
|
+
Apply the default collate the
|
|
216
|
+
:param batch:
|
|
217
|
+
:return: collated data and metadata separated
|
|
218
|
+
"""
|
|
219
|
+
data = []
|
|
220
|
+
meta = []
|
|
221
|
+
for b in batch:
|
|
222
|
+
data.append(b[0])
|
|
223
|
+
meta.append(b[1])
|
|
224
|
+
|
|
225
|
+
return torch.utils.data.default_collate(data), meta
|
|
226
|
+
|
|
227
|
+
def multi_input_meta_data_collate(batch):
|
|
228
|
+
"""
|
|
229
|
+
custom collate function, used when the __get_item__/the iterator return multiple data, and 1 metadata. Applly default collate to
|
|
230
|
+
each data independantly and left the meta-data untransformed.
|
|
231
|
+
Apply the default collate the
|
|
232
|
+
:param batch:
|
|
233
|
+
:return: collated data and metadata separated
|
|
234
|
+
"""
|
|
235
|
+
data_entries = [[] for _ in range(len(batch[0][0]))]
|
|
236
|
+
sample, meta = batch[0]
|
|
237
|
+
#skip le last as it is meta data
|
|
238
|
+
|
|
239
|
+
for i, entry in enumerate(sample):
|
|
240
|
+
data_entries[i].append(entry)
|
|
241
|
+
|
|
242
|
+
# we get an outside list so we take index 0
|
|
243
|
+
out = list(map(lambda x: default_collate(x)[0], data_entries))
|
|
244
|
+
|
|
245
|
+
return out, [meta]
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def no_collate(batch):
|
|
250
|
+
"""
|
|
251
|
+
Do not transform data to tensor. flatten numpy array and separate data from meta data
|
|
252
|
+
"""
|
|
253
|
+
data = []
|
|
254
|
+
meta = []
|
|
255
|
+
|
|
256
|
+
for b in batch:
|
|
257
|
+
d = b[0].numpy().reshape(len(b[0]), len(b[0][0]))
|
|
258
|
+
np.nan_to_num(d, copy=False, )
|
|
259
|
+
data.append(b[0].numpy().reshape(len(b[0]), len(b[0][0])))
|
|
260
|
+
meta.append(b[1])
|
|
261
|
+
|
|
262
|
+
return data, meta
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def align_grid_deprecated(source_meta, bounds, size):
|
|
266
|
+
"""
|
|
267
|
+
Given the bounds we want to apply convolution to, the function will align the bound to the best matching pixel.
|
|
268
|
+
The bounds are computed for the center pixel of the windows to always be inside. TODO add offset?
|
|
269
|
+
:param source_meta:
|
|
270
|
+
:param bounds:
|
|
271
|
+
:param size:
|
|
272
|
+
:return:
|
|
273
|
+
"""
|
|
274
|
+
##align the grids taking into account the covolution on the border
|
|
275
|
+
|
|
276
|
+
## take into account shifted coordinate system
|
|
277
|
+
|
|
278
|
+
half_size = size // 2
|
|
279
|
+
|
|
280
|
+
transform = source_meta["transform"]
|
|
281
|
+
|
|
282
|
+
# inverted coordinate
|
|
283
|
+
assert transform.e < 0
|
|
284
|
+
|
|
285
|
+
# grid bound in the source grid coordinate
|
|
286
|
+
|
|
287
|
+
(bottom, left) = rowcol(transform, bounds.left, bounds.bottom, op)
|
|
288
|
+
(top, right) = rowcol(transform, bounds.right, bounds.top, op)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# compute the target bounds taking into account the convolution
|
|
292
|
+
left = max(0, left - half_size) + half_size
|
|
293
|
+
bottom = min(bottom + half_size, source_meta["height"]) - half_size
|
|
294
|
+
|
|
295
|
+
top = max(0, top - half_size) + half_size
|
|
296
|
+
right = min(right + half_size, source_meta["width"]) - half_size
|
|
297
|
+
|
|
298
|
+
# dimension of the bound grid
|
|
299
|
+
width = right - left
|
|
300
|
+
height = bottom - top
|
|
301
|
+
|
|
302
|
+
(west, north) = xy(transform, top, left, offset="ul")
|
|
303
|
+
|
|
304
|
+
# based on transformation from bound. specify the left, top pixel and pixel size same as original)
|
|
305
|
+
target_transform = Affine.translation(west, north) * Affine.scale(transform.a, transform.e)
|
|
306
|
+
|
|
307
|
+
# lef top is the offset
|
|
308
|
+
return target_transform, width, height, left, top
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def align_grid(src_transform, bounds, r_width, r_height, size, shrink_for_conv=False, precision=0.01):
|
|
312
|
+
"""
|
|
313
|
+
Given the bounds we want to apply convolution to, the function will align the bound to the best matching pixel.
|
|
314
|
+
The bounds are computed for the center pixel of the windows to always be inside. TODO add offset?
|
|
315
|
+
:param transform:
|
|
316
|
+
:param bounds:
|
|
317
|
+
:param size:
|
|
318
|
+
:return:
|
|
319
|
+
"""
|
|
320
|
+
##align the grids taking into account the covolution on the border
|
|
321
|
+
|
|
322
|
+
## take into account shifted coordinate system
|
|
323
|
+
|
|
324
|
+
half_size = size // 2
|
|
325
|
+
|
|
326
|
+
# grid bound in the source grid coordinate
|
|
327
|
+
|
|
328
|
+
window = aligned_bound(bounds.left, bounds.bottom, bounds.right, bounds.top, src_transform, precision=precision)
|
|
329
|
+
|
|
330
|
+
left = window.col_off
|
|
331
|
+
right = left + window.width
|
|
332
|
+
|
|
333
|
+
top = window.row_off
|
|
334
|
+
bottom = top + window.height
|
|
335
|
+
|
|
336
|
+
# compute the target bounds taking into account the convolution
|
|
337
|
+
if shrink_for_conv:
|
|
338
|
+
left = max(0, left - half_size) + half_size
|
|
339
|
+
bottom = min(bottom + half_size, r_height) - half_size + 1
|
|
340
|
+
|
|
341
|
+
top = max(0, top - half_size) + half_size
|
|
342
|
+
right = min(right + half_size, r_width) - half_size + 1
|
|
343
|
+
|
|
344
|
+
# dimension of the bound grid
|
|
345
|
+
# new windows with convolution inside
|
|
346
|
+
window = Window(left, top, right - left, bottom - top)
|
|
347
|
+
|
|
348
|
+
# from window_transform(windows)
|
|
349
|
+
|
|
350
|
+
width = right - left
|
|
351
|
+
height = bottom - top
|
|
352
|
+
|
|
353
|
+
gtransform = guard_transform(src_transform)
|
|
354
|
+
target_transform = transform(window, gtransform)
|
|
355
|
+
|
|
356
|
+
#(west, north) = xy(transform, top, left, offset="ul")
|
|
357
|
+
|
|
358
|
+
# based on transformation from bound. specify the left, top pixel and pixel size same as original)
|
|
359
|
+
#target_transform = Affine.translation(west, north) * Affine.scale(transform.a, transform.e)
|
|
360
|
+
|
|
361
|
+
# lef top is the offset
|
|
362
|
+
return target_transform, width, height, left, top
|
|
363
|
+
|
|
364
|
+
def aligned_bound(left, bottom, right, top, transform, precision=0.01):
|
|
365
|
+
"""
|
|
366
|
+
Compute the input windows with a shrink of 1 pixel.
|
|
367
|
+
We assume pixel is area. pixel "point" is located at the top left of the pixel and bounding is the real bounding
|
|
368
|
+
box. This mean boding box is effectively at pixel (0, 0) and (length, length). the actual array pixel as last
|
|
369
|
+
pixel at length-1, length-1
|
|
370
|
+
|
|
371
|
+
the pixel is included if precision percent of it is covered by the raw raster
|
|
372
|
+
"""
|
|
373
|
+
|
|
374
|
+
def idx(x):
|
|
375
|
+
return x
|
|
376
|
+
|
|
377
|
+
#index invert coordinate order ie rowcol (could use row col
|
|
378
|
+
bottom, left = rowcol(transform, left, bottom, op=idx) #transform.index(left, bottom, op=idx)
|
|
379
|
+
top, right = rowcol(transform, right, top, op=idx) # transform.index(right, top, op=idx)
|
|
380
|
+
|
|
381
|
+
# top left pixel if contained more than precision from the top lef corner of the pixel we need to round up
|
|
382
|
+
left = _round_high(left, precision) # the index match the bound
|
|
383
|
+
top = _round_high(top, precision)
|
|
384
|
+
|
|
385
|
+
# we need to be very close to the bottom right of the pixel (it mean it includ almost all the pixel, if not round down
|
|
386
|
+
bottom = _round_low(bottom, precision)-1 #the bound is at index + 1
|
|
387
|
+
right = _round_low(right, precision)-1
|
|
388
|
+
|
|
389
|
+
return Window(left, top, right-left, bottom-top)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _round_low(value, precision):
|
|
393
|
+
|
|
394
|
+
if math.ceil(value)-value < precision:
|
|
395
|
+
value = math.ceil(value)
|
|
396
|
+
return math.floor(value)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def _round_high(value, precision):
|
|
400
|
+
# if pixel of input almost covered we use it anyway (the the left and top pixel. covered mean close to floor)
|
|
401
|
+
if value - math.floor(value) < precision:
|
|
402
|
+
value = math.floor(value)
|
|
403
|
+
|
|
404
|
+
return math.ceil(value)
|