ssb-sgis 1.0.5__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ import math
6
6
  import os
7
7
  import random
8
8
  import re
9
+ import time
9
10
  from collections.abc import Callable
10
11
  from collections.abc import Iterable
11
12
  from collections.abc import Iterator
@@ -27,8 +28,6 @@ from geopandas import GeoDataFrame
27
28
  from geopandas import GeoSeries
28
29
  from matplotlib.colors import LinearSegmentedColormap
29
30
  from rasterio.enums import MergeAlg
30
- from rtree.index import Index
31
- from rtree.index import Property
32
31
  from scipy import stats
33
32
  from scipy.ndimage import binary_dilation
34
33
  from scipy.ndimage import binary_erosion
@@ -49,24 +48,15 @@ except ImportError:
49
48
 
50
49
 
51
50
  try:
52
- from rioxarray.exceptions import NoDataInBounds
53
- from rioxarray.merge import merge_arrays
54
- from rioxarray.rioxarray import _generate_spatial_coords
55
- except ImportError:
56
- pass
57
- try:
58
- import xarray as xr
59
- from xarray import DataArray
51
+ from google.auth import exceptions
60
52
  except ImportError:
61
53
 
62
- class DataArray:
54
+ class exceptions:
63
55
  """Placeholder."""
64
56
 
57
+ class RefreshError:
58
+ """Placeholder."""
65
59
 
66
- try:
67
- import torch
68
- except ImportError:
69
- pass
70
60
 
71
61
  try:
72
62
  from gcsfs.core import GCSFile
@@ -77,26 +67,22 @@ except ImportError:
77
67
 
78
68
 
79
69
  try:
80
- from torchgeo.datasets.utils import disambiguate_timestamp
70
+ from rioxarray.exceptions import NoDataInBounds
71
+ from rioxarray.merge import merge_arrays
72
+ from rioxarray.rioxarray import _generate_spatial_coords
81
73
  except ImportError:
82
-
83
- class torch:
84
- """Placeholder."""
85
-
86
- class Tensor:
87
- """Placeholder to reference torch.Tensor."""
88
-
89
-
74
+ pass
90
75
  try:
91
- from torchgeo.datasets.utils import BoundingBox
76
+ import xarray as xr
77
+ from xarray import DataArray
78
+ from xarray import Dataset
92
79
  except ImportError:
93
80
 
94
- class BoundingBox:
81
+ class DataArray:
95
82
  """Placeholder."""
96
83
 
97
- def __init__(self, *args, **kwargs) -> None:
98
- """Placeholder."""
99
- raise ImportError("missing optional dependency 'torchgeo'")
84
+ class Dataset:
85
+ """Placeholder."""
100
86
 
101
87
 
102
88
  from ..geopandas_tools.bounds import get_total_bounds
@@ -115,6 +101,12 @@ from .base import _get_shape_from_bounds
115
101
  from .base import _get_transform_from_bounds
116
102
  from .base import get_index_mapper
117
103
  from .indices import ndvi
104
+ from .regex import _any_regex_matches
105
+ from .regex import _extract_regex_match_from_string
106
+ from .regex import _get_first_group_match
107
+ from .regex import _get_non_optional_groups
108
+ from .regex import _get_regexes_matches_for_df
109
+ from .regex import _RegexError
118
110
  from .zonal import _aggregate
119
111
  from .zonal import _make_geometry_iterrows
120
112
  from .zonal import _no_overlap_df
@@ -132,9 +124,6 @@ if is_dapla():
132
124
  def _open_func(*args, **kwargs) -> GCSFile:
133
125
  return dp.FileClient.get_gcs_file_system().open(*args, **kwargs)
134
126
 
135
- def _rm_file_func(*args, **kwargs) -> None:
136
- return dp.FileClient.get_gcs_file_system().rm_file(*args, **kwargs)
137
-
138
127
  def _read_parquet_func(*args, **kwargs) -> list[str]:
139
128
  return dp.read_pandas(*args, **kwargs)
140
129
 
@@ -142,22 +131,25 @@ else:
142
131
  _ls_func = functools.partial(get_all_files, recursive=False)
143
132
  _open_func = open
144
133
  _glob_func = glob.glob
145
- _rm_file_func = os.remove
146
134
  _read_parquet_func = pd.read_parquet
147
135
 
148
- TORCHGEO_RETURN_TYPE = dict[str, torch.Tensor | pyproj.CRS | BoundingBox]
136
+ DATE_RANGES_TYPE = (
137
+ tuple[str | pd.Timestamp | None, str | pd.Timestamp | None]
138
+ | tuple[tuple[str | pd.Timestamp | None, str | pd.Timestamp | None], ...]
139
+ )
140
+
149
141
  FILENAME_COL_SUFFIX = "_filename"
142
+
150
143
  DEFAULT_FILENAME_REGEX = r"""
151
144
  .*?
152
- (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
145
+ (?:_?(?P<date>\d{8}(?:T\d{6})?))? # Optional underscore and date group
153
146
  .*?
154
- (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
147
+ (?:_?(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional underscore and band group
155
148
  \.(?:tif|tiff|jp2)$ # End with .tif, .tiff, or .jp2
156
149
  """
157
150
  DEFAULT_IMAGE_REGEX = r"""
158
151
  .*?
159
- (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
160
- (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
152
+ (?:_?(?P<date>\d{8}(?:T\d{6})?))? # Optional underscore and date group
161
153
  """
162
154
 
