eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Generic data persistence interfaces for geospatial machine learning datasets.
|
|
2
|
+
|
|
3
|
+
This module defines abstract base classes and interfaces for storing and retrieving
|
|
4
|
+
geospatial training data. It provides a data access object (DAO) pattern for managing
|
|
5
|
+
labeled geospatial samples with associated metadata, supporting various backend
|
|
6
|
+
storage implementations (e.g., LMDB, HDF5, etc.).
|
|
7
|
+
|
|
8
|
+
The module includes serialization interfaces for both data and headers, enabling
|
|
9
|
+
flexible storage formats while maintaining a consistent access API.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
|
|
14
|
+
from eoml.data.basic_geo_data import BasicGeoData, GeoDataHeader
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BasicGeoDataDAO(ABC):
|
|
18
|
+
"""Abstract data access object for geospatial sample databases.
|
|
19
|
+
|
|
20
|
+
Defines the minimum interface required for reading from and writing to
|
|
21
|
+
a geospatial sample database. Implementations should provide efficient
|
|
22
|
+
storage and retrieval of raster data with associated labels and metadata.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def n_sample(self):
|
|
27
|
+
"""Number of sample in the db"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def get(self, num: int) -> BasicGeoData:
|
|
32
|
+
"""
|
|
33
|
+
Get the sample num
|
|
34
|
+
:param num: index of the sample to get
|
|
35
|
+
:return: Full sample with header
|
|
36
|
+
"""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def get_header(self, key: int) -> GeoDataHeader:
|
|
40
|
+
"""
|
|
41
|
+
Get the header for sample num
|
|
42
|
+
:param num: index of the sample to get
|
|
43
|
+
:return: header of the sample
|
|
44
|
+
"""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def get_data(self, num: int):
|
|
49
|
+
"""
|
|
50
|
+
Only get the input and expected output of the nn
|
|
51
|
+
:param num: index of the sample to get
|
|
52
|
+
:return: Full sample with header
|
|
53
|
+
"""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def get_output(self, num: int):
|
|
58
|
+
"""
|
|
59
|
+
get the output for num
|
|
60
|
+
:param num:
|
|
61
|
+
:return:
|
|
62
|
+
"""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def get_header_key(self, header: GeoDataHeader):
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def open(self):
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def close(self):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
# @abstractmethod
|
|
78
|
+
# def remove(self, value):
|
|
79
|
+
# pass
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def save(self, geodata: BasicGeoData):
|
|
83
|
+
"""
|
|
84
|
+
Insert sample in the db
|
|
85
|
+
:param geodata: data to save in the db
|
|
86
|
+
:return: nothing
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def __len__(self):
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class GeoDataReader:
|
|
96
|
+
"""Reader class for accessing geospatial dataset databases.
|
|
97
|
+
|
|
98
|
+
Provides high-level read operations on geospatial sample databases,
|
|
99
|
+
wrapping a DAO implementation with convenient access methods.
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
dao: Data access object providing low-level database operations.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, dao: BasicGeoDataDAO):
|
|
106
|
+
self.dao: BasicGeoDataDAO = dao
|
|
107
|
+
|
|
108
|
+
def n_sample(self):
|
|
109
|
+
return self.dao.n_sample()
|
|
110
|
+
|
|
111
|
+
def get(self, key: int) -> BasicGeoData:
|
|
112
|
+
return self.dao.get(key)
|
|
113
|
+
|
|
114
|
+
def get_header(self, key: int) -> GeoDataHeader:
|
|
115
|
+
return self.dao.get_header(key)
|
|
116
|
+
|
|
117
|
+
def get_output(self, num: int):
|
|
118
|
+
"""
|
|
119
|
+
get the output for num
|
|
120
|
+
:param num:
|
|
121
|
+
:return:
|
|
122
|
+
"""
|
|
123
|
+
return self.dao.get_output(num)
|
|
124
|
+
|
|
125
|
+
def get_data(self, key: int):
|
|
126
|
+
return self.dao.get_data(key)
|
|
127
|
+
|
|
128
|
+
def get_header_key(self, header: GeoDataHeader):
|
|
129
|
+
return self.dao.get_header_key(header)
|
|
130
|
+
|
|
131
|
+
def def_get_output_dic(self):
|
|
132
|
+
"""return a dictionary with id:output as entries"""
|
|
133
|
+
return {i: self.dao.get_output(i) for i in range(len(self.dao))}
|
|
134
|
+
|
|
135
|
+
def get_sample_id_output_dic(self):
|
|
136
|
+
"""return a dictionary with id IN THE GEOPACKAGE and the output"""
|
|
137
|
+
return {self.dao.get_header(i).idx: self.dao.get_output(i) for i in range(len(self.dao))}
|
|
138
|
+
|
|
139
|
+
def get_sample_id_db_key_dic(self):
|
|
140
|
+
"""return a dictionary with id:output as entries"""
|
|
141
|
+
return {self.dao.get_header(i).idx: i for i in range(len(self.dao))}
|
|
142
|
+
|
|
143
|
+
def _check_db_match(self, db_reader2):
|
|
144
|
+
with self, db_reader2:
|
|
145
|
+
|
|
146
|
+
if self.n_sample() !=db_reader2.n_sample():
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
id_key1 = self.get_sample_id_db_key_dic()
|
|
150
|
+
id_key2 = db_reader2.get_sample_id_db_key_dic()
|
|
151
|
+
|
|
152
|
+
for idx, k1 in id_key1.items():
|
|
153
|
+
k2 = id_key2.get(idx, None)
|
|
154
|
+
if k2 is None:
|
|
155
|
+
return False
|
|
156
|
+
s1 = self.get(k1)
|
|
157
|
+
s2 = db_reader2.get(k2)
|
|
158
|
+
|
|
159
|
+
if s1 != s2:
|
|
160
|
+
return False
|
|
161
|
+
return True
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def open(self):
|
|
165
|
+
self.dao.open()
|
|
166
|
+
|
|
167
|
+
def close(self):
|
|
168
|
+
self.dao.close()
|
|
169
|
+
|
|
170
|
+
def __enter__(self):
|
|
171
|
+
self.open()
|
|
172
|
+
return self
|
|
173
|
+
|
|
174
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
175
|
+
self.close()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class GeoDataWriter:
|
|
179
|
+
"""Writer class for saving geospatial samples to databases.
|
|
180
|
+
|
|
181
|
+
Provides high-level write operations for storing geospatial samples,
|
|
182
|
+
wrapping a DAO implementation.
|
|
183
|
+
|
|
184
|
+
Attributes:
|
|
185
|
+
dao: Data access object providing low-level database operations.
|
|
186
|
+
"""
|
|
187
|
+
def __init__(self, dao: BasicGeoDataDAO):
|
|
188
|
+
self.dao = dao
|
|
189
|
+
|
|
190
|
+
def save(self, geodata):
|
|
191
|
+
self.dao.save(geodata)
|
|
192
|
+
|
|
193
|
+
def open(self):
|
|
194
|
+
self.dao.open()
|
|
195
|
+
|
|
196
|
+
def close(self):
|
|
197
|
+
self.dao.close()
|
|
198
|
+
|
|
199
|
+
def __enter__(self):
|
|
200
|
+
self.open()
|
|
201
|
+
return self
|
|
202
|
+
|
|
203
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
204
|
+
self.close()
|
|
205
|
+
|
|
206
|
+
class GeoDataHeaderSerializer(ABC):
|
|
207
|
+
"""Abstract serializer for geospatial data headers.
|
|
208
|
+
|
|
209
|
+
Defines interface for converting GeoDataHeader objects to and from
|
|
210
|
+
byte representations for storage.
|
|
211
|
+
"""
|
|
212
|
+
@abstractmethod
|
|
213
|
+
def serialize(self, header: GeoDataHeader):
|
|
214
|
+
"""
|
|
215
|
+
serialize the header to bytes
|
|
216
|
+
:return: bytes representation of the headers
|
|
217
|
+
"""
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
@abstractmethod
|
|
221
|
+
def deserialize(self, msg):
|
|
222
|
+
"""
|
|
223
|
+
Deserialize the header
|
|
224
|
+
:param msg: data to deserialize
|
|
225
|
+
:return: a GeoDataHeader
|
|
226
|
+
"""
|
|
227
|
+
pass
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class GeoDataSerializer(ABC):
|
|
231
|
+
"""Abstract serializer for geospatial dataset arrays.
|
|
232
|
+
|
|
233
|
+
Defines interface for converting data arrays (typically numpy arrays)
|
|
234
|
+
to and from byte representations for storage.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
@abstractmethod
|
|
238
|
+
def serialize(self, data: BasicGeoData):
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
:param data: dataset to be serialized
|
|
242
|
+
:return: bytes representation of the data
|
|
243
|
+
"""
|
|
244
|
+
pass
|
|
245
|
+
|
|
246
|
+
@abstractmethod
|
|
247
|
+
def deserialize(self, msg):
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
:param msg: serialised representation
|
|
251
|
+
:return: the geodataset
|
|
252
|
+
"""
|
|
253
|
+
pass
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""LMDB-based persistence for geospatial machine learning datasets.
|
|
2
|
+
|
|
3
|
+
This module provides LMDB (Lightning Memory-Mapped Database) implementations
|
|
4
|
+
for storing and retrieving geospatial training data. LMDB offers fast random
|
|
5
|
+
access and efficient storage for large datasets, making it ideal for machine
|
|
6
|
+
learning workflows.
|
|
7
|
+
|
|
8
|
+
The implementation uses multiple named databases within a single LMDB environment
|
|
9
|
+
to organize different types of data (metadata, headers, inputs, outputs, and indices).
|
|
10
|
+
|
|
11
|
+
Reference:
|
|
12
|
+
https://stackoverflow.com/questions/32489778/how-do-i-count-and-enumerate-the-keys-in-an-lmdb-with-python
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Literal
|
|
16
|
+
|
|
17
|
+
import lmdb
|
|
18
|
+
from eoml.data.basic_geo_data import GeoDataHeader, BasicGeoData
|
|
19
|
+
from eoml.data.persistence.generic import BasicGeoDataDAO, GeoDataReader, GeoDataWriter
|
|
20
|
+
from eoml.data.persistence.serializer import MsgpackGeoDataHeaderSerializer, MsgpackGeoDataSerializer
|
|
21
|
+
|
|
22
|
+
# Stores geospatial training data in LMDB format with the following structure:
|
|
23
|
+
# - index_db: Maps (filename, feature_id) to sample keys
|
|
24
|
+
# - meta_db: Stores metadata (total samples, samples per category, etc.)
|
|
25
|
+
# - header_db: MessagePack-serialized headers (geometry, source file, source ID)
|
|
26
|
+
# - data_db: Raster data matrices
|
|
27
|
+
# - out_db: Target outputs/categories
|
|
28
|
+
|
|
29
|
+
class LMDBasicGeoDataDAO(BasicGeoDataDAO):
|
|
30
|
+
"""LMDB-based data access object for geospatial samples.
|
|
31
|
+
|
|
32
|
+
Stores geospatial training data in LMDB format with the following structure:
|
|
33
|
+
- index_db: Maps (filename, feature_id) to sample keys
|
|
34
|
+
- meta_db: Stores metadata (total samples, samples per category, etc.)
|
|
35
|
+
- header_db: MessagePack-serialized headers (geometry, source file, source ID)
|
|
36
|
+
- data_db: Raster data matrices
|
|
37
|
+
- out_db: Target outputs/categories
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
db_path: Path to LMDB database directory.
|
|
41
|
+
map_size_limit: Maximum size of LMDB map in bytes.
|
|
42
|
+
header_serializer: Serializer for GeoDataHeader objects.
|
|
43
|
+
data_serializer: Serializer for numpy array data.
|
|
44
|
+
index_db: Database for sample index mapping.
|
|
45
|
+
meta_db: Database for metadata storage.
|
|
46
|
+
header_db: Database for sample headers.
|
|
47
|
+
data_db: Database for input data.
|
|
48
|
+
category_db: Database for output labels.
|
|
49
|
+
num_sample: Current number of samples in database.
|
|
50
|
+
"""
|
|
51
|
+
def __init__(self, db_path, map_size_limit=int(1e+11), header_serializer=None, data_serializer=None):
|
|
52
|
+
|
|
53
|
+
super().__init__()
|
|
54
|
+
|
|
55
|
+
if header_serializer is None:
|
|
56
|
+
header_serializer = MsgpackGeoDataHeaderSerializer()
|
|
57
|
+
|
|
58
|
+
if data_serializer is None:
|
|
59
|
+
data_serializer = MsgpackGeoDataSerializer()
|
|
60
|
+
|
|
61
|
+
self.db_path = db_path
|
|
62
|
+
self._lmdb_env = None
|
|
63
|
+
self.category_db = None
|
|
64
|
+
self.data_db = None
|
|
65
|
+
self.header_db = None
|
|
66
|
+
self.meta_db = None
|
|
67
|
+
self.index_db = None
|
|
68
|
+
|
|
69
|
+
self.map_size_limit = map_size_limit
|
|
70
|
+
self.header_serializer = header_serializer
|
|
71
|
+
self.data_serializer = data_serializer
|
|
72
|
+
|
|
73
|
+
self.index_db_name = self.str_encode("index_db")
|
|
74
|
+
self.meta_db_name = self.str_encode("meta_db")
|
|
75
|
+
self.header_db_name = self.str_encode("header_db")
|
|
76
|
+
self.input_db_name = self.str_encode("data_db")
|
|
77
|
+
self.output_db_name = self.str_encode("out_db")
|
|
78
|
+
self.num_db = 5
|
|
79
|
+
self.num_sample = 0
|
|
80
|
+
self.num_sample_key = self.str_encode("num_sample")
|
|
81
|
+
|
|
82
|
+
def str_encode(self, text, encoding="utf-8", errors="strict"):
|
|
83
|
+
"""Encode string for db"""
|
|
84
|
+
return text.encode(encoding=encoding, errors=errors)
|
|
85
|
+
|
|
86
|
+
def str_decode(self, obj, encoding="utf-8", errors="strict"):
|
|
87
|
+
"""decode string"""
|
|
88
|
+
return obj.decode(encoding=encoding, errors=errors)
|
|
89
|
+
|
|
90
|
+
def int_encode(self, vale: int, byteorder: Literal["little", "big"] = 'big'):
|
|
91
|
+
return vale.to_bytes(4, byteorder)
|
|
92
|
+
|
|
93
|
+
def int_decode(self, val, byteorder: Literal["little", "big"] = 'big'):
|
|
94
|
+
return int.from_bytes(val, byteorder)
|
|
95
|
+
|
|
96
|
+
def n_sample(self):
|
|
97
|
+
return self.fetch_num_sample()
|
|
98
|
+
|
|
99
|
+
def sample_number_to_key(self, num):
|
|
100
|
+
"""transform the sample id number to a db key (i.e. byte), if bytes received assume it is already encoded
|
|
101
|
+
and do nothing
|
|
102
|
+
"""
|
|
103
|
+
# if already bytes. we assume the key is already encoded
|
|
104
|
+
if isinstance(num, bytes):
|
|
105
|
+
return num
|
|
106
|
+
return self.int_encode(num)
|
|
107
|
+
#encode with trailing 0
|
|
108
|
+
#return self.str_encode(f"{num:010}")
|
|
109
|
+
# encode as string
|
|
110
|
+
#return self.str_encode(str(num))
|
|
111
|
+
|
|
112
|
+
def header_to_index_key(self, geo_header: GeoDataHeader):
|
|
113
|
+
return self.str_encode(f"{geo_header.idx}-{geo_header.file_name}")
|
|
114
|
+
|
|
115
|
+
def get(self, num: int):
|
|
116
|
+
"""
|
|
117
|
+
Get the sample numer "num"
|
|
118
|
+
:param num:
|
|
119
|
+
:return:self.str_encode(
|
|
120
|
+
"""
|
|
121
|
+
key = self.sample_number_to_key(num)
|
|
122
|
+
data, target = self.get_data(key)
|
|
123
|
+
|
|
124
|
+
with self._lmdb_env.begin(write=False, db=self.header_db) as txn:
|
|
125
|
+
header = txn.get(key)
|
|
126
|
+
header = self.header_serializer.deserialize(header)
|
|
127
|
+
|
|
128
|
+
return BasicGeoData(header, data, target)
|
|
129
|
+
|
|
130
|
+
def get_header(self, num: int) -> GeoDataHeader:
|
|
131
|
+
"""
|
|
132
|
+
get the header for num
|
|
133
|
+
:param num:
|
|
134
|
+
:return:
|
|
135
|
+
"""
|
|
136
|
+
key = self.sample_number_to_key(num)
|
|
137
|
+
with self._lmdb_env.begin(write=False, db=self.header_db) as txn:
|
|
138
|
+
header = txn.get(key)
|
|
139
|
+
header = self.header_serializer.deserialize(header)
|
|
140
|
+
|
|
141
|
+
return header
|
|
142
|
+
|
|
143
|
+
def get_output(self, num: int):
|
|
144
|
+
"""
|
|
145
|
+
get the output for num
|
|
146
|
+
:param num:
|
|
147
|
+
:return:
|
|
148
|
+
"""
|
|
149
|
+
key = self.sample_number_to_key(num)
|
|
150
|
+
with self._lmdb_env.begin(write=False, db=self.category_db) as txn:
|
|
151
|
+
target = txn.get(key)
|
|
152
|
+
target = self.data_serializer.deserialize(target)
|
|
153
|
+
|
|
154
|
+
return target
|
|
155
|
+
|
|
156
|
+
def get_data(self, num):
|
|
157
|
+
"""Only get the data (input and output) for the sample num"""
|
|
158
|
+
key = self.sample_number_to_key(num)
|
|
159
|
+
|
|
160
|
+
# Get the input
|
|
161
|
+
with self._lmdb_env.begin(write=False, db=self.data_db) as txn:
|
|
162
|
+
data = txn.get(key)
|
|
163
|
+
data = self.data_serializer.deserialize(data)
|
|
164
|
+
# The output
|
|
165
|
+
target = self.get_output(num)
|
|
166
|
+
|
|
167
|
+
return data, target
|
|
168
|
+
|
|
169
|
+
def get_header_key(self, header: GeoDataHeader):
|
|
170
|
+
"""Only get the data (input and output) for the sample num"""
|
|
171
|
+
key = self.header_to_index_key(header)
|
|
172
|
+
|
|
173
|
+
with self._lmdb_env.begin(write=False, db=self.index_db) as txn:
|
|
174
|
+
data = txn.get(key)
|
|
175
|
+
|
|
176
|
+
# return index of key before change to int
|
|
177
|
+
#data = int(self.str_decode(data))
|
|
178
|
+
data = self.int_decode(data)
|
|
179
|
+
|
|
180
|
+
return data
|
|
181
|
+
|
|
182
|
+
def save(self, geodata: BasicGeoData):
|
|
183
|
+
"""
|
|
184
|
+
Save the sample to the database.
|
|
185
|
+
:param geodata:
|
|
186
|
+
:return:
|
|
187
|
+
"""
|
|
188
|
+
key = self.sample_number_to_key(self.num_sample)
|
|
189
|
+
|
|
190
|
+
self._save_index(geodata.header, key)
|
|
191
|
+
|
|
192
|
+
self._save_header(geodata.header, key)
|
|
193
|
+
self._save_data(geodata.data, key)
|
|
194
|
+
self._save_target(geodata.target, key)
|
|
195
|
+
|
|
196
|
+
self.num_sample += 1
|
|
197
|
+
self._write_num_sample(self.num_sample)
|
|
198
|
+
|
|
199
|
+
def fetch_num_sample(self):
|
|
200
|
+
"""
|
|
201
|
+
get the number of sample in the database from the saved db entry
|
|
202
|
+
:return:
|
|
203
|
+
"""
|
|
204
|
+
with self._lmdb_env.begin(write=True, db=self.meta_db) as txn:
|
|
205
|
+
msg = txn.get(self.num_sample_key, default=self.int_encode(0))
|
|
206
|
+
|
|
207
|
+
return self.int_decode(msg)
|
|
208
|
+
|
|
209
|
+
def _write_num_sample(self, num_sample: int):
|
|
210
|
+
"""
|
|
211
|
+
Write the number of sample into the db
|
|
212
|
+
:param num_sample:
|
|
213
|
+
:return:
|
|
214
|
+
"""
|
|
215
|
+
msg = self.int_encode(num_sample)
|
|
216
|
+
with self._lmdb_env.begin(write=True, db=self.meta_db) as txn:
|
|
217
|
+
txn.put(self.num_sample_key, msg)
|
|
218
|
+
|
|
219
|
+
def _save_index(self, geo_header: GeoDataHeader, key: int):
|
|
220
|
+
"""
|
|
221
|
+
Save the file index i.e. the number linked to a header
|
|
222
|
+
:param geo_header:
|
|
223
|
+
:param key:
|
|
224
|
+
:return:
|
|
225
|
+
"""
|
|
226
|
+
index_key = self.header_to_index_key(geo_header)
|
|
227
|
+
|
|
228
|
+
with self._lmdb_env.begin(write=True, db=self.index_db) as txn:
|
|
229
|
+
txn.put(index_key, key)
|
|
230
|
+
|
|
231
|
+
def _save_header(self, geo_header: GeoDataHeader, key: int):
|
|
232
|
+
"""
|
|
233
|
+
Save only the header to the db (function called by write)
|
|
234
|
+
:param geo_header:
|
|
235
|
+
:param key:
|
|
236
|
+
:return:
|
|
237
|
+
"""
|
|
238
|
+
data = self.header_serializer.serialize(geo_header)
|
|
239
|
+
|
|
240
|
+
with self._lmdb_env.begin(write=True, db=self.header_db) as txn:
|
|
241
|
+
txn.put(key, data)
|
|
242
|
+
|
|
243
|
+
def _save_data(self, geodata: BasicGeoData, key: int):
|
|
244
|
+
"""
|
|
245
|
+
Save only the input part to the db (function called by write)
|
|
246
|
+
:param geodata:
|
|
247
|
+
:param key:
|
|
248
|
+
:return:
|
|
249
|
+
"""
|
|
250
|
+
data = self.data_serializer.serialize(geodata)
|
|
251
|
+
|
|
252
|
+
with self._lmdb_env.begin(write=True, db=self.data_db) as txn:
|
|
253
|
+
txn.put(key, data)
|
|
254
|
+
|
|
255
|
+
def _save_target(self, geodata: BasicGeoData, key: int):
|
|
256
|
+
"""
|
|
257
|
+
Save only the output part to the db (function called by write)
|
|
258
|
+
:param geodata:
|
|
259
|
+
:param key:
|
|
260
|
+
:return:
|
|
261
|
+
"""
|
|
262
|
+
data = self.data_serializer.serialize(geodata)
|
|
263
|
+
|
|
264
|
+
with self._lmdb_env.begin(write=True, db=self.category_db) as txn:
|
|
265
|
+
txn.put(key, data)
|
|
266
|
+
|
|
267
|
+
def open(self):
|
|
268
|
+
"""Open the db for transaction"""
|
|
269
|
+
# Open LMDB environment
|
|
270
|
+
|
|
271
|
+
self._lmdb_env = lmdb.open(self.db_path, map_size=self.map_size_limit, max_dbs=self.num_db)
|
|
272
|
+
|
|
273
|
+
self.index_db = self._lmdb_env.open_db(self.index_db_name)
|
|
274
|
+
self.meta_db = self._lmdb_env.open_db(self.meta_db_name)
|
|
275
|
+
self.header_db = self._lmdb_env.open_db(self.header_db_name)
|
|
276
|
+
self.data_db = self._lmdb_env.open_db(self.input_db_name)
|
|
277
|
+
self.category_db = self._lmdb_env.open_db(self.output_db_name)
|
|
278
|
+
|
|
279
|
+
self.num_sample = self.fetch_num_sample()
|
|
280
|
+
|
|
281
|
+
def close(self):
|
|
282
|
+
"""Close the db"""
|
|
283
|
+
self._lmdb_env.close()
|
|
284
|
+
|
|
285
|
+
def __len__(self):
|
|
286
|
+
"""Return the number of sample in the database
|
|
287
|
+
:return:
|
|
288
|
+
"""
|
|
289
|
+
return self.num_sample
|
|
290
|
+
|
|
291
|
+
# Stores geospatial training data in LMDB format with the following structure:
|
|
292
|
+
# - index_db: Maps (filename, feature_id) to sample keys
|
|
293
|
+
# - meta_db: Stores metadata (total samples, samples per category, etc.)
|
|
294
|
+
# - header_db: MessagePack-serialized headers (geometry, source file, source ID)
|
|
295
|
+
# - data_db: Raster data matrices
|
|
296
|
+
# - out_db: Target outputs/categories
|
|
297
|
+
|
|
298
|
+
class LMDBKeepENVDAO(LMDBasicGeoDataDAO):
|
|
299
|
+
"""LMDB DAO that keeps the environment open for performance.
|
|
300
|
+
|
|
301
|
+
Unlike the base class which opens/closes the environment on each transaction,
|
|
302
|
+
this variant keeps the LMDB environment open throughout its lifetime for
|
|
303
|
+
improved performance in read-heavy workloads.
|
|
304
|
+
|
|
305
|
+
Warning:
|
|
306
|
+
Must be properly closed to avoid resource leaks.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, db_path, map_size_limit=int(1e+11), header_serializer=None, data_serializer=None):
|
|
310
|
+
super().__init__(db_path, map_size_limit, header_serializer, data_serializer)
|
|
311
|
+
self._lmdb_env = lmdb.open(self.db_path, map_size=self.map_size_limit, max_dbs=self.num_db)
|
|
312
|
+
|
|
313
|
+
self.index_db = self._lmdb_env.open_db(self.index_db_name)
|
|
314
|
+
self.meta_db = self._lmdb_env.open_db(self.meta_db_name)
|
|
315
|
+
self.header_db = self._lmdb_env.open_db(self.header_db_name)
|
|
316
|
+
self.data_db = self._lmdb_env.open_db(self.input_db_name)
|
|
317
|
+
self.category_db = self._lmdb_env.open_db(self.output_db_name)
|
|
318
|
+
|
|
319
|
+
self.num_sample = self.fetch_num_sample()
|
|
320
|
+
def open(self):
|
|
321
|
+
"""Open the db for transaction"""
|
|
322
|
+
# Open LMDB environment
|
|
323
|
+
pass
|
|
324
|
+
|
|
325
|
+
# Stores geospatial training data in LMDB format with the following structure:
|
|
326
|
+
# - index_db: Maps (filename, feature_id) to sample keys
|
|
327
|
+
# - meta_db: Stores metadata (total samples, samples per category, etc.)
|
|
328
|
+
# - header_db: MessagePack-serialized headers (geometry, source file, source ID)
|
|
329
|
+
# - data_db: Raster data matrices
|
|
330
|
+
# - out_db: Target outputs/categories
|
|
331
|
+
|
|
332
|
+
def close(self):
|
|
333
|
+
"""Close the db"""
|
|
334
|
+
pass
|
|
335
|
+
|
|
336
|
+
class LMDBReader(GeoDataReader):
|
|
337
|
+
"""High-level reader for LMDB geospatial databases.
|
|
338
|
+
|
|
339
|
+
Attributes:
|
|
340
|
+
db_path: Path to LMDB database.
|
|
341
|
+
header_serializer: Optional custom header serializer.
|
|
342
|
+
data_serializer: Optional custom data serializer.
|
|
343
|
+
keep_env_open: Whether to keep LMDB environment open.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(self, db_path, header_serializer=None, data_serializer=None, keep_env_open=False):
|
|
347
|
+
|
|
348
|
+
if keep_env_open:
|
|
349
|
+
super().__init__(LMDBKeepENVDAO(db_path,
|
|
350
|
+
header_serializer=header_serializer,
|
|
351
|
+
data_serializer=data_serializer))
|
|
352
|
+
else:
|
|
353
|
+
super().__init__(LMDBasicGeoDataDAO(db_path,
|
|
354
|
+
header_serializer=header_serializer,
|
|
355
|
+
data_serializer=data_serializer))
|
|
356
|
+
|
|
357
|
+
class LMDBWriter(GeoDataWriter):
|
|
358
|
+
"""High-level writer for LMDB geospatial databases.
|
|
359
|
+
|
|
360
|
+
Attributes:
|
|
361
|
+
db_path: Path to LMDB database.
|
|
362
|
+
header_serializer: Optional custom header serializer.
|
|
363
|
+
data_serializer: Optional custom data serializer.
|
|
364
|
+
keep_env_open: Whether to keep LMDB environment open.
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(self, db_path, header_serializer=None, data_serializer=None, keep_env_open=False):
|
|
368
|
+
|
|
369
|
+
if keep_env_open:
|
|
370
|
+
super().__init__(LMDBKeepENVDAO(db_path,
|
|
371
|
+
header_serializer=header_serializer,
|
|
372
|
+
data_serializer=data_serializer))
|
|
373
|
+
else:
|
|
374
|
+
super().__init__(LMDBasicGeoDataDAO(db_path,
|
|
375
|
+
header_serializer=header_serializer,
|
|
376
|
+
data_serializer=data_serializer))
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
|