ssb-sgis 1.0.5__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,14 +2,15 @@ import datetime
2
2
  import functools
3
3
  import glob
4
4
  import itertools
5
- import math
6
5
  import os
7
6
  import random
8
7
  import re
8
+ import time
9
9
  from collections.abc import Callable
10
10
  from collections.abc import Iterable
11
11
  from collections.abc import Iterator
12
12
  from collections.abc import Sequence
13
+ from concurrent.futures import ThreadPoolExecutor
13
14
  from copy import deepcopy
14
15
  from dataclasses import dataclass
15
16
  from pathlib import Path
@@ -26,9 +27,8 @@ from affine import Affine
26
27
  from geopandas import GeoDataFrame
27
28
  from geopandas import GeoSeries
28
29
  from matplotlib.colors import LinearSegmentedColormap
30
+ from pandas.api.types import is_dict_like
29
31
  from rasterio.enums import MergeAlg
30
- from rtree.index import Index
31
- from rtree.index import Property
32
32
  from scipy import stats
33
33
  from scipy.ndimage import binary_dilation
34
34
  from scipy.ndimage import binary_erosion
@@ -49,24 +49,15 @@ except ImportError:
49
49
 
50
50
 
51
51
  try:
52
- from rioxarray.exceptions import NoDataInBounds
53
- from rioxarray.merge import merge_arrays
54
- from rioxarray.rioxarray import _generate_spatial_coords
55
- except ImportError:
56
- pass
57
- try:
58
- import xarray as xr
59
- from xarray import DataArray
52
+ from google.auth import exceptions
60
53
  except ImportError:
61
54
 
62
- class DataArray:
55
+ class exceptions:
63
56
  """Placeholder."""
64
57
 
58
+ class RefreshError:
59
+ """Placeholder."""
65
60
 
66
- try:
67
- import torch
68
- except ImportError:
69
- pass
70
61
 
71
62
  try:
72
63
  from gcsfs.core import GCSFile
@@ -77,33 +68,31 @@ except ImportError:
77
68
 
78
69
 
79
70
  try:
80
- from torchgeo.datasets.utils import disambiguate_timestamp
71
+ from rioxarray.exceptions import NoDataInBounds
72
+ from rioxarray.merge import merge_arrays
73
+ from rioxarray.rioxarray import _generate_spatial_coords
81
74
  except ImportError:
82
-
83
- class torch:
84
- """Placeholder."""
85
-
86
- class Tensor:
87
- """Placeholder to reference torch.Tensor."""
88
-
89
-
75
+ pass
90
76
  try:
91
- from torchgeo.datasets.utils import BoundingBox
77
+ import xarray as xr
78
+ from xarray import DataArray
79
+ from xarray import Dataset
92
80
  except ImportError:
93
81
 
94
- class BoundingBox:
82
+ class DataArray:
95
83
  """Placeholder."""
96
84
 
97
- def __init__(self, *args, **kwargs) -> None:
98
- """Placeholder."""
99
- raise ImportError("missing optional dependency 'torchgeo'")
85
+ class Dataset:
86
+ """Placeholder."""
100
87
 
101
88
 
102
89
  from ..geopandas_tools.bounds import get_total_bounds
103
90
  from ..geopandas_tools.conversion import to_bbox
104
91
  from ..geopandas_tools.conversion import to_gdf
92
+ from ..geopandas_tools.conversion import to_geoseries
105
93
  from ..geopandas_tools.conversion import to_shapely
106
94
  from ..geopandas_tools.general import get_common_crs
95
+ from ..helpers import _fix_path
107
96
  from ..helpers import get_all_files
108
97
  from ..helpers import get_numpy_func
109
98
  from ..io._is_dapla import is_dapla
@@ -115,6 +104,11 @@ from .base import _get_shape_from_bounds
115
104
  from .base import _get_transform_from_bounds
116
105
  from .base import get_index_mapper
117
106
  from .indices import ndvi
107
+ from .regex import _extract_regex_match_from_string
108
+ from .regex import _get_first_group_match
109
+ from .regex import _get_non_optional_groups
110
+ from .regex import _get_regexes_matches_for_df
111
+ from .regex import _RegexError
118
112
  from .zonal import _aggregate
119
113
  from .zonal import _make_geometry_iterrows
120
114
  from .zonal import _no_overlap_df
@@ -132,9 +126,6 @@ if is_dapla():
132
126
  def _open_func(*args, **kwargs) -> GCSFile:
133
127
  return dp.FileClient.get_gcs_file_system().open(*args, **kwargs)
134
128
 
135
- def _rm_file_func(*args, **kwargs) -> None:
136
- return dp.FileClient.get_gcs_file_system().rm_file(*args, **kwargs)
137
-
138
129
  def _read_parquet_func(*args, **kwargs) -> list[str]:
139
130
  return dp.read_pandas(*args, **kwargs)
140
131
 
@@ -142,22 +133,25 @@ else:
142
133
  _ls_func = functools.partial(get_all_files, recursive=False)
143
134
  _open_func = open
144
135
  _glob_func = glob.glob
145
- _rm_file_func = os.remove
146
136
  _read_parquet_func = pd.read_parquet
147
137
 
148
- TORCHGEO_RETURN_TYPE = dict[str, torch.Tensor | pyproj.CRS | BoundingBox]
138
+ DATE_RANGES_TYPE = (
139
+ tuple[str | pd.Timestamp | None, str | pd.Timestamp | None]
140
+ | tuple[tuple[str | pd.Timestamp | None, str | pd.Timestamp | None], ...]
141
+ )
142
+
149
143
  FILENAME_COL_SUFFIX = "_filename"
144
+
150
145
  DEFAULT_FILENAME_REGEX = r"""
151
146
  .*?
152
- (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
147
+ (?:_?(?P<date>\d{8}(?:T\d{6})?))? # Optional underscore and date group
153
148
  .*?
154
- (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
149
+ (?:_?(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional underscore and band group
155
150
  \.(?:tif|tiff|jp2)$ # End with .tif, .tiff, or .jp2
156
151
  """
157
152
  DEFAULT_IMAGE_REGEX = r"""
158
153
  .*?
159
- (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
160
- (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
154
+ (?:_?(?P<date>\d{8}(?:T\d{6})?))? # Optional underscore and date group
161
155
  """
162
156
 