163
155
  ALLOWED_INIT_KWARGS = [
@@ -165,13 +157,12 @@ ALLOWED_INIT_KWARGS = [
165
157
  "band_class",
166
158
  "image_regexes",
167
159
  "filename_regexes",
168
- "date_format",
169
- "cloud_cover_regexes",
170
160
  "bounds_regexes",
171
161
  "all_bands",
172
162
  "crs",
173
163
  "masking",
174
164
  "_merged",
165
+ "_add_metadata_attributes",
175
166
  ]
176
167
 
177
168
 
@@ -293,9 +284,38 @@ class ImageCollectionGroupBy:
293
284
  return f"{self.__class__.__name__}({len(self)})"
294
285
 
295
286
 
287
+ def standardize_band_id(x: str) -> str:
288
+ return x.replace("B", "").replace("A", "").zfill(2)
289
+
290
+
291
+ class BandIdDict(dict):
292
+ """Dict that tells the band initialiser to get the dict value of the band_id."""
293
+
294
+ def __init__(self, data: dict | None = None, **kwargs) -> None:
295
+ """Add dicts or kwargs."""
296
+ self._standardized_keys = {}
297
+ for key, value in ((data or {}) | kwargs).items():
298
+ setattr(self, key, value)
299
+ self._standardized_keys[standardize_band_id(key)] = value
300
+
301
+ def __len__(self) -> int:
302
+ """Number of items."""
303
+ return len({key for key in self.__dict__ if key != "_standardized_keys"})
304
+
305
+ def __getitem__(self, item: str) -> Any:
306
+ """Get dict value from key."""
307
+ try:
308
+ return getattr(self, item)
309
+ except AttributeError as e:
310
+ try:
311
+ return self._standardized_keys[standardize_band_id(item)]
312
+ except KeyError:
313
+ raise KeyError(item, self.__dict__) from e
314
+
315
+
296
316
  @dataclass(frozen=True)
297
317
  class BandMasking:
298
- """Basically a frozen dict with forced keys."""
318
+ """Frozen dict with forced keys."""
299
319
 
300
320
  band_id: str
301
321
  values: tuple[int]
@@ -305,19 +325,27 @@ class BandMasking:
305
325
  return getattr(self, item)
306
326
 
307
327
 
328
+ class NoLevel:
329
+ """Equivelant to None."""
330
+
331
+
308
332
  class _ImageBase:
309
333
  image_regexes: ClassVar[str | None] = (DEFAULT_IMAGE_REGEX,)
310
334
  filename_regexes: ClassVar[str | tuple[str]] = (DEFAULT_FILENAME_REGEX,)
311
- date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
335
+ metadata_attributes: ClassVar[dict | None] = None
312
336
  masking: ClassVar[BandMasking | None] = None
313
337
 
314
- def __init__(self, **kwargs) -> None:
338
+ def __init__(self, *, bbox=None, **kwargs) -> None:
315
339
 
316
340
  self._mask = None
317
341
  self._bounds = None
318
342
  self._merged = False
319
343
  self._from_array = False
320
344
  self._from_gdf = False
345
+ self.metadata_attributes = self.metadata_attributes or {}
346
+ self._path = None
347
+
348
+ self._bbox = to_bbox(bbox) if bbox is not None else None
321
349
 
322
350
  if self.filename_regexes:
323
351
  if isinstance(self.filename_regexes, str):
@@ -381,7 +409,6 @@ class _ImageBase:
381
409
  for pat in patterns:
382
410
  try:
383
411
  return _get_first_group_match(pat, self.name)[group]
384
- return re.match(pat, self.name).group(group)
385
412
  except (TypeError, KeyError):
386
413
  pass
387
414
  if not any(group in _get_non_optional_groups(pat) for pat in patterns):
@@ -394,18 +421,18 @@ class _ImageBase:
394
421
  """Create a dataframe with file paths and image paths that match regexes."""
395
422
  df = pd.DataFrame({"file_path": file_paths})
396
423
 
397
- df["filename"] = df["file_path"].apply(lambda x: _fix_path(Path(x).name))
424
+ df["file_path"] = df["file_path"].apply(_fix_path)
425
+ df["filename"] = df["file_path"].apply(lambda x: Path(x).name)
398
426
 
399
- if not self.single_banded:
400
- df["image_path"] = df["file_path"].apply(
401
- lambda x: _fix_path(str(Path(x).parent))
402
- )
403
- else:
404
- df["image_path"] = df["file_path"]
427
+ df["image_path"] = df["file_path"].apply(
428
+ lambda x: _fix_path(str(Path(x).parent))
429
+ )
405
430
 
406
431
  if not len(df):
407
432
  return df
408
433
 
434
+ df = df[~df["file_path"].isin(df["image_path"])]
435
+
409
436
  if self.filename_patterns:
410
437
  df = _get_regexes_matches_for_df(df, "filename", self.filename_patterns)
411
438
 
@@ -446,8 +473,19 @@ class _ImageBase:
446
473
  continue
447
474
  return copied
448
475
 
476
+ def equals(self, other) -> bool:
477
+ for key, value in self.__dict__.items():
478
+ if key.startswith("_"):
479
+ continue
480
+ if value != getattr(other, key):
481
+ print(key, value, getattr(other, key))
482
+ return False
483
+ return True
484
+
449
485
 
450
486
  class _ImageBandBase(_ImageBase):
487
+ """Common parent class of Image and Band."""
488
+
451
489
  def intersects(self, other: GeoDataFrame | GeoSeries | Geometry) -> bool:
452
490
  if hasattr(other, "crs") and not pyproj.CRS(self.crs).equals(
453
491
  pyproj.CRS(other.crs)
@@ -455,6 +493,12 @@ class _ImageBandBase(_ImageBase):
455
493
  raise ValueError(f"crs mismatch: {self.crs} and {other.crs}")
456
494
  return self.union_all().intersects(to_shapely(other))
457
495
 
496
+ def union_all(self) -> Polygon:
497
+ try:
498
+ return box(*self.bounds)
499
+ except TypeError:
500
+ return Polygon()
501
+
458
502
  @property
459
503
  def mask_percentage(self) -> float:
460
504
  return self.mask.values.sum() / (self.mask.width * self.mask.height) * 100
@@ -495,31 +539,57 @@ class _ImageBandBase(_ImageBase):
495
539
  def level(self) -> str:
496
540
  return self._name_regex_searcher("level", self.image_patterns)
497
541
 
498
- @property
499
- def mint(self) -> float:
500
- return disambiguate_timestamp(self.date, self.date_format)[0]
542
+ def _add_metadata_attributes(self):
501
543
 
502
- @property
503
- def maxt(self) -> float:
504
- return disambiguate_timestamp(self.date, self.date_format)[1]
544
+ missing_attributes = {}
545
+ for key, value in self.metadata_attributes.items():
546
+ if getattr(self, key) is None:
547
+ missing_attributes[key] = value
505
548
 
506
- def union_all(self) -> Polygon:
507
- try:
508
- return box(*self.bounds)
509
- except TypeError:
510
- return Polygon()
549
+ if not missing_attributes:
550
+ return
511
551
 
512
- @property
513
- def torch_bbox(self) -> BoundingBox:
514
- bounds = GeoSeries([self.union_all()]).bounds
515
- return BoundingBox(
516
- minx=bounds.minx[0],
517
- miny=bounds.miny[0],
518
- maxx=bounds.maxx[0],
519
- maxy=bounds.maxy[0],
520
- mint=self.mint,
521
- maxt=self.maxt,
522
- )
552
+ file_contents: list[str] = []
553
+ for path in self._all_file_paths:
554
+ if ".xml" not in path:
555
+ continue
556
+ with _open_func(path, "rb") as file:
557
+ file_contents.append(file.read().decode("utf-8"))
558
+
559
+ for key, value in missing_attributes.items():
560
+ results = None
561
+ for i, filetext in enumerate(file_contents):
562
+ if isinstance(value, str) and value in dir(self):
563
+ method = getattr(self, value)
564
+ try:
565
+ results = method(filetext)
566
+ except _RegexError as e:
567
+ if i == len(self._all_file_paths) - 1:
568
+ raise e
569
+ continue
570
+ if results is not None:
571
+ break
572
+
573
+ if callable(value):
574
+ try:
575
+ results = value(filetext)
576
+ except _RegexError as e:
577
+ if i == len(self._all_file_paths) - 1:
578
+ raise e
579
+ continue
580
+ if results is not None:
581
+ break
582
+
583
+ try:
584
+ results = _extract_regex_match_from_string(filetext, value)
585
+ except _RegexError as e:
586
+ if i == len(self._all_file_paths) - 1:
587
+ raise e
588
+
589
+ if isinstance(results, BandIdDict) and isinstance(self, Band):
590
+ results = results[self.band_id]
591
+
592
+ setattr(self, key, results)
523
593
 
524
594
 
525
595
  class Band(_ImageBandBase):
@@ -561,28 +631,36 @@ class Band(_ImageBandBase):
561
631
  res: int | None,
562
632
  crs: Any | None = None,
563
633
  bounds: tuple[float, float, float, float] | None = None,
564
- cmap: str | None = None,
565
- name: str | None = None,
634
+ nodata: int | None = None,
635
+ mask: "Band | None" = None,
566
636
  file_system: GCSFileSystem | None = None,
567
- band_id: str | None = None,
568
637
  processes: int = 1,
569
- bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
570
- mask: "Band | None" = None,
571
- nodata: int | None = None,
638
+ name: str | None = None,
639
+ band_id: str | None = None,
640
+ cmap: str | None = None,
641
+ all_file_paths: list[str] | None = None,
572
642
  **kwargs,
573
643
  ) -> None:
574
644
  """Band initialiser."""
575
645
  super().__init__(**kwargs)
576
646
 
647
+ if isinstance(data, (str | Path | os.PathLike)) and any(
648
+ arg is not None for arg in [crs, bounds]
649
+ ):
650
+ raise ValueError("Can only specify 'bounds' and 'crs' if data is an array.")
651
+
577
652
  self._mask = mask
578
- self._bbox = to_bbox(bbox) if bbox is not None else None
579
653
  self._values = None
580
- self._crs = None
581
654
  self.nodata = nodata
582
-
655
+ self._crs = crs
583
656
  bounds = to_bbox(bounds) if bounds is not None else None
584
-
585
657
  self._bounds = bounds
658
+ self._all_file_paths = all_file_paths
659
+
660
+ self._image = None
661
+
662
+ for key in self.metadata_attributes:
663
+ setattr(self, key, None)
586
664
 
587
665
  if isinstance(data, np.ndarray):
588
666
  self.values = data
@@ -610,19 +688,34 @@ class Band(_ImageBandBase):
610
688
  self._band_id = band_id
611
689
  self.processes = processes
612
690
 
613
- # if self.filename_regexes:
614
- # if isinstance(self.filename_regexes, str):
615
- # self.filename_regexes = [self.filename_regexes]
616
- # self.filename_patterns = [
617
- # re.compile(pat, flags=re.VERBOSE) for pat in self.filename_regexes
618
- # ]
619
- # else:
620
- # self.filename_patterns = None
691
+ if (
692
+ kwargs.get("_add_metadata_attributes", True)
693
+ and self.metadata_attributes
694
+ and self.path is not None
695
+ ):
696
+ if self._all_file_paths is None:
697
+ self._all_file_paths = _get_all_file_paths(str(Path(self.path).parent))
698
+ self._add_metadata_attributes()
621
699
 
622
700
  def __lt__(self, other: "Band") -> bool:
623
701
  """Makes Bands sortable by band_id."""
624
702
  return self.band_id < other.band_id
625
703
 
704
+ # def __getattribute__(self, attr: str) -> Any:
705
+ # # try:
706
+ # # value =
707
+ # # except AttributeError:
708
+ # # value = None
709
+
710
+ # if (
711
+ # attr in (super().__getattribute__("metadata_attributes") or {})
712
+ # and super().__getattribute__(attr) is None
713
+ # ):
714
+ # if self._all_file_paths is None:
715
+ # self._all_file_paths = _get_all_file_paths(str(Path(self.path).parent))
716
+ # self._add_metadata_attributes()
717
+ # return super().__getattribute__(attr)
718
+
626
719
  @property
627
720
  def values(self) -> np.ndarray:
628
721
  """The numpy array, if loaded."""
@@ -688,24 +781,22 @@ class Band(_ImageBandBase):
688
781
  @property
689
782
  def crs(self) -> str | None:
690
783
  """Coordinate reference system."""
691
- if self._crs is not None:
692
- return self._crs
693
- with opener(self.path, file_system=self.file_system) as file:
694
- with rasterio.open(file) as src:
695
- # self._bounds = to_bbox(src.bounds)
696
- self._crs = src.crs
784
+ if self._crs is None:
785
+ self._add_crs_and_bounds()
697
786
  return self._crs
698
787
 
699
788
  @property
700
789
  def bounds(self) -> tuple[int, int, int, int] | None:
701
790
  """Bounds as tuple (minx, miny, maxx, maxy)."""
702
- if self._bounds is not None:
703
- return self._bounds
791
+ if self._bounds is None:
792
+ self._add_crs_and_bounds()
793
+ return self._bounds
794
+
795
+ def _add_crs_and_bounds(self) -> None:
704
796
  with opener(self.path, file_system=self.file_system) as file:
705
797
  with rasterio.open(file) as src:
706
798
  self._bounds = to_bbox(src.bounds)
707
799
  self._crs = src.crs
708
- return self._bounds
709
800
 
710
801
  def get_n_largest(
711
802
  self, n: int, precision: float = 0.000001, column: str = "value"
@@ -745,44 +836,29 @@ class Band(_ImageBandBase):
745
836
 
746
837
  bounds_was_none = bounds is None
747
838
 
748
- try:
749
- if not isinstance(self.values, np.ndarray):
750
- raise ValueError()
751
- has_array = True
752
- except ValueError: # also catches ArrayNotLoadedError
753
- has_array = False
754
-
755
- # get common bounds of function argument 'bounds' and previously set bbox
756
- if bounds is None and self._bbox is None:
757
- bounds = None
758
- elif bounds is not None and self._bbox is None:
759
- bounds = to_shapely(bounds).intersection(self.union_all())
760
- elif bounds is None and self._bbox is not None:
761
- bounds = to_shapely(self._bbox).intersection(self.union_all())
762
- else:
763
- bounds = to_shapely(bounds).intersection(to_shapely(self._bbox))
839
+ bounds = _get_bounds(bounds, self._bbox)
764
840
 
765
841
  should_return_empty: bool = bounds is not None and bounds.area == 0
766
842
  if should_return_empty:
767
843
  self._values = np.array([])
768
844
  if self.mask is not None and not self.is_mask:
769
845
  self._mask = self._mask.load()
770
- # self._mask = np.ma.array([], [])
771
846
  self._bounds = None
772
847
  self.transform = None
848
+ try:
849
+ self._image._mask = self._mask
850
+ except AttributeError:
851
+ pass
773
852
  return self
774
853
 
775
- if has_array and bounds_was_none:
854
+ if self.has_array and bounds_was_none:
776
855
  return self
777
856
 
778
857
  # round down/up to integer to avoid precision trouble
779
858
  if bounds is not None:
780
- # bounds = to_bbox(bounds)
781
859
  minx, miny, maxx, maxy = to_bbox(bounds)
782
860
  bounds = (int(minx), int(miny), math.ceil(maxx), math.ceil(maxy))
783
861
 
784
- boundless = False
785
-
786
862
  if indexes is None:
787
863
  indexes = 1
788
864
 
@@ -792,7 +868,7 @@ class Band(_ImageBandBase):
792
868
  # allow setting a fixed out_shape for the array, in order to make mask same shape as values
793
869
  out_shape = kwargs.pop("out_shape", None)
794
870
 
795
- if has_array:
871
+ if self.has_array:
796
872
  self.values = _clip_loaded_array(
797
873
  self.values, bounds, self.transform, self.crs, out_shape, **kwargs
798
874
  )
@@ -849,7 +925,7 @@ class Band(_ImageBandBase):
849
925
  self._values = src.read(
850
926
  indexes=indexes,
851
927
  window=window,
852
- boundless=boundless,
928
+ boundless=False,
853
929
  out_shape=out_shape,
854
930
  masked=masked,
855
931
  **kwargs,
@@ -884,13 +960,15 @@ class Band(_ImageBandBase):
884
960
  )
885
961
  mask_arr = self.mask.values
886
962
 
887
- # if self.masking:
888
- # mask_arr = np.isin(mask_arr, self.masking["values"])
889
-
890
963
  self._values = np.ma.array(
891
964
  self._values, mask=mask_arr, fill_value=self.nodata
892
965
  )
893
966
 
967
+ try:
968
+ self._image._mask = self._mask
969
+ except AttributeError:
970
+ pass
971
+
894
972
  return self
895
973
 
896
974
  @property
@@ -898,6 +976,16 @@ class Band(_ImageBandBase):
898
976
  """True if the band_id is equal to the masking band_id."""
899
977
  return self.band_id == self.masking["band_id"]
900
978
 
979
+ @property
980
+ def has_array(self) -> bool:
981
+ """Whether the array is loaded."""
982
+ try:
983
+ if not isinstance(self.values, np.ndarray):
984
+ raise ValueError()
985
+ return True
986
+ except ValueError: # also catches ArrayNotLoadedError
987
+ return False
988
+
901
989
  def write(
902
990
  self, path: str | Path, driver: str = "GTiff", compress: str = "LZW", **kwargs
903
991
  ) -> None:
@@ -1154,131 +1242,6 @@ class NDVIBand(Band):
1154
1242
  # return get_cmap(arr)
1155
1243
 
1156
1244
 
1157
- def get_cmap(arr: np.ndarray) -> LinearSegmentedColormap:
1158
-
1159
- # blue = [[i / 10 + 0.1, i / 10 + 0.1, 1 - (i / 10) + 0.1] for i in range(11)][1:]
1160
- blue = [
1161
- [0.1, 0.1, 1.0],
1162
- [0.2, 0.2, 0.9],
1163
- [0.3, 0.3, 0.8],
1164
- [0.4, 0.4, 0.7],
1165
- [0.6, 0.6, 0.6],
1166
- [0.6, 0.6, 0.6],
1167
- [0.7, 0.7, 0.7],
1168
- [0.8, 0.8, 0.8],
1169
- ]
1170
- # gray = list(reversed([[i / 10 - 0.1, i / 10, i / 10 - 0.1] for i in range(11)][1:]))
1171
- gray = [
1172
- [0.6, 0.6, 0.6],
1173
- [0.6, 0.6, 0.6],
1174
- [0.6, 0.6, 0.6],
1175
- [0.6, 0.6, 0.6],
1176
- [0.6, 0.6, 0.6],
1177
- [0.4, 0.7, 0.4],
1178
- [0.3, 0.7, 0.3],
1179
- [0.2, 0.8, 0.2],
1180
- ]
1181
- # gray = [[0.6, 0.6, 0.6] for i in range(10)]
1182
- # green = [[0.2 + i/20, i / 10 - 0.1, + i/20] for i in range(11)][1:]
1183
- green = [
1184
- [0.25, 0.0, 0.05],
1185
- [0.3, 0.1, 0.1],
1186
- [0.35, 0.2, 0.15],
1187
- [0.4, 0.3, 0.2],
1188
- [0.45, 0.4, 0.25],
1189
- [0.5, 0.5, 0.3],
1190
- [0.55, 0.6, 0.35],
1191
- [0.7, 0.9, 0.5],
1192
- ]
1193
- green = [
1194
- [0.6, 0.6, 0.6],
1195
- [0.4, 0.7, 0.4],
1196
- [0.3, 0.8, 0.3],
1197
- [0.25, 0.4, 0.25],
1198
- [0.2, 0.5, 0.2],
1199
- [0.10, 0.7, 0.10],
1200
- [0, 0.9, 0],
1201
- ]
1202
-
1203
- def get_start(arr):
1204
- min_value = np.min(arr)
1205
- if min_value < -0.75:
1206
- return 0
1207
- if min_value < -0.5:
1208
- return 1
1209
- if min_value < -0.25:
1210
- return 2
1211
- if min_value < 0:
1212
- return 3
1213
- if min_value < 0.25:
1214
- return 4
1215
- if min_value < 0.5:
1216
- return 5
1217
- if min_value < 0.75:
1218
- return 6
1219
- return 7
1220
-
1221
- def get_stop(arr):
1222
- max_value = np.max(arr)
1223
- if max_value <= 0.05:
1224
- return 0
1225
- if max_value < 0.175:
1226
- return 1
1227
- if max_value < 0.25:
1228
- return 2
1229
- if max_value < 0.375:
1230
- return 3
1231
- if max_value < 0.5:
1232
- return 4
1233
- if max_value < 0.75:
1234
- return 5
1235
- return 6
1236
-
1237
- cmap_name = "blue_gray_green"
1238
-
1239
- start = get_start(arr)
1240
- stop = get_stop(arr)
1241
- blue = blue[start]
1242
- gray = gray[start]
1243
- # green = green[start]
1244
- green = green[stop]
1245
-
1246
- # green[0] = np.arange(0, 1, 0.1)[::-1][stop]
1247
- # green[1] = np.arange(0, 1, 0.1)[stop]
1248
- # green[2] = np.arange(0, 1, 0.1)[::-1][stop]
1249
-
1250
- print(green)
1251
- print(start, stop)
1252
- print("blue gray green")
1253
- print(blue)
1254
- print(gray)
1255
- print(green)
1256
-
1257
- # Define the segments of the colormap
1258
- cdict = {
1259
- "red": [
1260
- (0.0, blue[0], blue[0]),
1261
- (0.3, gray[0], gray[0]),
1262
- (0.7, gray[0], gray[0]),
1263
- (1.0, green[0], green[0]),
1264
- ],
1265
- "green": [
1266
- (0.0, blue[1], blue[1]),
1267
- (0.3, gray[1], gray[1]),
1268
- (0.7, gray[1], gray[1]),
1269
- (1.0, green[1], green[1]),
1270
- ],
1271
- "blue": [
1272
- (0.0, blue[2], blue[2]),
1273
- (0.3, gray[2], gray[2]),
1274
- (0.7, gray[2], gray[2]),
1275
- (1.0, green[2], green[2]),
1276
- ],
1277
- }
1278
-
1279
- return LinearSegmentedColormap(cmap_name, segmentdata=cdict, N=50)
1280
-
1281
-
1282
1245
  def median_as_int_and_minimum_dtype(arr: np.ndarray) -> np.ndarray:
1283
1246
  arr = np.median(arr, axis=0).astype(int)
1284
1247
  min_dtype = rasterio.dtypes.get_minimum_dtype(arr)
@@ -1288,21 +1251,17 @@ def median_as_int_and_minimum_dtype(arr: np.ndarray) -> np.ndarray:
1288
1251
  class Image(_ImageBandBase):
1289
1252
  """Image consisting of one or more Bands."""
1290
1253
 
1291
- cloud_cover_regexes: ClassVar[tuple[str] | None] = None
1292
1254
  band_class: ClassVar[Band] = Band
1293
1255
 
1294
1256
  def __init__(
1295
1257
  self,
1296
1258
  data: str | Path | Sequence[Band],
1297
1259
  res: int | None = None,
1298
- crs: Any | None = None,
1299
- single_banded: bool = False,
1300
1260
  file_system: GCSFileSystem | None = None,
1301
- df: pd.DataFrame | None = None,
1302
- all_file_paths: list[str] | None = None,
1303
1261
  processes: int = 1,
1304
- bbox: GeoDataFrame | GeoSeries | Geometry | tuple | None = None,
1262
+ df: pd.DataFrame | None = None,
1305
1263
  nodata: int | None = None,
1264
+ all_file_paths: list[str] | None = None,
1306
1265
  **kwargs,
1307
1266
  ) -> None:
1308
1267
  """Image initialiser."""
@@ -1310,18 +1269,14 @@ class Image(_ImageBandBase):
1310
1269
 
1311
1270
  self.nodata = nodata
1312
1271
  self._res = res
1313
- self._crs = crs
1272
+ self._crs = None
1314
1273
  self.file_system = file_system
1315
- self._bbox = to_bbox(bbox) if bbox is not None else None
1316
- # self._mask = _mask
1317
- self.single_banded = single_banded
1318
1274
  self.processes = processes
1319
- self._all_file_paths = all_file_paths
1320
1275
 
1321
1276
  if hasattr(data, "__iter__") and all(isinstance(x, Band) for x in data):
1322
1277
  self._bands = list(data)
1323
1278
  if res is None:
1324
- res = list({band.res for band in self._bands})
1279
+ res = list({band.res for band in self.bands})
1325
1280
  if len(res) == 1:
1326
1281
  self._res = res[0]
1327
1282
  else:
@@ -1334,25 +1289,23 @@ class Image(_ImageBandBase):
1334
1289
  raise TypeError("'data' must be string, Path-like or a sequence of Band.")
1335
1290
 
1336
1291
  self._bands = None
1337
- self._path = str(data)
1292
+ self._path = _fix_path(data) # str(data).rstrip("/").rstrip(r"\"")
1293
+
1294
+ if all_file_paths is None and self.path:
1295
+ self._all_file_paths = _get_all_file_paths(self.path)
1296
+ elif self.path:
1297
+ self._all_file_paths = [
1298
+ x for x in all_file_paths if self.path in _fix_path(x)
1299
+ ]
1300
+ else:
1301
+ self._all_file_paths = None
1338
1302
 
1339
1303
  if df is None:
1340
- if is_dapla():
1341
- file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
1342
- else:
1343
- file_paths = list(
1344
- sorted(
1345
- set(
1346
- _glob_func(self.path + "/**/**")
1347
- + _glob_func(self.path + "/**/**/**")
1348
- + _glob_func(self.path + "/**/**/**/**")
1349
- + _glob_func(self.path + "/**/**/**/**/**")
1350
- )
1351
- )
1352
- )
1353
- if not file_paths:
1354
- file_paths = [self.path]
1355
- df = self._create_metadata_df(file_paths)
1304
+ # file_paths = _get_all_file_paths(self.path)
1305
+
1306
+ if not self._all_file_paths:
1307
+ self._all_file_paths = [self.path]
1308
+ df = self._create_metadata_df(self._all_file_paths)
1356
1309
 
1357
1310
  df["image_path"] = df["image_path"].astype(str)
1358
1311
 
@@ -1368,27 +1321,24 @@ class Image(_ImageBandBase):
1368
1321
  df = df.explode(col)
1369
1322
  df = df.loc[lambda x: ~x["filename"].duplicated()].reset_index(drop=True)
1370
1323
 
1371
- df = df.loc[lambda x: x["image_path"].str.contains(_fix_path(self.path))]
1372
-
1373
- if self.cloud_cover_regexes:
1374
- if all_file_paths is None:
1375
- file_paths = _ls_func(self.path)
1376
- else:
1377
- file_paths = [path for path in all_file_paths if self.name in path]
1378
- self.cloud_coverage_percentage = float(
1379
- _get_regex_match_from_xml_in_local_dir(
1380
- file_paths, regexes=self.cloud_cover_regexes
1381
- )
1382
- )
1383
- else:
1384
- self.cloud_coverage_percentage = None
1324
+ df = df.loc[lambda x: x["image_path"] == _fix_path(self.path)]
1385
1325
 
1386
1326
  self._df = df
1387
1327
 
1328
+ for key in self.metadata_attributes:
1329
+ setattr(self, key, None)
1330
+
1331
+ if self.metadata_attributes:
1332
+ self._add_metadata_attributes()
1333
+
1388
1334
  @property
1389
1335
  def values(self) -> np.ndarray:
1390
1336
  """3 dimensional numpy array."""
1391
- return np.array([band.values for band in self])
1337
+ values = [band.values for band in self]
1338
+ if self.mask is not None:
1339
+ mask = [band.mask.values for band in self]
1340
+ return np.ma.array(values, mask=mask, fill_value=self.nodata)
1341
+ return np.array(values)
1392
1342
 
1393
1343
  def ndvi(self, red_band: str, nir_band: str, copy: bool = True) -> NDVIBand:
1394
1344
  """Calculate the NDVI for the Image."""
@@ -1398,13 +1348,6 @@ class Image(_ImageBandBase):
1398
1348
 
1399
1349
  arr: np.ndarray | np.ma.core.MaskedArray = ndvi(red.values, nir.values)
1400
1350
 
1401
- # if self.nodata is not None and not np.isnan(self.nodata):
1402
- # try:
1403
- # arr.data[arr.mask] = self.nodata
1404
- # arr = arr.copy()
1405
- # except AttributeError:
1406
- # pass
1407
-
1408
1351
  return NDVIBand(
1409
1352
  arr,
1410
1353
  bounds=red.bounds,
@@ -1445,10 +1388,30 @@ class Image(_ImageBandBase):
1445
1388
  **self._common_init_kwargs,
1446
1389
  )
1447
1390
 
1391
+ def to_xarray(self) -> DataArray:
1392
+ """Convert the raster to an xarray.DataArray."""
1393
+ name = self.name or self.__class__.__name__.lower()
1394
+ coords = _generate_spatial_coords(
1395
+ self[0].transform, self[0].width, self[0].height
1396
+ )
1397
+ dims = ["band", "y", "x"]
1398
+ return xr.DataArray(
1399
+ self.values,
1400
+ coords=coords,
1401
+ dims=dims,
1402
+ name=name,
1403
+ attrs={"crs": self.crs},
1404
+ )
1405
+
1448
1406
  @property
1449
1407
  def mask(self) -> Band | None:
1450
1408
  """Mask Band."""
1451
1409
  if self._mask is not None:
1410
+ # if not self._mask.has_array:
1411
+ # try:
1412
+ # self._mask.values = self[0]._mask.values
1413
+ # except Exception:
1414
+ # pass
1452
1415
  return self._mask
1453
1416
  if self.masking is None:
1454
1417
  return None
@@ -1465,6 +1428,7 @@ class Image(_ImageBandBase):
1465
1428
  )
1466
1429
  self._mask = self.band_class(
1467
1430
  mask_paths[0],
1431
+ _add_metadata_attributes=False,
1468
1432
  **self._common_init_kwargs,
1469
1433
  )
1470
1434
 
@@ -1506,34 +1470,11 @@ class Image(_ImageBandBase):
1506
1470
  if self._bands is not None:
1507
1471
  return self._bands
1508
1472
 
1509
- # if self.masking:
1510
- # mask_band_id = self.masking["band_id"]
1511
- # mask_paths = [
1512
- # path for path in self._df["file_path"] if mask_band_id in path
1513
- # ]
1514
- # if len(mask_paths) > 1:
1515
- # raise ValueError(
1516
- # f"Multiple file_paths match mask band_id {mask_band_id}"
1517
- # )
1518
- # elif not mask_paths:
1519
- # raise ValueError(f"No file_paths match mask band_id {mask_band_id}")
1520
- # arr = (
1521
- # self.band_class(
1522
- # mask_paths[0],
1523
- # # mask=self.mask,
1524
- # **self._common_init_kwargs,
1525
- # )
1526
- # .load()
1527
- # .values
1528
- # )
1529
- # self._mask = np.ma.array(
1530
- # arr, mask=np.isin(arr, self.masking["values"]), fill_value=None
1531
- # )
1532
-
1533
1473
  self._bands = [
1534
1474
  self.band_class(
1535
1475
  path,
1536
1476
  mask=self.mask,
1477
+ _add_metadata_attributes=False,
1537
1478
  **self._common_init_kwargs,
1538
1479
  )
1539
1480
  for path in (self._df["file_path"])
@@ -1557,11 +1498,7 @@ class Image(_ImageBandBase):
1557
1498
  self._bands = [
1558
1499
  band
1559
1500
  for band in self._bands
1560
- if any(
1561
- # _get_first_group_match(pat, band.name)
1562
- re.search(pat, band.name)
1563
- for pat in self.filename_patterns
1564
- )
1501
+ if any(re.search(pat, band.name) for pat in self.filename_patterns)
1565
1502
  ]
1566
1503
 
1567
1504
  if self.image_patterns:
@@ -1570,7 +1507,6 @@ class Image(_ImageBandBase):
1570
1507
  for band in self._bands
1571
1508
  if any(
1572
1509
  re.search(pat, Path(band.path).parent.name)
1573
- # _get_first_group_match(pat, Path(band.path).parent.name)
1574
1510
  for pat in self.image_patterns
1575
1511
  )
1576
1512
  ]
@@ -1578,6 +1514,21 @@ class Image(_ImageBandBase):
1578
1514
  if self._should_be_sorted:
1579
1515
  self._bands = list(sorted(self._bands))
1580
1516
 
1517
+ for key in self.metadata_attributes:
1518
+ for band in self:
1519
+ value = getattr(self, key)
1520
+ if value is None:
1521
+ continue
1522
+ if isinstance(value, BandIdDict):
1523
+ try:
1524
+ value = value[band.band_id]
1525
+ except KeyError:
1526
+ continue
1527
+ setattr(band, key, value)
1528
+
1529
+ for band in self:
1530
+ band._image = self
1531
+
1581
1532
  return self._bands
1582
1533
 
1583
1534
  @property
@@ -1621,7 +1572,14 @@ class Image(_ImageBandBase):
1621
1572
  @property
1622
1573
  def bounds(self) -> tuple[int, int, int, int] | None:
1623
1574
  """Bounds of the Image (minx, miny, maxx, maxy)."""
1624
- return get_total_bounds([band.bounds for band in self])
1575
+ try:
1576
+ return get_total_bounds([band.bounds for band in self])
1577
+ except exceptions.RefreshError:
1578
+ bounds = []
1579
+ for band in self:
1580
+ time.sleep(0.1)
1581
+ bounds.append(band.bounds)
1582
+ return get_total_bounds(bounds)
1625
1583
 
1626
1584
  def to_gdf(self, column: str = "value") -> GeoDataFrame:
1627
1585
  """Convert the array to a GeoDataFrame of grid polygons and values."""
@@ -1647,7 +1605,7 @@ class Image(_ImageBandBase):
1647
1605
  def __getitem__(
1648
1606
  self, band: str | int | Sequence[str] | Sequence[int]
1649
1607
  ) -> "Band | Image":
1650
- """Get bands by band_id or integer index.
1608
+ """Get bands by band_id or integer index or a sequence of such.
1651
1609
 
1652
1610
  Returns a Band if a string or int is passed,
1653
1611
  returns an Image if a sequence of strings or integers is passed.
@@ -1743,34 +1701,29 @@ class ImageCollection(_ImageBase):
1743
1701
 
1744
1702
  image_class: ClassVar[Image] = Image
1745
1703
  band_class: ClassVar[Band] = Band
1704
+ _metadata_attribute_collection_type: ClassVar[type] = pd.Series
1746
1705
 
1747
1706
  def __init__(
1748
1707
  self,
1749
- data: str | Path | Sequence[Image],
1708
+ data: str | Path | Sequence[Image] | Sequence[str | Path],
1750
1709
  res: int,
1751
- level: str | None,
1752
- crs: Any | None = None,
1753
- single_banded: bool = False,
1710
+ level: str | None = NoLevel,
1754
1711
  processes: int = 1,
1755
1712
  file_system: GCSFileSystem | None = None,
1756
- df: pd.DataFrame | None = None,
1757
- bbox: Any | None = None,
1758
- nodata: int | None = None,
1759
1713
  metadata: str | dict | pd.DataFrame | None = None,
1714
+ nodata: int | None = None,
1760
1715
  **kwargs,
1761
1716
  ) -> None:
1762
1717
  """Initialiser."""
1763
1718
  super().__init__(**kwargs)
1764
1719
 
1765
1720
  self.nodata = nodata
1766
- self.level = level
1767
- self._crs = crs
1721
+ self.level = level if not isinstance(level, NoLevel) else None
1768
1722
  self.processes = processes
1769
1723
  self.file_system = file_system
1770
1724
  self._res = res
1771
- self._bbox = to_bbox(bbox) if bbox is not None else None
1772
1725
  self._band_ids = None
1773
- self.single_banded = single_banded
1726
+ self._crs = None # crs
1774
1727
 
1775
1728
  if metadata is not None:
1776
1729
  if isinstance(metadata, (str | Path | os.PathLike)):
@@ -1780,45 +1733,43 @@ class ImageCollection(_ImageBase):
1780
1733
  else:
1781
1734
  self.metadata = metadata
1782
1735
 
1783
- if hasattr(data, "__iter__") and all(isinstance(x, Image) for x in data):
1736
+ self._df = None
1737
+ self._all_file_paths = None
1738
+ self._images = None
1739
+
1740
+ if hasattr(data, "__iter__") and not isinstance(data, str):
1784
1741
  self._path = None
1785
- self.images = [x.copy() for x in data]
1786
- return
1787
- else:
1788
- self._images = None
1742
+ if all(isinstance(x, Image) for x in data):
1743
+ self.images = [x.copy() for x in data]
1744
+ return
1745
+ elif all(isinstance(x, (str | Path | os.PathLike)) for x in data):
1746
+ self._all_file_paths = list(
1747
+ itertools.chain.from_iterable(
1748
+ _get_all_file_paths(str(path)) for path in data
1749
+ )
1750
+ )
1751
+ self._df = self._create_metadata_df([str(x) for x in data])
1752
+ return
1789
1753
 
1790
1754
  if not isinstance(data, (str | Path | os.PathLike)):
1791
1755
  raise TypeError("'data' must be string, Path-like or a sequence of Image.")
1792
1756
 
1793
1757
  self._path = str(data)
1794
1758
 
1795
- if is_dapla():
1796
- self._all_file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
1797
- else:
1798
- self._all_file_paths = list(
1799
- sorted(
1800
- set(
1801
- _glob_func(self.path + "/**/**")
1802
- + _glob_func(self.path + "/**/**/**")
1803
- + _glob_func(self.path + "/**/**/**/**")
1804
- + _glob_func(self.path + "/**/**/**/**/**")
1805
- )
1806
- )
1807
- )
1759
+ self._all_file_paths = _get_all_file_paths(self.path)
1808
1760
 
1809
1761
  if self.level:
1810
1762
  self._all_file_paths = [
1811
1763
  path for path in self._all_file_paths if self.level in path
1812
1764
  ]
1813
1765
 
1814
- if df is not None:
1815
- self._df = df
1816
- else:
1817
- self._df = self._create_metadata_df(self._all_file_paths)
1766
+ self._df = self._create_metadata_df(self._all_file_paths)
1818
1767
 
1819
1768
  @property
1820
1769
  def values(self) -> np.ndarray:
1821
1770
  """4 dimensional numpy array."""
1771
+ if isinstance(self[0].values, np.ma.core.MaskedArray):
1772
+ return np.ma.array([img.values for img in self])
1822
1773
  return np.array([img.values for img in self])
1823
1774
 
1824
1775
  @property
@@ -1826,21 +1777,6 @@ class ImageCollection(_ImageBase):
1826
1777
  """4 dimensional numpy array."""
1827
1778
  return np.array([img.mask.values for img in self])
1828
1779
 
1829
- # def ndvi(
1830
- # self, red_band: str, nir_band: str, copy: bool = True
1831
- # ) -> "ImageCollection":
1832
- # # copied = self.copy() if copy else self
1833
-
1834
- # with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1835
- # ndvi_images = parallel(
1836
- # joblib.delayed(_img_ndvi)(
1837
- # img, red_band=red_band, nir_band=nir_band, copy=False
1838
- # )
1839
- # for img in self
1840
- # )
1841
-
1842
- # return ImageCollection(ndvi_images, single_banded=True)
1843
-
1844
1780
  def groupby(self, by: str | list[str], **kwargs) -> ImageCollectionGroupBy:
1845
1781
  """Group the Collection by Image or Band attribute(s)."""
1846
1782
  df = pd.DataFrame(
@@ -1882,7 +1818,6 @@ class ImageCollection(_ImageBase):
1882
1818
  copied.images = [
1883
1819
  self.image_class(
1884
1820
  [band],
1885
- single_banded=True,
1886
1821
  masking=self.masking,
1887
1822
  band_class=self.band_class,
1888
1823
  **self._common_init_kwargs,
@@ -1892,6 +1827,60 @@ class ImageCollection(_ImageBase):
1892
1827
  for img in self
1893
1828
  for band in img
1894
1829
  ]
1830
+ for img in copied:
1831
+ assert len(img) == 1
1832
+ try:
1833
+ img._path = img[0].path
1834
+ except PathlessImageError:
1835
+ pass
1836
+ return copied
1837
+
1838
+ def apply(self, func: Callable, **kwargs) -> "ImageCollection":
1839
+ """Apply a function to all bands in each image of the collection."""
1840
+ for img in self:
1841
+ img._bands = [func(band, **kwargs) for band in img]
1842
+ return self
1843
+
1844
+ def get_unique_band_ids(self) -> list[str]:
1845
+ """Get a list of unique band_ids across all images."""
1846
+ return list({band.band_id for img in self for band in img})
1847
+
1848
+ def filter(
1849
+ self,
1850
+ bands: str | list[str] | None = None,
1851
+ date_ranges: DATE_RANGES_TYPE = None,
1852
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
1853
+ intersects: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
1854
+ max_cloud_coverage: int | None = None,
1855
+ copy: bool = True,
1856
+ ) -> "ImageCollection":
1857
+ """Filter images and bands in the collection."""
1858
+ copied = self.copy() if copy else self
1859
+
1860
+ if date_ranges:
1861
+ copied = copied._filter_dates(date_ranges)
1862
+
1863
+ if max_cloud_coverage is not None:
1864
+ copied.images = [
1865
+ image
1866
+ for image in copied.images
1867
+ if image.cloud_coverage_percentage < max_cloud_coverage
1868
+ ]
1869
+
1870
+ if bbox is not None:
1871
+ copied = copied._filter_bounds(bbox)
1872
+ copied._set_bbox(bbox)
1873
+
1874
+ if intersects is not None:
1875
+ copied = copied._filter_bounds(intersects)
1876
+
1877
+ if bands is not None:
1878
+ if isinstance(bands, str):
1879
+ bands = [bands]
1880
+ bands = set(bands)
1881
+ copied._band_ids = bands
1882
+ copied.images = [img[bands] for img in copied.images if bands in img]
1883
+
1895
1884
  return copied
1896
1885
 
1897
1886
  def merge(
@@ -1903,7 +1892,10 @@ class ImageCollection(_ImageBase):
1903
1892
  **kwargs,
1904
1893
  ) -> Band:
1905
1894
  """Merge all areas and all bands to a single Band."""
1906
- bounds = to_bbox(bounds) if bounds is not None else self._bbox
1895
+ bounds = _get_bounds(bounds, self._bbox)
1896
+ if bounds is not None:
1897
+ bounds = to_bbox(bounds)
1898
+
1907
1899
  crs = self.crs
1908
1900
 
1909
1901
  if indexes is None:
@@ -1971,7 +1963,9 @@ class ImageCollection(_ImageBase):
1971
1963
  **kwargs,
1972
1964
  ) -> Image:
1973
1965
  """Merge all areas to a single tile, one band per band_id."""
1974
- bounds = to_bbox(bounds) if bounds is not None else self._bbox
1966
+ bounds = _get_bounds(bounds, self._bbox)
1967
+ if bounds is not None:
1968
+ bounds = to_bbox(bounds)
1975
1969
  bounds = self.bounds if bounds is None else bounds
1976
1970
  out_bounds = bounds
1977
1971
  crs = self.crs
@@ -2027,11 +2021,12 @@ class ImageCollection(_ImageBase):
2027
2021
  bounds=out_bounds,
2028
2022
  crs=crs,
2029
2023
  band_id=band_id,
2024
+ _add_metadata_attributes=False,
2030
2025
  **self._common_init_kwargs,
2031
2026
  )
2032
2027
  )
