eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,137 @@
1
+ import typer
2
+ from pathlib import Path
3
+
4
+ from rasterio.enums import Resampling
5
+
6
+ from eoml import get_read_profile, get_write_profile
7
+ from eoml.automation.tasks import tiled_task
8
+ from rasterop.tiled_op.operation import CopyFirstNonNullOP
9
+ from rasterop.tiled_op.tiled_raster_op import get_image_file, TiledOPExecutor
10
+
11
+ app = typer.Typer(help="Raster merging utility that take all the TIFF file in the input directories sorted by"
12
+ " alphabetical order and copies the first non-nan value to the final TIFF")
13
+
14
+ def parse_resampling(value: str) -> Resampling:
15
+ value_norm = value.strip().lower()
16
+
17
+ by_name = {m.name.lower(): m for m in Resampling}
18
+ if value_norm in by_name:
19
+ return by_name[value_norm]
20
+
21
+ # Optional: allow numeric values too (handy for backwards-compat)
22
+ if value_norm.isdigit():
23
+ return Resampling(int(value_norm))
24
+
25
+ raise typer.BadParameter(
26
+ f"Invalid resampling '{value}'. Choose one of: {', '.join(sorted(by_name))}"
27
+ )
28
+
29
+
30
+ @app.command()
31
+ def merge_rasters(
32
+ input_dir: Path = typer.Argument(
33
+ ...,
34
+ help="Input directory containing TIFF files",
35
+ exists=True,
36
+ dir_okay=True,
37
+ file_okay=False
38
+ ),
39
+ output_file: Path = typer.Argument(
40
+ ...,
41
+ help="Output TIFF file path"
42
+ ),
43
+ num_threads: str = typer.Option(
44
+ "all_cpus",
45
+ "--threads", "-t",
46
+ help="Number of threads to use by gdal for compression"
47
+ ),
48
+ block_size: int = typer.Option(
49
+ 256,
50
+ "--block-size", "-b",
51
+ help="Block size for x and y dimensions for the geotiff internal structure"
52
+ ),
53
+ tile_size: int = typer.Option(
54
+ 2028,
55
+ "--tile-size", "-T",
56
+ help="Block size for the operation"
57
+ ),
58
+ num_workers: int = typer.Option(
59
+ 8,
60
+ "--workers", "-w",
61
+ help="Number of workers for processing"
62
+ ),
63
+ resampling: Resampling = typer.Option(
64
+ Resampling.nearest, # internal default as enum
65
+ "--resampling", "-r",
66
+ callback=lambda v: parse_resampling(v) if isinstance(v, str) else v,
67
+ help="Resampling method by name (nearest, bilinear, cubic, ...)",
68
+ )
69
+ ):
70
+ """
71
+ Merge multiple raster files by copying the first non-nan value to the final TIFF.
72
+ """
73
+ try:
74
+ # Get the list of raster files
75
+ rasters = get_image_file(input_dir, extension = ["tif", "tiff", "TIF", "TIFF"])
76
+ rasters.sort()
77
+ if len(rasters) == 0:
78
+ raise typer.BadParameter(f"No raster files found in {input_dir}")
79
+
80
+ # Set up writing and reading profiles
81
+ read_profile = get_read_profile()
82
+ profile = get_write_profile()
83
+
84
+ read_profile.update({'num_threads': num_threads})
85
+ profile.update({
86
+ "driver": "COG",
87
+ 'num_threads': num_threads,
88
+ 'blockxsize': block_size,
89
+ 'blockysize': block_size
90
+ })
91
+
92
+ # Set up operation parameters
93
+ default_op_param = {
94
+ "bounds": None,
95
+ "res": None,
96
+ "resampling": resampling,
97
+ "target_aligned_pixels": False,
98
+ "indexes": None,
99
+ "src_kwds": None,
100
+ "dst_kwds": None,
101
+ "num_workers": num_workers
102
+ }
103
+
104
+ # Create operator and set parameters
105
+ operator = CopyFirstNonNullOP.same_as(rasters[0])
106
+ operator_param = {
107
+ "maps": rasters,
108
+ "raster_out": str(output_file),
109
+ "operation": operator,
110
+ "dst_kwds": profile
111
+ }
112
+ operator_param.update(default_op_param)
113
+
114
+ # Execute tiled task
115
+ typer.echo(f"Merging {len(rasters)} raster files...")
116
+
117
+ # TiledOPExecutor(res=None,
118
+ # indexes=None,
119
+ # resampling=Resampling.nearest,
120
+ # target_aligned_pixels=False,
121
+ # dst_kwds=None,
122
+ # src_kwds=None,
123
+ # num_workers=2,
124
+ # window_size=None).execute(**operator_param)
125
+ #
126
+ #
127
+ tiled_task(**operator_param)
128
+ typer.echo(f"Successfully merged rasters to {output_file}")
129
+
130
+ except Exception as e:
131
+ typer.echo(f"Error: {str(e)}", err=True)
132
+ raise typer.Exit(1)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ app()
137
+
eoml/data/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """
2
+ Data Module for EOML.
3
+
4
+ This module provides data structures and utilities for handling geospatial
5
+ data in Earth observation applications. It includes basic geodata classes,
6
+ dataset utilities, and persistence mechanisms.
7
+ """
@@ -0,0 +1,214 @@
1
+ """Basic geographical data structures for storing raster samples with metadata.
2
+
3
+ This module defines core data structures for representing geospatial training samples,
4
+ including headers with geometry information and complete samples combining raster data
5
+ with labels and metadata.
6
+ """
7
+
8
+ import math
9
+
10
+ import numpy as np
11
+ import rasterio
12
+
13
+ from eoml import get_read_profile, get_write_profile
14
+
15
+
16
+ class GeoDataHeader:
17
+ """Header containing metadata for a geospatial data sample.
18
+
19
+ Stores identifying information about a geographic sample including its unique identifier,
20
+ spatial geometry (typically a point), and source file name.
21
+
22
+ Attributes:
23
+ idx: Unique identifier for the sample (from vector file or assigned)
24
+ geometry: Shapely geometry object representing the sample location
25
+ file_name: Name of the source file containing this sample
26
+ """
27
+
28
+ def __init__(self, idx, geometry, file_name):
29
+ """Initialize a GeoDataHeader.
30
+
31
+ Args:
32
+ idx: Unique identifier for the sample
33
+ geometry: Shapely geometry object (typically Point) for sample location
34
+ file_name: Source filename where this sample originates
35
+ """
36
+ self.idx = idx
37
+ self.geometry = geometry
38
+ self.file_name = file_name
39
+
40
+ def __eq__(self, other):
41
+ """Check equality based on idx, geometry, and file_name.
42
+
43
+ Args:
44
+ other: Another object to compare against
45
+
46
+ Returns:
47
+ True if all attributes match, False otherwise
48
+ """
49
+ if isinstance(other, GeoDataHeader):
50
+ return self.idx == other.idx and self.geometry == other.geometry and self.file_name == other.file_name
51
+ return NotImplemented
52
+
53
+ def __repr__(self):
54
+ """Return string representation of the header.
55
+
56
+ Returns:
57
+ String showing id, geometry WKT, and filename
58
+ """
59
+ return f"GeoDataHeader(id:{self.idx}, geometry:{self.geometry.wkt}, file_name:{self.file_name})"
60
+
61
+
62
+ class BasicGeoData:
63
+ """Complete geospatial sample with header, raster data, and target label.
64
+
65
+ Represents a training sample combining metadata (header), multi-band raster data
66
+ (typically a small image window), and a target value (class label or regression value).
67
+
68
+ Attributes:
69
+ header: GeoDataHeader containing sample metadata
70
+ data: NumPy array of raster data, typically shape (bands, height, width)
71
+ target: Target value (int for classification, float for regression, or array)
72
+ """
73
+
74
+ def __init__(self, header, data, target):
75
+ """Initialize a BasicGeoData sample.
76
+
77
+ Args:
78
+ header: GeoDataHeader with sample metadata
79
+ data: NumPy array containing raster data
80
+ target: Target value(s) for supervised learning
81
+ """
82
+ self.header = header
83
+ self.data = data
84
+ self.target = target
85
+
86
+ @property
87
+ def header(self):
88
+ """Get the sample header.
89
+
90
+ Returns:
91
+ GeoDataHeader instance
92
+ """
93
+ return self._header
94
+
95
+ @header.setter
96
+ def header(self, value):
97
+ """Set the sample header.
98
+
99
+ Args:
100
+ value: GeoDataHeader instance
101
+ """
102
+ self._header = value
103
+
104
+ @property
105
+ def data(self):
106
+ """Get the raster data array.
107
+
108
+ Returns:
109
+ NumPy array of raster data
110
+ """
111
+ return self._data
112
+
113
+ @data.setter
114
+ def data(self, value):
115
+ """Set the raster data array.
116
+
117
+ Args:
118
+ value: NumPy array containing raster data
119
+ """
120
+ self._data = value
121
+
122
+ @property
123
+ def target(self):
124
+ """Get the target label or value.
125
+
126
+ Returns:
127
+ Target value (scalar or array)
128
+ """
129
+ return self._target
130
+
131
+ @target.setter
132
+ def target(self, value):
133
+ """Set the target label or value.
134
+
135
+ Args:
136
+ value: Target value for the sample
137
+ """
138
+ self._target = value
139
+
140
+ def __eq__(self, other):
141
+ """Check equality based on header, data, and target.
142
+
143
+ Args:
144
+ other: Another object to compare against
145
+
146
+ Returns:
147
+ True if all components match (including NaN values), False otherwise
148
+ """
149
+ if isinstance(other, BasicGeoData):
150
+ return self.header == other.header and np.array_equal(self.data, other.data, equal_nan=True)\
151
+ and self.target == other.target
152
+
153
+ return NotImplemented
154
+
155
+ def to_file(self, path, ref):
156
+ """Write the raster data to a GeoTIFF file with proper georeferencing.
157
+
158
+ Exports the sample's raster data to a georeferenced GeoTIFF using the coordinate
159
+ reference system and transform from a reference raster. The output raster is
160
+ centered on the sample's geometry point.
161
+
162
+ Args:
163
+ path: Output path for the GeoTIFF file
164
+ ref: Path to reference raster file for CRS and transform information
165
+
166
+ Returns:
167
+ None. Writes GeoTIFF to specified path
168
+
169
+ Raises:
170
+ IOError: If reference file cannot be opened or output cannot be written
171
+ """
172
+ with rasterio.open(ref) as src:
173
+ #aff = src.transform
174
+ #pixelSizeX = aff[0]
175
+ #pixelSizeY = -aff[4]
176
+
177
+ crs = src.crs
178
+ x = self.header.geometry.x
179
+ y = self.header.geometry.y
180
+
181
+ row, col = src.index(x, y, op=math.floor)
182
+
183
+ sizeX = self.data.shape[1] / 2
184
+ sizeY = self.data.shape[1] / 2
185
+
186
+ west, north = src.xy(row-sizeX, col-sizeY)
187
+ east, south = src.xy(row + sizeX, col + sizeY)
188
+ #west, south, east, north = self.header.geometry.extends
189
+
190
+
191
+ transform = rasterio.transform.from_bounds(west, south, east, north,
192
+ self.data.shape[1], self.data.shape[2])
193
+
194
+
195
+ profile = get_write_profile()
196
+
197
+ profile.update({"height": self.data.shape[1],
198
+ "width": self.data.shape[2],
199
+ "count": self.data.shape[0],
200
+ "dtype": self.data.dtype,
201
+ "crs": crs,
202
+ "transform": transform})
203
+
204
+ with rasterio.open(path, "w", **profile) as src:
205
+ src.write(self.data)
206
+
207
+
208
+ def __repr__(self):
209
+ """Return string representation of the sample.
210
+
211
+ Returns:
212
+ String showing header, data shape/dtype, and target
213
+ """
214
+ return f"BasicGeoData(header:{self.header.__repr__()}, data:{self.data.__repr__()}, target:{self.target.__repr__()})"
@@ -0,0 +1,98 @@
1
+ """
2
+ Dataset utility functions for machine learning workflows.
3
+
4
+ This module provides utility functions for splitting and organizing datasets
5
+ for machine learning experiments, including random splitting and k-fold
6
+ cross-validation setup.
7
+ """
8
+
9
+ import math
10
+ import random
11
+
12
+ import numpy as np
13
+ import rasterio
14
+
15
+
16
+ def random_split(id_list, counts_list, relative=False) -> list:
17
+ """
18
+ Randomly split a list of IDs into multiple subsets.
19
+
20
+ This function splits a list of identifiers into multiple subsets according
21
+ to specified counts. Useful for creating train/validation/test splits.
22
+
23
+ Args:
24
+ id_list (list): List of identifiers to split.
25
+ counts_list (list): List of counts for each split. If relative=True,
26
+ these are interpreted as proportions; otherwise as absolute counts.
27
+ relative (bool, optional): If True, counts_list values are proportions
28
+ of the total. If False, they are absolute counts. Defaults to False.
29
+
30
+ Returns:
31
+ list: List of lists, where each sublist contains IDs for one split.
32
+
33
+ Raises:
34
+ Exception: If the requested number of samples exceeds the list length.
35
+
36
+ Examples:
37
+ >>> ids = list(range(100))
38
+ >>> train, val, test = random_split(ids, [0.7, 0.15, 0.15], relative=True)
39
+ >>> # or with absolute counts:
40
+ >>> train, val = random_split(ids, [80, 20], relative=False)
41
+ """
42
+ n_el = len(id_list)
43
+ ids = id_list.copy()
44
+ counts = counts_list
45
+
46
+ if relative:
47
+ counts = list(map(lambda x: round(x*n_el), counts))
48
+
49
+ random.shuffle(ids)
50
+ sums = np.cumsum(counts)
51
+ sums = np.insert(sums, 0, 0)
52
+
53
+ if sums[-1] > n_el:
54
+ raise Exception("number of sample requested higher than list length")
55
+
56
+ split = []
57
+ for i in range(1, len(sums)):
58
+ start = sums[i-1]
59
+ end = sums[i]
60
+ split.append(ids[start:end])
61
+
62
+ return split
63
+
64
+
65
+ def k_fold_sample(id_list, n_fold):
66
+ """
67
+ Create k-fold cross-validation splits from a list of IDs.
68
+
69
+ This function creates n_fold partitions of the data and generates fold
70
+ definitions for k-fold cross-validation, where each fold is used once
71
+ as validation while the remaining folds are used for training.
72
+
73
+ Args:
74
+ id_list (list): List of sample identifiers to split.
75
+ n_fold (int): Number of folds to create.
76
+
77
+ Returns:
78
+ tuple: A tuple containing:
79
+ - folds (list): List of n_fold lists, each containing sample IDs for that fold.
80
+ - fold_id (list): List of tuples defining train/validation splits, where each
81
+ tuple contains ([training_fold_indices], [validation_fold_index]).
82
+
83
+ Examples:
84
+ >>> ids = list(range(100))
85
+ >>> folds, fold_splits = k_fold_sample(ids, n_fold=5)
86
+ >>> # folds[0] contains ~20 samples, fold_splits[0] is ([1,2,3,4], [0])
87
+ """
88
+ # create n partition of the data
89
+ random.shuffle(id_list)
90
+ # create n partition of the data
91
+ folds = [id_list[cross::n_fold] for cross in range(n_fold)]
92
+
93
+ fold_id = []
94
+ # make the fold, exclude 1 sample each time
95
+ for i in range(n_fold):
96
+ fold_id.append(([j for j in range(n_fold) if j != i], [i]))
97
+
98
+ return folds, fold_id
@@ -0,0 +1,7 @@
1
+ """
2
+ Persistence Submodule for Data.
3
+
4
+ This submodule provides persistence mechanisms for geodata, including
5
+ database access objects (DAOs), readers, writers, and serializers for
6
+ efficient storage and retrieval of geospatial machine learning datasets.
7
+ """