163
157
  ALLOWED_INIT_KWARGS = [
@@ -165,15 +159,21 @@ ALLOWED_INIT_KWARGS = [
165
159
  "band_class",
166
160
  "image_regexes",
167
161
  "filename_regexes",
168
- "date_format",
169
- "cloud_cover_regexes",
170
- "bounds_regexes",
171
162
  "all_bands",
172
163
  "crs",
164
+ "backend",
173
165
  "masking",
174
166
  "_merged",
175
167
  ]
176
168
 
169
+ _load_counter: int = 0
170
+
171
+
172
+ def _get_child_paths_threaded(data: Sequence[str]) -> set[str]:
173
+ with ThreadPoolExecutor() as executor:
174
+ all_paths: Iterator[set[str]] = executor.map(_ls_func, data)
175
+ return set(itertools.chain.from_iterable(all_paths))
176
+
177
177
 
178
178
  class ImageCollectionGroupBy:
179
179
  """Iterator and merger class returned from groupby.
@@ -225,7 +225,6 @@ class ImageCollectionGroupBy:
225
225
 
226
226
  collection = ImageCollection(
227
227
  images,
228
- # TODO band_class?
229
228
  level=self.collection.level,
230
229
  **self.collection._common_init_kwargs,
231
230
  )
@@ -263,7 +262,6 @@ class ImageCollectionGroupBy:
263
262
 
264
263
  image = Image(
265
264
  bands,
266
- # TODO band_class?
267
265
  **self.collection._common_init_kwargs,
268
266
  )
269
267
  image._merged = True
@@ -295,29 +293,40 @@ class ImageCollectionGroupBy:
295
293
 
296
294
  @dataclass(frozen=True)
297
295
  class BandMasking:
298
- """Basically a frozen dict with forced keys."""
296
+ """Frozen dict with forced keys."""
299
297
 
300
298
  band_id: str
301
- values: tuple[int]
299
+ values: Sequence[int] | dict[int, Any]
302
300
 
303
301
  def __getitem__(self, item: str) -> Any:
304
302
  """Index into attributes to mimick dict."""
305
303
  return getattr(self, item)
306
304
 
307
305
 
306
+ class None_:
307
+ """Default value for keyword arguments that should not have a default."""
308
+
309
+
308
310
  class _ImageBase:
309
311
  image_regexes: ClassVar[str | None] = (DEFAULT_IMAGE_REGEX,)
310
312
  filename_regexes: ClassVar[str | tuple[str]] = (DEFAULT_FILENAME_REGEX,)
311
- date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
313
+ metadata_attributes: ClassVar[dict | None] = None
312
314
  masking: ClassVar[BandMasking | None] = None
313
315
 
314
- def __init__(self, **kwargs) -> None:
316
+ def __init__(self, *, metadata=None, bbox=None, **kwargs) -> None:
315
317
 
316
318
  self._mask = None
317
319
  self._bounds = None
318
320
  self._merged = False
319
321
  self._from_array = False
320
322
  self._from_gdf = False
323
+ self.metadata_attributes = self.metadata_attributes or {}
324
+ self._path = None
325
+ self._metadata_from_xml = False
326
+
327
+ self._bbox = to_bbox(bbox) if bbox is not None else None
328
+
329
+ self.metadata = self._metadata_to_nested_dict(metadata)
321
330
 
322
331
  if self.filename_regexes:
323
332
  if isinstance(self.filename_regexes, str):
@@ -346,14 +355,45 @@ class _ImageBase:
346
355
  f"{self.__class__.__name__} got an unexpected keyword argument '{key}'"
347
356
  )
348
357
 
358
+ @staticmethod
359
+ def _metadata_to_nested_dict(
360
+ metadata: str | Path | os.PathLike | dict | pd.DataFrame | None,
361
+ ) -> dict[str, dict[str, Any]] | None:
362
+ if metadata is None:
363
+ return {}
364
+ if isinstance(metadata, (str | Path | os.PathLike)):
365
+ metadata = _read_parquet_func(metadata)
366
+
367
+ if isinstance(metadata, pd.DataFrame):
368
+
369
+ def is_scalar(x) -> bool:
370
+ return not hasattr(x, "__len__") or len(x) <= 1
371
+
372
+ def na_to_none(x) -> None:
373
+ """Convert to None rowwise because pandas doesn't always."""
374
+ return x if not (is_scalar(x) and pd.isna(x)) else None
375
+
376
+ # to nested dict because pandas indexing gives rare KeyError with long strings
377
+ metadata = {
378
+ _fix_path(path): {
379
+ attr: na_to_none(value) for attr, value in row.items()
380
+ }
381
+ for path, row in metadata.iterrows()
382
+ }
383
+ elif is_dict_like(metadata):
384
+ metadata = {_fix_path(path): value for path, value in metadata.items()}
385
+
386
+ return metadata
387
+
349
388
  @property
350
389
  def _common_init_kwargs(self) -> dict:
351
390
  return {
352
- "file_system": self.file_system,
353
391
  "processes": self.processes,
354
392
  "res": self.res,
355
393
  "bbox": self._bbox,
356
394
  "nodata": self.nodata,
395
+ "backend": self.backend,
396
+ "metadata": self.metadata,
357
397
  }
358
398
 
359
399
  @property
@@ -373,6 +413,14 @@ class _ImageBase:
373
413
  """Centerpoint of the object."""
374
414
  return self.union_all().centroid
375
415
 
416
+ def assign(self, **kwargs) -> "_ImageBase":
417
+ for key, value in kwargs.items():
418
+ try:
419
+ setattr(self, key, value)
420
+ except AttributeError:
421
+ setattr(self, f"_{key}", value)
422
+ return self
423
+
376
424
  def _name_regex_searcher(
377
425
  self, group: str, patterns: tuple[re.Pattern]
378
426
  ) -> str | None:
@@ -381,46 +429,55 @@ class _ImageBase:
381
429
  for pat in patterns:
382
430
  try:
383
431
  return _get_first_group_match(pat, self.name)[group]
384
- return re.match(pat, self.name).group(group)
385
432
  except (TypeError, KeyError):
386
433
  pass
434
+ if isinstance(self, Band):
435
+ for pat in patterns:
436
+ try:
437
+ return _get_first_group_match(
438
+ pat, str(Path(self.path).parent.name)
439
+ )[group]
440
+ except (TypeError, KeyError):
441
+ pass
387
442
  if not any(group in _get_non_optional_groups(pat) for pat in patterns):
388
443
  return None
444
+ band_text = (
445
+ f" or {Path(self.path).parent.name!s}" if isinstance(self, Band) else ""
446
+ )
389
447
  raise ValueError(
390
- f"Couldn't find group '{group}' in name {self.name} with regex patterns {patterns}"
448
+ f"Couldn't find group '{group}' in name {self.name}{band_text} with regex patterns {patterns}"
391
449
  )
392
450
 
393
- def _create_metadata_df(self, file_paths: list[str]) -> pd.DataFrame:
451
+ def _create_metadata_df(self, file_paths: Sequence[str]) -> pd.DataFrame:
394
452
  """Create a dataframe with file paths and image paths that match regexes."""
395
- df = pd.DataFrame({"file_path": file_paths})
453
+ df = pd.DataFrame({"file_path": list(file_paths)})
396
454
 
397
- df["filename"] = df["file_path"].apply(lambda x: _fix_path(Path(x).name))
455
+ df["file_name"] = df["file_path"].apply(lambda x: Path(x).name)
398
456
 
399
- if not self.single_banded:
400
- df["image_path"] = df["file_path"].apply(
401
- lambda x: _fix_path(str(Path(x).parent))
402
- )
403
- else:
404
- df["image_path"] = df["file_path"]
457
+ df["image_path"] = df["file_path"].apply(
458
+ lambda x: _fix_path(str(Path(x).parent))
459
+ )
405
460
 
406
461
  if not len(df):
407
462
  return df
408
463
 
464
+ df = df[~df["file_path"].isin(df["image_path"])]
465
+
409
466
  if self.filename_patterns:
410
- df = _get_regexes_matches_for_df(df, "filename", self.filename_patterns)
467
+ df = _get_regexes_matches_for_df(df, "file_name", self.filename_patterns)
411
468
 
412
469
  if not len(df):
413
470
  return df
414
471
 
415
472
  grouped = df.drop_duplicates("image_path").set_index("image_path")
416
- for col in ["file_path", "filename"]:
473
+ for col in ["file_path", "file_name"]:
417
474
  if col in df:
418
475
  grouped[col] = df.groupby("image_path")[col].apply(tuple)
419
476
 
420
477
  grouped = grouped.reset_index()
421
478
  else:
422
479
  df["file_path"] = df.groupby("image_path")["file_path"].apply(tuple)
423
- df["filename"] = df.groupby("image_path")["filename"].apply(tuple)
480
+ df["file_name"] = df.groupby("image_path")["file_name"].apply(tuple)
424
481
  grouped = df.drop_duplicates("image_path")
425
482
 
426
483
  grouped["imagename"] = grouped["image_path"].apply(
@@ -446,8 +503,19 @@ class _ImageBase:
446
503
  continue
447
504
  return copied
448
505
 
506
+ def equals(self, other) -> bool:
507
+ for key, value in self.__dict__.items():
508
+ if key.startswith("_"):
509
+ continue
510
+ if value != getattr(other, key):
511
+ print(key, value, getattr(other, key))
512
+ return False
513
+ return True
514
+
449
515
 
450
516
  class _ImageBandBase(_ImageBase):
517
+ """Common parent class of Image and Band."""
518
+
451
519
  def intersects(self, other: GeoDataFrame | GeoSeries | Geometry) -> bool:
452
520
  if hasattr(other, "crs") and not pyproj.CRS(self.crs).equals(
453
521
  pyproj.CRS(other.crs)
@@ -455,6 +523,12 @@ class _ImageBandBase(_ImageBase):
455
523
  raise ValueError(f"crs mismatch: {self.crs} and {other.crs}")
456
524
  return self.union_all().intersects(to_shapely(other))
457
525
 
526
+ def union_all(self) -> Polygon:
527
+ try:
528
+ return box(*self.bounds)
529
+ except TypeError:
530
+ return Polygon()
531
+
458
532
  @property
459
533
  def mask_percentage(self) -> float:
460
534
  return self.mask.values.sum() / (self.mask.width * self.mask.height) * 100
@@ -477,7 +551,7 @@ class _ImageBandBase(_ImageBase):
477
551
  return self._name
478
552
  try:
479
553
  return Path(self.path).name
480
- except (ValueError, AttributeError):
554
+ except (ValueError, AttributeError, TypeError):
481
555
  return None
482
556
 
483
557
  @name.setter
@@ -488,37 +562,101 @@ class _ImageBandBase(_ImageBase):
488
562
  def stem(self) -> str | None:
489
563
  try:
490
564
  return Path(self.path).stem
491
- except (AttributeError, ValueError):
565
+ except (AttributeError, ValueError, TypeError):
492
566
  return None
493
567
 
494
568
  @property
495
569
  def level(self) -> str:
496
570
  return self._name_regex_searcher("level", self.image_patterns)
497
571
 
498
- @property
499
- def mint(self) -> float:
500
- return disambiguate_timestamp(self.date, self.date_format)[0]
572
+ def _get_metadata_attributes(self, metadata_attributes: dict) -> dict:
501
573
 
502
- @property
503
- def maxt(self) -> float:
504
- return disambiguate_timestamp(self.date, self.date_format)[1]
574
+ self._metadata_from_xml = True
505
575
 
506
- def union_all(self) -> Polygon:
507
- try:
508
- return box(*self.bounds)
509
- except TypeError:
510
- return Polygon()
576
+ missing_metadata_attributes = {
577
+ key: value
578
+ for key, value in metadata_attributes.items()
579
+ if not hasattr(self, key) or getattr(self, key) is None
580
+ }
511
581
 
512
- @property
513
- def torch_bbox(self) -> BoundingBox:
514
- bounds = GeoSeries([self.union_all()]).bounds
515
- return BoundingBox(
516
- minx=bounds.minx[0],
517
- miny=bounds.miny[0],
518
- maxx=bounds.maxx[0],
519
- maxy=bounds.maxy[0],
520
- mint=self.mint,
521
- maxt=self.maxt,
582
+ nonmissing_metadata_attributes = {
583
+ key: getattr(self, key)
584
+ for key in metadata_attributes
585
+ if key not in missing_metadata_attributes
586
+ }
587
+
588
+ if not missing_metadata_attributes:
589
+ return nonmissing_metadata_attributes
590
+
591
+ file_contents: list[str] = []
592
+ for path in self._all_file_paths:
593
+ if ".xml" not in path:
594
+ continue
595
+ with _open_func(path, "rb") as file:
596
+ file_contents.append(file.read().decode("utf-8"))
597
+
598
+ for key, value in missing_metadata_attributes.items():
599
+ results = None
600
+ for i, filetext in enumerate(file_contents):
601
+ if isinstance(value, str) and value in dir(self):
602
+ method = getattr(self, value)
603
+ try:
604
+ results = method(filetext)
605
+ except _RegexError as e:
606
+ if i == len(self._all_file_paths) - 1:
607
+ raise e
608
+ continue
609
+ if results is not None:
610
+ break
611
+
612
+ if callable(value):
613
+ try:
614
+ results = value(filetext)
615
+ except _RegexError as e:
616
+ if i == len(self._all_file_paths) - 1:
617
+ raise e
618
+ continue
619
+ if results is not None:
620
+ break
621
+
622
+ try:
623
+ results = _extract_regex_match_from_string(filetext, value)
624
+ except _RegexError as e:
625
+ if i == len(self._all_file_paths) - 1:
626
+ raise e
627
+
628
+ missing_metadata_attributes[key] = results
629
+
630
+ return missing_metadata_attributes | nonmissing_metadata_attributes
631
+
632
+ def _to_xarray(self, array: np.ndarray, transform: Affine) -> DataArray:
633
+ """Convert the raster to an xarray.DataArray."""
634
+ if len(array.shape) == 2:
635
+ height, width = array.shape
636
+ dims = ["y", "x"]
637
+ elif len(array.shape) == 3:
638
+ height, width = array.shape[1:]
639
+ dims = ["band", "y", "x"]
640
+ else:
641
+ raise ValueError(
642
+ f"Array should be 2 or 3 dimensional. Got shape {array.shape}"
643
+ )
644
+
645
+ coords = _generate_spatial_coords(transform, width, height)
646
+
647
+ attrs = {"crs": self.crs}
648
+ for attr in set(self.metadata_attributes).union({"date"}):
649
+ try:
650
+ attrs[attr] = getattr(self, attr)
651
+ except Exception:
652
+ pass
653
+
654
+ return DataArray(
655
+ array,
656
+ coords=coords,
657
+ dims=dims,
658
+ name=self.name or self.__class__.__name__,
659
+ attrs=attrs,
522
660
  )
523
661
 
524
662
 
@@ -526,6 +664,7 @@ class Band(_ImageBandBase):
526
664
  """Band holding a single 2 dimensional array representing an image band."""
527
665
 
528
666
  cmap: ClassVar[str | None] = None
667
+ backend: str = "numpy"
529
668
 
530
669
  @classmethod
531
670
  def from_gdf(
@@ -557,42 +696,52 @@ class Band(_ImageBandBase):
557
696
 
558
697
  def __init__(
559
698
  self,
560
- data: str | np.ndarray,
561
- res: int | None,
699
+ data: str | np.ndarray | None = None,
700
+ res: int | None_ = None_,
562
701
  crs: Any | None = None,
563
702
  bounds: tuple[float, float, float, float] | None = None,
564
- cmap: str | None = None,
703
+ nodata: int | None = None,
704
+ mask: "Band | None" = None,
705
+ processes: int = 1,
565
706
  name: str | None = None,
566
- file_system: GCSFileSystem | None = None,
567
707
  band_id: str | None = None,
568
- processes: int = 1,
569
- bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
570
- mask: "Band | None" = None,
571
- nodata: int | None = None,
708
+ cmap: str | None = None,
709
+ all_file_paths: list[str] | None = None,
572
710
  **kwargs,
573
711
  ) -> None:
574
712
  """Band initialiser."""
713
+ if callable(res) and isinstance(res(), None_):
714
+ raise TypeError("Must specify 'res'")
715
+
716
+ if data is None:
717
+ # allowing 'path' to replace 'data' as argument
718
+ # to make the print repr. valid as initialiser
719
+ if "path" not in kwargs:
720
+ raise TypeError("Must specify either 'data' or 'path'.")
721
+ data = kwargs.pop("path")
722
+
575
723
  super().__init__(**kwargs)
576
724
 
725
+ if isinstance(data, (str | Path | os.PathLike)) and any(
726
+ arg is not None for arg in [crs, bounds]
727
+ ):
728
+ raise ValueError("Can only specify 'bounds' and 'crs' if data is an array.")
729
+
577
730
  self._mask = mask
578
- self._bbox = to_bbox(bbox) if bbox is not None else None
579
731
  self._values = None
580
- self._crs = None
581
732
  self.nodata = nodata
582
-
733
+ self._crs = crs
583
734
  bounds = to_bbox(bounds) if bounds is not None else None
584
-
585
735
  self._bounds = bounds
736
+ self._all_file_paths = all_file_paths
586
737
 
587
738
  if isinstance(data, np.ndarray):
588
- self.values = data
589
739
  if self._bounds is None:
590
740
  raise ValueError("Must specify bounds when data is an array.")
591
741
  self._crs = crs
592
- self.transform = _get_transform_from_bounds(
593
- self._bounds, shape=self.values.shape
594
- )
742
+ self.transform = _get_transform_from_bounds(self._bounds, shape=data.shape)
595
743
  self._from_array = True
744
+ self.values = data
596
745
 
597
746
  elif not isinstance(data, (str | Path | os.PathLike)):
598
747
  raise TypeError(
@@ -600,24 +749,43 @@ class Band(_ImageBandBase):
600
749
  f"Got {type(data)}"
601
750
  )
602
751
  else:
603
- self._path = str(data)
752
+ self._path = _fix_path(str(data))
604
753
 
605
754
  self._res = res
606
755
  if cmap is not None:
607
756
  self.cmap = cmap
608
- self.file_system = file_system
609
757
  self._name = name
610
758
  self._band_id = band_id
611
759
  self.processes = processes
612
760
 
613
- # if self.filename_regexes:
614
- # if isinstance(self.filename_regexes, str):
615
- # self.filename_regexes = [self.filename_regexes]
616
- # self.filename_patterns = [
617
- # re.compile(pat, flags=re.VERBOSE) for pat in self.filename_regexes
618
- # ]
619
- # else:
620
- # self.filename_patterns = None
761
+ if self._all_file_paths:
762
+ self._all_file_paths = {_fix_path(path) for path in self._all_file_paths}
763
+ parent = _fix_path(Path(self.path).parent)
764
+ self._all_file_paths = {
765
+ path for path in self._all_file_paths if parent in path
766
+ }
767
+
768
+ if self.metadata:
769
+ if self.path is not None:
770
+ self.metadata = {
771
+ key: value
772
+ for key, value in self.metadata.items()
773
+ if key == self.path
774
+ }
775
+ this_metadata = self.metadata[self.path]
776
+ for key, value in this_metadata.items():
777
+ if key in dir(self):
778
+ setattr(self, f"_{key}", value)
779
+ else:
780
+ setattr(self, key, value)
781
+
782
+ elif self.metadata_attributes and self.path is not None and not self.is_mask:
783
+ if self._all_file_paths is None:
784
+ self._all_file_paths = _get_all_file_paths(str(Path(self.path).parent))
785
+ for key, value in self._get_metadata_attributes(
786
+ self.metadata_attributes
787
+ ).items():
788
+ setattr(self, key, value)
621
789
 
622
790
  def __lt__(self, other: "Band") -> bool:
623
791
  """Makes Bands sortable by band_id."""
@@ -632,23 +800,35 @@ class Band(_ImageBandBase):
632
800
 
633
801
  @values.setter
634
802
  def values(self, new_val):
635
- if not isinstance(new_val, np.ndarray):
636
- raise TypeError(
637
- f"{self.__class__.__name__} 'values' must be np.ndarray. Got {type(new_val)}"
638
- )
639
- self._values = new_val
803
+ if self.backend == "numpy" and isinstance(new_val, np.ndarray):
804
+ self._values = new_val
805
+ return
806
+ elif self.backend == "xarray" and isinstance(new_val, DataArray):
807
+ # attrs can dissappear, so doing a union
808
+ attrs = self._values.attrs | new_val.attrs
809
+ self._values = new_val
810
+ self._values.attrs = attrs
811
+ return
812
+
813
+ if self.backend == "numpy":
814
+ self._values = self._to_numpy(new_val)
815
+ if self.backend == "xarray":
816
+ if not isinstance(self._values, DataArray):
817
+ self._values = self._to_xarray(
818
+ new_val,
819
+ transform=self.transform,
820
+ )
821
+
822
+ elif isinstance(new_val, np.ndarray):
823
+ self._values.values = new_val
824
+ else:
825
+ self._values = new_val
640
826
 
641
827
  @property
642
828
  def mask(self) -> "Band":
643
829
  """Mask Band."""
644
830
  return self._mask
645
831
 
646
- @mask.setter
647
- def mask(self, values: "Band") -> None:
648
- if values is not None and not isinstance(values, Band):
649
- raise TypeError(f"'mask' should be of type Band. Got {type(values)}")
650
- self._mask = values
651
-
652
832
  @property
653
833
  def band_id(self) -> str:
654
834
  """Band id."""
@@ -686,26 +866,24 @@ class Band(_ImageBandBase):
686
866
  )
687
867
 
688
868
  @property
689
- def crs(self) -> str | None:
869
+ def crs(self) -> pyproj.CRS | None:
690
870
  """Coordinate reference system."""
691
- if self._crs is not None:
692
- return self._crs
693
- with opener(self.path, file_system=self.file_system) as file:
694
- with rasterio.open(file) as src:
695
- # self._bounds = to_bbox(src.bounds)
696
- self._crs = src.crs
697
- return self._crs
871
+ if self._crs is None:
872
+ self._add_crs_and_bounds()
873
+ return pyproj.CRS(self._crs)
698
874
 
699
875
  @property
700
876
  def bounds(self) -> tuple[int, int, int, int] | None:
701
877
  """Bounds as tuple (minx, miny, maxx, maxy)."""
702
- if self._bounds is not None:
703
- return self._bounds
704
- with opener(self.path, file_system=self.file_system) as file:
878
+ if self._bounds is None:
879
+ self._add_crs_and_bounds()
880
+ return self._bounds
881
+
882
+ def _add_crs_and_bounds(self) -> None:
883
+ with opener(self.path) as file:
705
884
  with rasterio.open(file) as src:
706
885
  self._bounds = to_bbox(src.bounds)
707
886
  self._crs = src.crs
708
- return self._bounds
709
887
 
710
888
  def get_n_largest(
711
889
  self, n: int, precision: float = 0.000001, column: str = "value"
@@ -729,59 +907,64 @@ class Band(_ImageBandBase):
729
907
  df[column] = f"smallest_{n}"
730
908
  return df
731
909
 
910
+ def clip(
911
+ self, mask: GeoDataFrame | GeoSeries | Polygon | MultiPolygon, **kwargs
912
+ ) -> "Band":
913
+ """Clip band values to geometry mask."""
914
+ values = _clip_xarray(
915
+ self.to_xarray(),
916
+ mask,
917
+ crs=self.crs,
918
+ **kwargs,
919
+ )
920
+ self._bounds = to_bbox(mask)
921
+ self.transform = _get_transform_from_bounds(self._bounds, values.shape)
922
+ self.values = values
923
+ return self
924
+
732
925
  def load(
733
926
  self,
734
927
  bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
735
928
  indexes: int | tuple[int] | None = None,
736
929
  masked: bool | None = None,
930
+ file_system=None,
737
931
  **kwargs,
738
932
  ) -> "Band":
739
933
  """Load and potentially clip the array.
740
934
 
741
935
  The array is stored in the 'values' property.
742
936
  """
937
+ global _load_counter
938
+ _load_counter += 1
939
+
743
940
  if masked is None:
744
941
  masked = True if self.mask is None else False
745
942
 
746
943
  bounds_was_none = bounds is None
747
944
 
748
- try:
749
- if not isinstance(self.values, np.ndarray):
750
- raise ValueError()
751
- has_array = True
752
- except ValueError: # also catches ArrayNotLoadedError
753
- has_array = False
754
-
755
- # get common bounds of function argument 'bounds' and previously set bbox
756
- if bounds is None and self._bbox is None:
757
- bounds = None
758
- elif bounds is not None and self._bbox is None:
759
- bounds = to_shapely(bounds).intersection(self.union_all())
760
- elif bounds is None and self._bbox is not None:
761
- bounds = to_shapely(self._bbox).intersection(self.union_all())
762
- else:
763
- bounds = to_shapely(bounds).intersection(to_shapely(self._bbox))
945
+ bounds = _get_bounds(bounds, self._bbox, self.union_all())
764
946
 
765
947
  should_return_empty: bool = bounds is not None and bounds.area == 0
766
948
  if should_return_empty:
767
949
  self._values = np.array([])
768
950
  if self.mask is not None and not self.is_mask:
769
- self._mask = self._mask.load()
770
- # self._mask = np.ma.array([], [])
951
+ self._mask = self._mask.load(
952
+ bounds=bounds, indexes=indexes, file_system=file_system
953
+ )
771
954
  self._bounds = None
772
955
  self.transform = None
956
+ self.values = self._values
957
+
773
958
  return self
774
959
 
775
- if has_array and bounds_was_none:
960
+ if self.has_array and bounds_was_none:
776
961
  return self
777
962
 
778
- # round down/up to integer to avoid precision trouble
779
963
  if bounds is not None:
780
- # bounds = to_bbox(bounds)
781
964
  minx, miny, maxx, maxy = to_bbox(bounds)
782
- bounds = (int(minx), int(miny), math.ceil(maxx), math.ceil(maxy))
783
-
784
- boundless = False
965
+ ## round down/up to integer to avoid precision trouble
966
+ # bounds = (int(minx), int(miny), math.ceil(maxx), math.ceil(maxy))
967
+ bounds = minx, miny, maxx, maxy
785
968
 
786
969
  if indexes is None:
787
970
  indexes = 1
@@ -792,114 +975,132 @@ class Band(_ImageBandBase):
792
975
  # allow setting a fixed out_shape for the array, in order to make mask same shape as values
793
976
  out_shape = kwargs.pop("out_shape", None)
794
977
 
795
- if has_array:
796
- self.values = _clip_loaded_array(
797
- self.values, bounds, self.transform, self.crs, out_shape, **kwargs
978
+ if self.has_array and [int(x) for x in bounds] != [int(x) for x in self.bounds]:
979
+ print(self)
980
+ print(self.mask)
981
+ print(self.mask.values.shape)
982
+ print(self.values.shape)
983
+ print([int(x) for x in bounds], [int(x) for x in self.bounds])
984
+ raise ValueError(
985
+ "Cannot re-load array with different bounds. "
986
+ "Use .copy() to read with different bounds. "
987
+ "Or .clip(mask) to clip."
798
988
  )
799
- self._bounds = bounds
800
- self.transform = _get_transform_from_bounds(self._bounds, self.values.shape)
989
+ # with opener(self.path, file_system=self.file_system) as f:
990
+ with opener(self.path, file_system=file_system) as f:
991
+ with rasterio.open(f, nodata=self.nodata) as src:
992
+ self._res = int(src.res[0]) if not self.res else self.res
801
993
 
802
- else:
803
- with opener(self.path, file_system=self.file_system) as f:
804
- with rasterio.open(f, nodata=self.nodata) as src:
805
- self._res = int(src.res[0]) if not self.res else self.res
806
-
807
- if self.nodata is None or np.isnan(self.nodata):
808
- self.nodata = src.nodata
809
- else:
810
- dtype_min_value = _get_dtype_min(src.dtypes[0])
811
- dtype_max_value = _get_dtype_max(src.dtypes[0])
812
- if (
813
- self.nodata > dtype_max_value
814
- or self.nodata < dtype_min_value
815
- ):
816
- src._dtypes = tuple(
817
- rasterio.dtypes.get_minimum_dtype(self.nodata)
818
- for _ in range(len(_indexes))
819
- )
820
-
821
- if bounds is None:
822
- if self._res != int(src.res[0]):
823
- if out_shape is None:
824
- out_shape = _get_shape_from_bounds(
825
- to_bbox(src.bounds), self.res, indexes
826
- )
827
- self.transform = _get_transform_from_bounds(
828
- to_bbox(src.bounds), shape=out_shape
829
- )
830
- else:
831
- self.transform = src.transform
832
-
833
- self._values = src.read(
834
- indexes=indexes,
835
- out_shape=out_shape,
836
- masked=masked,
837
- **kwargs,
838
- )
839
- else:
840
- window = rasterio.windows.from_bounds(
841
- *bounds, transform=src.transform
994
+ if self.nodata is None or np.isnan(self.nodata):
995
+ self.nodata = src.nodata
996
+ else:
997
+ dtype_min_value = _get_dtype_min(src.dtypes[0])
998
+ dtype_max_value = _get_dtype_max(src.dtypes[0])
999
+ if self.nodata > dtype_max_value or self.nodata < dtype_min_value:
1000
+ src._dtypes = tuple(
1001
+ rasterio.dtypes.get_minimum_dtype(self.nodata)
1002
+ for _ in range(len(_indexes))
842
1003
  )
843
1004
 
1005
+ if bounds is None:
1006
+ if self._res != int(src.res[0]):
844
1007
  if out_shape is None:
845
1008
  out_shape = _get_shape_from_bounds(
846
- bounds, self.res, indexes
1009
+ to_bbox(src.bounds), self.res, indexes
847
1010
  )
848
-
849
- self._values = src.read(
850
- indexes=indexes,
851
- window=window,
852
- boundless=boundless,
853
- out_shape=out_shape,
854
- masked=masked,
855
- **kwargs,
1011
+ self.transform = _get_transform_from_bounds(
1012
+ to_bbox(src.bounds), shape=out_shape
856
1013
  )
1014
+ else:
1015
+ self.transform = src.transform
857
1016
 
858
- assert out_shape == self._values.shape, (
859
- out_shape,
860
- self._values.shape,
861
- )
1017
+ values = src.read(
1018
+ indexes=indexes,
1019
+ out_shape=out_shape,
1020
+ masked=masked,
1021
+ **kwargs,
1022
+ )
1023
+ else:
1024
+ window = rasterio.windows.from_bounds(
1025
+ *bounds, transform=src.transform
1026
+ )
1027
+
1028
+ if out_shape is None:
1029
+ out_shape = _get_shape_from_bounds(bounds, self.res, indexes)
862
1030
 
1031
+ values = src.read(
1032
+ indexes=indexes,
1033
+ window=window,
1034
+ boundless=False,
1035
+ out_shape=out_shape,
1036
+ masked=masked,
1037
+ **kwargs,
1038
+ )
1039
+
1040
+ assert out_shape == values.shape, (
1041
+ out_shape,
1042
+ values.shape,
1043
+ )
1044
+
1045
+ width, height = values.shape[-2:]
1046
+
1047
+ if width and height:
863
1048
  self.transform = rasterio.transform.from_bounds(
864
- *bounds, self.width, self.height
1049
+ *bounds, width, height
865
1050
  )
866
- self._bounds = bounds
867
1051
 
868
- if self.nodata is not None and not np.isnan(self.nodata):
869
- if isinstance(self.values, np.ma.core.MaskedArray):
870
- self.values.data[self.values.data == src.nodata] = (
871
- self.nodata
872
- )
873
- else:
874
- self.values[self.values == src.nodata] = self.nodata
1052
+ if self.nodata is not None and not np.isnan(self.nodata):
1053
+ if isinstance(values, np.ma.core.MaskedArray):
1054
+ values.data[values.data == src.nodata] = self.nodata
1055
+ else:
1056
+ values[values == src.nodata] = self.nodata
875
1057
 
876
1058
  if self.masking and self.is_mask:
877
- self.values = np.isin(self.values, self.masking["values"])
1059
+ values = np.isin(values, list(self.masking["values"]))
878
1060
 
879
- elif self.mask is not None and not isinstance(
880
- self.values, np.ma.core.MaskedArray
881
- ):
882
- self.mask = self.mask.copy().load(
883
- bounds=bounds, indexes=indexes, out_shape=out_shape, **kwargs
884
- )
1061
+ elif self.mask is not None and not isinstance(values, np.ma.core.MaskedArray):
1062
+
1063
+ if not self.mask.has_array:
1064
+ self._mask = self.mask.load(
1065
+ bounds=bounds, indexes=indexes, out_shape=out_shape, **kwargs
1066
+ )
885
1067
  mask_arr = self.mask.values
886
1068
 
887
- # if self.masking:
888
- # mask_arr = np.isin(mask_arr, self.masking["values"])
1069
+ values = np.ma.array(values, mask=mask_arr, fill_value=self.nodata)
889
1070
 
890
- self._values = np.ma.array(
891
- self._values, mask=mask_arr, fill_value=self.nodata
892
- )
1071
+ if bounds is not None:
1072
+ self._bounds = to_bbox(bounds)
1073
+
1074
+ self._values = values
1075
+ # trigger the setter
1076
+ self.values = values
893
1077
 
894
1078
  return self
895
1079
 
896
1080
  @property
897
1081
  def is_mask(self) -> bool:
898
1082
  """True if the band_id is equal to the masking band_id."""
1083
+ if self.masking is None:
1084
+ return False
899
1085
  return self.band_id == self.masking["band_id"]
900
1086
 
1087
+ @property
1088
+ def has_array(self) -> bool:
1089
+ """Whether the array is loaded."""
1090
+ try:
1091
+ if not isinstance(self.values, (np.ndarray | DataArray)):
1092
+ raise ValueError()
1093
+ return True
1094
+ except ValueError: # also catches ArrayNotLoadedError
1095
+ return False
1096
+
901
1097
  def write(
902
- self, path: str | Path, driver: str = "GTiff", compress: str = "LZW", **kwargs
1098
+ self,
1099
+ path: str | Path,
1100
+ driver: str = "GTiff",
1101
+ compress: str = "LZW",
1102
+ file_system=None,
1103
+ **kwargs,
903
1104
  ) -> None:
904
1105
  """Write the array as an image file."""
905
1106
  if not hasattr(self, "_values"):
@@ -922,7 +1123,8 @@ class Band(_ImageBandBase):
922
1123
  "width": self.width,
923
1124
  } | kwargs
924
1125
 
925
- with opener(path, "wb", file_system=self.file_system) as f:
1126
+ # with opener(path, "wb", file_system=self.file_system) as f:
1127
+ with opener(path, "wb", file_system=file_system) as f:
926
1128
  with rasterio.open(f, "w", **profile) as dst:
927
1129
 
928
1130
  if dst.nodata is None:
@@ -944,17 +1146,14 @@ class Band(_ImageBandBase):
944
1146
  if isinstance(self.values, np.ma.core.MaskedArray):
945
1147
  dst.write_mask(self.values.mask)
946
1148
 
947
- self._path = str(path)
1149
+ self._path = _fix_path(str(path))
948
1150
 
949
1151
  def apply(self, func: Callable, **kwargs) -> "Band":
950
- """Apply a function to the array."""
951
- self.values = func(self.values, **kwargs)
952
- return self
953
-
954
- def normalize(self) -> "Band":
955
- """Normalize array values between 0 and 1."""
956
- arr = self.values
957
- self.values = (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
1152
+ """Apply a function to the Band."""
1153
+ results = func(self, **kwargs)
1154
+ if isinstance(results, Band):
1155
+ return results
1156
+ self.values = results
958
1157
  return self
959
1158
 
960
1159
  def sample(self, size: int = 1000, mask: Any = None, **kwargs) -> "Image":
@@ -1112,23 +1311,43 @@ class Band(_ImageBandBase):
1112
1311
  )
1113
1312
 
1114
1313
  def to_xarray(self) -> DataArray:
1115
- """Convert the raster to an xarray.DataArray."""
1116
- name = self.name or self.__class__.__name__.lower()
1117
- coords = _generate_spatial_coords(self.transform, self.width, self.height)
1118
- if len(self.values.shape) == 2:
1119
- dims = ["y", "x"]
1120
- elif len(self.values.shape) == 3:
1121
- dims = ["band", "y", "x"]
1122
- else:
1123
- raise ValueError("Array must be 2 or 3 dimensional.")
1124
- return xr.DataArray(
1314
+ """Convert the raster to an xarray.DataArray."""
1315
+ if self.backend == "xarray":
1316
+ return self.values
1317
+ return self._to_xarray(
1125
1318
  self.values,
1126
- coords=coords,
1127
- dims=dims,
1128
- name=name,
1129
- attrs={"crs": self.crs},
1319
+ transform=self.transform,
1320
+ # name=self.name or self.__class__.__name__.lower(),
1130
1321
  )
1131
1322
 
1323
+ def to_numpy(self) -> np.ndarray | np.ma.core.MaskedArray:
1324
+ """Convert the raster to a numpy.ndarray."""
1325
+ return self._to_numpy(self.values).copy()
1326
+
1327
+ def _to_numpy(
1328
+ self, arr: np.ndarray | DataArray, masked: bool = True
1329
+ ) -> np.ndarray | np.ma.core.MaskedArray:
1330
+ if not isinstance(arr, np.ndarray):
1331
+ if masked:
1332
+ try:
1333
+ mask_arr = arr.isnull().values
1334
+ except AttributeError:
1335
+ mask_arr = np.full(arr.shape, False)
1336
+ try:
1337
+ arr = arr.to_numpy()
1338
+ except AttributeError:
1339
+ arr = arr.values
1340
+ if not isinstance(arr, np.ndarray):
1341
+ arr = np.array(arr)
1342
+ if (
1343
+ masked
1344
+ and self.mask is not None
1345
+ and not self.is_mask
1346
+ and not isinstance(arr, np.ma.core.MaskedArray)
1347
+ ):
1348
+ arr = np.ma.array(arr, mask=mask_arr, fill_value=self.nodata)
1349
+ return arr
1350
+
1132
1351
  def __repr__(self) -> str:
1133
1352
  """String representation."""
1134
1353
  try:
@@ -1154,211 +1373,70 @@ class NDVIBand(Band):
1154
1373
  # return get_cmap(arr)
1155
1374
 
1156
1375
 
1157
- def get_cmap(arr: np.ndarray) -> LinearSegmentedColormap:
1158
-
1159
- # blue = [[i / 10 + 0.1, i / 10 + 0.1, 1 - (i / 10) + 0.1] for i in range(11)][1:]
1160
- blue = [
1161
- [0.1, 0.1, 1.0],
1162
- [0.2, 0.2, 0.9],
1163
- [0.3, 0.3, 0.8],
1164
- [0.4, 0.4, 0.7],
1165
- [0.6, 0.6, 0.6],
1166
- [0.6, 0.6, 0.6],
1167
- [0.7, 0.7, 0.7],
1168
- [0.8, 0.8, 0.8],
1169
- ]
1170
- # gray = list(reversed([[i / 10 - 0.1, i / 10, i / 10 - 0.1] for i in range(11)][1:]))
1171
- gray = [
1172
- [0.6, 0.6, 0.6],
1173
- [0.6, 0.6, 0.6],
1174
- [0.6, 0.6, 0.6],
1175
- [0.6, 0.6, 0.6],
1176
- [0.6, 0.6, 0.6],
1177
- [0.4, 0.7, 0.4],
1178
- [0.3, 0.7, 0.3],
1179
- [0.2, 0.8, 0.2],
1180
- ]
1181
- # gray = [[0.6, 0.6, 0.6] for i in range(10)]
1182
- # green = [[0.2 + i/20, i / 10 - 0.1, + i/20] for i in range(11)][1:]
1183
- green = [
1184
- [0.25, 0.0, 0.05],
1185
- [0.3, 0.1, 0.1],
1186
- [0.35, 0.2, 0.15],
1187
- [0.4, 0.3, 0.2],
1188
- [0.45, 0.4, 0.25],
1189
- [0.5, 0.5, 0.3],
1190
- [0.55, 0.6, 0.35],
1191
- [0.7, 0.9, 0.5],
1192
- ]
1193
- green = [
1194
- [0.6, 0.6, 0.6],
1195
- [0.4, 0.7, 0.4],
1196
- [0.3, 0.8, 0.3],
1197
- [0.25, 0.4, 0.25],
1198
- [0.2, 0.5, 0.2],
1199
- [0.10, 0.7, 0.10],
1200
- [0, 0.9, 0],
1201
- ]
1202
-
1203
- def get_start(arr):
1204
- min_value = np.min(arr)
1205
- if min_value < -0.75:
1206
- return 0
1207
- if min_value < -0.5:
1208
- return 1
1209
- if min_value < -0.25:
1210
- return 2
1211
- if min_value < 0:
1212
- return 3
1213
- if min_value < 0.25:
1214
- return 4
1215
- if min_value < 0.5:
1216
- return 5
1217
- if min_value < 0.75:
1218
- return 6
1219
- return 7
1220
-
1221
- def get_stop(arr):
1222
- max_value = np.max(arr)
1223
- if max_value <= 0.05:
1224
- return 0
1225
- if max_value < 0.175:
1226
- return 1
1227
- if max_value < 0.25:
1228
- return 2
1229
- if max_value < 0.375:
1230
- return 3
1231
- if max_value < 0.5:
1232
- return 4
1233
- if max_value < 0.75:
1234
- return 5
1235
- return 6
1236
-
1237
- cmap_name = "blue_gray_green"
1238
-
1239
- start = get_start(arr)
1240
- stop = get_stop(arr)
1241
- blue = blue[start]
1242
- gray = gray[start]
1243
- # green = green[start]
1244
- green = green[stop]
1245
-
1246
- # green[0] = np.arange(0, 1, 0.1)[::-1][stop]
1247
- # green[1] = np.arange(0, 1, 0.1)[stop]
1248
- # green[2] = np.arange(0, 1, 0.1)[::-1][stop]
1249
-
1250
- print(green)
1251
- print(start, stop)
1252
- print("blue gray green")
1253
- print(blue)
1254
- print(gray)
1255
- print(green)
1256
-
1257
- # Define the segments of the colormap
1258
- cdict = {
1259
- "red": [
1260
- (0.0, blue[0], blue[0]),
1261
- (0.3, gray[0], gray[0]),
1262
- (0.7, gray[0], gray[0]),
1263
- (1.0, green[0], green[0]),
1264
- ],
1265
- "green": [
1266
- (0.0, blue[1], blue[1]),
1267
- (0.3, gray[1], gray[1]),
1268
- (0.7, gray[1], gray[1]),
1269
- (1.0, green[1], green[1]),
1270
- ],
1271
- "blue": [
1272
- (0.0, blue[2], blue[2]),
1273
- (0.3, gray[2], gray[2]),
1274
- (0.7, gray[2], gray[2]),
1275
- (1.0, green[2], green[2]),
1276
- ],
1277
- }
1278
-
1279
- return LinearSegmentedColormap(cmap_name, segmentdata=cdict, N=50)
1280
-
1281
-
1282
- def median_as_int_and_minimum_dtype(arr: np.ndarray) -> np.ndarray:
1283
- arr = np.median(arr, axis=0).astype(int)
1284
- min_dtype = rasterio.dtypes.get_minimum_dtype(arr)
1285
- return arr.astype(min_dtype)
1376
+ def median_as_int_and_minimum_dtype(arr: np.ndarray) -> np.ndarray:
1377
+ arr = np.median(arr, axis=0).astype(int)
1378
+ min_dtype = rasterio.dtypes.get_minimum_dtype(arr)
1379
+ return arr.astype(min_dtype)
1286
1380
 
1287
1381
 
1288
1382
  class Image(_ImageBandBase):
1289
1383
  """Image consisting of one or more Bands."""
1290
1384
 
1291
- cloud_cover_regexes: ClassVar[tuple[str] | None] = None
1292
1385
  band_class: ClassVar[Band] = Band
1386
+ backend: str = "numpy"
1293
1387
 
1294
1388
  def __init__(
1295
1389
  self,
1296
- data: str | Path | Sequence[Band],
1390
+ data: str | Path | Sequence[Band] | None = None,
1297
1391
  res: int | None = None,
1298
- crs: Any | None = None,
1299
- single_banded: bool = False,
1300
- file_system: GCSFileSystem | None = None,
1301
- df: pd.DataFrame | None = None,
1302
- all_file_paths: list[str] | None = None,
1303
1392
  processes: int = 1,
1304
- bbox: GeoDataFrame | GeoSeries | Geometry | tuple | None = None,
1393
+ df: pd.DataFrame | None = None,
1305
1394
  nodata: int | None = None,
1395
+ all_file_paths: list[str] | None = None,
1306
1396
  **kwargs,
1307
1397
  ) -> None:
1308
1398
  """Image initialiser."""
1399
+ if data is None:
1400
+ # allowing 'bands' to replace 'data' as argument
1401
+ # to make the print repr. valid as initialiser
1402
+ if "bands" not in kwargs:
1403
+ raise TypeError("Must specify either 'data' or 'bands'.")
1404
+ data = kwargs.pop("bands")
1405
+
1309
1406
  super().__init__(**kwargs)
1310
1407
 
1311
1408
  self.nodata = nodata
1312
- self._res = res
1313
- self._crs = crs
1314
- self.file_system = file_system
1315
- self._bbox = to_bbox(bbox) if bbox is not None else None
1316
- # self._mask = _mask
1317
- self.single_banded = single_banded
1318
1409
  self.processes = processes
1319
- self._all_file_paths = all_file_paths
1410
+ self._crs = None
1411
+ self._bands = None
1320
1412
 
1321
1413
  if hasattr(data, "__iter__") and all(isinstance(x, Band) for x in data):
1322
- self._bands = list(data)
1323
- if res is None:
1324
- res = list({band.res for band in self._bands})
1325
- if len(res) == 1:
1326
- self._res = res[0]
1327
- else:
1328
- raise ValueError(f"Different resolutions for the bands: {res}")
1329
- else:
1330
- self._res = res
1414
+ self._construct_image_from_bands(data, res)
1331
1415
  return
1332
-
1333
- if not isinstance(data, (str | Path | os.PathLike)):
1416
+ elif not isinstance(data, (str | Path | os.PathLike)):
1334
1417
  raise TypeError("'data' must be string, Path-like or a sequence of Band.")
1335
1418
 
1336
- self._bands = None
1337
- self._path = str(data)
1419
+ self._res = res
1420
+ self._path = _fix_path(data)
1421
+
1422
+ if all_file_paths is None and self.path:
1423
+ self._all_file_paths = _get_all_file_paths(self.path)
1424
+ elif self.path:
1425
+ all_file_paths = {_fix_path(x) for x in all_file_paths}
1426
+ self._all_file_paths = {x for x in all_file_paths if self.path in x}
1427
+ else:
1428
+ self._all_file_paths = None
1338
1429
 
1339
1430
  if df is None:
1340
- if is_dapla():
1341
- file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
1342
- else:
1343
- file_paths = list(
1344
- sorted(
1345
- set(
1346
- _glob_func(self.path + "/**/**")
1347
- + _glob_func(self.path + "/**/**/**")
1348
- + _glob_func(self.path + "/**/**/**/**")
1349
- + _glob_func(self.path + "/**/**/**/**/**")
1350
- )
1351
- )
1352
- )
1353
- if not file_paths:
1354
- file_paths = [self.path]
1355
- df = self._create_metadata_df(file_paths)
1431
+ if not self._all_file_paths:
1432
+ self._all_file_paths = [self.path]
1433
+ df = self._create_metadata_df(self._all_file_paths)
1356
1434
 
1357
1435
  df["image_path"] = df["image_path"].astype(str)
1358
1436
 
1359
1437
  cols_to_explode = [
1360
1438
  "file_path",
1361
- "filename",
1439
+ "file_name",
1362
1440
  *[x for x in df if FILENAME_COL_SUFFIX in x],
1363
1441
  ]
1364
1442
  try:
@@ -1366,44 +1444,82 @@ class Image(_ImageBandBase):
1366
1444
  except ValueError:
1367
1445
  for col in cols_to_explode:
1368
1446
  df = df.explode(col)
1369
- df = df.loc[lambda x: ~x["filename"].duplicated()].reset_index(drop=True)
1447
+ df = df.loc[lambda x: ~x["file_name"].duplicated()].reset_index(drop=True)
1448
+
1449
+ df = df.loc[lambda x: x["image_path"] == self.path]
1450
+
1451
+ self._df = df
1452
+
1453
+ if self.path is not None and self.metadata:
1454
+ self.metadata = {
1455
+ key: value for key, value in self.metadata.items() if self.path in key
1456
+ }
1457
+
1458
+ if self.metadata:
1459
+ try:
1460
+ metadata = self.metadata[self.path]
1461
+ except KeyError:
1462
+ metadata = {}
1463
+ for key, value in metadata.items():
1464
+ if key in dir(self):
1465
+ setattr(self, f"_{key}", value)
1466
+ else:
1467
+ setattr(self, key, value)
1370
1468
 
1371
- df = df.loc[lambda x: x["image_path"].str.contains(_fix_path(self.path))]
1469
+ else:
1470
+ for key, value in self._get_metadata_attributes(
1471
+ self.metadata_attributes
1472
+ ).items():
1473
+ setattr(self, key, value)
1372
1474
 
1373
- if self.cloud_cover_regexes:
1374
- if all_file_paths is None:
1375
- file_paths = _ls_func(self.path)
1475
+ def _construct_image_from_bands(
1476
+ self, data: Sequence[Band], res: int | None
1477
+ ) -> None:
1478
+ self._bands = list(data)
1479
+ if res is None:
1480
+ res = list({band.res for band in self.bands})
1481
+ if len(res) == 1:
1482
+ self._res = res[0]
1376
1483
  else:
1377
- file_paths = [path for path in all_file_paths if self.name in path]
1378
- self.cloud_coverage_percentage = float(
1379
- _get_regex_match_from_xml_in_local_dir(
1380
- file_paths, regexes=self.cloud_cover_regexes
1381
- )
1382
- )
1484
+ raise ValueError(f"Different resolutions for the bands: {res}")
1383
1485
  else:
1384
- self.cloud_coverage_percentage = None
1486
+ self._res = res
1487
+ for key in self.metadata_attributes:
1488
+ band_values = {getattr(band, key) for band in self if hasattr(band, key)}
1489
+ band_values = {x for x in band_values if x is not None}
1490
+ if len(band_values) > 1:
1491
+ raise ValueError(f"Different {key} values in bands: {band_values}")
1492
+ elif len(band_values):
1493
+ try:
1494
+ setattr(self, key, next(iter(band_values)))
1495
+ except AttributeError:
1496
+ setattr(self, f"_{key}", next(iter(band_values)))
1385
1497
 
1386
- self._df = df
1498
+ def copy(self) -> "Image":
1499
+ """Copy the instance and its attributes."""
1500
+ copied = super().copy()
1501
+ for band in copied:
1502
+ band._mask = copied._mask
1503
+ return copied
1387
1504
 
1388
- @property
1389
- def values(self) -> np.ndarray:
1390
- """3 dimensional numpy array."""
1391
- return np.array([band.values for band in self])
1505
+ def apply(self, func: Callable, **kwargs) -> "Image":
1506
+ """Apply a function to each band of the Image."""
1507
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1508
+ parallel(joblib.delayed(_band_apply)(band, func, **kwargs) for band in self)
1509
+
1510
+ return self
1392
1511
 
1393
- def ndvi(self, red_band: str, nir_band: str, copy: bool = True) -> NDVIBand:
1512
+ def ndvi(
1513
+ self, red_band: str, nir_band: str, padding: int = 0, copy: bool = True
1514
+ ) -> NDVIBand:
1394
1515
  """Calculate the NDVI for the Image."""
1395
1516
  copied = self.copy() if copy else self
1396
1517
  red = copied[red_band].load()
1397
1518
  nir = copied[nir_band].load()
1398
1519
 
1399
- arr: np.ndarray | np.ma.core.MaskedArray = ndvi(red.values, nir.values)
1400
-
1401
- # if self.nodata is not None and not np.isnan(self.nodata):
1402
- # try:
1403
- # arr.data[arr.mask] = self.nodata
1404
- # arr = arr.copy()
1405
- # except AttributeError:
1406
- # pass
1520
+ arr: np.ndarray | np.ma.core.MaskedArray = ndvi(
1521
+ red.values, nir.values, padding=padding
1522
+ )
1407
1523
 
1408
1524
  return NDVIBand(
1409
1525
  arr,
@@ -1445,37 +1561,63 @@ class Image(_ImageBandBase):
1445
1561
  **self._common_init_kwargs,
1446
1562
  )
1447
1563
 
1564
+ def to_xarray(self) -> DataArray:
1565
+ """Convert the raster to an xarray.DataArray."""
1566
+ if self.backend == "xarray":
1567
+ return self.values
1568
+
1569
+ return self._to_xarray(
1570
+ np.array([band.values for band in self]),
1571
+ transform=self[0].transform,
1572
+ )
1573
+
1448
1574
  @property
1449
1575
  def mask(self) -> Band | None:
1450
1576
  """Mask Band."""
1451
- if self._mask is not None:
1452
- return self._mask
1453
1577
  if self.masking is None:
1454
1578
  return None
1455
1579
 
1580
+ elif self._mask is not None:
1581
+ return self._mask
1582
+
1583
+ elif self._bands is not None and all(band.mask is not None for band in self):
1584
+ if len({id(band.mask) for band in self}) > 1:
1585
+ raise ValueError(
1586
+ "Image bands must have same mask.",
1587
+ {id(band.mask) for band in self},
1588
+ ) # TODO
1589
+ self._mask = next(
1590
+ iter([band.mask for band in self if band.mask is not None])
1591
+ )
1592
+ return self._mask
1593
+
1456
1594
  mask_band_id = self.masking["band_id"]
1457
- mask_paths = [path for path in self._df["file_path"] if mask_band_id in path]
1595
+ mask_paths = [path for path in self._all_file_paths if mask_band_id in path]
1458
1596
  if len(mask_paths) > 1:
1459
1597
  raise ValueError(
1460
1598
  f"Multiple file_paths match mask band_id {mask_band_id} for {self.path}"
1461
1599
  )
1462
1600
  elif not mask_paths:
1463
1601
  raise ValueError(
1464
- f"No file_paths match mask band_id {mask_band_id} for {self.path}"
1602
+ f"No file_paths match mask band_id {mask_band_id} for {self.path} among "
1603
+ + str([Path(x).name for x in _ls_func(self.path)])
1465
1604
  )
1605
+
1466
1606
  self._mask = self.band_class(
1467
1607
  mask_paths[0],
1468
1608
  **self._common_init_kwargs,
1469
1609
  )
1470
-
1610
+ if self._bands is not None:
1611
+ for band in self:
1612
+ band._mask = self._mask
1471
1613
  return self._mask
1472
1614
 
1473
1615
  @mask.setter
1474
- def mask(self, values: Band) -> None:
1616
+ def mask(self, values: Band | None) -> None:
1475
1617
  if values is None:
1476
1618
  self._mask = None
1477
1619
  for band in self:
1478
- band.mask = None
1620
+ band._mask = None
1479
1621
  return
1480
1622
  if not isinstance(values, Band):
1481
1623
  raise TypeError(f"mask must be Band. Got {type(values)}")
@@ -1485,7 +1627,7 @@ class Image(_ImageBandBase):
1485
1627
  band._mask = self._mask
1486
1628
  try:
1487
1629
  band.values = np.ma.array(
1488
- band.values, mask=mask_arr, fill_value=band.nodata
1630
+ band.values.data, mask=mask_arr, fill_value=band.nodata
1489
1631
  )
1490
1632
  except ArrayNotLoadedError:
1491
1633
  pass
@@ -1506,45 +1648,24 @@ class Image(_ImageBandBase):
1506
1648
  if self._bands is not None:
1507
1649
  return self._bands
1508
1650
 
1509
- # if self.masking:
1510
- # mask_band_id = self.masking["band_id"]
1511
- # mask_paths = [
1512
- # path for path in self._df["file_path"] if mask_band_id in path
1513
- # ]
1514
- # if len(mask_paths) > 1:
1515
- # raise ValueError(
1516
- # f"Multiple file_paths match mask band_id {mask_band_id}"
1517
- # )
1518
- # elif not mask_paths:
1519
- # raise ValueError(f"No file_paths match mask band_id {mask_band_id}")
1520
- # arr = (
1521
- # self.band_class(
1522
- # mask_paths[0],
1523
- # # mask=self.mask,
1524
- # **self._common_init_kwargs,
1525
- # )
1526
- # .load()
1527
- # .values
1528
- # )
1529
- # self._mask = np.ma.array(
1530
- # arr, mask=np.isin(arr, self.masking["values"]), fill_value=None
1531
- # )
1651
+ if self.masking:
1652
+ mask_band_id = self.masking["band_id"]
1653
+ paths = [path for path in self._df["file_path"] if mask_band_id not in path]
1654
+ else:
1655
+ paths = self._df["file_path"]
1656
+
1657
+ mask = self.mask
1532
1658
 
1533
1659
  self._bands = [
1534
1660
  self.band_class(
1535
1661
  path,
1536
- mask=self.mask,
1662
+ mask=mask,
1663
+ all_file_paths=self._all_file_paths,
1537
1664
  **self._common_init_kwargs,
1538
1665
  )
1539
- for path in (self._df["file_path"])
1666
+ for path in paths
1540
1667
  ]
1541
1668
 
1542
- if self.masking:
1543
- mask_band_id = self.masking["band_id"]
1544
- self._bands = [
1545
- band for band in self._bands if mask_band_id not in band.path
1546
- ]
1547
-
1548
1669
  if (
1549
1670
  self.filename_patterns
1550
1671
  and any(_get_non_optional_groups(pat) for pat in self.filename_patterns)
@@ -1557,11 +1678,7 @@ class Image(_ImageBandBase):
1557
1678
  self._bands = [
1558
1679
  band
1559
1680
  for band in self._bands
1560
- if any(
1561
- # _get_first_group_match(pat, band.name)
1562
- re.search(pat, band.name)
1563
- for pat in self.filename_patterns
1564
- )
1681
+ if any(re.search(pat, band.name) for pat in self.filename_patterns)
1565
1682
  ]
1566
1683
 
1567
1684
  if self.image_patterns:
@@ -1570,7 +1687,6 @@ class Image(_ImageBandBase):
1570
1687
  for band in self._bands
1571
1688
  if any(
1572
1689
  re.search(pat, Path(band.path).parent.name)
1573
- # _get_first_group_match(pat, Path(band.path).parent.name)
1574
1690
  for pat in self.image_patterns
1575
1691
  )
1576
1692
  ]
@@ -1583,10 +1699,14 @@ class Image(_ImageBandBase):
1583
1699
  @property
1584
1700
  def _should_be_sorted(self) -> bool:
1585
1701
  sort_groups = ["band", "band_id"]
1586
- return self.filename_patterns and any(
1587
- group in _get_non_optional_groups(pat)
1588
- for group in sort_groups
1589
- for pat in self.filename_patterns
1702
+ return (
1703
+ self.filename_patterns
1704
+ and any(
1705
+ group in _get_non_optional_groups(pat)
1706
+ for group in sort_groups
1707
+ for pat in self.filename_patterns
1708
+ )
1709
+ or all(band.band_id is not None for band in self)
1590
1710
  )
1591
1711
 
1592
1712
  @property
@@ -1621,7 +1741,14 @@ class Image(_ImageBandBase):
1621
1741
  @property
1622
1742
  def bounds(self) -> tuple[int, int, int, int] | None:
1623
1743
  """Bounds of the Image (minx, miny, maxx, maxy)."""
1624
- return get_total_bounds([band.bounds for band in self])
1744
+ try:
1745
+ return get_total_bounds([band.bounds for band in self])
1746
+ except exceptions.RefreshError:
1747
+ bounds = []
1748
+ for band in self:
1749
+ time.sleep(0.1)
1750
+ bounds.append(band.bounds)
1751
+ return get_total_bounds(bounds)
1625
1752
 
1626
1753
  def to_gdf(self, column: str = "value") -> GeoDataFrame:
1627
1754
  """Convert the array to a GeoDataFrame of grid polygons and values."""
@@ -1647,7 +1774,7 @@ class Image(_ImageBandBase):
1647
1774
  def __getitem__(
1648
1775
  self, band: str | int | Sequence[str] | Sequence[int]
1649
1776
  ) -> "Band | Image":
1650
- """Get bands by band_id or integer index.
1777
+ """Get bands by band_id or integer index or a sequence of such.
1651
1778
 
1652
1779
  Returns a Band if a string or int is passed,
1653
1780
  returns an Image if a sequence of strings or integers is passed.
@@ -1655,7 +1782,7 @@ class Image(_ImageBandBase):
1655
1782
  if isinstance(band, str):
1656
1783
  return self._get_band(band)
1657
1784
  if isinstance(band, int):
1658
- return self.bands[band] # .copy()
1785
+ return self.bands[band]
1659
1786
 
1660
1787
  copied = self.copy()
1661
1788
  try:
@@ -1681,10 +1808,7 @@ class Image(_ImageBandBase):
1681
1808
  try:
1682
1809
  return self.date < other.date
1683
1810
  except Exception as e:
1684
- print(self.path)
1685
- print(self.date)
1686
- print(other.path)
1687
- print(other.date)
1811
+ print("", self.path, self.date, other.path, other.date, sep="\n")
1688
1812
  raise e
1689
1813
 
1690
1814
  def __iter__(self) -> Iterator[Band]:
@@ -1743,103 +1867,73 @@ class ImageCollection(_ImageBase):
1743
1867
 
1744
1868
  image_class: ClassVar[Image] = Image
1745
1869
  band_class: ClassVar[Band] = Band
1870
+ _metadata_attribute_collection_type: ClassVar[type] = pd.Series
1871
+ backend: str = "numpy"
1746
1872
 
1747
1873
  def __init__(
1748
1874
  self,
1749
- data: str | Path | Sequence[Image],
1875
+ data: str | Path | Sequence[Image] | Sequence[str | Path],
1750
1876
  res: int,
1751
- level: str | None,
1752
- crs: Any | None = None,
1753
- single_banded: bool = False,
1877
+ level: str | None = None_,
1754
1878
  processes: int = 1,
1755
- file_system: GCSFileSystem | None = None,
1756
- df: pd.DataFrame | None = None,
1757
- bbox: Any | None = None,
1758
- nodata: int | None = None,
1759
1879
  metadata: str | dict | pd.DataFrame | None = None,
1880
+ nodata: int | None = None,
1760
1881
  **kwargs,
1761
1882
  ) -> None:
1762
1883
  """Initialiser."""
1763
- super().__init__(**kwargs)
1884
+ if data is not None and kwargs.get("root"):
1885
+ root = _fix_path(kwargs.pop("root"))
1886
+ data = [f"{root}/{name}" for name in data]
1887
+ _from_root = True
1888
+ else:
1889
+ _from_root = False
1890
+
1891
+ super().__init__(metadata=metadata, **kwargs)
1892
+
1893
+ if callable(level) and isinstance(level(), None_):
1894
+ level = None
1764
1895
 
1765
1896
  self.nodata = nodata
1766
1897
  self.level = level
1767
- self._crs = crs
1768
1898
  self.processes = processes
1769
- self.file_system = file_system
1770
1899
  self._res = res
1771
- self._bbox = to_bbox(bbox) if bbox is not None else None
1772
- self._band_ids = None
1773
- self.single_banded = single_banded
1900
+ self._crs = None
1774
1901
 
1775
- if metadata is not None:
1776
- if isinstance(metadata, (str | Path | os.PathLike)):
1777
- self.metadata = _read_parquet_func(metadata)
1778
- else:
1779
- self.metadata = metadata
1780
- else:
1781
- self.metadata = metadata
1902
+ self._df = None
1903
+ self._all_file_paths = None
1904
+ self._images = None
1782
1905
 
1783
- if hasattr(data, "__iter__") and all(isinstance(x, Image) for x in data):
1906
+ if hasattr(data, "__iter__") and not isinstance(data, str):
1784
1907
  self._path = None
1785
- self.images = [x.copy() for x in data]
1786
- return
1787
- else:
1788
- self._images = None
1908
+ if all(isinstance(x, Image) for x in data):
1909
+ self.images = [x.copy() for x in data]
1910
+ return
1911
+ elif all(isinstance(x, (str | Path | os.PathLike)) for x in data):
1912
+ # adding band paths (asuming 'data' is a sequence of image paths)
1913
+ try:
1914
+ self._all_file_paths = _get_child_paths_threaded(data) | set(data)
1915
+ except FileNotFoundError as e:
1916
+ if _from_root:
1917
+ raise TypeError(
1918
+ "When passing 'root', 'data' must be a sequence of image names that have 'root' as parent path."
1919
+ ) from e
1920
+ raise e
1921
+ self._df = self._create_metadata_df(self._all_file_paths)
1922
+ return
1789
1923
 
1790
1924
  if not isinstance(data, (str | Path | os.PathLike)):
1791
1925
  raise TypeError("'data' must be string, Path-like or a sequence of Image.")
1792
1926
 
1793
- self._path = str(data)
1927
+ self._path = _fix_path(str(data))
1794
1928
 
1795
- if is_dapla():
1796
- self._all_file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
1797
- else:
1798
- self._all_file_paths = list(
1799
- sorted(
1800
- set(
1801
- _glob_func(self.path + "/**/**")
1802
- + _glob_func(self.path + "/**/**/**")
1803
- + _glob_func(self.path + "/**/**/**/**")
1804
- + _glob_func(self.path + "/**/**/**/**/**")
1805
- )
1806
- )
1807
- )
1929
+ self._all_file_paths = _get_all_file_paths(self.path)
1808
1930
 
1809
1931
  if self.level:
1810
1932
  self._all_file_paths = [
1811
1933
  path for path in self._all_file_paths if self.level in path
1812
1934
  ]
1813
1935
 
1814
- if df is not None:
1815
- self._df = df
1816
- else:
1817
- self._df = self._create_metadata_df(self._all_file_paths)
1818
-
1819
- @property
1820
- def values(self) -> np.ndarray:
1821
- """4 dimensional numpy array."""
1822
- return np.array([img.values for img in self])
1823
-
1824
- @property
1825
- def mask(self) -> np.ndarray:
1826
- """4 dimensional numpy array."""
1827
- return np.array([img.mask.values for img in self])
1828
-
1829
- # def ndvi(
1830
- # self, red_band: str, nir_band: str, copy: bool = True
1831
- # ) -> "ImageCollection":
1832
- # # copied = self.copy() if copy else self
1833
-
1834
- # with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1835
- # ndvi_images = parallel(
1836
- # joblib.delayed(_img_ndvi)(
1837
- # img, red_band=red_band, nir_band=nir_band, copy=False
1838
- # )
1839
- # for img in self
1840
- # )
1841
-
1842
- # return ImageCollection(ndvi_images, single_banded=True)
1936
+ self._df = self._create_metadata_df(self._all_file_paths)
1843
1937
 
1844
1938
  def groupby(self, by: str | list[str], **kwargs) -> ImageCollectionGroupBy:
1845
1939
  """Group the Collection by Image or Band attribute(s)."""
@@ -1882,7 +1976,6 @@ class ImageCollection(_ImageBase):
1882
1976
  copied.images = [
1883
1977
  self.image_class(
1884
1978
  [band],
1885
- single_banded=True,
1886
1979
  masking=self.masking,
1887
1980
  band_class=self.band_class,
1888
1981
  **self._common_init_kwargs,
@@ -1892,6 +1985,64 @@ class ImageCollection(_ImageBase):
1892
1985
  for img in self
1893
1986
  for band in img
1894
1987
  ]
1988
+ for img in copied:
1989
+ assert len(img) == 1
1990
+ try:
1991
+ img._path = _fix_path(img[0].path)
1992
+ except PathlessImageError:
1993
+ pass
1994
+ return copied
1995
+
1996
+ def apply(self, func: Callable, **kwargs) -> "ImageCollection":
1997
+ """Apply a function to all bands in each image of the collection."""
1998
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1999
+ parallel(
2000
+ joblib.delayed(_band_apply)(band, func, **kwargs)
2001
+ for img in self
2002
+ for band in img
2003
+ )
2004
+
2005
+ return self
2006
+
2007
+ def get_unique_band_ids(self) -> list[str]:
2008
+ """Get a list of unique band_ids across all images."""
2009
+ return list({band.band_id for img in self for band in img})
2010
+
2011
+ def filter(
2012
+ self,
2013
+ bands: str | list[str] | None = None,
2014
+ date_ranges: DATE_RANGES_TYPE = None,
2015
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2016
+ intersects: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2017
+ max_cloud_cover: int | None = None,
2018
+ copy: bool = True,
2019
+ ) -> "ImageCollection":
2020
+ """Filter images and bands in the collection."""
2021
+ copied = self.copy() if copy else self
2022
+
2023
+ if date_ranges:
2024
+ copied = copied._filter_dates(date_ranges)
2025
+
2026
+ if max_cloud_cover is not None:
2027
+ copied.images = [
2028
+ image
2029
+ for image in copied.images
2030
+ if image.cloud_cover_percentage < max_cloud_cover
2031
+ ]
2032
+
2033
+ if bbox is not None:
2034
+ copied = copied._filter_bounds(bbox)
2035
+ copied._set_bbox(bbox)
2036
+
2037
+ if intersects is not None:
2038
+ copied = copied._filter_bounds(intersects)
2039
+
2040
+ if bands is not None:
2041
+ if isinstance(bands, str):
2042
+ bands = [bands]
2043
+ bands = set(bands)
2044
+ copied.images = [img[bands] for img in copied.images if bands in img]
2045
+
1895
2046
  return copied
1896
2047
 
1897
2048
  def merge(
@@ -1903,8 +2054,11 @@ class ImageCollection(_ImageBase):
1903
2054
  **kwargs,
1904
2055
  ) -> Band:
1905
2056
  """Merge all areas and all bands to a single Band."""
1906
- bounds = to_bbox(bounds) if bounds is not None else self._bbox
1907
- crs = self.crs
2057
+ bounds = _get_bounds(bounds, self._bbox, self.union_all())
2058
+ if bounds is not None:
2059
+ bounds = to_bbox(bounds)
2060
+
2061
+ crs = self.crs
1908
2062
 
1909
2063
  if indexes is None:
1910
2064
  indexes = 1
@@ -1938,14 +2092,14 @@ class ImageCollection(_ImageBase):
1938
2092
  **kwargs,
1939
2093
  )
1940
2094
 
1941
- if isinstance(indexes, int) and len(arr.shape) == 3 and arr.shape[0] == 1:
1942
- arr = arr[0]
2095
+ if isinstance(indexes, int) and len(arr.shape) == 3 and arr.shape[0] == 1:
2096
+ arr = arr[0]
1943
2097
 
1944
- if method == "mean":
1945
- if as_int:
1946
- arr = arr // len(datasets)
1947
- else:
1948
- arr = arr / len(datasets)
2098
+ if method == "mean":
2099
+ if as_int:
2100
+ arr = arr // len(datasets)
2101
+ else:
2102
+ arr = arr / len(datasets)
1949
2103
 
1950
2104
  if bounds is None:
1951
2105
  bounds = self.bounds
@@ -1971,7 +2125,9 @@ class ImageCollection(_ImageBase):
1971
2125
  **kwargs,
1972
2126
  ) -> Image:
1973
2127
  """Merge all areas to a single tile, one band per band_id."""
1974
- bounds = to_bbox(bounds) if bounds is not None else self._bbox
2128
+ bounds = _get_bounds(bounds, self._bbox, self.union_all())
2129
+ if bounds is not None:
2130
+ bounds = to_bbox(bounds)
1975
2131
  bounds = self.bounds if bounds is None else bounds
1976
2132
  out_bounds = bounds
1977
2133
  crs = self.crs
@@ -2031,7 +2187,7 @@ class ImageCollection(_ImageBase):
2031
2187
  )
2032
2188
  )
2033
2189
 
2034
- # return self.image_class(
2190
+ # return self.image_class( # TODO
2035
2191
  image = Image(
2036
2192
  bands,
2037
2193
  band_class=self.band_class,
@@ -2066,10 +2222,13 @@ class ImageCollection(_ImageBase):
2066
2222
  arr = np.array(
2067
2223
  [
2068
2224
  (
2069
- band.load(
2070
- bounds=(_bounds if _bounds is not None else None),
2071
- **kwargs,
2072
- )
2225
+ # band.load(
2226
+ # bounds=(_bounds if _bounds is not None else None),
2227
+ # **kwargs,
2228
+ # )
2229
+ # if not band.has_array
2230
+ # else
2231
+ band
2073
2232
  ).values
2074
2233
  for img in collection
2075
2234
  for band in img
@@ -2092,7 +2251,7 @@ class ImageCollection(_ImageBase):
2092
2251
  coords = _generate_spatial_coords(transform, width, height)
2093
2252
 
2094
2253
  arrs.append(
2095
- xr.DataArray(
2254
+ DataArray(
2096
2255
  arr,
2097
2256
  coords=coords,
2098
2257
  dims=["y", "x"],
@@ -2109,7 +2268,7 @@ class ImageCollection(_ImageBase):
2109
2268
  return merged.to_numpy()
2110
2269
 
2111
2270
  def sort_images(self, ascending: bool = True) -> "ImageCollection":
2112
- """Sort Images by date."""
2271
+ """Sort Images by date, then file path if date attribute is missing."""
2113
2272
  self._images = (
2114
2273
  list(sorted([img for img in self if img.date is not None]))
2115
2274
  + sorted(
@@ -2126,20 +2285,56 @@ class ImageCollection(_ImageBase):
2126
2285
  self,
2127
2286
  bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
2128
2287
  indexes: int | tuple[int] | None = None,
2288
+ file_system=None,
2129
2289
  **kwargs,
2130
2290
  ) -> "ImageCollection":
2131
2291
  """Load all image Bands with threading."""
2292
+ if (
2293
+ bounds is None
2294
+ and indexes is None
2295
+ and all(band.has_array for img in self for band in img)
2296
+ ):
2297
+ return self
2298
+
2299
+ # if self.processes == 1:
2300
+ # for img in self:
2301
+ # for band in img:
2302
+ # band.load(
2303
+ # bounds=bounds,
2304
+ # indexes=indexes,
2305
+ # file_system=file_system,
2306
+ # **kwargs,
2307
+ # )
2308
+ # return self
2309
+
2132
2310
  with joblib.Parallel(n_jobs=self.processes, backend="threading") as parallel:
2133
2311
  if self.masking:
2134
2312
  parallel(
2135
2313
  joblib.delayed(_load_band)(
2136
- img.mask, bounds=bounds, indexes=indexes, **kwargs
2314
+ img.mask,
2315
+ bounds=bounds,
2316
+ indexes=indexes,
2317
+ file_system=file_system,
2318
+ **kwargs,
2137
2319
  )
2138
2320
  for img in self
2139
2321
  )
2322
+ for img in self:
2323
+ for band in img:
2324
+ band._mask = img.mask
2325
+
2326
+ # print({img.mask.has_array for img in self })
2327
+ # print({band.mask.has_array for img in self for band in img})
2328
+
2329
+ # with joblib.Parallel(n_jobs=self.processes, backend="threading") as parallel:
2330
+
2140
2331
  parallel(
2141
2332
  joblib.delayed(_load_band)(
2142
- band, bounds=bounds, indexes=indexes, **kwargs
2333
+ band,
2334
+ bounds=bounds,
2335
+ indexes=indexes,
2336
+ file_system=file_system,
2337
+ **kwargs,
2143
2338
  )
2144
2339
  for img in self
2145
2340
  for band in img
@@ -2147,7 +2342,28 @@ class ImageCollection(_ImageBase):
2147
2342
 
2148
2343
  return self
2149
2344
 
2150
- def set_bbox(
2345
+ def clip(
2346
+ self,
2347
+ mask: Geometry | GeoDataFrame | GeoSeries,
2348
+ **kwargs,
2349
+ ) -> "ImageCollection":
2350
+ """Clip all image Bands with 'loky'."""
2351
+ if self.processes == 1:
2352
+ for img in self:
2353
+ for band in img:
2354
+ band.clip(mask, **kwargs)
2355
+ return self
2356
+
2357
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
2358
+ parallel(
2359
+ joblib.delayed(_clip_band)(band, mask, **kwargs)
2360
+ for img in self
2361
+ for band in img
2362
+ )
2363
+
2364
+ return self
2365
+
2366
+ def _set_bbox(
2151
2367
  self, bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float]
2152
2368
  ) -> "ImageCollection":
2153
2369
  """Set the mask to be used to clip the images to."""
@@ -2156,86 +2372,23 @@ class ImageCollection(_ImageBase):
2156
2372
  if self._images is not None:
2157
2373
  for img in self._images:
2158
2374
  img._bbox = self._bbox
2159
- if img._bands is not None:
2160
- for band in img:
2161
- band._bbox = self._bbox
2162
- bounds = box(*band._bbox).intersection(box(*band.bounds))
2163
- band._bounds = to_bbox(bounds) if not bounds.is_empty else None
2164
-
2165
- return self
2375
+ if img.mask is not None:
2376
+ img.mask._bbox = self._bbox
2377
+ if img.bands is None:
2378
+ continue
2379
+ for band in img:
2380
+ band._bbox = self._bbox
2381
+ bounds = box(*band._bbox).intersection(box(*band.bounds))
2382
+ band._bounds = to_bbox(bounds) if not bounds.is_empty else None
2383
+ if band.mask is not None:
2384
+ band.mask._bbox = self._bbox
2385
+ band.mask._bounds = band._bounds
2166
2386
 
2167
- def apply(self, func: Callable, **kwargs) -> "ImageCollection":
2168
- """Apply a function to all bands in each image of the collection."""
2169
- for img in self:
2170
- img.bands = [func(band, **kwargs) for band in img]
2171
2387
  return self
2172
2388
 
2173
- def filter(
2174
- self,
2175
- bands: str | list[str] | None = None,
2176
- exclude_bands: str | list[str] | None = None,
2177
- date_ranges: (
2178
- tuple[str | None, str | None]
2179
- | tuple[tuple[str | None, str | None], ...]
2180
- | None
2181
- ) = None,
2182
- bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2183
- intersects: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2184
- max_cloud_coverage: int | None = None,
2185
- copy: bool = True,
2186
- ) -> "ImageCollection":
2187
- """Filter images and bands in the collection."""
2188
- copied = self.copy() if copy else self
2189
-
2190
- if isinstance(bbox, BoundingBox):
2191
- date_ranges = (bbox.mint, bbox.maxt)
2192
-
2193
- if date_ranges:
2194
- copied = copied._filter_dates(date_ranges)
2195
-
2196
- if max_cloud_coverage is not None:
2197
- copied.images = [
2198
- image
2199
- for image in copied.images
2200
- if image.cloud_coverage_percentage < max_cloud_coverage
2201
- ]
2202
-
2203
- if bbox is not None:
2204
- copied = copied._filter_bounds(bbox)
2205
- copied.set_bbox(bbox)
2206
-
2207
- if intersects is not None:
2208
- copied = copied._filter_bounds(intersects)
2209
-
2210
- if bands is not None:
2211
- if isinstance(bands, str):
2212
- bands = [bands]
2213
- bands = set(bands)
2214
- copied._band_ids = bands
2215
- copied.images = [img[bands] for img in copied.images if bands in img]
2216
-
2217
- if exclude_bands is not None:
2218
- if isinstance(exclude_bands, str):
2219
- exclude_bands = {exclude_bands}
2220
- else:
2221
- exclude_bands = set(exclude_bands)
2222
- include_bands: list[list[str]] = [
2223
- [band_id for band_id in img.band_ids if band_id not in exclude_bands]
2224
- for img in copied
2225
- ]
2226
- copied.images = [
2227
- img[bands]
2228
- for img, bands in zip(copied.images, include_bands, strict=False)
2229
- if bands
2230
- ]
2231
-
2232
- return copied
2233
-
2234
2389
  def _filter_dates(
2235
2390
  self,
2236
- date_ranges: (
2237
- tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...]
2238
- ),
2391
+ date_ranges: DATE_RANGES_TYPE = None,
2239
2392
  ) -> "ImageCollection":
2240
2393
  if not isinstance(date_ranges, (tuple, list)):
2241
2394
  raise TypeError(
@@ -2247,13 +2400,7 @@ class ImageCollection(_ImageBase):
2247
2400
  "Cannot set date_ranges when the class's image_regexes attribute is None"
2248
2401
  )
2249
2402
 
2250
- self.images = [
2251
- img
2252
- for img in self
2253
- if _date_is_within(
2254
- img.path, date_ranges, self.image_patterns, self.date_format
2255
- )
2256
- ]
2403
+ self.images = [img for img in self if _date_is_within(img.date, date_ranges)]
2257
2404
  return self
2258
2405
 
2259
2406
  def _filter_bounds(
@@ -2264,11 +2411,15 @@ class ImageCollection(_ImageBase):
2264
2411
 
2265
2412
  other = to_shapely(other)
2266
2413
 
2267
- # intersects_list = GeoSeries([img.union_all() for img in self]).intersects(other)
2268
- with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
2269
- intersects_list: list[bool] = parallel(
2270
- joblib.delayed(_intesects)(image, other) for image in self
2271
- )
2414
+ if self.processes == 1:
2415
+ intersects_list: pd.Series = GeoSeries(
2416
+ [img.union_all() for img in self]
2417
+ ).intersects(other)
2418
+ else:
2419
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
2420
+ intersects_list: list[bool] = parallel(
2421
+ joblib.delayed(_intesects)(image, other) for image in self
2422
+ )
2272
2423
 
2273
2424
  self.images = [
2274
2425
  image
@@ -2277,6 +2428,69 @@ class ImageCollection(_ImageBase):
2277
2428
  ]
2278
2429
  return self
2279
2430
 
2431
+ def to_xarray(
2432
+ self,
2433
+ **kwargs,
2434
+ ) -> Dataset:
2435
+ """Convert the raster to an xarray.Dataset.
2436
+
2437
+ Images are converted to 2d arrays for each unique bounds.
2438
+ The spatial dimensions will be labeled "x" and "y". The third
2439
+ dimension defaults to "date" if all images have date attributes.
2440
+ Otherwise defaults to the image name.
2441
+ """
2442
+ if any(not band.has_array for img in self for band in img):
2443
+ raise ValueError("Arrays must be loaded.")
2444
+
2445
+ # if by is None:
2446
+ if all(img.date for img in self):
2447
+ by = ["date"]
2448
+ elif not pd.Index([img.name for img in self]).is_unique:
2449
+ raise ValueError("Images must have unique names.")
2450
+ else:
2451
+ by = ["name"]
2452
+ # elif isinstance(by, str):
2453
+ # by = [by]
2454
+
2455
+ xarrs: dict[str, DataArray] = {}
2456
+ for (bounds, band_id), collection in self.groupby(["bounds", "band_id"]):
2457
+ name = f"{band_id}_{'-'.join(str(int(x)) for x in bounds)}"
2458
+ first_band = collection[0][0]
2459
+ coords = _generate_spatial_coords(
2460
+ first_band.transform, first_band.width, first_band.height
2461
+ )
2462
+ values = np.array([band.to_numpy() for img in collection for band in img])
2463
+ assert len(values) == len(collection)
2464
+
2465
+ # coords["band_id"] = [
2466
+ # band.band_id or i for i, band in enumerate(collection[0])
2467
+ # ]
2468
+ for attr in by:
2469
+ coords[attr] = [getattr(img, attr) for img in collection]
2470
+ # coords["band"] = band_id #
2471
+
2472
+ dims = [*by, "y", "x"]
2473
+ # dims = ["band", "y", "x"]
2474
+ # dims = {}
2475
+ # for attr in by:
2476
+ # dims[attr] = [getattr(img, attr) for img in collection]
2477
+
2478
+ xarrs[name] = DataArray(
2479
+ values,
2480
+ coords=coords,
2481
+ dims=dims,
2482
+ # name=name,
2483
+ name=band_id,
2484
+ attrs={
2485
+ "crs": collection.crs,
2486
+ "band_id": band_id,
2487
+ }, # , "bounds": bounds},
2488
+ **kwargs,
2489
+ )
2490
+
2491
+ return xr.combine_by_coords(list(xarrs.values()))
2492
+ # return Dataset(xarrs)
2493
+
2280
2494
  def to_gdfs(self, column: str = "value") -> dict[str, GeoDataFrame]:
2281
2495
  """Convert each band in each Image to a GeoDataFrame."""
2282
2496
  out = {}
@@ -2289,12 +2503,8 @@ class ImageCollection(_ImageBase):
2289
2503
  except AttributeError:
2290
2504
  name = f"{self.__class__.__name__}({i})"
2291
2505
 
2292
- band.load()
2293
-
2294
2506
  if name not in out:
2295
2507
  out[name] = band.to_gdf(column=column)
2296
- # else:
2297
- # out[name] = f"{self.__class__.__name__}({i})"
2298
2508
  return out
2299
2509
 
2300
2510
  def sample(self, n: int = 1, size: int = 500) -> "ImageCollection":
@@ -2363,11 +2573,16 @@ class ImageCollection(_ImageBase):
2363
2573
  """Number of images."""
2364
2574
  return len(self.images)
2365
2575
 
2366
- def __getitem__(
2367
- self,
2368
- item: int | slice | Sequence[int | bool] | BoundingBox | Sequence[BoundingBox],
2369
- ) -> Image | TORCHGEO_RETURN_TYPE:
2370
- """Select one Image by integer index, or multiple Images by slice, list of int or torchgeo.BoundingBox."""
2576
+ def __getattr__(self, attr: str) -> Any:
2577
+ """Make iterable of metadata_attribute."""
2578
+ if attr in (self.metadata_attributes or {}):
2579
+ return self._metadata_attribute_collection_type(
2580
+ [getattr(img, attr) for img in self]
2581
+ )
2582
+ return super().__getattribute__(attr)
2583
+
2584
+ def __getitem__(self, item: int | slice | Sequence[int | bool]) -> Image:
2585
+ """Select one Image by integer index, or multiple Images by slice, list of int."""
2371
2586
  if isinstance(item, int):
2372
2587
  return self.images[item]
2373
2588
 
@@ -2392,90 +2607,23 @@ class ImageCollection(_ImageBase):
2392
2607
  ]
2393
2608
  return copied
2394
2609
 
2395
- if not isinstance(item, BoundingBox) and not (
2396
- isinstance(item, Iterable)
2397
- and len(item)
2398
- and all(isinstance(x, BoundingBox) for x in item)
2399
- ):
2400
- copied = self.copy()
2401
- if callable(item):
2402
- item = [item(img) for img in copied]
2403
-
2404
- # check for base bool and numpy bool
2405
- if all("bool" in str(type(x)) for x in item):
2406
- copied.images = [img for x, img in zip(item, copied, strict=True) if x]
2610
+ copied = self.copy()
2611
+ if callable(item):
2612
+ item = [item(img) for img in copied]
2407
2613
 
2408
- else:
2409
- copied.images = [copied.images[i] for i in item]
2410
- return copied
2614
+ # check for base bool and numpy bool
2615
+ if all("bool" in str(type(x)) for x in item):
2616
+ copied.images = [img for x, img in zip(item, copied, strict=True) if x]
2411
2617
 
2412
- if isinstance(item, BoundingBox):
2413
- date_ranges: tuple[str] = (item.mint, item.maxt)
2414
- data: torch.Tensor = numpy_to_torch(
2415
- np.array(
2416
- [
2417
- band.values
2418
- for band in self.filter(
2419
- bbox=item, date_ranges=date_ranges
2420
- ).merge_by_band(bounds=item)
2421
- ]
2422
- )
2423
- )
2424
2618
  else:
2425
- bboxes: list[Polygon] = [to_bbox(x) for x in item]
2426
- date_ranges: list[list[str, str]] = [(x.mint, x.maxt) for x in item]
2427
- data: torch.Tensor = torch.cat(
2428
- [
2429
- numpy_to_torch(
2430
- np.array(
2431
- [
2432
- band.values
2433
- for band in self.filter(
2434
- bbox=bbox, date_ranges=date_range
2435
- ).merge_by_band(bounds=bbox)
2436
- ]
2437
- )
2438
- )
2439
- for bbox, date_range in zip(bboxes, date_ranges, strict=True)
2440
- ]
2441
- )
2442
-
2443
- crs = get_common_crs(self.images)
2444
-
2445
- key = "image" # if self.is_image else "mask"
2446
- sample = {key: data, "crs": crs, "bbox": item}
2447
-
2448
- return sample
2449
-
2450
- @property
2451
- def mint(self) -> float:
2452
- """Min timestamp of the images combined."""
2453
- return min(img.mint for img in self)
2454
-
2455
- @property
2456
- def maxt(self) -> float:
2457
- """Max timestamp of the images combined."""
2458
- return max(img.maxt for img in self)
2459
-
2460
- @property
2461
- def band_ids(self) -> list[str]:
2462
- """Sorted list of unique band_ids."""
2463
- return list(sorted({band.band_id for img in self for band in img}))
2464
-
2465
- @property
2466
- def file_paths(self) -> list[str]:
2467
- """Sorted list of all file paths, meaning all band paths."""
2468
- return list(sorted({band.path for img in self for band in img}))
2619
+ copied.images = [copied.images[i] for i in item]
2620
+ return copied
2469
2621
 
2470
2622
  @property
2471
2623
  def dates(self) -> list[str]:
2472
2624
  """List of image dates."""
2473
2625
  return [img.date for img in self]
2474
2626
 
2475
- def dates_as_int(self) -> list[int]:
2476
- """List of image dates as 8-length integers."""
2477
- return [int(img.date[:8]) for img in self]
2478
-
2479
2627
  @property
2480
2628
  def image_paths(self) -> list[str]:
2481
2629
  """List of image paths."""
@@ -2496,29 +2644,22 @@ class ImageCollection(_ImageBase):
2496
2644
  masking=self.masking,
2497
2645
  **self._common_init_kwargs,
2498
2646
  )
2647
+
2499
2648
  if self.masking is not None:
2500
2649
  images = []
2501
2650
  for image in self._images:
2651
+ # TODO why this loop?
2502
2652
  try:
2503
2653
  if not isinstance(image.mask, Band):
2504
2654
  raise ValueError()
2505
2655
  images.append(image)
2506
- except ValueError:
2656
+ except ValueError as e:
2657
+ raise e
2507
2658
  continue
2508
2659
  self._images = images
2509
2660
  for image in self._images:
2510
2661
  image._bands = [band for band in image if band.band_id is not None]
2511
2662
 
2512
- if self.metadata is not None:
2513
- for img in self:
2514
- for band in img:
2515
- for key in ["crs", "bounds"]:
2516
- try:
2517
- value = self.metadata[band.path][key]
2518
- except KeyError:
2519
- value = self.metadata[key][band.path]
2520
- setattr(band, f"_{key}", value)
2521
-
2522
2663
  self._images = [img for img in self if len(img)]
2523
2664
 
2524
2665
  if self._should_be_sorted:
@@ -2543,7 +2684,7 @@ class ImageCollection(_ImageBase):
2543
2684
  and sort_group in _get_non_optional_groups(pat)
2544
2685
  for pat in self.image_patterns
2545
2686
  )
2546
- or all(img.date is not None for img in self)
2687
+ or all(getattr(img, sort_group) is not None for img in self)
2547
2688
  )
2548
2689
 
2549
2690
  @images.setter
@@ -2552,31 +2693,20 @@ class ImageCollection(_ImageBase):
2552
2693
  if not all(isinstance(x, Image) for x in self._images):
2553
2694
  raise TypeError("images should be a sequence of Image.")
2554
2695
 
2555
- @property
2556
- def index(self) -> Index:
2557
- """Spatial index that makes torchgeo think this class is a RasterDataset."""
2558
- try:
2559
- if len(self) == len(self._index):
2560
- return self._index
2561
- except AttributeError:
2562
- self._index = Index(interleaved=False, properties=Property(dimension=3))
2563
-
2564
- for i, img in enumerate(self.images):
2565
- if img.date:
2566
- try:
2567
- mint, maxt = disambiguate_timestamp(img.date, self.date_format)
2568
- except (NameError, TypeError):
2569
- mint, maxt = 0, 1
2570
- else:
2571
- mint, maxt = 0, 1
2572
- # important: torchgeo has a different order of the bbox than shapely and geopandas
2573
- minx, miny, maxx, maxy = img.bounds
2574
- self._index.insert(i, (minx, maxx, miny, maxy, mint, maxt))
2575
- return self._index
2576
-
2577
2696
  def __repr__(self) -> str:
2578
2697
  """String representation."""
2579
- return f"{self.__class__.__name__}({len(self)}, path='{self.path}')"
2698
+ root = ""
2699
+ if self.path is not None:
2700
+ data = f"'{self.path}'"
2701
+ elif all(img.path is not None for img in self):
2702
+ data = [img.path for img in self]
2703
+ parents = {str(Path(path).parent) for path in data}
2704
+ if len(parents) == 1:
2705
+ data = [Path(path).name for path in data]
2706
+ root = f" root='{next(iter(parents))}',"
2707
+ else:
2708
+ data = [img for img in self]
2709
+ return f"{self.__class__.__name__}({data},{root} res={self.res}, level='{self.level}')"
2580
2710
 
2581
2711
  def union_all(self) -> Polygon | MultiPolygon:
2582
2712
  """(Multi)Polygon representing the union of all image bounds."""
@@ -2603,6 +2733,7 @@ class ImageCollection(_ImageBase):
2603
2733
  p: float = 0.95,
2604
2734
  ylim: tuple[float, float] | None = None,
2605
2735
  figsize: tuple[int] = (20, 8),
2736
+ rounding: int = 3,
2606
2737
  ) -> None:
2607
2738
  """Plot each individual pixel in a dotplot for all dates.
2608
2739
 
@@ -2616,6 +2747,7 @@ class ImageCollection(_ImageBase):
2616
2747
  p: p-value for the confidence interval.
2617
2748
  ylim: Limits of the y-axis.
2618
2749
  figsize: Figure size as tuple (width, height).
2750
+ rounding: rounding of title n
2619
2751
 
2620
2752
  """
2621
2753
  if by is None and all(band.band_id is not None for img in self for band in img):
@@ -2625,12 +2757,11 @@ class ImageCollection(_ImageBase):
2625
2757
 
2626
2758
  alpha = 1 - p
2627
2759
 
2628
- for img in self:
2629
- for band in img:
2630
- band.load()
2631
-
2632
2760
  for group_values, subcollection in self.groupby(by):
2633
- print("group_values:", *group_values)
2761
+ print("subcollection group values:", group_values)
2762
+
2763
+ if "date" in x_var and subcollection._should_be_sorted:
2764
+ subcollection._images = list(sorted(subcollection._images))
2634
2765
 
2635
2766
  y = np.array([band.values for img in subcollection for band in img])
2636
2767
  if "date" in x_var and subcollection._should_be_sorted:
@@ -2641,6 +2772,7 @@ class ImageCollection(_ImageBase):
2641
2772
  for band in img
2642
2773
  ]
2643
2774
  )
2775
+ first_date = pd.Timestamp(x[0])
2644
2776
  x = (
2645
2777
  pd.to_datetime(
2646
2778
  [band.date[:8] for img in subcollection for band in img]
@@ -2685,6 +2817,10 @@ class ImageCollection(_ImageBase):
2685
2817
  )[0]
2686
2818
  predicted = np.array([intercept + coef * x for x in this_x])
2687
2819
 
2820
+ predicted_start = predicted[0]
2821
+ predicted_end = predicted[-1]
2822
+ predicted_change = predicted_end - predicted_start
2823
+
2688
2824
  # Degrees of freedom
2689
2825
  dof = len(this_x) - 2
2690
2826
 
@@ -2708,8 +2844,6 @@ class ImageCollection(_ImageBase):
2708
2844
  ci_lower = predicted - t_val * pred_stderr
2709
2845
  ci_upper = predicted + t_val * pred_stderr
2710
2846
 
2711
- rounding = int(np.log(1 / abs(coef)))
2712
-
2713
2847
  fig = plt.figure(figsize=figsize)
2714
2848
  ax = fig.add_subplot(1, 1, 1)
2715
2849
 
@@ -2723,120 +2857,353 @@ class ImageCollection(_ImageBase):
2723
2857
  alpha=0.2,
2724
2858
  label=f"{int(alpha*100)}% CI",
2725
2859
  )
2726
- plt.title(f"Coefficient: {round(coef, rounding)}")
2860
+ plt.title(
2861
+ f"coef: {round(coef, int(np.log(1 / abs(coef))))}, "
2862
+ f"pred change: {round(predicted_change, rounding)}, "
2863
+ f"pred start: {round(predicted_start, rounding)}, "
2864
+ f"pred end: {round(predicted_end, rounding)}"
2865
+ )
2727
2866
  plt.xlabel(x_var)
2728
2867
  plt.ylabel(y_label)
2729
- plt.show()
2730
2868
 
2869
+ if x_var == "date":
2870
+ date_labels = pd.to_datetime(
2871
+ [first_date + pd.Timedelta(days=int(day)) for day in this_x]
2872
+ )
2731
2873
 
2732
- def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCollection:
2733
- """Union multiple ImageCollections together.
2874
+ _, unique_indices = np.unique(
2875
+ date_labels.strftime("%Y-%m"), return_index=True
2876
+ )
2734
2877
 
2735
- Same as using the union operator |.
2736
- """
2737
- resolutions = {x.res for x in collections}
2738
- if len(resolutions) > 1:
2739
- raise ValueError(f"resoultion mismatch. {resolutions}")
2740
- images = list(itertools.chain.from_iterable([x.images for x in collections]))
2741
- levels = {x.level for x in collections}
2742
- level = next(iter(levels)) if len(levels) == 1 else None
2743
- first_collection = collections[0]
2878
+ unique_x = np.array(this_x)[unique_indices]
2879
+ unique_labels = date_labels[unique_indices].strftime("%Y-%m")
2744
2880
 
2745
- out_collection = first_collection.__class__(
2746
- images,
2747
- level=level,
2748
- band_class=first_collection.band_class,
2749
- image_class=first_collection.image_class,
2750
- **first_collection._common_init_kwargs,
2881
+ ax.set_xticks(unique_x)
2882
+ ax.set_xticklabels(unique_labels, rotation=45, ha="right")
2883
+ # ax.tick_params(axis="x", length=10, width=2)
2884
+
2885
+ plt.show()
2886
+
2887
+
2888
+ def _get_all_regex_matches(xml_file: str, regexes: tuple[str]) -> tuple[str]:
2889
+ for regex in regexes:
2890
+ try:
2891
+ return re.search(regex, xml_file)
2892
+ except (TypeError, AttributeError):
2893
+ continue
2894
+ raise ValueError(
2895
+ f"Could not find processing_baseline info from {regexes} in {xml_file}"
2751
2896
  )
2752
- out_collection._all_file_paths = list(
2753
- sorted(
2754
- set(itertools.chain.from_iterable([x._all_file_paths for x in collections]))
2755
- )
2897
+
2898
+
2899
+ class Sentinel2Config:
2900
+ """Holder of Sentinel 2 regexes, band_ids etc."""
2901
+
2902
+ image_regexes: ClassVar[str] = (config.SENTINEL2_IMAGE_REGEX,)
2903
+ filename_regexes: ClassVar[str] = (config.SENTINEL2_FILENAME_REGEX,)
2904
+ metadata_attributes: ClassVar[
2905
+ dict[str, Callable | functools.partial | tuple[str]]
2906
+ ] = {
2907
+ "processing_baseline": functools.partial(
2908
+ _extract_regex_match_from_string,
2909
+ regexes=(r"<PROCESSING_BASELINE>(.*?)</PROCESSING_BASELINE>",),
2910
+ ),
2911
+ "cloud_cover_percentage": "_get_cloud_cover_percentage",
2912
+ "is_refined": "_get_image_refining_flag",
2913
+ "boa_quantification_value": "_get_boa_quantification_value",
2914
+ }
2915
+ l1c_bands: ClassVar[set[str]] = {
2916
+ "B01": 60,
2917
+ "B02": 10,
2918
+ "B03": 10,
2919
+ "B04": 10,
2920
+ "B05": 20,
2921
+ "B06": 20,
2922
+ "B07": 20,
2923
+ "B08": 10,
2924
+ "B8A": 20,
2925
+ "B09": 60,
2926
+ "B10": 60,
2927
+ "B11": 20,
2928
+ "B12": 20,
2929
+ }
2930
+ l2a_bands: ClassVar[set[str]] = {
2931
+ key: res for key, res in l1c_bands.items() if key != "B10"
2932
+ }
2933
+ all_bands: ClassVar[set[str]] = l1c_bands
2934
+ rbg_bands: ClassVar[tuple[str]] = ("B04", "B02", "B03")
2935
+ ndvi_bands: ClassVar[tuple[str]] = ("B04", "B08")
2936
+ masking: ClassVar[BandMasking] = BandMasking(
2937
+ band_id="SCL",
2938
+ values={
2939
+ 2: "Topographic casted shadows",
2940
+ 3: "Cloud shadows",
2941
+ 8: "Cloud medium probability",
2942
+ 9: "Cloud high probability",
2943
+ 10: "Thin cirrus",
2944
+ 11: "Snow or ice",
2945
+ },
2756
2946
  )
2757
- return out_collection
2758
2947
 
2948
+ def _get_image_refining_flag(self, xml_file: str) -> bool:
2949
+ match_ = re.search(
2950
+ r'Image_Refining flag="(?:REFINED|NOT_REFINED)"',
2951
+ xml_file,
2952
+ )
2953
+ if match_ is None:
2954
+ raise _RegexError()
2759
2955
 
2760
- def _get_gradient(band: Band, degrees: bool = False, copy: bool = True) -> Band:
2761
- copied = band.copy() if copy else band
2762
- if len(copied.values.shape) == 3:
2763
- return np.array(
2764
- [_slope_2d(arr, copied.res, degrees=degrees) for arr in copied.values]
2956
+ if "NOT_REFINED" in match_.group(0):
2957
+ return False
2958
+ elif "REFINED" in match_.group(0):
2959
+ return True
2960
+ else:
2961
+ raise _RegexError()
2962
+
2963
+ def _get_boa_quantification_value(self, xml_file: str) -> int:
2964
+ return int(
2965
+ _extract_regex_match_from_string(
2966
+ xml_file,
2967
+ (
2968
+ r'<BOA_QUANTIFICATION_VALUE unit="none">-?(\d+)</BOA_QUANTIFICATION_VALUE>',
2969
+ ),
2970
+ )
2765
2971
  )
2766
- elif len(copied.values.shape) == 2:
2767
- return _slope_2d(copied.values, copied.res, degrees=degrees)
2768
- else:
2769
- raise ValueError("array must be 2 or 3 dimensional")
2770
2972
 
2973
+ def _get_cloud_cover_percentage(self, xml_file: str) -> float:
2974
+ return float(
2975
+ _extract_regex_match_from_string(
2976
+ xml_file,
2977
+ (
2978
+ r"<Cloud_Coverage_Assessment>([\d.]+)</Cloud_Coverage_Assessment>",
2979
+ r"<CLOUDY_PIXEL_OVER_LAND_PERCENTAGE>([\d.]+)</CLOUDY_PIXEL_OVER_LAND_PERCENTAGE>",
2980
+ ),
2981
+ )
2982
+ )
2771
2983
 
2772
- def to_xarray(
2773
- array: np.ndarray, transform: Affine, crs: Any, name: str | None = None
2774
- ) -> DataArray:
2775
- """Convert the raster to an xarray.DataArray."""
2776
- if len(array.shape) == 2:
2777
- height, width = array.shape
2778
- dims = ["y", "x"]
2779
- elif len(array.shape) == 3:
2780
- height, width = array.shape[1:]
2781
- dims = ["band", "y", "x"]
2782
- else:
2783
- raise ValueError(f"Array should be 2 or 3 dimensional. Got shape {array.shape}")
2784
-
2785
- coords = _generate_spatial_coords(transform, width, height)
2786
- return xr.DataArray(
2787
- array,
2788
- coords=coords,
2789
- dims=dims,
2790
- name=name,
2791
- attrs={"crs": crs},
2792
- )
2793
2984
 
2985
+ class Sentinel2CloudlessConfig(Sentinel2Config):
2986
+ """Holder of regexes, band_ids etc. for Sentinel 2 cloudless mosaic."""
2794
2987
 
2795
- def _slope_2d(array: np.ndarray, res: int, degrees: int) -> np.ndarray:
2796
- gradient_x, gradient_y = np.gradient(array, res, res)
2988
+ image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2989
+ filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2990
+ masking: ClassVar[None] = None
2991
+ all_bands: ClassVar[list[str]] = [
2992
+ x.replace("B0", "B") for x in Sentinel2Config.all_bands
2993
+ ]
2994
+ rbg_bands: ClassVar[dict[str, str]] = {
2995
+ key.replace("B0", "B") for key in Sentinel2Config.rbg_bands
2996
+ }
2997
+ ndvi_bands: ClassVar[dict[str, str]] = {
2998
+ key.replace("B0", "B") for key in Sentinel2Config.ndvi_bands
2999
+ }
2797
3000
 
2798
- gradient = abs(gradient_x) + abs(gradient_y)
2799
3001
 
2800
- if not degrees:
2801
- return gradient
3002
+ class Sentinel2Band(Sentinel2Config, Band):
3003
+ """Band with Sentinel2 specific name variables and regexes."""
2802
3004
 
2803
- radians = np.arctan(gradient)
2804
- degrees = np.degrees(radians)
3005
+ metadata_attributes = Sentinel2Config.metadata_attributes | {
3006
+ "boa_add_offset": "_get_boa_add_offset_dict",
3007
+ }
2805
3008
 
2806
- assert np.max(degrees) <= 90
3009
+ def _get_boa_add_offset_dict(self, xml_file: str) -> int | None:
3010
+ if self.is_mask:
3011
+ return None
2807
3012
 
2808
- return degrees
3013
+ pat = re.compile(
3014
+ r"""
3015
+ <BOA_ADD_OFFSET\s*
3016
+ band_id="(?P<band_id>\d+)"\s*
3017
+ >\s*(?P<value>-?\d+)\s*
3018
+ </BOA_ADD_OFFSET>
3019
+ """,
3020
+ flags=re.VERBOSE,
3021
+ )
2809
3022
 
3023
+ try:
3024
+ matches = [x.groupdict() for x in re.finditer(pat, xml_file)]
3025
+ except (TypeError, AttributeError, KeyError) as e:
3026
+ raise _RegexError(f"Could not find boa_add_offset info from {pat}") from e
3027
+ if not matches:
3028
+ raise _RegexError(f"Could not find boa_add_offset info from {pat}")
3029
+
3030
+ dict_ = (
3031
+ pd.DataFrame(matches).set_index("band_id")["value"].astype(int).to_dict()
3032
+ )
2810
3033
 
2811
- def _clip_loaded_array(
2812
- arr: np.ndarray,
2813
- bounds: tuple[int, int, int, int],
2814
- transform: Affine,
2815
- crs: Any,
2816
- out_shape: tuple[int, int],
2817
- **kwargs,
2818
- ) -> np.ndarray:
2819
- # xarray needs a numpy array of polygon(s)
2820
- bounds_arr: np.ndarray = GeoSeries([to_shapely(bounds)]).values
2821
- try:
3034
+ # some xml files have band ids in range index form
3035
+ # converting these to actual band ids (B01 etc.)
3036
+ is_integer_coded = [int(i) for i in dict_] == list(range(len(dict_)))
2822
3037
 
2823
- while out_shape != arr.shape:
2824
- arr = (
2825
- to_xarray(
2826
- arr,
2827
- transform=transform,
2828
- crs=crs,
3038
+ if is_integer_coded:
3039
+ # the xml files contain 13 bandIds for both L1C and L2A
3040
+ # eventhough L2A doesn't have band B10
3041
+ all_bands = list(self.l1c_bands)
3042
+ if len(all_bands) != len(dict_):
3043
+ raise ValueError(
3044
+ f"Different number of bands in xml file and config for {self.name}: {all_bands}, {list(dict_)}"
2829
3045
  )
2830
- .rio.clip(bounds_arr, crs=crs, **kwargs)
2831
- .to_numpy()
2832
- )
2833
- # bounds_arr = bounds_arr.buffer(0.0000001)
2834
- return arr
3046
+ dict_ = {
3047
+ band_id: value
3048
+ for band_id, value in zip(all_bands, dict_.values(), strict=True)
3049
+ }
3050
+
3051
+ try:
3052
+ return dict_[self.band_id]
3053
+ except KeyError as e:
3054
+ band_id = self.band_id.upper()
3055
+ for txt in ["B0", "B", "A"]:
3056
+ band_id = band_id.replace(txt, "")
3057
+ try:
3058
+ return dict_[band_id]
3059
+ except KeyError:
3060
+ continue
3061
+ raise KeyError(self.band_id, dict_) from e
3062
+
3063
+
3064
+ class Sentinel2Image(Sentinel2Config, Image):
3065
+ """Image with Sentinel2 specific name variables and regexes."""
3066
+
3067
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3068
+
3069
+ def ndvi(
3070
+ self,
3071
+ red_band: str = "B04",
3072
+ nir_band: str = "B08",
3073
+ padding: int = 0,
3074
+ copy: bool = True,
3075
+ ) -> NDVIBand:
3076
+ """Calculate the NDVI for the Image."""
3077
+ return super().ndvi(
3078
+ red_band=red_band, nir_band=nir_band, padding=padding, copy=copy
3079
+ )
3080
+
3081
+
3082
+ class Sentinel2Collection(Sentinel2Config, ImageCollection):
3083
+ """ImageCollection with Sentinel2 specific name variables and path regexes."""
3084
+
3085
+ image_class: ClassVar[Sentinel2Image] = Sentinel2Image
3086
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3087
+
3088
+ def __init__(self, data: str | Path | Sequence[Image], **kwargs) -> None:
3089
+ """ImageCollection with Sentinel2 specific name variables and path regexes."""
3090
+ level = kwargs.get("level", None_)
3091
+ if callable(level) and isinstance(level(), None_):
3092
+ raise ValueError("Must specify level for Sentinel2Collection.")
3093
+ super().__init__(data=data, **kwargs)
3094
+
3095
+
3096
+ class Sentinel2CloudlessBand(Sentinel2CloudlessConfig, Band):
3097
+ """Band for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3098
+
3099
+
3100
+ class Sentinel2CloudlessImage(Sentinel2CloudlessConfig, Sentinel2Image):
3101
+ """Image for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3102
+
3103
+ band_class: ClassVar[Sentinel2CloudlessBand] = Sentinel2CloudlessBand
3104
+
3105
+ ndvi = Sentinel2Image.ndvi
3106
+
3107
+
3108
+ class Sentinel2CloudlessCollection(Sentinel2CloudlessConfig, ImageCollection):
3109
+ """ImageCollection with Sentinel2 specific name variables and regexes."""
3110
+
3111
+ image_class: ClassVar[Sentinel2CloudlessImage] = Sentinel2CloudlessImage
3112
+ band_class: ClassVar[Sentinel2Band] = Sentinel2CloudlessBand
3113
+
3114
+
3115
+ def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCollection:
3116
+ """Union multiple ImageCollections together.
3117
+
3118
+ Same as using the union operator |.
3119
+ """
3120
+ resolutions = {x.res for x in collections}
3121
+ if len(resolutions) > 1:
3122
+ raise ValueError(f"resoultion mismatch. {resolutions}")
3123
+ images = list(itertools.chain.from_iterable([x.images for x in collections]))
3124
+ levels = {x.level for x in collections}
3125
+ level = next(iter(levels)) if len(levels) == 1 else None
3126
+ first_collection = collections[0]
3127
+
3128
+ out_collection = first_collection.__class__(
3129
+ images,
3130
+ level=level,
3131
+ band_class=first_collection.band_class,
3132
+ image_class=first_collection.image_class,
3133
+ **first_collection._common_init_kwargs,
3134
+ )
3135
+ out_collection._all_file_paths = list(
3136
+ sorted(
3137
+ set(itertools.chain.from_iterable([x._all_file_paths for x in collections]))
3138
+ )
3139
+ )
3140
+ return out_collection
3141
+
3142
+
3143
+ def _get_gradient(band: Band, degrees: bool = False, copy: bool = True) -> Band:
3144
+ copied = band.copy() if copy else band
3145
+ if len(copied.values.shape) == 3:
3146
+ return np.array(
3147
+ [_slope_2d(arr, copied.res, degrees=degrees) for arr in copied.values]
3148
+ )
3149
+ elif len(copied.values.shape) == 2:
3150
+ return _slope_2d(copied.values, copied.res, degrees=degrees)
3151
+ else:
3152
+ raise ValueError("array must be 2 or 3 dimensional")
3153
+
3154
+
3155
+ def _slope_2d(array: np.ndarray, res: int, degrees: int) -> np.ndarray:
3156
+ gradient_x, gradient_y = np.gradient(array, res, res)
3157
+
3158
+ gradient = abs(gradient_x) + abs(gradient_y)
3159
+
3160
+ if not degrees:
3161
+ return gradient
3162
+
3163
+ radians = np.arctan(gradient)
3164
+ degrees = np.degrees(radians)
3165
+
3166
+ assert np.max(degrees) <= 90
3167
+
3168
+ return degrees
2835
3169
 
3170
+
3171
+ def _clip_xarray(
3172
+ xarr: DataArray,
3173
+ mask: tuple[int, int, int, int],
3174
+ crs: Any,
3175
+ **kwargs,
3176
+ ) -> DataArray:
3177
+ # xarray needs a numpy array of polygons
3178
+ mask_arr: np.ndarray = to_geoseries(mask).values
3179
+ try:
3180
+ return xarr.rio.clip(
3181
+ mask_arr,
3182
+ crs=crs,
3183
+ **kwargs,
3184
+ )
2836
3185
  except NoDataInBounds:
2837
3186
  return np.array([])
2838
3187
 
2839
3188
 
3189
+ def _get_all_file_paths(path: str) -> set[str]:
3190
+ if is_dapla():
3191
+ return {_fix_path(x) for x in sorted(set(_glob_func(path + "/**")))}
3192
+ else:
3193
+ return {
3194
+ _fix_path(x)
3195
+ for x in sorted(
3196
+ set(
3197
+ _glob_func(path + "/**")
3198
+ + _glob_func(path + "/**/**")
3199
+ + _glob_func(path + "/**/**/**")
3200
+ + _glob_func(path + "/**/**/**/**")
3201
+ + _glob_func(path + "/**/**/**/**/**")
3202
+ )
3203
+ )
3204
+ }
3205
+
3206
+
2840
3207
  def _get_images(
2841
3208
  image_paths: list[str],
2842
3209
  *,
@@ -2849,9 +3216,8 @@ def _get_images(
2849
3216
  masking: BandMasking | None,
2850
3217
  **kwargs,
2851
3218
  ) -> list[Image]:
2852
-
2853
- with joblib.Parallel(n_jobs=processes, backend="loky") as parallel:
2854
- images = parallel(
3219
+ with joblib.Parallel(n_jobs=processes, backend="threading") as parallel:
3220
+ images: list[Image] = parallel(
2855
3221
  joblib.delayed(image_class)(
2856
3222
  path,
2857
3223
  df=df,
@@ -2874,21 +3240,6 @@ def _get_images(
2874
3240
  return images
2875
3241
 
2876
3242
 
2877
- def numpy_to_torch(array: np.ndarray) -> torch.Tensor:
2878
- """Convert numpy array to a pytorch tensor."""
2879
- # fix numpy dtypes which are not supported by pytorch tensors
2880
- if array.dtype == np.uint16:
2881
- array = array.astype(np.int32)
2882
- elif array.dtype == np.uint32:
2883
- array = array.astype(np.int64)
2884
-
2885
- return torch.tensor(array)
2886
-
2887
-
2888
- class _RegexError(ValueError):
2889
- pass
2890
-
2891
-
2892
3243
  class ArrayNotLoadedError(ValueError):
2893
3244
  """Arrays are not loaded."""
2894
3245
 
@@ -2904,10 +3255,12 @@ class PathlessImageError(ValueError):
2904
3255
  """String representation."""
2905
3256
  if self.instance._merged:
2906
3257
  what = "that have been merged"
2907
- elif self.isinstance._from_array:
3258
+ elif self.instance._from_array:
2908
3259
  what = "from arrays"
2909
- elif self.isinstance._from_gdf:
3260
+ elif self.instance._from_gdf:
2910
3261
  what = "from GeoDataFrames"
3262
+ else:
3263
+ raise ValueError(self.instance)
2911
3264
 
2912
3265
  return (
2913
3266
  f"{self.instance.__class__.__name__} instances {what} "
@@ -2915,165 +3268,32 @@ class PathlessImageError(ValueError):
2915
3268
  )
2916
3269
 
2917
3270
 
2918
- def _get_regex_match_from_xml_in_local_dir(
2919
- paths: list[str], regexes: str | tuple[str]
2920
- ) -> str | dict[str, str]:
2921
- for i, path in enumerate(paths):
2922
- if ".xml" not in path:
2923
- continue
2924
- with _open_func(path, "rb") as file:
2925
- filebytes: bytes = file.read()
2926
- try:
2927
- return _extract_regex_match_from_string(
2928
- filebytes.decode("utf-8"), regexes
2929
- )
2930
- except _RegexError as e:
2931
- if i == len(paths) - 1:
2932
- raise e
2933
-
2934
-
2935
- def _extract_regex_match_from_string(
2936
- xml_file: str, regexes: tuple[str | re.Pattern]
2937
- ) -> str | dict[str, str]:
2938
- if all(isinstance(x, str) for x in regexes):
2939
- for regex in regexes:
2940
- try:
2941
- return re.search(regex, xml_file).group(1)
2942
- except (TypeError, AttributeError):
2943
- continue
2944
- raise _RegexError()
2945
-
2946
- out = {}
2947
- for regex in regexes:
2948
- try:
2949
- matches = re.search(regex, xml_file)
2950
- out |= matches.groupdict()
2951
- except (TypeError, AttributeError):
2952
- continue
2953
- if not out:
2954
- raise _RegexError()
2955
- return out
2956
-
2957
-
2958
- def _fix_path(path: str) -> str:
2959
- return (
2960
- str(path).replace("\\", "/").replace(r"\"", "/").replace("//", "/").rstrip("/")
2961
- )
2962
-
2963
-
2964
- def _get_regexes_matches_for_df(
2965
- df, match_col: str, patterns: Sequence[re.Pattern]
2966
- ) -> pd.DataFrame:
2967
- if not len(df):
2968
- return df
2969
-
2970
- non_optional_groups = list(
2971
- set(
2972
- itertools.chain.from_iterable(
2973
- [_get_non_optional_groups(pat) for pat in patterns]
2974
- )
2975
- )
2976
- )
2977
-
2978
- if not non_optional_groups:
2979
- return df
2980
-
2981
- assert df.index.is_unique
2982
- keep = []
2983
- for pat in patterns:
2984
- for i, row in df[match_col].items():
2985
- matches = _get_first_group_match(pat, row)
2986
- if all(group in matches for group in non_optional_groups):
2987
- keep.append(i)
2988
-
2989
- return df.loc[keep]
2990
-
2991
-
2992
- def _get_non_optional_groups(pat: re.Pattern | str) -> list[str]:
2993
- return [
2994
- x
2995
- for x in [
2996
- _extract_group_name(group)
2997
- for group in pat.pattern.split("\n")
2998
- if group
2999
- and not group.replace(" ", "").startswith("#")
3000
- and not group.replace(" ", "").split("#")[0].endswith("?")
3001
- ]
3002
- if x is not None
3003
- ]
3004
-
3005
-
3006
- def _extract_group_name(txt: str) -> str | None:
3007
- try:
3008
- return re.search(r"\(\?P<(\w+)>", txt)[1]
3009
- except TypeError:
3010
- return None
3011
-
3012
-
3013
- def _get_first_group_match(pat: re.Pattern, text: str) -> dict[str, str]:
3014
- groups = pat.groupindex.keys()
3015
- all_matches: dict[str, str] = {}
3016
- for x in pat.findall(text):
3017
- for group, value in zip(groups, x, strict=True):
3018
- if value and group not in all_matches:
3019
- all_matches[group] = value
3020
- return all_matches
3021
-
3022
-
3023
3271
  def _date_is_within(
3024
- path,
3025
- date_ranges: (
3026
- tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...] | None
3027
- ),
3028
- image_patterns: Sequence[re.Pattern],
3029
- date_format: str,
3272
+ date: str | None,
3273
+ date_ranges: DATE_RANGES_TYPE,
3030
3274
  ) -> bool:
3031
- for pat in image_patterns:
3032
-
3033
- try:
3034
- date = _get_first_group_match(pat, Path(path).name)["date"]
3035
- break
3036
- except KeyError:
3037
- date = None
3275
+ if date_ranges is None:
3276
+ return True
3038
3277
 
3039
3278
  if date is None:
3040
3279
  return False
3041
3280
 
3042
- if date_ranges is None:
3043
- return True
3281
+ date = pd.Timestamp(date)
3044
3282
 
3045
- if all(x is None or isinstance(x, (str, float)) for x in date_ranges):
3283
+ if all(x is None or isinstance(x, str) for x in date_ranges):
3046
3284
  date_ranges = (date_ranges,)
3047
3285
 
3048
- if all(isinstance(x, float) for date_range in date_ranges for x in date_range):
3049
- date = disambiguate_timestamp(date, date_format)
3050
- else:
3051
- date = date[:8]
3052
-
3053
3286
  for date_range in date_ranges:
3054
3287
  date_min, date_max = date_range
3055
3288
 
3056
- if isinstance(date_min, float) and isinstance(date_max, float):
3057
- if date[0] >= date_min + 0.0000001 and date[1] <= date_max - 0.0000001:
3058
- return True
3059
- continue
3289
+ if date_min is not None:
3290
+ date_min = pd.Timestamp(date_min)
3291
+ if date_max is not None:
3292
+ date_max = pd.Timestamp(date_max)
3060
3293
 
3061
- try:
3062
- date_min = date_min or "00000000"
3063
- date_max = date_max or "99999999"
3064
- if not (
3065
- isinstance(date_min, str)
3066
- and len(date_min) == 8
3067
- and isinstance(date_max, str)
3068
- and len(date_max) == 8
3069
- ):
3070
- raise ValueError()
3071
- except ValueError as err:
3072
- raise TypeError(
3073
- "date_ranges should be a tuple of two 8-charactered strings (start and end date)."
3074
- f"Got {date_range} of type {[type(x) for x in date_range]}"
3075
- ) from err
3076
- if date >= date_min and date <= date_max:
3294
+ if (date_min is None or date >= date_min) and (
3295
+ date_max is None or date <= date_max
3296
+ ):
3077
3297
  return True
3078
3298
 
3079
3299
  return False
@@ -3093,10 +3313,6 @@ def _get_dtype_max(dtype: str | type) -> int | float:
3093
3313
  return np.finfo(dtype).max
3094
3314
 
3095
3315
 
3096
- def _img_ndvi(img, **kwargs):
3097
- return Image([img.ndvi(**kwargs)])
3098
-
3099
-
3100
3316
  def _intesects(x, other) -> bool:
3101
3317
  return box(*x.bounds).intersects(other)
3102
3318
 
@@ -3116,6 +3332,17 @@ def _copy_and_add_df_parallel(
3116
3332
  return (i, copied)
3117
3333
 
3118
3334
 
3335
+ def _get_bounds(bounds, bbox, band_bounds: Polygon) -> None | Polygon:
3336
+ if bounds is None and bbox is None:
3337
+ return None
3338
+ elif bounds is not None and bbox is None:
3339
+ return to_shapely(bounds).intersection(band_bounds)
3340
+ elif bounds is None and bbox is not None:
3341
+ return to_shapely(bbox).intersection(band_bounds)
3342
+ else:
3343
+ return to_shapely(bounds).intersection(to_shapely(bbox))
3344
+
3345
+
3119
3346
  def _get_single_value(values: tuple):
3120
3347
  if len(set(values)) == 1:
3121
3348
  return next(iter(values))
@@ -3129,7 +3356,15 @@ def _open_raster(path: str | Path) -> rasterio.io.DatasetReader:
3129
3356
 
3130
3357
 
3131
3358
  def _load_band(band: Band, **kwargs) -> None:
3132
- band.load(**kwargs)
3359
+ return band.load(**kwargs)
3360
+
3361
+
3362
+ def _band_apply(band: Band, func: Callable, **kwargs) -> None:
3363
+ return band.apply(func, **kwargs)
3364
+
3365
+
3366
+ def _clip_band(band: Band, mask, **kwargs) -> None:
3367
+ return band.clip(mask, **kwargs)
3133
3368
 
3134
3369
 
3135
3370
  def _merge_by_band(collection: ImageCollection, **kwargs) -> Image:
@@ -3141,7 +3376,7 @@ def _merge(collection: ImageCollection, **kwargs) -> Band:
3141
3376
 
3142
3377
 
3143
3378
  def _zonal_one_pair(i: int, poly: Polygon, band: Band, aggfunc, array_func, func_names):
3144
- clipped = band.copy().load(bounds=poly)
3379
+ clipped = band.copy().clip(poly)
3145
3380
  if not np.size(clipped.values):
3146
3381
  return _no_overlap_df(func_names, i, date=band.date)
3147
3382
  return _aggregate(clipped.values, array_func, aggfunc, func_names, band.date, i)
@@ -3173,85 +3408,126 @@ def array_buffer(arr: np.ndarray, distance: int) -> np.ndarray:
3173
3408
  return binary_erosion(arr, structure=structure).astype(dtype)
3174
3409
 
3175
3410
 
3176
- class Sentinel2Config:
3177
- """Holder of Sentinel 2 regexes, band_ids etc."""
3178
-
3179
- image_regexes: ClassVar[str] = (config.SENTINEL2_IMAGE_REGEX,)
3180
- filename_regexes: ClassVar[str] = (
3181
- config.SENTINEL2_FILENAME_REGEX,
3182
- config.SENTINEL2_CLOUD_FILENAME_REGEX,
3183
- )
3184
- all_bands: ClassVar[list[str]] = list(config.SENTINEL2_BANDS)
3185
- rbg_bands: ClassVar[list[str]] = config.SENTINEL2_RBG_BANDS
3186
- ndvi_bands: ClassVar[list[str]] = config.SENTINEL2_NDVI_BANDS
3187
- l2a_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L2A_BANDS
3188
- l1c_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L1C_BANDS
3189
- date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
3190
- masking: ClassVar[BandMasking] = BandMasking(
3191
- band_id="SCL", values=(3, 8, 9, 10, 11)
3192
- )
3193
-
3194
-
3195
- class Sentinel2CloudlessConfig(Sentinel2Config):
3196
- """Holder of regexes, band_ids etc. for Sentinel 2 cloudless mosaic."""
3411
+ def get_cmap(arr: np.ndarray) -> LinearSegmentedColormap:
3197
3412
 
3198
- image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
3199
- filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
3200
- masking: ClassVar[None] = None
3201
- date_format: ClassVar[str] = "%Y%m%d"
3202
- all_bands: ClassVar[list[str]] = [
3203
- x.replace("B0", "B") for x in Sentinel2Config.all_bands
3413
+ # blue = [[i / 10 + 0.1, i / 10 + 0.1, 1 - (i / 10) + 0.1] for i in range(11)][1:]
3414
+ blue = [
3415
+ [0.1, 0.1, 1.0],
3416
+ [0.2, 0.2, 0.9],
3417
+ [0.3, 0.3, 0.8],
3418
+ [0.4, 0.4, 0.7],
3419
+ [0.6, 0.6, 0.6],
3420
+ [0.6, 0.6, 0.6],
3421
+ [0.7, 0.7, 0.7],
3422
+ [0.8, 0.8, 0.8],
3204
3423
  ]
3205
- rbg_bands: ClassVar[list[str]] = [
3206
- x.replace("B0", "B") for x in Sentinel2Config.rbg_bands
3424
+ # gray = list(reversed([[i / 10 - 0.1, i / 10, i / 10 - 0.1] for i in range(11)][1:]))
3425
+ gray = [
3426
+ [0.6, 0.6, 0.6],
3427
+ [0.6, 0.6, 0.6],
3428
+ [0.6, 0.6, 0.6],
3429
+ [0.6, 0.6, 0.6],
3430
+ [0.6, 0.6, 0.6],
3431
+ [0.4, 0.7, 0.4],
3432
+ [0.3, 0.7, 0.3],
3433
+ [0.2, 0.8, 0.2],
3207
3434
  ]
3208
- ndvi_bands: ClassVar[list[str]] = [
3209
- x.replace("B0", "B") for x in Sentinel2Config.ndvi_bands
3435
+ # gray = [[0.6, 0.6, 0.6] for i in range(10)]
3436
+ # green = [[0.2 + i/20, i / 10 - 0.1, + i/20] for i in range(11)][1:]
3437
+ green = [
3438
+ [0.25, 0.0, 0.05],
3439
+ [0.3, 0.1, 0.1],
3440
+ [0.35, 0.2, 0.15],
3441
+ [0.4, 0.3, 0.2],
3442
+ [0.45, 0.4, 0.25],
3443
+ [0.5, 0.5, 0.3],
3444
+ [0.55, 0.6, 0.35],
3445
+ [0.7, 0.9, 0.5],
3446
+ ]
3447
+ green = [
3448
+ [0.6, 0.6, 0.6],
3449
+ [0.4, 0.7, 0.4],
3450
+ [0.3, 0.8, 0.3],
3451
+ [0.25, 0.4, 0.25],
3452
+ [0.2, 0.5, 0.2],
3453
+ [0.10, 0.7, 0.10],
3454
+ [0, 0.9, 0],
3210
3455
  ]
3211
3456
 
3457
+ def get_start(arr):
3458
+ min_value = np.min(arr)
3459
+ if min_value < -0.75:
3460
+ return 0
3461
+ if min_value < -0.5:
3462
+ return 1
3463
+ if min_value < -0.25:
3464
+ return 2
3465
+ if min_value < 0:
3466
+ return 3
3467
+ if min_value < 0.25:
3468
+ return 4
3469
+ if min_value < 0.5:
3470
+ return 5
3471
+ if min_value < 0.75:
3472
+ return 6
3473
+ return 7
3212
3474
 
3213
- class Sentinel2Band(Sentinel2Config, Band):
3214
- """Band with Sentinel2 specific name variables and regexes."""
3215
-
3216
-
3217
- class Sentinel2Image(Sentinel2Config, Image):
3218
- """Image with Sentinel2 specific name variables and regexes."""
3219
-
3220
- cloud_cover_regexes: ClassVar[tuple[str]] = config.CLOUD_COVERAGE_REGEXES
3221
- band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3222
-
3223
- def ndvi(
3224
- self,
3225
- red_band: str = Sentinel2Config.ndvi_bands[0],
3226
- nir_band: str = Sentinel2Config.ndvi_bands[1],
3227
- copy: bool = True,
3228
- ) -> NDVIBand:
3229
- """Calculate the NDVI for the Image."""
3230
- return super().ndvi(red_band=red_band, nir_band=nir_band, copy=copy)
3231
-
3232
-
3233
- class Sentinel2Collection(Sentinel2Config, ImageCollection):
3234
- """ImageCollection with Sentinel2 specific name variables and regexes."""
3235
-
3236
- image_class: ClassVar[Sentinel2Image] = Sentinel2Image
3237
- band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3238
-
3239
-
3240
- class Sentinel2CloudlessBand(Sentinel2CloudlessConfig, Band):
3241
- """Band for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3242
-
3475
+ def get_stop(arr):
3476
+ max_value = np.max(arr)
3477
+ if max_value <= 0.05:
3478
+ return 0
3479
+ if max_value < 0.175:
3480
+ return 1
3481
+ if max_value < 0.25:
3482
+ return 2
3483
+ if max_value < 0.375:
3484
+ return 3
3485
+ if max_value < 0.5:
3486
+ return 4
3487
+ if max_value < 0.75:
3488
+ return 5
3489
+ return 6
3243
3490
 
3244
- class Sentinel2CloudlessImage(Sentinel2CloudlessConfig, Sentinel2Image):
3245
- """Image for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3491
+ cmap_name = "blue_gray_green"
3246
3492
 
3247
- cloud_cover_regexes: ClassVar[None] = None
3248
- band_class: ClassVar[Sentinel2CloudlessBand] = Sentinel2CloudlessBand
3493
+ start = get_start(arr)
3494
+ stop = get_stop(arr)
3495
+ blue = blue[start]
3496
+ gray = gray[start]
3497
+ # green = green[start]
3498
+ green = green[stop]
3249
3499
 
3250
- ndvi = Sentinel2Image.ndvi
3500
+ # green[0] = np.arange(0, 1, 0.1)[::-1][stop]
3501
+ # green[1] = np.arange(0, 1, 0.1)[stop]
3502
+ # green[2] = np.arange(0, 1, 0.1)[::-1][stop]
3251
3503
 
3504
+ print(green)
3505
+ print(start, stop)
3506
+ print("blue gray green")
3507
+ print(blue)
3508
+ print(gray)
3509
+ print(green)
3252
3510
 
3253
- class Sentinel2CloudlessCollection(Sentinel2CloudlessConfig, ImageCollection):
3254
- """ImageCollection with Sentinel2 specific name variables and regexes."""
3511
+ # Define the segments of the colormap
3512
+ cdict = {
3513
+ "red": [
3514
+ (0.0, blue[0], blue[0]),
3515
+ (0.3, gray[0], gray[0]),
3516
+ (0.7, gray[0], gray[0]),
3517
+ (1.0, green[0], green[0]),
3518
+ ],
3519
+ "green": [
3520
+ (0.0, blue[1], blue[1]),
3521
+ (0.3, gray[1], gray[1]),
3522
+ (0.7, gray[1], gray[1]),
3523
+ (1.0, green[1], green[1]),
3524
+ ],
3525
+ "blue": [
3526
+ (0.0, blue[2], blue[2]),
3527
+ (0.3, gray[2], gray[2]),
3528
+ (0.7, gray[2], gray[2]),
3529
+ (1.0, green[2], green[2]),
3530
+ ],
3531
+ }
3255
3532
 
3256
- image_class: ClassVar[Sentinel2CloudlessImage] = Sentinel2CloudlessImage
3257
- band_class: ClassVar[Sentinel2Band] = Sentinel2CloudlessBand
3533
+ return LinearSegmentedColormap(cmap_name, segmentdata=cdict, N=50)