eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
"""Dataset creation utilities for shade and tree detection from satellite imagery.
|
|
2
|
+
|
|
3
|
+
This module provides tools for creating training datasets that combine high-resolution
|
|
4
|
+
shade/tree annotations with lower-resolution satellite imagery. Handles spatial alignment,
|
|
5
|
+
temporal matching, and preprocessing of multi-source geospatial data.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import contextlib
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import shutil
|
|
12
|
+
import tempfile
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import rasterio
|
|
18
|
+
|
|
19
|
+
from libterra_gis.environement import get_write_profile, get_read_profile
|
|
20
|
+
from libterra_gis.raster_utils import RasterImage
|
|
21
|
+
from libterra_gis.shade.heuristic import extract_soil, post_process_hand_cleaned
|
|
22
|
+
from eoml.raster.band import Band
|
|
23
|
+
from eoml.raster.raster_reader import AbstractRasterReader, append_raster_reader, RasterReader
|
|
24
|
+
from eoml.torch.cnn.torch_utils import aligned_bound
|
|
25
|
+
from eoml.torch.cnn.training_dataset import BasicTrainingDataset, BasicYearTrainingDataset
|
|
26
|
+
from rasterio.enums import Resampling
|
|
27
|
+
from rasterio.windows import Window
|
|
28
|
+
from torch.utils.data import Dataset, random_split
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ShadeMatchNNInput(Dataset):
|
|
34
|
+
"""Match and align high-resolution annotations with satellite imagery for CNN training.
|
|
35
|
+
|
|
36
|
+
Creates training datasets by spatially aligning high-resolution shade/tree annotations
|
|
37
|
+
with corresponding satellite imagery, handling different resolutions and projections.
|
|
38
|
+
Supports temporal matching when multiple acquisition dates are available.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
input_raster_reader (AbstractRasterReader or dict): Reader(s) for satellite imagery.
|
|
42
|
+
Can be single reader or dict mapping years to readers for temporal matching.
|
|
43
|
+
size (int): Kernel/window size for CNN.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self,
|
|
47
|
+
input_raster_reader: AbstractRasterReader | dict[int, AbstractRasterReader],
|
|
48
|
+
size: int):
|
|
49
|
+
"""Initialize ShadeMatchNNInput.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
input_raster_reader (AbstractRasterReader or dict): RasterReader for satellite
|
|
53
|
+
imagery, or dict mapping years to readers for temporal datasets.
|
|
54
|
+
size (int): Size of CNN input window.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
self.input_raster_reader: AbstractRasterReader | dict[int, AbstractRasterReader] = input_raster_reader
|
|
58
|
+
self.size: int = size
|
|
59
|
+
|
|
60
|
+
def _input_windows(self, input, raw, precision=0.01):
|
|
61
|
+
"""
|
|
62
|
+
Compute the input windows with a shrink of 1 pixel.
|
|
63
|
+
We assume pixel is area. pixel "point" is located at the top left of the pixel and bounding is the real bounding
|
|
64
|
+
box. This mean boding box is effectively at pixel (0, 0) and (length, length). the actual array pixel as last
|
|
65
|
+
pixel at length-1, length-1
|
|
66
|
+
|
|
67
|
+
the pixel is included if precision percent of it is covered by the raw raster
|
|
68
|
+
"""
|
|
69
|
+
(left, bottom, right, top) = rasterio.warp.transform_bounds(raw.crs, input.crs, *raw.bounds)
|
|
70
|
+
|
|
71
|
+
return aligned_bound(left, bottom, right, top, input.transform, precision)
|
|
72
|
+
|
|
73
|
+
def mask(self, value, bound, mask: AbstractRasterReader, transformer):
|
|
74
|
+
"""Act on the stacked raster """
|
|
75
|
+
mask_data=None
|
|
76
|
+
if mask is not None:
|
|
77
|
+
with mask:
|
|
78
|
+
mask_data = mask.read_bound(bound)
|
|
79
|
+
|
|
80
|
+
masked = transformer(value, mask_data)
|
|
81
|
+
|
|
82
|
+
return masked
|
|
83
|
+
|
|
84
|
+
def generate_target_raster(self, nn_out_readers, nn_input_raster, path, mask=None, mask_f=None, save_all=False,
|
|
85
|
+
precision=0.01):
|
|
86
|
+
"""
|
|
87
|
+
Generate the target raster (the ideal output of the nn network
|
|
88
|
+
:param nn_out_readers:
|
|
89
|
+
:param nn_out_readers:
|
|
90
|
+
:param path:
|
|
91
|
+
:param mask: extra information used to mask given pixel
|
|
92
|
+
:param mask_f: function used to mask pixel
|
|
93
|
+
:param save_all:
|
|
94
|
+
:param precision:
|
|
95
|
+
:return:
|
|
96
|
+
"""
|
|
97
|
+
out = []
|
|
98
|
+
half = self.size // 2
|
|
99
|
+
|
|
100
|
+
# sentinel we want to reproject to
|
|
101
|
+
target = nn_input_raster.ref_raster_info()
|
|
102
|
+
# nn output
|
|
103
|
+
src = nn_out_readers.ref_raster_info()
|
|
104
|
+
n_bands = nn_out_readers.n_band
|
|
105
|
+
|
|
106
|
+
windows = self._input_windows(target, src, precision=precision)
|
|
107
|
+
|
|
108
|
+
sample_raster_reader = append_raster_reader([nn_input_raster, nn_out_readers],
|
|
109
|
+
0,
|
|
110
|
+
nn_input_raster.read_profile)
|
|
111
|
+
|
|
112
|
+
windows = Window(windows.col_off - half, windows.row_off - half, windows.width + self.size,
|
|
113
|
+
windows.height + self.size)
|
|
114
|
+
|
|
115
|
+
with sample_raster_reader:
|
|
116
|
+
values = sample_raster_reader.read_windows(windows)
|
|
117
|
+
|
|
118
|
+
#remove non output value
|
|
119
|
+
if not save_all:
|
|
120
|
+
values = values[-n_bands:]
|
|
121
|
+
|
|
122
|
+
if mask_f is not None:
|
|
123
|
+
bounds = target.window_bounds(windows)
|
|
124
|
+
values = self.mask(values, bounds, mask, mask_f)
|
|
125
|
+
|
|
126
|
+
out.append(values)
|
|
127
|
+
|
|
128
|
+
self.export_result(values, target.window_transform(windows), target.crs, path)
|
|
129
|
+
|
|
130
|
+
return out
|
|
131
|
+
|
|
132
|
+
def export_result(self, values, transform, crs, out_path):
|
|
133
|
+
|
|
134
|
+
write_profile = get_write_profile()
|
|
135
|
+
write_profile.update({
|
|
136
|
+
'transform': transform,
|
|
137
|
+
'width': values.shape[2],
|
|
138
|
+
'height': values.shape[1],
|
|
139
|
+
'count': values.shape[0],
|
|
140
|
+
'dtype': "float32",
|
|
141
|
+
'crs': crs,
|
|
142
|
+
'nodata': 255})
|
|
143
|
+
|
|
144
|
+
with rasterio.open(out_path, 'w', **write_profile) as src:
|
|
145
|
+
src.write(values)
|
|
146
|
+
|
|
147
|
+
def create_nn_output(self,
|
|
148
|
+
in_folder,
|
|
149
|
+
out_folder,
|
|
150
|
+
transformer=None,
|
|
151
|
+
mask=None,
|
|
152
|
+
mask_f=None,
|
|
153
|
+
save_all=True,
|
|
154
|
+
read_profile=None,
|
|
155
|
+
sharing=True):
|
|
156
|
+
|
|
157
|
+
file_out = []
|
|
158
|
+
date_out = []
|
|
159
|
+
|
|
160
|
+
if read_profile is None:
|
|
161
|
+
read_profile = get_read_profile()
|
|
162
|
+
|
|
163
|
+
output_rasters, dates = self.get_outputs_raster(in_folder)
|
|
164
|
+
|
|
165
|
+
Path(out_folder).mkdir(parents=True, exist_ok=True)
|
|
166
|
+
|
|
167
|
+
for out_raster, date in zip(output_rasters, dates):
|
|
168
|
+
raster_name = Path(out_raster).stem
|
|
169
|
+
|
|
170
|
+
out_raster_reader = RasterReader(out_raster,
|
|
171
|
+
Band.from_file(out_raster),
|
|
172
|
+
transformer,
|
|
173
|
+
Resampling.average,
|
|
174
|
+
read_profile,
|
|
175
|
+
sharing)
|
|
176
|
+
|
|
177
|
+
out_p = f"{out_folder}/{raster_name}.tif"
|
|
178
|
+
|
|
179
|
+
if isinstance(self.input_raster_reader, dict):
|
|
180
|
+
# if raster is not available at date skip
|
|
181
|
+
nn_input_raster = self.input_raster_reader.get(date, None)
|
|
182
|
+
else:
|
|
183
|
+
nn_input_raster = self.input_raster_reader
|
|
184
|
+
|
|
185
|
+
if nn_input_raster is not None:
|
|
186
|
+
self.generate_target_raster(out_raster_reader,
|
|
187
|
+
nn_input_raster,
|
|
188
|
+
out_p,
|
|
189
|
+
mask=mask, mask_f=mask_f, save_all=save_all)
|
|
190
|
+
|
|
191
|
+
file_out.append(out_p)
|
|
192
|
+
date_out.append(date)
|
|
193
|
+
else:
|
|
194
|
+
logger.info(f"skipping {out_raster} at year {date}")
|
|
195
|
+
|
|
196
|
+
return file_out, date_out
|
|
197
|
+
|
|
198
|
+
def get_outputs_raster(self, in_folder) -> (list[str], list[datetime]):
|
|
199
|
+
""" read the raster and date in the input folder"""
|
|
200
|
+
tif_path = []
|
|
201
|
+
dates = []
|
|
202
|
+
|
|
203
|
+
tifs = [f.path for f in os.scandir(in_folder) if f.name.endswith('.tif')]
|
|
204
|
+
for tif in tifs:
|
|
205
|
+
logger.info(f"reading date for {tif}")
|
|
206
|
+
tif_path.append(tif)
|
|
207
|
+
dates.append(self._read_date(f'{tif[:-4]}.txt'))
|
|
208
|
+
|
|
209
|
+
return tifs, dates
|
|
210
|
+
|
|
211
|
+
def _read_date(self, date):
|
|
212
|
+
with open(date, "r") as f:
|
|
213
|
+
for line in f.readlines():
|
|
214
|
+
before, match, after = line.partition(":")
|
|
215
|
+
if before.strip() == 'date':
|
|
216
|
+
return datetime.strptime(after.strip(), '%d/%m/%Y').year
|
|
217
|
+
|
|
218
|
+
raise "no date found"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class DatasetPostProcessor:
|
|
222
|
+
""""""
|
|
223
|
+
|
|
224
|
+
def __init__(self, image_post_processor):
|
|
225
|
+
|
|
226
|
+
self.image_post_processor = image_post_processor
|
|
227
|
+
|
|
228
|
+
def process_folder(self,
|
|
229
|
+
in_folder,
|
|
230
|
+
out_folder,
|
|
231
|
+
targets=('shade.tif', 'soil.tif'),
|
|
232
|
+
reference='google.tif',
|
|
233
|
+
mode="match",
|
|
234
|
+
info=None,
|
|
235
|
+
outname=None, ):
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
:param in_folder: folder containing the sample, it will be recursively scanned
|
|
239
|
+
:param out_folder:
|
|
240
|
+
:param targets:
|
|
241
|
+
:param reference: input containing the geo-location information
|
|
242
|
+
:param mode: using mode="match",same in and out folder (will just add the file to the folder),
|
|
243
|
+
mode is "one" = all in one folder or "same" == same file structure
|
|
244
|
+
:param info:
|
|
245
|
+
:param outname:
|
|
246
|
+
:return:
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
# create the new folder and put all the image in one folder
|
|
250
|
+
if mode == "one":
|
|
251
|
+
Path(out_folder).mkdir(parents=True, exist_ok=True)
|
|
252
|
+
|
|
253
|
+
save_folder = out_folder
|
|
254
|
+
|
|
255
|
+
#return object with more info keep only name
|
|
256
|
+
dirs = [f.name for f in os.scandir(in_folder) if f.is_dir()]
|
|
257
|
+
for dir in dirs:
|
|
258
|
+
logger.info(f"processing {dir}")
|
|
259
|
+
# We use the same folder and put each image in a different folder
|
|
260
|
+
# if we use one we put everything in out_folder
|
|
261
|
+
|
|
262
|
+
f_outname = None
|
|
263
|
+
if mode == "match":
|
|
264
|
+
save_folder = os.path.join(out_folder, dir)
|
|
265
|
+
Path(save_folder).mkdir(parents=True, exist_ok=True)
|
|
266
|
+
# name is just out name or folder name
|
|
267
|
+
f_outname = f'{outname}' if outname is not None else f'{dir}'
|
|
268
|
+
|
|
269
|
+
if mode == "one":
|
|
270
|
+
# we add dir name to separe file
|
|
271
|
+
f_outname = f'{dir}_{outname}' if outname is not None else f'{dir}'
|
|
272
|
+
|
|
273
|
+
if f_outname is None:
|
|
274
|
+
raise "mode should be match or one"
|
|
275
|
+
|
|
276
|
+
if isinstance(targets, (list, tuple)):
|
|
277
|
+
targets_full_p = [os.path.join(in_folder, dir, f) for f in targets]
|
|
278
|
+
else:
|
|
279
|
+
targets_full_p = [os.path.join(in_folder, dir, targets)]
|
|
280
|
+
|
|
281
|
+
reference_full_p = os.path.join(in_folder, dir, reference)
|
|
282
|
+
|
|
283
|
+
out_path = os.path.join(save_folder, f'{f_outname}.tif')
|
|
284
|
+
|
|
285
|
+
logger.info(f"saving at {out_path}")
|
|
286
|
+
raster = self.image_post_processor.process(targets_full_p, reference_full_p)
|
|
287
|
+
raster.save(out_path)
|
|
288
|
+
|
|
289
|
+
if info is not None:
|
|
290
|
+
if mode == "same":
|
|
291
|
+
shutil.copyfile(os.path.join(in_folder, dir, info), os.path.join(save_folder, info))
|
|
292
|
+
else:
|
|
293
|
+
shutil.copyfile(os.path.join(in_folder, dir, info), os.path.join(save_folder, f'{f_outname}.txt'))
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class SimpleShadeImagePostProcessor:
|
|
297
|
+
""" Post process data to binary mask band"""
|
|
298
|
+
|
|
299
|
+
def __init__(self, threshold=10, value=1, add_constraint_band=True, overwrite=True):
|
|
300
|
+
|
|
301
|
+
self.threshold = threshold
|
|
302
|
+
self.value = value
|
|
303
|
+
|
|
304
|
+
self.add_constraint_band = add_constraint_band
|
|
305
|
+
self.overwrite = overwrite
|
|
306
|
+
|
|
307
|
+
def process(self, targets_path, reference):
|
|
308
|
+
|
|
309
|
+
targets = [RasterImage.from_file(target) for target in targets_path]
|
|
310
|
+
reference = RasterImage.from_file(reference)
|
|
311
|
+
|
|
312
|
+
self.glue_geo_info(targets, reference, targets_path)
|
|
313
|
+
|
|
314
|
+
self.threshold_bl_value(targets)
|
|
315
|
+
post_processed = self.stack_image(targets)
|
|
316
|
+
|
|
317
|
+
# need float for interpolation to work
|
|
318
|
+
post_processed.data = post_processed.data.astype('float32')
|
|
319
|
+
|
|
320
|
+
return post_processed
|
|
321
|
+
|
|
322
|
+
def glue_geo_info(self, targets, reference, targets_path):
|
|
323
|
+
meta = reference.meta
|
|
324
|
+
for i, (raster, path) in enumerate(zip(targets, targets_path)):
|
|
325
|
+
raster.meta = meta
|
|
326
|
+
|
|
327
|
+
if self.overwrite:
|
|
328
|
+
raster.save(path)
|
|
329
|
+
|
|
330
|
+
def threshold_bl_value(self, target):
|
|
331
|
+
""" Transform the image to black and white and apply the threshold to have value from 0 to 1"""
|
|
332
|
+
for raster in target:
|
|
333
|
+
raster.data = post_process_hand_cleaned(raster.data, value=self.value)
|
|
334
|
+
|
|
335
|
+
def stack_image(self, targets):
|
|
336
|
+
""" Stack the image"""
|
|
337
|
+
stacked = []
|
|
338
|
+
post_processed = RasterImage(None, None)
|
|
339
|
+
for raster in targets:
|
|
340
|
+
stacked.append(raster.data)
|
|
341
|
+
|
|
342
|
+
_fix_overlaping_pixel(stacked)
|
|
343
|
+
|
|
344
|
+
if self.add_constraint_band:
|
|
345
|
+
stacked = _add_constraint_band(stacked)
|
|
346
|
+
|
|
347
|
+
post_processed.data = np.array(stacked)
|
|
348
|
+
post_processed.meta = targets[0].meta
|
|
349
|
+
|
|
350
|
+
return post_processed
|
|
351
|
+
|
|
352
|
+
class SoilShadeMaskImagePostProcessor(SimpleShadeImagePostProcessor):
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def __init__(self, threshold=10, value=1, add_constraint_band=True, overwrite=True):
|
|
356
|
+
super().__init__(threshold, value, add_constraint_band, overwrite)
|
|
357
|
+
|
|
358
|
+
def process(self, targets, reference):
|
|
359
|
+
|
|
360
|
+
assert len(targets)>1
|
|
361
|
+
|
|
362
|
+
#run the simple image pre processor
|
|
363
|
+
raster = super().process(targets[:-1], reference)
|
|
364
|
+
# add the constraint band
|
|
365
|
+
constraint = RasterImage.from_file(targets[-1])
|
|
366
|
+
self.threshold_bl_value([constraint])
|
|
367
|
+
raster.data = np.concatenate( (raster.data, [constraint.data]),axis=0 )
|
|
368
|
+
return raster
|
|
369
|
+
|
|
370
|
+
def _fix_overlaping_pixel(stacked):
|
|
371
|
+
"""Assume that a pixel is 100% on category. Set the pixel in the following band to 0"""
|
|
372
|
+
mask = np.zeros_like(stacked[0], dtype=bool)
|
|
373
|
+
for b in stacked:
|
|
374
|
+
b[...] = b * np.logical_not(mask)
|
|
375
|
+
mask += np.logical_or(mask, b == 1)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _add_constraint_band(stacked):
|
|
379
|
+
"""add a band with a value such that Sum(band)==1
|
|
380
|
+
Maybe need to be done after reprojection"""
|
|
381
|
+
last_b = np.ones_like(stacked[0])
|
|
382
|
+
|
|
383
|
+
for b in stacked:
|
|
384
|
+
last_b = last_b - b
|
|
385
|
+
stacked.append(last_b)
|
|
386
|
+
return stacked
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class ShadeDatasSetCreator:
|
|
390
|
+
"""Create PyTorch datasets for shade/tree detection from preprocessed imagery.
|
|
391
|
+
|
|
392
|
+
Orchestrates the full pipeline: preprocessing raw annotations, spatial alignment with
|
|
393
|
+
satellite imagery, and creation of PyTorch training/validation datasets. Supports
|
|
394
|
+
temporal datasets with year as additional input.
|
|
395
|
+
|
|
396
|
+
Attributes:
|
|
397
|
+
training_date (list): Years corresponding to training rasters.
|
|
398
|
+
training_raster (list): Paths to processed training raster files.
|
|
399
|
+
in_folder (str): Input folder containing raw annotations.
|
|
400
|
+
preprocessor (DatasetPostProcessor): Handles annotation preprocessing.
|
|
401
|
+
dataset_creator (ShadeMatchNNInput): Creates spatially-aligned datasets.
|
|
402
|
+
with_year (bool): Whether to include year as neural network input.
|
|
403
|
+
year_normalisation (int): Normalization divisor for year values.
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
def __init__(self,
|
|
407
|
+
in_folder: str,
|
|
408
|
+
post_process_f,
|
|
409
|
+
dataset_creator: ShadeMatchNNInput,
|
|
410
|
+
with_year: bool=False,
|
|
411
|
+
year_normalisation: int=2500):
|
|
412
|
+
"""Initialize ShadeDatasSetCreator.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
in_folder (str): Input folder containing raw annotation images.
|
|
416
|
+
post_process_f: Post-processing function for cleaning/transforming annotations.
|
|
417
|
+
dataset_creator (ShadeMatchNNInput): Object for creating aligned datasets.
|
|
418
|
+
with_year (bool, optional): Include year as NN input feature. Defaults to False.
|
|
419
|
+
year_normalisation (int, optional): Divisor for normalizing year values to [0,1].
|
|
420
|
+
Input becomes year/year_normalisation. Defaults to 2500.
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
self.training_date = None
|
|
425
|
+
self.training_raster = None
|
|
426
|
+
self.in_folder = in_folder
|
|
427
|
+
# Pre processor to fix clean value to hand cleaned
|
|
428
|
+
self.preprocessor = DatasetPostProcessor(post_process_f)
|
|
429
|
+
self.dataset_creator = dataset_creator
|
|
430
|
+
|
|
431
|
+
self.with_year = with_year
|
|
432
|
+
self.year_normalisation = year_normalisation
|
|
433
|
+
|
|
434
|
+
def prepare_dataset(self,
|
|
435
|
+
mask_raster,
|
|
436
|
+
mask_f,
|
|
437
|
+
training_folder,
|
|
438
|
+
targets=('shade.tif', 'soil.tif'),
|
|
439
|
+
reference='google.tif',
|
|
440
|
+
mode="one",
|
|
441
|
+
info='info.txt',
|
|
442
|
+
save_all=True,
|
|
443
|
+
outname="out",
|
|
444
|
+
pre_process_folder=None):
|
|
445
|
+
"""
|
|
446
|
+
:param mask_raster: given as an input to the dataset make to further mask pixel.
|
|
447
|
+
:param mask_f: used by the dataset creator to mask pixel
|
|
448
|
+
:param training_folder:
|
|
449
|
+
:param targets:
|
|
450
|
+
:param reference:
|
|
451
|
+
:param mode:
|
|
452
|
+
:param info:
|
|
453
|
+
:param save_all:
|
|
454
|
+
:param outname:
|
|
455
|
+
:param pre_process_folder:
|
|
456
|
+
:return:
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
context = tempfile.TemporaryDirectory() if pre_process_folder is None else contextlib.nullcontext(pre_process_folder)
|
|
460
|
+
|
|
461
|
+
with context as pre_process_dir:
|
|
462
|
+
self.preprocessor.process_folder(in_folder=self.in_folder,
|
|
463
|
+
out_folder=pre_process_dir, # out_folder_shade
|
|
464
|
+
targets=targets,
|
|
465
|
+
reference=reference,
|
|
466
|
+
mode=mode,
|
|
467
|
+
info=info,
|
|
468
|
+
outname=outname)
|
|
469
|
+
|
|
470
|
+
#
|
|
471
|
+
# create raster reader
|
|
472
|
+
#
|
|
473
|
+
|
|
474
|
+
self.training_raster, self.training_date = self.dataset_creator.create_nn_output(pre_process_dir,
|
|
475
|
+
training_folder,
|
|
476
|
+
mask=mask_raster,
|
|
477
|
+
mask_f=mask_f,
|
|
478
|
+
save_all=save_all)
|
|
479
|
+
|
|
480
|
+
def nn_training_dataset(self,
|
|
481
|
+
width,
|
|
482
|
+
n_out,
|
|
483
|
+
out_function,
|
|
484
|
+
test_size=0.15,
|
|
485
|
+
validation_size=0.15,
|
|
486
|
+
transformer=None):
|
|
487
|
+
''' Given the raster reader, create the dataset the train the nn'''
|
|
488
|
+
|
|
489
|
+
#pixel_at_band = PixelAtBandSkipValue(center_pixel, center_pixel, not_coffee)
|
|
490
|
+
|
|
491
|
+
if self.with_year:
|
|
492
|
+
dataset = BasicYearTrainingDataset(self.training_raster, self.training_date, width, 1, n_out, out_function,
|
|
493
|
+
f_transform=transformer, year_normalisation=self.year_normalisation)
|
|
494
|
+
else:
|
|
495
|
+
dataset = BasicTrainingDataset(self.training_raster, width, 1, n_out, out_function, f_transform=transformer)
|
|
496
|
+
|
|
497
|
+
#train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=test_size)
|
|
498
|
+
|
|
499
|
+
#print([1 - test_size - validation_size, test_size, validation_size])
|
|
500
|
+
|
|
501
|
+
train_idx, test_id, val_idx = random_split(dataset,
|
|
502
|
+
[1 - test_size - validation_size, test_size, validation_size])
|
|
503
|
+
#print(len(train_idx), len(test_id), len(val_idx))
|
|
504
|
+
|
|
505
|
+
return {'train': train_idx, 'test': test_id, "val": val_idx}
|
|
506
|
+
|
|
507
|
+
#train_idx, test_id, val_idx = random_split(list(range(len(dataset))), [1 - test_size - validation_size, test_size, validation_size])
|
|
508
|
+
#return {'train': Subset(dataset, train_idx), 'test': Subset(dataset, test_id), "val": Subset(dataset, val_idx)}
|
|
509
|
+
|
|
510
|
+
def nn_per_image_dataset(self,
|
|
511
|
+
width,
|
|
512
|
+
n_out,
|
|
513
|
+
out_function,
|
|
514
|
+
transformer=None):
|
|
515
|
+
''' Return one dataset per image.'''
|
|
516
|
+
|
|
517
|
+
if self.with_year:
|
|
518
|
+
dataset = [(BasicYearTrainingDataset([raster], [date], width, 1, n_out, out_function,
|
|
519
|
+
f_transform=transformer, year_normalisation=self.year_normalisation), raster, date)
|
|
520
|
+
for raster, date in zip (self.training_raster, self.training_date)]
|
|
521
|
+
else:
|
|
522
|
+
dataset = [(BasicTrainingDataset([raster], width, 1, n_out, out_function, f_transform=transformer), raster)
|
|
523
|
+
for raster in self.training_raster]
|
|
524
|
+
|
|
525
|
+
return dataset
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def write_mask(self, mask_function, out_folder):
|
|
532
|
+
''' Map all the sample of this dataset. This can be used to map a validation dataset and compare the results'''
|
|
533
|
+
for raster_path in self.training_raster:
|
|
534
|
+
raster = RasterImage.from_file(raster_path)
|
|
535
|
+
raster.data= np.apply_along_axis(mask_function, 0, raster.data)
|
|
536
|
+
|
|
537
|
+
raster.save(f"{out_folder}/{Path(raster_path).name}")
|