eoml 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. eoml/__init__.py +74 -0
  2. eoml/automation/__init__.py +7 -0
  3. eoml/automation/configuration.py +105 -0
  4. eoml/automation/dag.py +233 -0
  5. eoml/automation/experience.py +618 -0
  6. eoml/automation/tasks.py +825 -0
  7. eoml/bin/__init__.py +6 -0
  8. eoml/bin/clean_checkpoint.py +146 -0
  9. eoml/bin/land_cover_mapping_toml.py +435 -0
  10. eoml/bin/mosaic_images.py +137 -0
  11. eoml/data/__init__.py +7 -0
  12. eoml/data/basic_geo_data.py +214 -0
  13. eoml/data/dataset_utils.py +98 -0
  14. eoml/data/persistence/__init__.py +7 -0
  15. eoml/data/persistence/generic.py +253 -0
  16. eoml/data/persistence/lmdb.py +379 -0
  17. eoml/data/persistence/serializer.py +82 -0
  18. eoml/raster/__init__.py +7 -0
  19. eoml/raster/band.py +141 -0
  20. eoml/raster/dataset/__init__.py +6 -0
  21. eoml/raster/dataset/extractor.py +604 -0
  22. eoml/raster/raster_reader.py +602 -0
  23. eoml/raster/raster_utils.py +116 -0
  24. eoml/torch/__init__.py +7 -0
  25. eoml/torch/cnn/__init__.py +7 -0
  26. eoml/torch/cnn/augmentation.py +150 -0
  27. eoml/torch/cnn/dataset_evaluator.py +68 -0
  28. eoml/torch/cnn/db_dataset.py +605 -0
  29. eoml/torch/cnn/map_dataset.py +579 -0
  30. eoml/torch/cnn/map_dataset_const_mem.py +135 -0
  31. eoml/torch/cnn/outputs_transformer.py +130 -0
  32. eoml/torch/cnn/torch_utils.py +404 -0
  33. eoml/torch/cnn/training_dataset.py +241 -0
  34. eoml/torch/cnn/windows_dataset.py +120 -0
  35. eoml/torch/dataset/__init__.py +6 -0
  36. eoml/torch/dataset/shade_dataset_tester.py +46 -0
  37. eoml/torch/dataset/shade_tree_dataset_creators.py +537 -0
  38. eoml/torch/model_low_use.py +507 -0
  39. eoml/torch/models.py +282 -0
  40. eoml/torch/resnet.py +437 -0
  41. eoml/torch/sample_statistic.py +260 -0
  42. eoml/torch/trainer.py +782 -0
  43. eoml/torch/trainer_v2.py +253 -0
  44. eoml-0.9.0.dist-info/METADATA +93 -0
  45. eoml-0.9.0.dist-info/RECORD +47 -0
  46. eoml-0.9.0.dist-info/WHEEL +4 -0
  47. eoml-0.9.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,605 @@