2033
2028
 
2034
- # return self.image_class(
2029
+ # return self.image_class( # TODO
2035
2030
  image = Image(
2036
2031
  bands,
2037
2032
  band_class=self.band_class,
@@ -2129,14 +2124,13 @@ class ImageCollection(_ImageBase):
2129
2124
  **kwargs,
2130
2125
  ) -> "ImageCollection":
2131
2126
  """Load all image Bands with threading."""
2127
+ if (
2128
+ bounds is None
2129
+ and indexes is None
2130
+ and all(band.has_array for img in self for band in img)
2131
+ ):
2132
+ return self
2132
2133
  with joblib.Parallel(n_jobs=self.processes, backend="threading") as parallel:
2133
- if self.masking:
2134
- parallel(
2135
- joblib.delayed(_load_band)(
2136
- img.mask, bounds=bounds, indexes=indexes, **kwargs
2137
- )
2138
- for img in self
2139
- )
2140
2134
  parallel(
2141
2135
  joblib.delayed(_load_band)(
2142
2136
  band, bounds=bounds, indexes=indexes, **kwargs
@@ -2147,7 +2141,7 @@ class ImageCollection(_ImageBase):
2147
2141
 
2148
2142
  return self
2149
2143
 
2150
- def set_bbox(
2144
+ def _set_bbox(
2151
2145
  self, bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float]
2152
2146
  ) -> "ImageCollection":
2153
2147
  """Set the mask to be used to clip the images to."""
@@ -2156,86 +2150,18 @@ class ImageCollection(_ImageBase):
2156
2150
  if self._images is not None:
2157
2151
  for img in self._images:
2158
2152
  img._bbox = self._bbox
2159
- if img._bands is not None:
2160
- for band in img:
2161
- band._bbox = self._bbox
2162
- bounds = box(*band._bbox).intersection(box(*band.bounds))
2163
- band._bounds = to_bbox(bounds) if not bounds.is_empty else None
2164
-
2165
- return self
2153
+ if img.bands is None:
2154
+ continue
2155
+ for band in img:
2156
+ band._bbox = self._bbox
2157
+ bounds = box(*band._bbox).intersection(box(*band.bounds))
2158
+ band._bounds = to_bbox(bounds) if not bounds.is_empty else None
2166
2159
 
2167
- def apply(self, func: Callable, **kwargs) -> "ImageCollection":
2168
- """Apply a function to all bands in each image of the collection."""
2169
- for img in self:
2170
- img.bands = [func(band, **kwargs) for band in img]
2171
2160
  return self
2172
2161
 
2173
- def filter(
2174
- self,
2175
- bands: str | list[str] | None = None,
2176
- exclude_bands: str | list[str] | None = None,
2177
- date_ranges: (
2178
- tuple[str | None, str | None]
2179
- | tuple[tuple[str | None, str | None], ...]
2180
- | None
2181
- ) = None,
2182
- bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2183
- intersects: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2184
- max_cloud_coverage: int | None = None,
2185
- copy: bool = True,
2186
- ) -> "ImageCollection":
2187
- """Filter images and bands in the collection."""
2188
- copied = self.copy() if copy else self
2189
-
2190
- if isinstance(bbox, BoundingBox):
2191
- date_ranges = (bbox.mint, bbox.maxt)
2192
-
2193
- if date_ranges:
2194
- copied = copied._filter_dates(date_ranges)
2195
-
2196
- if max_cloud_coverage is not None:
2197
- copied.images = [
2198
- image
2199
- for image in copied.images
2200
- if image.cloud_coverage_percentage < max_cloud_coverage
2201
- ]
2202
-
2203
- if bbox is not None:
2204
- copied = copied._filter_bounds(bbox)
2205
- copied.set_bbox(bbox)
2206
-
2207
- if intersects is not None:
2208
- copied = copied._filter_bounds(intersects)
2209
-
2210
- if bands is not None:
2211
- if isinstance(bands, str):
2212
- bands = [bands]
2213
- bands = set(bands)
2214
- copied._band_ids = bands
2215
- copied.images = [img[bands] for img in copied.images if bands in img]
2216
-
2217
- if exclude_bands is not None:
2218
- if isinstance(exclude_bands, str):
2219
- exclude_bands = {exclude_bands}
2220
- else:
2221
- exclude_bands = set(exclude_bands)
2222
- include_bands: list[list[str]] = [
2223
- [band_id for band_id in img.band_ids if band_id not in exclude_bands]
2224
- for img in copied
2225
- ]
2226
- copied.images = [
2227
- img[bands]
2228
- for img, bands in zip(copied.images, include_bands, strict=False)
2229
- if bands
2230
- ]
2231
-
2232
- return copied
2233
-
2234
2162
  def _filter_dates(
2235
2163
  self,
2236
- date_ranges: (
2237
- tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...]
2238
- ),
2164
+ date_ranges: DATE_RANGES_TYPE = None,
2239
2165
  ) -> "ImageCollection":
2240
2166
  if not isinstance(date_ranges, (tuple, list)):
2241
2167
  raise TypeError(
@@ -2247,13 +2173,7 @@ class ImageCollection(_ImageBase):
2247
2173
  "Cannot set date_ranges when the class's image_regexes attribute is None"
2248
2174
  )
2249
2175
 
2250
- self.images = [
2251
- img
2252
- for img in self
2253
- if _date_is_within(
2254
- img.path, date_ranges, self.image_patterns, self.date_format
2255
- )
2256
- ]
2176
+ self.images = [img for img in self if _date_is_within(img.date, date_ranges)]
2257
2177
  return self
2258
2178
 
2259
2179
  def _filter_bounds(
@@ -2277,6 +2197,38 @@ class ImageCollection(_ImageBase):
2277
2197
  ]
2278
2198
  return self
2279
2199
 
2200
+ def to_xarray(self, **kwargs) -> DataArray:
2201
+ """Convert the raster to an xarray.DataArray."""
2202
+ # arrs = []
2203
+ # for img in self:
2204
+ # for band in img:
2205
+ # arr = band.load(**kwargs).values
2206
+ # arrs.append(arr)
2207
+
2208
+ # n_images = len(self)
2209
+ # n_bands = len(img)
2210
+ # height, width = arr.shape
2211
+
2212
+ # arr_4d = np.array(arrs).reshape(n_images, n_bands, height, width)
2213
+
2214
+ try:
2215
+ name = Path(self.path).stem
2216
+ except TypeError:
2217
+ name = self.__class__.__name__.lower()
2218
+
2219
+ first_band = self[0][0]
2220
+ coords = _generate_spatial_coords(
2221
+ first_band.transform, first_band.width, first_band.height
2222
+ )
2223
+ dims = ["image", "band", "y", "x"]
2224
+ return xr.DataArray(
2225
+ self.values,
2226
+ coords=coords,
2227
+ dims=dims,
2228
+ name=name,
2229
+ attrs={"crs": self.crs},
2230
+ )
2231
+
2280
2232
  def to_gdfs(self, column: str = "value") -> dict[str, GeoDataFrame]:
2281
2233
  """Convert each band in each Image to a GeoDataFrame."""
2282
2234
  out = {}
@@ -2289,12 +2241,10 @@ class ImageCollection(_ImageBase):
2289
2241
  except AttributeError:
2290
2242
  name = f"{self.__class__.__name__}({i})"
2291
2243
 
2292
- band.load()
2244
+ # band.load()
2293
2245
 
2294
2246
  if name not in out:
2295
2247
  out[name] = band.to_gdf(column=column)
2296
- # else:
2297
- # out[name] = f"{self.__class__.__name__}({i})"
2298
2248
  return out
2299
2249
 
2300
2250
  def sample(self, n: int = 1, size: int = 500) -> "ImageCollection":
@@ -2363,11 +2313,16 @@ class ImageCollection(_ImageBase):
2363
2313
  """Number of images."""
2364
2314
  return len(self.images)
2365
2315
 
2366
- def __getitem__(
2367
- self,
2368
- item: int | slice | Sequence[int | bool] | BoundingBox | Sequence[BoundingBox],
2369
- ) -> Image | TORCHGEO_RETURN_TYPE:
2370
- """Select one Image by integer index, or multiple Images by slice, list of int or torchgeo.BoundingBox."""
2316
+ def __getattr__(self, attr: str) -> Any:
2317
+ """Make iterable of metadata_attribute."""
2318
+ if attr in (self.metadata_attributes or {}):
2319
+ return self._metadata_attribute_collection_type(
2320
+ [getattr(img, attr) for img in self]
2321
+ )
2322
+ return super().__getattribute__(attr)
2323
+
2324
+ def __getitem__(self, item: int | slice | Sequence[int | bool]) -> Image:
2325
+ """Select one Image by integer index, or multiple Images by slice, list of int."""
2371
2326
  if isinstance(item, int):
2372
2327
  return self.images[item]
2373
2328
 
@@ -2392,90 +2347,23 @@ class ImageCollection(_ImageBase):
2392
2347
  ]
