eoml 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eoml/__init__.py +74 -0
- eoml/automation/__init__.py +7 -0
- eoml/automation/configuration.py +105 -0
- eoml/automation/dag.py +233 -0
- eoml/automation/experience.py +618 -0
- eoml/automation/tasks.py +825 -0
- eoml/bin/__init__.py +6 -0
- eoml/bin/clean_checkpoint.py +146 -0
- eoml/bin/land_cover_mapping_toml.py +435 -0
- eoml/bin/mosaic_images.py +137 -0
- eoml/data/__init__.py +7 -0
- eoml/data/basic_geo_data.py +214 -0
- eoml/data/dataset_utils.py +98 -0
- eoml/data/persistence/__init__.py +7 -0
- eoml/data/persistence/generic.py +253 -0
- eoml/data/persistence/lmdb.py +379 -0
- eoml/data/persistence/serializer.py +82 -0
- eoml/raster/__init__.py +7 -0
- eoml/raster/band.py +141 -0
- eoml/raster/dataset/__init__.py +6 -0
- eoml/raster/dataset/extractor.py +604 -0
- eoml/raster/raster_reader.py +602 -0
- eoml/raster/raster_utils.py +116 -0
- eoml/torch/__init__.py +7 -0
- eoml/torch/cnn/__init__.py +7 -0
- eoml/torch/cnn/augmentation.py +150 -0
- eoml/torch/cnn/dataset_evaluator.py +68 -0
- eoml/torch/cnn/db_dataset.py +605 -0
- eoml/torch/cnn/map_dataset.py +579 -0
- eoml/torch/cnn/map_dataset_const_mem.py +135 -0
- eoml/torch/cnn/outputs_transformer.py +130 -0
- eoml/torch/cnn/torch_utils.py +404 -0
- eoml/torch/cnn/training_dataset.py +241 -0
- eoml/torch/cnn/windows_dataset.py +120 -0
- eoml/torch/dataset/__init__.py +6 -0
- eoml/torch/dataset/shade_dataset_tester.py +46 -0
- eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
- eoml/torch/model_low_use.py +507 -0
- eoml/torch/models.py +282 -0
- eoml/torch/resnet.py +437 -0
- eoml/torch/sample_statistic.py +260 -0
- eoml/torch/trainer.py +782 -0
- eoml/torch/trainer_v2.py +253 -0
- eoml-0.9.0.dist-info/METADATA +93 -0
- eoml-0.9.0.dist-info/RECORD +47 -0
- eoml-0.9.0.dist-info/WHEEL +4 -0
- eoml-0.9.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
"""PyTorch datasets for reading training data from LMDB databases.
|
|
2
|
+
|
|
3
|
+
This module provides dataset classes that read image patches and labels from LMDB
|
|
4
|
+
databases, with support for data augmentation, label mapping, and multi-database access.
|
|
5
|
+
Includes utilities for mapping between database labels and neural network outputs.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import csv
|
|
9
|
+
import logging
|
|
10
|
+
from collections import Counter
|
|
11
|
+
from typing import List, Dict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
from eoml.data.persistence.lmdb import LMDBReader
|
|
16
|
+
from eoml.torch.cnn.outputs_transformer import ArgMaxToCategory, ArgMax
|
|
17
|
+
from torch.utils.data import Dataset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
def sample_list(keys_out, mapper, filter_na=True):
|
|
23
|
+
"""Transform id:value pairs to id:nn_output using mapper.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
keys_out (dict): Dictionary mapping sample IDs to database values.
|
|
27
|
+
mapper: Mapper object with __call__ method for value transformation.
|
|
28
|
+
filter_na (bool, optional): Filter out samples with invalid output. Defaults to True.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
list: List of (id, nn_output) tuples.
|
|
32
|
+
"""
|
|
33
|
+
if filter_na:
|
|
34
|
+
sample = [(id, mapper(val)) for id, val in keys_out.items() if mapper(val) != mapper.no_target]
|
|
35
|
+
else:
|
|
36
|
+
sample = [(id, mapper(val)) for id, val in keys_out.items()]
|
|
37
|
+
|
|
38
|
+
return sample
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def sample_list_id(keys_out, mapper, filter_na=True):
|
|
42
|
+
"""Return list of sample IDs, optionally filtering invalid outputs.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
keys_out (dict): Dictionary mapping sample IDs to database values.
|
|
46
|
+
mapper: Mapper object with __call__ method for value transformation.
|
|
47
|
+
filter_na (bool, optional): Filter out samples with invalid output. Defaults to True.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
list: List of sample IDs.
|
|
51
|
+
"""
|
|
52
|
+
if filter_na:
|
|
53
|
+
sample = [id for id, val in keys_out.items() if mapper(val) != mapper.no_target]
|
|
54
|
+
else:
|
|
55
|
+
sample = [id for id, val in keys_out.items()]
|
|
56
|
+
|
|
57
|
+
return sample
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class NNOutput:
|
|
61
|
+
"""Represents a neural network output category with associated database labels.
|
|
62
|
+
|
|
63
|
+
Maps database label values to neural network output indices and final map values.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
name (str): Name of the output category.
|
|
67
|
+
map_out (int): Integer value to write to output map.
|
|
68
|
+
nn_out (int): Neural network output index for this category.
|
|
69
|
+
labels_value (list): Database label values that map to this category.
|
|
70
|
+
labels_name (list): Human-readable names for labels.
|
|
71
|
+
"""
|
|
72
|
+
"""
|
|
73
|
+
Represent one possible output of a neural network and the value which should be given to the map
|
|
74
|
+
If no map output is specified, it will be the same as the neural network argmax value of the output
|
|
75
|
+
"""
|
|
76
|
+
def __init__(self, name: str, labels_value: List, labels_name: List, nn_out:int, map_out: int = None):
|
|
77
|
+
"""Initialize NNOutput category.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
name (str): Name of the output category.
|
|
81
|
+
labels_value (List): Database label values that map to this category.
|
|
82
|
+
labels_name (List): Human-readable names for the labels.
|
|
83
|
+
nn_out (int): Neural network output index for this category.
|
|
84
|
+
map_out (int, optional): Value to write to output map. Defaults to nn_out if None.
|
|
85
|
+
"""
|
|
86
|
+
self.name = name
|
|
87
|
+
self.map_out = map_out
|
|
88
|
+
self.nn_out = nn_out
|
|
89
|
+
|
|
90
|
+
self.labels_value = labels_value.copy()
|
|
91
|
+
self.labels_name = labels_name.copy()
|
|
92
|
+
|
|
93
|
+
def __repr__(self):
|
|
94
|
+
return f'NNOutput(name: {repr(self.name)}, ' \
|
|
95
|
+
f' nn_out: {repr(self.nn_out)}, ' \
|
|
96
|
+
f' map_out:{repr(self.map_out)}, ' \
|
|
97
|
+
f' labels_value: {repr(self.labels_value)}, ' \
|
|
98
|
+
f'labels_name: {repr(self.labels_name)})'
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Mapper:
|
|
102
|
+
"""Maps database values to neural network outputs.
|
|
103
|
+
|
|
104
|
+
Builds a dictionary mapping database label values to neural network output indices.
|
|
105
|
+
Supports grouping multiple database labels into single NN categories.
|
|
106
|
+
|
|
107
|
+
Attributes:
|
|
108
|
+
output_list (List[NNOutput]): List of output categories.
|
|
109
|
+
dictionary (dict): Mapping from database values to NN outputs.
|
|
110
|
+
no_target: Value returned for invalid/missing labels.
|
|
111
|
+
vectorize (bool): Whether to use one-hot vector outputs.
|
|
112
|
+
label_dictionary (Dict[str,int], optional): Maps label names to values.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, no_target=-1, vectorize=False, label_dictionary: Dict[str,int]=None):
|
|
116
|
+
"""Initialize Mapper.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
no_target (int, optional): Value for invalid targets. Defaults to -1.
|
|
120
|
+
vectorize (bool, optional): Use one-hot vectors instead of scalar outputs.
|
|
121
|
+
Defaults to False.
|
|
122
|
+
label_dictionary (Dict[str,int], optional): Maps label names to integer values.
|
|
123
|
+
Defaults to None.
|
|
124
|
+
"""
|
|
125
|
+
self.output_list: List[NNOutput] = []
|
|
126
|
+
self.dictionary = {}
|
|
127
|
+
self.no_target = no_target
|
|
128
|
+
self.vectorize = vectorize
|
|
129
|
+
|
|
130
|
+
self.label_dictionary = label_dictionary
|
|
131
|
+
|
|
132
|
+
def __repr__(self):
|
|
133
|
+
return f'Mapper(output_list: {repr(self.output_list)}, ' \
|
|
134
|
+
f'dictionary: {repr(self.dictionary)}, ' \
|
|
135
|
+
f'no_target: {repr(self.no_target)}, ' \
|
|
136
|
+
f'vectorize: {repr(self.vectorize)})' \
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def load_dic_from_file(self, csv_path):
|
|
140
|
+
with open(csv_path, mode='r') as infile:
|
|
141
|
+
reader = csv.reader(infile)
|
|
142
|
+
self.label_dictionary = {rows[0]: rows[1] for rows in reader}
|
|
143
|
+
|
|
144
|
+
def map_value_names(self):
|
|
145
|
+
return [(output.map_out, output.name) for output in self.output_list]
|
|
146
|
+
|
|
147
|
+
def map_names(self):
|
|
148
|
+
return [output.name for output in self.output_list]
|
|
149
|
+
|
|
150
|
+
def map_values(self):
|
|
151
|
+
return [output.map_out for output in self.output_list]
|
|
152
|
+
|
|
153
|
+
def nn_name(self):
|
|
154
|
+
return [output.name for output in self.output_list]
|
|
155
|
+
|
|
156
|
+
def add_category(self, name, labels, map_value=None):
|
|
157
|
+
""" add an output to the neural network"""
|
|
158
|
+
|
|
159
|
+
labels_name = labels
|
|
160
|
+
|
|
161
|
+
if self.label_dictionary is not None:
|
|
162
|
+
labels_values = [self.label_dictionary[name] for name in labels]
|
|
163
|
+
else:
|
|
164
|
+
labels_values = labels
|
|
165
|
+
|
|
166
|
+
if map_value is None:
|
|
167
|
+
map_value = len(self)
|
|
168
|
+
|
|
169
|
+
category = NNOutput(name, labels_values, labels_name, len(self), map_value)
|
|
170
|
+
|
|
171
|
+
self.output_list.append(category)
|
|
172
|
+
self._update_dictionary(category)
|
|
173
|
+
|
|
174
|
+
def __len__(self):
|
|
175
|
+
return len(self.output_list)
|
|
176
|
+
|
|
177
|
+
def _vectorize(self, no_target=-1):
|
|
178
|
+
"""Transform to integer output to vector output """
|
|
179
|
+
|
|
180
|
+
for i, value in enumerate(self.output_list):
|
|
181
|
+
out = np.zeros(len(self))
|
|
182
|
+
out[i] = 1
|
|
183
|
+
value.nn_out = out
|
|
184
|
+
|
|
185
|
+
# if no target has len we assum it fine and no need to touch
|
|
186
|
+
if hasattr(no_target, '__len__'):
|
|
187
|
+
self.no_target = no_target
|
|
188
|
+
else:
|
|
189
|
+
# set it to a vector of 0
|
|
190
|
+
self.no_target = np.zeros(len(self))
|
|
191
|
+
|
|
192
|
+
def __call__(self, value):
|
|
193
|
+
return self.dictionary.get(value, self.no_target)
|
|
194
|
+
|
|
195
|
+
def map_output_transformer(self):
|
|
196
|
+
"""Return a transformer to change transform the output of the nn (assuming argmax is used)"""
|
|
197
|
+
is_identity = True
|
|
198
|
+
nn_out_to_map = []
|
|
199
|
+
for i, out in enumerate(self.output_list):
|
|
200
|
+
is_identity = is_identity and (i == out.map_out)
|
|
201
|
+
nn_out_to_map.append(out.map_out)
|
|
202
|
+
|
|
203
|
+
return ArgMax() if is_identity else ArgMaxToCategory(nn_out_to_map)
|
|
204
|
+
|
|
205
|
+
def _update_dictionary(self, output: NNOutput):
|
|
206
|
+
"""Update the dictionary transforming mapping the label to the output based ont the new nn output """
|
|
207
|
+
for value in output.labels_value:
|
|
208
|
+
if value in self.dictionary.keys():
|
|
209
|
+
logger.warning(f"{value} appears twice in the label-value mapping. One value has been ignored.")
|
|
210
|
+
self.dictionary[value] = output.nn_out
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def db_dataset_multi_proc_init(worker_id):
|
|
214
|
+
"""This function initialise the dataset in a way that the database reader environment is keep open during the full
|
|
215
|
+
process life
|
|
216
|
+
|
|
217
|
+
Used to fix:
|
|
218
|
+
Issue with newer version and lmdb, keep the db env open for the worker
|
|
219
|
+
https://github.com/jnwatson/py-lmdb/issues/340
|
|
220
|
+
|
|
221
|
+
"""
|
|
222
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
223
|
+
dataset: DBDataset = worker_info.dataset
|
|
224
|
+
dataset.init_db_environment(True)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class DBDataset(Dataset):
|
|
228
|
+
"""PyTorch Dataset for reading training samples from LMDB database.
|
|
229
|
+
|
|
230
|
+
Reads image patches and labels from LMDB database with optional label mapping
|
|
231
|
+
and data augmentation. Supports both single-threaded and multi-threaded data loading.
|
|
232
|
+
|
|
233
|
+
Attributes:
|
|
234
|
+
multithread (bool): Whether to use multi-threaded data loading.
|
|
235
|
+
db_path (str): Path to LMDB database.
|
|
236
|
+
samples_list (np.ndarray): Array of sample IDs to fetch.
|
|
237
|
+
target_mapper: Mapper for transforming database labels to NN outputs.
|
|
238
|
+
f_transform: Data augmentation function.
|
|
239
|
+
transform_param (np.ndarray, optional): Parameters for augmentation per sample.
|
|
240
|
+
reader (LMDBReader): Database reader instance.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(self, db_path, samples_list, target_mapper=None, f_transform=None, transform_param=None, multithread=True):
|
|
244
|
+
"""Initialize DBDataset.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
db_path (str): Path to LMDB database file.
|
|
248
|
+
samples_list (list): List of sample IDs to include in dataset.
|
|
249
|
+
target_mapper (Mapper, optional): Maps database labels to NN outputs.
|
|
250
|
+
Defaults to None.
|
|
251
|
+
f_transform (callable, optional): Data augmentation function. Defaults to None.
|
|
252
|
+
transform_param (list, optional): Per-sample augmentation parameters.
|
|
253
|
+
Must be numeric types to avoid memory leaks. Defaults to None.
|
|
254
|
+
multithread (bool, optional): Enable multi-threaded loading. Defaults to True.
|
|
255
|
+
|
|
256
|
+
Note:
|
|
257
|
+
- Use numpy arrays (not Python objects) for transform_param to avoid memory leaks
|
|
258
|
+
(see https://github.com/pytorch/pytorch/issues/13246)
|
|
259
|
+
- For multithread=True, use db_dataset_multi_proc_init as worker_init_fn
|
|
260
|
+
(see https://github.com/jnwatson/py-lmdb/issues/340)
|
|
261
|
+
"""
|
|
262
|
+
super().__init__()
|
|
263
|
+
|
|
264
|
+
self.multithread = multithread
|
|
265
|
+
|
|
266
|
+
self.db_path = db_path
|
|
267
|
+
|
|
268
|
+
# normal list cause memory leak
|
|
269
|
+
self.samples_list = np.array(samples_list)
|
|
270
|
+
self.target_mapper = target_mapper
|
|
271
|
+
|
|
272
|
+
self.f_transform = f_transform
|
|
273
|
+
|
|
274
|
+
if transform_param is not None:
|
|
275
|
+
self.transform_param = np.array(transform_param)
|
|
276
|
+
else:
|
|
277
|
+
self.transform_param = None
|
|
278
|
+
|
|
279
|
+
self.reader = None
|
|
280
|
+
if not multithread:
|
|
281
|
+
self.init_db_environment(False)
|
|
282
|
+
|
|
283
|
+
def init_db_environment(self, keep_env_open):
|
|
284
|
+
"""init the db environment, Used to init it in different process using the worker_init_fun from
|
|
285
|
+
the datasetloader"""
|
|
286
|
+
self.reader = LMDBReader(self.db_path, keep_env_open=keep_env_open)
|
|
287
|
+
|
|
288
|
+
def __len__(self):
|
|
289
|
+
return len(self.samples_list)
|
|
290
|
+
|
|
291
|
+
def __getitem__(self, idx):
|
|
292
|
+
|
|
293
|
+
# TODO we create the reader here for multithreading env, but maybe could be done differently
|
|
294
|
+
|
|
295
|
+
with self.reader as db:
|
|
296
|
+
|
|
297
|
+
if hasattr(idx, '__iter__'):
|
|
298
|
+
return self._get_items(idx, db)
|
|
299
|
+
|
|
300
|
+
if isinstance(idx, int):
|
|
301
|
+
return self._get_one_item(idx, db)
|
|
302
|
+
|
|
303
|
+
if isinstance(idx, slice):
|
|
304
|
+
# Get the start, stop, and step from the slice
|
|
305
|
+
return self._get_items(range(idx.start, idx.stop, idx.step), db)
|
|
306
|
+
|
|
307
|
+
def batch_statistic(self):
|
|
308
|
+
|
|
309
|
+
reader = LMDBReader(self.db_path)
|
|
310
|
+
|
|
311
|
+
with reader:
|
|
312
|
+
outputs = [reader.get_output(s) for s in self.samples_list]
|
|
313
|
+
|
|
314
|
+
outputs = [self.target_mapper(val) for val in outputs]
|
|
315
|
+
|
|
316
|
+
counter = Counter(outputs)
|
|
317
|
+
logger.info(f"counter: {counter}")
|
|
318
|
+
#total = counter.total()
|
|
319
|
+
# highest appearing value
|
|
320
|
+
maximum = counter.most_common()[0][1]
|
|
321
|
+
|
|
322
|
+
logger.info(f"most common category {maximum}")
|
|
323
|
+
|
|
324
|
+
return {key: maximum / val for key, val in counter.items()}
|
|
325
|
+
|
|
326
|
+
def weight_list(self):
|
|
327
|
+
|
|
328
|
+
weight_dic = self.batch_statistic()
|
|
329
|
+
|
|
330
|
+
reader = LMDBReader(self.db_path)
|
|
331
|
+
with reader:
|
|
332
|
+
weights = []
|
|
333
|
+
for s in self.samples_list:
|
|
334
|
+
target_val = self.target_mapper(reader.get_output(s))
|
|
335
|
+
weights.append(weight_dic[target_val])
|
|
336
|
+
|
|
337
|
+
return weights
|
|
338
|
+
|
|
339
|
+
def _get_items_deprecated(self, iterable, reader):
|
|
340
|
+
inputs = []
|
|
341
|
+
targets = []
|
|
342
|
+
|
|
343
|
+
# see default collate for better memory management
|
|
344
|
+
|
|
345
|
+
for key in iterable:
|
|
346
|
+
# inputs[i], targets[i] = self._get_one_item(key, reader)
|
|
347
|
+
input, target = self._get_one_item(key, reader)
|
|
348
|
+
inputs.append(input)
|
|
349
|
+
targets.append(target)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
#batch = len(inputs)
|
|
353
|
+
# taken from default_collate, we initialise the space on a shared memory, to avoid extra copy
|
|
354
|
+
# could also be done for target
|
|
355
|
+
#storage = input.storage()._new_shared(len(inputs) * input.numel(), device=input.device)
|
|
356
|
+
#out = input.new(storage).resize_(batch, *list(input.size()))
|
|
357
|
+
#torch.stack(inputs, out=out)
|
|
358
|
+
|
|
359
|
+
#return out, torch.LongTensor(targets)
|
|
360
|
+
|
|
361
|
+
return torch.stack(inputs), torch.LongTensor(targets)
|
|
362
|
+
|
|
363
|
+
def _get_items(self, iterable, reader):
|
|
364
|
+
# datas = []
|
|
365
|
+
# labels = []
|
|
366
|
+
batch = len(iterable)
|
|
367
|
+
|
|
368
|
+
iterable = iterable.__iter__()
|
|
369
|
+
|
|
370
|
+
try:
|
|
371
|
+
(one_input,), target = self._get_one_item(next(iterable), reader)
|
|
372
|
+
|
|
373
|
+
except StopIteration:
|
|
374
|
+
return []
|
|
375
|
+
|
|
376
|
+
# compute the shape
|
|
377
|
+
shape_in = (batch,) + one_input.shape
|
|
378
|
+
|
|
379
|
+
if isinstance(target, int):
|
|
380
|
+
shape_out = batch
|
|
381
|
+
targets = torch.empty(shape_out, dtype=torch.long)
|
|
382
|
+
else:
|
|
383
|
+
shape_out = (batch,) + target.shape
|
|
384
|
+
targets = torch.empty(shape_out, dtype=torch.float32)
|
|
385
|
+
|
|
386
|
+
inputs = torch.empty(shape_in, dtype=torch.float32)
|
|
387
|
+
|
|
388
|
+
inputs[0] = one_input
|
|
389
|
+
targets[0] = target
|
|
390
|
+
|
|
391
|
+
for i, key in enumerate(iterable, 1):
|
|
392
|
+
# the nn take on parameter so we unpack the 1 tuples and make it for the batch
|
|
393
|
+
(inputs[i],), targets[i] = self._get_one_item(key, reader)
|
|
394
|
+
# the nn take on parameter so we make a 1 element tuple
|
|
395
|
+
return (inputs,), targets
|
|
396
|
+
|
|
397
|
+
def _get_one_item(self, idx, reader):
|
|
398
|
+
key = self.samples_list[idx]
|
|
399
|
+
|
|
400
|
+
# key = [db_key, tranfsorm_param]
|
|
401
|
+
inputs, target = reader.get_data(int(key))
|
|
402
|
+
|
|
403
|
+
inputs = torch.from_numpy(inputs.copy())
|
|
404
|
+
|
|
405
|
+
if self.f_transform is not None:
|
|
406
|
+
|
|
407
|
+
if self.transform_param is not None:
|
|
408
|
+
param = self.transform_param[idx]
|
|
409
|
+
inputs = self.f_transform(inputs, *param)
|
|
410
|
+
else:
|
|
411
|
+
inputs = self.f_transform(inputs)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
if self.target_mapper is not None:
|
|
415
|
+
target = self.target_mapper(target)
|
|
416
|
+
|
|
417
|
+
# the nn take on parameter so we make a 1 element tuple
|
|
418
|
+
return (inputs,), target
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class DBDatasetMeta(DBDataset):
|
|
423
|
+
"""
|
|
424
|
+
read dataset from db.
|
|
425
|
+
"""
|
|
426
|
+
|
|
427
|
+
def __init__(self, db_path, samples_list, target_mapper=None, f_transform=None, transform_param=None):
|
|
428
|
+
"""
|
|
429
|
+
|
|
430
|
+
:param db_path: path of the db to open
|
|
431
|
+
:param samples_list: a list of key to fetch from the db
|
|
432
|
+
:param target_mapper: map the db output to the nn output.
|
|
433
|
+
|
|
434
|
+
https://github.com/pytorch/pytorch/issues/13246 use numpy of NOT OBJ to solve memory leak
|
|
435
|
+
transfomration parameter should all be number!!!
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
"""
|
|
439
|
+
super().__init__(db_path, samples_list, target_mapper, f_transform, transform_param)
|
|
440
|
+
|
|
441
|
+
def __len__(self):
|
|
442
|
+
return len(self.samples_list)
|
|
443
|
+
|
|
444
|
+
def __getitem__(self, idx):
|
|
445
|
+
|
|
446
|
+
with self.reader as db:
|
|
447
|
+
|
|
448
|
+
if hasattr(idx, '__iter__'):
|
|
449
|
+
inputs, outputs = self._get_items(idx, db)
|
|
450
|
+
headers = self._get_headers(idx, db)
|
|
451
|
+
return inputs, outputs, headers
|
|
452
|
+
|
|
453
|
+
if isinstance(idx, int):
|
|
454
|
+
_input, output = self._get_one_item(idx, db)
|
|
455
|
+
header = self._get_header(idx, db)
|
|
456
|
+
return _input, output, header
|
|
457
|
+
|
|
458
|
+
if isinstance(idx, slice):
|
|
459
|
+
# Get the start, stop, and step from the slice
|
|
460
|
+
inputs, outputs = self._get_items(range(idx.start, idx.stop, idx.step), db)
|
|
461
|
+
headers = self._get_headers(range(idx.start, idx.stop, idx.step), db)
|
|
462
|
+
return inputs, outputs, headers
|
|
463
|
+
|
|
464
|
+
def _get_headers(self, iterable, reader):
|
|
465
|
+
return [self._get_header(h, reader) for h in iterable]
|
|
466
|
+
|
|
467
|
+
def _get_header(self, idx, reader):
|
|
468
|
+
key = int(self.samples_list[idx])
|
|
469
|
+
return reader.get_header(key)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
class DBInfo:
|
|
473
|
+
def __init__(self, db_path, sample_list, target_mapper, f_transform=None, transform_param=None):
|
|
474
|
+
self.db_path = db_path
|
|
475
|
+
self.sample_list = sample_list
|
|
476
|
+
self.target_mapper = target_mapper
|
|
477
|
+
self.f_transform = f_transform
|
|
478
|
+
self.transform_param = transform_param
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class MultiDBDataset(Dataset):
|
|
482
|
+
"""
|
|
483
|
+
read dataset from multiple db
|
|
484
|
+
TODO check if helper function of torch can not replace that.
|
|
485
|
+
TODO 2 check if get one item function
|
|
486
|
+
"""
|
|
487
|
+
def __init__(self, db_info: List[DBInfo], multithread=True):
|
|
488
|
+
"""
|
|
489
|
+
|
|
490
|
+
:param db_path: path of the db to open
|
|
491
|
+
:param samples_list: a list of key to fetch from the db
|
|
492
|
+
:param target_mapper: map the db output to the nn output.
|
|
493
|
+
|
|
494
|
+
https://github.com/pytorch/pytorch/issues/13246 use numpy of NOT OBJ to solve memory leak not implemented here
|
|
495
|
+
transfomration parameter should all be number!!!
|
|
496
|
+
"""
|
|
497
|
+
super().__init__()
|
|
498
|
+
|
|
499
|
+
self.db_info: List[DBInfo] = db_info
|
|
500
|
+
|
|
501
|
+
self.size = 0
|
|
502
|
+
|
|
503
|
+
#todo improve
|
|
504
|
+
index =0
|
|
505
|
+
self.samples_index= []
|
|
506
|
+
for db_index, db in enumerate(self.db_info):
|
|
507
|
+
self.size += len(db.sample_list)
|
|
508
|
+
|
|
509
|
+
for sample_index in db.sample_list:
|
|
510
|
+
# sample list is the pair db and sample index in db
|
|
511
|
+
self.samples_index.append((db_index, sample_index))
|
|
512
|
+
index += 1
|
|
513
|
+
|
|
514
|
+
self.readers=None
|
|
515
|
+
|
|
516
|
+
if not multithread:
|
|
517
|
+
self.init_db_environment(False)
|
|
518
|
+
def init_db_environment(self, keep_env_open):
|
|
519
|
+
"""init the db environment, Used to init it in different process using the worker_init_fun from
|
|
520
|
+
the datasetloader"""
|
|
521
|
+
self.readers = [LMDBReader(db.db_path, keep_env_open=keep_env_open) for db in self.db_info]
|
|
522
|
+
|
|
523
|
+
def __len__(self):
|
|
524
|
+
return self.size
|
|
525
|
+
|
|
526
|
+
def __getitem__(self, idx):
|
|
527
|
+
|
|
528
|
+
# TODO we create the reader here for multithreading env, but maybe could be done differently
|
|
529
|
+
|
|
530
|
+
try:
|
|
531
|
+
|
|
532
|
+
for reader in self.readers:
|
|
533
|
+
reader.open()
|
|
534
|
+
|
|
535
|
+
if hasattr(idx, '__iter__'):
|
|
536
|
+
return self._get_items(idx, self.readers)
|
|
537
|
+
|
|
538
|
+
if isinstance(idx, int):
|
|
539
|
+
return self._get_one_item(idx, self.readers)
|
|
540
|
+
|
|
541
|
+
if isinstance(idx, slice):
|
|
542
|
+
# Get the start, stop, and step from the slice
|
|
543
|
+
return self._get_items(range(idx.start, idx.stop, idx.step), self.readers)
|
|
544
|
+
|
|
545
|
+
except Exception as e:
|
|
546
|
+
raise e
|
|
547
|
+
|
|
548
|
+
finally:
|
|
549
|
+
for reader in self.readers:
|
|
550
|
+
reader.close()
|
|
551
|
+
|
|
552
|
+
def _get_items(self, iterable, readers):
|
|
553
|
+
# datas = []
|
|
554
|
+
# labels = []
|
|
555
|
+
batch = len(iterable)
|
|
556
|
+
|
|
557
|
+
iterable = iterable.__iter__()
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
try:
|
|
561
|
+
idx = next(iterable)
|
|
562
|
+
input, target = self._get_one_item(idx, readers)
|
|
563
|
+
|
|
564
|
+
except StopIteration:
|
|
565
|
+
return []
|
|
566
|
+
|
|
567
|
+
# compute the shape
|
|
568
|
+
shape_in = (batch,) + input.shape
|
|
569
|
+
|
|
570
|
+
if isinstance(target, int):
|
|
571
|
+
shape_out = batch
|
|
572
|
+
targets = torch.empty(shape_out, dtype=torch.long)
|
|
573
|
+
else:
|
|
574
|
+
shape_out = (batch,) + target.shape
|
|
575
|
+
targets = torch.empty(shape_out, dtype=torch.float32)
|
|
576
|
+
|
|
577
|
+
inputs = torch.empty(shape_in, dtype=torch.float32)
|
|
578
|
+
|
|
579
|
+
inputs[0] = input
|
|
580
|
+
targets[0] = target
|
|
581
|
+
|
|
582
|
+
for i, idx in enumerate(iterable, 1):
|
|
583
|
+
|
|
584
|
+
inputs[i], targets[i] = self._get_one_item(idx, readers)
|
|
585
|
+
|
|
586
|
+
return inputs, targets
|
|
587
|
+
|
|
588
|
+
def _get_one_item(self, idx, readers):
|
|
589
|
+
db_index, id_sample = self.samples_index[idx]
|
|
590
|
+
inputs, target = readers[db_index].get_data(id_sample)
|
|
591
|
+
|
|
592
|
+
inputs = torch.from_numpy(inputs.copy())
|
|
593
|
+
|
|
594
|
+
if self.db_info[db_index].f_transform is not None:
|
|
595
|
+
|
|
596
|
+
if self.db_info[db_index].transform_param is not None:
|
|
597
|
+
param = self.db_info[db_index].transform_param[id_sample]
|
|
598
|
+
inputs = self.db_info[db_index].f_transform(inputs, *param)
|
|
599
|
+
else:
|
|
600
|
+
inputs = self.db_info[db_index].f_transform(inputs)
|
|
601
|
+
|
|
602
|
+
if self.db_info[db_index].target_mapper is not None:
|
|
603
|
+
target = self.db_info[db_index].target_mapper(target)
|
|
604
|
+
|
|
605
|
+
return inputs, target
|