eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
eoml/automation/tasks.py
ADDED
|
@@ -0,0 +1,825 @@
|
|
|
1
|
+
"""Automation tasks for machine learning workflows with Earth observation data.
|
|
2
|
+
|
|
3
|
+
This module provides high-level functions and utilities for automating common machine learning
|
|
4
|
+
workflows including data extraction, model training, validation, and inference mapping. It serves
|
|
5
|
+
as the orchestration layer that connects data preparation, model creation, training loops, and
|
|
6
|
+
prediction generation.
|
|
7
|
+
|
|
8
|
+
Key functionality:
|
|
9
|
+
- Sample extraction from raster data and vector labels
|
|
10
|
+
- K-fold cross-validation setup for training/validation splits
|
|
11
|
+
- Dataset and dataloader configuration with augmentation
|
|
12
|
+
- Model instantiation and initialization
|
|
13
|
+
- Training orchestration with multiple datasets
|
|
14
|
+
- Map generation using trained models
|
|
15
|
+
- Statistics computation for trained models
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import itertools
|
|
19
|
+
import logging
|
|
20
|
+
from os import path
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
import fiona
|
|
25
|
+
import torch
|
|
26
|
+
|
|
27
|
+
from eoml import get_read_profile, get_write_profile
|
|
28
|
+
from eoml.data.dataset_utils import k_fold_sample, random_split
|
|
29
|
+
from eoml.data.persistence.generic import GeoDataWriter
|
|
30
|
+
from eoml.data.persistence.lmdb import LMDBasicGeoDataDAO, LMDBReader
|
|
31
|
+
from eoml.raster.dataset.extractor import ThreadedOptimiseLabeledWindowsExtractor
|
|
32
|
+
from eoml.torch.cnn.augmentation import rotate_flip_transform
|
|
33
|
+
from eoml.torch.cnn.db_dataset import sample_list_id, DBDataset, DBDatasetMeta, DBInfo, MultiDBDataset, \
|
|
34
|
+
db_dataset_multi_proc_init
|
|
35
|
+
from eoml.torch.cnn.map_dataset import IterableMapDataset, NNMapper
|
|
36
|
+
from eoml.torch.cnn.torch_utils import meta_data_collate, batch_collate
|
|
37
|
+
from eoml.torch.models import initialize_weights, ModelFactory
|
|
38
|
+
from eoml.torch.sample_statistic import ClassificationStats, BadlyClassifyToGPKG
|
|
39
|
+
from eoml.torch.trainer import GradNormClipper, agressive_train_labeling
|
|
40
|
+
|
|
41
|
+
from rasterio.enums import Resampling
|
|
42
|
+
from shapely.geometry import shape
|
|
43
|
+
from torch.utils.data import RandomSampler, BatchSampler, DataLoader, WeightedRandomSampler
|
|
44
|
+
|
|
45
|
+
from rasterop.tiled_op.tiled_raster_op import tiled_op
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class KFoldIterator:
|
|
49
|
+
"""Iterator for K-fold cross-validation splits.
|
|
50
|
+
|
|
51
|
+
Constructs sequences of training and validation folds based on a list of folds
|
|
52
|
+
and split specifications. Each split is defined as ((train_fold_indices), (validation_fold_indices)).
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
folds: List of fold data, where each fold contains sample identifiers
|
|
56
|
+
folds_split: List of tuples defining the train/validation split for each iteration.
|
|
57
|
+
Each tuple is ((train_indices,), (val_indices,))
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, folds, folds_split):
|
|
61
|
+
"""Initialize the K-fold iterator.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
folds: List of folds, where each fold is a list of sample identifiers
|
|
65
|
+
folds_split: List of split specifications, where each split is a tuple
|
|
66
|
+
((train_fold_indices), (validation_fold_indices))
|
|
67
|
+
"""
|
|
68
|
+
self.folds = folds
|
|
69
|
+
self.folds_split = folds_split
|
|
70
|
+
|
|
71
|
+
def __iter__(self):
|
|
72
|
+
"""Iterate through all fold splits.
|
|
73
|
+
|
|
74
|
+
Yields:
|
|
75
|
+
Tuple of (train_fold, validation_fold), where each fold is a flat list
|
|
76
|
+
of sample identifiers
|
|
77
|
+
"""
|
|
78
|
+
for split in self.folds_split:
|
|
79
|
+
train_fold = []
|
|
80
|
+
validation_fold = []
|
|
81
|
+
for s in split[0]:
|
|
82
|
+
train_fold.extend(self.folds[s])
|
|
83
|
+
for s in split[1]:
|
|
84
|
+
validation_fold.extend(self.folds[s])
|
|
85
|
+
|
|
86
|
+
yield train_fold, validation_fold
|
|
87
|
+
|
|
88
|
+
def __len__(self):
|
|
89
|
+
"""Return the number of splits.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Number of train/validation splits
|
|
93
|
+
"""
|
|
94
|
+
return len(self.folds_split)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# todo use torch script for model
|
|
98
|
+
def extract_sample(gps_path, raster_reader, db_path, windows_size, label_name="Class", id_field=None, mask_path=None, force_write=False):
|
|
99
|
+
"""Extract training samples from raster data based on GPS/vector labels.
|
|
100
|
+
|
|
101
|
+
Extracts image windows from raster data at locations specified by vector geometries
|
|
102
|
+
(GPS points) and stores them in an LMDB database for efficient training data access.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
gps_path: Path to vector file (GeoPackage, Shapefile, etc.) containing sample locations
|
|
106
|
+
raster_reader: RasterReader instance for reading the source imagery
|
|
107
|
+
db_path: Output path for the LMDB database
|
|
108
|
+
windows_size: Size of the extraction window (in pixels)
|
|
109
|
+
label_name: Name of the attribute field containing class labels. Defaults to "Class"
|
|
110
|
+
id_field: Name of the unique identifier field. Defaults to None (uses feature index)
|
|
111
|
+
mask_path: Optional path to mask polygon restricting extraction area
|
|
112
|
+
force_write: If True, overwrite existing database. Defaults to False
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
None. Writes extracted samples to database at db_path
|
|
116
|
+
"""
|
|
117
|
+
if force_write or not path.exists(db_path):
|
|
118
|
+
writer = GeoDataWriter(LMDBasicGeoDataDAO(db_path))
|
|
119
|
+
|
|
120
|
+
#extractor = LabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name, id_field)
|
|
121
|
+
|
|
122
|
+
#extractor = LabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name)
|
|
123
|
+
#extractor = OptimiseLabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size,
|
|
124
|
+
# label_name, id_field, mask_path=mask_path)
|
|
125
|
+
extractor = ThreadedOptimiseLabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name,
|
|
126
|
+
id_field, mask_path=mask_path, worker=4, prefetch=3)
|
|
127
|
+
#extractor = ProcessOptimiseLabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name,
|
|
128
|
+
# id_field, mask_path=mask_path, worker=4, prefetch=3)
|
|
129
|
+
|
|
130
|
+
extractor.process()
|
|
131
|
+
else:
|
|
132
|
+
logger.info(f"{db_path} exists, skipping extraction")
|
|
133
|
+
|
|
134
|
+
def multi_samples_k_fold_setup(db_path, mapper, n_fold=2):
|
|
135
|
+
"""Set up K-fold cross-validation splits for multiple databases.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
db_path: List of paths to LMDB databases
|
|
139
|
+
mapper: List of output mappers corresponding to each database
|
|
140
|
+
n_fold: Number of folds for cross-validation. Defaults to 2
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Zipped iterator of K-fold splits for each database
|
|
144
|
+
"""
|
|
145
|
+
iterator = [samples_k_fold_setup(path, mapp, n_fold=n_fold) for path, mapp in zip(db_path, mapper)]
|
|
146
|
+
return zip(*iterator)
|
|
147
|
+
|
|
148
|
+
def multi_samples_yearly_k_fold_setup(db_path, mapper, n_fold=2):
|
|
149
|
+
"""Set up K-fold cross-validation for multi-year datasets.
|
|
150
|
+
|
|
151
|
+
Creates K-fold splits based on unique geopackage identifiers that appear across
|
|
152
|
+
multiple years (one database per year). This ensures samples from the same location
|
|
153
|
+
across different years are grouped together, maintaining temporal consistency in splits.
|
|
154
|
+
|
|
155
|
+
Note: Weighting of repeating samples is not currently implemented.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
db_path: List of paths to LMDB databases, one per year
|
|
159
|
+
mapper: List of output mappers for each database
|
|
160
|
+
n_fold: Number of folds for cross-validation. Defaults to 2
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Zipped iterator of KFoldIterator objects, one for each database
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
None
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
key_set = set()
|
|
170
|
+
for db, mapp in zip(db_path, mapper):
|
|
171
|
+
db_reader = LMDBReader(db)
|
|
172
|
+
with db_reader:
|
|
173
|
+
id_out = db_reader.get_sample_id_output_dic()
|
|
174
|
+
|
|
175
|
+
sample_idx = sample_list_id(id_out, mapp)
|
|
176
|
+
key_set.update(sample_idx)
|
|
177
|
+
|
|
178
|
+
key_set = list(key_set)
|
|
179
|
+
folds_idx_ref, folds_split = k_fold_sample(key_set, n_fold)
|
|
180
|
+
|
|
181
|
+
iterators=[]
|
|
182
|
+
for db, mapp in zip(db_path, mapper):
|
|
183
|
+
db_reader = LMDBReader(db)
|
|
184
|
+
with db_reader:
|
|
185
|
+
id_key = db_reader.get_sample_id_db_key_dic()
|
|
186
|
+
|
|
187
|
+
fold_i = []
|
|
188
|
+
for folds in folds_idx_ref:
|
|
189
|
+
f = []
|
|
190
|
+
for idx in folds:
|
|
191
|
+
key = id_key.get(idx, None)
|
|
192
|
+
# check is the key exist and if the key is a cover we are mapping
|
|
193
|
+
if key is not None and mapp(db_reader.get_output(key)) != mapp.no_target:
|
|
194
|
+
f.append(key)
|
|
195
|
+
|
|
196
|
+
fold_i.append(f)
|
|
197
|
+
|
|
198
|
+
iterators.append(KFoldIterator(fold_i, folds_split))
|
|
199
|
+
|
|
200
|
+
return zip(*iterators)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def samples_k_fold_setup(db_path, mapper, n_fold=2):
|
|
204
|
+
"""Set up K-fold cross-validation splits for a single database.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
db_path: Path to LMDB database
|
|
208
|
+
mapper: Output mapper for converting raw labels
|
|
209
|
+
n_fold: Number of folds for cross-validation. Defaults to 2
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
KFoldIterator object containing train/validation splits
|
|
213
|
+
"""
|
|
214
|
+
db_reader = LMDBReader(db_path)
|
|
215
|
+
|
|
216
|
+
with db_reader:
|
|
217
|
+
keys_out = db_reader.def_get_output_dic()
|
|
218
|
+
|
|
219
|
+
id_list = sample_list_id(keys_out, mapper)
|
|
220
|
+
|
|
221
|
+
folds, folds_split = k_fold_sample(id_list, n_fold)
|
|
222
|
+
|
|
223
|
+
return KFoldIterator(folds, folds_split)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def samples_split_setup(db_path, mapper, split=None):
|
|
227
|
+
"""Set up a single train/validation split of samples.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
db_path: Path to LMDB database
|
|
231
|
+
mapper: Output mapper for converting raw labels
|
|
232
|
+
split: List defining train/validation split ratios [train_frac, val_frac].
|
|
233
|
+
Defaults to [0.8, 0.2]
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
List containing a single train/validation split
|
|
237
|
+
"""
|
|
238
|
+
if split is None:
|
|
239
|
+
split = [0.8, 0.2]
|
|
240
|
+
|
|
241
|
+
db_reader = LMDBReader(db_path)
|
|
242
|
+
with db_reader:
|
|
243
|
+
keys_out = db_reader.def_get_output_dic()
|
|
244
|
+
|
|
245
|
+
id_list = sample_list_id(keys_out, mapper)
|
|
246
|
+
|
|
247
|
+
train_id, validation_id = random_split(id_list, split, True)
|
|
248
|
+
return [[train_id, validation_id]]
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def augmentation_setup(samples, angles=None, flip=None):
|
|
252
|
+
"""Set up data augmentation parameters for a set of samples.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
samples: List of sample identifiers
|
|
256
|
+
angles: List of rotation angles in degrees. Defaults to [0, 90, 180, -90]
|
|
257
|
+
flip: List of boolean flags for horizontal flipping. Defaults to [False, True]
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple of (augmented_samples, augmentation_parameters)
|
|
261
|
+
"""
|
|
262
|
+
if angles is None:
|
|
263
|
+
angles = [0, 90, 180, -90]
|
|
264
|
+
|
|
265
|
+
if flip is None:
|
|
266
|
+
flip = [False, True]
|
|
267
|
+
|
|
268
|
+
t_param_list = list(itertools.product(angles, flip))
|
|
269
|
+
|
|
270
|
+
t_params = []
|
|
271
|
+
samples_split = []
|
|
272
|
+
for k in samples:
|
|
273
|
+
for p in t_param_list:
|
|
274
|
+
t_params.append(p)
|
|
275
|
+
samples_split.append(k)
|
|
276
|
+
|
|
277
|
+
return samples_split, t_params
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def dataset_setup(train_id_split, validation_split, augmentation_param, db_path, mapper, db_type=DBDataset):
|
|
281
|
+
"""Set up training and validation datasets with augmentation.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
train_id_split: List of training sample IDs
|
|
285
|
+
validation_split: List of validation sample IDs
|
|
286
|
+
augmentation_param: Dict containing augmentation settings
|
|
287
|
+
db_path: Path to LMDB database
|
|
288
|
+
mapper: Output mapper for converting raw labels
|
|
289
|
+
db_type: Dataset class to use. Defaults to DBDataset
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Tuple of (train_dataset, validation_dataset)
|
|
293
|
+
"""
|
|
294
|
+
# we load a full batch at once to limit the opening and closing of the db
|
|
295
|
+
|
|
296
|
+
# dataset_train = DBDataset(gps_db, train_id, mapper_vector)
|
|
297
|
+
transform_param = None
|
|
298
|
+
transform = None
|
|
299
|
+
transform_valid = None
|
|
300
|
+
if augmentation_param["methode"] == "fix":
|
|
301
|
+
train_id_split, transform_param = augmentation_setup(train_id_split, **augmentation_param["parameters"])
|
|
302
|
+
transform = rotate_flip_transform
|
|
303
|
+
|
|
304
|
+
if augmentation_param["methode"] == "no_dep":
|
|
305
|
+
transform = augmentation_param["transform_train"]
|
|
306
|
+
transform_valid = augmentation_param["transform_valid"]
|
|
307
|
+
|
|
308
|
+
dataset_train = db_type(db_path, train_id_split, mapper, f_transform=transform,
|
|
309
|
+
transform_param=transform_param)
|
|
310
|
+
dataset_valid = db_type(db_path, validation_split, mapper, f_transform=transform_valid)
|
|
311
|
+
|
|
312
|
+
return dataset_train, dataset_valid
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def dataloader_setup(dataset_train, dataset_valid, batch_size, balance_sample, num_worker,
|
|
316
|
+
prefetch, device, persistent_workers):
|
|
317
|
+
"""Set up training and validation data loaders.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
dataset_train: Training dataset
|
|
321
|
+
dataset_valid: Validation dataset
|
|
322
|
+
batch_size: Number of samples per batch
|
|
323
|
+
balance_sample: Whether to use weighted sampling for class balance
|
|
324
|
+
num_worker: Number of worker processes for data loading
|
|
325
|
+
prefetch: Number of batches to prefetch
|
|
326
|
+
device: Device to use ('cuda' or 'cpu')
|
|
327
|
+
persistent_workers: Whether to maintain persistent worker processes
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
Tuple of (train_dataloader, validation_dataloader)
|
|
331
|
+
"""
|
|
332
|
+
# TODO balanced sample will not work for multi sample
|
|
333
|
+
if balance_sample:
|
|
334
|
+
train_weight = dataset_train.weight_list()
|
|
335
|
+
#test_weight = dataset_valid.weight_list()
|
|
336
|
+
|
|
337
|
+
#print(train_weight)
|
|
338
|
+
#print(test_weight)
|
|
339
|
+
|
|
340
|
+
train_rsampler = WeightedRandomSampler(train_weight, len(train_weight), replacement=True)
|
|
341
|
+
else:
|
|
342
|
+
train_rsampler = RandomSampler(dataset_train, replacement=False, num_samples=None, generator=None)
|
|
343
|
+
|
|
344
|
+
# TODO also balance but not for now for comparison
|
|
345
|
+
valid_rsampler = RandomSampler(dataset_valid, replacement=False, num_samples=None, generator=None)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
train_bsampler = BatchSampler(train_rsampler, batch_size=batch_size, drop_last=False)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
valid_bsampler = BatchSampler(valid_rsampler, batch_size=batch_size, drop_last=False)
|
|
352
|
+
|
|
353
|
+
if device == "cuda":
|
|
354
|
+
pin_memory = True
|
|
355
|
+
else:
|
|
356
|
+
pin_memory = False
|
|
357
|
+
|
|
358
|
+
train_dataloader = DataLoader(dataset_train, sampler=train_bsampler, collate_fn=batch_collate,
|
|
359
|
+
num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
|
|
360
|
+
persistent_workers=persistent_workers, worker_init_fn=db_dataset_multi_proc_init)
|
|
361
|
+
validation_dataloader = DataLoader(dataset_valid, sampler=valid_bsampler, collate_fn=batch_collate,
|
|
362
|
+
num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
|
|
363
|
+
persistent_workers=persistent_workers, worker_init_fn=db_dataset_multi_proc_init)
|
|
364
|
+
|
|
365
|
+
return train_dataloader, validation_dataloader
|
|
366
|
+
|
|
367
|
+
def model_setup(model_name, type, path, device, nn_parameter):
|
|
368
|
+
"""Set up and initialize a neural network model.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
model_name: Name of the model architecture
|
|
372
|
+
type: Type of model
|
|
373
|
+
path: Path to model weights/checkpoints
|
|
374
|
+
device: Device to use ('cuda' or 'cpu')
|
|
375
|
+
nn_parameter: Dict of model hyperparameters
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Initialized neural network model
|
|
379
|
+
"""
|
|
380
|
+
# ----------------------------------------
|
|
381
|
+
# Architecture
|
|
382
|
+
# ----------------------------------------
|
|
383
|
+
|
|
384
|
+
factory = ModelFactory()
|
|
385
|
+
|
|
386
|
+
net = factory(model_name, type=type, path=path, model_args=nn_parameter)
|
|
387
|
+
|
|
388
|
+
# net = Conv2Dense3(size, 65, n_out)
|
|
389
|
+
#net = ConvJavaSmall(**nn_parameter)
|
|
390
|
+
net.apply(initialize_weights)
|
|
391
|
+
net.to(device)
|
|
392
|
+
|
|
393
|
+
return net
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def optimizer_setup(net, loss, optimizer, optimizer_parameter, scheduler_mode,
|
|
397
|
+
scheduler_parameter=None, data_loader=None, epoch=None):
|
|
398
|
+
optimizer = optimizer(net.parameters(), **optimizer_parameter)
|
|
399
|
+
"""Set up optimizer and learning rate scheduler.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
net: Neural network model
|
|
403
|
+
loss: Loss function
|
|
404
|
+
optimizer: Optimizer class
|
|
405
|
+
optimizer_parameter: Dict of optimizer parameters
|
|
406
|
+
scheduler_mode: Learning rate scheduler mode
|
|
407
|
+
scheduler_parameter: Dict of scheduler parameters. Defaults to None
|
|
408
|
+
data_loader: DataLoader for scheduler steps. Defaults to None
|
|
409
|
+
epoch: Number of epochs for scheduler. Defaults to None
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
Tuple of (optimizer, loss_function, scheduler)
|
|
413
|
+
"""
|
|
414
|
+
if scheduler_mode is None:
|
|
415
|
+
return optimizer, loss, None
|
|
416
|
+
|
|
417
|
+
if scheduler_mode == "cycle":
|
|
418
|
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, steps_per_epoch=len(data_loader), epochs=epoch, **scheduler_parameter)
|
|
419
|
+
return optimizer, loss, scheduler
|
|
420
|
+
|
|
421
|
+
#if scheduler_mode == "plateau":
|
|
422
|
+
# print("metric need in step and other stuff")
|
|
423
|
+
# torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, threshold=0.0001,
|
|
424
|
+
# threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
raise Exception("Unknown scheduler_mode")
|
|
428
|
+
|
|
429
|
+
def generate_map_iterator(jit_model_path,
|
|
430
|
+
map_path,
|
|
431
|
+
map_iterator,
|
|
432
|
+
transformer,
|
|
433
|
+
mask, #path or geom
|
|
434
|
+
bounds,
|
|
435
|
+
mode,
|
|
436
|
+
num_worker,
|
|
437
|
+
prefetch,
|
|
438
|
+
custom_collate=meta_data_collate,
|
|
439
|
+
worker_init_fn=IterableMapDataset.basic_worker_init_fn):
|
|
440
|
+
# 0 full cpu, no pinning
|
|
441
|
+
# 1 pinned memory in loader, moved asynchronously (more loader needed) to the gpu, (seems to like big batch)
|
|
442
|
+
# 2 start cuda in each thread of the loader and prepare most of the sampled directly on the gpu (each thread/num
|
|
443
|
+
# loader use ~1gb graphic memory and need torch.multiprocessing.set_start_method('spawn') to be used
|
|
444
|
+
# TODO test const mem for memory limited gpy
|
|
445
|
+
write_profile = get_write_profile()
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
# if a path is given
|
|
449
|
+
if isinstance(mask, str):
|
|
450
|
+
with fiona.open(mask) as layer:
|
|
451
|
+
features = [shape(feature["geometry"]) for feature in layer]
|
|
452
|
+
else:
|
|
453
|
+
#mask is already a mask
|
|
454
|
+
features = mask
|
|
455
|
+
|
|
456
|
+
is_pinned = False
|
|
457
|
+
loader_device = "cpu"
|
|
458
|
+
nn_device = "cpu"
|
|
459
|
+
|
|
460
|
+
if isinstance(jit_model_path, str):
|
|
461
|
+
model = torch.jit.load(jit_model_path)
|
|
462
|
+
else:
|
|
463
|
+
model = jit_model_path
|
|
464
|
+
|
|
465
|
+
if mode == 0:
|
|
466
|
+
is_pinned = False
|
|
467
|
+
loader_device = "cpu"
|
|
468
|
+
nn_device = "cpu"
|
|
469
|
+
if mode == 1:
|
|
470
|
+
is_pinned = True
|
|
471
|
+
loader_device = "cpu"
|
|
472
|
+
nn_device = "cuda"
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
mapper = NNMapper(model,
|
|
476
|
+
stride=1,
|
|
477
|
+
loader_device=loader_device,
|
|
478
|
+
mapper_device=nn_device,
|
|
479
|
+
pin_memory=is_pinned,
|
|
480
|
+
num_worker=num_worker,
|
|
481
|
+
prefetch_factor=prefetch,
|
|
482
|
+
custom_collate=custom_collate,
|
|
483
|
+
worker_init_fn=worker_init_fn,
|
|
484
|
+
write_profile=write_profile,
|
|
485
|
+
async_agr=True)
|
|
486
|
+
|
|
487
|
+
# deprecated, it is not recommended to use cuda in workers
|
|
488
|
+
# if mode == 2:
|
|
489
|
+
# is_pinned = False
|
|
490
|
+
# loader_device = "cuda"
|
|
491
|
+
# nn_device = "cuda"
|
|
492
|
+
|
|
493
|
+
# if mode == 2:
|
|
494
|
+
# mapper = NNMapper(model,
|
|
495
|
+
# windows_size,
|
|
496
|
+
# stride=1,
|
|
497
|
+
# batch_size=batch_size,
|
|
498
|
+
# loader_device=loader_device,
|
|
499
|
+
# mapper_device=nn_device,
|
|
500
|
+
# pin_memory=is_pinned,
|
|
501
|
+
# num_worker=num_worker,
|
|
502
|
+
# prefetch_factor=prefetch,
|
|
503
|
+
# map_dataset_class=IterableMapDatasetConstMem,
|
|
504
|
+
# custom_collate=meta_data_collate,
|
|
505
|
+
# worker_init_fn=worker_init_fn,
|
|
506
|
+
# aggregator=MapResultAggregatorConstMem,
|
|
507
|
+
# write_profile=write_profile,
|
|
508
|
+
# async_agr=True)
|
|
509
|
+
|
|
510
|
+
mapper.map(map_path, map_iterator, transformer, bounds=bounds, mask=features)
|
|
511
|
+
|
|
512
|
+
return map_path
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def training_step(net, optimizer, loss, scheduler, train_dataloader, validation_dataloader, max_epochs,
|
|
516
|
+
run_stats_dir, model_base_path, model_tag, grad_clip_value, device):
|
|
517
|
+
"""do the training step"""
|
|
518
|
+
|
|
519
|
+
grad_f = None
|
|
520
|
+
if grad_clip_value is not None:
|
|
521
|
+
grad_f = GradNormClipper(grad_clip_value)
|
|
522
|
+
|
|
523
|
+
return agressive_train_labeling(max_epochs, net, optimizer, loss, scheduler, train_dataloader,
|
|
524
|
+
validation_dataloader,
|
|
525
|
+
writer_base_path=run_stats_dir, model_base_path=model_base_path,
|
|
526
|
+
model_tag=model_tag,
|
|
527
|
+
grad_f=grad_f, device=device)
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def generate_map(jit_model_path,
|
|
531
|
+
map_path,
|
|
532
|
+
raster_reader,
|
|
533
|
+
windows_size,
|
|
534
|
+
batch_size,
|
|
535
|
+
transformer,
|
|
536
|
+
mask, #path or geom
|
|
537
|
+
bounds,
|
|
538
|
+
mode,
|
|
539
|
+
num_worker,
|
|
540
|
+
prefetch,
|
|
541
|
+
custom_collate=meta_data_collate,
|
|
542
|
+
worker_init_fn=IterableMapDataset.basic_worker_init_fn):
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
map_iterator = IterableMapDataset(raster_reader, windows_size, batch_size=batch_size)
|
|
547
|
+
|
|
548
|
+
generate_map_iterator(jit_model_path,
|
|
549
|
+
map_path,
|
|
550
|
+
map_iterator,
|
|
551
|
+
transformer,
|
|
552
|
+
mask, #path or geom
|
|
553
|
+
bounds,
|
|
554
|
+
mode,
|
|
555
|
+
num_worker,
|
|
556
|
+
prefetch,
|
|
557
|
+
custom_collate=custom_collate,
|
|
558
|
+
worker_init_fn=worker_init_fn)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def train(sample_param, augmentation_param, dataset_parameter, dataloader_parameter, model_parameter, optimizer_parameter, train_nn_parameter):
|
|
563
|
+
"""
|
|
564
|
+
Map and train in function of the parameter
|
|
565
|
+
:param sample_param:
|
|
566
|
+
:param augmentation_param:
|
|
567
|
+
:param dataset_parameter:
|
|
568
|
+
:param model_parameter:
|
|
569
|
+
:param train_nn_parameter:
|
|
570
|
+
:param map_parameter:
|
|
571
|
+
:return:
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
experiment_iterator = sample_param["methode"](**sample_param["param"])
|
|
575
|
+
|
|
576
|
+
out = []
|
|
577
|
+
for (i, (train_id, validation_id)) in enumerate(experiment_iterator):
|
|
578
|
+
|
|
579
|
+
train_dataset, validation_dataset = dataset_setup(train_id, validation_id, augmentation_param,
|
|
580
|
+
**dataset_parameter)
|
|
581
|
+
|
|
582
|
+
train_dataloader, validation_dataloader = dataloader_setup(train_dataset, validation_dataset,
|
|
583
|
+
**dataloader_parameter)
|
|
584
|
+
|
|
585
|
+
net = model_setup(**model_parameter)
|
|
586
|
+
|
|
587
|
+
optimizer, loss, scheduler = optimizer_setup(net=net, data_loader=train_dataloader, epoch=train_nn_parameter["max_epochs"],
|
|
588
|
+
**optimizer_parameter)
|
|
589
|
+
|
|
590
|
+
base_model_dir, best_model_path, model_path_jitted, model_name = training_step(net, optimizer, loss, scheduler, train_dataloader,
|
|
591
|
+
validation_dataloader, **train_nn_parameter)
|
|
592
|
+
|
|
593
|
+
del train_dataloader, validation_dataloader, net, loss, optimizer, train_id,
|
|
594
|
+
|
|
595
|
+
out.append((base_model_dir, best_model_path, model_path_jitted, model_name))
|
|
596
|
+
|
|
597
|
+
return out
|
|
598
|
+
|
|
599
|
+
def multi_db_merge(ids, augmentation_param, dataset_parameter):
|
|
600
|
+
"""Merge multiple databases into combined train/validation datasets.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
ids: List of (train_ids, val_ids) tuples for each database
|
|
604
|
+
augmentation_param: Dict of augmentation parameters
|
|
605
|
+
dataset_parameter: List of dataset parameters for each database
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
Tuple of (merged_train_dataset, merged_validation_dataset)
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
train_data = []
|
|
612
|
+
valid_data = []
|
|
613
|
+
for (train_id, validation_id), data_param in zip(ids, dataset_parameter):
|
|
614
|
+
|
|
615
|
+
train_dataset, validation_dataset = dataset_setup(train_id, validation_id, augmentation_param, db_type= DBInfo,
|
|
616
|
+
**data_param)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
train_data.append(train_dataset)
|
|
620
|
+
valid_data.append(validation_dataset)
|
|
621
|
+
|
|
622
|
+
return MultiDBDataset(train_data), MultiDBDataset(valid_data)
|
|
623
|
+
|
|
624
|
+
def multi_train_and_map(sample_param, augmentation_param, dataset_parameter, dataloader_parameter,model_parameter, optimizer_parameter, train_nn_parameter,
|
|
625
|
+
map_parameter):
|
|
626
|
+
"""
|
|
627
|
+
Map and train in function of the parameter
|
|
628
|
+
:param sample_param:
|
|
629
|
+
:param augmentation_param:
|
|
630
|
+
:param dataset_parameter:
|
|
631
|
+
:param model_parameter:
|
|
632
|
+
:param train_nn_parameter:
|
|
633
|
+
:param map_parameter:
|
|
634
|
+
:return:
|
|
635
|
+
"""
|
|
636
|
+
|
|
637
|
+
experiment_iterator = sample_param["methode"](**sample_param["param"])
|
|
638
|
+
|
|
639
|
+
out = {}
|
|
640
|
+
for (i, ids) in enumerate(experiment_iterator):
|
|
641
|
+
|
|
642
|
+
train_dataset, validation_dataset = multi_db_merge(ids, augmentation_param, dataset_parameter)
|
|
643
|
+
|
|
644
|
+
train_dataloader, validation_dataloader = dataloader_setup(train_dataset, validation_dataset,
|
|
645
|
+
**dataloader_parameter)
|
|
646
|
+
|
|
647
|
+
net = model_setup(**model_parameter)
|
|
648
|
+
|
|
649
|
+
optimizer, loss, scheduler = optimizer_setup(net=net, data_loader=train_dataloader, epoch=train_nn_parameter["max_epochs"],
|
|
650
|
+
**optimizer_parameter)
|
|
651
|
+
|
|
652
|
+
base_model_dir, best_model_path, model_path_jitted, model_name = training_step(net, optimizer, loss, scheduler, train_dataloader,
|
|
653
|
+
validation_dataloader, **train_nn_parameter)
|
|
654
|
+
|
|
655
|
+
# TODO add again for later
|
|
656
|
+
#dataset_stats(f'{base_model_dir}' ,model_path_jitted, train_id, validation_id, dataset_parameter["db_path"],
|
|
657
|
+
# dataset_parameter["batch_size"], dataset_parameter["mapper"], dataset_parameter["num_worker"],
|
|
658
|
+
# dataset_parameter["prefetch"], train_nn_parameter["device"])
|
|
659
|
+
|
|
660
|
+
del train_dataloader, validation_dataloader, net, loss, optimizer, ids
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
for i, map_param in enumerate(map_parameter):
|
|
664
|
+
map_path = f'{base_model_dir}/{map_param["map_tag"]}_{model_name}_{i}.tif'
|
|
665
|
+
map_param["map_path"] = map_path
|
|
666
|
+
current_map_parameter = map_param.copy()
|
|
667
|
+
del current_map_parameter['map_tag'] # remove as not a input of the mapping function
|
|
668
|
+
|
|
669
|
+
m_out = out.setdefault(i, [])
|
|
670
|
+
m_out.append(generate_map(model_path_jitted, **current_map_parameter))
|
|
671
|
+
|
|
672
|
+
return out
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def train_and_map(sample_param, augmentation_param, dataset_parameter, dataloader_parameter, model_parameter, optimizer_parameter, train_nn_parameter,
|
|
676
|
+
map_parameter):
|
|
677
|
+
"""
|
|
678
|
+
Map and train in function of the parameter
|
|
679
|
+
:param sample_param:
|
|
680
|
+
:param augmentation_param:
|
|
681
|
+
:param dataset_parameter:
|
|
682
|
+
:param model_parameter:
|
|
683
|
+
:param train_nn_parameter:
|
|
684
|
+
:param map_parameter:
|
|
685
|
+
:return:
|
|
686
|
+
"""
|
|
687
|
+
|
|
688
|
+
experiment_iterator = sample_param["methode"](**sample_param["param"])
|
|
689
|
+
|
|
690
|
+
out = []
|
|
691
|
+
for (i, (train_id, validation_id)) in enumerate(experiment_iterator):
|
|
692
|
+
|
|
693
|
+
train_dataset, validation_dataset = dataset_setup(train_id, validation_id, augmentation_param,
|
|
694
|
+
**dataset_parameter)
|
|
695
|
+
|
|
696
|
+
train_dataloader, validation_dataloader = dataloader_setup(train_dataset, validation_dataset,
|
|
697
|
+
**dataloader_parameter)
|
|
698
|
+
|
|
699
|
+
net = model_setup(**model_parameter)
|
|
700
|
+
|
|
701
|
+
optimizer, loss, scheduler = optimizer_setup(net=net, data_loader=train_dataloader,
|
|
702
|
+
epoch=train_nn_parameter["max_epochs"],
|
|
703
|
+
**optimizer_parameter)
|
|
704
|
+
|
|
705
|
+
base_model_dir, best_model_path, model_path_jitted, model_name = training_step(net,
|
|
706
|
+
optimizer,
|
|
707
|
+
loss,
|
|
708
|
+
scheduler,
|
|
709
|
+
train_dataloader,
|
|
710
|
+
validation_dataloader,
|
|
711
|
+
**train_nn_parameter)
|
|
712
|
+
|
|
713
|
+
dataset_stats(f'{base_model_dir}', model_path_jitted, train_id, validation_id, augmentation_param["transform_valid"], dataset_parameter["db_path"],
|
|
714
|
+
dataloader_parameter["batch_size"], dataset_parameter["mapper"], dataloader_parameter["num_worker"],
|
|
715
|
+
dataloader_parameter["prefetch"], train_nn_parameter["device"])
|
|
716
|
+
|
|
717
|
+
del train_dataloader, validation_dataloader, net, loss, optimizer, train_id
|
|
718
|
+
|
|
719
|
+
map_path = f'{base_model_dir}/{map_parameter["map_tag"]}_{model_name}_{i}.tif'
|
|
720
|
+
map_parameter["map_path"] = map_path
|
|
721
|
+
current_map_parameter = map_parameter.copy()
|
|
722
|
+
del current_map_parameter['map_tag'] # remove as not a input of the mapping function
|
|
723
|
+
out.append(generate_map(model_path_jitted, **current_map_parameter))
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
return out
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
def dataset_stats(path, jit_model_path, train_id_split, validation_split, transform, db_path, batch_size, mapper, num_worker,
|
|
730
|
+
prefetch, device):
|
|
731
|
+
|
|
732
|
+
model = torch.jit.load(jit_model_path)
|
|
733
|
+
model = model.to(device)
|
|
734
|
+
|
|
735
|
+
# we load a full batch at once to limit the opening and closing of the db
|
|
736
|
+
|
|
737
|
+
# dataset_train = DBDataset(gps_db, train_id, mapper_vector)
|
|
738
|
+
dataset_train = DBDatasetMeta(db_path, train_id_split, mapper, f_transform=transform)
|
|
739
|
+
dataset_valid = DBDatasetMeta(db_path, validation_split, mapper, f_transform=transform)
|
|
740
|
+
|
|
741
|
+
train_rsampler = RandomSampler(dataset_train, replacement=False, num_samples=None, generator=None)
|
|
742
|
+
valid_rsampler = RandomSampler(dataset_valid, replacement=False, num_samples=None, generator=None)
|
|
743
|
+
|
|
744
|
+
train_bsampler = BatchSampler(train_rsampler, batch_size=batch_size, drop_last=False)
|
|
745
|
+
valid_bsampler = BatchSampler(valid_rsampler, batch_size=batch_size, drop_last=False)
|
|
746
|
+
|
|
747
|
+
if device == "cuda":
|
|
748
|
+
pin_memory = True
|
|
749
|
+
else:
|
|
750
|
+
pin_memory = False
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
train_dataloader = DataLoader(dataset_train, sampler=train_bsampler, collate_fn=batch_collate,
|
|
754
|
+
num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
|
|
755
|
+
worker_init_fn=db_dataset_multi_proc_init)
|
|
756
|
+
validation_dataloader = DataLoader(dataset_valid, sampler=valid_bsampler, collate_fn=batch_collate,
|
|
757
|
+
num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
|
|
758
|
+
worker_init_fn=db_dataset_multi_proc_init)
|
|
759
|
+
|
|
760
|
+
stats = ClassificationStats(len(mapper), device, mapper.nn_name())
|
|
761
|
+
|
|
762
|
+
stats.compute(model, train_dataloader, device)
|
|
763
|
+
stats.display()
|
|
764
|
+
stats.to_file(f"{path}/train.txt")
|
|
765
|
+
|
|
766
|
+
stats.compute(model, validation_dataloader, device)
|
|
767
|
+
stats.display()
|
|
768
|
+
stats.to_file(f"{path}/train.txt")
|
|
769
|
+
|
|
770
|
+
bctg = BadlyClassifyToGPKG()
|
|
771
|
+
|
|
772
|
+
bctg.compute(model, train_dataloader, device)
|
|
773
|
+
bctg.to_file(f"{path}/train.gpkg")
|
|
774
|
+
|
|
775
|
+
bctg.compute(model, validation_dataloader, device)
|
|
776
|
+
bctg.to_file(f"{path}/validation.gpkg")
|
|
777
|
+
|
|
778
|
+
def tiled_task(maps, raster_out, operation, bounds=None, res=None,
|
|
779
|
+
resampling=Resampling.nearest, target_aligned_pixels=False, indexes=None, src_kwds=None,
|
|
780
|
+
dst_kwds=None, num_workers=8):
|
|
781
|
+
"""
|
|
782
|
+
Execute the specified tiled task with the specified arguments
|
|
783
|
+
:param maps:
|
|
784
|
+
:param raster_out:
|
|
785
|
+
:param operation:
|
|
786
|
+
:param operation_args:
|
|
787
|
+
:param bounds:
|
|
788
|
+
:param res:
|
|
789
|
+
:param nodata:
|
|
790
|
+
:param resampling:
|
|
791
|
+
:param target_aligned_pixels:
|
|
792
|
+
:param indexes:
|
|
793
|
+
:param src_kwds:
|
|
794
|
+
:param dst_kwds:
|
|
795
|
+
:param num_workers:
|
|
796
|
+
:return:
|
|
797
|
+
"""
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
if src_kwds is None:
|
|
801
|
+
src_kwds = get_read_profile()
|
|
802
|
+
|
|
803
|
+
if dst_kwds is None:
|
|
804
|
+
dst_kwds = get_write_profile()
|
|
805
|
+
|
|
806
|
+
tiled_op(
|
|
807
|
+
maps,
|
|
808
|
+
operation,
|
|
809
|
+
operation.n_band_out,
|
|
810
|
+
raster_out,
|
|
811
|
+
bounds=bounds,
|
|
812
|
+
res=res,
|
|
813
|
+
nodata=operation.nodata,
|
|
814
|
+
dtype=operation.dtype,
|
|
815
|
+
indexes=indexes,
|
|
816
|
+
resampling=resampling,
|
|
817
|
+
target_aligned_pixels=target_aligned_pixels,
|
|
818
|
+
dst_kwds=dst_kwds,
|
|
819
|
+
src_kwds=src_kwds,
|
|
820
|
+
num_workers=num_workers)
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
# only run if main script (important as we will use threading with the spawn methode
|