2393
2348
  return copied
2394
2349
 
2395
- if not isinstance(item, BoundingBox) and not (
2396
- isinstance(item, Iterable)
2397
- and len(item)
2398
- and all(isinstance(x, BoundingBox) for x in item)
2399
- ):
2400
- copied = self.copy()
2401
- if callable(item):
2402
- item = [item(img) for img in copied]
2403
-
2404
- # check for base bool and numpy bool
2405
- if all("bool" in str(type(x)) for x in item):
2406
- copied.images = [img for x, img in zip(item, copied, strict=True) if x]
2350
+ copied = self.copy()
2351
+ if callable(item):
2352
+ item = [item(img) for img in copied]
2407
2353
 
2408
- else:
2409
- copied.images = [copied.images[i] for i in item]
2410
- return copied
2354
+ # check for base bool and numpy bool
2355
+ if all("bool" in str(type(x)) for x in item):
2356
+ copied.images = [img for x, img in zip(item, copied, strict=True) if x]
2411
2357
 
2412
- if isinstance(item, BoundingBox):
2413
- date_ranges: tuple[str] = (item.mint, item.maxt)
2414
- data: torch.Tensor = numpy_to_torch(
2415
- np.array(
2416
- [
2417
- band.values
2418
- for band in self.filter(
2419
- bbox=item, date_ranges=date_ranges
2420
- ).merge_by_band(bounds=item)
2421
- ]
2422
- )
2423
- )
2424
2358
  else:
