eotdl 2024.10.7__py3-none-any.whl → 2025.3.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. eotdl/__init__.py +1 -1
  2. eotdl/access/search.py +0 -2
  3. eotdl/access/sentinelhub/parameters.py +1 -1
  4. eotdl/cli.py +2 -2
  5. eotdl/commands/datasets.py +28 -31
  6. eotdl/commands/models.py +27 -30
  7. eotdl/commands/stac.py +57 -0
  8. eotdl/curation/__init__.py +0 -8
  9. eotdl/curation/stac/__init__.py +1 -8
  10. eotdl/curation/stac/api.py +58 -0
  11. eotdl/curation/stac/stac.py +31 -341
  12. eotdl/datasets/__init__.py +1 -1
  13. eotdl/datasets/ingest.py +28 -159
  14. eotdl/datasets/retrieve.py +0 -9
  15. eotdl/datasets/stage.py +64 -0
  16. eotdl/files/__init__.py +0 -2
  17. eotdl/files/ingest.bck +178 -0
  18. eotdl/files/ingest.py +229 -164
  19. eotdl/{datasets → files}/metadata.py +16 -17
  20. eotdl/models/__init__.py +1 -1
  21. eotdl/models/ingest.py +28 -159
  22. eotdl/models/stage.py +60 -0
  23. eotdl/repos/APIRepo.py +1 -1
  24. eotdl/repos/DatasetsAPIRepo.py +56 -43
  25. eotdl/repos/FilesAPIRepo.py +260 -167
  26. eotdl/repos/STACAPIRepo.py +40 -0
  27. eotdl/repos/__init__.py +1 -0
  28. eotdl/tools/geo_utils.py +7 -2
  29. {eotdl-2024.10.7.dist-info → eotdl-2025.3.25.dist-info}/METADATA +5 -4
  30. eotdl-2025.3.25.dist-info/RECORD +65 -0
  31. {eotdl-2024.10.7.dist-info → eotdl-2025.3.25.dist-info}/WHEEL +1 -1
  32. eotdl/curation/stac/assets.py +0 -110
  33. eotdl/curation/stac/dataframe.py +0 -172
  34. eotdl/curation/stac/dataframe_bck.py +0 -253
  35. eotdl/curation/stac/dataframe_labeling.py +0 -63
  36. eotdl/curation/stac/extensions/__init__.py +0 -23
  37. eotdl/curation/stac/extensions/base.py +0 -30
  38. eotdl/curation/stac/extensions/dem.py +0 -18
  39. eotdl/curation/stac/extensions/eo.py +0 -117
  40. eotdl/curation/stac/extensions/label/__init__.py +0 -7
  41. eotdl/curation/stac/extensions/label/base.py +0 -136
  42. eotdl/curation/stac/extensions/label/image_name_labeler.py +0 -203
  43. eotdl/curation/stac/extensions/label/scaneo.py +0 -219
  44. eotdl/curation/stac/extensions/ml_dataset.py +0 -648
  45. eotdl/curation/stac/extensions/projection.py +0 -44
  46. eotdl/curation/stac/extensions/raster.py +0 -53
  47. eotdl/curation/stac/extensions/sar.py +0 -55
  48. eotdl/curation/stac/extent.py +0 -158
  49. eotdl/curation/stac/parsers.py +0 -61
  50. eotdl/datasets/download.py +0 -104
  51. eotdl/files/list_files.py +0 -13
  52. eotdl/models/download.py +0 -101
  53. eotdl/models/metadata.py +0 -43
  54. eotdl/wrappers/utils.py +0 -35
  55. eotdl-2024.10.7.dist-info/RECORD +0 -82
  56. {eotdl-2024.10.7.dist-info → eotdl-2025.3.25.dist-info}/entry_points.txt +0 -0
