eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,825 @@
1
+ """Automation tasks for machine learning workflows with Earth observation data.
2
+
3
+ This module provides high-level functions and utilities for automating common machine learning
4
+ workflows including data extraction, model training, validation, and inference mapping. It serves
5
+ as the orchestration layer that connects data preparation, model creation, training loops, and
6
+ prediction generation.
7
+
8
+ Key functionality:
9
+ - Sample extraction from raster data and vector labels
10
+ - K-fold cross-validation setup for training/validation splits
11
+ - Dataset and dataloader configuration with augmentation
12
+ - Model instantiation and initialization
13
+ - Training orchestration with multiple datasets
14
+ - Map generation using trained models
15
+ - Statistics computation for trained models
16
+ """
17
+
18
+ import itertools
19
+ import logging
20
+ from os import path
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ import fiona
25
+ import torch
26
+
27
+ from eoml import get_read_profile, get_write_profile
28
+ from eoml.data.dataset_utils import k_fold_sample, random_split
29
+ from eoml.data.persistence.generic import GeoDataWriter
30
+ from eoml.data.persistence.lmdb import LMDBasicGeoDataDAO, LMDBReader
31
+ from eoml.raster.dataset.extractor import ThreadedOptimiseLabeledWindowsExtractor
32
+ from eoml.torch.cnn.augmentation import rotate_flip_transform
33
+ from eoml.torch.cnn.db_dataset import sample_list_id, DBDataset, DBDatasetMeta, DBInfo, MultiDBDataset, \
34
+ db_dataset_multi_proc_init
35
+ from eoml.torch.cnn.map_dataset import IterableMapDataset, NNMapper
36
+ from eoml.torch.cnn.torch_utils import meta_data_collate, batch_collate
37
+ from eoml.torch.models import initialize_weights, ModelFactory
38
+ from eoml.torch.sample_statistic import ClassificationStats, BadlyClassifyToGPKG
39
+ from eoml.torch.trainer import GradNormClipper, agressive_train_labeling
40
+
41
+ from rasterio.enums import Resampling
42
+ from shapely.geometry import shape
43
+ from torch.utils.data import RandomSampler, BatchSampler, DataLoader, WeightedRandomSampler
44
+
45
+ from rasterop.tiled_op.tiled_raster_op import tiled_op
46
+
47
+
48
+ class KFoldIterator:
49
+ """Iterator for K-fold cross-validation splits.
50
+
51
+ Constructs sequences of training and validation folds based on a list of folds
52
+ and split specifications. Each split is defined as ((train_fold_indices), (validation_fold_indices)).
53
+
54
+ Attributes:
55
+ folds: List of fold data, where each fold contains sample identifiers
56
+ folds_split: List of tuples defining the train/validation split for each iteration.
57
+ Each tuple is ((train_indices,), (val_indices,))
58
+ """
59
+
60
+ def __init__(self, folds, folds_split):
61
+ """Initialize the K-fold iterator.
62
+
63
+ Args:
64
+ folds: List of folds, where each fold is a list of sample identifiers
65
+ folds_split: List of split specifications, where each split is a tuple
66
+ ((train_fold_indices), (validation_fold_indices))
67
+ """
68
+ self.folds = folds
69
+ self.folds_split = folds_split
70
+
71
+ def __iter__(self):
72
+ """Iterate through all fold splits.
73
+
74
+ Yields:
75
+ Tuple of (train_fold, validation_fold), where each fold is a flat list
76
+ of sample identifiers
77
+ """
78
+ for split in self.folds_split:
79
+ train_fold = []
80
+ validation_fold = []
81
+ for s in split[0]:
82
+ train_fold.extend(self.folds[s])
83
+ for s in split[1]:
84
+ validation_fold.extend(self.folds[s])
85
+
86
+ yield train_fold, validation_fold
87
+
88
+ def __len__(self):
89
+ """Return the number of splits.
90
+
91
+ Returns:
92
+ Number of train/validation splits
93
+ """
94
+ return len(self.folds_split)
95
+
96
+
97
+ # todo use torch script for model
98
+ def extract_sample(gps_path, raster_reader, db_path, windows_size, label_name="Class", id_field=None, mask_path=None, force_write=False):
99
+ """Extract training samples from raster data based on GPS/vector labels.
100
+
101
+ Extracts image windows from raster data at locations specified by vector geometries
102
+ (GPS points) and stores them in an LMDB database for efficient training data access.
103
+
104
+ Args:
105
+ gps_path: Path to vector file (GeoPackage, Shapefile, etc.) containing sample locations
106
+ raster_reader: RasterReader instance for reading the source imagery
107
+ db_path: Output path for the LMDB database
108
+ windows_size: Size of the extraction window (in pixels)
109
+ label_name: Name of the attribute field containing class labels. Defaults to "Class"
110
+ id_field: Name of the unique identifier field. Defaults to None (uses feature index)
111
+ mask_path: Optional path to mask polygon restricting extraction area
112
+ force_write: If True, overwrite existing database. Defaults to False
113
+
114
+ Returns:
115
+ None. Writes extracted samples to database at db_path
116
+ """
117
+ if force_write or not path.exists(db_path):
118
+ writer = GeoDataWriter(LMDBasicGeoDataDAO(db_path))
119
+
120
+ #extractor = LabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name, id_field)
121
+
122
+ #extractor = LabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name)
123
+ #extractor = OptimiseLabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size,
124
+ # label_name, id_field, mask_path=mask_path)
125
+ extractor = ThreadedOptimiseLabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name,
126
+ id_field, mask_path=mask_path, worker=4, prefetch=3)
127
+ #extractor = ProcessOptimiseLabeledWindowsExtractor(gps_path, writer, raster_reader, windows_size, label_name,
128
+ # id_field, mask_path=mask_path, worker=4, prefetch=3)
129
+
130
+ extractor.process()
131
+ else:
132
+ logger.info(f"{db_path} exists, skipping extraction")
133
+
134
+ def multi_samples_k_fold_setup(db_path, mapper, n_fold=2):
135
+ """Set up K-fold cross-validation splits for multiple databases.
136
+
137
+ Args:
138
+ db_path: List of paths to LMDB databases
139
+ mapper: List of output mappers corresponding to each database
140
+ n_fold: Number of folds for cross-validation. Defaults to 2
141
+
142
+ Returns:
143
+ Zipped iterator of K-fold splits for each database
144
+ """
145
+ iterator = [samples_k_fold_setup(path, mapp, n_fold=n_fold) for path, mapp in zip(db_path, mapper)]
146
+ return zip(*iterator)
147
+
148
+ def multi_samples_yearly_k_fold_setup(db_path, mapper, n_fold=2):
149
+ """Set up K-fold cross-validation for multi-year datasets.
150
+
151
+ Creates K-fold splits based on unique geopackage identifiers that appear across
152
+ multiple years (one database per year). This ensures samples from the same location
153
+ across different years are grouped together, maintaining temporal consistency in splits.
154
+
155
+ Note: Weighting of repeating samples is not currently implemented.
156
+
157
+ Args:
158
+ db_path: List of paths to LMDB databases, one per year
159
+ mapper: List of output mappers for each database
160
+ n_fold: Number of folds for cross-validation. Defaults to 2
161
+
162
+ Returns:
163
+ Zipped iterator of KFoldIterator objects, one for each database
164
+
165
+ Raises:
166
+ None
167
+ """
168
+
169
+ key_set = set()
170
+ for db, mapp in zip(db_path, mapper):
171
+ db_reader = LMDBReader(db)
172
+ with db_reader:
173
+ id_out = db_reader.get_sample_id_output_dic()
174
+
175
+ sample_idx = sample_list_id(id_out, mapp)
176
+ key_set.update(sample_idx)
177
+
178
+ key_set = list(key_set)
179
+ folds_idx_ref, folds_split = k_fold_sample(key_set, n_fold)
180
+
181
+ iterators=[]
182
+ for db, mapp in zip(db_path, mapper):
183
+ db_reader = LMDBReader(db)
184
+ with db_reader:
185
+ id_key = db_reader.get_sample_id_db_key_dic()
186
+
187
+ fold_i = []
188
+ for folds in folds_idx_ref:
189
+ f = []
190
+ for idx in folds:
191
+ key = id_key.get(idx, None)
192
+ # check is the key exist and if the key is a cover we are mapping
193
+ if key is not None and mapp(db_reader.get_output(key)) != mapp.no_target:
194
+ f.append(key)
195
+
196
+ fold_i.append(f)
197
+
198
+ iterators.append(KFoldIterator(fold_i, folds_split))
199
+
200
+ return zip(*iterators)
201
+
202
+
203
+ def samples_k_fold_setup(db_path, mapper, n_fold=2):
204
+ """Set up K-fold cross-validation splits for a single database.
205
+
206
+ Args:
207
+ db_path: Path to LMDB database
208
+ mapper: Output mapper for converting raw labels
209
+ n_fold: Number of folds for cross-validation. Defaults to 2
210
+
211
+ Returns:
212
+ KFoldIterator object containing train/validation splits
213
+ """
214
+ db_reader = LMDBReader(db_path)
215
+
216
+ with db_reader:
217
+ keys_out = db_reader.def_get_output_dic()
218
+
219
+ id_list = sample_list_id(keys_out, mapper)
220
+
221
+ folds, folds_split = k_fold_sample(id_list, n_fold)
222
+
223
+ return KFoldIterator(folds, folds_split)
224
+
225
+
226
+ def samples_split_setup(db_path, mapper, split=None):
227
+ """Set up a single train/validation split of samples.
228
+
229
+ Args:
230
+ db_path: Path to LMDB database
231
+ mapper: Output mapper for converting raw labels
232
+ split: List defining train/validation split ratios [train_frac, val_frac].
233
+ Defaults to [0.8, 0.2]
234
+
235
+ Returns:
236
+ List containing a single train/validation split
237
+ """
238
+ if split is None:
239
+ split = [0.8, 0.2]
240
+
241
+ db_reader = LMDBReader(db_path)
242
+ with db_reader:
243
+ keys_out = db_reader.def_get_output_dic()
244
+
245
+ id_list = sample_list_id(keys_out, mapper)
246
+
247
+ train_id, validation_id = random_split(id_list, split, True)
248
+ return [[train_id, validation_id]]
249
+
250
+
251
+ def augmentation_setup(samples, angles=None, flip=None):
252
+ """Set up data augmentation parameters for a set of samples.
253
+
254
+ Args:
255
+ samples: List of sample identifiers
256
+ angles: List of rotation angles in degrees. Defaults to [0, 90, 180, -90]
257
+ flip: List of boolean flags for horizontal flipping. Defaults to [False, True]
258
+
259
+ Returns:
260
+ Tuple of (augmented_samples, augmentation_parameters)
261
+ """
262
+ if angles is None:
263
+ angles = [0, 90, 180, -90]
264
+
265
+ if flip is None:
266
+ flip = [False, True]
267
+
268
+ t_param_list = list(itertools.product(angles, flip))
269
+
270
+ t_params = []
271
+ samples_split = []
272
+ for k in samples:
273
+ for p in t_param_list:
274
+ t_params.append(p)
275
+ samples_split.append(k)
276
+
277
+ return samples_split, t_params
278
+
279
+
280
+ def dataset_setup(train_id_split, validation_split, augmentation_param, db_path, mapper, db_type=DBDataset):
281
+ """Set up training and validation datasets with augmentation.
282
+
283
+ Args:
284
+ train_id_split: List of training sample IDs
285
+ validation_split: List of validation sample IDs
286
+ augmentation_param: Dict containing augmentation settings
287
+ db_path: Path to LMDB database
288
+ mapper: Output mapper for converting raw labels
289
+ db_type: Dataset class to use. Defaults to DBDataset
290
+
291
+ Returns:
292
+ Tuple of (train_dataset, validation_dataset)
293
+ """
294
+ # we load a full batch at once to limit the opening and closing of the db
295
+
296
+ # dataset_train = DBDataset(gps_db, train_id, mapper_vector)
297
+ transform_param = None
298
+ transform = None
299
+ transform_valid = None
300
+ if augmentation_param["methode"] == "fix":
301
+ train_id_split, transform_param = augmentation_setup(train_id_split, **augmentation_param["parameters"])
302
+ transform = rotate_flip_transform
303
+
304
+ if augmentation_param["methode"] == "no_dep":
305
+ transform = augmentation_param["transform_train"]
306
+ transform_valid = augmentation_param["transform_valid"]
307
+
308
+ dataset_train = db_type(db_path, train_id_split, mapper, f_transform=transform,
309
+ transform_param=transform_param)
310
+ dataset_valid = db_type(db_path, validation_split, mapper, f_transform=transform_valid)
311
+
312
+ return dataset_train, dataset_valid
313
+
314
+
315
+ def dataloader_setup(dataset_train, dataset_valid, batch_size, balance_sample, num_worker,
316
+ prefetch, device, persistent_workers):
317
+ """Set up training and validation data loaders.
318
+
319
+ Args:
320
+ dataset_train: Training dataset
321
+ dataset_valid: Validation dataset
322
+ batch_size: Number of samples per batch
323
+ balance_sample: Whether to use weighted sampling for class balance
324
+ num_worker: Number of worker processes for data loading
325
+ prefetch: Number of batches to prefetch
326
+ device: Device to use ('cuda' or 'cpu')
327
+ persistent_workers: Whether to maintain persistent worker processes
328
+
329
+ Returns:
330
+ Tuple of (train_dataloader, validation_dataloader)
331
+ """
332
+ # TODO balanced sample will not work for multi sample
333
+ if balance_sample:
334
+ train_weight = dataset_train.weight_list()
335
+ #test_weight = dataset_valid.weight_list()
336
+
337
+ #print(train_weight)
338
+ #print(test_weight)
339
+
340
+ train_rsampler = WeightedRandomSampler(train_weight, len(train_weight), replacement=True)
341
+ else:
342
+ train_rsampler = RandomSampler(dataset_train, replacement=False, num_samples=None, generator=None)
343
+
344
+ # TODO also balance but not for now for comparison
345
+ valid_rsampler = RandomSampler(dataset_valid, replacement=False, num_samples=None, generator=None)
346
+
347
+
348
+ train_bsampler = BatchSampler(train_rsampler, batch_size=batch_size, drop_last=False)
349
+
350
+
351
+ valid_bsampler = BatchSampler(valid_rsampler, batch_size=batch_size, drop_last=False)
352
+
353
+ if device == "cuda":
354
+ pin_memory = True
355
+ else:
356
+ pin_memory = False
357
+
358
+ train_dataloader = DataLoader(dataset_train, sampler=train_bsampler, collate_fn=batch_collate,
359
+ num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
360
+ persistent_workers=persistent_workers, worker_init_fn=db_dataset_multi_proc_init)
361
+ validation_dataloader = DataLoader(dataset_valid, sampler=valid_bsampler, collate_fn=batch_collate,
362
+ num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
363
+ persistent_workers=persistent_workers, worker_init_fn=db_dataset_multi_proc_init)
364
+
365
+ return train_dataloader, validation_dataloader
366
+
367
+ def model_setup(model_name, type, path, device, nn_parameter):
368
+ """Set up and initialize a neural network model.
369
+
370
+ Args:
371
+ model_name: Name of the model architecture
372
+ type: Type of model
373
+ path: Path to model weights/checkpoints
374
+ device: Device to use ('cuda' or 'cpu')
375
+ nn_parameter: Dict of model hyperparameters
376
+
377
+ Returns:
378
+ Initialized neural network model
379
+ """
380
+ # ----------------------------------------
381
+ # Architecture
382
+ # ----------------------------------------
383
+
384
+ factory = ModelFactory()
385
+
386
+ net = factory(model_name, type=type, path=path, model_args=nn_parameter)
387
+
388
+ # net = Conv2Dense3(size, 65, n_out)
389
+ #net = ConvJavaSmall(**nn_parameter)
390
+ net.apply(initialize_weights)
391
+ net.to(device)
392
+
393
+ return net
394
+
395
+
396
+ def optimizer_setup(net, loss, optimizer, optimizer_parameter, scheduler_mode,
397
+ scheduler_parameter=None, data_loader=None, epoch=None):
398
+ optimizer = optimizer(net.parameters(), **optimizer_parameter)
399
+ """Set up optimizer and learning rate scheduler.
400
+
401
+ Args:
402
+ net: Neural network model
403
+ loss: Loss function
404
+ optimizer: Optimizer class
405
+ optimizer_parameter: Dict of optimizer parameters
406
+ scheduler_mode: Learning rate scheduler mode
407
+ scheduler_parameter: Dict of scheduler parameters. Defaults to None
408
+ data_loader: DataLoader for scheduler steps. Defaults to None
409
+ epoch: Number of epochs for scheduler. Defaults to None
410
+
411
+ Returns:
412
+ Tuple of (optimizer, loss_function, scheduler)
413
+ """
414
+ if scheduler_mode is None:
415
+ return optimizer, loss, None
416
+
417
+ if scheduler_mode == "cycle":
418
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, steps_per_epoch=len(data_loader), epochs=epoch, **scheduler_parameter)
419
+ return optimizer, loss, scheduler
420
+
421
+ #if scheduler_mode == "plateau":
422
+ # print("metric need in step and other stuff")
423
+ # torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, threshold=0.0001,
424
+ # threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)
425
+
426
+
427
+ raise Exception("Unknown scheduler_mode")
428
+
429
+ def generate_map_iterator(jit_model_path,
430
+ map_path,
431
+ map_iterator,
432
+ transformer,
433
+ mask, #path or geom
434
+ bounds,
435
+ mode,
436
+ num_worker,
437
+ prefetch,
438
+ custom_collate=meta_data_collate,
439
+ worker_init_fn=IterableMapDataset.basic_worker_init_fn):
440
+ # 0 full cpu, no pinning
441
+ # 1 pinned memory in loader, moved asynchronously (more loader needed) to the gpu, (seems to like big batch)
442
+ # 2 start cuda in each thread of the loader and prepare most of the sampled directly on the gpu (each thread/num
443
+ # loader use ~1gb graphic memory and need torch.multiprocessing.set_start_method('spawn') to be used
444
+ # TODO test const mem for memory limited gpy
445
+ write_profile = get_write_profile()
446
+
447
+
448
+ # if a path is given
449
+ if isinstance(mask, str):
450
+ with fiona.open(mask) as layer:
451
+ features = [shape(feature["geometry"]) for feature in layer]
452
+ else:
453
+ #mask is already a mask
454
+ features = mask
455
+
456
+ is_pinned = False
457
+ loader_device = "cpu"
458
+ nn_device = "cpu"
459
+
460
+ if isinstance(jit_model_path, str):
461
+ model = torch.jit.load(jit_model_path)
462
+ else:
463
+ model = jit_model_path
464
+
465
+ if mode == 0:
466
+ is_pinned = False
467
+ loader_device = "cpu"
468
+ nn_device = "cpu"
469
+ if mode == 1:
470
+ is_pinned = True
471
+ loader_device = "cpu"
472
+ nn_device = "cuda"
473
+
474
+
475
+ mapper = NNMapper(model,
476
+ stride=1,
477
+ loader_device=loader_device,
478
+ mapper_device=nn_device,
479
+ pin_memory=is_pinned,
480
+ num_worker=num_worker,
481
+ prefetch_factor=prefetch,
482
+ custom_collate=custom_collate,
483
+ worker_init_fn=worker_init_fn,
484
+ write_profile=write_profile,
485
+ async_agr=True)
486
+
487
+ # deprecated, it is not recommended to use cuda in workers
488
+ # if mode == 2:
489
+ # is_pinned = False
490
+ # loader_device = "cuda"
491
+ # nn_device = "cuda"
492
+
493
+ # if mode == 2:
494
+ # mapper = NNMapper(model,
495
+ # windows_size,
496
+ # stride=1,
497
+ # batch_size=batch_size,
498
+ # loader_device=loader_device,
499
+ # mapper_device=nn_device,
500
+ # pin_memory=is_pinned,
501
+ # num_worker=num_worker,
502
+ # prefetch_factor=prefetch,
503
+ # map_dataset_class=IterableMapDatasetConstMem,
504
+ # custom_collate=meta_data_collate,
505
+ # worker_init_fn=worker_init_fn,
506
+ # aggregator=MapResultAggregatorConstMem,
507
+ # write_profile=write_profile,
508
+ # async_agr=True)
509
+
510
+ mapper.map(map_path, map_iterator, transformer, bounds=bounds, mask=features)
511
+
512
+ return map_path
513
+
514
+
515
+ def training_step(net, optimizer, loss, scheduler, train_dataloader, validation_dataloader, max_epochs,
516
+ run_stats_dir, model_base_path, model_tag, grad_clip_value, device):
517
+ """do the training step"""
518
+
519
+ grad_f = None
520
+ if grad_clip_value is not None:
521
+ grad_f = GradNormClipper(grad_clip_value)
522
+
523
+ return agressive_train_labeling(max_epochs, net, optimizer, loss, scheduler, train_dataloader,
524
+ validation_dataloader,
525
+ writer_base_path=run_stats_dir, model_base_path=model_base_path,
526
+ model_tag=model_tag,
527
+ grad_f=grad_f, device=device)
528
+
529
+
530
+ def generate_map(jit_model_path,
531
+ map_path,
532
+ raster_reader,
533
+ windows_size,
534
+ batch_size,
535
+ transformer,
536
+ mask, #path or geom
537
+ bounds,
538
+ mode,
539
+ num_worker,
540
+ prefetch,
541
+ custom_collate=meta_data_collate,
542
+ worker_init_fn=IterableMapDataset.basic_worker_init_fn):
543
+
544
+
545
+
546
+ map_iterator = IterableMapDataset(raster_reader, windows_size, batch_size=batch_size)
547
+
548
+ generate_map_iterator(jit_model_path,
549
+ map_path,
550
+ map_iterator,
551
+ transformer,
552
+ mask, #path or geom
553
+ bounds,
554
+ mode,
555
+ num_worker,
556
+ prefetch,
557
+ custom_collate=custom_collate,
558
+ worker_init_fn=worker_init_fn)
559
+
560
+
561
+
562
+ def train(sample_param, augmentation_param, dataset_parameter, dataloader_parameter, model_parameter, optimizer_parameter, train_nn_parameter):
563
+ """
564
+ Map and train in function of the parameter
565
+ :param sample_param:
566
+ :param augmentation_param:
567
+ :param dataset_parameter:
568
+ :param model_parameter:
569
+ :param train_nn_parameter:
570
+ :param map_parameter:
571
+ :return:
572
+ """
573
+
574
+ experiment_iterator = sample_param["methode"](**sample_param["param"])
575
+
576
+ out = []
577
+ for (i, (train_id, validation_id)) in enumerate(experiment_iterator):
578
+
579
+ train_dataset, validation_dataset = dataset_setup(train_id, validation_id, augmentation_param,
580
+ **dataset_parameter)
581
+
582
+ train_dataloader, validation_dataloader = dataloader_setup(train_dataset, validation_dataset,
583
+ **dataloader_parameter)
584
+
585
+ net = model_setup(**model_parameter)
586
+
587
+ optimizer, loss, scheduler = optimizer_setup(net=net, data_loader=train_dataloader, epoch=train_nn_parameter["max_epochs"],
588
+ **optimizer_parameter)
589
+
590
+ base_model_dir, best_model_path, model_path_jitted, model_name = training_step(net, optimizer, loss, scheduler, train_dataloader,
591
+ validation_dataloader, **train_nn_parameter)
592
+
593
+ del train_dataloader, validation_dataloader, net, loss, optimizer, train_id,
594
+
595
+ out.append((base_model_dir, best_model_path, model_path_jitted, model_name))
596
+
597
+ return out
598
+
599
+ def multi_db_merge(ids, augmentation_param, dataset_parameter):
600
+ """Merge multiple databases into combined train/validation datasets.
601
+
602
+ Args:
603
+ ids: List of (train_ids, val_ids) tuples for each database
604
+ augmentation_param: Dict of augmentation parameters
605
+ dataset_parameter: List of dataset parameters for each database
606
+
607
+ Returns:
608
+ Tuple of (merged_train_dataset, merged_validation_dataset)
609
+ """
610
+
611
+ train_data = []
612
+ valid_data = []
613
+ for (train_id, validation_id), data_param in zip(ids, dataset_parameter):
614
+
615
+ train_dataset, validation_dataset = dataset_setup(train_id, validation_id, augmentation_param, db_type= DBInfo,
616
+ **data_param)
617
+
618
+
619
+ train_data.append(train_dataset)
620
+ valid_data.append(validation_dataset)
621
+
622
+ return MultiDBDataset(train_data), MultiDBDataset(valid_data)
623
+
624
+ def multi_train_and_map(sample_param, augmentation_param, dataset_parameter, dataloader_parameter,model_parameter, optimizer_parameter, train_nn_parameter,
625
+ map_parameter):
626
+ """
627
+ Map and train in function of the parameter
628
+ :param sample_param:
629
+ :param augmentation_param:
630
+ :param dataset_parameter:
631
+ :param model_parameter:
632
+ :param train_nn_parameter:
633
+ :param map_parameter:
634
+ :return:
635
+ """
636
+
637
+ experiment_iterator = sample_param["methode"](**sample_param["param"])
638
+
639
+ out = {}
640
+ for (i, ids) in enumerate(experiment_iterator):
641
+
642
+ train_dataset, validation_dataset = multi_db_merge(ids, augmentation_param, dataset_parameter)
643
+
644
+ train_dataloader, validation_dataloader = dataloader_setup(train_dataset, validation_dataset,
645
+ **dataloader_parameter)
646
+
647
+ net = model_setup(**model_parameter)
648
+
649
+ optimizer, loss, scheduler = optimizer_setup(net=net, data_loader=train_dataloader, epoch=train_nn_parameter["max_epochs"],
650
+ **optimizer_parameter)
651
+
652
+ base_model_dir, best_model_path, model_path_jitted, model_name = training_step(net, optimizer, loss, scheduler, train_dataloader,
653
+ validation_dataloader, **train_nn_parameter)
654
+
655
+ # TODO add again for later
656
+ #dataset_stats(f'{base_model_dir}' ,model_path_jitted, train_id, validation_id, dataset_parameter["db_path"],
657
+ # dataset_parameter["batch_size"], dataset_parameter["mapper"], dataset_parameter["num_worker"],
658
+ # dataset_parameter["prefetch"], train_nn_parameter["device"])
659
+
660
+ del train_dataloader, validation_dataloader, net, loss, optimizer, ids
661
+
662
+
663
+ for i, map_param in enumerate(map_parameter):
664
+ map_path = f'{base_model_dir}/{map_param["map_tag"]}_{model_name}_{i}.tif'
665
+ map_param["map_path"] = map_path
666
+ current_map_parameter = map_param.copy()
667
+ del current_map_parameter['map_tag'] # remove as not a input of the mapping function
668
+
669
+ m_out = out.setdefault(i, [])
670
+ m_out.append(generate_map(model_path_jitted, **current_map_parameter))
671
+
672
+ return out
673
+
674
+
675
+ def train_and_map(sample_param, augmentation_param, dataset_parameter, dataloader_parameter, model_parameter, optimizer_parameter, train_nn_parameter,
676
+ map_parameter):
677
+ """
678
+ Map and train in function of the parameter
679
+ :param sample_param:
680
+ :param augmentation_param:
681
+ :param dataset_parameter:
682
+ :param model_parameter:
683
+ :param train_nn_parameter:
684
+ :param map_parameter:
685
+ :return:
686
+ """
687
+
688
+ experiment_iterator = sample_param["methode"](**sample_param["param"])
689
+
690
+ out = []
691
+ for (i, (train_id, validation_id)) in enumerate(experiment_iterator):
692
+
693
+ train_dataset, validation_dataset = dataset_setup(train_id, validation_id, augmentation_param,
694
+ **dataset_parameter)
695
+
696
+ train_dataloader, validation_dataloader = dataloader_setup(train_dataset, validation_dataset,
697
+ **dataloader_parameter)
698
+
699
+ net = model_setup(**model_parameter)
700
+
701
+ optimizer, loss, scheduler = optimizer_setup(net=net, data_loader=train_dataloader,
702
+ epoch=train_nn_parameter["max_epochs"],
703
+ **optimizer_parameter)
704
+
705
+ base_model_dir, best_model_path, model_path_jitted, model_name = training_step(net,
706
+ optimizer,
707
+ loss,
708
+ scheduler,
709
+ train_dataloader,
710
+ validation_dataloader,
711
+ **train_nn_parameter)
712
+
713
+ dataset_stats(f'{base_model_dir}', model_path_jitted, train_id, validation_id, augmentation_param["transform_valid"], dataset_parameter["db_path"],
714
+ dataloader_parameter["batch_size"], dataset_parameter["mapper"], dataloader_parameter["num_worker"],
715
+ dataloader_parameter["prefetch"], train_nn_parameter["device"])
716
+
717
+ del train_dataloader, validation_dataloader, net, loss, optimizer, train_id
718
+
719
+ map_path = f'{base_model_dir}/{map_parameter["map_tag"]}_{model_name}_{i}.tif'
720
+ map_parameter["map_path"] = map_path
721
+ current_map_parameter = map_parameter.copy()
722
+ del current_map_parameter['map_tag'] # remove as not a input of the mapping function
723
+ out.append(generate_map(model_path_jitted, **current_map_parameter))
724
+
725
+
726
+ return out
727
+
728
+
729
+ def dataset_stats(path, jit_model_path, train_id_split, validation_split, transform, db_path, batch_size, mapper, num_worker,
730
+ prefetch, device):
731
+
732
+ model = torch.jit.load(jit_model_path)
733
+ model = model.to(device)
734
+
735
+ # we load a full batch at once to limit the opening and closing of the db
736
+
737
+ # dataset_train = DBDataset(gps_db, train_id, mapper_vector)
738
+ dataset_train = DBDatasetMeta(db_path, train_id_split, mapper, f_transform=transform)
739
+ dataset_valid = DBDatasetMeta(db_path, validation_split, mapper, f_transform=transform)
740
+
741
+ train_rsampler = RandomSampler(dataset_train, replacement=False, num_samples=None, generator=None)
742
+ valid_rsampler = RandomSampler(dataset_valid, replacement=False, num_samples=None, generator=None)
743
+
744
+ train_bsampler = BatchSampler(train_rsampler, batch_size=batch_size, drop_last=False)
745
+ valid_bsampler = BatchSampler(valid_rsampler, batch_size=batch_size, drop_last=False)
746
+
747
+ if device == "cuda":
748
+ pin_memory = True
749
+ else:
750
+ pin_memory = False
751
+
752
+
753
+ train_dataloader = DataLoader(dataset_train, sampler=train_bsampler, collate_fn=batch_collate,
754
+ num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
755
+ worker_init_fn=db_dataset_multi_proc_init)
756
+ validation_dataloader = DataLoader(dataset_valid, sampler=valid_bsampler, collate_fn=batch_collate,
757
+ num_workers=num_worker, prefetch_factor=prefetch, pin_memory=pin_memory,
758
+ worker_init_fn=db_dataset_multi_proc_init)
759
+
760
+ stats = ClassificationStats(len(mapper), device, mapper.nn_name())
761
+
762
+ stats.compute(model, train_dataloader, device)
763
+ stats.display()
764
+ stats.to_file(f"{path}/train.txt")
765
+
766
+ stats.compute(model, validation_dataloader, device)
767
+ stats.display()
768
+ stats.to_file(f"{path}/train.txt")
769
+
770
+ bctg = BadlyClassifyToGPKG()
771
+
772
+ bctg.compute(model, train_dataloader, device)
773
+ bctg.to_file(f"{path}/train.gpkg")
774
+
775
+ bctg.compute(model, validation_dataloader, device)
776
+ bctg.to_file(f"{path}/validation.gpkg")
777
+
778
+ def tiled_task(maps, raster_out, operation, bounds=None, res=None,
779
+ resampling=Resampling.nearest, target_aligned_pixels=False, indexes=None, src_kwds=None,
780
+ dst_kwds=None, num_workers=8):
781
+ """
782
+ Execute the specified tiled task with the specified arguments
783
+ :param maps:
784
+ :param raster_out:
785
+ :param operation:
786
+ :param operation_args:
787
+ :param bounds:
788
+ :param res:
789
+ :param nodata:
790
+ :param resampling:
791
+ :param target_aligned_pixels:
792
+ :param indexes:
793
+ :param src_kwds:
794
+ :param dst_kwds:
795
+ :param num_workers:
796
+ :return:
797
+ """
798
+
799
+
800
+ if src_kwds is None:
801
+ src_kwds = get_read_profile()
802
+
803
+ if dst_kwds is None:
804
+ dst_kwds = get_write_profile()
805
+
806
+ tiled_op(
807
+ maps,
808
+ operation,
809
+ operation.n_band_out,
810
+ raster_out,
811
+ bounds=bounds,
812
+ res=res,
813
+ nodata=operation.nodata,
814
+ dtype=operation.dtype,
815
+ indexes=indexes,
816
+ resampling=resampling,
817
+ target_aligned_pixels=target_aligned_pixels,
818
+ dst_kwds=dst_kwds,
819
+ src_kwds=src_kwds,
820
+ num_workers=num_workers)
821
+
822
+
823
+
824
+
825
+ # only run if main script (important as we will use threading with the spawn methode