eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,579 @@
1
+ """Map dataset utilities for applying neural networks to raster imagery.
2
+
3
+ This module provides tools for iterating over raster data in a sliding window fashion,
4
+ applying neural network predictions, and aggregating results back to raster format.
5
+ Supports parallel processing and various optimization strategies.
6
+ """
7
+
8
+ import copy
9
+ import logging
10
+ import math
11
+ from typing import Union
12
+
13
+ import numpy as np
14
+ import rasterio
15
+ import torch
16
+ import torch.multiprocessing as mp
17
+
18
+ from eoml import get_write_profile
19
+ from eoml.torch.cnn.outputs_transformer import OutputTransformer
20
+ from eoml.torch.cnn.torch_utils import align_grid
21
+ from rasterio.windows import Window
22
+ from shapely.geometry import box
23
+ from torch.utils.data import IterableDataset, DataLoader
24
+ from tqdm import tqdm
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ class BatchMeta:
29
+ """Metadata for a batch of data windows.
30
+
31
+ Attributes:
32
+ window (Window, optional): Rasterio window specification.
33
+ is_finished (bool): Whether this is the last batch for current window.
34
+ worker (int): Worker ID that processed this batch.
35
+ """
36
+
37
+ def __init__(self, window: Union[Window, None], is_finished: bool, worker: int):
38
+ """Initialize BatchMeta.
39
+
40
+ Args:
41
+ window (Window, optional): Rasterio window for this batch.
42
+ is_finished (bool): Whether this completes processing of current window.
43
+ worker (int): Worker ID that processed this batch.
44
+ """
45
+ self.window = window
46
+ self.is_finished = is_finished
47
+ self.worker = worker
48
+
49
+ def continuous_split(dataset, worker_id:int, n_workers: int):
50
+ """Split dataset windows continuously across workers.
51
+
52
+ Each worker receives adjacent contiguous blocks. Better for locality when
53
+ overlapping cells may be cached by GDAL.
54
+
55
+ Args:
56
+ dataset: Dataset with target_windows attribute.
57
+ worker_id (int): ID of current worker.
58
+ n_workers (int): Total number of workers.
59
+
60
+ Returns:
61
+ list: Subset of target_windows for this worker.
62
+ """
63
+
64
+ size, reminder = divmod(len(dataset.target_windows), n_workers)
65
+
66
+ if worker_id < reminder:
67
+ size = size + 1
68
+ start = worker_id * size
69
+ end = start + size
70
+ # make a deep copy to try to avoid the shared memory copy issue
71
+
72
+ else:
73
+ # the reminder are consumed by the previous worker
74
+ start = worker_id * size + reminder
75
+ end = start + size
76
+
77
+ return dataset.target_windows[start:end]
78
+
79
+
80
+ def jumped_split(dataset, worker_id: int, n_workers: int):
81
+ """Split dataset windows in interleaved fashion across workers.
82
+
83
+ Workers process blocks i, i+n_workers, i+2*n_workers, etc. Better for seeing
84
+ overall progress as work is distributed across the full spatial extent.
85
+
86
+ Args:
87
+ dataset: Dataset with target_windows attribute.
88
+ worker_id (int): ID of current worker.
89
+ n_workers (int): Total number of workers.
90
+
91
+ Returns:
92
+ list: Subset of target_windows for this worker (every n_workers-th window).
93
+ """
94
+ return dataset.target_windows[worker_id::n_workers]
95
+
96
+
97
+ def windows_in_mask(window: Window, transform, mask):
98
+ """Check if a raster window intersects with spatial mask.
99
+
100
+ Args:
101
+ window (Window): Rasterio window to check.
102
+ transform: Affine transform for the raster.
103
+ mask: List of shapely geometries defining the mask.
104
+
105
+ Returns:
106
+ bool: True if window intersects any geometry in mask.
107
+ """
108
+ bounds = box(*rasterio.windows.bounds(window, transform))
109
+ return any(map(lambda shape: shape.intersects(bounds), mask))
110
+
111
+ # remove soon
112
+ class IterableMapDataset(IterableDataset):
113
+ """Iterable dataset for applying CNNs to raster imagery using sliding windows.
114
+
115
+ Reads raster data in windows and extracts overlapping patches for CNN inference.
116
+ Creates an aligned output raster with convolution border handling. When stride > 1,
117
+ windows starting at top-left with size stride x stride are filled with NN output.
118
+
119
+ Attributes:
120
+ raster_reader: Reader for input raster data.
121
+ size (int): Kernel/window size for CNN.
122
+ half_size (int): Half of kernel size (for padding calculations).
123
+ target_windows (list, optional): List of windows to process.
124
+ off_x (float): X offset for window alignment.
125
+ off_y (float): Y offset for window alignment.
126
+ device (str): Device for tensor operations.
127
+ stride (int): Stride for window extraction.
128
+ batch_size (int): Number of samples per batch.
129
+ worker_id (int): ID of current worker process.
130
+ """
131
+ # Create an aligned raster with cropped border to take the convolution into account.
132
+ # If stride is >1, the widows starting at the top left corner and size stride X stride
133
+ # will be filled with the value returned by the NN.
134
+
135
+ def __init__(self, raster_reader, kernel_size, target_windows=None, off_x=None, off_y=None, stride=1, batch_size=1024,
136
+ device="cpu"):
137
+ """Initialize IterableMapDataset.
138
+
139
+ Args:
140
+ raster_reader: RasterReader instance for input data.
141
+ kernel_size (int): Size of CNN input window (must be odd).
142
+ target_windows (list, optional): Pre-defined windows to process. Defaults to None.
143
+ off_x (float, optional): X offset for alignment. Defaults to kernel_size/2.
144
+ off_y (float, optional): Y offset for alignment. Defaults to kernel_size/2.
145
+ stride (int, optional): Stride for window extraction. Defaults to 1.
146
+ batch_size (int, optional): Batch size for processing. Defaults to 1024.
147
+ device (str, optional): PyTorch device. Defaults to "cpu".
148
+
149
+ Raises:
150
+ Exception: If kernel_size is even (only odd kernels supported).
151
+ """
152
+
153
+ super().__init__()
154
+
155
+ self.raster_reader = raster_reader
156
+
157
+ self.size = kernel_size
158
+ self.half_size = math.floor(kernel_size / 2)
159
+
160
+ self.target_windows = target_windows
161
+
162
+
163
+ if kernel_size % 2 == 0:
164
+ raise "odd kernel not supported yet"
165
+
166
+ if off_x is None:
167
+ off_x = kernel_size / 2
168
+
169
+ if off_y is None:
170
+ off_y = kernel_size / 2
171
+
172
+ self.off_x = off_x
173
+ self.off_y = off_y
174
+
175
+ self.device = device
176
+
177
+
178
+ self.stride = stride
179
+
180
+ # Return the list of block with corresponding target input
181
+ # filter the list based on a shapefile if needed
182
+ # each worker has a list of block assignated
183
+ # load a block in memory
184
+ # read the block nline by nline? and create input return with cell id and line id
185
+
186
+ self.batch_size = batch_size
187
+ self.worker_id = 0
188
+
189
+
190
+ def __iter__(self):
191
+ """
192
+ iterator over the dataset. return at most batch_size data or the number of data needed to finish the current
193
+ block of data.
194
+ :return: data, (target_windows, is_block_finished, worker_id)
195
+ """
196
+
197
+ with self.raster_reader as reader:
198
+ for ji, window in self.target_windows:
199
+
200
+ (col_off, row_off, w_width, w_height) = window.flatten()
201
+ # compute the source windows
202
+ window_source = Window(col_off + self.off_x - self.half_size, row_off + self.off_y - self.half_size,
203
+ w_width + self.size - 1, w_height + self.size - 1)
204
+
205
+ a = reader.read_windows(window_source)
206
+ buffer = torch.from_numpy(a).to(self.device)
207
+
208
+ for tmp in self.extract_tensor_iter(buffer, self.batch_size):
209
+ sample, meta = tmp
210
+ meta.window = window
211
+ yield sample, meta
212
+
213
+ def extract_tensor_iter(self, data, batch_size):
214
+ """
215
+ Read the nn windows from the given data
216
+ :param data:
217
+ :param batch_size:
218
+ :return:
219
+ """
220
+ _, height, width = data.shape
221
+
222
+ height = height - self.size + 1
223
+ width = width - self.size + 1
224
+
225
+ samples = []
226
+
227
+ count = 0
228
+
229
+ for i in range(0, height, self.stride):
230
+ for j in range(0, width, self.stride):
231
+ if count == batch_size:
232
+ yield torch.stack(samples, dim=0), BatchMeta(None, False, self.worker_id)
233
+ samples = []
234
+ count = 0
235
+ source_w = data.narrow(1, i, self.size).narrow(2, j, self.size)
236
+ samples.append(source_w)
237
+ count += 1
238
+
239
+ yield torch.stack(samples, dim=0), BatchMeta(None, True, self.worker_id)
240
+
241
+
242
+ @staticmethod
243
+ def basic_worker_init_fn(worker_id, splitting_f=jumped_split):
244
+ """
245
+ A basic function splitting the worker job to the dataset. Try to make deep copy where needed to avoid the
246
+ memory issue when multiple worker (test needed)
247
+ :param worker_id:
248
+ :param splitting_f:
249
+ :return:
250
+ """
251
+ worker_info = torch.utils.data.get_worker_info()
252
+
253
+ dataset = worker_info.dataset # the dataset copy in this worker process
254
+ n_workers = worker_info.num_workers
255
+
256
+ # make a deep copy to try to avoid the shared memory copy issue
257
+ dataset.target_windows = copy.deepcopy(splitting_f(dataset, worker_id, n_workers))
258
+ dataset.worker_id = worker_id
259
+
260
+ # make a copy of the dataset_reader for thread safety
261
+ dataset.raster_reader = copy.copy(dataset.raster_reader)
262
+
263
+ class IterableYearMapDataset(IterableMapDataset):
264
+ def __init__(self, raster_reader, year, kernel_size, target_windows=None, off_x=None, off_y=None, stride=1, batch_size=1024,
265
+ device="cpu", year_normalisation=2500):
266
+ super().__init__(raster_reader, kernel_size, target_windows, off_x, off_y, stride, batch_size, device)
267
+
268
+ self.year = year
269
+ self.year_normalisation=year_normalisation
270
+ logger.info("initialize IterableYearMapDataset")
271
+ def extract_tensor_iter(self, data, batch_size):
272
+ """
273
+ Read the nn windows from the given data
274
+ :param data:
275
+ :param batch_size:
276
+ :return:
277
+ """
278
+ _, height, width = data.shape
279
+
280
+ height = height - self.size + 1
281
+ width = width - self.size + 1
282
+
283
+ samples = []
284
+
285
+ count = 0
286
+
287
+ for i in range(0, height, self.stride):
288
+ for j in range(0, width, self.stride):
289
+ if count == batch_size:
290
+ yield (torch.stack(samples, dim=0), np.array([[self.year/self.year_normalisation]
291
+ for _ in range(len(samples))],dtype=np.float32)),\
292
+ BatchMeta(None, False, self.worker_id)
293
+
294
+ samples = []
295
+ count = 0
296
+ source_w = data.narrow(1, i, self.size).narrow(2, j, self.size)
297
+ samples.append(source_w)
298
+ count += 1
299
+
300
+ yield (torch.stack(samples, dim=0), np.array([[self.year/self.year_normalisation] for _ in range(len(samples))],
301
+ dtype=np.float32),), BatchMeta(None, True, self.worker_id)
302
+
303
+ class MapResultAggregator:
304
+ """
305
+ Recieve the result back from the processing and write it ot a map
306
+ TODO manage encoder decoder
307
+ """
308
+
309
+ def __init__(self, path_out, output_transformer: OutputTransformer, n_windows, write_profile):
310
+
311
+ self.bands = output_transformer.bands
312
+ self.write_profile = copy.deepcopy(write_profile)
313
+ self.write_profile.update({"dtype": output_transformer.dtype,
314
+ "count": output_transformer.bands})
315
+ self.result_cache = {}
316
+
317
+ self.path_out = path_out
318
+ self.output_transformer = output_transformer
319
+ self.n_windows = n_windows
320
+
321
+ def submit_result(self, values, meta: BatchMeta):
322
+
323
+ values = self.output_transformer(values)
324
+ # values = values.reshape((self.n_band,windows.width, windows.height))
325
+ cached = self.result_cache.setdefault(meta.worker, [])
326
+
327
+ cached.append(values)
328
+
329
+ if meta.is_finished:
330
+ values = self.reshape(cached, meta.window)
331
+ cached.clear()
332
+ self.write(values, meta.window)
333
+
334
+ def reshape(self, data, windows):
335
+ """
336
+ Take a list of n array of abritraty length and n_bands depth and return a windows of size n_chanel, height, width
337
+ :param data:
338
+ :param windows:
339
+ :return:
340
+ """
341
+ #out = np.empty((self.n_bands, windows.height, windows.width), dtype=self.d_type)
342
+
343
+ width = windows.width
344
+ height = windows.height
345
+ #concatenate make one array. then we reshape and move the band which is in the last position to the first
346
+ return np.moveaxis(np.concatenate(data).reshape((height, width, self.bands)), 2, 0)
347
+
348
+ def write(self, data, windows):
349
+ """
350
+ todo currently flush after each windows maybe perf cost
351
+ :param data:
352
+ :param windows:
353
+ :return:
354
+ """
355
+ # for some reason if not in threading spawn mode, setting any other option cause a deadlock num_threads=4
356
+ with rasterio.open(self.path_out, "r+", sharing=False ) as writer:
357
+ writer.write(data, window=windows)
358
+
359
+
360
+ class AsyncAggregator(mp.Process):
361
+ """
362
+ wrapper around an aggregator which does the operation asynchronously
363
+ """
364
+ def __init__(self, aggregator: MapResultAggregator, max_queue=5):
365
+ super().__init__(daemon=True)
366
+ self.queue = mp.Queue(max_queue)
367
+ self.aggregator = copy.deepcopy(aggregator)
368
+ self.windows_left = aggregator.n_windows
369
+
370
+ self.daemon = True
371
+
372
+ def run(self):
373
+ while True:
374
+ data, meta = self.queue.get()
375
+ self.aggregator.submit_result(data, meta)
376
+ del data
377
+ if meta.is_finished:
378
+ self.windows_left -= 1
379
+ if self.windows_left == 0:
380
+ return
381
+
382
+ def submit_result(self, values, meta):
383
+ self.queue.put((values, meta))
384
+
385
+
386
+ class GenericMapper:
387
+ """ apply a generic mapping function. (allow to run random forest and co)
388
+ use no_collate to work on numpy
389
+ """
390
+ def __init__(self,
391
+ mapper,
392
+ stride=1,
393
+ loader_device='cpu',
394
+ mapper_device='cpu',
395
+ pin_memory=False,
396
+ num_worker=0,
397
+ prefetch_factor=2,
398
+ custom_collate=None,
399
+ worker_init_fn=None,
400
+ aggregator=MapResultAggregator,
401
+ write_profile=None,
402
+ async_agr=True):
403
+
404
+ logger.info("setting model to eval mode")
405
+
406
+
407
+ self.mapper = mapper.to(mapper_device)
408
+ self.mapper.eval()
409
+ self.stride = stride
410
+
411
+ self.loader_device = loader_device
412
+ self.mapper_device = mapper_device
413
+ self.pin_memory = pin_memory
414
+ self.num_worker = num_worker
415
+ self.prefetch_factor = prefetch_factor
416
+
417
+ self.custom_collate = custom_collate
418
+
419
+ self.worker_init_fn = worker_init_fn
420
+
421
+ if num_worker > 0 and worker_init_fn is None:
422
+ raise Exception("A custom worker_init_fn is needed for map iterator when parallel mode is used")
423
+
424
+ # factory should be used instead but for now
425
+ self.aggregator = aggregator
426
+ self.async_agr = async_agr
427
+
428
+ self.write_profile = write_profile
429
+
430
+ def map(self,
431
+ out_path,
432
+ map_iterator,
433
+ output_transformer: OutputTransformer,
434
+ bounds=None,
435
+ mask=None):
436
+
437
+ map_iterator, aggregator = self.mapping_generator(map_iterator,
438
+ out_path,
439
+ output_transformer,
440
+ bounds=bounds,
441
+ mask=mask,
442
+ write_profile=self.write_profile)
443
+
444
+ dl = DataLoader(map_iterator, collate_fn=self.custom_collate, pin_memory=self.pin_memory,
445
+ num_workers=self.num_worker, prefetch_factor=self.prefetch_factor,
446
+ worker_init_fn=self.worker_init_fn)
447
+
448
+ if self.async_agr:
449
+ aggregator = AsyncAggregator(aggregator, 10)
450
+ aggregator.start()
451
+
452
+ with torch.inference_mode():
453
+ with tqdm(total=len(map_iterator.target_windows), desc='Map') as pbar:
454
+
455
+ for inputs, meta in dl:
456
+ # increment when windows are finished
457
+
458
+ # nn_inputs = inputs[0].to(device)
459
+ # release cuda memory as soon as possible
460
+
461
+ out = self.process_batch(inputs)
462
+
463
+ aggregator.submit_result(out, meta[0])
464
+
465
+ if meta[0].is_finished:
466
+ # to monitor cuda
467
+ # if nn_device == "cuda":
468
+ # pbar.set_postfix({'allocated': torch.cuda.memory_allocated(),
469
+ # 'max allocated': torch.cuda.max_memory_allocated(),
470
+ # 'reserved': torch.cuda.memory_reserved(),
471
+ # 'max reserved': torch.cuda.max_memory_reserved()}, refresh=False)
472
+ pbar.update(1)
473
+
474
+ if self.async_agr:
475
+ aggregator.join()
476
+
477
+ def process_batch(self, inputs):
478
+
479
+ if isinstance(inputs, (list, tuple)):
480
+ inputs = map(lambda x: x.to(self.mapper_device, non_blocking=True), inputs)
481
+ else:
482
+ inputs = inputs.to(self.mapper_device, non_blocking=True)
483
+
484
+ # if not traced we need to run pass to the function *inputs
485
+ outputs = self.mapper(*inputs)
486
+ del inputs
487
+ s = outputs.detach().cpu().numpy()
488
+ del outputs
489
+ return s
490
+
491
+
492
+
493
+ def mapping_generator(self,
494
+ map_iterator,
495
+ path_out,
496
+ output_transformer: OutputTransformer,
497
+ bounds=None,
498
+ mask=None,
499
+ write_profile=None):
500
+ """
501
+ Generate a mapiterator and a agregattor to be used for mapping
502
+ :param map_iterator:
503
+ :param output_transformer:
504
+ :param path_out:
505
+ :param kernel_size:
506
+ :param bounds:
507
+ :param mask:
508
+ :param write_profile:
509
+ :return:
510
+ """
511
+ # TODO adapt the code for when windows are not block and to take into account encoder decoder
512
+
513
+ ref_raster = map_iterator.raster_reader.ref_raster()
514
+
515
+ with rasterio.open(ref_raster, mode="r") as raster_source:
516
+ src = raster_source
517
+
518
+ if write_profile is None:
519
+ write_profile = get_write_profile()
520
+
521
+ if map_iterator.size % 2 == 0:
522
+ raise "odd kernel not supported yet"
523
+
524
+ if not bounds:
525
+ bounds = src.bounds
526
+
527
+ # this function i based on the center pixel need adjustement for fully convolutional
528
+ transform, width, height, off_x, off_y = align_grid(src.transform, bounds, src.width, src.height, map_iterator.size)
529
+
530
+ write_profile.update({'dtype': output_transformer.dtype,
531
+ 'crs': src.crs,
532
+ 'transform': transform,
533
+ 'width': width,
534
+ 'height': height,
535
+ 'nodata': output_transformer.nodata,
536
+ 'count': output_transformer.bands})
537
+ # out_profile =out_meta(src.meta, 5, 1)
538
+ # create an empty raster
539
+ with rasterio.open(path_out, mode="w", **write_profile) as dst:
540
+ windows = list(dst.block_windows())
541
+
542
+ if mask is not None:
543
+ windows = list(filter(lambda x: windows_in_mask(x[1], transform, mask), windows))
544
+
545
+ # set the good parameter to the map iterator
546
+ map_iterator.target_windows = windows
547
+ map_iterator.off_x = off_x
548
+ map_iterator.off_y = off_y
549
+
550
+
551
+ aggregator = self.aggregator(path_out, output_transformer, len(windows), write_profile)
552
+
553
+ return map_iterator, aggregator
554
+
555
+
556
+ class NNMapper(GenericMapper):
557
+ """
558
+ Object specialised to run nn
559
+ """
560
+ def __init__(self,
561
+ model,
562
+ stride=1,
563
+ loader_device='cpu',
564
+ mapper_device='cpu',
565
+ pin_memory=False,
566
+ num_worker=0,
567
+ prefetch_factor=2,
568
+ custom_collate=None,
569
+ worker_init_fn=None,
570
+ aggregator=MapResultAggregator,
571
+ write_profile=None,
572
+ async_agr=True):
573
+
574
+ super().__init__(model, stride, loader_device, mapper_device, pin_memory, num_worker,
575
+ prefetch_factor, custom_collate, worker_init_fn, aggregator, write_profile, async_agr)
576
+
577
+
578
+
579
+