@@ -1,648 +0,0 @@
1
- """Implements the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`."""
2
-
3
- from typing import Any, Dict, List, Optional, Generic, TypeVar, Union, Set
4
- from shutil import rmtree
5
- from os.path import dirname, exists
6
-
7
- import traceback
8
- import json
9
- import random
10
-
11
- import pystac
12
-
13
- from tqdm import tqdm
14
- from pystac.extensions.base import ExtensionManagementMixin, PropertiesExtension
15
- from pystac.extensions.label import LabelExtension
16
- from pystac import STACValidationError
17
- from pystac.cache import ResolvedObjectCache
18
- from pystac.extensions.hooks import ExtensionHooks
19
- from ....tools import make_links_relative_to_path
20
-
21
- T = TypeVar("T", pystac.Item, pystac.Collection, pystac.Catalog)
22
-
23
-
24
- SCHEMA_URI: str = "https://raw.githubusercontent.com/earthpulse/ml-dataset/main/json-schema/schema.json"
25
- PREFIX: str = "ml-dataset:"
26
-
27
-
28
- class MLDatasetExtension(
29
- pystac.Catalog,
30
- Generic[T],
31
- PropertiesExtension,
32
- ExtensionManagementMixin[
33
- Union[pystac.item.Item, pystac.collection.Collection, pystac.catalog.Catalog]
34
- ],
35
- ):
36
- """An abstract class that can be used to extend the properties of a
37
- :class:`~pystac.Collection`, :class:`~pystac.Item`, or :class:`~pystac.Catalog` with
38
- properties from the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`. This class is
39
- generic over the type of STAC Object to be extended (e.g. :class:`~pystac.Item`,
40
- :class:`~pystac.Asset`).
41
-
42
- To create a concrete instance of :class:`MLDatasetExtension`, use the
43
- :meth:`MLDatasetExtension.ext` method. For example:
44
-
45
- .. code-block:: python
46
-
47
- >>> item: pystac.Item = ...
48
- >>> ml_ext = MLDatasetExtension.ext(item)
49
- """
50
-
51
- catalog: pystac.Catalog
52
- """The :class:`~pystac.Catalog` being extended."""
53
-
54
- properties: Dict[str, Any]
55
- """The :class:`~pystac.Catalog` extra fields, including extension properties."""
56
-
57
- links: List[pystac.Link]
58
- """The list of :class:`~pystac.Link` objects associated with the
59
- :class:`~pystac.Catalog` being extended, including links added by this extension.
60
- """
61
-
62
- def __init__(self, catalog: pystac.Catalog):
63
- super().__init__(id=catalog.id, description=catalog.description)
64
- self._catalog = catalog
65
- self.id = catalog.id
66
- self.description = catalog.description
67
- self.title = catalog.title if catalog.title else None
68
- self.stac_extensions = (
69
- catalog.stac_extensions if catalog.stac_extensions else []
70
- )
71
- self.extra_fields = self.properties = (
72
- catalog.extra_fields if catalog.extra_fields else {}
73
- )
74
- self.links = catalog.links
75
- self._resolved_objects = ResolvedObjectCache()
76
-
77
- def apply(self, name: str = None) -> None:
78
- """
79
- Applies the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>` to the extended
80
- :class:`~pystac.Catalog`.
81
- """
82
- self.name = name
83
-
84
- @property
85
- def name(self) -> str:
86
- """
87
- Name of the ML Dataset.
88
- """
89
- return self.extra_fields[f"{PREFIX}name"]
90
-
91
- @name.setter
92
- def name(self, v: str) -> None:
93
- """
94
- Set the name of the ML Dataset.
95
- """
96
- self.extra_fields[f"{PREFIX}name"] = v
97
-
98
- @property
99
- def tasks(self) -> List:
100
- """
101
- Tasks of the ML Dataset.
102
- """
103
- return self.extra_fields[f"{PREFIX}tasks"]
104
-
105
- @tasks.setter
106
- def tasks(self, v: Union[list, tuple]) -> None:
107
- """
108
- Set the tasks of the ML Dataset.
109
- """
110
- self.extra_fields[f"{PREFIX}tasks"] = v
111
-
112
- @property
113
- def type(self) -> str:
114
- """
115
- Type of the ML Dataset.
116
- """
117
- return self.extra_fields[f"{PREFIX}type"]
118
-
119
- @type.setter
120
- def type(self, v: str) -> None:
121
- """
122
- Set the type of the ML Dataset.
123
- """
124
- self.extra_fields[f"{PREFIX}type"] = v
125
-
126
- @property
127
- def inputs_type(self) -> str:
128
- """
129
- Inputs type of the ML Dataset.
130
- """
131
- return self.extra_fields[f"{PREFIX}inputs-type"]
132
-
133
- @inputs_type.setter
134
- def inputs_type(self, v: str) -> None:
135
- """
136
- Set the inputs type of the ML Dataset.
137
- """
138
- self.extra_fields[f"{PREFIX}inputs-type"] = v
139
-
140
- @property
141
- def annotations_type(self) -> str:
142
- """
143
- Annotations type of the ML Dataset.
144
- """
145
- return self.extra_fields[f"{PREFIX}annotations-type"]
146
-
147
- @annotations_type.setter
148
- def annotations_type(self, v: str) -> None:
149
- """
150
- Set the annotations type of the ML Dataset.
151
- """
152
- self.extra_fields[f"{PREFIX}annotations-type"] = v
153
-
154
- @property
155
- def splits(self) -> List[str]:
156
- """
157
- Splits of the ML Dataset.
158
- """
159
- return self.extra_fields[f"{PREFIX}splits"]
160
-
161
- @splits.setter
162
- def splits(self, v: dict) -> None:
163
- """
164
- Set the splits of the ML Dataset.
165
- """
166
- self.extra_fields[f"{PREFIX}splits"] = v
167
-
168
- @property
169
- def quality_metrics(self) -> List[dict]:
170
- """
171
- Quality metrics of the ML Dataset.
172
- """
173
- return self.extra_fields[f"{PREFIX}quality-metrics"]
174
-
175
- @quality_metrics.setter
176
- def quality_metrics(self, v: dict) -> None:
177
- """
178
- Set the quality metrics of the ML Dataset.
179
- """
180
- self.extra_fields[f"{PREFIX}quality-metrics"] = v
181
-
182
- @property
183
- def version(self) -> str:
184
- """
185
- Version of the ML Dataset.
186
- """
187
- return self.extra_fields[f"{PREFIX}version"]
188
-
189
- @version.setter
190
- def version(self, v: str) -> None:
191
- """
192
- Set the version of the ML Dataset.
193
- """
194
- self.extra_fields[f"{PREFIX}version"] = v
195
-
196
- @classmethod
197
- def get_schema_uri(cls) -> str:
198
- """
199
- Get the JSON Schema URI that validates the extended object.
200
- """
201
- return SCHEMA_URI
202
-
203
- def add_metric(self, metric: dict) -> None:
204
- """Add a metric to this object's set of metrics.
205
-
206
- Args:
207
- metric : The metric to add.
208
- """
209
- if not self.extra_fields.get(f"{PREFIX}quality-metrics"):
210
- self.extra_fields[f"{PREFIX}quality-metrics"] = []
211
-
212
- if metric not in self.extra_fields[f"{PREFIX}quality-metrics"]:
213
- self.extra_fields[f"{PREFIX}quality-metrics"].append(metric)
214
-
215
- def add_metrics(self, metrics: List[dict]) -> None:
216
- """Add a list of metrics to this object's set of metrics.
217
-
218
- Args:
219
- metrics : The metrics to add.
220
- """
221
- for metric in metrics:
222
- self.add_metric(metric)
223
-
224
- @classmethod
225
- def ext(cls, obj: T, add_if_missing: bool = False):
226
- """Extends the given STAC Object with properties from the
227
- :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`.
228
-
229
- This extension can be applied to instances of :class:`~pystac.Catalog`,
230
- :class:`~pystac.Collection` or :class:`~pystac.Item`.
231
-
232
- Raises:
233
- pystac.ExtensionTypeError : If an invalid object type is passed.
234
- """
235
- if isinstance(obj, pystac.Collection):
236
- cls.validate_has_extension(obj, add_if_missing)
237
- return CollectionMLDatasetExtension(obj)
238
- elif isinstance(obj, pystac.Catalog):
239
- cls.validate_has_extension(obj, add_if_missing)
240
- return MLDatasetExtension(obj)
241
- elif isinstance(obj, pystac.Item):
242
- cls.validate_has_extension(obj, add_if_missing)
243
- return ItemMLDatasetExtension(obj)
244
- else:
245
- raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
246
-
247
-
248
- class CollectionMLDatasetExtension(MLDatasetExtension[pystac.Collection]):
249
- """A concrete implementation of :class:`MLDatasetExtension` on an
250
- :class:`~pystac.Collection` that extends the properties of the Collection to include
251
- properties defined in the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`.
252
- """
253
-
254
- collection: pystac.Collection
255
- properties: Dict[str, Any]
256
-
257
- def __init__(self, collection: pystac.Collection):
258
- self.collection = collection
259
- self.properties = collection.extra_fields
260
- self.properties[f"{PREFIX}split-items"] = []
261
-
262
- def __repr__(self) -> str:
263
- return f"<CollectionMLDatasetExtension Item id={self.collection.id}>"
264
-
265
- @property
266
- def splits(self) -> List[dict]:
267
- """
268
- Splits of the ML Dataset.
269
- """
270
- return self.extra_fields[f"{PREFIX}splits"]
271
-
272
- @splits.setter
273
- def splits(self, v: dict) -> None:
274
- """
275
- Set the splits of the ML Dataset.
276
- """
277
- self.properties[f"{PREFIX}split-items"] = v
278
-
279
- def add_split(self, v: dict) -> None:
280
- """
281
- Add a split to the ML Dataset.
282
- """
283
- self.properties[f"{PREFIX}split-items"].append(v)
284
-
285
- def create_and_add_split(
286
- self, split_data: List[pystac.Item], split_type: str
287
- ) -> None:
288
- """
289
- Create and add a split to the ML Dataset.
290
- """
291
- items_ids = [item.id for item in split_data]
292
- items_ids.sort()
293
-
294
- split = {"name": split_type, "items": items_ids}
295
-
296
- if not self.properties.get(f"{PREFIX}split-items"):
297
- self.properties[f"{PREFIX}split-items"] = []
298
-
299
- self.add_split(split)
300
- print(f"Generating {split_type} split...")
301
- for _item in tqdm(split_data):
302
- item = self.collection.get_item(_item.id)
303
- if item:
304
- item_ml = MLDatasetExtension.ext(item, add_if_missing=True)
305
- item_ml.split = split_type
306
-
307
-
308
- class ItemMLDatasetExtension(MLDatasetExtension[pystac.Item]):
309
- """A concrete implementation of :class:`MLDatasetExtension` on an
310
- :class:`~pystac.Item` that extends the properties of the Item to include properties
311
- defined in the :stac-ext:`Machine Learning Dataset Extension <ml-dataset>`.
312
-
313
- This class should generally not be instantiated directly. Instead, call
314
- :meth:`MLDatasetExtension.ext` on an :class:`~pystac.Item` to extend it.
315
- """
316
-
317
- item: pystac.Item
318
- properties: Dict[str, Any]
319
-
320
- def __init__(self, item: pystac.Item):
321
- self.item = item
322
- self.properties = item.properties
323
-
324
- @property
325
- def split(self) -> str:
326
- """
327
- Split of the ML Dataset.
328
- """
329
- return self.properties[f"{PREFIX}split"]
330
-
331
- @split.setter
332
- def split(self, v: str) -> None:
333
- """
334
- Set the split of the ML Dataset.
335
- """
336
- self.properties[f"{PREFIX}split"] = v
337
-
338
- def __repr__(self) -> str:
339
- return f"<ItemMLDatasetExtension Item id={self.item.id}>"
340
-
341
-
342
- class MLDatasetQualityMetrics:
343
- """
344
- ML Dataset Quality Metrics
345
- """
346
-
347
- @classmethod
348
- def calculate(cls, catalog: Union[pystac.Catalog, str]) -> None:
349
- """
350
- Calculate the quality metrics of the catalog
351
- """
352
- if isinstance(catalog, str):
353
- catalog = MLDatasetExtension(pystac.read_file(catalog))
354
- elif isinstance(catalog, pystac.Catalog):
355
- catalog = MLDatasetExtension(catalog)
356
- # Check the catalog has the extension
357
- if not MLDatasetExtension.has_extension(catalog):
358
- raise pystac.ExtensionNotImplemented(
359
- f"MLDatasetExtension does not apply to type '{type(catalog).__name__}'"
360
- )
361
-
362
- try:
363
- catalog.add_metric(cls._search_spatial_duplicates(catalog))
364
- catalog.add_metric(cls._get_classes_balance(catalog))
365
- except AttributeError as exc:
366
- raise pystac.ExtensionNotImplemented(
367
- f"The catalog does not have the required properties or the ML-Dataset extension to calculate the metrics: {exc}"
368
- )
369
- finally:
370
- catalog.make_all_asset_hrefs_relative()
371
-
372
- try:
373
- print("Validating and saving...")
374
- catalog.validate()
375
- destination = dirname(catalog.get_self_href())
376
- rmtree(
377
- destination
378
- ) # Remove the old catalog and replace it with the new one
379
- catalog.set_root(catalog)
380
- catalog.normalize_and_save(root_href=destination)
381
- print("Success!")
382
- except STACValidationError:
383
- # Return full callback
384
- traceback.print_exc()
385
-
386
- @staticmethod
387
- def _search_spatial_duplicates(catalog: pystac.Catalog):
388
- """
389
- Search for spatial duplicates in the catalog
390
- """
391
- items = list(
392
- set(
393
- [
394
- item
395
- for item in tqdm(
396
- catalog.get_items(recursive=True),
397
- desc="Looking for spatial duplicates...",
398
- )
399
- if not LabelExtension.has_extension(item)
400
- ]
401
- )
402
- )
403
-
404
- # Initialize the spatial duplicates dict
405
- spatial_duplicates = {"name": "spatial-duplicates", "values": [], "total": 0}
406
-
407
- items_bboxes = {}
408
- for item in items:
409
- # Get the item bounding box
410
- bbox = str(item.bbox)
411
- # If the bounding box is not in the items dict, add it
412
- if bbox not in items_bboxes.keys():
413
- items_bboxes[bbox] = item.id
414
- # If the bounding box is already in the items dict, add it to the duplicates dict
415
- else:
416
- spatial_duplicates["values"].append(
417
- {"item": item.id, "duplicate": items_bboxes[bbox]}
418
- )
419
- spatial_duplicates["total"] += 1
420
-
421
- return spatial_duplicates
422
-
423
- @staticmethod
424
- def _get_classes_balance(catalog: pystac.Catalog) -> dict:
425
- """
426
- Get the classes balance of the catalog
427
- """
428
-
429
- def get_label_properties(items: List[pystac.Item]) -> List:
430
- """
431
- Get the label properties of the catalog
432
- """
433
- label_properties = []
434
- for label in items:
435
- label_ext = LabelExtension.ext(label)
436
- for prop in label_ext.label_properties:
437
- if prop not in label_properties:
438
- label_properties.append(prop)
439
-
440
- return label_properties
441
-
442
- catalog.make_all_asset_hrefs_absolute()
443
-
444
- labels = list(
445
- set(
446
- [
447
- item
448
- for item in tqdm(
449
- catalog.get_items(recursive=True),
450
- desc="Calculating classes balance...",
451
- )
452
- if LabelExtension.has_extension(item)
453
- ]
454
- )
455
- )
456
-
457
- # Initialize the classes balance dict
458
- classes_balance = {"name": "classes-balance", "values": []}
459
- label_properties = get_label_properties(labels)
460
-
461
- for prop in label_properties:
462
- property_balance = {"name": prop, "values": []}
463
- properties = {}
464
- for label in labels:
465
- if 'labels' not in label.assets:
466
- continue
467
- asset_path = label.assets["labels"].href
468
- # Open the linked geoJSON to obtain the label properties
469
- try:
470
- with open(asset_path, mode="r", encoding="utf-8") as f:
471
- label_data = json.load(f)
472
- except FileNotFoundError:
473
- raise FileNotFoundError(
474
- f"The file {asset_path} does not exist. Make sure the assets hrefs are correct"
475
- )
476
- # Get the property
477
- for feature in label_data["features"]:
478
- if prop in feature["properties"]:
479
- property_value = feature["properties"][prop]
480
- else:
481
- if feature["properties"]["labels"]:
482
- property_value = feature["properties"]["labels"][0]
483
- else:
484
- continue
485
- if property_value not in properties:
486
- properties[property_value] = 0
487
- properties[property_value] += 1
488
-
489
- # Create the property balance dict
490
- total_labels = sum(properties.values())
491
- for key, value in properties.items():
492
- property_balance["values"].append(
493
- {
494
- "class": key,
495
- "total": value,
496
- "percentage": int(value / total_labels * 100),
497
- }
498
- )
499
-
500
- classes_balance["values"].append(property_balance)
501
-
502
- catalog.make_all_asset_hrefs_relative()
503
-
504
- return classes_balance
505
-
506
-
507
- class MLDatasetExtensionHooks(ExtensionHooks):
508
- """
509
- ML Dataset Extension Hooks
510
- """
511
- schema_uri: str = SCHEMA_URI
512
- prev_extension_ids: Set[str] = set()
513
- stac_object_types = {
514
- pystac.STACObjectType.CATALOG,
515
- pystac.STACObjectType.COLLECTION,
516
- pystac.STACObjectType.ITEM,
517
- }
518
-
519
-
520
- STORAGE_EXTENSION_HOOKS: ExtensionHooks = MLDatasetExtensionHooks()
521
-
522
-
523
- def add_ml_extension(
524
- catalog: Union[pystac.Catalog, str],
525
- destination: Optional[str] = None,
526
- splits: Optional[bool] = False,
527
- splits_collection_id: Optional[str] = "labels",
528
- splits_names: Optional[list] = ("Training", "Validation", "Test"),
529
- split_proportions: Optional[List[int]] = (80, 10, 10),
530
- catalog_type: Optional[pystac.CatalogType] = pystac.CatalogType.SELF_CONTAINED,
531
- **kwargs,
532
- ) -> None:
533
- """
534
- Adds the ML Dataset extension to a STAC catalog.
535
- """
536
- if not isinstance(catalog, pystac.Catalog) and isinstance(catalog, str):
537
- catalog = pystac.read_file(catalog)
538
- elif isinstance(catalog, pystac.Catalog):
539
- pass
540
- else:
541
- raise pystac.ExtensionTypeError(
542
- f"MLDatasetExtension does not apply to type '{type(catalog).__name__}'"
543
- )
544
-
545
- catalog_ml_dataset = MLDatasetExtension.ext(catalog, add_if_missing=True)
546
- if destination:
547
- catalog_ml_dataset.set_self_href(destination + "/catalog.json")
548
- else:
549
- destination = dirname(catalog.get_self_href())
550
- catalog_ml_dataset.set_root(catalog_ml_dataset)
551
-
552
- # Set extension properties
553
- for key, value in kwargs.items():
554
- setattr(catalog_ml_dataset, key, value)
555
-
556
- # Make splits if needed
557
- if splits:
558
- catalog_ml_dataset.splits = splits_names # Add the splits names to the catalog
559
- train_size, test_size, val_size = split_proportions
560
- splits_collection = catalog.get_child(
561
- splits_collection_id
562
- ) # Get the collection to split
563
- if not splits_collection:
564
- raise AttributeError(
565
- f"The catalog does not have a collection with the id {splits_collection_id}"
566
- )
567
- make_splits(
568
- splits_collection,
569
- train_size=train_size,
570
- test_size=test_size,
571
- val_size=val_size,
572
- **kwargs,
573
- )
574
-
575
- # Normalize the ref on the same folder
576
- if destination:
577
- catalog_ml_dataset = make_links_relative_to_path(
578
- destination, catalog_ml_dataset
579
- )
580
-
581
- try:
582
- print("Validating and saving...")
583
- catalog_ml_dataset.validate()
584
- rmtree(destination) if exists(
585
- destination
586
- ) else None # Remove the old catalog and replace it with the new one
587
- catalog_ml_dataset.normalize_and_save(
588
- root_href=destination, catalog_type=catalog_type
589
- )
590
- print("Success!")
591
- except STACValidationError:
592
- # Return full callback
593
- traceback.print_exc()
594
-
595
-
596
- def make_splits(
597
- labels_collection: Union[CollectionMLDatasetExtension, pystac.Collection, str],
598
- splits_names: Optional[List[str]] = ("Training", "Validation", "Test"),
599
- splits_proportions: Optional[List[int]] = (80, 10, 10),
600
- verbose: Optional[bool] = True,
601
- **kwargs: Optional[dict],
602
- ) -> None:
603
- """
604
- Makes the splits of the labels collection.
605
- """
606
- if isinstance(labels_collection, str):
607
- labels_collection = pystac.read_file(labels_collection)
608
-
609
- train_size, test_size, val_size = splits_proportions
610
-
611
- if train_size + test_size + val_size != 100:
612
- raise ValueError("The sum of the splits must be 100")
613
-
614
- # Get all items in the labels collection
615
- items = list(labels_collection.get_items(recursive=True))
616
-
617
- # Calculate indices to split the items
618
- length = len(items)
619
- idx_train = int(train_size / 100 * length)
620
- idx_test = int(test_size / 100 * length)
621
- if val_size:
622
- idx_val = int(val_size / 100 * length)
623
-
624
- print("Generating splits...")
625
- if verbose:
626
- print(f"Total size: {length}")
627
- print(f"Train size: {idx_train}")
628
- print(f"Test size: {idx_test}")
629
- if val_size:
630
- print(f"Validation size: {idx_val}")
631
-
632
- # Make sure the items are shuffled
633
- random.shuffle(items)
634
-
635
- # Split the items
636
- train_items = items[:idx_train]
637
- test_items = items[idx_train: idx_train + idx_test]
638
- if val_size:
639
- val_items = items[idx_train + idx_test: idx_train + idx_test + idx_val]
640
-
641
- # Create the splits in the collection
642
- labels_collection = MLDatasetExtension.ext(labels_collection, add_if_missing=True)
643
- for split_type, split_data in zip(
644
- splits_names, [train_items, test_items, val_items]
645
- ):
646
- labels_collection.create_and_add_split(split_data, split_type)
647
-
648
- print("Success on splits generation!")
@@ -1,44 +0,0 @@
1
- """
2
- Module for projection STAC extensions object
3
- """
4
-
5
- from typing import Union
6
-
7
- import pystac
8
- import pandas as pd
9
- import rasterio
10
-
11
- from pystac.extensions.projection import ProjectionExtension
12
- from .base import STACExtensionObject
13
-
14
-
15
- class ProjExtensionObject(STACExtensionObject):
16
- """
17
- Projection extension object
18
- """
19
- def __init__(self) -> None:
20
- super().__init__()
21
-
22
- def add_extension_to_object(
23
- self, obj: Union[pystac.Item, pystac.Asset], obj_info: pd.DataFrame
24
- ) -> Union[pystac.Item, pystac.Asset]:
25
- """
26
- Add the extension to the given object
27
-
28
- :param obj: object to add the extension
29
- :param obj_info: object info from the STACDataFrame
30
- """
31
- # Add raster extension to the item
32
- if isinstance(obj, pystac.Asset):
33
- return obj
34
- elif isinstance(obj, pystac.Item):
35
- proj_ext = ProjectionExtension.ext(obj, add_if_missing=True)
36
- ds = rasterio.open(obj_info["image"].values[0])
37
- # Assume all the bands have the same projection
38
- proj_ext.apply(
39
- epsg=ds.crs.to_epsg(),
40
- transform=ds.transform,
41
- shape=ds.shape,
42
- )
43
-
44
- return obj