2425
- bboxes: list[Polygon] = [to_bbox(x) for x in item]
2426
- date_ranges: list[list[str, str]] = [(x.mint, x.maxt) for x in item]
2427
- data: torch.Tensor = torch.cat(
2428
- [
2429
- numpy_to_torch(
2430
- np.array(
2431
- [
2432
- band.values
2433
- for band in self.filter(
2434
- bbox=bbox, date_ranges=date_range
2435
- ).merge_by_band(bounds=bbox)
2436
- ]
2437
- )
2438
- )
2439
- for bbox, date_range in zip(bboxes, date_ranges, strict=True)
2440
- ]
2441
- )
2442
-
2443
- crs = get_common_crs(self.images)
2444
-
2445
- key = "image" # if self.is_image else "mask"
2446
- sample = {key: data, "crs": crs, "bbox": item}
2447
-
2448
- return sample
2449
-
2450
- @property
2451
- def mint(self) -> float:
2452
- """Min timestamp of the images combined."""
2453
- return min(img.mint for img in self)
2454
-
2455
- @property
2456
- def maxt(self) -> float:
2457
- """Max timestamp of the images combined."""
2458
- return max(img.maxt for img in self)
2459
-
2460
- @property
2461
- def band_ids(self) -> list[str]:
2462
- """Sorted list of unique band_ids."""
2463
- return list(sorted({band.band_id for img in self for band in img}))
2464
-
2465
- @property
2466
- def file_paths(self) -> list[str]:
2467
- """Sorted list of all file paths, meaning all band paths."""
2468
- return list(sorted({band.path for img in self for band in img}))
2359
+ copied.images = [copied.images[i] for i in item]
2360
+ return copied
2469
2361
 
