eotdl 2023.7.19.post4__py3-none-any.whl → 2023.9.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eotdl/commands/datasets.py +15 -29
- eotdl/curation/__init__.py +5 -5
- eotdl/curation/formatters.py +0 -2
- eotdl/curation/metadata.py +34 -9
- eotdl/curation/stac/assets.py +127 -0
- eotdl/curation/stac/dataframe.py +8 -4
- eotdl/curation/stac/extensions.py +295 -46
- eotdl/curation/stac/extent.py +130 -0
- eotdl/curation/stac/ml_dataset.py +509 -0
- eotdl/curation/stac/parsers.py +2 -0
- eotdl/curation/stac/stac.py +309 -286
- eotdl/curation/stac/utils.py +47 -1
- eotdl/datasets/__init__.py +2 -2
- eotdl/datasets/download.py +16 -3
- eotdl/datasets/ingest.py +21 -10
- eotdl/datasets/retrieve.py +10 -2
- eotdl/src/repos/APIRepo.py +42 -18
- eotdl/src/repos/AuthRepo.py +3 -3
- eotdl/src/usecases/auth/IsLogged.py +5 -3
- eotdl/src/usecases/datasets/DownloadDataset.py +35 -6
- eotdl/src/usecases/datasets/DownloadFileURL.py +22 -0
- eotdl/src/usecases/datasets/IngestFile.py +48 -28
- eotdl/src/usecases/datasets/IngestSTAC.py +43 -8
- eotdl/src/usecases/datasets/RetrieveDatasets.py +3 -2
- eotdl/src/usecases/datasets/__init__.py +1 -0
- eotdl/tools/sen12floods/tools.py +3 -3
- eotdl/tools/stac.py +8 -2
- {eotdl-2023.7.19.post4.dist-info → eotdl-2023.9.14.dist-info}/METADATA +2 -1
- {eotdl-2023.7.19.post4.dist-info → eotdl-2023.9.14.dist-info}/RECORD +31 -27
- {eotdl-2023.7.19.post4.dist-info → eotdl-2023.9.14.dist-info}/WHEEL +1 -1
- {eotdl-2023.7.19.post4.dist-info → eotdl-2023.9.14.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,509 @@
|
|
1
|
+
"""Implements the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`."""
|
2
|
+
|
3
|
+
import traceback
|
4
|
+
import random
|
5
|
+
|
6
|
+
import pystac
|
7
|
+
from tqdm import tqdm
|
8
|
+
from pystac.extensions.base import ExtensionManagementMixin, PropertiesExtension
|
9
|
+
from pystac.extensions.label import LabelExtension
|
10
|
+
from pystac import STACValidationError
|
11
|
+
from shutil import rmtree
|
12
|
+
from os.path import dirname
|
13
|
+
from pystac.cache import ResolvedObjectCache
|
14
|
+
from pystac.extensions.hooks import ExtensionHooks
|
15
|
+
from typing import Any, Dict, List, Optional, Generic, TypeVar, Union, Set
|
16
|
+
|
17
|
+
T = TypeVar("T", pystac.Item, pystac.Collection, pystac.Catalog)
|
18
|
+
|
19
|
+
|
20
|
+
SCHEMA_URI: str = "https://raw.githubusercontent.com/earthpulse/ml-dataset/main/json-schema/schema.json"
|
21
|
+
PREFIX: str = "ml-dataset:"
|
22
|
+
|
23
|
+
|
24
|
+
class MLDatasetExtension(
|
25
|
+
pystac.Catalog,
|
26
|
+
Generic[T],
|
27
|
+
PropertiesExtension,
|
28
|
+
ExtensionManagementMixin[
|
29
|
+
Union[pystac.item.Item, pystac.collection.Collection, pystac.catalog.Catalog]
|
30
|
+
],
|
31
|
+
):
|
32
|
+
"""An abstract class that can be used to extend the properties of a
|
33
|
+
:class:`~pystac.Collection`, :class:`~pystac.Item`, or :class:`~pystac.Catalog` with
|
34
|
+
properties from the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`. This class is
|
35
|
+
generic over the type of STAC Object to be extended (e.g. :class:`~pystac.Item`,
|
36
|
+
:class:`~pystac.Asset`).
|
37
|
+
|
38
|
+
To create a concrete instance of :class:`MLDatasetExtension`, use the
|
39
|
+
:meth:`MLDatasetExtension.ext` method. For example:
|
40
|
+
|
41
|
+
.. code-block:: python
|
42
|
+
|
43
|
+
>>> item: pystac.Item = ...
|
44
|
+
>>> ml_ext = MLDatasetExtension.ext(item)
|
45
|
+
"""
|
46
|
+
|
47
|
+
catalog: pystac.Catalog
|
48
|
+
"""The :class:`~pystac.Catalog` being extended."""
|
49
|
+
|
50
|
+
properties: Dict[str, Any]
|
51
|
+
"""The :class:`~pystac.Catalog` extra fields, including extension properties."""
|
52
|
+
|
53
|
+
links: List[pystac.Link]
|
54
|
+
"""The list of :class:`~pystac.Link` objects associated with the
|
55
|
+
:class:`~pystac.Catalog` being extended, including links added by this extension.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self, catalog: pystac.Catalog):
|
59
|
+
super().__init__(id=catalog.id, description=catalog.description)
|
60
|
+
self._catalog = catalog
|
61
|
+
self.id = catalog.id
|
62
|
+
self.description = catalog.description
|
63
|
+
self.title = catalog.title if catalog.title else None
|
64
|
+
self.stac_extensions = (
|
65
|
+
catalog.stac_extensions if catalog.stac_extensions else []
|
66
|
+
)
|
67
|
+
self.extra_fields = self.properties = (
|
68
|
+
catalog.extra_fields if catalog.extra_fields else {}
|
69
|
+
)
|
70
|
+
self.links = catalog.links
|
71
|
+
self._resolved_objects = ResolvedObjectCache()
|
72
|
+
|
73
|
+
def apply(self, name: str = None) -> None:
|
74
|
+
self.name = name
|
75
|
+
|
76
|
+
@property
|
77
|
+
def name(self) -> str:
|
78
|
+
return self.extra_fields[f"{PREFIX}name"]
|
79
|
+
|
80
|
+
@name.setter
|
81
|
+
def name(self, v: str) -> None:
|
82
|
+
self.extra_fields[f"{PREFIX}name"] = v
|
83
|
+
|
84
|
+
@property
|
85
|
+
def tasks(self) -> List:
|
86
|
+
return self.extra_fields[f"{PREFIX}tasks"]
|
87
|
+
|
88
|
+
@tasks.setter
|
89
|
+
def tasks(self, v: Union[list, tuple]) -> None:
|
90
|
+
self.extra_fields[f"{PREFIX}tasks"] = v
|
91
|
+
|
92
|
+
@property
|
93
|
+
def type(self) -> str:
|
94
|
+
return self.extra_fields[f"{PREFIX}type"]
|
95
|
+
|
96
|
+
@type.setter
|
97
|
+
def type(self, v: str) -> None:
|
98
|
+
self.extra_fields[f"{PREFIX}type"] = v
|
99
|
+
|
100
|
+
@property
|
101
|
+
def inputs_type(self) -> str:
|
102
|
+
return self.extra_fields[f"{PREFIX}inputs-type"]
|
103
|
+
|
104
|
+
@inputs_type.setter
|
105
|
+
def inputs_type(self, v: str) -> None:
|
106
|
+
self.extra_fields[f"{PREFIX}inputs-type"] = v
|
107
|
+
|
108
|
+
@property
|
109
|
+
def annotations_type(self) -> str:
|
110
|
+
return self.extra_fields[f"{PREFIX}annotations-type"]
|
111
|
+
|
112
|
+
@annotations_type.setter
|
113
|
+
def annotations_type(self, v: str) -> None:
|
114
|
+
self.extra_fields[f"{PREFIX}annotations-type"] = v
|
115
|
+
|
116
|
+
@property
|
117
|
+
def splits(self) -> List[str]:
|
118
|
+
self.extra_fields[f"{PREFIX}splits"]
|
119
|
+
|
120
|
+
@splits.setter
|
121
|
+
def splits(self, v: dict) -> None:
|
122
|
+
self.extra_fields[f"{PREFIX}splits"] = v
|
123
|
+
|
124
|
+
@property
|
125
|
+
def quality_metrics(self) -> List[dict]:
|
126
|
+
self.extra_fields[f"{PREFIX}quality-metrics"]
|
127
|
+
|
128
|
+
@quality_metrics.setter
|
129
|
+
def quality_metrics(self, v: dict) -> None:
|
130
|
+
self.extra_fields[f"{PREFIX}quality-metrics"] = v
|
131
|
+
|
132
|
+
@property
|
133
|
+
def version(self) -> str:
|
134
|
+
self.extra_fields[f"{PREFIX}version"]
|
135
|
+
|
136
|
+
@version.setter
|
137
|
+
def version(self, v: str) -> None:
|
138
|
+
self.extra_fields[f"{PREFIX}version"] = v
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def get_schema_uri(cls) -> str:
|
142
|
+
return SCHEMA_URI
|
143
|
+
|
144
|
+
def add_metric(self, metric: dict) -> None:
|
145
|
+
"""Add a metric to this object's set of metrics.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
metric : The metric to add.
|
149
|
+
"""
|
150
|
+
if not self.extra_fields.get(f'{PREFIX}quality-metrics'):
|
151
|
+
self.extra_fields[f'{PREFIX}quality-metrics'] = []
|
152
|
+
|
153
|
+
if metric not in self.extra_fields[f'{PREFIX}quality-metrics']:
|
154
|
+
self.extra_fields[f'{PREFIX}quality-metrics'].append(metric)
|
155
|
+
|
156
|
+
def add_metrics(self, metrics: List[dict]) -> None:
|
157
|
+
"""Add a list of metrics to this object's set of metrics.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
metrics : The metrics to add.
|
161
|
+
"""
|
162
|
+
for metric in metrics:
|
163
|
+
self.add_metric(metric)
|
164
|
+
|
165
|
+
@classmethod
|
166
|
+
def ext(cls, obj: T, add_if_missing: bool = False):
|
167
|
+
"""Extends the given STAC Object with properties from the
|
168
|
+
:stac-ext:`Machine Learning Dataset Extension <ml-dataset>`.
|
169
|
+
|
170
|
+
This extension can be applied to instances of :class:`~pystac.Catalog`,
|
171
|
+
:class:`~pystac.Collection` or :class:`~pystac.Item`.
|
172
|
+
|
173
|
+
Raises:
|
174
|
+
pystac.ExtensionTypeError : If an invalid object type is passed.
|
175
|
+
"""
|
176
|
+
if isinstance(obj, pystac.Collection):
|
177
|
+
cls.validate_has_extension(obj, add_if_missing)
|
178
|
+
return CollectionMLDatasetExtension(obj)
|
179
|
+
elif isinstance(obj, pystac.Catalog):
|
180
|
+
cls.validate_has_extension(obj, add_if_missing)
|
181
|
+
return MLDatasetExtension(obj)
|
182
|
+
elif isinstance(obj, pystac.Item):
|
183
|
+
cls.validate_has_extension(obj, add_if_missing)
|
184
|
+
return ItemMLDatasetExtension(obj)
|
185
|
+
else:
|
186
|
+
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
|
187
|
+
|
188
|
+
|
189
|
+
class CollectionMLDatasetExtension(MLDatasetExtension[pystac.Collection]):
|
190
|
+
"""A concrete implementation of :class:`MLDatasetExtension` on an
|
191
|
+
:class:`~pystac.Collection` that extends the properties of the Collection to include
|
192
|
+
properties defined in the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`.
|
193
|
+
"""
|
194
|
+
|
195
|
+
collection: pystac.Collection
|
196
|
+
properties: Dict[str, Any]
|
197
|
+
|
198
|
+
def __init__(self, collection: pystac.Collection):
|
199
|
+
self.collection = collection
|
200
|
+
self.properties = collection.extra_fields
|
201
|
+
self.properties[f"{PREFIX}split-items"] = []
|
202
|
+
|
203
|
+
def __repr__(self) -> str:
|
204
|
+
return "<CollectionMLDatasetExtension Item id={}>".format(self.collection.id)
|
205
|
+
|
206
|
+
@property
|
207
|
+
def splits(self) -> List[dict]:
|
208
|
+
return self._splits
|
209
|
+
|
210
|
+
@splits.setter
|
211
|
+
def splits(self, v: dict) -> None:
|
212
|
+
self.properties[f"{PREFIX}split-items"] = v
|
213
|
+
|
214
|
+
def add_split(self, v: dict) -> None:
|
215
|
+
self.properties[f"{PREFIX}split-items"].append(v)
|
216
|
+
|
217
|
+
def create_and_add_split(
|
218
|
+
self, split_data: List[pystac.Item], split_type: str
|
219
|
+
) -> None:
|
220
|
+
""" """
|
221
|
+
items_ids = [item.id for item in split_data]
|
222
|
+
items_ids.sort()
|
223
|
+
|
224
|
+
split = {"name": split_type, "items": items_ids}
|
225
|
+
|
226
|
+
if not self.properties.get(f"{PREFIX}split-items"):
|
227
|
+
self.properties[f"{PREFIX}split-items"] = []
|
228
|
+
|
229
|
+
self.add_split(split)
|
230
|
+
print(f"Generating {split_type} split...")
|
231
|
+
for _item in tqdm(split_data):
|
232
|
+
item = self.collection.get_item(_item.id)
|
233
|
+
if item:
|
234
|
+
item_ml = MLDatasetExtension.ext(item, add_if_missing=True)
|
235
|
+
item_ml.split = split_type
|
236
|
+
|
237
|
+
|
238
|
+
|
239
|
+
class ItemMLDatasetExtension(MLDatasetExtension[pystac.Item]):
|
240
|
+
"""A concrete implementation of :class:`MLDatasetExtension` on an
|
241
|
+
:class:`~pystac.Item` that extends the properties of the Item to include properties
|
242
|
+
defined in the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`.
|
243
|
+
|
244
|
+
This class should generally not be instantiated directly. Instead, call
|
245
|
+
:meth:`MLDatasetExtension.ext` on an :class:`~pystac.Item` to extend it.
|
246
|
+
"""
|
247
|
+
|
248
|
+
item: pystac.Item
|
249
|
+
properties: Dict[str, Any]
|
250
|
+
|
251
|
+
def __init__(self, item: pystac.Item):
|
252
|
+
self.item = item
|
253
|
+
self.properties = item.properties
|
254
|
+
|
255
|
+
@property
|
256
|
+
def split(self) -> str:
|
257
|
+
return self._split
|
258
|
+
|
259
|
+
@split.setter
|
260
|
+
def split(self, v: str) -> None:
|
261
|
+
self.properties[f"{PREFIX}split"] = v
|
262
|
+
|
263
|
+
def __repr__(self) -> str:
|
264
|
+
return "<ItemMLDatasetExtension Item id={}>".format(self.item.id)
|
265
|
+
|
266
|
+
|
267
|
+
class MLDatasetQualityMetrics:
|
268
|
+
""" """
|
269
|
+
|
270
|
+
@classmethod
|
271
|
+
def calculate(self, catalog: Union[pystac.Catalog, str]) -> None:
|
272
|
+
""" """
|
273
|
+
|
274
|
+
if isinstance(catalog, str):
|
275
|
+
catalog = MLDatasetExtension(pystac.read_file(catalog))
|
276
|
+
# Check the catalog has the extension
|
277
|
+
if not MLDatasetExtension.has_extension(catalog):
|
278
|
+
raise pystac.ExtensionNotImplemented(
|
279
|
+
f"MLDatasetExtension does not apply to type '{type(catalog).__name__}'"
|
280
|
+
)
|
281
|
+
|
282
|
+
try:
|
283
|
+
catalog.add_metric(self._search_spatial_duplicates(catalog))
|
284
|
+
catalog.add_metric(self._get_classes_balance(catalog))
|
285
|
+
except AttributeError:
|
286
|
+
raise pystac.ExtensionNotImplemented(
|
287
|
+
f"The catalog does not have the required properties or the ML-Dataset extension to calculate the metrics"
|
288
|
+
)
|
289
|
+
|
290
|
+
try:
|
291
|
+
print("Validating and saving...")
|
292
|
+
catalog.validate()
|
293
|
+
destination = dirname(catalog.get_self_href())
|
294
|
+
rmtree(
|
295
|
+
destination
|
296
|
+
) # Remove the old catalog and replace it with the new one
|
297
|
+
catalog.save(dest_href=destination)
|
298
|
+
print("Success!")
|
299
|
+
except STACValidationError as error:
|
300
|
+
# Return full callback
|
301
|
+
traceback.print_exc()
|
302
|
+
|
303
|
+
@staticmethod
|
304
|
+
def _search_spatial_duplicates(catalog: pystac.Catalog):
|
305
|
+
""" """
|
306
|
+
# TODO test this method
|
307
|
+
print("Looking for spatial duplicates...")
|
308
|
+
items = [
|
309
|
+
item
|
310
|
+
for item in tqdm(catalog.get_all_items())
|
311
|
+
if not LabelExtension.has_extension(item)
|
312
|
+
]
|
313
|
+
|
314
|
+
# Initialize the spatial duplicates dict
|
315
|
+
spatial_duplicates = {"name": "spatial-duplicates", "values": [], "total": 0}
|
316
|
+
|
317
|
+
items_bboxes = dict()
|
318
|
+
for item in items:
|
319
|
+
# Get the item bounding box
|
320
|
+
bbox = str(item.bbox)
|
321
|
+
# If the bounding box is not in the items dict, add it
|
322
|
+
if bbox not in items_bboxes.keys():
|
323
|
+
items_bboxes[bbox] = item.id
|
324
|
+
# If the bounding box is already in the items dict, add it to the duplicates dict
|
325
|
+
else:
|
326
|
+
spatial_duplicates["values"].append(
|
327
|
+
{"item": item.id, "duplicate": items_bboxes[bbox]}
|
328
|
+
)
|
329
|
+
spatial_duplicates["total"] += 1
|
330
|
+
|
331
|
+
return spatial_duplicates
|
332
|
+
|
333
|
+
@staticmethod
|
334
|
+
def _get_classes_balance(catalog: pystac.Catalog) -> dict:
|
335
|
+
""" """
|
336
|
+
print("Calculating classes balance...")
|
337
|
+
labels = [
|
338
|
+
item
|
339
|
+
for item in tqdm(catalog.get_all_items())
|
340
|
+
if LabelExtension.has_extension(item)
|
341
|
+
]
|
342
|
+
|
343
|
+
# Initialize the classes balance dict
|
344
|
+
classes_balance = {"name": "classes-balance", "values": []}
|
345
|
+
|
346
|
+
classes = dict()
|
347
|
+
for label in labels:
|
348
|
+
label_ext = LabelExtension.ext(label)
|
349
|
+
label_classes = label_ext.label_classes
|
350
|
+
|
351
|
+
for label_class_obj in label_classes:
|
352
|
+
label_class = label_class_obj.classes
|
353
|
+
|
354
|
+
for single_class in label_class:
|
355
|
+
if single_class not in classes:
|
356
|
+
classes[single_class] = 0
|
357
|
+
classes[single_class] += 1
|
358
|
+
|
359
|
+
total_labels = sum(classes.values())
|
360
|
+
for key, value in classes.items():
|
361
|
+
classes_balance["values"].append(
|
362
|
+
{
|
363
|
+
"class": key,
|
364
|
+
"total": value,
|
365
|
+
"percentage": int(value / total_labels * 100),
|
366
|
+
}
|
367
|
+
)
|
368
|
+
|
369
|
+
return classes_balance
|
370
|
+
|
371
|
+
|
372
|
+
class MLDatasetExtensionHooks(ExtensionHooks):
|
373
|
+
schema_uri: str = SCHEMA_URI
|
374
|
+
prev_extension_ids: Set[str] = set()
|
375
|
+
stac_object_types = {
|
376
|
+
pystac.STACObjectType.CATALOG,
|
377
|
+
pystac.STACObjectType.COLLECTION,
|
378
|
+
pystac.STACObjectType.ITEM,
|
379
|
+
}
|
380
|
+
|
381
|
+
|
382
|
+
STORAGE_EXTENSION_HOOKS: ExtensionHooks = MLDatasetExtensionHooks()
|
383
|
+
|
384
|
+
|
385
|
+
def add_ml_extension(
|
386
|
+
catalog: Union[pystac.Catalog, str],
|
387
|
+
destination: Optional[str] = None,
|
388
|
+
splits: Optional[bool] = False,
|
389
|
+
splits_collection_id: Optional[str] = "labels",
|
390
|
+
splits_names: Optional[list] = ("Training", "Validation", "Test"),
|
391
|
+
split_proportions: Optional[List[int]] = (80, 10, 10),
|
392
|
+
**kwargs,
|
393
|
+
) -> None:
|
394
|
+
"""
|
395
|
+
Adds the ML Dataset extension to a STAC catalog.
|
396
|
+
|
397
|
+
Args:
|
398
|
+
catalog : The STAC catalog to add the extension to.
|
399
|
+
destination : The destination path to save the catalog to.
|
400
|
+
splits : The splits to make.
|
401
|
+
split_proportions : The proportions of the splits.
|
402
|
+
"""
|
403
|
+
if not isinstance(catalog, pystac.Catalog) and isinstance(catalog, str):
|
404
|
+
catalog = pystac.read_file(catalog)
|
405
|
+
elif isinstance(catalog, pystac.Catalog):
|
406
|
+
pass
|
407
|
+
else:
|
408
|
+
raise pystac.ExtensionTypeError(
|
409
|
+
f"MLDatasetExtension does not apply to type '{type(catalog).__name__}'"
|
410
|
+
)
|
411
|
+
|
412
|
+
catalog_ml_dataset = MLDatasetExtension.ext(catalog, add_if_missing=True)
|
413
|
+
|
414
|
+
# Set extension properties
|
415
|
+
for key, value in kwargs.items():
|
416
|
+
setattr(catalog_ml_dataset, key, value)
|
417
|
+
|
418
|
+
# Make splits if needed
|
419
|
+
if splits:
|
420
|
+
catalog_ml_dataset.splits = splits_names # Add the splits names to the catalog
|
421
|
+
train_size, test_size, val_size = split_proportions
|
422
|
+
splits_collection = catalog.get_child(
|
423
|
+
splits_collection_id
|
424
|
+
) # Get the collection to split
|
425
|
+
make_splits(
|
426
|
+
splits_collection,
|
427
|
+
train_size=train_size,
|
428
|
+
test_size=test_size,
|
429
|
+
val_size=val_size,
|
430
|
+
**kwargs,
|
431
|
+
)
|
432
|
+
# Normalize the ref on the same folder
|
433
|
+
catalog_ml_dataset.normalize_hrefs(root_href=dirname(catalog.get_self_href()))
|
434
|
+
|
435
|
+
try:
|
436
|
+
print("Validating and saving...")
|
437
|
+
catalog_ml_dataset.validate()
|
438
|
+
if not destination:
|
439
|
+
destination = dirname(catalog.get_self_href())
|
440
|
+
rmtree(
|
441
|
+
destination
|
442
|
+
) # Remove the old catalog and replace it with the new one
|
443
|
+
catalog_ml_dataset.save(dest_href=destination)
|
444
|
+
print("Success!")
|
445
|
+
except STACValidationError as error:
|
446
|
+
# Return full callback
|
447
|
+
traceback.print_exc()
|
448
|
+
|
449
|
+
|
450
|
+
def make_splits(
|
451
|
+
labels_collection: Union[CollectionMLDatasetExtension, pystac.Collection, str],
|
452
|
+
splits_names: Optional[List[str]] = ("Training", "Validation", "Test"),
|
453
|
+
splits_proportions: Optional[List[int]] = (80, 10, 10),
|
454
|
+
verbose: Optional[bool] = True,
|
455
|
+
**kwargs,
|
456
|
+
) -> None:
|
457
|
+
"""
|
458
|
+
Makes the splits of the labels collection.
|
459
|
+
|
460
|
+
Args:
|
461
|
+
labels_collection : The STAC Collection make the splits on.
|
462
|
+
train_size : The percentage of the dataset to use for training.
|
463
|
+
test_size : The percentage of the dataset to use for testing.
|
464
|
+
val_size : The percentage of the dataset to use for validation.
|
465
|
+
verbose : Whether to print the sizes of the splits.
|
466
|
+
"""
|
467
|
+
if isinstance(labels_collection, str):
|
468
|
+
labels_collection = pystac.read_file(labels_collection)
|
469
|
+
|
470
|
+
train_size, test_size, val_size = splits_proportions
|
471
|
+
|
472
|
+
if train_size + test_size + val_size != 100:
|
473
|
+
raise ValueError("The sum of the splits must be 100")
|
474
|
+
|
475
|
+
# Get all items in the labels collection
|
476
|
+
items = [item for item in labels_collection.get_all_items()]
|
477
|
+
|
478
|
+
# Calculate indices to split the items
|
479
|
+
length = len(items)
|
480
|
+
idx_train = int(train_size / 100 * length)
|
481
|
+
idx_test = int(test_size / 100 * length)
|
482
|
+
if val_size:
|
483
|
+
idx_val = int(val_size / 100 * length)
|
484
|
+
|
485
|
+
print("Generating splits...")
|
486
|
+
if verbose:
|
487
|
+
print(f"Total size: {length}")
|
488
|
+
print(f"Train size: {idx_train}")
|
489
|
+
print(f"Test size: {idx_test}")
|
490
|
+
if val_size:
|
491
|
+
print(f"Validation size: {idx_val}")
|
492
|
+
|
493
|
+
# Make sure the items are shuffled
|
494
|
+
random.shuffle(items)
|
495
|
+
|
496
|
+
# Split the items
|
497
|
+
train_items = items[:idx_train]
|
498
|
+
test_items = items[idx_train : idx_train + idx_test]
|
499
|
+
if val_size:
|
500
|
+
val_items = items[idx_train + idx_test : idx_train + idx_test + idx_val]
|
501
|
+
|
502
|
+
# Create the splits in the collection
|
503
|
+
labels_collection = MLDatasetExtension.ext(labels_collection, add_if_missing=True)
|
504
|
+
for split_type, split_data in zip(
|
505
|
+
splits_names, [train_items, test_items, val_items]
|
506
|
+
):
|
507
|
+
labels_collection.create_and_add_split(split_data, split_type)
|
508
|
+
|
509
|
+
print("Success on splits generation!")
|