eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,604 @@
|
|
|
1
|
+
"""Raster data extraction module for creating labeled datasets.
|
|
2
|
+
|
|
3
|
+
This module provides classes for extracting windows of raster data around labeled
|
|
4
|
+
point locations. It supports various optimization strategies including block-level
|
|
5
|
+
reading, parallel processing with threads or processes, and efficient I/O operations.
|
|
6
|
+
|
|
7
|
+
The extractors are designed to work with geospatial vector files (containing labeled
|
|
8
|
+
points) and raster files, producing datasets suitable for machine learning training.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import copy
|
|
12
|
+
import logging
|
|
13
|
+
import math
|
|
14
|
+
import threading
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
|
17
|
+
from multiprocessing import Process, Queue
|
|
18
|
+
from typing import List, Union
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
import fiona
|
|
23
|
+
import numpy as np
|
|
24
|
+
import rasterio
|
|
25
|
+
import rasterio.crs
|
|
26
|
+
import rasterio.warp
|
|
27
|
+
import shapely
|
|
28
|
+
from eoml.data.basic_geo_data import GeoDataHeader, BasicGeoData
|
|
29
|
+
from eoml.data.persistence.generic import GeoDataWriter
|
|
30
|
+
from eoml.data.persistence.lmdb import LMDBWriter
|
|
31
|
+
from eoml.raster.raster_reader import RasterReader, AbstractRasterReader
|
|
32
|
+
from rasterio.windows import Window
|
|
33
|
+
from rasterio.windows import round_window_to_full_blocks
|
|
34
|
+
from shapely.geometry import shape
|
|
35
|
+
from tqdm import tqdm
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Header:
|
|
39
|
+
"""Container for sample metadata during extraction.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
label: Class label for the sample.
|
|
43
|
+
geometry: Shapely geometry (typically Point) for the sample location.
|
|
44
|
+
idx: Unique identifier for the sample.
|
|
45
|
+
window: Rasterio window defining the extraction area.
|
|
46
|
+
"""
|
|
47
|
+
def __init__(self, label, geometry, idx=None):
|
|
48
|
+
self.label = label
|
|
49
|
+
self.geometry = geometry
|
|
50
|
+
self.idx = idx
|
|
51
|
+
self.window: Window | None = None
|
|
52
|
+
|
|
53
|
+
#row col
|
|
54
|
+
#header.y, header.x
|
|
55
|
+
|
|
56
|
+
class AbstractExtractor:
|
|
57
|
+
"""Abstract base class for dataset extractors.
|
|
58
|
+
|
|
59
|
+
Defines the interface for extracting labeled windows from raster data.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def _prepare(self):
|
|
64
|
+
"""Prepare headers and metadata for extraction.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
List of Header objects ready for extraction.
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def _extract(self, headers: List[Header]):
|
|
73
|
+
"""Extract data for the given headers.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
headers: List of Header objects defining what to extract.
|
|
77
|
+
"""
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def process(self):
|
|
83
|
+
"""Execute the complete extraction workflow."""
|
|
84
|
+
...
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def basic_extract_iter(samples: List[Header], location, reader):
|
|
88
|
+
"""Generate extracted samples by reading each window individually.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
samples: List of headers defining extraction windows.
|
|
92
|
+
location: Path to a vector file with sample locations.
|
|
93
|
+
reader: Raster reader instance.
|
|
94
|
+
|
|
95
|
+
Yields:
|
|
96
|
+
BasicGeoData: Extracted raster data with metadata and label.
|
|
97
|
+
"""
|
|
98
|
+
for header in samples:
|
|
99
|
+
data = reader.read_windows(header.window)
|
|
100
|
+
|
|
101
|
+
if LabeledWindowsExtractor.is_valid(data):
|
|
102
|
+
yield BasicGeoData(GeoDataHeader(header.idx, header.geometry, location), data, header.label)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def extract_blocks_iter(h_list, window_tuple, location, reader):
|
|
106
|
+
"""Generate extracted samples from a single large block read.
|
|
107
|
+
|
|
108
|
+
More efficient than individual reads when multiple samples are close together.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
h_list: List of headers within the block.
|
|
112
|
+
window_tuple: Tuple defining the block window (col_off, row_off, width, height).
|
|
113
|
+
location: Path to a vector file with sample locations.
|
|
114
|
+
reader: Raster reader instance.
|
|
115
|
+
|
|
116
|
+
Yields:
|
|
117
|
+
BasicGeoData: Extracted raster data with metadata and label.
|
|
118
|
+
"""
|
|
119
|
+
window = Window(*window_tuple)
|
|
120
|
+
|
|
121
|
+
# Data array to read from
|
|
122
|
+
data = reader.read_windows(window)
|
|
123
|
+
|
|
124
|
+
for i, h in enumerate(h_list):
|
|
125
|
+
row, col = OptimiseLabeledWindowsExtractor._slice_to_read(h.window, window)
|
|
126
|
+
data_h = data[:, row, col]
|
|
127
|
+
|
|
128
|
+
if LabeledWindowsExtractor.is_valid(data_h):
|
|
129
|
+
yield BasicGeoData(GeoDataHeader(h.idx, h.geometry, location), data_h, h.label)
|
|
130
|
+
|
|
131
|
+
class LabeledWindowsExtractor(AbstractExtractor):
|
|
132
|
+
"""Extract labeled windows from raster data around point locations.
|
|
133
|
+
|
|
134
|
+
Reads windows of specified size centered on labeled point locations from
|
|
135
|
+
vector data. Uses floor operation for pixel indexing, ensuring the pixel
|
|
136
|
+
containing the point is extracted.
|
|
137
|
+
|
|
138
|
+
Todo:
|
|
139
|
+
Save all metadata as a dictionary.
|
|
140
|
+
|
|
141
|
+
Attributes:
|
|
142
|
+
locations: Path to vector file with labeled points.
|
|
143
|
+
writer: GeoDataWriter for saving extracted samples.
|
|
144
|
+
raster_reader: Reader for the raster data.
|
|
145
|
+
windows_size: Size of extraction windows in pixels.
|
|
146
|
+
labelName: Name of the label field in vector data.
|
|
147
|
+
id_field: Name of the ID field in vector data.
|
|
148
|
+
geometryName: Name of the geometry field.
|
|
149
|
+
locationsCRS: Coordinate system of the location data.
|
|
150
|
+
rasterCRS: Coordinate system of the raster data.
|
|
151
|
+
show_progress: Whether to display progress bars.
|
|
152
|
+
mask: Optional geometry to filter extraction locations.
|
|
153
|
+
"""
|
|
154
|
+
def __init__(self,
|
|
155
|
+
locations: str,
|
|
156
|
+
writer: Union[GeoDataWriter, None],
|
|
157
|
+
raster_reader: RasterReader,
|
|
158
|
+
windows_size: int,
|
|
159
|
+
label_name: str = 'class',
|
|
160
|
+
id_field: str = None,
|
|
161
|
+
geometry_name: str = 'geometry',
|
|
162
|
+
mask_path: str = None,
|
|
163
|
+
show_progress: bool = True):
|
|
164
|
+
|
|
165
|
+
self.locations = locations
|
|
166
|
+
|
|
167
|
+
self.writer: GeoDataWriter = writer
|
|
168
|
+
|
|
169
|
+
self.raster_reader: AbstractRasterReader = raster_reader
|
|
170
|
+
self.windows_size: int = windows_size
|
|
171
|
+
self.labelName: str = label_name
|
|
172
|
+
self.id_field: str = id_field
|
|
173
|
+
self.geometryName: str = geometry_name
|
|
174
|
+
|
|
175
|
+
with fiona.open(self.locations) as locs:
|
|
176
|
+
self.locationsCRS = rasterio.crs.CRS.from_dict(locs.crs)
|
|
177
|
+
|
|
178
|
+
with rasterio.open(self.raster_reader.ref_raster()) as src:
|
|
179
|
+
self.rasterCRS = src.crs
|
|
180
|
+
|
|
181
|
+
self.show_progress = show_progress
|
|
182
|
+
|
|
183
|
+
self.mask = None
|
|
184
|
+
|
|
185
|
+
if mask_path is not None:
|
|
186
|
+
with fiona.open(mask_path) as mask:
|
|
187
|
+
self.mask = shapely.geometry.MultiPolygon([shape(feature[self.geometryName]) for feature in mask])
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _prepare(self):
|
|
191
|
+
|
|
192
|
+
headers = self._read_header(self.locations)
|
|
193
|
+
self._reproject(headers)
|
|
194
|
+
self._read_location(headers)
|
|
195
|
+
|
|
196
|
+
if self.mask is not None:
|
|
197
|
+
headers = self._filter_in_mask(headers)
|
|
198
|
+
|
|
199
|
+
return self._filter_in_raster(headers)
|
|
200
|
+
|
|
201
|
+
def _extract(self, samples, show_progress=True):
|
|
202
|
+
# if only one header arrive wrap it in a list
|
|
203
|
+
# can be used more easily by subclass
|
|
204
|
+
if not isinstance(samples, list):
|
|
205
|
+
samples = [samples]
|
|
206
|
+
|
|
207
|
+
num = len(samples)
|
|
208
|
+
|
|
209
|
+
# sort x and y if x equal may speed up du to cache
|
|
210
|
+
samples.sort(key=lambda h: (h.col, h.row))
|
|
211
|
+
|
|
212
|
+
with self.raster_reader as reader, self.writer as dst:
|
|
213
|
+
for sample in tqdm(basic_extract_iter(samples, self.locations, reader), total=num, disable=not show_progress):
|
|
214
|
+
dst.save(sample)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def process(self):
|
|
218
|
+
headers = self._prepare()
|
|
219
|
+
self._extract(headers, self.show_progress)
|
|
220
|
+
|
|
221
|
+
def _read_header(self, locations)->List[Header]:
|
|
222
|
+
with fiona.open(locations) as locs:
|
|
223
|
+
|
|
224
|
+
headers =[]
|
|
225
|
+
for i, feature in enumerate(locs):
|
|
226
|
+
label = feature['properties'][self.labelName]
|
|
227
|
+
geometry = shape(feature[self.geometryName])
|
|
228
|
+
idx = feature['properties'][self.id_field] if self.id_field is not None else i
|
|
229
|
+
|
|
230
|
+
headers.append(Header(label, geometry, idx))
|
|
231
|
+
|
|
232
|
+
return headers
|
|
233
|
+
|
|
234
|
+
def _reproject(self, headers):
|
|
235
|
+
# transform geom to match raster
|
|
236
|
+
if self.rasterCRS != self.locationsCRS:
|
|
237
|
+
geom = rasterio.warp.transform_geom(self.locationsCRS, self.rasterCRS, [h.geometry for h in headers])
|
|
238
|
+
for h, geo in zip(headers, geom):
|
|
239
|
+
h.geometry = shape(geo)
|
|
240
|
+
|
|
241
|
+
def _read_location(self, headers: List[Header]):
|
|
242
|
+
# loop over the feature to get coordinate
|
|
243
|
+
for header in headers:
|
|
244
|
+
# the methode return row column -> must invert (not the case for window
|
|
245
|
+
header.window = self.raster_reader.windows_for_center(header.geometry.x,
|
|
246
|
+
header.geometry.y,
|
|
247
|
+
self.windows_size,
|
|
248
|
+
op=math.floor)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _filter_in_mask(self, headers):
|
|
252
|
+
return list(filter(lambda h: self.mask.contains(h.geometry), headers))
|
|
253
|
+
|
|
254
|
+
def _filter_in_raster(self, headers: List[Header]):
|
|
255
|
+
return list(filter(lambda h: self.raster_reader.is_inside(h.window), headers))
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
def is_valid(data) -> bool:
|
|
260
|
+
return not np.isnan(data).any()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class OptimiseLabeledWindowsExtractor(LabeledWindowsExtractor):
|
|
264
|
+
"""
|
|
265
|
+
Extract the data window of the given size around pixel point
|
|
266
|
+
"""
|
|
267
|
+
def __init__(self,
|
|
268
|
+
locations: str,
|
|
269
|
+
writer: GeoDataWriter,
|
|
270
|
+
raster_reader: RasterReader,
|
|
271
|
+
windows_size: int,
|
|
272
|
+
label_name: str = 'class',
|
|
273
|
+
geometry_name: str = 'geometry',
|
|
274
|
+
mask_path: str = None,
|
|
275
|
+
show_progress: bool = True):
|
|
276
|
+
super().__init__(locations, writer, raster_reader, windows_size, label_name, geometry_name,
|
|
277
|
+
mask_path=mask_path,show_progress=show_progress)
|
|
278
|
+
|
|
279
|
+
self.to_write = -1
|
|
280
|
+
|
|
281
|
+
def _prepare(self):
|
|
282
|
+
|
|
283
|
+
header = super()._prepare()
|
|
284
|
+
self.to_write = len(header)
|
|
285
|
+
windows = self._list_windows(header)
|
|
286
|
+
self._merge_windows(windows)
|
|
287
|
+
return windows
|
|
288
|
+
|
|
289
|
+
def _extract(self, w_to_load, show_progress=True):
|
|
290
|
+
# extract
|
|
291
|
+
self._load_and_save(w_to_load, self.to_write, show_progress)
|
|
292
|
+
|
|
293
|
+
@staticmethod
|
|
294
|
+
def _slice_to_read(target: Window, src: Window):
|
|
295
|
+
|
|
296
|
+
# based on rasterio toslices
|
|
297
|
+
row_off = target.row_off - src.row_off
|
|
298
|
+
col_off = target.col_off - src.col_off
|
|
299
|
+
|
|
300
|
+
range_w = ((row_off, row_off + target.height),
|
|
301
|
+
(col_off, col_off + target.width))
|
|
302
|
+
|
|
303
|
+
return tuple(slice(*rng) for rng in range_w)
|
|
304
|
+
|
|
305
|
+
def _load_and_save(self, window_header_map, total, show_progress):
|
|
306
|
+
"""Merge the windows as 1 big windows and load the data inside
|
|
307
|
+
Keyword arguments:
|
|
308
|
+
"""
|
|
309
|
+
with self.raster_reader as reader, self.writer as writer:
|
|
310
|
+
with tqdm(total=total, disable= not show_progress) as pbar:
|
|
311
|
+
for window_tuple, h_list in window_header_map.items():
|
|
312
|
+
|
|
313
|
+
# quick test seems to show that loading 1 by one is faster on ssd when there are 5 samples or
|
|
314
|
+
# fewer in the block. More test needed for the good valye
|
|
315
|
+
if len(h_list)<5:
|
|
316
|
+
itera = basic_extract_iter(h_list, self.locations, reader)
|
|
317
|
+
else:
|
|
318
|
+
itera = extract_blocks_iter(h_list, window_tuple, self.locations, reader)
|
|
319
|
+
for sample in itera:
|
|
320
|
+
writer.save(sample)
|
|
321
|
+
pbar.update(1)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _list_windows(self, headers):
|
|
326
|
+
|
|
327
|
+
ref = self.raster_reader.ref_raster()
|
|
328
|
+
|
|
329
|
+
with rasterio.open(ref) as src:
|
|
330
|
+
block_shapes = src.block_shapes
|
|
331
|
+
|
|
332
|
+
w_to_load = {}
|
|
333
|
+
for h in headers:
|
|
334
|
+
|
|
335
|
+
block = round_window_to_full_blocks(h.window, block_shapes)
|
|
336
|
+
block = (block.col_off, block.row_off, block.width, block.height)
|
|
337
|
+
|
|
338
|
+
list_h = w_to_load.get(block, [])
|
|
339
|
+
list_h.append(h)
|
|
340
|
+
|
|
341
|
+
w_to_load[block] = list_h
|
|
342
|
+
|
|
343
|
+
return w_to_load
|
|
344
|
+
|
|
345
|
+
def _merge_windows(self, windows):
|
|
346
|
+
# list of (block, sample)
|
|
347
|
+
k1s = windows.copy()
|
|
348
|
+
|
|
349
|
+
for w1, l1 in k1s.items():
|
|
350
|
+
# we check each element against the original dic. If one block is contained in the other, we merge it
|
|
351
|
+
# inside the other and remnove the key from the list. then we iterate on the next element, on the merged
|
|
352
|
+
#list
|
|
353
|
+
for w2, l2 in windows.items():
|
|
354
|
+
if w1 != w2 and OptimiseLabeledWindowsExtractor.windows_is_inside(*w1, *w2):
|
|
355
|
+
l2.extend(l1)
|
|
356
|
+
# we merged so we remove from the dictionnary
|
|
357
|
+
del windows[w1]
|
|
358
|
+
# we can exit the loop
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
@staticmethod
|
|
362
|
+
def windows_is_inside(col_off_1, row_off_1, width_1, height_1, col_off_2, row_off_2, width_2, height_2):
|
|
363
|
+
"""check if w1 is inside w2"""
|
|
364
|
+
|
|
365
|
+
#(if right2 <= right1
|
|
366
|
+
# and left2 >= left1
|
|
367
|
+
# and top2 >= top1
|
|
368
|
+
# and bottom2 <= bottom1)
|
|
369
|
+
return (col_off_1 + width_1) <= (col_off_2 + width_2) and (col_off_1 >= col_off_2) and \
|
|
370
|
+
(row_off_1 + height_1) <= (row_off_2 + height_2) and (row_off_1 >= row_off_2)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class AsyncKernelWriter(Process):
|
|
374
|
+
"""write up to max_queue kernel asynchronously"""
|
|
375
|
+
def __init__(self, db_writer, n_reader, max_queue=100):
|
|
376
|
+
super().__init__(daemon=True)
|
|
377
|
+
self.queue = Queue(max_queue)
|
|
378
|
+
self.db_writer: LMDBWriter = db_writer
|
|
379
|
+
self.n_reader = n_reader
|
|
380
|
+
def run(self):
|
|
381
|
+
with self.db_writer:
|
|
382
|
+
while True:
|
|
383
|
+
window = self.queue.get()
|
|
384
|
+
if window is not None:
|
|
385
|
+
self.db_writer.save(window)
|
|
386
|
+
|
|
387
|
+
if window is None:
|
|
388
|
+
self.n_reader -= 1
|
|
389
|
+
if self.n_reader == 0:
|
|
390
|
+
return
|
|
391
|
+
|
|
392
|
+
def submit(self, kernel):
|
|
393
|
+
self.queue.put(kernel)
|
|
394
|
+
|
|
395
|
+
class AbstractPooledWindowsExtractor(OptimiseLabeledWindowsExtractor):
|
|
396
|
+
def __init__(self,
|
|
397
|
+
locations: str,
|
|
398
|
+
writer: GeoDataWriter,
|
|
399
|
+
raster_reader: RasterReader,
|
|
400
|
+
windows_size: int,
|
|
401
|
+
label_name: str = 'class',
|
|
402
|
+
geometry_name: str = 'geometry',
|
|
403
|
+
mask_path: str = None,
|
|
404
|
+
show_progress: bool = True,
|
|
405
|
+
worker=4,
|
|
406
|
+
prefetch=3):
|
|
407
|
+
super().__init__(locations, writer, raster_reader, windows_size, label_name, geometry_name,
|
|
408
|
+
mask_path=mask_path, show_progress=show_progress)
|
|
409
|
+
|
|
410
|
+
self.worker = worker
|
|
411
|
+
# self.reader_lock = threading.Lock()
|
|
412
|
+
self.writer = AsyncKernelWriter(writer, 1)
|
|
413
|
+
# we use semaphore from threading as they are thread internal
|
|
414
|
+
# the number is the max number of simultaneous cell processed
|
|
415
|
+
self.semaphore = threading.Semaphore(worker*prefetch)
|
|
416
|
+
self.pbar =None
|
|
417
|
+
|
|
418
|
+
def save_callback(self, future):
|
|
419
|
+
# release the semaphore to allow a new task
|
|
420
|
+
try:
|
|
421
|
+
result = future.result()
|
|
422
|
+
except Exception as e:
|
|
423
|
+
logger.error(f"Extractor failed: {e}")
|
|
424
|
+
raise
|
|
425
|
+
|
|
426
|
+
for sample in result:
|
|
427
|
+
self.writer.submit(sample)
|
|
428
|
+
self.pbar.update(1)
|
|
429
|
+
|
|
430
|
+
self.semaphore.release()
|
|
431
|
+
|
|
432
|
+
def submit_proxy(self, executor, function, *args, **kwargs):
|
|
433
|
+
# acquire the semaphore, blocks if occupied
|
|
434
|
+
self.semaphore.acquire()
|
|
435
|
+
# submit the task normally
|
|
436
|
+
future = executor.submit(function, *args, **kwargs)
|
|
437
|
+
# add the custom done callback
|
|
438
|
+
future.add_done_callback(self.save_callback)
|
|
439
|
+
return future
|
|
440
|
+
|
|
441
|
+
def init_pool(raster_r):
|
|
442
|
+
"""function to extract windows in the process pool"""
|
|
443
|
+
# Initialize pool processes global variables:
|
|
444
|
+
global glob_raster_reader
|
|
445
|
+
glob_raster_reader = copy.deepcopy(raster_r)
|
|
446
|
+
|
|
447
|
+
# proxy for submitting tasks that imposes a limit on the queue size
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def extract_many(h_list, window_tuple, location, raster_reader, read_lock=None):
|
|
451
|
+
"""Since we can not pickle generator we can not use the iterator in the process pool (copy past code here)"""
|
|
452
|
+
result = []
|
|
453
|
+
window = Window(*window_tuple)
|
|
454
|
+
|
|
455
|
+
# Data array to read from
|
|
456
|
+
if read_lock is not None:
|
|
457
|
+
with read_lock:
|
|
458
|
+
data = raster_reader.read_windows(window)
|
|
459
|
+
else:
|
|
460
|
+
data = raster_reader.read_windows(window)
|
|
461
|
+
|
|
462
|
+
for i, h in enumerate(h_list):
|
|
463
|
+
row, col = OptimiseLabeledWindowsExtractor._slice_to_read(h.window, window)
|
|
464
|
+
data_h = data[:, row, col]
|
|
465
|
+
|
|
466
|
+
if LabeledWindowsExtractor.is_valid(data_h):
|
|
467
|
+
result.append(BasicGeoData(GeoDataHeader(h.idx, h.geometry, location), data_h, h.label))
|
|
468
|
+
return result
|
|
469
|
+
|
|
470
|
+
def extract_many_pool(h_list, window_tuple, location):
|
|
471
|
+
with glob_raster_reader:
|
|
472
|
+
return extract_many(h_list, window_tuple, location, glob_raster_reader)
|
|
473
|
+
|
|
474
|
+
def extract_few(samples, location, raster_reader, read_lock=None):
|
|
475
|
+
"""Since we can not pickle generator we can not use the iterator in the process pool (copy past code here)"""
|
|
476
|
+
results = []
|
|
477
|
+
for header in samples:
|
|
478
|
+
if read_lock is not None:
|
|
479
|
+
with read_lock:
|
|
480
|
+
data = raster_reader.read_windows(header.window)
|
|
481
|
+
else:
|
|
482
|
+
data = raster_reader.read_windows(header.window)
|
|
483
|
+
|
|
484
|
+
if LabeledWindowsExtractor.is_valid(data):
|
|
485
|
+
results.append(BasicGeoData(GeoDataHeader(header.idx, header.geometry, location), data, header.label))
|
|
486
|
+
return results
|
|
487
|
+
|
|
488
|
+
def extract_few_pool(samples, location):
|
|
489
|
+
with glob_raster_reader:
|
|
490
|
+
return extract_few(samples, location, glob_raster_reader)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
class ProcessOptimiseLabeledWindowsExtractor(AbstractPooledWindowsExtractor):
|
|
494
|
+
"""
|
|
495
|
+
Extract the data window of the given size around pixel point. Use the optimized writing way. Is baked by a pool of
|
|
496
|
+
n process worker.
|
|
497
|
+
The sample are writen asynchronously to the db/
|
|
498
|
+
|
|
499
|
+
the executor is based on https://superfastpython.com/threadpoolexecutor-limit-pending-tasks/ to limite the number
|
|
500
|
+
of task. can try also with a thread pool. Possible similar perf with less complexity
|
|
501
|
+
"""
|
|
502
|
+
def __init__(self,
|
|
503
|
+
locations: str,
|
|
504
|
+
writer: GeoDataWriter,
|
|
505
|
+
raster_reader: RasterReader,
|
|
506
|
+
windows_size: int,
|
|
507
|
+
label_name: str = 'class',
|
|
508
|
+
geometry_name: str = 'geometry',
|
|
509
|
+
mask_path: str = None,
|
|
510
|
+
show_progress: bool = True,
|
|
511
|
+
worker=4,
|
|
512
|
+
prefetch=3):
|
|
513
|
+
super().__init__(locations, writer, raster_reader, windows_size, label_name, geometry_name, worker=worker,
|
|
514
|
+
prefetch=prefetch, mask_path=mask_path, show_progress=show_progress)
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _load_and_save(self, window_header_map, total, show_progress):
|
|
518
|
+
"""Merge the windows as 1 big windows and load the data inside
|
|
519
|
+
Keyword arguments:
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
self.writer.start()
|
|
523
|
+
self.pbar = tqdm(total=total, disable=not show_progress)
|
|
524
|
+
|
|
525
|
+
with ProcessPoolExecutor(max_workers=self.worker, initializer=init_pool, initargs=(self.raster_reader,))\
|
|
526
|
+
as executor:
|
|
527
|
+
for window_tuple, h_list in window_header_map.items():
|
|
528
|
+
if len(h_list) < 5:
|
|
529
|
+
self.submit_proxy(executor,
|
|
530
|
+
extract_few_pool,
|
|
531
|
+
h_list,
|
|
532
|
+
self.locations)
|
|
533
|
+
else:
|
|
534
|
+
self.submit_proxy(executor,
|
|
535
|
+
extract_many_pool,
|
|
536
|
+
h_list,
|
|
537
|
+
window_tuple,
|
|
538
|
+
self.locations)
|
|
539
|
+
# normally conserve the order. This may cost more memory
|
|
540
|
+
|
|
541
|
+
# older version of python don't have the good shutdown version
|
|
542
|
+
# for f in as_completed(futures):
|
|
543
|
+
# pass
|
|
544
|
+
#executor.shutdown(wait=True, cancel_futures=False)
|
|
545
|
+
|
|
546
|
+
#send the poison pill
|
|
547
|
+
self.writer.submit(None)
|
|
548
|
+
self.writer.join()
|
|
549
|
+
self.pbar.close()
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class ThreadedOptimiseLabeledWindowsExtractor(AbstractPooledWindowsExtractor):
|
|
553
|
+
"""
|
|
554
|
+
Extract the data window of the given size around pixel point. Use the optimized writing way. Is baked by a pool of
|
|
555
|
+
n thread worker. Threading improve performance because rasterio release the GIL.
|
|
556
|
+
The sample are writen asynchronously to the db/
|
|
557
|
+
|
|
558
|
+
the executor is based on https://superfastpython.com/threadpoolexecutor-limit-pending-tasks/ to limite the number
|
|
559
|
+
of task. can try also with a thread pool. Possible similar perf with less complexity
|
|
560
|
+
"""
|
|
561
|
+
|
|
562
|
+
def __init__(self,
|
|
563
|
+
locations: str,
|
|
564
|
+
writer: GeoDataWriter,
|
|
565
|
+
raster_reader: RasterReader,
|
|
566
|
+
windows_size: int,
|
|
567
|
+
label_name: str = 'class',
|
|
568
|
+
geometry_name: str = 'geometry',
|
|
569
|
+
mask_path: str = None,
|
|
570
|
+
show_progress: bool = True,
|
|
571
|
+
worker=4,
|
|
572
|
+
prefetch=3):
|
|
573
|
+
super().__init__(locations, writer, raster_reader, windows_size, label_name, geometry_name,
|
|
574
|
+
mask_path=mask_path, show_progress=show_progress, worker=worker, prefetch=prefetch)
|
|
575
|
+
|
|
576
|
+
self.reader_lock = threading.Lock()
|
|
577
|
+
|
|
578
|
+
def _load_and_save(self, window_header_map, total, show_progress):
|
|
579
|
+
"""Merge the windows as 1 big windows and load the data inside
|
|
580
|
+
Keyword arguments:
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
self.writer.start()
|
|
584
|
+
self.pbar = tqdm(total=total, disable=not show_progress)
|
|
585
|
+
|
|
586
|
+
# need to be called before the pool to make sur we close the reader after the pool is done executing
|
|
587
|
+
with self.raster_reader:
|
|
588
|
+
with ThreadPoolExecutor(max_workers=self.worker) as executor:
|
|
589
|
+
for window_tuple, h_list in window_header_map.items():
|
|
590
|
+
if len(h_list) < 5:
|
|
591
|
+
self.submit_proxy(executor, extract_few, h_list, self.locations, self.raster_reader, self.reader_lock)
|
|
592
|
+
else:
|
|
593
|
+
self.submit_proxy(executor, extract_many, h_list, window_tuple, self.locations, self.raster_reader, self.reader_lock)
|
|
594
|
+
|
|
595
|
+
#for f in as_completed(futures):
|
|
596
|
+
# pass
|
|
597
|
+
#executor.shutdown(wait=True, cancel_futures=False)
|
|
598
|
+
|
|
599
|
+
# send the poison pill
|
|
600
|
+
self.writer.submit(None)
|
|
601
|
+
self.writer.join()
|
|
602
|
+
self.pbar.close()
|
|
603
|
+
|
|
604
|
+
|