eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
"""Map dataset utilities for applying neural networks to raster imagery.
|
|
2
|
+
|
|
3
|
+
This module provides tools for iterating over raster data in a sliding window fashion,
|
|
4
|
+
applying neural network predictions, and aggregating results back to raster format.
|
|
5
|
+
Supports parallel processing and various optimization strategies.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import copy
|
|
9
|
+
import logging
|
|
10
|
+
import math
|
|
11
|
+
from typing import Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import rasterio
|
|
15
|
+
import torch
|
|
16
|
+
import torch.multiprocessing as mp
|
|
17
|
+
|
|
18
|
+
from eoml import get_write_profile
|
|
19
|
+
from eoml.torch.cnn.outputs_transformer import OutputTransformer
|
|
20
|
+
from eoml.torch.cnn.torch_utils import align_grid
|
|
21
|
+
from rasterio.windows import Window
|
|
22
|
+
from shapely.geometry import box
|
|
23
|
+
from torch.utils.data import IterableDataset, DataLoader
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
class BatchMeta:
|
|
29
|
+
"""Metadata for a batch of data windows.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
window (Window, optional): Rasterio window specification.
|
|
33
|
+
is_finished (bool): Whether this is the last batch for current window.
|
|
34
|
+
worker (int): Worker ID that processed this batch.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, window: Union[Window, None], is_finished: bool, worker: int):
|
|
38
|
+
"""Initialize BatchMeta.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
window (Window, optional): Rasterio window for this batch.
|
|
42
|
+
is_finished (bool): Whether this completes processing of current window.
|
|
43
|
+
worker (int): Worker ID that processed this batch.
|
|
44
|
+
"""
|
|
45
|
+
self.window = window
|
|
46
|
+
self.is_finished = is_finished
|
|
47
|
+
self.worker = worker
|
|
48
|
+
|
|
49
|
+
def continuous_split(dataset, worker_id:int, n_workers: int):
|
|
50
|
+
"""Split dataset windows continuously across workers.
|
|
51
|
+
|
|
52
|
+
Each worker receives adjacent contiguous blocks. Better for locality when
|
|
53
|
+
overlapping cells may be cached by GDAL.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
dataset: Dataset with target_windows attribute.
|
|
57
|
+
worker_id (int): ID of current worker.
|
|
58
|
+
n_workers (int): Total number of workers.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
list: Subset of target_windows for this worker.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
size, reminder = divmod(len(dataset.target_windows), n_workers)
|
|
65
|
+
|
|
66
|
+
if worker_id < reminder:
|
|
67
|
+
size = size + 1
|
|
68
|
+
start = worker_id * size
|
|
69
|
+
end = start + size
|
|
70
|
+
# make a deep copy to try to avoid the shared memory copy issue
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
# the reminder are consumed by the previous worker
|
|
74
|
+
start = worker_id * size + reminder
|
|
75
|
+
end = start + size
|
|
76
|
+
|
|
77
|
+
return dataset.target_windows[start:end]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def jumped_split(dataset, worker_id: int, n_workers: int):
|
|
81
|
+
"""Split dataset windows in interleaved fashion across workers.
|
|
82
|
+
|
|
83
|
+
Workers process blocks i, i+n_workers, i+2*n_workers, etc. Better for seeing
|
|
84
|
+
overall progress as work is distributed across the full spatial extent.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
dataset: Dataset with target_windows attribute.
|
|
88
|
+
worker_id (int): ID of current worker.
|
|
89
|
+
n_workers (int): Total number of workers.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
list: Subset of target_windows for this worker (every n_workers-th window).
|
|
93
|
+
"""
|
|
94
|
+
return dataset.target_windows[worker_id::n_workers]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def windows_in_mask(window: Window, transform, mask):
|
|
98
|
+
"""Check if a raster window intersects with spatial mask.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
window (Window): Rasterio window to check.
|
|
102
|
+
transform: Affine transform for the raster.
|
|
103
|
+
mask: List of shapely geometries defining the mask.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
bool: True if window intersects any geometry in mask.
|
|
107
|
+
"""
|
|
108
|
+
bounds = box(*rasterio.windows.bounds(window, transform))
|
|
109
|
+
return any(map(lambda shape: shape.intersects(bounds), mask))
|
|
110
|
+
|
|
111
|
+
# remove soon
|
|
112
|
+
class IterableMapDataset(IterableDataset):
|
|
113
|
+
"""Iterable dataset for applying CNNs to raster imagery using sliding windows.
|
|
114
|
+
|
|
115
|
+
Reads raster data in windows and extracts overlapping patches for CNN inference.
|
|
116
|
+
Creates an aligned output raster with convolution border handling. When stride > 1,
|
|
117
|
+
windows starting at top-left with size stride x stride are filled with NN output.
|
|
118
|
+
|
|
119
|
+
Attributes:
|
|
120
|
+
raster_reader: Reader for input raster data.
|
|
121
|
+
size (int): Kernel/window size for CNN.
|
|
122
|
+
half_size (int): Half of kernel size (for padding calculations).
|
|
123
|
+
target_windows (list, optional): List of windows to process.
|
|
124
|
+
off_x (float): X offset for window alignment.
|
|
125
|
+
off_y (float): Y offset for window alignment.
|
|
126
|
+
device (str): Device for tensor operations.
|
|
127
|
+
stride (int): Stride for window extraction.
|
|
128
|
+
batch_size (int): Number of samples per batch.
|
|
129
|
+
worker_id (int): ID of current worker process.
|
|
130
|
+
"""
|
|
131
|
+
# Create an aligned raster with cropped border to take the convolution into account.
|
|
132
|
+
# If stride is >1, the widows starting at the top left corner and size stride X stride
|
|
133
|
+
# will be filled with the value returned by the NN.
|
|
134
|
+
|
|
135
|
+
def __init__(self, raster_reader, kernel_size, target_windows=None, off_x=None, off_y=None, stride=1, batch_size=1024,
|
|
136
|
+
device="cpu"):
|
|
137
|
+
"""Initialize IterableMapDataset.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
raster_reader: RasterReader instance for input data.
|
|
141
|
+
kernel_size (int): Size of CNN input window (must be odd).
|
|
142
|
+
target_windows (list, optional): Pre-defined windows to process. Defaults to None.
|
|
143
|
+
off_x (float, optional): X offset for alignment. Defaults to kernel_size/2.
|
|
144
|
+
off_y (float, optional): Y offset for alignment. Defaults to kernel_size/2.
|
|
145
|
+
stride (int, optional): Stride for window extraction. Defaults to 1.
|
|
146
|
+
batch_size (int, optional): Batch size for processing. Defaults to 1024.
|
|
147
|
+
device (str, optional): PyTorch device. Defaults to "cpu".
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
Exception: If kernel_size is even (only odd kernels supported).
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
super().__init__()
|
|
154
|
+
|
|
155
|
+
self.raster_reader = raster_reader
|
|
156
|
+
|
|
157
|
+
self.size = kernel_size
|
|
158
|
+
self.half_size = math.floor(kernel_size / 2)
|
|
159
|
+
|
|
160
|
+
self.target_windows = target_windows
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
if kernel_size % 2 == 0:
|
|
164
|
+
raise "odd kernel not supported yet"
|
|
165
|
+
|
|
166
|
+
if off_x is None:
|
|
167
|
+
off_x = kernel_size / 2
|
|
168
|
+
|
|
169
|
+
if off_y is None:
|
|
170
|
+
off_y = kernel_size / 2
|
|
171
|
+
|
|
172
|
+
self.off_x = off_x
|
|
173
|
+
self.off_y = off_y
|
|
174
|
+
|
|
175
|
+
self.device = device
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
self.stride = stride
|
|
179
|
+
|
|
180
|
+
# Return the list of block with corresponding target input
|
|
181
|
+
# filter the list based on a shapefile if needed
|
|
182
|
+
# each worker has a list of block assignated
|
|
183
|
+
# load a block in memory
|
|
184
|
+
# read the block nline by nline? and create input return with cell id and line id
|
|
185
|
+
|
|
186
|
+
self.batch_size = batch_size
|
|
187
|
+
self.worker_id = 0
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def __iter__(self):
|
|
191
|
+
"""
|
|
192
|
+
iterator over the dataset. return at most batch_size data or the number of data needed to finish the current
|
|
193
|
+
block of data.
|
|
194
|
+
:return: data, (target_windows, is_block_finished, worker_id)
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
with self.raster_reader as reader:
|
|
198
|
+
for ji, window in self.target_windows:
|
|
199
|
+
|
|
200
|
+
(col_off, row_off, w_width, w_height) = window.flatten()
|
|
201
|
+
# compute the source windows
|
|
202
|
+
window_source = Window(col_off + self.off_x - self.half_size, row_off + self.off_y - self.half_size,
|
|
203
|
+
w_width + self.size - 1, w_height + self.size - 1)
|
|
204
|
+
|
|
205
|
+
a = reader.read_windows(window_source)
|
|
206
|
+
buffer = torch.from_numpy(a).to(self.device)
|
|
207
|
+
|
|
208
|
+
for tmp in self.extract_tensor_iter(buffer, self.batch_size):
|
|
209
|
+
sample, meta = tmp
|
|
210
|
+
meta.window = window
|
|
211
|
+
yield sample, meta
|
|
212
|
+
|
|
213
|
+
def extract_tensor_iter(self, data, batch_size):
|
|
214
|
+
"""
|
|
215
|
+
Read the nn windows from the given data
|
|
216
|
+
:param data:
|
|
217
|
+
:param batch_size:
|
|
218
|
+
:return:
|
|
219
|
+
"""
|
|
220
|
+
_, height, width = data.shape
|
|
221
|
+
|
|
222
|
+
height = height - self.size + 1
|
|
223
|
+
width = width - self.size + 1
|
|
224
|
+
|
|
225
|
+
samples = []
|
|
226
|
+
|
|
227
|
+
count = 0
|
|
228
|
+
|
|
229
|
+
for i in range(0, height, self.stride):
|
|
230
|
+
for j in range(0, width, self.stride):
|
|
231
|
+
if count == batch_size:
|
|
232
|
+
yield torch.stack(samples, dim=0), BatchMeta(None, False, self.worker_id)
|
|
233
|
+
samples = []
|
|
234
|
+
count = 0
|
|
235
|
+
source_w = data.narrow(1, i, self.size).narrow(2, j, self.size)
|
|
236
|
+
samples.append(source_w)
|
|
237
|
+
count += 1
|
|
238
|
+
|
|
239
|
+
yield torch.stack(samples, dim=0), BatchMeta(None, True, self.worker_id)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@staticmethod
|
|
243
|
+
def basic_worker_init_fn(worker_id, splitting_f=jumped_split):
|
|
244
|
+
"""
|
|
245
|
+
A basic function splitting the worker job to the dataset. Try to make deep copy where needed to avoid the
|
|
246
|
+
memory issue when multiple worker (test needed)
|
|
247
|
+
:param worker_id:
|
|
248
|
+
:param splitting_f:
|
|
249
|
+
:return:
|
|
250
|
+
"""
|
|
251
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
252
|
+
|
|
253
|
+
dataset = worker_info.dataset # the dataset copy in this worker process
|
|
254
|
+
n_workers = worker_info.num_workers
|
|
255
|
+
|
|
256
|
+
# make a deep copy to try to avoid the shared memory copy issue
|
|
257
|
+
dataset.target_windows = copy.deepcopy(splitting_f(dataset, worker_id, n_workers))
|
|
258
|
+
dataset.worker_id = worker_id
|
|
259
|
+
|
|
260
|
+
# make a copy of the dataset_reader for thread safety
|
|
261
|
+
dataset.raster_reader = copy.copy(dataset.raster_reader)
|
|
262
|
+
|
|
263
|
+
class IterableYearMapDataset(IterableMapDataset):
|
|
264
|
+
def __init__(self, raster_reader, year, kernel_size, target_windows=None, off_x=None, off_y=None, stride=1, batch_size=1024,
|
|
265
|
+
device="cpu", year_normalisation=2500):
|
|
266
|
+
super().__init__(raster_reader, kernel_size, target_windows, off_x, off_y, stride, batch_size, device)
|
|
267
|
+
|
|
268
|
+
self.year = year
|
|
269
|
+
self.year_normalisation=year_normalisation
|
|
270
|
+
logger.info("initialize IterableYearMapDataset")
|
|
271
|
+
def extract_tensor_iter(self, data, batch_size):
|
|
272
|
+
"""
|
|
273
|
+
Read the nn windows from the given data
|
|
274
|
+
:param data:
|
|
275
|
+
:param batch_size:
|
|
276
|
+
:return:
|
|
277
|
+
"""
|
|
278
|
+
_, height, width = data.shape
|
|
279
|
+
|
|
280
|
+
height = height - self.size + 1
|
|
281
|
+
width = width - self.size + 1
|
|
282
|
+
|
|
283
|
+
samples = []
|
|
284
|
+
|
|
285
|
+
count = 0
|
|
286
|
+
|
|
287
|
+
for i in range(0, height, self.stride):
|
|
288
|
+
for j in range(0, width, self.stride):
|
|
289
|
+
if count == batch_size:
|
|
290
|
+
yield (torch.stack(samples, dim=0), np.array([[self.year/self.year_normalisation]
|
|
291
|
+
for _ in range(len(samples))],dtype=np.float32)),\
|
|
292
|
+
BatchMeta(None, False, self.worker_id)
|
|
293
|
+
|
|
294
|
+
samples = []
|
|
295
|
+
count = 0
|
|
296
|
+
source_w = data.narrow(1, i, self.size).narrow(2, j, self.size)
|
|
297
|
+
samples.append(source_w)
|
|
298
|
+
count += 1
|
|
299
|
+
|
|
300
|
+
yield (torch.stack(samples, dim=0), np.array([[self.year/self.year_normalisation] for _ in range(len(samples))],
|
|
301
|
+
dtype=np.float32),), BatchMeta(None, True, self.worker_id)
|
|
302
|
+
|
|
303
|
+
class MapResultAggregator:
|
|
304
|
+
"""
|
|
305
|
+
Recieve the result back from the processing and write it ot a map
|
|
306
|
+
TODO manage encoder decoder
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, path_out, output_transformer: OutputTransformer, n_windows, write_profile):
|
|
310
|
+
|
|
311
|
+
self.bands = output_transformer.bands
|
|
312
|
+
self.write_profile = copy.deepcopy(write_profile)
|
|
313
|
+
self.write_profile.update({"dtype": output_transformer.dtype,
|
|
314
|
+
"count": output_transformer.bands})
|
|
315
|
+
self.result_cache = {}
|
|
316
|
+
|
|
317
|
+
self.path_out = path_out
|
|
318
|
+
self.output_transformer = output_transformer
|
|
319
|
+
self.n_windows = n_windows
|
|
320
|
+
|
|
321
|
+
def submit_result(self, values, meta: BatchMeta):
|
|
322
|
+
|
|
323
|
+
values = self.output_transformer(values)
|
|
324
|
+
# values = values.reshape((self.n_band,windows.width, windows.height))
|
|
325
|
+
cached = self.result_cache.setdefault(meta.worker, [])
|
|
326
|
+
|
|
327
|
+
cached.append(values)
|
|
328
|
+
|
|
329
|
+
if meta.is_finished:
|
|
330
|
+
values = self.reshape(cached, meta.window)
|
|
331
|
+
cached.clear()
|
|
332
|
+
self.write(values, meta.window)
|
|
333
|
+
|
|
334
|
+
def reshape(self, data, windows):
|
|
335
|
+
"""
|
|
336
|
+
Take a list of n array of abritraty length and n_bands depth and return a windows of size n_chanel, height, width
|
|
337
|
+
:param data:
|
|
338
|
+
:param windows:
|
|
339
|
+
:return:
|
|
340
|
+
"""
|
|
341
|
+
#out = np.empty((self.n_bands, windows.height, windows.width), dtype=self.d_type)
|
|
342
|
+
|
|
343
|
+
width = windows.width
|
|
344
|
+
height = windows.height
|
|
345
|
+
#concatenate make one array. then we reshape and move the band which is in the last position to the first
|
|
346
|
+
return np.moveaxis(np.concatenate(data).reshape((height, width, self.bands)), 2, 0)
|
|
347
|
+
|
|
348
|
+
def write(self, data, windows):
|
|
349
|
+
"""
|
|
350
|
+
todo currently flush after each windows maybe perf cost
|
|
351
|
+
:param data:
|
|
352
|
+
:param windows:
|
|
353
|
+
:return:
|
|
354
|
+
"""
|
|
355
|
+
# for some reason if not in threading spawn mode, setting any other option cause a deadlock num_threads=4
|
|
356
|
+
with rasterio.open(self.path_out, "r+", sharing=False ) as writer:
|
|
357
|
+
writer.write(data, window=windows)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class AsyncAggregator(mp.Process):
|
|
361
|
+
"""
|
|
362
|
+
wrapper around an aggregator which does the operation asynchronously
|
|
363
|
+
"""
|
|
364
|
+
def __init__(self, aggregator: MapResultAggregator, max_queue=5):
|
|
365
|
+
super().__init__(daemon=True)
|
|
366
|
+
self.queue = mp.Queue(max_queue)
|
|
367
|
+
self.aggregator = copy.deepcopy(aggregator)
|
|
368
|
+
self.windows_left = aggregator.n_windows
|
|
369
|
+
|
|
370
|
+
self.daemon = True
|
|
371
|
+
|
|
372
|
+
def run(self):
|
|
373
|
+
while True:
|
|
374
|
+
data, meta = self.queue.get()
|
|
375
|
+
self.aggregator.submit_result(data, meta)
|
|
376
|
+
del data
|
|
377
|
+
if meta.is_finished:
|
|
378
|
+
self.windows_left -= 1
|
|
379
|
+
if self.windows_left == 0:
|
|
380
|
+
return
|
|
381
|
+
|
|
382
|
+
def submit_result(self, values, meta):
|
|
383
|
+
self.queue.put((values, meta))
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class GenericMapper:
|
|
387
|
+
""" apply a generic mapping function. (allow to run random forest and co)
|
|
388
|
+
use no_collate to work on numpy
|
|
389
|
+
"""
|
|
390
|
+
def __init__(self,
|
|
391
|
+
mapper,
|
|
392
|
+
stride=1,
|
|
393
|
+
loader_device='cpu',
|
|
394
|
+
mapper_device='cpu',
|
|
395
|
+
pin_memory=False,
|
|
396
|
+
num_worker=0,
|
|
397
|
+
prefetch_factor=2,
|
|
398
|
+
custom_collate=None,
|
|
399
|
+
worker_init_fn=None,
|
|
400
|
+
aggregator=MapResultAggregator,
|
|
401
|
+
write_profile=None,
|
|
402
|
+
async_agr=True):
|
|
403
|
+
|
|
404
|
+
logger.info("setting model to eval mode")
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
self.mapper = mapper.to(mapper_device)
|
|
408
|
+
self.mapper.eval()
|
|
409
|
+
self.stride = stride
|
|
410
|
+
|
|
411
|
+
self.loader_device = loader_device
|
|
412
|
+
self.mapper_device = mapper_device
|
|
413
|
+
self.pin_memory = pin_memory
|
|
414
|
+
self.num_worker = num_worker
|
|
415
|
+
self.prefetch_factor = prefetch_factor
|
|
416
|
+
|
|
417
|
+
self.custom_collate = custom_collate
|
|
418
|
+
|
|
419
|
+
self.worker_init_fn = worker_init_fn
|
|
420
|
+
|
|
421
|
+
if num_worker > 0 and worker_init_fn is None:
|
|
422
|
+
raise Exception("A custom worker_init_fn is needed for map iterator when parallel mode is used")
|
|
423
|
+
|
|
424
|
+
# factory should be used instead but for now
|
|
425
|
+
self.aggregator = aggregator
|
|
426
|
+
self.async_agr = async_agr
|
|
427
|
+
|
|
428
|
+
self.write_profile = write_profile
|
|
429
|
+
|
|
430
|
+
def map(self,
|
|
431
|
+
out_path,
|
|
432
|
+
map_iterator,
|
|
433
|
+
output_transformer: OutputTransformer,
|
|
434
|
+
bounds=None,
|
|
435
|
+
mask=None):
|
|
436
|
+
|
|
437
|
+
map_iterator, aggregator = self.mapping_generator(map_iterator,
|
|
438
|
+
out_path,
|
|
439
|
+
output_transformer,
|
|
440
|
+
bounds=bounds,
|
|
441
|
+
mask=mask,
|
|
442
|
+
write_profile=self.write_profile)
|
|
443
|
+
|
|
444
|
+
dl = DataLoader(map_iterator, collate_fn=self.custom_collate, pin_memory=self.pin_memory,
|
|
445
|
+
num_workers=self.num_worker, prefetch_factor=self.prefetch_factor,
|
|
446
|
+
worker_init_fn=self.worker_init_fn)
|
|
447
|
+
|
|
448
|
+
if self.async_agr:
|
|
449
|
+
aggregator = AsyncAggregator(aggregator, 10)
|
|
450
|
+
aggregator.start()
|
|
451
|
+
|
|
452
|
+
with torch.inference_mode():
|
|
453
|
+
with tqdm(total=len(map_iterator.target_windows), desc='Map') as pbar:
|
|
454
|
+
|
|
455
|
+
for inputs, meta in dl:
|
|
456
|
+
# increment when windows are finished
|
|
457
|
+
|
|
458
|
+
# nn_inputs = inputs[0].to(device)
|
|
459
|
+
# release cuda memory as soon as possible
|
|
460
|
+
|
|
461
|
+
out = self.process_batch(inputs)
|
|
462
|
+
|
|
463
|
+
aggregator.submit_result(out, meta[0])
|
|
464
|
+
|
|
465
|
+
if meta[0].is_finished:
|
|
466
|
+
# to monitor cuda
|
|
467
|
+
# if nn_device == "cuda":
|
|
468
|
+
# pbar.set_postfix({'allocated': torch.cuda.memory_allocated(),
|
|
469
|
+
# 'max allocated': torch.cuda.max_memory_allocated(),
|
|
470
|
+
# 'reserved': torch.cuda.memory_reserved(),
|
|
471
|
+
# 'max reserved': torch.cuda.max_memory_reserved()}, refresh=False)
|
|
472
|
+
pbar.update(1)
|
|
473
|
+
|
|
474
|
+
if self.async_agr:
|
|
475
|
+
aggregator.join()
|
|
476
|
+
|
|
477
|
+
def process_batch(self, inputs):
|
|
478
|
+
|
|
479
|
+
if isinstance(inputs, (list, tuple)):
|
|
480
|
+
inputs = map(lambda x: x.to(self.mapper_device, non_blocking=True), inputs)
|
|
481
|
+
else:
|
|
482
|
+
inputs = inputs.to(self.mapper_device, non_blocking=True)
|
|
483
|
+
|
|
484
|
+
# if not traced we need to run pass to the function *inputs
|
|
485
|
+
outputs = self.mapper(*inputs)
|
|
486
|
+
del inputs
|
|
487
|
+
s = outputs.detach().cpu().numpy()
|
|
488
|
+
del outputs
|
|
489
|
+
return s
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def mapping_generator(self,
|
|
494
|
+
map_iterator,
|
|
495
|
+
path_out,
|
|
496
|
+
output_transformer: OutputTransformer,
|
|
497
|
+
bounds=None,
|
|
498
|
+
mask=None,
|
|
499
|
+
write_profile=None):
|
|
500
|
+
"""
|
|
501
|
+
Generate a mapiterator and a agregattor to be used for mapping
|
|
502
|
+
:param map_iterator:
|
|
503
|
+
:param output_transformer:
|
|
504
|
+
:param path_out:
|
|
505
|
+
:param kernel_size:
|
|
506
|
+
:param bounds:
|
|
507
|
+
:param mask:
|
|
508
|
+
:param write_profile:
|
|
509
|
+
:return:
|
|
510
|
+
"""
|
|
511
|
+
# TODO adapt the code for when windows are not block and to take into account encoder decoder
|
|
512
|
+
|
|
513
|
+
ref_raster = map_iterator.raster_reader.ref_raster()
|
|
514
|
+
|
|
515
|
+
with rasterio.open(ref_raster, mode="r") as raster_source:
|
|
516
|
+
src = raster_source
|
|
517
|
+
|
|
518
|
+
if write_profile is None:
|
|
519
|
+
write_profile = get_write_profile()
|
|
520
|
+
|
|
521
|
+
if map_iterator.size % 2 == 0:
|
|
522
|
+
raise "odd kernel not supported yet"
|
|
523
|
+
|
|
524
|
+
if not bounds:
|
|
525
|
+
bounds = src.bounds
|
|
526
|
+
|
|
527
|
+
# this function i based on the center pixel need adjustement for fully convolutional
|
|
528
|
+
transform, width, height, off_x, off_y = align_grid(src.transform, bounds, src.width, src.height, map_iterator.size)
|
|
529
|
+
|
|
530
|
+
write_profile.update({'dtype': output_transformer.dtype,
|
|
531
|
+
'crs': src.crs,
|
|
532
|
+
'transform': transform,
|
|
533
|
+
'width': width,
|
|
534
|
+
'height': height,
|
|
535
|
+
'nodata': output_transformer.nodata,
|
|
536
|
+
'count': output_transformer.bands})
|
|
537
|
+
# out_profile =out_meta(src.meta, 5, 1)
|
|
538
|
+
# create an empty raster
|
|
539
|
+
with rasterio.open(path_out, mode="w", **write_profile) as dst:
|
|
540
|
+
windows = list(dst.block_windows())
|
|
541
|
+
|
|
542
|
+
if mask is not None:
|
|
543
|
+
windows = list(filter(lambda x: windows_in_mask(x[1], transform, mask), windows))
|
|
544
|
+
|
|
545
|
+
# set the good parameter to the map iterator
|
|
546
|
+
map_iterator.target_windows = windows
|
|
547
|
+
map_iterator.off_x = off_x
|
|
548
|
+
map_iterator.off_y = off_y
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
aggregator = self.aggregator(path_out, output_transformer, len(windows), write_profile)
|
|
552
|
+
|
|
553
|
+
return map_iterator, aggregator
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
class NNMapper(GenericMapper):
|
|
557
|
+
"""
|
|
558
|
+
Object specialised to run nn
|
|
559
|
+
"""
|
|
560
|
+
def __init__(self,
|
|
561
|
+
model,
|
|
562
|
+
stride=1,
|
|
563
|
+
loader_device='cpu',
|
|
564
|
+
mapper_device='cpu',
|
|
565
|
+
pin_memory=False,
|
|
566
|
+
num_worker=0,
|
|
567
|
+
prefetch_factor=2,
|
|
568
|
+
custom_collate=None,
|
|
569
|
+
worker_init_fn=None,
|
|
570
|
+
aggregator=MapResultAggregator,
|
|
571
|
+
write_profile=None,
|
|
572
|
+
async_agr=True):
|
|
573
|
+
|
|
574
|
+
super().__init__(model, stride, loader_device, mapper_device, pin_memory, num_worker,
|
|
575
|
+
prefetch_factor, custom_collate, worker_init_fn, aggregator, write_profile, async_agr)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
|