2470
2362
  @property
2471
2363
  def dates(self) -> list[str]:
2472
2364
  """List of image dates."""
2473
2365
  return [img.date for img in self]
2474
2366
 
2475
- def dates_as_int(self) -> list[int]:
2476
- """List of image dates as 8-length integers."""
2477
- return [int(img.date[:8]) for img in self]
2478
-
2479
2367
  @property
2480
2368
  def image_paths(self) -> list[str]:
2481
2369
  """List of image paths."""
@@ -2510,14 +2398,21 @@ class ImageCollection(_ImageBase):
2510
2398
  image._bands = [band for band in image if band.band_id is not None]
2511
2399
 
2512
2400
  if self.metadata is not None:
2401
+ attributes_to_add = ["crs", "bounds"] + list(self.metadata_attributes)
2513
2402
  for img in self:
2514
2403
  for band in img:
2515
- for key in ["crs", "bounds"]:
2404
+ for key in attributes_to_add:
2516
2405
  try:
2517
2406
  value = self.metadata[band.path][key]
2518
2407
  except KeyError:
2519
- value = self.metadata[key][band.path]
2520
- setattr(band, f"_{key}", value)
2408
+ try:
2409
+ value = self.metadata[key][band.path]
2410
+ except KeyError:
2411
+ continue
2412
+ try:
2413
+ setattr(band, key, value)
2414
+ except Exception:
2415
+ setattr(band, f"_{key}", value)
2521
2416
 
2522
2417
  self._images = [img for img in self if len(img)]
2523
2418
 
@@ -2552,28 +2447,6 @@ class ImageCollection(_ImageBase):
2552
2447
  if not all(isinstance(x, Image) for x in self._images):
2553
2448
  raise TypeError("images should be a sequence of Image.")
2554
2449
 
2555
- @property
2556
- def index(self) -> Index:
2557
- """Spatial index that makes torchgeo think this class is a RasterDataset."""
2558
- try:
2559
- if len(self) == len(self._index):
2560
- return self._index
2561
- except AttributeError:
2562
- self._index = Index(interleaved=False, properties=Property(dimension=3))
2563
-
2564
- for i, img in enumerate(self.images):
2565
- if img.date:
2566
- try:
2567
- mint, maxt = disambiguate_timestamp(img.date, self.date_format)
2568
- except (NameError, TypeError):
2569
- mint, maxt = 0, 1
2570
- else:
2571
- mint, maxt = 0, 1
2572
- # important: torchgeo has a different order of the bbox than shapely and geopandas
2573
- minx, miny, maxx, maxy = img.bounds
2574
- self._index.insert(i, (minx, maxx, miny, maxy, mint, maxt))
2575
- return self._index
2576
-
2577
2450
  def __repr__(self) -> str:
2578
2451
  """String representation."""
2579
2452
  return f"{self.__class__.__name__}({len(self)}, path='{self.path}')"
@@ -2603,6 +2476,7 @@ class ImageCollection(_ImageBase):
2603
2476
  p: float = 0.95,
2604
2477
  ylim: tuple[float, float] | None = None,
2605
2478
  figsize: tuple[int] = (20, 8),
2479
+ rounding: int = 3,
2606
2480
  ) -> None:
2607
2481
  """Plot each individual pixel in a dotplot for all dates.
2608
2482
 
@@ -2616,6 +2490,7 @@ class ImageCollection(_ImageBase):
2616
2490
  p: p-value for the confidence interval.
2617
2491
  ylim: Limits of the y-axis.
2618
2492
  figsize: Figure size as tuple (width, height).
2493
+ rounding: rounding of title n
2619
2494
 
2620
2495
  """
2621
2496
  if by is None and all(band.band_id is not None for img in self for band in img):
@@ -2625,13 +2500,16 @@ class ImageCollection(_ImageBase):
2625
2500
 
2626
2501
  alpha = 1 - p
2627
2502
 
2628
- for img in self:
2629
- for band in img:
2630
- band.load()
2503
+ # for img in self:
2504
+ # for band in img:
2505
+ # band.load()
2631
2506
 
2632
2507
  for group_values, subcollection in self.groupby(by):
2633
2508
  print("group_values:", *group_values)
2634
2509
 
2510
+ if "date" in x_var and subcollection._should_be_sorted:
2511
+ subcollection._images = list(sorted(subcollection._images))
2512
+
2635
2513
  y = np.array([band.values for img in subcollection for band in img])
2636
2514
  if "date" in x_var and subcollection._should_be_sorted:
2637
2515
  x = np.array(
@@ -2685,6 +2563,10 @@ class ImageCollection(_ImageBase):
2685
2563
  )[0]
2686
2564
  predicted = np.array([intercept + coef * x for x in this_x])
2687
2565
 
2566
+ predicted_start = predicted[0]
2567
+ predicted_end = predicted[-1]
2568
+ predicted_change = predicted_end - predicted_start
2569
+
2688
2570
  # Degrees of freedom
2689
2571
  dof = len(this_x) - 2
2690
2572
 
@@ -2708,8 +2590,6 @@ class ImageCollection(_ImageBase):
2708
2590
  ci_lower = predicted - t_val * pred_stderr
2709
2591
  ci_upper = predicted + t_val * pred_stderr
2710
2592
 
2711
- rounding = int(np.log(1 / abs(coef)))
2712
-
2713
2593
  fig = plt.figure(figsize=figsize)
2714
2594
  ax = fig.add_subplot(1, 1, 1)
2715
2595
 
@@ -2723,22 +2603,170 @@ class ImageCollection(_ImageBase):
2723
2603
  alpha=0.2,
2724
2604
  label=f"{int(alpha*100)}% CI",
2725
2605
  )
2726
- plt.title(f"Coefficient: {round(coef, rounding)}")
2606
+ plt.title(
2607
+ f"coef: {round(coef, int(np.log(1 / abs(coef))))}, "
2608
+ f"pred change: {round(predicted_change, rounding)}, "
2609
+ f"pred start: {round(predicted_start, rounding)}, "
2610
+ f"pred end: {round(predicted_end, rounding)}"
2611
+ )
2727
2612
  plt.xlabel(x_var)
2728
2613
  plt.ylabel(y_label)
2729
2614
  plt.show()
2730
2615
 
2731
2616
 