1
+ """PyTorch datasets for reading training data from LMDB databases.
2
+
3
+ This module provides dataset classes that read image patches and labels from LMDB
4
+ databases, with support for data augmentation, label mapping, and multi-database access.
5
+ Includes utilities for mapping between database labels and neural network outputs.
6
+ """
7
+
8
+ import csv
9
+ import logging
10
+ from collections import Counter
11
+ from typing import List, Dict
12
+
13
+ import numpy as np
14
+ import torch
15
+ from eoml.data.persistence.lmdb import LMDBReader
16
+ from eoml.torch.cnn.outputs_transformer import ArgMaxToCategory, ArgMax
17
+ from torch.utils.data import Dataset
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def sample_list(keys_out, mapper, filter_na=True):
23
+ """Transform id:value pairs to id:nn_output using mapper.
24
+
25
+ Args:
26
+ keys_out (dict): Dictionary mapping sample IDs to database values.
27
+ mapper: Mapper object with __call__ method for value transformation.
28
+ filter_na (bool, optional): Filter out samples with invalid output. Defaults to True.
29
+
30
+ Returns:
31
+ list: List of (id, nn_output) tuples.
32
+ """
33
+ if filter_na:
34
+ sample = [(id, mapper(val)) for id, val in keys_out.items() if mapper(val) != mapper.no_target]
35
+ else:
36
+ sample = [(id, mapper(val)) for id, val in keys_out.items()]
37
+
38
+ return sample
39
+
40
+
41
+ def sample_list_id(keys_out, mapper, filter_na=True):
42
+ """Return list of sample IDs, optionally filtering invalid outputs.
43
+
44
+ Args:
45
+ keys_out (dict): Dictionary mapping sample IDs to database values.
46
+ mapper: Mapper object with __call__ method for value transformation.
47
+ filter_na (bool, optional): Filter out samples with invalid output. Defaults to True.
48
+
49
+ Returns:
50
+ list: List of sample IDs.
51
+ """
52
+ if filter_na:
53
+ sample = [id for id, val in keys_out.items() if mapper(val) != mapper.no_target]
54
+ else:
55
+ sample = [id for id, val in keys_out.items()]
56
+
57
+ return sample
58
+
59
+
60
+ class NNOutput:
61
+ """Represents a neural network output category with associated database labels.
62
+
63
+ Maps database label values to neural network output indices and final map values.
64
+
65
+ Attributes:
66
+ name (str): Name of the output category.
67
+ map_out (int): Integer value to write to output map.
68
+ nn_out (int): Neural network output index for this category.
69
+ labels_value (list): Database label values that map to this category.
70
+ labels_name (list): Human-readable names for labels.
71
+ """
72
+ """
73
+ Represent one possible output of a neural network and the value which should be given to the map
74
+ If no map output is specified, it will be the same as the neural network argmax value of the output
75
+ """
76
+ def __init__(self, name: str, labels_value: List, labels_name: List, nn_out:int, map_out: int = None):
77
+ """Initialize NNOutput category.
78
+
79
+ Args:
80
+ name (str): Name of the output category.
81
+ labels_value (List): Database label values that map to this category.
82
+ labels_name (List): Human-readable names for the labels.
83
+ nn_out (int): Neural network output index for this category.
84
+ map_out (int, optional): Value to write to output map. Defaults to nn_out if None.
85
+ """
86
+ self.name = name
87
+ self.map_out = map_out
88
+ self.nn_out = nn_out
89
+
90
+ self.labels_value = labels_value.copy()
91
+ self.labels_name = labels_name.copy()
92
+
93
+ def __repr__(self):
94
+ return f'NNOutput(name: {repr(self.name)}, ' \
95
+ f' nn_out: {repr(self.nn_out)}, ' \
96
+ f' map_out:{repr(self.map_out)}, ' \
97
+ f' labels_value: {repr(self.labels_value)}, ' \
98
+ f'labels_name: {repr(self.labels_name)})'
99
+
100
+
101
+ class Mapper:
102
+ """Maps database values to neural network outputs.
103
+
104
+ Builds a dictionary mapping database label values to neural network output indices.
105
+ Supports grouping multiple database labels into single NN categories.
106
+
107
+ Attributes:
108
+ output_list (List[NNOutput]): List of output categories.
109
+ dictionary (dict): Mapping from database values to NN outputs.
110
+ no_target: Value returned for invalid/missing labels.
111
+ vectorize (bool): Whether to use one-hot vector outputs.
112
+ label_dictionary (Dict[str,int], optional): Maps label names to values.
113
+ """
114
+
115
+ def __init__(self, no_target=-1, vectorize=False, label_dictionary: Dict[str,int]=None):
116
+ """Initialize Mapper.
117
+
118
+ Args:
119
+ no_target (int, optional): Value for invalid targets. Defaults to -1.
120
+ vectorize (bool, optional): Use one-hot vectors instead of scalar outputs.
121
+ Defaults to False.
122
+ label_dictionary (Dict[str,int], optional): Maps label names to integer values.
123
+ Defaults to None.
124
+ """
125
+ self.output_list: List[NNOutput] = []
126
+ self.dictionary = {}
127
+ self.no_target = no_target
128
+ self.vectorize = vectorize
129
+
130
+ self.label_dictionary = label_dictionary
131
+
132
+ def __repr__(self):
133
+ return f'Mapper(output_list: {repr(self.output_list)}, ' \
134
+ f'dictionary: {repr(self.dictionary)}, ' \
135
+ f'no_target: {repr(self.no_target)}, ' \
136
+ f'vectorize: {repr(self.vectorize)})' \
137
+
138
+
139
+ def load_dic_from_file(self, csv_path):
140
+ with open(csv_path, mode='r') as infile:
141
+ reader = csv.reader(infile)
142
+ self.label_dictionary = {rows[0]: rows[1] for rows in reader}
143
+
144
+ def map_value_names(self):
145
+ return [(output.map_out, output.name) for output in self.output_list]
146
+
147
+ def map_names(self):
148
+ return [output.name for output in self.output_list]
149
+
150
+ def map_values(self):
151
+ return [output.map_out for output in self.output_list]
152
+
153
+ def nn_name(self):
154
+ return [output.name for output in self.output_list]
155
+
156
+ def add_category(self, name, labels, map_value=None):
157
+ """ add an output to the neural network"""
158
+
159
+ labels_name = labels
160
+
161
+ if self.label_dictionary is not None:
162
+ labels_values = [self.label_dictionary[name] for name in labels]
163
+ else:
164
+ labels_values = labels
165
+
166
+ if map_value is None:
167
+ map_value = len(self)
168
+
169
+ category = NNOutput(name, labels_values, labels_name, len(self), map_value)
170
+
171
+ self.output_list.append(category)
172
+ self._update_dictionary(category)
173
+
174
+ def __len__(self):
175
+ return len(self.output_list)
176
+
177
+ def _vectorize(self, no_target=-1):
178
+ """Transform to integer output to vector output """
179
+
180
+ for i, value in enumerate(self.output_list):
181
+ out = np.zeros(len(self))
182
+ out[i] = 1
183
+ value.nn_out = out
184
+
185
+ # if no target has len we assum it fine and no need to touch
186
+ if hasattr(no_target, '__len__'):
187
+ self.no_target = no_target
188
+ else:
189
+ # set it to a vector of 0
190
+ self.no_target = np.zeros(len(self))
191
+
192
+ def __call__(self, value):
193
+ return self.dictionary.get(value, self.no_target)
194
+
195
+ def map_output_transformer(self):
196
+ """Return a transformer to change transform the output of the nn (assuming argmax is used)"""
197
+ is_identity = True
198
+ nn_out_to_map = []
199
+ for i, out in enumerate(self.output_list):
200
+ is_identity = is_identity and (i == out.map_out)
201
+ nn_out_to_map.append(out.map_out)
202
+
203
+ return ArgMax() if is_identity else ArgMaxToCategory(nn_out_to_map)
204
+
205
+ def _update_dictionary(self, output: NNOutput):
206
+ """Update the dictionary transforming mapping the label to the output based ont the new nn output """
207
+ for value in output.labels_value:
208
+ if value in self.dictionary.keys():
209
+ logger.warning(f"{value} appears twice in the label-value mapping. One value has been ignored.")
210
+ self.dictionary[value] = output.nn_out
211
+
212
+
213
+ def db_dataset_multi_proc_init(worker_id):
214
+ """This function initialise the dataset in a way that the database reader environment is keep open during the full
215
+ process life
216
+
217
+ Used to fix:
218
+ Issue with newer version and lmdb, keep the db env open for the worker
219
+ https://github.com/jnwatson/py-lmdb/issues/340
220
+
221
+ """
222
+ worker_info = torch.utils.data.get_worker_info()
223
+ dataset: DBDataset = worker_info.dataset
224
+ dataset.init_db_environment(True)
225
+
226
+
227
+ class DBDataset(Dataset):
228
+ """PyTorch Dataset for reading training samples from LMDB database.
229
+
230
+ Reads image patches and labels from LMDB database with optional label mapping
231
+ and data augmentation. Supports both single-threaded and multi-threaded data loading.
232
+
233
+ Attributes:
234
+ multithread (bool): Whether to use multi-threaded data loading.
235
+ db_path (str): Path to LMDB database.
236
+ samples_list (np.ndarray): Array of sample IDs to fetch.
237
+ target_mapper: Mapper for transforming database labels to NN outputs.
238
+ f_transform: Data augmentation function.
239
+ transform_param (np.ndarray, optional): Parameters for augmentation per sample.
240
+ reader (LMDBReader): Database reader instance.
241
+ """
242
+
243
+ def __init__(self, db_path, samples_list, target_mapper=None, f_transform=None, transform_param=None, multithread=True):
244
+ """Initialize DBDataset.
245
+
246
+ Args:
247
+ db_path (str): Path to LMDB database file.
248
+ samples_list (list): List of sample IDs to include in dataset.
249
+ target_mapper (Mapper, optional): Maps database labels to NN outputs.
250
+ Defaults to None.
251
+ f_transform (callable, optional): Data augmentation function. Defaults to None.
252
+ transform_param (list, optional): Per-sample augmentation parameters.
253
+ Must be numeric types to avoid memory leaks. Defaults to None.
254
+ multithread (bool, optional): Enable multi-threaded loading. Defaults to True.
255
+
256
+ Note:
257
+ - Use numpy arrays (not Python objects) for transform_param to avoid memory leaks
258
+ (see https://github.com/pytorch/pytorch/issues/13246)
259
+ - For multithread=True, use db_dataset_multi_proc_init as worker_init_fn
260
+ (see https://github.com/jnwatson/py-lmdb/issues/340)
261
+ """
262
+ super().__init__()
263
+
264
+ self.multithread = multithread
265
+
266
+ self.db_path = db_path
267
+
268
+ # normal list cause memory leak
269
+ self.samples_list = np.array(samples_list)
270
+ self.target_mapper = target_mapper
271
+
272
+ self.f_transform = f_transform
273
+
274
+ if transform_param is not None:
275
+ self.transform_param = np.array(transform_param)
276
+ else:
277
+ self.transform_param = None
278
+
279
+ self.reader = None
280
+ if not multithread:
281
+ self.init_db_environment(False)
282
+
283
+ def init_db_environment(self, keep_env_open):
284
+ """init the db environment, Used to init it in different process using the worker_init_fun from
285
+ the datasetloader"""
286
+ self.reader = LMDBReader(self.db_path, keep_env_open=keep_env_open)
287
+
288
+ def __len__(self):
289
+ return len(self.samples_list)
290
+
291
+ def __getitem__(self, idx):
292
+
293
+ # TODO we create the reader here for multithreading env, but maybe could be done differently
294
+
295
+ with self.reader as db:
296
+
297
+ if hasattr(idx, '__iter__'):
298
+ return self._get_items(idx, db)
299
+
300
+ if isinstance(idx, int):
301
+ return self._get_one_item(idx, db)
302
+
303
+ if isinstance(idx, slice):
304
+ # Get the start, stop, and step from the slice
305
+ return self._get_items(range(idx.start, idx.stop, idx.step), db)
306
+
307
+ def batch_statistic(self):
308
+
309
+ reader = LMDBReader(self.db_path)
310
+
311
+ with reader:
312
+ outputs = [reader.get_output(s) for s in self.samples_list]
313
+
314
+ outputs = [self.target_mapper(val) for val in outputs]
315
+
316
+ counter = Counter(outputs)
317
+ logger.info(f"counter: {counter}")
318
+ #total = counter.total()
319
+ # highest appearing value
320
+ maximum = counter.most_common()[0][1]
321
+
322
+ logger.info(f"most common category {maximum}")
323
+
324
+ return {key: maximum / val for key, val in counter.items()}
325
+
326
+ def weight_list(self):
327
+
328
+ weight_dic = self.batch_statistic()
329
+
330
+ reader = LMDBReader(self.db_path)
331
+ with reader:
332
+ weights = []
333
+ for s in self.samples_list:
334
+ target_val = self.target_mapper(reader.get_output(s))
335
+ weights.append(weight_dic[target_val])
336
+
337
+ return weights
338
+
339
+ def _get_items_deprecated(self, iterable, reader):
340
+ inputs = []
341
+ targets = []
342
+
343
+ # see default collate for better memory management
344
+
345
+ for key in iterable:
346
+ # inputs[i], targets[i] = self._get_one_item(key, reader)
347
+ input, target = self._get_one_item(key, reader)
348
+ inputs.append(input)
349
+ targets.append(target)
350
+
351
+
352
+ #batch = len(inputs)
353
+ # taken from default_collate, we initialise the space on a shared memory, to avoid extra copy
354
+ # could also be done for target
355
+ #storage = input.storage()._new_shared(len(inputs) * input.numel(), device=input.device)
356
+ #out = input.new(storage).resize_(batch, *list(input.size()))
357
+ #torch.stack(inputs, out=out)
358
+
359
+ #return out, torch.LongTensor(targets)
360
+
361
+ return torch.stack(inputs), torch.LongTensor(targets)
362
+
363
+ def _get_items(self, iterable, reader):
364
+ # datas = []
365
+ # labels = []
366
+ batch = len(iterable)
367
+
368
+ iterable = iterable.__iter__()
369
+
370
+ try:
371
+ (one_input,), target = self._get_one_item(next(iterable), reader)
372
+
373
+ except StopIteration:
374
+ return []
375
+
376
+ # compute the shape
377
+ shape_in = (batch,) + one_input.shape
378
+
379
+ if isinstance(target, int):
380
+ shape_out = batch
381
+ targets = torch.empty(shape_out, dtype=torch.long)
382
+ else:
383
+ shape_out = (batch,) + target.shape
384
+ targets = torch.empty(shape_out, dtype=torch.float32)
385
+
386
+ inputs = torch.empty(shape_in, dtype=torch.float32)
387
+
388
+ inputs[0] = one_input
389
+ targets[0] = target
390
+
391
+ for i, key in enumerate(iterable, 1):
392
+ # the nn take on parameter so we unpack the 1 tuples and make it for the batch
393
+ (inputs[i],), targets[i] = self._get_one_item(key, reader)
394
+ # the nn take on parameter so we make a 1 element tuple
395
+ return (inputs,), targets
396
+
397
+ def _get_one_item(self, idx, reader):
398
+ key = self.samples_list[idx]
399
+
400
+ # key = [db_key, tranfsorm_param]
401
+ inputs, target = reader.get_data(int(key))
402
+
403
+ inputs = torch.from_numpy(inputs.copy())
404
+
405
+ if self.f_transform is not None:
406
+
407
+ if self.transform_param is not None:
408
+ param = self.transform_param[idx]
409
+ inputs = self.f_transform(inputs, *param)
410
+ else:
411
+ inputs = self.f_transform(inputs)
412
+
413
+
414
+ if self.target_mapper is not None:
415
+ target = self.target_mapper(target)
416
+
417
+ # the nn take on parameter so we make a 1 element tuple
418
+ return (inputs,), target
419
+
420
+
421
+
422
+ class DBDatasetMeta(DBDataset):
423
+ """
424
+ read dataset from db.
425
+ """
426
+
427
+ def __init__(self, db_path, samples_list, target_mapper=None, f_transform=None, transform_param=None):
428
+ """
429
+
430
+ :param db_path: path of the db to open
431
+ :param samples_list: a list of key to fetch from the db
432
+ :param target_mapper: map the db output to the nn output.
433
+
434
+ https://github.com/pytorch/pytorch/issues/13246 use numpy of NOT OBJ to solve memory leak
435
+ transfomration parameter should all be number!!!
436
+
437
+
438
+ """
439
+ super().__init__(db_path, samples_list, target_mapper, f_transform, transform_param)
440
+
441
+ def __len__(self):
442
+ return len(self.samples_list)
443
+
444
+ def __getitem__(self, idx):
445
+
446
+ with self.reader as db:
447
+
448
+ if hasattr(idx, '__iter__'):
449
+ inputs, outputs = self._get_items(idx, db)
450
+ headers = self._get_headers(idx, db)
451
+ return inputs, outputs, headers
452
+
453
+ if isinstance(idx, int):
454
+ _input, output = self._get_one_item(idx, db)
455
+ header = self._get_header(idx, db)
456
+ return _input, output, header
457
+
458
+ if isinstance(idx, slice):
459
+ # Get the start, stop, and step from the slice
460
+ inputs, outputs = self._get_items(range(idx.start, idx.stop, idx.step), db)
461
+ headers = self._get_headers(range(idx.start, idx.stop, idx.step), db)
462
+ return inputs, outputs, headers
463
+
464
+ def _get_headers(self, iterable, reader):
465
+ return [self._get_header(h, reader) for h in iterable]
466
+
467
+ def _get_header(self, idx, reader):
468
+ key = int(self.samples_list[idx])
469
+ return reader.get_header(key)
470
+
471
+
472
+ class DBInfo:
473
+ def __init__(self, db_path, sample_list, target_mapper, f_transform=None, transform_param=None):
474
+ self.db_path = db_path
475
+ self.sample_list = sample_list
476
+ self.target_mapper = target_mapper
477
+ self.f_transform = f_transform
478
+ self.transform_param = transform_param
479
+
480
+
481
+ class MultiDBDataset(Dataset):
482
+ """
483
+ read dataset from multiple db
484
+ TODO check if helper function of torch can not replace that.
485
+ TODO 2 check if get one item function
486
+ """
487
+ def __init__(self, db_info: List[DBInfo], multithread=True):
488
+ """
489
+
490
+ :param db_path: path of the db to open
491
+ :param samples_list: a list of key to fetch from the db
492
+ :param target_mapper: map the db output to the nn output.
493
+
494
+ https://github.com/pytorch/pytorch/issues/13246 use numpy of NOT OBJ to solve memory leak not implemented here
495
+ transfomration parameter should all be number!!!
496
+ """
497
+ super().__init__()
498
+
499
+ self.db_info: List[DBInfo] = db_info
500
+
501
+ self.size = 0
502
+
503
+ #todo improve
504
+ index =0
505
+ self.samples_index= []
506
+ for db_index, db in enumerate(self.db_info):
507
+ self.size += len(db.sample_list)
508
+
509
+ for sample_index in db.sample_list:
510
+ # sample list is the pair db and sample index in db
511
+ self.samples_index.append((db_index, sample_index))
512
+ index += 1
513
+
514
+ self.readers=None
515
+
516
+ if not multithread:
517
+ self.init_db_environment(False)
518
+ def init_db_environment(self, keep_env_open):
519
+ """init the db environment, Used to init it in different process using the worker_init_fun from
520
+ the datasetloader"""
521
+ self.readers = [LMDBReader(db.db_path, keep_env_open=keep_env_open) for db in self.db_info]
522
+
523
+ def __len__(self):
524
+ return self.size
525
+
526
+ def __getitem__(self, idx):
527
+
528
+ # TODO we create the reader here for multithreading env, but maybe could be done differently
529
+
530
+ try:
531
+
532
+ for reader in self.readers:
533
+ reader.open()
534
+
535
+ if hasattr(idx, '__iter__'):
536
+ return self._get_items(idx, self.readers)
537
+
538
+ if isinstance(idx, int):
539
+ return self._get_one_item(idx, self.readers)
540
+
541
+ if isinstance(idx, slice):
542
+ # Get the start, stop, and step from the slice
543
+ return self._get_items(range(idx.start, idx.stop, idx.step), self.readers)
544
+
545
+ except Exception as e:
546
+ raise e
547
+
548
+ finally:
549
+ for reader in self.readers:
550
+ reader.close()
551
+
552
+ def _get_items(self, iterable, readers):
553
+ # datas = []
554
+ # labels = []
555
+ batch = len(iterable)
556
+
557
+ iterable = iterable.__iter__()
558
+
559
+
560
+ try:
561
+ idx = next(iterable)
562
+ input, target = self._get_one_item(idx, readers)
563
+
564
+ except StopIteration:
565
+ return []
566
+
567
+ # compute the shape
568
+ shape_in = (batch,) + input.shape
569
+
570
+ if isinstance(target, int):
571
+ shape_out = batch
572
+ targets = torch.empty(shape_out, dtype=torch.long)
573
+ else:
574
+ shape_out = (batch,) + target.shape
575
+ targets = torch.empty(shape_out, dtype=torch.float32)
576
+
577
+ inputs = torch.empty(shape_in, dtype=torch.float32)
578
+
579
+ inputs[0] = input
580
+ targets[0] = target
581
+
582
+ for i, idx in enumerate(iterable, 1):
583
+
584
+ inputs[i], targets[i] = self._get_one_item(idx, readers)
585
+
586
+ return inputs, targets
587
+
588
+ def _get_one_item(self, idx, readers):
589
+ db_index, id_sample = self.samples_index[idx]
590
+ inputs, target = readers[db_index].get_data(id_sample)
591
+
592
+ inputs = torch.from_numpy(inputs.copy())
593
+
594
+ if self.db_info[db_index].f_transform is not None:
595
+
596
+ if self.db_info[db_index].transform_param is not None:
597
+ param = self.db_info[db_index].transform_param[id_sample]
598
+ inputs = self.db_info[db_index].f_transform(inputs, *param)
599
+ else:
600
+ inputs = self.db_info[db_index].f_transform(inputs)
601
+
602
+ if self.db_info[db_index].target_mapper is not None:
603
+ target = self.db_info[db_index].target_mapper(target)
604
+
605
+ return inputs, target