2732
- def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCollection:
2733
- """Union multiple ImageCollections together.
2617
+ def _get_all_regex_matches(xml_file: str, regexes: tuple[str]) -> tuple[str]:
2618
+ for regex in regexes:
2619
+ try:
2620
+ return re.search(regex, xml_file)
2621
+ except (TypeError, AttributeError):
2622
+ continue
2623
+ raise ValueError(
2624
+ f"Could not find processing_baseline info from {regexes} in {xml_file}"
2625
+ )
2734
2626
 
2735
- Same as using the union operator |.
2736
- """
2737
- resolutions = {x.res for x in collections}
2738
- if len(resolutions) > 1:
2739
- raise ValueError(f"resoultion mismatch. {resolutions}")
2740
- images = list(itertools.chain.from_iterable([x.images for x in collections]))
2741
- levels = {x.level for x in collections}
2627
+
2628
+ class Sentinel2Config:
2629
+ """Holder of Sentinel 2 regexes, band_ids etc."""
2630
+
2631
+ image_regexes: ClassVar[str] = (config.SENTINEL2_IMAGE_REGEX,)
2632
+ filename_regexes: ClassVar[str] = (
2633
+ config.SENTINEL2_FILENAME_REGEX,
2634
+ config.SENTINEL2_CLOUD_FILENAME_REGEX,
2635
+ )
2636
+ metadata_attributes: ClassVar[
2637
+ dict[str, Callable | functools.partial | tuple[str]]
2638
+ ] = {
2639
+ "processing_baseline": functools.partial(
2640
+ _extract_regex_match_from_string,
2641
+ regexes=(r"<PROCESSING_BASELINE>(.*?)</PROCESSING_BASELINE>",),
2642
+ ),
2643
+ "cloud_coverage_percentage": "_get_cloud_coverage_percentage",
2644
+ "is_refined": functools.partial(
2645
+ _any_regex_matches, regexes=(r'<Image_Refining flag="REFINED">',)
2646
+ ),
2647
+ "boa_add_offset": "_get_boa_add_offset_dict",
2648
+ }
2649
+ all_bands: ClassVar[list[str]] = list(config.SENTINEL2_BANDS)
2650
+ rbg_bands: ClassVar[list[str]] = config.SENTINEL2_RBG_BANDS
2651
+ ndvi_bands: ClassVar[list[str]] = config.SENTINEL2_NDVI_BANDS
2652
+ l2a_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L2A_BANDS
2653
+ l1c_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L1C_BANDS
2654
+ masking: ClassVar[BandMasking] = BandMasking(
2655
+ band_id="SCL", values=(3, 8, 9, 10, 11)
2656
+ )
2657
+
2658
+ def _get_cloud_coverage_percentage(self, xml_file: str) -> float:
2659
+ return float(
2660
+ _extract_regex_match_from_string(
2661
+ xml_file,
2662
+ (
2663
+ r"<Cloud_Coverage_Assessment>([\d.]+)</Cloud_Coverage_Assessment>",
2664
+ r"<CLOUDY_PIXEL_OVER_LAND_PERCENTAGE>([\d.]+)</CLOUDY_PIXEL_OVER_LAND_PERCENTAGE>",
2665
+ ),
2666
+ )
2667
+ )
2668
+
2669
+ def _get_boa_add_offset_dict(self, xml_file: str) -> BandIdDict:
2670
+ pat = re.compile(
2671
+ r"""
2672
+ <BOA_ADD_OFFSET\s*
2673
+ band_id="(?P<band_id>\d+)"\s*
2674
+ >\s*(?P<value>-?\d+)\s*
2675
+ </BOA_ADD_OFFSET>
2676
+ """,
2677
+ flags=re.VERBOSE,
2678
+ )
2679
+
2680
+ try:
2681
+ matches = [x.groupdict() for x in re.finditer(pat, xml_file)]
2682
+ except (TypeError, AttributeError, KeyError) as e:
2683
+ raise _RegexError(f"Could not find boa_add_offset info from {pat}") from e
2684
+ if not matches:
2685
+ raise _RegexError(f"Could not find boa_add_offset info from {pat}")
2686
+ return BandIdDict(
2687
+ pd.DataFrame(matches).set_index("band_id")["value"].astype(int).to_dict()
2688
+ )
2689
+
2690
+
2691
+ class Sentinel2CloudlessConfig(Sentinel2Config):
2692
+ """Holder of regexes, band_ids etc. for Sentinel 2 cloudless mosaic."""
2693
+
2694
+ image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2695
+ filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2696
+ masking: ClassVar[None] = None
2697
+ all_bands: ClassVar[list[str]] = [
2698
+ x.replace("B0", "B") for x in Sentinel2Config.all_bands
2699
+ ]
2700
+ rbg_bands: ClassVar[list[str]] = [
2701
+ x.replace("B0", "B") for x in Sentinel2Config.rbg_bands
2702
+ ]
2703
+ ndvi_bands: ClassVar[list[str]] = [
2704
+ x.replace("B0", "B") for x in Sentinel2Config.ndvi_bands
2705
+ ]
2706
+
2707
+
2708
+ class Sentinel2Band(Sentinel2Config, Band):
2709
+ """Band with Sentinel2 specific name variables and regexes."""
2710
+
2711
+
2712
+ class Sentinel2Image(Sentinel2Config, Image):
2713
+ """Image with Sentinel2 specific name variables and regexes."""
2714
+
2715
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
2716
+
2717
+ def ndvi(
2718
+ self,
2719
+ red_band: str = Sentinel2Config.ndvi_bands[0],
2720
+ nir_band: str = Sentinel2Config.ndvi_bands[1],
2721
+ copy: bool = True,
2722
+ ) -> NDVIBand:
2723
+ """Calculate the NDVI for the Image."""
2724
+ return super().ndvi(red_band=red_band, nir_band=nir_band, copy=copy)
2725
+
2726
+
2727
+ class Sentinel2Collection(Sentinel2Config, ImageCollection):
2728
+ """ImageCollection with Sentinel2 specific name variables and path regexes."""
2729
+
2730
+ image_class: ClassVar[Sentinel2Image] = Sentinel2Image
2731
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
2732
+
2733
+ def __init__(self, data: str | Path | Sequence[Image], **kwargs) -> None:
2734
+ """ImageCollection with Sentinel2 specific name variables and path regexes."""
2735
+ level = kwargs.get("level", NoLevel)
2736
+ if isinstance(level, type) and isinstance(level(), NoLevel):
2737
+ raise ValueError("Must specify level for Sentinel2Collection.")
2738
+ super().__init__(data=data, **kwargs)
2739
+
2740
+
2741
+ class Sentinel2CloudlessBand(Sentinel2CloudlessConfig, Band):
2742
+ """Band for cloudless mosaic with Sentinel2 specific name variables and regexes."""
2743
+
2744
+
2745
+ class Sentinel2CloudlessImage(Sentinel2CloudlessConfig, Sentinel2Image):
2746
+ """Image for cloudless mosaic with Sentinel2 specific name variables and regexes."""
2747
+
2748
+ band_class: ClassVar[Sentinel2CloudlessBand] = Sentinel2CloudlessBand
2749
+
2750
+ ndvi = Sentinel2Image.ndvi
2751
+
2752
+
2753
+ class Sentinel2CloudlessCollection(Sentinel2CloudlessConfig, ImageCollection):
2754
+ """ImageCollection with Sentinel2 specific name variables and regexes."""
2755
+
2756
+ image_class: ClassVar[Sentinel2CloudlessImage] = Sentinel2CloudlessImage
2757
+ band_class: ClassVar[Sentinel2Band] = Sentinel2CloudlessBand
2758
+
2759
+
2760
+ def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCollection:
2761
+ """Union multiple ImageCollections together.
2762
+
2763
+ Same as using the union operator |.
2764
+ """
2765
+ resolutions = {x.res for x in collections}
2766
+ if len(resolutions) > 1:
2767
+ raise ValueError(f"resoultion mismatch. {resolutions}")
2768
+ images = list(itertools.chain.from_iterable([x.images for x in collections]))
2769
+ levels = {x.level for x in collections}
2742
2770
  level = next(iter(levels)) if len(levels) == 1 else None
2743
2771
  first_collection = collections[0]
2744
2772
 
@@ -2816,7 +2844,7 @@ def _clip_loaded_array(
2816
2844
  out_shape: tuple[int, int],
2817
2845
  **kwargs,
2818
2846
  ) -> np.ndarray:
2819
- # xarray needs a numpy array of polygon(s)
2847
+ # xarray needs a numpy array of polygons
2820
2848
  bounds_arr: np.ndarray = GeoSeries([to_shapely(bounds)]).values
2821
2849
  try:
2822
2850
 
@@ -2837,6 +2865,29 @@ def _clip_loaded_array(
2837
2865
  return np.array([])
2838
2866
 
2839
2867
 
2868
+ def _fix_path(path: str) -> str:
2869
+ return (
2870
+ str(path).replace("\\", "/").replace(r"\"", "/").replace("//", "/").rstrip("/")
2871
+ )
2872
+
2873
+
2874
+ def _get_all_file_paths(path: str) -> list[str]:
2875
+ if is_dapla():
2876
+ return list(sorted(set(_glob_func(path + "/**"))))
2877
+ else:
2878
+ return list(
2879
+ sorted(
2880
+ set(
2881
+ _glob_func(path + "/**")
2882
+ + _glob_func(path + "/**/**")
2883
+ + _glob_func(path + "/**/**/**")
2884
+ + _glob_func(path + "/**/**/**/**")
2885
+ + _glob_func(path + "/**/**/**/**/**")
2886
+ )
2887
+ )
2888
+ )
2889
+
2890
+
2840
2891
  def _get_images(
2841
2892
  image_paths: list[str],
2842
2893
  *,
@@ -2874,21 +2925,6 @@ def _get_images(
2874
2925
  return images
2875
2926
 
2876
2927
 
2877
- def numpy_to_torch(array: np.ndarray) -> torch.Tensor:
2878
- """Convert numpy array to a pytorch tensor."""
2879
- # fix numpy dtypes which are not supported by pytorch tensors
2880
- if array.dtype == np.uint16:
2881
- array = array.astype(np.int32)
2882
- elif array.dtype == np.uint32:
2883
- array = array.astype(np.int64)
2884
-
2885
- return torch.tensor(array)
2886
-
2887
-
2888
- class _RegexError(ValueError):
2889
- pass
2890
-
2891
-
2892
2928
  class ArrayNotLoadedError(ValueError):
2893
2929
  """Arrays are not loaded."""
2894
2930
 
@@ -2904,10 +2940,12 @@ class PathlessImageError(ValueError):
2904
2940
  """String representation."""
2905
2941
  if self.instance._merged:
2906
2942
  what = "that have been merged"
2907
- elif self.isinstance._from_array:
2943
+ elif self.instance._from_array:
2908
2944
  what = "from arrays"
2909
- elif self.isinstance._from_gdf:
2945
+ elif self.instance._from_gdf:
2910
2946
  what = "from GeoDataFrames"
2947
+ else:
2948
+ raise ValueError(self.instance)
2911
2949
 
2912
2950
  return (
2913
2951
  f"{self.instance.__class__.__name__} instances {what} "
@@ -2915,165 +2953,32 @@ class PathlessImageError(ValueError):
2915
2953
  )
2916
2954
 
2917
2955
 
2918
- def _get_regex_match_from_xml_in_local_dir(
2919
- paths: list[str], regexes: str | tuple[str]
2920
- ) -> str | dict[str, str]:
2921
- for i, path in enumerate(paths):
2922
- if ".xml" not in path:
2923
- continue
2924
- with _open_func(path, "rb") as file:
2925
- filebytes: bytes = file.read()
2926
- try:
2927
- return _extract_regex_match_from_string(
2928
- filebytes.decode("utf-8"), regexes
2929
- )
2930
- except _RegexError as e:
2931
- if i == len(paths) - 1:
2932
- raise e
2933
-
2934
-
2935
- def _extract_regex_match_from_string(
2936
- xml_file: str, regexes: tuple[str | re.Pattern]
2937
- ) -> str | dict[str, str]:
2938
- if all(isinstance(x, str) for x in regexes):
2939
- for regex in regexes:
2940
- try:
2941
- return re.search(regex, xml_file).group(1)
2942
- except (TypeError, AttributeError):
2943
- continue
2944
- raise _RegexError()
2945
-
2946
- out = {}
2947
- for regex in regexes:
2948
- try:
2949
- matches = re.search(regex, xml_file)
2950
- out |= matches.groupdict()
2951
- except (TypeError, AttributeError):
2952
- continue
2953
- if not out:
2954
- raise _RegexError()
2955
- return out
2956
-
2957
-
2958
- def _fix_path(path: str) -> str:
2959
- return (
2960
- str(path).replace("\\", "/").replace(r"\"", "/").replace("//", "/").rstrip("/")
2961
- )
2962
-
2963
-
2964
- def _get_regexes_matches_for_df(
2965
- df, match_col: str, patterns: Sequence[re.Pattern]
2966
- ) -> pd.DataFrame:
2967
- if not len(df):
2968
- return df
2969
-
2970
- non_optional_groups = list(
2971
- set(
2972
- itertools.chain.from_iterable(
2973
- [_get_non_optional_groups(pat) for pat in patterns]
2974
- )
2975
- )
2976
- )
2977
-
2978
- if not non_optional_groups:
2979
- return df
2980
-
2981
- assert df.index.is_unique
2982
- keep = []
2983
- for pat in patterns:
2984
- for i, row in df[match_col].items():
2985
- matches = _get_first_group_match(pat, row)
2986
- if all(group in matches for group in non_optional_groups):
2987
- keep.append(i)
2988
-
2989
- return df.loc[keep]
2990
-
2991
-
2992
- def _get_non_optional_groups(pat: re.Pattern | str) -> list[str]:
2993
- return [
2994
- x
2995
- for x in [
2996
- _extract_group_name(group)
2997
- for group in pat.pattern.split("\n")
2998
- if group
2999
- and not group.replace(" ", "").startswith("#")
3000
- and not group.replace(" ", "").split("#")[0].endswith("?")
3001
- ]
3002
- if x is not None
3003
- ]
3004
-
3005
-
3006
- def _extract_group_name(txt: str) -> str | None:
3007
- try:
3008
- return re.search(r"\(\?P<(\w+)>", txt)[1]
3009
- except TypeError:
3010
- return None
3011
-
3012
-
3013
- def _get_first_group_match(pat: re.Pattern, text: str) -> dict[str, str]:
3014
- groups = pat.groupindex.keys()
3015
- all_matches: dict[str, str] = {}
3016
- for x in pat.findall(text):
3017
- for group, value in zip(groups, x, strict=True):
3018
- if value and group not in all_matches:
3019
- all_matches[group] = value
3020
- return all_matches
3021
-
3022
-
3023
2956
  def _date_is_within(
3024
- path,
3025
- date_ranges: (
3026
- tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...] | None
3027
- ),
3028
- image_patterns: Sequence[re.Pattern],
3029
- date_format: str,
2957
+ date: str | None,
2958
+ date_ranges: DATE_RANGES_TYPE,
3030
2959
  ) -> bool:
3031
- for pat in image_patterns:
3032
-
3033
- try:
3034
- date = _get_first_group_match(pat, Path(path).name)["date"]
3035
- break
3036
- except KeyError:
3037
- date = None
2960
+ if date_ranges is None:
2961
+ return True
3038
2962
 
3039
2963
  if date is None:
3040
2964
  return False
3041
2965
 
3042
- if date_ranges is None:
3043
- return True
2966
+ date = pd.Timestamp(date)
3044
2967
 
3045
- if all(x is None or isinstance(x, (str, float)) for x in date_ranges):
2968
+ if all(x is None or isinstance(x, str) for x in date_ranges):
3046
2969
  date_ranges = (date_ranges,)
3047
2970
 
3048
- if all(isinstance(x, float) for date_range in date_ranges for x in date_range):
3049
- date = disambiguate_timestamp(date, date_format)
3050
- else:
3051
- date = date[:8]
3052
-
3053
2971
  for date_range in date_ranges:
3054
2972
  date_min, date_max = date_range
3055
2973
 
3056
- if isinstance(date_min, float) and isinstance(date_max, float):
3057
- if date[0] >= date_min + 0.0000001 and date[1] <= date_max - 0.0000001:
3058
- return True
3059
- continue
2974
+ if date_min is not None:
2975
+ date_min = pd.Timestamp(date_min)
2976
+ if date_max is not None:
2977
+ date_max = pd.Timestamp(date_max)
3060
2978
 
3061
- try:
3062
- date_min = date_min or "00000000"
3063
- date_max = date_max or "99999999"
3064
- if not (
3065
- isinstance(date_min, str)
3066
- and len(date_min) == 8
3067
- and isinstance(date_max, str)
3068
- and len(date_max) == 8
3069
- ):
3070
- raise ValueError()
3071
- except ValueError as err:
3072
- raise TypeError(
3073
- "date_ranges should be a tuple of two 8-charactered strings (start and end date)."
3074
- f"Got {date_range} of type {[type(x) for x in date_range]}"
3075
- ) from err
3076
- if date >= date_min and date <= date_max:
2979
+ if (date_min is None or date >= date_min) and (
2980
+ date_max is None or date <= date_max
2981
+ ):
3077
2982
  return True
3078
2983
 
3079
2984
  return False
@@ -3093,10 +2998,6 @@ def _get_dtype_max(dtype: str | type) -> int | float:
3093
2998
  return np.finfo(dtype).max
3094
2999
 
3095
3000
 
3096
- def _img_ndvi(img, **kwargs):
3097
- return Image([img.ndvi(**kwargs)])
3098
-
3099
-
3100
3001
  def _intesects(x, other) -> bool:
3101
3002
  return box(*x.bounds).intersects(other)
3102
3003
 
@@ -3116,6 +3017,17 @@ def _copy_and_add_df_parallel(
3116
3017
  return (i, copied)
3117
3018
 
3118
3019
 
3020
+ def _get_bounds(bounds, bbox) -> None | Polygon:
3021
+ if bounds is None and bbox is None:
3022
+ return None
3023
+ elif bounds is not None and bbox is None:
3024
+ return to_shapely(bounds) # .intersection(self.union_all())
3025
+ elif bounds is None and bbox is not None:
3026
+ return to_shapely(bbox) # .intersection(self.union_all())
3027
+ else:
3028
+ return to_shapely(bounds).intersection(to_shapely(bbox))
3029
+
3030
+
3119
3031
  def _get_single_value(values: tuple):
3120
3032
  if len(set(values)) == 1:
3121
3033
  return next(iter(values))
@@ -3173,85 +3085,126 @@ def array_buffer(arr: np.ndarray, distance: int) -> np.ndarray:
3173
3085
  return binary_erosion(arr, structure=structure).astype(dtype)
3174
3086
 
3175
3087
 
3176
- class Sentinel2Config:
3177
- """Holder of Sentinel 2 regexes, band_ids etc."""
3178
-
3179
- image_regexes: ClassVar[str] = (config.SENTINEL2_IMAGE_REGEX,)
3180
- filename_regexes: ClassVar[str] = (
3181
- config.SENTINEL2_FILENAME_REGEX,
3182
- config.SENTINEL2_CLOUD_FILENAME_REGEX,
3183
- )
3184
- all_bands: ClassVar[list[str]] = list(config.SENTINEL2_BANDS)
3185
- rbg_bands: ClassVar[list[str]] = config.SENTINEL2_RBG_BANDS
3186
- ndvi_bands: ClassVar[list[str]] = config.SENTINEL2_NDVI_BANDS
3187
- l2a_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L2A_BANDS
3188
- l1c_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L1C_BANDS
3189
- date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
3190
- masking: ClassVar[BandMasking] = BandMasking(
3191
- band_id="SCL", values=(3, 8, 9, 10, 11)
3192
- )
3193
-
3194
-
3195
- class Sentinel2CloudlessConfig(Sentinel2Config):
3196
- """Holder of regexes, band_ids etc. for Sentinel 2 cloudless mosaic."""
3088
+ def get_cmap(arr: np.ndarray) -> LinearSegmentedColormap:
3197
3089
 
3198
- image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
3199
- filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
3200
- masking: ClassVar[None] = None
3201
- date_format: ClassVar[str] = "%Y%m%d"
3202
- all_bands: ClassVar[list[str]] = [
3203
- x.replace("B0", "B") for x in Sentinel2Config.all_bands
3090
+ # blue = [[i / 10 + 0.1, i / 10 + 0.1, 1 - (i / 10) + 0.1] for i in range(11)][1:]
3091
+ blue = [
3092
+ [0.1, 0.1, 1.0],
3093
+ [0.2, 0.2, 0.9],
3094
+ [0.3, 0.3, 0.8],
3095
+ [0.4, 0.4, 0.7],
3096
+ [0.6, 0.6, 0.6],
3097
+ [0.6, 0.6, 0.6],
3098
+ [0.7, 0.7, 0.7],
3099
+ [0.8, 0.8, 0.8],
3204
3100
  ]
3205
- rbg_bands: ClassVar[list[str]] = [
3206
- x.replace("B0", "B") for x in Sentinel2Config.rbg_bands
3101
+ # gray = list(reversed([[i / 10 - 0.1, i / 10, i / 10 - 0.1] for i in range(11)][1:]))
3102
+ gray = [
3103
+ [0.6, 0.6, 0.6],
3104
+ [0.6, 0.6, 0.6],
3105
+ [0.6, 0.6, 0.6],
3106
+ [0.6, 0.6, 0.6],
3107
+ [0.6, 0.6, 0.6],
3108
+ [0.4, 0.7, 0.4],
3109
+ [0.3, 0.7, 0.3],
3110
+ [0.2, 0.8, 0.2],
3207
3111
  ]
3208
- ndvi_bands: ClassVar[list[str]] = [
3209
- x.replace("B0", "B") for x in Sentinel2Config.ndvi_bands
3112
+ # gray = [[0.6, 0.6, 0.6] for i in range(10)]
3113
+ # green = [[0.2 + i/20, i / 10 - 0.1, + i/20] for i in range(11)][1:]
3114
+ green = [
3115
+ [0.25, 0.0, 0.05],
3116
+ [0.3, 0.1, 0.1],
3117
+ [0.35, 0.2, 0.15],
3118
+ [0.4, 0.3, 0.2],
3119
+ [0.45, 0.4, 0.25],
3120
+ [0.5, 0.5, 0.3],
3121
+ [0.55, 0.6, 0.35],
3122
+ [0.7, 0.9, 0.5],
3123
+ ]
3124
+ green = [
3125
+ [0.6, 0.6, 0.6],
3126
+ [0.4, 0.7, 0.4],
3127
+ [0.3, 0.8, 0.3],
3128
+ [0.25, 0.4, 0.25],
3129
+ [0.2, 0.5, 0.2],
3130
+ [0.10, 0.7, 0.10],
3131
+ [0, 0.9, 0],
3210
3132
  ]
3211
3133
 
3134
+ def get_start(arr):
3135
+ min_value = np.min(arr)
3136
+ if min_value < -0.75:
3137
+ return 0
3138
+ if min_value < -0.5:
3139
+ return 1
3140
+ if min_value < -0.25:
3141
+ return 2
3142
+ if min_value < 0:
3143
+ return 3
3144
+ if min_value < 0.25:
3145
+ return 4
3146
+ if min_value < 0.5:
3147
+ return 5
3148
+ if min_value < 0.75:
3149
+ return 6
3150
+ return 7
3212
3151
 
3213
- class Sentinel2Band(Sentinel2Config, Band):
3214
- """Band with Sentinel2 specific name variables and regexes."""
3215
-
3216
-
3217
- class Sentinel2Image(Sentinel2Config, Image):
3218
- """Image with Sentinel2 specific name variables and regexes."""
3219
-
3220
- cloud_cover_regexes: ClassVar[tuple[str]] = config.CLOUD_COVERAGE_REGEXES
3221
- band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3222
-
3223
- def ndvi(
3224
- self,
3225
- red_band: str = Sentinel2Config.ndvi_bands[0],
3226
- nir_band: str = Sentinel2Config.ndvi_bands[1],
3227
- copy: bool = True,
3228
- ) -> NDVIBand:
3229
- """Calculate the NDVI for the Image."""
3230
- return super().ndvi(red_band=red_band, nir_band=nir_band, copy=copy)
3231
-
3232
-
3233
- class Sentinel2Collection(Sentinel2Config, ImageCollection):
3234
- """ImageCollection with Sentinel2 specific name variables and regexes."""
3235
-
3236
- image_class: ClassVar[Sentinel2Image] = Sentinel2Image
3237
- band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3238
-
3239
-
3240
- class Sentinel2CloudlessBand(Sentinel2CloudlessConfig, Band):
3241
- """Band for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3242
-
3152
+ def get_stop(arr):
3153
+ max_value = np.max(arr)
3154
+ if max_value <= 0.05:
3155
+ return 0
3156
+ if max_value < 0.175:
3157
+ return 1
3158
+ if max_value < 0.25:
3159
+ return 2
3160
+ if max_value < 0.375:
3161
+ return 3
3162
+ if max_value < 0.5:
3163
+ return 4
3164
+ if max_value < 0.75:
3165
+ return 5
3166
+ return 6
3243
3167
 
3244
- class Sentinel2CloudlessImage(Sentinel2CloudlessConfig, Sentinel2Image):
3245
- """Image for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3168
+ cmap_name = "blue_gray_green"
3246
3169
 
3247
- cloud_cover_regexes: ClassVar[None] = None
3248
- band_class: ClassVar[Sentinel2CloudlessBand] = Sentinel2CloudlessBand
3170
+ start = get_start(arr)
3171
+ stop = get_stop(arr)
3172
+ blue = blue[start]
3173
+ gray = gray[start]
3174
+ # green = green[start]
3175
+ green = green[stop]
3249
3176
 
3250
- ndvi = Sentinel2Image.ndvi
3177
+ # green[0] = np.arange(0, 1, 0.1)[::-1][stop]
3178
+ # green[1] = np.arange(0, 1, 0.1)[stop]
3179
+ # green[2] = np.arange(0, 1, 0.1)[::-1][stop]
3251
3180
 
3181
+ print(green)
3182
+ print(start, stop)
3183
+ print("blue gray green")
3184
+ print(blue)
3185
+ print(gray)
3186
+ print(green)
3252
3187
 
3253
- class Sentinel2CloudlessCollection(Sentinel2CloudlessConfig, ImageCollection):
3254
- """ImageCollection with Sentinel2 specific name variables and regexes."""
3188
+ # Define the segments of the colormap
3189
+ cdict = {
3190
+ "red": [
3191
+ (0.0, blue[0], blue[0]),
3192
+ (0.3, gray[0], gray[0]),
3193
+ (0.7, gray[0], gray[0]),
3194
+ (1.0, green[0], green[0]),
3195
+ ],
3196
+ "green": [
3197
+ (0.0, blue[1], blue[1]),
3198
+ (0.3, gray[1], gray[1]),
3199
+ (0.7, gray[1], gray[1]),
3200
+ (1.0, green[1], green[1]),
3201
+ ],
3202
+ "blue": [
3203
+ (0.0, blue[2], blue[2]),
3204
+ (0.3, gray[2], gray[2]),
3205
+ (0.7, gray[2], gray[2]),
3206
+ (1.0, green[2], green[2]),
3207
+ ],
3208
+ }
3255
3209
 
3256
- image_class: ClassVar[Sentinel2CloudlessImage] = Sentinel2CloudlessImage
3257
- band_class: ClassVar[Sentinel2Band] = Sentinel2CloudlessBand
3210
+ return LinearSegmentedColormap(cmap_name, segmentdata=cdict, N=50)