ssb-sgis 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sgis/__init__.py +20 -9
  2. sgis/debug_config.py +24 -0
  3. sgis/exceptions.py +2 -2
  4. sgis/geopandas_tools/bounds.py +33 -36
  5. sgis/geopandas_tools/buffer_dissolve_explode.py +136 -35
  6. sgis/geopandas_tools/centerlines.py +4 -91
  7. sgis/geopandas_tools/cleaning.py +1576 -583
  8. sgis/geopandas_tools/conversion.py +38 -19
  9. sgis/geopandas_tools/duplicates.py +29 -8
  10. sgis/geopandas_tools/general.py +263 -100
  11. sgis/geopandas_tools/geometry_types.py +4 -4
  12. sgis/geopandas_tools/neighbors.py +19 -15
  13. sgis/geopandas_tools/overlay.py +2 -2
  14. sgis/geopandas_tools/point_operations.py +5 -5
  15. sgis/geopandas_tools/polygon_operations.py +510 -105
  16. sgis/geopandas_tools/polygons_as_rings.py +40 -8
  17. sgis/geopandas_tools/sfilter.py +29 -12
  18. sgis/helpers.py +3 -3
  19. sgis/io/dapla_functions.py +238 -19
  20. sgis/io/read_parquet.py +1 -1
  21. sgis/maps/examine.py +27 -12
  22. sgis/maps/explore.py +450 -65
  23. sgis/maps/legend.py +177 -76
  24. sgis/maps/map.py +206 -103
  25. sgis/maps/maps.py +178 -105
  26. sgis/maps/thematicmap.py +243 -83
  27. sgis/networkanalysis/_service_area.py +6 -1
  28. sgis/networkanalysis/closing_network_holes.py +2 -2
  29. sgis/networkanalysis/cutting_lines.py +15 -8
  30. sgis/networkanalysis/directednetwork.py +1 -1
  31. sgis/networkanalysis/finding_isolated_networks.py +15 -8
  32. sgis/networkanalysis/networkanalysis.py +17 -19
  33. sgis/networkanalysis/networkanalysisrules.py +1 -1
  34. sgis/networkanalysis/traveling_salesman.py +1 -1
  35. sgis/parallel/parallel.py +64 -27
  36. sgis/raster/__init__.py +0 -6
  37. sgis/raster/base.py +208 -0
  38. sgis/raster/cube.py +54 -8
  39. sgis/raster/image_collection.py +3257 -0
  40. sgis/raster/indices.py +17 -5
  41. sgis/raster/raster.py +138 -243
  42. sgis/raster/sentinel_config.py +120 -0
  43. sgis/raster/zonal.py +0 -1
  44. {ssb_sgis-1.0.2.dist-info → ssb_sgis-1.0.4.dist-info}/METADATA +6 -7
  45. ssb_sgis-1.0.4.dist-info/RECORD +62 -0
  46. sgis/raster/methods_as_functions.py +0 -0
  47. sgis/raster/torchgeo.py +0 -171
  48. ssb_sgis-1.0.2.dist-info/RECORD +0 -61
  49. {ssb_sgis-1.0.2.dist-info → ssb_sgis-1.0.4.dist-info}/LICENSE +0 -0
  50. {ssb_sgis-1.0.2.dist-info → ssb_sgis-1.0.4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,3257 @@
1
+ import datetime
2
+ import functools
3
+ import glob
4
+ import itertools
5
+ import math
6
+ import os
7
+ import random
8
+ import re
9
+ from collections.abc import Callable
10
+ from collections.abc import Iterable
11
+ from collections.abc import Iterator
12
+ from collections.abc import Sequence
13
+ from copy import deepcopy
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import Any
17
+ from typing import ClassVar
18
+
19
+ import joblib
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import pandas as pd
23
+ import pyproj
24
+ import rasterio
25
+ from affine import Affine
26
+ from geopandas import GeoDataFrame
27
+ from geopandas import GeoSeries
28
+ from matplotlib.colors import LinearSegmentedColormap
29
+ from rasterio.enums import MergeAlg
30
+ from rtree.index import Index
31
+ from rtree.index import Property
32
+ from scipy import stats
33
+ from scipy.ndimage import binary_dilation
34
+ from scipy.ndimage import binary_erosion
35
+ from shapely import Geometry
36
+ from shapely import box
37
+ from shapely import unary_union
38
+ from shapely.geometry import MultiPolygon
39
+ from shapely.geometry import Point
40
+ from shapely.geometry import Polygon
41
+
42
+ try:
43
+ import dapla as dp
44
+ from dapla.gcs import GCSFileSystem
45
+ except ImportError:
46
+
47
+ class GCSFileSystem:
48
+ """Placeholder."""
49
+
50
+
51
+ try:
52
+ from rioxarray.exceptions import NoDataInBounds
53
+ from rioxarray.merge import merge_arrays
54
+ from rioxarray.rioxarray import _generate_spatial_coords
55
+ except ImportError:
56
+ pass
57
+ try:
58
+ import xarray as xr
59
+ from xarray import DataArray
60
+ except ImportError:
61
+
62
+ class DataArray:
63
+ """Placeholder."""
64
+
65
+
66
+ try:
67
+ import torch
68
+ except ImportError:
69
+ pass
70
+
71
+ try:
72
+ from gcsfs.core import GCSFile
73
+ except ImportError:
74
+
75
+ class GCSFile:
76
+ """Placeholder."""
77
+
78
+
79
+ try:
80
+ from torchgeo.datasets.utils import disambiguate_timestamp
81
+ except ImportError:
82
+
83
+ class torch:
84
+ """Placeholder."""
85
+
86
+ class Tensor:
87
+ """Placeholder to reference torch.Tensor."""
88
+
89
+
90
+ try:
91
+ from torchgeo.datasets.utils import BoundingBox
92
+ except ImportError:
93
+
94
+ class BoundingBox:
95
+ """Placeholder."""
96
+
97
+ def __init__(self, *args, **kwargs) -> None:
98
+ """Placeholder."""
99
+ raise ImportError("missing optional dependency 'torchgeo'")
100
+
101
+
102
+ from ..geopandas_tools.bounds import get_total_bounds
103
+ from ..geopandas_tools.conversion import to_bbox
104
+ from ..geopandas_tools.conversion import to_gdf
105
+ from ..geopandas_tools.conversion import to_shapely
106
+ from ..geopandas_tools.general import get_common_crs
107
+ from ..helpers import get_all_files
108
+ from ..helpers import get_numpy_func
109
+ from ..io._is_dapla import is_dapla
110
+ from ..io.opener import opener
111
+ from . import sentinel_config as config
112
+ from .base import _array_to_geojson
113
+ from .base import _gdf_to_arr
114
+ from .base import _get_shape_from_bounds
115
+ from .base import _get_transform_from_bounds
116
+ from .base import get_index_mapper
117
+ from .indices import ndvi
118
+ from .zonal import _aggregate
119
+ from .zonal import _make_geometry_iterrows
120
+ from .zonal import _no_overlap_df
121
+ from .zonal import _prepare_zonal
122
+ from .zonal import _zonal_post
123
+
124
+ if is_dapla():
125
+
126
+ def _ls_func(*args, **kwargs) -> list[str]:
127
+ return dp.FileClient.get_gcs_file_system().ls(*args, **kwargs)
128
+
129
+ def _glob_func(*args, **kwargs) -> list[str]:
130
+ return dp.FileClient.get_gcs_file_system().glob(*args, **kwargs)
131
+
132
+ def _open_func(*args, **kwargs) -> GCSFile:
133
+ return dp.FileClient.get_gcs_file_system().open(*args, **kwargs)
134
+
135
+ def _rm_file_func(*args, **kwargs) -> None:
136
+ return dp.FileClient.get_gcs_file_system().rm_file(*args, **kwargs)
137
+
138
+ def _read_parquet_func(*args, **kwargs) -> list[str]:
139
+ return dp.read_pandas(*args, **kwargs)
140
+
141
+ else:
142
+ _ls_func = functools.partial(get_all_files, recursive=False)
143
+ _open_func = open
144
+ _glob_func = glob.glob
145
+ _rm_file_func = os.remove
146
+ _read_parquet_func = pd.read_parquet
147
+
148
+ TORCHGEO_RETURN_TYPE = dict[str, torch.Tensor | pyproj.CRS | BoundingBox]
149
+ FILENAME_COL_SUFFIX = "_filename"
150
+ DEFAULT_FILENAME_REGEX = r"""
151
+ .*?
152
+ (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
153
+ .*?
154
+ (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
155
+ \.(?:tif|tiff|jp2)$ # End with .tif, .tiff, or .jp2
156
+ """
157
+ DEFAULT_IMAGE_REGEX = r"""
158
+ .*?
159
+ (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
160
+ (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
161
+ """
162
+
163
+ ALLOWED_INIT_KWARGS = [
164
+ "image_class",
165
+ "band_class",
166
+ "image_regexes",
167
+ "filename_regexes",
168
+ "date_format",
169
+ "cloud_cover_regexes",
170
+ "bounds_regexes",
171
+ "all_bands",
172
+ "crs",
173
+ "masking",
174
+ "_merged",
175
+ ]
176
+
177
+
178
+ class ImageCollectionGroupBy:
179
+ """Iterator and merger class returned from groupby.
180
+
181
+ Can be iterated through like pandas.DataFrameGroupBy.
182
+ Or use the methods merge_by_band or merge.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ data: Iterable[tuple[Any], "ImageCollection"],
188
+ by: list[str],
189
+ collection: "ImageCollection",
190
+ ) -> None:
191
+ """Initialiser.
192
+
193
+ Args:
194
+ data: Iterable of group values and ImageCollection groups.
195
+ by: list of group attributes.
196
+ collection: ImageCollection instance. Used to pass attributes.
197
+ """
198
+ self.data = list(data)
199
+ self.by = by
200
+ self.collection = collection
201
+
202
+ def merge_by_band(
203
+ self,
204
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
205
+ method: str | Callable = "mean",
206
+ as_int: bool = True,
207
+ indexes: int | tuple[int] | None = None,
208
+ **kwargs,
209
+ ) -> "ImageCollection":
210
+ """Merge each group into separate Bands per band_id, returned as an ImageCollection."""
211
+ images = self._run_func_for_collection_groups(
212
+ _merge_by_band,
213
+ method=method,
214
+ bounds=bounds,
215
+ as_int=as_int,
216
+ indexes=indexes,
217
+ **kwargs,
218
+ )
219
+ for img, (group_values, _) in zip(images, self.data, strict=True):
220
+ for attr, group_value in zip(self.by, group_values, strict=True):
221
+ try:
222
+ setattr(img, attr, group_value)
223
+ except AttributeError:
224
+ setattr(img, f"_{attr}", group_value)
225
+
226
+ collection = ImageCollection(
227
+ images,
228
+ # TODO band_class?
229
+ level=self.collection.level,
230
+ **self.collection._common_init_kwargs,
231
+ )
232
+ collection._merged = True
233
+ return collection
234
+
235
+ def merge(
236
+ self,
237
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
238
+ method: str | Callable = "mean",
239
+ as_int: bool = True,
240
+ indexes: int | tuple[int] | None = None,
241
+ **kwargs,
242
+ ) -> "Image":
243
+ """Merge each group into a single Band, returned as combined Image."""
244
+ bands: list[Band] = self._run_func_for_collection_groups(
245
+ _merge,
246
+ method=method,
247
+ bounds=bounds,
248
+ as_int=as_int,
249
+ indexes=indexes,
250
+ **kwargs,
251
+ )
252
+ for band, (group_values, _) in zip(bands, self.data, strict=True):
253
+ for attr, group_value in zip(self.by, group_values, strict=True):
254
+ try:
255
+ setattr(band, attr, group_value)
256
+ except AttributeError:
257
+ if hasattr(band, f"_{attr}"):
258
+ setattr(band, f"_{attr}", group_value)
259
+
260
+ if "band_id" in self.by:
261
+ for band in bands:
262
+ assert band.band_id is not None
263
+
264
+ image = Image(
265
+ bands,
266
+ # TODO band_class?
267
+ **self.collection._common_init_kwargs,
268
+ )
269
+ image._merged = True
270
+ return image
271
+
272
+ def _run_func_for_collection_groups(self, func: Callable, **kwargs) -> list[Any]:
273
+ if self.collection.processes == 1:
274
+ return [func(group, **kwargs) for _, group in self]
275
+ processes = min(self.collection.processes, len(self))
276
+
277
+ if processes == 0:
278
+ return []
279
+
280
+ with joblib.Parallel(n_jobs=processes, backend="threading") as parallel:
281
+ return parallel(joblib.delayed(func)(group, **kwargs) for _, group in self)
282
+
283
+ def __iter__(self) -> Iterator[tuple[tuple[Any, ...], "ImageCollection"]]:
284
+ """Iterate over the group values and the ImageCollection groups themselves."""
285
+ return iter(self.data)
286
+
287
+ def __len__(self) -> int:
288
+ """Number of ImageCollection groups."""
289
+ return len(self.data)
290
+
291
+ def __repr__(self) -> str:
292
+ """String representation."""
293
+ return f"{self.__class__.__name__}({len(self)})"
294
+
295
+
296
+ @dataclass(frozen=True)
297
+ class BandMasking:
298
+ """Basically a frozen dict with forced keys."""
299
+
300
+ band_id: str
301
+ values: tuple[int]
302
+
303
+ def __getitem__(self, item: str) -> Any:
304
+ """Index into attributes to mimick dict."""
305
+ return getattr(self, item)
306
+
307
+
308
+ class _ImageBase:
309
+ image_regexes: ClassVar[str | None] = (DEFAULT_IMAGE_REGEX,)
310
+ filename_regexes: ClassVar[str | tuple[str]] = (DEFAULT_FILENAME_REGEX,)
311
+ date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
312
+ masking: ClassVar[BandMasking | None] = None
313
+
314
+ def __init__(self, **kwargs) -> None:
315
+
316
+ self._mask = None
317
+ self._bounds = None
318
+ self._merged = False
319
+ self._from_array = False
320
+ self._from_gdf = False
321
+
322
+ if self.filename_regexes:
323
+ if isinstance(self.filename_regexes, str):
324
+ self.filename_regexes = (self.filename_regexes,)
325
+ self.filename_patterns = [
326
+ re.compile(regexes, flags=re.VERBOSE)
327
+ for regexes in self.filename_regexes
328
+ ]
329
+ else:
330
+ self.filename_patterns = ()
331
+
332
+ if self.image_regexes:
333
+ if isinstance(self.image_regexes, str):
334
+ self.image_regexes = (self.image_regexes,)
335
+ self.image_patterns = [
336
+ re.compile(regexes, flags=re.VERBOSE) for regexes in self.image_regexes
337
+ ]
338
+ else:
339
+ self.image_patterns = ()
340
+
341
+ for key, value in kwargs.items():
342
+ if key in ALLOWED_INIT_KWARGS and key in dir(self):
343
+ setattr(self, key, value)
344
+ else:
345
+ raise ValueError(
346
+ f"{self.__class__.__name__} got an unexpected keyword argument '{key}'"
347
+ )
348
+
349
+ @property
350
+ def _common_init_kwargs(self) -> dict:
351
+ return {
352
+ "file_system": self.file_system,
353
+ "processes": self.processes,
354
+ "res": self.res,
355
+ "bbox": self._bbox,
356
+ "nodata": self.nodata,
357
+ }
358
+
359
+ @property
360
+ def path(self) -> str:
361
+ try:
362
+ return self._path
363
+ except AttributeError as e:
364
+ raise PathlessImageError(self) from e
365
+
366
+ @property
367
+ def res(self) -> int:
368
+ """Pixel resolution."""
369
+ return self._res
370
+
371
+ @property
372
+ def centroid(self) -> Point:
373
+ """Centerpoint of the object."""
374
+ return self.union_all().centroid
375
+
376
+ def _name_regex_searcher(
377
+ self, group: str, patterns: tuple[re.Pattern]
378
+ ) -> str | None:
379
+ if not patterns or not any(pat.groups for pat in patterns):
380
+ return None
381
+ for pat in patterns:
382
+ try:
383
+ return _get_first_group_match(pat, self.name)[group]
384
+ return re.match(pat, self.name).group(group)
385
+ except (TypeError, KeyError):
386
+ pass
387
+ if not any(group in _get_non_optional_groups(pat) for pat in patterns):
388
+ return None
389
+ raise ValueError(
390
+ f"Couldn't find group '{group}' in name {self.name} with regex patterns {patterns}"
391
+ )
392
+
393
+ def _create_metadata_df(self, file_paths: list[str]) -> pd.DataFrame:
394
+ """Create a dataframe with file paths and image paths that match regexes."""
395
+ df = pd.DataFrame({"file_path": file_paths})
396
+
397
+ df["filename"] = df["file_path"].apply(lambda x: _fix_path(Path(x).name))
398
+
399
+ if not self.single_banded:
400
+ df["image_path"] = df["file_path"].apply(
401
+ lambda x: _fix_path(str(Path(x).parent))
402
+ )
403
+ else:
404
+ df["image_path"] = df["file_path"]
405
+
406
+ if not len(df):
407
+ return df
408
+
409
+ if self.filename_patterns:
410
+ df = _get_regexes_matches_for_df(df, "filename", self.filename_patterns)
411
+
412
+ if not len(df):
413
+ return df
414
+
415
+ grouped = df.drop_duplicates("image_path").set_index("image_path")
416
+ for col in ["file_path", "filename"]:
417
+ if col in df:
418
+ grouped[col] = df.groupby("image_path")[col].apply(tuple)
419
+
420
+ grouped = grouped.reset_index()
421
+ else:
422
+ df["file_path"] = df.groupby("image_path")["file_path"].apply(tuple)
423
+ df["filename"] = df.groupby("image_path")["filename"].apply(tuple)
424
+ grouped = df.drop_duplicates("image_path")
425
+
426
+ grouped["imagename"] = grouped["image_path"].apply(
427
+ lambda x: _fix_path(Path(x).name)
428
+ )
429
+
430
+ if self.image_patterns and len(grouped):
431
+ grouped = _get_regexes_matches_for_df(
432
+ grouped, "imagename", self.image_patterns
433
+ )
434
+
435
+ return grouped
436
+
437
+ def copy(self) -> "_ImageBase":
438
+ """Copy the instance and its attributes."""
439
+ copied = deepcopy(self)
440
+ for key, value in copied.__dict__.items():
441
+ try:
442
+ setattr(copied, key, value.copy())
443
+ except AttributeError:
444
+ setattr(copied, key, deepcopy(value))
445
+ except TypeError:
446
+ continue
447
+ return copied
448
+
449
+
450
+ class _ImageBandBase(_ImageBase):
451
+ def intersects(self, other: GeoDataFrame | GeoSeries | Geometry) -> bool:
452
+ if hasattr(other, "crs") and not pyproj.CRS(self.crs).equals(
453
+ pyproj.CRS(other.crs)
454
+ ):
455
+ raise ValueError(f"crs mismatch: {self.crs} and {other.crs}")
456
+ return self.union_all().intersects(to_shapely(other))
457
+
458
+ @property
459
+ def mask_percentage(self) -> float:
460
+ return self.mask.values.sum() / (self.mask.width * self.mask.height) * 100
461
+
462
+ @property
463
+ def year(self) -> str:
464
+ if hasattr(self, "_year") and self._year:
465
+ return self._year
466
+ return self.date[:4]
467
+
468
+ @property
469
+ def month(self) -> str:
470
+ if hasattr(self, "_month") and self._month:
471
+ return self._month
472
+ return "".join(self.date.split("-"))[4:6]
473
+
474
+ @property
475
+ def name(self) -> str | None:
476
+ if hasattr(self, "_name") and self._name is not None:
477
+ return self._name
478
+ try:
479
+ return Path(self.path).name
480
+ except (ValueError, AttributeError):
481
+ return None
482
+
483
+ @name.setter
484
+ def name(self, value) -> None:
485
+ self._name = value
486
+
487
+ @property
488
+ def stem(self) -> str | None:
489
+ try:
490
+ return Path(self.path).stem
491
+ except (AttributeError, ValueError):
492
+ return None
493
+
494
+ @property
495
+ def level(self) -> str:
496
+ return self._name_regex_searcher("level", self.image_patterns)
497
+
498
+ @property
499
+ def mint(self) -> float:
500
+ return disambiguate_timestamp(self.date, self.date_format)[0]
501
+
502
+ @property
503
+ def maxt(self) -> float:
504
+ return disambiguate_timestamp(self.date, self.date_format)[1]
505
+
506
+ def union_all(self) -> Polygon:
507
+ try:
508
+ return box(*self.bounds)
509
+ except TypeError:
510
+ return Polygon()
511
+
512
+ @property
513
+ def torch_bbox(self) -> BoundingBox:
514
+ bounds = GeoSeries([self.union_all()]).bounds
515
+ return BoundingBox(
516
+ minx=bounds.minx[0],
517
+ miny=bounds.miny[0],
518
+ maxx=bounds.maxx[0],
519
+ maxy=bounds.maxy[0],
520
+ mint=self.mint,
521
+ maxt=self.maxt,
522
+ )
523
+
524
+
525
+ class Band(_ImageBandBase):
526
+ """Band holding a single 2 dimensional array representing an image band."""
527
+
528
+ cmap: ClassVar[str | None] = None
529
+
530
+ @classmethod
531
+ def from_gdf(
532
+ cls,
533
+ gdf: GeoDataFrame | GeoSeries,
534
+ res: int,
535
+ *,
536
+ fill: int = 0,
537
+ all_touched: bool = False,
538
+ merge_alg: Callable = MergeAlg.replace,
539
+ default_value: int = 1,
540
+ dtype: Any | None = None,
541
+ **kwargs,
542
+ ) -> None:
543
+ """Create Band from a GeoDataFrame."""
544
+ arr: np.ndarray = _gdf_to_arr(
545
+ gdf,
546
+ res=res,
547
+ fill=fill,
548
+ all_touched=all_touched,
549
+ merge_alg=merge_alg,
550
+ default_value=default_value,
551
+ dtype=dtype,
552
+ )
553
+
554
+ obj = cls(arr, res=res, crs=gdf.crs, bounds=gdf.total_bounds, **kwargs)
555
+ obj._from_gdf = True
556
+ return obj
557
+
558
+ def __init__(
559
+ self,
560
+ data: str | np.ndarray,
561
+ res: int | None,
562
+ crs: Any | None = None,
563
+ bounds: tuple[float, float, float, float] | None = None,
564
+ cmap: str | None = None,
565
+ name: str | None = None,
566
+ file_system: GCSFileSystem | None = None,
567
+ band_id: str | None = None,
568
+ processes: int = 1,
569
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
570
+ mask: "Band | None" = None,
571
+ nodata: int | None = None,
572
+ **kwargs,
573
+ ) -> None:
574
+ """Band initialiser."""
575
+ super().__init__(**kwargs)
576
+
577
+ self._mask = mask
578
+ self._bbox = to_bbox(bbox) if bbox is not None else None
579
+ self._values = None
580
+ self._crs = None
581
+ self.nodata = nodata
582
+
583
+ bounds = to_bbox(bounds) if bounds is not None else None
584
+
585
+ self._bounds = bounds
586
+
587
+ if isinstance(data, np.ndarray):
588
+ self.values = data
589
+ if self._bounds is None:
590
+ raise ValueError("Must specify bounds when data is an array.")
591
+ self._crs = crs
592
+ self.transform = _get_transform_from_bounds(
593
+ self._bounds, shape=self.values.shape
594
+ )
595
+ self._from_array = True
596
+
597
+ elif not isinstance(data, (str | Path | os.PathLike)):
598
+ raise TypeError(
599
+ "'data' must be string, Path-like or numpy.ndarray. "
600
+ f"Got {type(data)}"
601
+ )
602
+ else:
603
+ self._path = str(data)
604
+
605
+ self._res = res
606
+ if cmap is not None:
607
+ self.cmap = cmap
608
+ self.file_system = file_system
609
+ self._name = name
610
+ self._band_id = band_id
611
+ self.processes = processes
612
+
613
+ # if self.filename_regexes:
614
+ # if isinstance(self.filename_regexes, str):
615
+ # self.filename_regexes = [self.filename_regexes]
616
+ # self.filename_patterns = [
617
+ # re.compile(pat, flags=re.VERBOSE) for pat in self.filename_regexes
618
+ # ]
619
+ # else:
620
+ # self.filename_patterns = None
621
+
622
+ def __lt__(self, other: "Band") -> bool:
623
+ """Makes Bands sortable by band_id."""
624
+ return self.band_id < other.band_id
625
+
626
+ @property
627
+ def values(self) -> np.ndarray:
628
+ """The numpy array, if loaded."""
629
+ if self._values is None:
630
+ raise ArrayNotLoadedError("array is not loaded.")
631
+ return self._values
632
+
633
+ @values.setter
634
+ def values(self, new_val):
635
+ if not isinstance(new_val, np.ndarray):
636
+ raise TypeError(
637
+ f"{self.__class__.__name__} 'values' must be np.ndarray. Got {type(new_val)}"
638
+ )
639
+ self._values = new_val
640
+
641
+ @property
642
+ def mask(self) -> "Band":
643
+ """Mask Band."""
644
+ return self._mask
645
+
646
+ @mask.setter
647
+ def mask(self, values: "Band") -> None:
648
+ if values is not None and not isinstance(values, Band):
649
+ raise TypeError(f"'mask' should be of type Band. Got {type(values)}")
650
+ self._mask = values
651
+
652
+ @property
653
+ def band_id(self) -> str:
654
+ """Band id."""
655
+ if self._band_id is not None:
656
+ return self._band_id
657
+ return self._name_regex_searcher("band", self.filename_patterns)
658
+
659
+ @property
660
+ def height(self) -> int:
661
+ """Pixel heigth of the image band."""
662
+ return self.values.shape[-2]
663
+
664
+ @property
665
+ def width(self) -> int:
666
+ """Pixel width of the image band."""
667
+ return self.values.shape[-1]
668
+
669
+ @property
670
+ def tile(self) -> str:
671
+ """Tile name from filename_regex."""
672
+ if hasattr(self, "_tile") and self._tile:
673
+ return self._tile
674
+ return self._name_regex_searcher(
675
+ "tile", self.filename_patterns + self.image_patterns
676
+ )
677
+
678
+ @property
679
+ def date(self) -> str:
680
+ """Tile name from filename_regex."""
681
+ if hasattr(self, "_date") and self._date:
682
+ return self._date
683
+
684
+ return self._name_regex_searcher(
685
+ "date", self.filename_patterns + self.image_patterns
686
+ )
687
+
688
+ @property
689
+ def crs(self) -> str | None:
690
+ """Coordinate reference system."""
691
+ if self._crs is not None:
692
+ return self._crs
693
+ with opener(self.path, file_system=self.file_system) as file:
694
+ with rasterio.open(file) as src:
695
+ # self._bounds = to_bbox(src.bounds)
696
+ self._crs = src.crs
697
+ return self._crs
698
+
699
+ @property
700
+ def bounds(self) -> tuple[int, int, int, int] | None:
701
+ """Bounds as tuple (minx, miny, maxx, maxy)."""
702
+ if self._bounds is not None:
703
+ return self._bounds
704
+ with opener(self.path, file_system=self.file_system) as file:
705
+ with rasterio.open(file) as src:
706
+ self._bounds = to_bbox(src.bounds)
707
+ self._crs = src.crs
708
+ return self._bounds
709
+
710
+ def get_n_largest(
711
+ self, n: int, precision: float = 0.000001, column: str = "value"
712
+ ) -> GeoDataFrame:
713
+ """Get the largest values of the array as polygons in a GeoDataFrame."""
714
+ copied = self.copy()
715
+ value_must_be_at_least = np.sort(np.ravel(copied.values))[-n] - (precision or 0)
716
+ copied._values = np.where(copied.values >= value_must_be_at_least, 1, 0)
717
+ df = copied.to_gdf(column).loc[lambda x: x[column] == 1]
718
+ df[column] = f"largest_{n}"
719
+ return df
720
+
721
+ def get_n_smallest(
722
+ self, n: int, precision: float = 0.000001, column: str = "value"
723
+ ) -> GeoDataFrame:
724
+ """Get the lowest values of the array as polygons in a GeoDataFrame."""
725
+ copied = self.copy()
726
+ value_must_be_at_least = np.sort(np.ravel(copied.values))[n] - (precision or 0)
727
+ copied._values = np.where(copied.values <= value_must_be_at_least, 1, 0)
728
+ df = copied.to_gdf(column).loc[lambda x: x[column] == 1]
729
+ df[column] = f"smallest_{n}"
730
+ return df
731
+
732
+ def load(
733
+ self,
734
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
735
+ indexes: int | tuple[int] | None = None,
736
+ masked: bool | None = None,
737
+ **kwargs,
738
+ ) -> "Band":
739
+ """Load and potentially clip the array.
740
+
741
+ The array is stored in the 'values' property.
742
+ """
743
+ if masked is None:
744
+ masked = True if self.mask is None else False
745
+
746
+ bounds_was_none = bounds is None
747
+
748
+ try:
749
+ if not isinstance(self.values, np.ndarray):
750
+ raise ValueError()
751
+ has_array = True
752
+ except ValueError: # also catches ArrayNotLoadedError
753
+ has_array = False
754
+
755
+ # get common bounds of function argument 'bounds' and previously set bbox
756
+ if bounds is None and self._bbox is None:
757
+ bounds = None
758
+ elif bounds is not None and self._bbox is None:
759
+ bounds = to_shapely(bounds).intersection(self.union_all())
760
+ elif bounds is None and self._bbox is not None:
761
+ bounds = to_shapely(self._bbox).intersection(self.union_all())
762
+ else:
763
+ bounds = to_shapely(bounds).intersection(to_shapely(self._bbox))
764
+
765
+ should_return_empty: bool = bounds is not None and bounds.area == 0
766
+ if should_return_empty:
767
+ self._values = np.array([])
768
+ if self.mask is not None and not self.is_mask:
769
+ self._mask = self._mask.load()
770
+ # self._mask = np.ma.array([], [])
771
+ self._bounds = None
772
+ self.transform = None
773
+ return self
774
+
775
+ if has_array and bounds_was_none:
776
+ return self
777
+
778
+ # round down/up to integer to avoid precision trouble
779
+ if bounds is not None:
780
+ # bounds = to_bbox(bounds)
781
+ minx, miny, maxx, maxy = to_bbox(bounds)
782
+ bounds = (int(minx), int(miny), math.ceil(maxx), math.ceil(maxy))
783
+
784
+ boundless = False
785
+
786
+ if indexes is None:
787
+ indexes = 1
788
+
789
+ # as tuple to ensure we get 3d array
790
+ _indexes: tuple[int] = (indexes,) if isinstance(indexes, int) else indexes
791
+
792
+ # allow setting a fixed out_shape for the array, in order to make mask same shape as values
793
+ out_shape = kwargs.pop("out_shape", None)
794
+
795
+ if has_array:
796
+ self.values = _clip_loaded_array(
797
+ self.values, bounds, self.transform, self.crs, out_shape, **kwargs
798
+ )
799
+ self._bounds = bounds
800
+ self.transform = _get_transform_from_bounds(self._bounds, self.values.shape)
801
+
802
+ else:
803
+ with opener(self.path, file_system=self.file_system) as f:
804
+ with rasterio.open(f, nodata=self.nodata) as src:
805
+ self._res = int(src.res[0]) if not self.res else self.res
806
+
807
+ if self.nodata is None or np.isnan(self.nodata):
808
+ self.nodata = src.nodata
809
+ else:
810
+ dtype_min_value = _get_dtype_min(src.dtypes[0])
811
+ dtype_max_value = _get_dtype_max(src.dtypes[0])
812
+ if (
813
+ self.nodata > dtype_max_value
814
+ or self.nodata < dtype_min_value
815
+ ):
816
+ src._dtypes = tuple(
817
+ rasterio.dtypes.get_minimum_dtype(self.nodata)
818
+ for _ in range(len(_indexes))
819
+ )
820
+
821
+ if bounds is None:
822
+ if self._res != int(src.res[0]):
823
+ if out_shape is None:
824
+ out_shape = _get_shape_from_bounds(
825
+ to_bbox(src.bounds), self.res, indexes
826
+ )
827
+ self.transform = _get_transform_from_bounds(
828
+ to_bbox(src.bounds), shape=out_shape
829
+ )
830
+ else:
831
+ self.transform = src.transform
832
+
833
+ self._values = src.read(
834
+ indexes=indexes,
835
+ out_shape=out_shape,
836
+ masked=masked,
837
+ **kwargs,
838
+ )
839
+ else:
840
+ window = rasterio.windows.from_bounds(
841
+ *bounds, transform=src.transform
842
+ )
843
+
844
+ if out_shape is None:
845
+ out_shape = _get_shape_from_bounds(
846
+ bounds, self.res, indexes
847
+ )
848
+
849
+ self._values = src.read(
850
+ indexes=indexes,
851
+ window=window,
852
+ boundless=boundless,
853
+ out_shape=out_shape,
854
+ masked=masked,
855
+ **kwargs,
856
+ )
857
+
858
+ assert out_shape == self._values.shape, (
859
+ out_shape,
860
+ self._values.shape,
861
+ )
862
+
863
+ self.transform = rasterio.transform.from_bounds(
864
+ *bounds, self.width, self.height
865
+ )
866
+ self._bounds = bounds
867
+
868
+ if self.nodata is not None and not np.isnan(self.nodata):
869
+ if isinstance(self.values, np.ma.core.MaskedArray):
870
+ self.values.data[self.values.data == src.nodata] = (
871
+ self.nodata
872
+ )
873
+ else:
874
+ self.values[self.values == src.nodata] = self.nodata
875
+
876
+ if self.masking and self.is_mask:
877
+ self.values = np.isin(self.values, self.masking["values"])
878
+
879
+ elif self.mask is not None and not isinstance(
880
+ self.values, np.ma.core.MaskedArray
881
+ ):
882
+ self.mask = self.mask.copy().load(
883
+ bounds=bounds, indexes=indexes, out_shape=out_shape, **kwargs
884
+ )
885
+ mask_arr = self.mask.values
886
+
887
+ # if self.masking:
888
+ # mask_arr = np.isin(mask_arr, self.masking["values"])
889
+
890
+ self._values = np.ma.array(
891
+ self._values, mask=mask_arr, fill_value=self.nodata
892
+ )
893
+
894
+ return self
895
+
896
+ @property
897
+ def is_mask(self) -> bool:
898
+ """True if the band_id is equal to the masking band_id."""
899
+ return self.band_id == self.masking["band_id"]
900
+
901
+ def write(
902
+ self, path: str | Path, driver: str = "GTiff", compress: str = "LZW", **kwargs
903
+ ) -> None:
904
+ """Write the array as an image file."""
905
+ if not hasattr(self, "_values"):
906
+ raise ValueError(
907
+ "Can only write image band from Band constructed from array."
908
+ )
909
+
910
+ if self.crs is None:
911
+ raise ValueError("Cannot write None crs to image.")
912
+
913
+ profile = {
914
+ "driver": driver,
915
+ "compress": compress,
916
+ "dtype": rasterio.dtypes.get_minimum_dtype(self.values),
917
+ "crs": self.crs,
918
+ "transform": self.transform,
919
+ "nodata": self.nodata,
920
+ "count": 1 if len(self.values.shape) == 2 else self.values.shape[0],
921
+ "height": self.height,
922
+ "width": self.width,
923
+ } | kwargs
924
+
925
+ with opener(path, "wb", file_system=self.file_system) as f:
926
+ with rasterio.open(f, "w", **profile) as dst:
927
+
928
+ if dst.nodata is None:
929
+ dst.nodata = _get_dtype_min(dst.dtypes[0])
930
+
931
+ # if (
932
+ # isinstance(self.values, np.ma.core.MaskedArray)
933
+ # # and dst.nodata is not None
934
+ # ):
935
+ # self.values.data[np.isnan(self.values.data)] = dst.nodata
936
+ # self.values.data[self.values.mask] = dst.nodata
937
+
938
+ if len(self.values.shape) == 2:
939
+ dst.write(self.values, indexes=1)
940
+ else:
941
+ for i in range(self.values.shape[0]):
942
+ dst.write(self.values[i], indexes=i + 1)
943
+
944
+ if isinstance(self.values, np.ma.core.MaskedArray):
945
+ dst.write_mask(self.values.mask)
946
+
947
+ self._path = str(path)
948
+
949
+ def apply(self, func: Callable, **kwargs) -> "Band":
950
+ """Apply a function to the array."""
951
+ self.values = func(self.values, **kwargs)
952
+ return self
953
+
954
+ def normalize(self) -> "Band":
955
+ """Normalize array values between 0 and 1."""
956
+ arr = self.values
957
+ self.values = (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
958
+ return self
959
+
960
+ def sample(self, size: int = 1000, mask: Any = None, **kwargs) -> "Image":
961
+ """Take a random spatial sample area of the Band."""
962
+ copied = self.copy()
963
+ if mask is not None:
964
+ point = GeoSeries([copied.union_all()]).clip(mask).sample_points(1)
965
+ else:
966
+ point = GeoSeries([copied.union_all()]).sample_points(1)
967
+ buffered = point.buffer(size / 2).clip(copied.union_all())
968
+ copied = copied.load(bounds=buffered.total_bounds, **kwargs)
969
+ return copied
970
+
971
+ def buffer(self, distance: int, copy: bool = True) -> "Band":
972
+ """Buffer array points with the value 1 in a binary array.
973
+
974
+ Args:
975
+ distance: Number of array cells to buffer by.
976
+ copy: Whether to copy the Band.
977
+
978
+ Returns:
979
+ Band with buffered values.
980
+ """
981
+ copied = self.copy() if copy else self
982
+ copied.values = array_buffer(copied.values, distance)
983
+ return copied
984
+
985
+ def gradient(self, degrees: bool = False, copy: bool = True) -> "Band":
986
+ """Get the slope of an elevation band.
987
+
988
+ Calculates the absolute slope between the grid cells
989
+ based on the image resolution.
990
+
991
+ For multi-band images, the calculation is done for each band.
992
+
993
+ Args:
994
+ band: band instance.
995
+ degrees: If False (default), the returned values will be in ratios,
996
+ where a value of 1 means 1 meter up per 1 meter forward. If True,
997
+ the values will be in degrees from 0 to 90.
998
+ copy: Whether to copy or overwrite the original Raster.
999
+ Defaults to True.
1000
+
1001
+ Returns:
1002
+ The class instance with new array values, or a copy if copy is True.
1003
+
1004
+ Examples:
1005
+ ---------
1006
+ Making an array where the gradient to the center is always 10.
1007
+
1008
+ >>> import sgis as sg
1009
+ >>> import numpy as np
1010
+ >>> arr = np.array(
1011
+ ... [
1012
+ ... [100, 100, 100, 100, 100],
1013
+ ... [100, 110, 110, 110, 100],
1014
+ ... [100, 110, 120, 110, 100],
1015
+ ... [100, 110, 110, 110, 100],
1016
+ ... [100, 100, 100, 100, 100],
1017
+ ... ]
1018
+ ... )
1019
+
1020
+ Now let's create a Raster from this array with a resolution of 10.
1021
+
1022
+ >>> band = sg.Band(arr, crs=None, bounds=(0, 0, 50, 50), res=10)
1023
+
1024
+ The gradient will be 1 (1 meter up for every meter forward).
1025
+ The calculation is by default done in place to save memory.
1026
+
1027
+ >>> band.gradient()
1028
+ >>> band.values
1029
+ array([[0., 1., 1., 1., 0.],
1030
+ [1., 1., 1., 1., 1.],
1031
+ [1., 1., 0., 1., 1.],
1032
+ [1., 1., 1., 1., 1.],
1033
+ [0., 1., 1., 1., 0.]])
1034
+ """
1035
+ copied = self.copy() if copy else self
1036
+ copied._values = _get_gradient(copied, degrees=degrees, copy=copy)
1037
+ return copied
1038
+
1039
+ def zonal(
1040
+ self,
1041
+ polygons: GeoDataFrame,
1042
+ aggfunc: str | Callable | list[Callable | str],
1043
+ array_func: Callable | None = None,
1044
+ dropna: bool = True,
1045
+ ) -> GeoDataFrame:
1046
+ """Calculate zonal statistics in polygons.
1047
+
1048
+ Args:
1049
+ polygons: A GeoDataFrame of polygon geometries.
1050
+ aggfunc: Function(s) of which to aggregate the values
1051
+ within each polygon.
1052
+ array_func: Optional calculation of the raster
1053
+ array before calculating the zonal statistics.
1054
+ dropna: If True (default), polygons with all missing
1055
+ values will be removed.
1056
+
1057
+ Returns:
1058
+ A GeoDataFrame with aggregated values per polygon.
1059
+ """
1060
+ idx_mapper, idx_name = get_index_mapper(polygons)
1061
+ polygons, aggfunc, func_names = _prepare_zonal(polygons, aggfunc)
1062
+ poly_iter = _make_geometry_iterrows(polygons)
1063
+
1064
+ kwargs = {
1065
+ "band": self,
1066
+ "aggfunc": aggfunc,
1067
+ "array_func": array_func,
1068
+ "func_names": func_names,
1069
+ }
1070
+
1071
+ if self.processes == 1:
1072
+ aggregated = [_zonal_one_pair(i, poly, **kwargs) for i, poly in poly_iter]
1073
+ else:
1074
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1075
+ aggregated = parallel(
1076
+ joblib.delayed(_zonal_one_pair)(i, poly, **kwargs)
1077
+ for i, poly in poly_iter
1078
+ )
1079
+
1080
+ return _zonal_post(
1081
+ aggregated,
1082
+ polygons=polygons,
1083
+ idx_mapper=idx_mapper,
1084
+ idx_name=idx_name,
1085
+ dropna=dropna,
1086
+ )
1087
+
1088
+ def to_gdf(self, column: str = "value") -> GeoDataFrame:
1089
+ """Create a GeoDataFrame from the image Band.
1090
+
1091
+ Args:
1092
+ column: Name of resulting column that holds the raster values.
1093
+
1094
+ Returns:
1095
+ A GeoDataFrame with a geometry column and array values.
1096
+ """
1097
+ if not hasattr(self, "_values"):
1098
+ raise ValueError("Array is not loaded.")
1099
+
1100
+ if self.values.shape[0] == 0:
1101
+ return GeoDataFrame({"geometry": []}, crs=self.crs)
1102
+
1103
+ return GeoDataFrame(
1104
+ pd.DataFrame(
1105
+ _array_to_geojson(
1106
+ self.values, self.transform, processes=self.processes
1107
+ ),
1108
+ columns=[column, "geometry"],
1109
+ ),
1110
+ geometry="geometry",
1111
+ crs=self.crs,
1112
+ )
1113
+
1114
+ def to_xarray(self) -> DataArray:
1115
+ """Convert the raster to an xarray.DataArray."""
1116
+ name = self.name or self.__class__.__name__.lower()
1117
+ coords = _generate_spatial_coords(self.transform, self.width, self.height)
1118
+ if len(self.values.shape) == 2:
1119
+ dims = ["y", "x"]
1120
+ elif len(self.values.shape) == 3:
1121
+ dims = ["band", "y", "x"]
1122
+ else:
1123
+ raise ValueError("Array must be 2 or 3 dimensional.")
1124
+ return xr.DataArray(
1125
+ self.values,
1126
+ coords=coords,
1127
+ dims=dims,
1128
+ name=name,
1129
+ attrs={"crs": self.crs},
1130
+ )
1131
+
1132
+ def __repr__(self) -> str:
1133
+ """String representation."""
1134
+ try:
1135
+ band_id = f"'{self.band_id}'" if self.band_id else None
1136
+ except (ValueError, AttributeError):
1137
+ band_id = None
1138
+ try:
1139
+ path = f"'{self.path}'"
1140
+ except (ValueError, AttributeError):
1141
+ path = None
1142
+ return (
1143
+ f"{self.__class__.__name__}(band_id={band_id}, res={self.res}, path={path})"
1144
+ )
1145
+
1146
+
1147
+ class NDVIBand(Band):
1148
+ """Band for NDVI values."""
1149
+
1150
+ cmap: str = "Greens"
1151
+
1152
+ # @staticmethod
1153
+ # def get_cmap(arr: np.ndarray):
1154
+ # return get_cmap(arr)
1155
+
1156
+
1157
+ def get_cmap(arr: np.ndarray) -> LinearSegmentedColormap:
1158
+
1159
+ # blue = [[i / 10 + 0.1, i / 10 + 0.1, 1 - (i / 10) + 0.1] for i in range(11)][1:]
1160
+ blue = [
1161
+ [0.1, 0.1, 1.0],
1162
+ [0.2, 0.2, 0.9],
1163
+ [0.3, 0.3, 0.8],
1164
+ [0.4, 0.4, 0.7],
1165
+ [0.6, 0.6, 0.6],
1166
+ [0.6, 0.6, 0.6],
1167
+ [0.7, 0.7, 0.7],
1168
+ [0.8, 0.8, 0.8],
1169
+ ]
1170
+ # gray = list(reversed([[i / 10 - 0.1, i / 10, i / 10 - 0.1] for i in range(11)][1:]))
1171
+ gray = [
1172
+ [0.6, 0.6, 0.6],
1173
+ [0.6, 0.6, 0.6],
1174
+ [0.6, 0.6, 0.6],
1175
+ [0.6, 0.6, 0.6],
1176
+ [0.6, 0.6, 0.6],
1177
+ [0.4, 0.7, 0.4],
1178
+ [0.3, 0.7, 0.3],
1179
+ [0.2, 0.8, 0.2],
1180
+ ]
1181
+ # gray = [[0.6, 0.6, 0.6] for i in range(10)]
1182
+ # green = [[0.2 + i/20, i / 10 - 0.1, + i/20] for i in range(11)][1:]
1183
+ green = [
1184
+ [0.25, 0.0, 0.05],
1185
+ [0.3, 0.1, 0.1],
1186
+ [0.35, 0.2, 0.15],
1187
+ [0.4, 0.3, 0.2],
1188
+ [0.45, 0.4, 0.25],
1189
+ [0.5, 0.5, 0.3],
1190
+ [0.55, 0.6, 0.35],
1191
+ [0.7, 0.9, 0.5],
1192
+ ]
1193
+ green = [
1194
+ [0.6, 0.6, 0.6],
1195
+ [0.4, 0.7, 0.4],
1196
+ [0.3, 0.8, 0.3],
1197
+ [0.25, 0.4, 0.25],
1198
+ [0.2, 0.5, 0.2],
1199
+ [0.10, 0.7, 0.10],
1200
+ [0, 0.9, 0],
1201
+ ]
1202
+
1203
+ def get_start(arr):
1204
+ min_value = np.min(arr)
1205
+ if min_value < -0.75:
1206
+ return 0
1207
+ if min_value < -0.5:
1208
+ return 1
1209
+ if min_value < -0.25:
1210
+ return 2
1211
+ if min_value < 0:
1212
+ return 3
1213
+ if min_value < 0.25:
1214
+ return 4
1215
+ if min_value < 0.5:
1216
+ return 5
1217
+ if min_value < 0.75:
1218
+ return 6
1219
+ return 7
1220
+
1221
+ def get_stop(arr):
1222
+ max_value = np.max(arr)
1223
+ if max_value <= 0.05:
1224
+ return 0
1225
+ if max_value < 0.175:
1226
+ return 1
1227
+ if max_value < 0.25:
1228
+ return 2
1229
+ if max_value < 0.375:
1230
+ return 3
1231
+ if max_value < 0.5:
1232
+ return 4
1233
+ if max_value < 0.75:
1234
+ return 5
1235
+ return 6
1236
+
1237
+ cmap_name = "blue_gray_green"
1238
+
1239
+ start = get_start(arr)
1240
+ stop = get_stop(arr)
1241
+ blue = blue[start]
1242
+ gray = gray[start]
1243
+ # green = green[start]
1244
+ green = green[stop]
1245
+
1246
+ # green[0] = np.arange(0, 1, 0.1)[::-1][stop]
1247
+ # green[1] = np.arange(0, 1, 0.1)[stop]
1248
+ # green[2] = np.arange(0, 1, 0.1)[::-1][stop]
1249
+
1250
+ print(green)
1251
+ print(start, stop)
1252
+ print("blue gray green")
1253
+ print(blue)
1254
+ print(gray)
1255
+ print(green)
1256
+
1257
+ # Define the segments of the colormap
1258
+ cdict = {
1259
+ "red": [
1260
+ (0.0, blue[0], blue[0]),
1261
+ (0.3, gray[0], gray[0]),
1262
+ (0.7, gray[0], gray[0]),
1263
+ (1.0, green[0], green[0]),
1264
+ ],
1265
+ "green": [
1266
+ (0.0, blue[1], blue[1]),
1267
+ (0.3, gray[1], gray[1]),
1268
+ (0.7, gray[1], gray[1]),
1269
+ (1.0, green[1], green[1]),
1270
+ ],
1271
+ "blue": [
1272
+ (0.0, blue[2], blue[2]),
1273
+ (0.3, gray[2], gray[2]),
1274
+ (0.7, gray[2], gray[2]),
1275
+ (1.0, green[2], green[2]),
1276
+ ],
1277
+ }
1278
+
1279
+ return LinearSegmentedColormap(cmap_name, segmentdata=cdict, N=50)
1280
+
1281
+
1282
+ def median_as_int_and_minimum_dtype(arr: np.ndarray) -> np.ndarray:
1283
+ arr = np.median(arr, axis=0).astype(int)
1284
+ min_dtype = rasterio.dtypes.get_minimum_dtype(arr)
1285
+ return arr.astype(min_dtype)
1286
+
1287
+
1288
+ class Image(_ImageBandBase):
1289
+ """Image consisting of one or more Bands."""
1290
+
1291
+ cloud_cover_regexes: ClassVar[tuple[str] | None] = None
1292
+ band_class: ClassVar[Band] = Band
1293
+
1294
+ def __init__(
1295
+ self,
1296
+ data: str | Path | Sequence[Band],
1297
+ res: int | None = None,
1298
+ crs: Any | None = None,
1299
+ single_banded: bool = False,
1300
+ file_system: GCSFileSystem | None = None,
1301
+ df: pd.DataFrame | None = None,
1302
+ all_file_paths: list[str] | None = None,
1303
+ processes: int = 1,
1304
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple | None = None,
1305
+ nodata: int | None = None,
1306
+ **kwargs,
1307
+ ) -> None:
1308
+ """Image initialiser."""
1309
+ super().__init__(**kwargs)
1310
+
1311
+ self.nodata = nodata
1312
+ self._res = res
1313
+ self._crs = crs
1314
+ self.file_system = file_system
1315
+ self._bbox = to_bbox(bbox) if bbox is not None else None
1316
+ # self._mask = _mask
1317
+ self.single_banded = single_banded
1318
+ self.processes = processes
1319
+ self._all_file_paths = all_file_paths
1320
+
1321
+ if hasattr(data, "__iter__") and all(isinstance(x, Band) for x in data):
1322
+ self._bands = list(data)
1323
+ if res is None:
1324
+ res = list({band.res for band in self._bands})
1325
+ if len(res) == 1:
1326
+ self._res = res[0]
1327
+ else:
1328
+ raise ValueError(f"Different resolutions for the bands: {res}")
1329
+ else:
1330
+ self._res = res
1331
+ return
1332
+
1333
+ if not isinstance(data, (str | Path | os.PathLike)):
1334
+ raise TypeError("'data' must be string, Path-like or a sequence of Band.")
1335
+
1336
+ self._bands = None
1337
+ self._path = str(data)
1338
+
1339
+ if df is None:
1340
+ if is_dapla():
1341
+ file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
1342
+ else:
1343
+ file_paths = list(
1344
+ sorted(
1345
+ set(
1346
+ _glob_func(self.path + "/**/**")
1347
+ + _glob_func(self.path + "/**/**/**")
1348
+ + _glob_func(self.path + "/**/**/**/**")
1349
+ + _glob_func(self.path + "/**/**/**/**/**")
1350
+ )
1351
+ )
1352
+ )
1353
+ if not file_paths:
1354
+ file_paths = [self.path]
1355
+ df = self._create_metadata_df(file_paths)
1356
+
1357
+ df["image_path"] = df["image_path"].astype(str)
1358
+
1359
+ cols_to_explode = [
1360
+ "file_path",
1361
+ "filename",
1362
+ *[x for x in df if FILENAME_COL_SUFFIX in x],
1363
+ ]
1364
+ try:
1365
+ df = df.explode(cols_to_explode, ignore_index=True)
1366
+ except ValueError:
1367
+ for col in cols_to_explode:
1368
+ df = df.explode(col)
1369
+ df = df.loc[lambda x: ~x["filename"].duplicated()].reset_index(drop=True)
1370
+
1371
+ df = df.loc[lambda x: x["image_path"].str.contains(_fix_path(self.path))]
1372
+
1373
+ if self.cloud_cover_regexes:
1374
+ if all_file_paths is None:
1375
+ file_paths = _ls_func(self.path)
1376
+ else:
1377
+ file_paths = [path for path in all_file_paths if self.name in path]
1378
+ self.cloud_coverage_percentage = float(
1379
+ _get_regex_match_from_xml_in_local_dir(
1380
+ file_paths, regexes=self.cloud_cover_regexes
1381
+ )
1382
+ )
1383
+ else:
1384
+ self.cloud_coverage_percentage = None
1385
+
1386
+ self._df = df
1387
+
1388
+ @property
1389
+ def values(self) -> np.ndarray:
1390
+ """3 dimensional numpy array."""
1391
+ return np.array([band.values for band in self])
1392
+
1393
+ def ndvi(self, red_band: str, nir_band: str, copy: bool = True) -> NDVIBand:
1394
+ """Calculate the NDVI for the Image."""
1395
+ copied = self.copy() if copy else self
1396
+ red = copied[red_band].load()
1397
+ nir = copied[nir_band].load()
1398
+
1399
+ arr: np.ndarray | np.ma.core.MaskedArray = ndvi(red.values, nir.values)
1400
+
1401
+ # if self.nodata is not None and not np.isnan(self.nodata):
1402
+ # try:
1403
+ # arr.data[arr.mask] = self.nodata
1404
+ # arr = arr.copy()
1405
+ # except AttributeError:
1406
+ # pass
1407
+
1408
+ return NDVIBand(
1409
+ arr,
1410
+ bounds=red.bounds,
1411
+ crs=red.crs,
1412
+ mask=red.mask,
1413
+ **red._common_init_kwargs,
1414
+ )
1415
+
1416
+ def get_brightness(
1417
+ self,
1418
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1419
+ rbg_bands: list[str] | None = None,
1420
+ ) -> Band:
1421
+ """Get a Band with a brightness score of the Image's RBG bands."""
1422
+ if rbg_bands is None:
1423
+ try:
1424
+ r, b, g = self.rbg_bands
1425
+ except AttributeError as err:
1426
+ raise AttributeError(
1427
+ "Must specify rbg_bands when there is no class variable 'rbd_bands'"
1428
+ ) from err
1429
+ else:
1430
+ r, b, g = rbg_bands
1431
+
1432
+ red = self[r].load(bounds=bounds)
1433
+ blue = self[b].load(bounds=bounds)
1434
+ green = self[g].load(bounds=bounds)
1435
+
1436
+ brightness = (
1437
+ 0.299 * red.values + 0.587 * green.values + 0.114 * blue.values
1438
+ ).astype(int)
1439
+
1440
+ return Band(
1441
+ brightness,
1442
+ bounds=red.bounds,
1443
+ crs=self.crs,
1444
+ mask=self.mask,
1445
+ **self._common_init_kwargs,
1446
+ )
1447
+
1448
+ @property
1449
+ def mask(self) -> Band | None:
1450
+ """Mask Band."""
1451
+ if self._mask is not None:
1452
+ return self._mask
1453
+ if self.masking is None:
1454
+ return None
1455
+
1456
+ mask_band_id = self.masking["band_id"]
1457
+ mask_paths = [path for path in self._df["file_path"] if mask_band_id in path]
1458
+ if len(mask_paths) > 1:
1459
+ raise ValueError(
1460
+ f"Multiple file_paths match mask band_id {mask_band_id} for {self.path}"
1461
+ )
1462
+ elif not mask_paths:
1463
+ raise ValueError(
1464
+ f"No file_paths match mask band_id {mask_band_id} for {self.path}"
1465
+ )
1466
+ self._mask = self.band_class(
1467
+ mask_paths[0],
1468
+ **self._common_init_kwargs,
1469
+ )
1470
+
1471
+ return self._mask
1472
+
1473
+ @mask.setter
1474
+ def mask(self, values: Band) -> None:
1475
+ if values is None:
1476
+ self._mask = None
1477
+ for band in self:
1478
+ band.mask = None
1479
+ return
1480
+ if not isinstance(values, Band):
1481
+ raise TypeError(f"mask must be Band. Got {type(values)}")
1482
+ self._mask = values
1483
+ mask_arr = self._mask.values
1484
+ for band in self:
1485
+ band._mask = self._mask
1486
+ try:
1487
+ band.values = np.ma.array(
1488
+ band.values, mask=mask_arr, fill_value=band.nodata
1489
+ )
1490
+ except ArrayNotLoadedError:
1491
+ pass
1492
+
1493
+ @property
1494
+ def band_ids(self) -> list[str]:
1495
+ """The Band ids."""
1496
+ return [band.band_id for band in self]
1497
+
1498
+ @property
1499
+ def file_paths(self) -> list[str]:
1500
+ """The Band file paths."""
1501
+ return [band.path for band in self]
1502
+
1503
+ @property
1504
+ def bands(self) -> list[Band]:
1505
+ """The Image Bands."""
1506
+ if self._bands is not None:
1507
+ return self._bands
1508
+
1509
+ # if self.masking:
1510
+ # mask_band_id = self.masking["band_id"]
1511
+ # mask_paths = [
1512
+ # path for path in self._df["file_path"] if mask_band_id in path
1513
+ # ]
1514
+ # if len(mask_paths) > 1:
1515
+ # raise ValueError(
1516
+ # f"Multiple file_paths match mask band_id {mask_band_id}"
1517
+ # )
1518
+ # elif not mask_paths:
1519
+ # raise ValueError(f"No file_paths match mask band_id {mask_band_id}")
1520
+ # arr = (
1521
+ # self.band_class(
1522
+ # mask_paths[0],
1523
+ # # mask=self.mask,
1524
+ # **self._common_init_kwargs,
1525
+ # )
1526
+ # .load()
1527
+ # .values
1528
+ # )
1529
+ # self._mask = np.ma.array(
1530
+ # arr, mask=np.isin(arr, self.masking["values"]), fill_value=None
1531
+ # )
1532
+
1533
+ self._bands = [
1534
+ self.band_class(
1535
+ path,
1536
+ mask=self.mask,
1537
+ **self._common_init_kwargs,
1538
+ )
1539
+ for path in (self._df["file_path"])
1540
+ ]
1541
+
1542
+ if self.masking:
1543
+ mask_band_id = self.masking["band_id"]
1544
+ self._bands = [
1545
+ band for band in self._bands if mask_band_id not in band.path
1546
+ ]
1547
+
1548
+ if (
1549
+ self.filename_patterns
1550
+ and any(_get_non_optional_groups(pat) for pat in self.filename_patterns)
1551
+ or self.image_patterns
1552
+ and any(_get_non_optional_groups(pat) for pat in self.image_patterns)
1553
+ ):
1554
+ self._bands = [band for band in self._bands if band.band_id is not None]
1555
+
1556
+ if self.filename_patterns:
1557
+ self._bands = [
1558
+ band
1559
+ for band in self._bands
1560
+ if any(
1561
+ # _get_first_group_match(pat, band.name)
1562
+ re.search(pat, band.name)
1563
+ for pat in self.filename_patterns
1564
+ )
1565
+ ]
1566
+
1567
+ if self.image_patterns:
1568
+ self._bands = [
1569
+ band
1570
+ for band in self._bands
1571
+ if any(
1572
+ re.search(pat, Path(band.path).parent.name)
1573
+ # _get_first_group_match(pat, Path(band.path).parent.name)
1574
+ for pat in self.image_patterns
1575
+ )
1576
+ ]
1577
+
1578
+ if self._should_be_sorted:
1579
+ self._bands = list(sorted(self._bands))
1580
+
1581
+ return self._bands
1582
+
1583
+ @property
1584
+ def _should_be_sorted(self) -> bool:
1585
+ sort_groups = ["band", "band_id"]
1586
+ return self.filename_patterns and any(
1587
+ group in _get_non_optional_groups(pat)
1588
+ for group in sort_groups
1589
+ for pat in self.filename_patterns
1590
+ )
1591
+
1592
+ @property
1593
+ def tile(self) -> str:
1594
+ """Tile name from filename_regex."""
1595
+ if hasattr(self, "_tile") and self._tile:
1596
+ return self._tile
1597
+ return self._name_regex_searcher(
1598
+ "tile", self.image_patterns + self.filename_patterns
1599
+ )
1600
+
1601
+ @property
1602
+ def date(self) -> str:
1603
+ """Tile name from filename_regex."""
1604
+ if hasattr(self, "_date") and self._date:
1605
+ return self._date
1606
+
1607
+ return self._name_regex_searcher(
1608
+ "date", self.image_patterns + self.filename_patterns
1609
+ )
1610
+
1611
+ @property
1612
+ def crs(self) -> str | None:
1613
+ """Coordinate reference system of the Image."""
1614
+ if self._crs is not None:
1615
+ return self._crs
1616
+ if not len(self):
1617
+ return None
1618
+ self._crs = get_common_crs(self)
1619
+ return self._crs
1620
+
1621
+ @property
1622
+ def bounds(self) -> tuple[int, int, int, int] | None:
1623
+ """Bounds of the Image (minx, miny, maxx, maxy)."""
1624
+ return get_total_bounds([band.bounds for band in self])
1625
+
1626
+ def to_gdf(self, column: str = "value") -> GeoDataFrame:
1627
+ """Convert the array to a GeoDataFrame of grid polygons and values."""
1628
+ return pd.concat(
1629
+ [band.to_gdf(column=column) for band in self], ignore_index=True
1630
+ )
1631
+
1632
+ def sample(
1633
+ self, n: int = 1, size: int = 1000, mask: Any = None, **kwargs
1634
+ ) -> "Image":
1635
+ """Take a random spatial sample of the image."""
1636
+ copied = self.copy()
1637
+ if mask is not None:
1638
+ points = GeoSeries([self.union_all()]).clip(mask).sample_points(n)
1639
+ else:
1640
+ points = GeoSeries([self.union_all()]).sample_points(n)
1641
+ buffered = points.buffer(size / 2).clip(self.union_all())
1642
+ boxes = to_gdf([box(*arr) for arr in buffered.bounds.values], crs=self.crs)
1643
+ copied._bands = [band.load(bounds=boxes, **kwargs) for band in copied]
1644
+ copied._bounds = get_total_bounds([band.bounds for band in copied])
1645
+ return copied
1646
+
1647
+ def __getitem__(
1648
+ self, band: str | int | Sequence[str] | Sequence[int]
1649
+ ) -> "Band | Image":
1650
+ """Get bands by band_id or integer index.
1651
+
1652
+ Returns a Band if a string or int is passed,
1653
+ returns an Image if a sequence of strings or integers is passed.
1654
+ """
1655
+ if isinstance(band, str):
1656
+ return self._get_band(band)
1657
+ if isinstance(band, int):
1658
+ return self.bands[band] # .copy()
1659
+
1660
+ copied = self.copy()
1661
+ try:
1662
+ copied._bands = [copied._get_band(x) for x in band]
1663
+ except TypeError:
1664
+ try:
1665
+ copied._bands = [copied.bands[i] for i in band]
1666
+ except TypeError as e:
1667
+ raise TypeError(
1668
+ f"{self.__class__.__name__} indices should be string, int "
1669
+ f"or sequence of string or int. Got {band}."
1670
+ ) from e
1671
+ return copied
1672
+
1673
+ def __contains__(self, item: str | Sequence[str]) -> bool:
1674
+ """Check if the Image contains a band_id (str) or all band_ids in a sequence."""
1675
+ if isinstance(item, str):
1676
+ return item in self.band_ids
1677
+ return all(x in self.band_ids for x in item)
1678
+
1679
+ def __lt__(self, other: "Image") -> bool:
1680
+ """Makes Images sortable by date."""
1681
+ try:
1682
+ return self.date < other.date
1683
+ except Exception as e:
1684
+ print(self.path)
1685
+ print(self.date)
1686
+ print(other.path)
1687
+ print(other.date)
1688
+ raise e
1689
+
1690
+ def __iter__(self) -> Iterator[Band]:
1691
+ """Iterate over the Bands."""
1692
+ return iter(self.bands)
1693
+
1694
+ def __len__(self) -> int:
1695
+ """Number of bands in the Image."""
1696
+ return len(self.bands)
1697
+
1698
+ def __repr__(self) -> str:
1699
+ """String representation."""
1700
+ return f"{self.__class__.__name__}(bands={self.bands})"
1701
+
1702
+ def _get_band(self, band: str) -> Band:
1703
+ if not isinstance(band, str):
1704
+ raise TypeError(f"band must be string. Got {type(band)}")
1705
+
1706
+ bands = [x for x in self.bands if x.band_id == band]
1707
+ if len(bands) == 1:
1708
+ return bands[0]
1709
+ if len(bands) > 1:
1710
+ raise ValueError(f"Multiple matches for band_id {band} for {self}")
1711
+
1712
+ bands = [x for x in self.bands if x.band_id == band.replace("B0", "B")]
1713
+ if len(bands) == 1:
1714
+ return bands[0]
1715
+
1716
+ bands = [x for x in self.bands if x.band_id.replace("B0", "B") == band]
1717
+ if len(bands) == 1:
1718
+ return bands[0]
1719
+
1720
+ try:
1721
+ more_bands = [x for x in self.bands if x.path == band]
1722
+ except PathlessImageError:
1723
+ more_bands = bands
1724
+
1725
+ if len(more_bands) == 1:
1726
+ return more_bands[0]
1727
+
1728
+ if len(bands) > 1:
1729
+ prefix = "Multiple"
1730
+ elif not bands:
1731
+ prefix = "No"
1732
+
1733
+ raise KeyError(
1734
+ f"{prefix} matches for band {band} among paths {[Path(band.path).name for band in self.bands]}"
1735
+ )
1736
+
1737
+
1738
+ class ImageCollection(_ImageBase):
1739
+ """Collection of Images.
1740
+
1741
+ Loops though Images.
1742
+ """
1743
+
1744
+ image_class: ClassVar[Image] = Image
1745
+ band_class: ClassVar[Band] = Band
1746
+
1747
+ def __init__(
1748
+ self,
1749
+ data: str | Path | Sequence[Image],
1750
+ res: int,
1751
+ level: str | None,
1752
+ crs: Any | None = None,
1753
+ single_banded: bool = False,
1754
+ processes: int = 1,
1755
+ file_system: GCSFileSystem | None = None,
1756
+ df: pd.DataFrame | None = None,
1757
+ bbox: Any | None = None,
1758
+ nodata: int | None = None,
1759
+ metadata: str | dict | pd.DataFrame | None = None,
1760
+ **kwargs,
1761
+ ) -> None:
1762
+ """Initialiser."""
1763
+ super().__init__(**kwargs)
1764
+
1765
+ self.nodata = nodata
1766
+ self.level = level
1767
+ self._crs = crs
1768
+ self.processes = processes
1769
+ self.file_system = file_system
1770
+ self._res = res
1771
+ self._bbox = to_bbox(bbox) if bbox is not None else None
1772
+ self._band_ids = None
1773
+ self.single_banded = single_banded
1774
+
1775
+ if metadata is not None:
1776
+ if isinstance(metadata, (str | Path | os.PathLike)):
1777
+ self.metadata = _read_parquet_func(metadata)
1778
+ else:
1779
+ self.metadata = metadata
1780
+ else:
1781
+ self.metadata = metadata
1782
+
1783
+ if hasattr(data, "__iter__") and all(isinstance(x, Image) for x in data):
1784
+ self._path = None
1785
+ self.images = [x.copy() for x in data]
1786
+ return
1787
+ else:
1788
+ self._images = None
1789
+
1790
+ if not isinstance(data, (str | Path | os.PathLike)):
1791
+ raise TypeError("'data' must be string, Path-like or a sequence of Image.")
1792
+
1793
+ self._path = str(data)
1794
+
1795
+ if is_dapla():
1796
+ self._all_file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
1797
+ else:
1798
+ self._all_file_paths = list(
1799
+ sorted(
1800
+ set(
1801
+ _glob_func(self.path + "/**/**")
1802
+ + _glob_func(self.path + "/**/**/**")
1803
+ + _glob_func(self.path + "/**/**/**/**")
1804
+ + _glob_func(self.path + "/**/**/**/**/**")
1805
+ )
1806
+ )
1807
+ )
1808
+
1809
+ if self.level:
1810
+ self._all_file_paths = [
1811
+ path for path in self._all_file_paths if self.level in path
1812
+ ]
1813
+
1814
+ if df is not None:
1815
+ self._df = df
1816
+ else:
1817
+ self._df = self._create_metadata_df(self._all_file_paths)
1818
+
1819
+ @property
1820
+ def values(self) -> np.ndarray:
1821
+ """4 dimensional numpy array."""
1822
+ return np.array([img.values for img in self])
1823
+
1824
+ @property
1825
+ def mask(self) -> np.ndarray:
1826
+ """4 dimensional numpy array."""
1827
+ return np.array([img.mask.values for img in self])
1828
+
1829
+ # def ndvi(
1830
+ # self, red_band: str, nir_band: str, copy: bool = True
1831
+ # ) -> "ImageCollection":
1832
+ # # copied = self.copy() if copy else self
1833
+
1834
+ # with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1835
+ # ndvi_images = parallel(
1836
+ # joblib.delayed(_img_ndvi)(
1837
+ # img, red_band=red_band, nir_band=nir_band, copy=False
1838
+ # )
1839
+ # for img in self
1840
+ # )
1841
+
1842
+ # return ImageCollection(ndvi_images, single_banded=True)
1843
+
1844
+ def groupby(self, by: str | list[str], **kwargs) -> ImageCollectionGroupBy:
1845
+ """Group the Collection by Image or Band attribute(s)."""
1846
+ df = pd.DataFrame(
1847
+ [(i, img) for i, img in enumerate(self) for _ in img],
1848
+ columns=["_image_idx", "_image_instance"],
1849
+ )
1850
+
1851
+ if isinstance(by, str):
1852
+ by = [by]
1853
+
1854
+ for attr in by:
1855
+ if attr == "bounds":
1856
+ # need integers to check equality when grouping
1857
+ df[attr] = [
1858
+ tuple(int(x) for x in band.bounds) for img in self for band in img
1859
+ ]
1860
+ continue
1861
+
1862
+ try:
1863
+ df[attr] = [getattr(band, attr) for img in self for band in img]
1864
+ except AttributeError:
1865
+ df[attr] = [getattr(img, attr) for img in self for _ in img]
1866
+
1867
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1868
+ return ImageCollectionGroupBy(
1869
+ sorted(
1870
+ parallel(
1871
+ joblib.delayed(_copy_and_add_df_parallel)(i, group, self)
1872
+ for i, group in df.groupby(by, **kwargs)
1873
+ )
1874
+ ),
1875
+ by=by,
1876
+ collection=self,
1877
+ )
1878
+
1879
+ def explode(self, copy: bool = True) -> "ImageCollection":
1880
+ """Make all Images single-banded."""
1881
+ copied = self.copy() if copy else self
1882
+ copied.images = [
1883
+ self.image_class(
1884
+ [band],
1885
+ single_banded=True,
1886
+ masking=self.masking,
1887
+ band_class=self.band_class,
1888
+ **self._common_init_kwargs,
1889
+ df=self._df,
1890
+ all_file_paths=self._all_file_paths,
1891
+ )
1892
+ for img in self
1893
+ for band in img
1894
+ ]
1895
+ return copied
1896
+
1897
+ def merge(
1898
+ self,
1899
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1900
+ method: str | Callable = "mean",
1901
+ as_int: bool = True,
1902
+ indexes: int | tuple[int] | None = None,
1903
+ **kwargs,
1904
+ ) -> Band:
1905
+ """Merge all areas and all bands to a single Band."""
1906
+ bounds = to_bbox(bounds) if bounds is not None else self._bbox
1907
+ crs = self.crs
1908
+
1909
+ if indexes is None:
1910
+ indexes = 1
1911
+
1912
+ if isinstance(indexes, int):
1913
+ _indexes = (indexes,)
1914
+ else:
1915
+ _indexes = indexes
1916
+
1917
+ if method == "mean":
1918
+ _method = "sum"
1919
+ else:
1920
+ _method = method
1921
+
1922
+ if self.masking or method not in list(rasterio.merge.MERGE_METHODS) + ["mean"]:
1923
+ arr = self._merge_with_numpy_func(
1924
+ method=method,
1925
+ bounds=bounds,
1926
+ as_int=as_int,
1927
+ **kwargs,
1928
+ )
1929
+ else:
1930
+ datasets = [_open_raster(path) for path in self.file_paths]
1931
+ arr, _ = rasterio.merge.merge(
1932
+ datasets,
1933
+ res=self.res,
1934
+ bounds=(bounds if bounds is not None else self.bounds),
1935
+ indexes=_indexes,
1936
+ method=_method,
1937
+ nodata=self.nodata,
1938
+ **kwargs,
1939
+ )
1940
+
1941
+ if isinstance(indexes, int) and len(arr.shape) == 3 and arr.shape[0] == 1:
1942
+ arr = arr[0]
1943
+
1944
+ if method == "mean":
1945
+ if as_int:
1946
+ arr = arr // len(datasets)
1947
+ else:
1948
+ arr = arr / len(datasets)
1949
+
1950
+ if bounds is None:
1951
+ bounds = self.bounds
1952
+
1953
+ # return self.band_class(
1954
+ band = Band(
1955
+ arr,
1956
+ bounds=bounds,
1957
+ crs=crs,
1958
+ mask=self.mask,
1959
+ **self._common_init_kwargs,
1960
+ )
1961
+
1962
+ band._merged = True
1963
+ return band
1964
+
1965
+ def merge_by_band(
1966
+ self,
1967
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1968
+ method: str = "mean",
1969
+ as_int: bool = True,
1970
+ indexes: int | tuple[int] | None = None,
1971
+ **kwargs,
1972
+ ) -> Image:
1973
+ """Merge all areas to a single tile, one band per band_id."""
1974
+ bounds = to_bbox(bounds) if bounds is not None else self._bbox
1975
+ bounds = self.bounds if bounds is None else bounds
1976
+ out_bounds = bounds
1977
+ crs = self.crs
1978
+
1979
+ if indexes is None:
1980
+ indexes = 1
1981
+
1982
+ if isinstance(indexes, int):
1983
+ _indexes = (indexes,)
1984
+ else:
1985
+ _indexes = indexes
1986
+
1987
+ if method == "mean":
1988
+ _method = "sum"
1989
+ else:
1990
+ _method = method
1991
+
1992
+ arrs = []
1993
+ bands: list[Band] = []
1994
+ for (band_id,), band_collection in self.groupby("band_id"):
1995
+ if self.masking or method not in list(rasterio.merge.MERGE_METHODS) + [
1996
+ "mean"
1997
+ ]:
1998
+ arr = band_collection._merge_with_numpy_func(
1999
+ method=method,
2000
+ bounds=bounds,
2001
+ as_int=as_int,
2002
+ **kwargs,
2003
+ )
2004
+ else:
2005
+ datasets = [_open_raster(path) for path in band_collection.file_paths]
2006
+ arr, _ = rasterio.merge.merge(
2007
+ datasets,
2008
+ res=self.res,
2009
+ bounds=(bounds if bounds is not None else self.bounds),
2010
+ indexes=_indexes,
2011
+ method=_method,
2012
+ nodata=self.nodata,
2013
+ **kwargs,
2014
+ )
2015
+ if isinstance(indexes, int):
2016
+ arr = arr[0]
2017
+ if method == "mean":
2018
+ if as_int:
2019
+ arr = arr // len(datasets)
2020
+ else:
2021
+ arr = arr / len(datasets)
2022
+
2023
+ arrs.append(arr)
2024
+ bands.append(
2025
+ self.band_class(
2026
+ arr,
2027
+ bounds=out_bounds,
2028
+ crs=crs,
2029
+ band_id=band_id,
2030
+ **self._common_init_kwargs,
2031
+ )
2032
+ )
2033
+
2034
+ # return self.image_class(
2035
+ image = Image(
2036
+ bands,
2037
+ band_class=self.band_class,
2038
+ **self._common_init_kwargs,
2039
+ )
2040
+
2041
+ image._merged = True
2042
+ return image
2043
+
2044
+ def _merge_with_numpy_func(
2045
+ self,
2046
+ method: str | Callable,
2047
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
2048
+ as_int: bool = True,
2049
+ indexes: int | tuple[int] | None = None,
2050
+ **kwargs,
2051
+ ) -> np.ndarray:
2052
+ arrs = []
2053
+ kwargs["indexes"] = indexes
2054
+ bounds = to_shapely(bounds) if bounds is not None else None
2055
+ numpy_func = get_numpy_func(method) if not callable(method) else method
2056
+ for (_bounds,), collection in self.groupby("bounds"):
2057
+ _bounds = (
2058
+ to_shapely(_bounds).intersection(bounds)
2059
+ if bounds is not None
2060
+ else to_shapely(_bounds)
2061
+ )
2062
+ if not _bounds.area:
2063
+ continue
2064
+
2065
+ _bounds = to_bbox(_bounds)
2066
+ arr = np.array(
2067
+ [
2068
+ (
2069
+ band.load(
2070
+ bounds=(_bounds if _bounds is not None else None),
2071
+ **kwargs,
2072
+ )
2073
+ ).values
2074
+ for img in collection
2075
+ for band in img
2076
+ ]
2077
+ )
2078
+ arr = numpy_func(arr, axis=0)
2079
+ if as_int:
2080
+ arr = arr.astype(int)
2081
+ min_dtype = rasterio.dtypes.get_minimum_dtype(arr)
2082
+ arr = arr.astype(min_dtype)
2083
+
2084
+ if len(arr.shape) == 2:
2085
+ height, width = arr.shape
2086
+ elif len(arr.shape) == 3:
2087
+ height, width = arr.shape[1:]
2088
+ else:
2089
+ raise ValueError(arr.shape)
2090
+
2091
+ transform = rasterio.transform.from_bounds(*_bounds, width, height)
2092
+ coords = _generate_spatial_coords(transform, width, height)
2093
+
2094
+ arrs.append(
2095
+ xr.DataArray(
2096
+ arr,
2097
+ coords=coords,
2098
+ dims=["y", "x"],
2099
+ attrs={"crs": self.crs},
2100
+ )
2101
+ )
2102
+
2103
+ merged = merge_arrays(
2104
+ arrs,
2105
+ res=self.res,
2106
+ nodata=self.nodata,
2107
+ )
2108
+
2109
+ return merged.to_numpy()
2110
+
2111
+ def sort_images(self, ascending: bool = True) -> "ImageCollection":
2112
+ """Sort Images by date."""
2113
+ self._images = (
2114
+ list(sorted([img for img in self if img.date is not None]))
2115
+ + sorted(
2116
+ [img for img in self if img.date is None and img.path is not None],
2117
+ key=lambda x: x.path,
2118
+ )
2119
+ + [img for img in self if img.date is None and img.path is None]
2120
+ )
2121
+ if not ascending:
2122
+ self._images = list(reversed(self.images))
2123
+ return self
2124
+
2125
+ def load(
2126
+ self,
2127
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
2128
+ indexes: int | tuple[int] | None = None,
2129
+ **kwargs,
2130
+ ) -> "ImageCollection":
2131
+ """Load all image Bands with threading."""
2132
+ with joblib.Parallel(n_jobs=self.processes, backend="threading") as parallel:
2133
+ if self.masking:
2134
+ parallel(
2135
+ joblib.delayed(_load_band)(
2136
+ img.mask, bounds=bounds, indexes=indexes, **kwargs
2137
+ )
2138
+ for img in self
2139
+ )
2140
+ parallel(
2141
+ joblib.delayed(_load_band)(
2142
+ band, bounds=bounds, indexes=indexes, **kwargs
2143
+ )
2144
+ for img in self
2145
+ for band in img
2146
+ )
2147
+
2148
+ return self
2149
+
2150
+ def set_bbox(
2151
+ self, bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float]
2152
+ ) -> "ImageCollection":
2153
+ """Set the mask to be used to clip the images to."""
2154
+ self._bbox = to_bbox(bbox)
2155
+ # only update images when already instansiated
2156
+ if self._images is not None:
2157
+ for img in self._images:
2158
+ img._bbox = self._bbox
2159
+ if img._bands is not None:
2160
+ for band in img:
2161
+ band._bbox = self._bbox
2162
+ bounds = box(*band._bbox).intersection(box(*band.bounds))
2163
+ band._bounds = to_bbox(bounds) if not bounds.is_empty else None
2164
+
2165
+ return self
2166
+
2167
+ def apply(self, func: Callable, **kwargs) -> "ImageCollection":
2168
+ """Apply a function to all bands in each image of the collection."""
2169
+ for img in self:
2170
+ img.bands = [func(band, **kwargs) for band in img]
2171
+ return self
2172
+
2173
+ def filter(
2174
+ self,
2175
+ bands: str | list[str] | None = None,
2176
+ exclude_bands: str | list[str] | None = None,
2177
+ date_ranges: (
2178
+ tuple[str | None, str | None]
2179
+ | tuple[tuple[str | None, str | None], ...]
2180
+ | None
2181
+ ) = None,
2182
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2183
+ intersects: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2184
+ max_cloud_coverage: int | None = None,
2185
+ copy: bool = True,
2186
+ ) -> "ImageCollection":
2187
+ """Filter images and bands in the collection."""
2188
+ copied = self.copy() if copy else self
2189
+
2190
+ if isinstance(bbox, BoundingBox):
2191
+ date_ranges = (bbox.mint, bbox.maxt)
2192
+
2193
+ if date_ranges:
2194
+ copied = copied._filter_dates(date_ranges)
2195
+
2196
+ if max_cloud_coverage is not None:
2197
+ copied.images = [
2198
+ image
2199
+ for image in copied.images
2200
+ if image.cloud_coverage_percentage < max_cloud_coverage
2201
+ ]
2202
+
2203
+ if bbox is not None:
2204
+ copied = copied._filter_bounds(bbox)
2205
+ copied.set_bbox(bbox)
2206
+
2207
+ if intersects is not None:
2208
+ copied = copied._filter_bounds(intersects)
2209
+
2210
+ if bands is not None:
2211
+ if isinstance(bands, str):
2212
+ bands = [bands]
2213
+ bands = set(bands)
2214
+ copied._band_ids = bands
2215
+ copied.images = [img[bands] for img in copied.images if bands in img]
2216
+
2217
+ if exclude_bands is not None:
2218
+ if isinstance(exclude_bands, str):
2219
+ exclude_bands = {exclude_bands}
2220
+ else:
2221
+ exclude_bands = set(exclude_bands)
2222
+ include_bands: list[list[str]] = [
2223
+ [band_id for band_id in img.band_ids if band_id not in exclude_bands]
2224
+ for img in copied
2225
+ ]
2226
+ copied.images = [
2227
+ img[bands]
2228
+ for img, bands in zip(copied.images, include_bands, strict=False)
2229
+ if bands
2230
+ ]
2231
+
2232
+ return copied
2233
+
2234
+ def _filter_dates(
2235
+ self,
2236
+ date_ranges: (
2237
+ tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...]
2238
+ ),
2239
+ ) -> "ImageCollection":
2240
+ if not isinstance(date_ranges, (tuple, list)):
2241
+ raise TypeError(
2242
+ "date_ranges should be a 2-length tuple of strings or None, "
2243
+ "or a tuple of tuples for multiple date ranges"
2244
+ )
2245
+ if not self.image_patterns:
2246
+ raise ValueError(
2247
+ "Cannot set date_ranges when the class's image_regexes attribute is None"
2248
+ )
2249
+
2250
+ self.images = [
2251
+ img
2252
+ for img in self
2253
+ if _date_is_within(
2254
+ img.path, date_ranges, self.image_patterns, self.date_format
2255
+ )
2256
+ ]
2257
+ return self
2258
+
2259
+ def _filter_bounds(
2260
+ self, other: GeoDataFrame | GeoSeries | Geometry | tuple
2261
+ ) -> "ImageCollection":
2262
+ if self._images is None:
2263
+ return self
2264
+
2265
+ other = to_shapely(other)
2266
+
2267
+ # intersects_list = GeoSeries([img.union_all() for img in self]).intersects(other)
2268
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
2269
+ intersects_list: list[bool] = parallel(
2270
+ joblib.delayed(_intesects)(image, other) for image in self
2271
+ )
2272
+
2273
+ self.images = [
2274
+ image
2275
+ for image, intersects in zip(self, intersects_list, strict=False)
2276
+ if intersects
2277
+ ]
2278
+ return self
2279
+
2280
+ def to_gdfs(self, column: str = "value") -> dict[str, GeoDataFrame]:
2281
+ """Convert each band in each Image to a GeoDataFrame."""
2282
+ out = {}
2283
+ i = 0
2284
+ for img in self:
2285
+ for band in img:
2286
+ i += 1
2287
+ try:
2288
+ name = band.name
2289
+ except AttributeError:
2290
+ name = f"{self.__class__.__name__}({i})"
2291
+
2292
+ band.load()
2293
+
2294
+ if name not in out:
2295
+ out[name] = band.to_gdf(column=column)
2296
+ else:
2297
+ out[name] = f"{self.__class__.__name__}({i})"
2298
+ return out
2299
+
2300
+ def sample(self, n: int = 1, size: int = 500) -> "ImageCollection":
2301
+ """Sample one or more areas of a given size and set this as mask for the images."""
2302
+ unioned = self.union_all()
2303
+ buffered_in = unioned.buffer(-size / 2)
2304
+ if not buffered_in.is_empty:
2305
+ bbox = to_gdf(buffered_in)
2306
+ else:
2307
+ bbox = to_gdf(unioned)
2308
+
2309
+ copied = self.copy()
2310
+ sampled_images = []
2311
+ while len(sampled_images) < n:
2312
+ mask = to_bbox(bbox.sample_points(1).buffer(size))
2313
+ images = copied.filter(bbox=mask).images
2314
+ random.shuffle(images)
2315
+ try:
2316
+ images = images[:n]
2317
+ except IndexError:
2318
+ pass
2319
+ sampled_images += images
2320
+ copied._images = sampled_images[:n]
2321
+ if copied._should_be_sorted:
2322
+ copied._images = list(sorted(copied._images))
2323
+
2324
+ return copied
2325
+
2326
+ def sample_tiles(self, n: int) -> "ImageCollection":
2327
+ """Sample one or more tiles in a copy of the ImageCollection."""
2328
+ copied = self.copy()
2329
+ sampled_tiles = list({img.tile for img in self})
2330
+ random.shuffle(sampled_tiles)
2331
+ sampled_tiles = sampled_tiles[:n]
2332
+
2333
+ copied.images = [image for image in self if image.tile in sampled_tiles]
2334
+ return copied
2335
+
2336
+ def sample_images(self, n: int) -> "ImageCollection":
2337
+ """Sample one or more images in a copy of the ImageCollection."""
2338
+ copied = self.copy()
2339
+ images = copied.images
2340
+ if n > len(images):
2341
+ raise ValueError(
2342
+ f"n ({n}) is higher than number of images in collection ({len(images)})"
2343
+ )
2344
+ sample = []
2345
+ for _ in range(n):
2346
+ random.shuffle(images)
2347
+ img = images.pop()
2348
+ sample.append(img)
2349
+
2350
+ copied.images = sample
2351
+
2352
+ return copied
2353
+
2354
+ def __or__(self, collection: "ImageCollection") -> "ImageCollection":
2355
+ """Concatenate the collection with another collection."""
2356
+ return concat_image_collections([self, collection])
2357
+
2358
+ def __iter__(self) -> Iterator[Image]:
2359
+ """Iterate over the images."""
2360
+ return iter(self.images)
2361
+
2362
+ def __len__(self) -> int:
2363
+ """Number of images."""
2364
+ return len(self.images)
2365
+
2366
+ def __getitem__(
2367
+ self,
2368
+ item: int | slice | Sequence[int | bool] | BoundingBox | Sequence[BoundingBox],
2369
+ ) -> Image | TORCHGEO_RETURN_TYPE:
2370
+ """Select one Image by integer index, or multiple Images by slice, list of int or torchgeo.BoundingBox."""
2371
+ if isinstance(item, int):
2372
+ return self.images[item]
2373
+
2374
+ if isinstance(item, slice):
2375
+ copied = self.copy()
2376
+ copied.images = copied.images[item]
2377
+ return copied
2378
+
2379
+ if isinstance(item, ImageCollection):
2380
+
2381
+ def _get_from_single_element_list(lst: list[Any]) -> Any:
2382
+ if len(lst) != 1:
2383
+ raise ValueError(lst)
2384
+ return next(iter(lst))
2385
+
2386
+ copied = self.copy()
2387
+ copied._images = [
2388
+ _get_from_single_element_list(
2389
+ [img2 for img2 in copied if img2.stem in img.path]
2390
+ )
2391
+ for img in item
2392
+ ]
2393
+ return copied
2394
+
2395
+ if not isinstance(item, BoundingBox) and not (
2396
+ isinstance(item, Iterable)
2397
+ and len(item)
2398
+ and all(isinstance(x, BoundingBox) for x in item)
2399
+ ):
2400
+ copied = self.copy()
2401
+ if callable(item):
2402
+ item = [item(img) for img in copied]
2403
+
2404
+ # check for base bool and numpy bool
2405
+ if all("bool" in str(type(x)) for x in item):
2406
+ copied.images = [img for x, img in zip(item, copied, strict=True) if x]
2407
+
2408
+ else:
2409
+ copied.images = [copied.images[i] for i in item]
2410
+ return copied
2411
+
2412
+ if isinstance(item, BoundingBox):
2413
+ date_ranges: tuple[str] = (item.mint, item.maxt)
2414
+ data: torch.Tensor = numpy_to_torch(
2415
+ np.array(
2416
+ [
2417
+ band.values
2418
+ for band in self.filter(
2419
+ bbox=item, date_ranges=date_ranges
2420
+ ).merge_by_band(bounds=item)
2421
+ ]
2422
+ )
2423
+ )
2424
+ else:
2425
+ bboxes: list[Polygon] = [to_bbox(x) for x in item]
2426
+ date_ranges: list[list[str, str]] = [(x.mint, x.maxt) for x in item]
2427
+ data: torch.Tensor = torch.cat(
2428
+ [
2429
+ numpy_to_torch(
2430
+ np.array(
2431
+ [
2432
+ band.values
2433
+ for band in self.filter(
2434
+ bbox=bbox, date_ranges=date_range
2435
+ ).merge_by_band(bounds=bbox)
2436
+ ]
2437
+ )
2438
+ )
2439
+ for bbox, date_range in zip(bboxes, date_ranges, strict=True)
2440
+ ]
2441
+ )
2442
+
2443
+ crs = get_common_crs(self.images)
2444
+
2445
+ key = "image" # if self.is_image else "mask"
2446
+ sample = {key: data, "crs": crs, "bbox": item}
2447
+
2448
+ return sample
2449
+
2450
+ @property
2451
+ def mint(self) -> float:
2452
+ """Min timestamp of the images combined."""
2453
+ return min(img.mint for img in self)
2454
+
2455
+ @property
2456
+ def maxt(self) -> float:
2457
+ """Max timestamp of the images combined."""
2458
+ return max(img.maxt for img in self)
2459
+
2460
+ @property
2461
+ def band_ids(self) -> list[str]:
2462
+ """Sorted list of unique band_ids."""
2463
+ return list(sorted({band.band_id for img in self for band in img}))
2464
+
2465
+ @property
2466
+ def file_paths(self) -> list[str]:
2467
+ """Sorted list of all file paths, meaning all band paths."""
2468
+ return list(sorted({band.path for img in self for band in img}))
2469
+
2470
+ @property
2471
+ def dates(self) -> list[str]:
2472
+ """List of image dates."""
2473
+ return [img.date for img in self]
2474
+
2475
+ def dates_as_int(self) -> list[int]:
2476
+ """List of image dates as 8-length integers."""
2477
+ return [int(img.date[:8]) for img in self]
2478
+
2479
+ @property
2480
+ def image_paths(self) -> list[str]:
2481
+ """List of image paths."""
2482
+ return [img.path for img in self]
2483
+
2484
+ @property
2485
+ def images(self) -> list["Image"]:
2486
+ """List of images in the Collection."""
2487
+ if self._images is not None:
2488
+ return self._images
2489
+ # only fetch images when they are needed
2490
+ self._images = _get_images(
2491
+ list(self._df["image_path"]),
2492
+ all_file_paths=self._all_file_paths,
2493
+ df=self._df,
2494
+ image_class=self.image_class,
2495
+ band_class=self.band_class,
2496
+ masking=self.masking,
2497
+ **self._common_init_kwargs,
2498
+ )
2499
+ if self.masking is not None:
2500
+ images = []
2501
+ for image in self._images:
2502
+ try:
2503
+ if not isinstance(image.mask, Band):
2504
+ raise ValueError()
2505
+ images.append(image)
2506
+ except ValueError:
2507
+ continue
2508
+ self._images = images
2509
+ for image in self._images:
2510
+ image._bands = [band for band in image if band.band_id is not None]
2511
+
2512
+ if self.metadata is not None:
2513
+ for img in self:
2514
+ for band in img:
2515
+ for key in ["crs", "bounds"]:
2516
+ try:
2517
+ value = self.metadata[band.path][key]
2518
+ except KeyError:
2519
+ value = self.metadata[key][band.path]
2520
+ setattr(band, f"_{key}", value)
2521
+
2522
+ self._images = [img for img in self if len(img)]
2523
+
2524
+ if self._should_be_sorted:
2525
+ self._images = list(sorted(self._images))
2526
+
2527
+ return self._images
2528
+
2529
+ @property
2530
+ def _should_be_sorted(self) -> bool:
2531
+ """True if the ImageCollection has regexes that should make it sortable by date."""
2532
+ sort_group = "date"
2533
+ return (
2534
+ self.filename_patterns
2535
+ and any(
2536
+ sort_group in pat.groupindex
2537
+ and sort_group in _get_non_optional_groups(pat)
2538
+ for pat in self.filename_patterns
2539
+ )
2540
+ or self.image_patterns
2541
+ and any(
2542
+ sort_group in pat.groupindex
2543
+ and sort_group in _get_non_optional_groups(pat)
2544
+ for pat in self.image_patterns
2545
+ )
2546
+ or all(img.date is not None for img in self)
2547
+ )
2548
+
2549
+ @images.setter
2550
+ def images(self, new_value: list["Image"]) -> list["Image"]:
2551
+ self._images = list(new_value)
2552
+ if not all(isinstance(x, Image) for x in self._images):
2553
+ raise TypeError("images should be a sequence of Image.")
2554
+
2555
+ @property
2556
+ def index(self) -> Index:
2557
+ """Spatial index that makes torchgeo think this class is a RasterDataset."""
2558
+ try:
2559
+ if len(self) == len(self._index):
2560
+ return self._index
2561
+ except AttributeError:
2562
+ self._index = Index(interleaved=False, properties=Property(dimension=3))
2563
+
2564
+ for i, img in enumerate(self.images):
2565
+ if img.date:
2566
+ try:
2567
+ mint, maxt = disambiguate_timestamp(img.date, self.date_format)
2568
+ except (NameError, TypeError):
2569
+ mint, maxt = 0, 1
2570
+ else:
2571
+ mint, maxt = 0, 1
2572
+ # important: torchgeo has a different order of the bbox than shapely and geopandas
2573
+ minx, miny, maxx, maxy = img.bounds
2574
+ self._index.insert(i, (minx, maxx, miny, maxy, mint, maxt))
2575
+ return self._index
2576
+
2577
+ def __repr__(self) -> str:
2578
+ """String representation."""
2579
+ return f"{self.__class__.__name__}({len(self)}, path='{self.path}')"
2580
+
2581
+ def union_all(self) -> Polygon | MultiPolygon:
2582
+ """(Multi)Polygon representing the union of all image bounds."""
2583
+ return unary_union([img.union_all() for img in self])
2584
+
2585
+ @property
2586
+ def bounds(self) -> tuple[int, int, int, int]:
2587
+ """Total bounds for all Images combined."""
2588
+ return get_total_bounds([img.bounds for img in self])
2589
+
2590
+ @property
2591
+ def crs(self) -> Any:
2592
+ """Common coordinate reference system of the Images."""
2593
+ if self._crs is not None:
2594
+ return self._crs
2595
+ self._crs = get_common_crs([img.crs for img in self])
2596
+ return self._crs
2597
+
2598
+ def plot_pixels(
2599
+ self,
2600
+ by: str | list[str] | None = None,
2601
+ x_var: str = "date",
2602
+ y_label: str = "value",
2603
+ p: float = 0.95,
2604
+ ylim: tuple[float, float] | None = None,
2605
+ figsize: tuple[int] = (20, 8),
2606
+ ) -> None:
2607
+ """Plot each individual pixel in a dotplot for all dates.
2608
+
2609
+ Args:
2610
+ by: Band attributes to groupby. Defaults to "bounds" and "band_id"
2611
+ if all bands have no-None band_ids, otherwise defaults to "bounds".
2612
+ x_var: Attribute to use on the x-axis. Defaults to "date"
2613
+ if the ImageCollection is sortable by date, otherwise a range index.
2614
+ Can be set to "days_since_start".
2615
+ y_label: Label to use on the y-axis.
2616
+ p: p-value for the confidence interval.
2617
+ ylim: Limits of the y-axis.
2618
+ figsize: Figure size as tuple (width, height).
2619
+
2620
+ """
2621
+ if by is None and all(band.band_id is not None for img in self for band in img):
2622
+ by = ["bounds", "band_id"]
2623
+ elif by is None:
2624
+ by = ["bounds"]
2625
+
2626
+ alpha = 1 - p
2627
+
2628
+ for img in self:
2629
+ for band in img:
2630
+ band.load()
2631
+
2632
+ for group_values, subcollection in self.groupby(by):
2633
+ print("group_values:", *group_values)
2634
+
2635
+ y = np.array([band.values for img in subcollection for band in img])
2636
+ if "date" in x_var and subcollection._should_be_sorted:
2637
+ x = np.array(
2638
+ [
2639
+ datetime.datetime.strptime(band.date[:8], "%Y%m%d").date()
2640
+ for img in subcollection
2641
+ for band in img
2642
+ ]
2643
+ )
2644
+ x = (
2645
+ pd.to_datetime(
2646
+ [band.date[:8] for img in subcollection for band in img]
2647
+ )
2648
+ - pd.Timestamp(np.min(x))
2649
+ ).days
2650
+ else:
2651
+ x = np.arange(0, len(y))
2652
+
2653
+ mask = np.array(
2654
+ [
2655
+ (
2656
+ band.values.mask
2657
+ if hasattr(band.values, "mask")
2658
+ else np.full(band.values.shape, False)
2659
+ )
2660
+ for img in subcollection
2661
+ for band in img
2662
+ ]
2663
+ )
2664
+
2665
+ if x_var == "days_since_start":
2666
+ x = x - np.min(x)
2667
+
2668
+ for i in range(y.shape[1]):
2669
+ for j in range(y.shape[2]):
2670
+ this_y = y[:, i, j]
2671
+
2672
+ this_mask = mask[:, i, j]
2673
+ this_x = x[~this_mask]
2674
+ this_y = this_y[~this_mask]
2675
+
2676
+ if ylim:
2677
+ condition = (this_y >= ylim[0]) & (this_y <= ylim[1])
2678
+ this_y = this_y[condition]
2679
+ this_x = this_x[condition]
2680
+
2681
+ coef, intercept = np.linalg.lstsq(
2682
+ np.vstack([this_x, np.ones(this_x.shape[0])]).T,
2683
+ this_y,
2684
+ rcond=None,
2685
+ )[0]
2686
+ predicted = np.array([intercept + coef * x for x in this_x])
2687
+
2688
+ # Degrees of freedom
2689
+ dof = len(this_x) - 2
2690
+
2691
+ # 95% confidence interval
2692
+ t_val = stats.t.ppf(1 - alpha / 2, dof)
2693
+
2694
+ # Mean squared error of the residuals
2695
+ mse = np.sum((this_y - predicted) ** 2) / dof
2696
+
2697
+ # Calculate the standard error of predictions
2698
+ pred_stderr = np.sqrt(
2699
+ mse
2700
+ * (
2701
+ 1 / len(this_x)
2702
+ + (this_x - np.mean(this_x)) ** 2
2703
+ / np.sum((this_x - np.mean(this_x)) ** 2)
2704
+ )
2705
+ )
2706
+
2707
+ # Calculate the confidence interval for predictions
2708
+ ci_lower = predicted - t_val * pred_stderr
2709
+ ci_upper = predicted + t_val * pred_stderr
2710
+
2711
+ rounding = int(np.log(1 / abs(coef)))
2712
+
2713
+ fig = plt.figure(figsize=figsize)
2714
+ ax = fig.add_subplot(1, 1, 1)
2715
+
2716
+ ax.scatter(this_x, this_y, color="#2c93db")
2717
+ ax.plot(this_x, predicted, color="#e0436b")
2718
+ ax.fill_between(
2719
+ this_x,
2720
+ ci_lower,
2721
+ ci_upper,
2722
+ color="#e0436b",
2723
+ alpha=0.2,
2724
+ label=f"{int(alpha*100)}% CI",
2725
+ )
2726
+ plt.title(f"Coefficient: {round(coef, rounding)}")
2727
+ plt.xlabel(x_var)
2728
+ plt.ylabel(y_label)
2729
+ plt.show()
2730
+
2731
+
2732
+ def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCollection:
2733
+ """Union multiple ImageCollections together.
2734
+
2735
+ Same as using the union operator |.
2736
+ """
2737
+ resolutions = {x.res for x in collections}
2738
+ if len(resolutions) > 1:
2739
+ raise ValueError(f"resoultion mismatch. {resolutions}")
2740
+ images = list(itertools.chain.from_iterable([x.images for x in collections]))
2741
+ levels = {x.level for x in collections}
2742
+ level = next(iter(levels)) if len(levels) == 1 else None
2743
+ first_collection = collections[0]
2744
+
2745
+ out_collection = first_collection.__class__(
2746
+ images,
2747
+ level=level,
2748
+ band_class=first_collection.band_class,
2749
+ image_class=first_collection.image_class,
2750
+ **first_collection._common_init_kwargs,
2751
+ )
2752
+ out_collection._all_file_paths = list(
2753
+ sorted(
2754
+ set(itertools.chain.from_iterable([x._all_file_paths for x in collections]))
2755
+ )
2756
+ )
2757
+ return out_collection
2758
+
2759
+
2760
+ def _get_gradient(band: Band, degrees: bool = False, copy: bool = True) -> Band:
2761
+ copied = band.copy() if copy else band
2762
+ if len(copied.values.shape) == 3:
2763
+ return np.array(
2764
+ [_slope_2d(arr, copied.res, degrees=degrees) for arr in copied.values]
2765
+ )
2766
+ elif len(copied.values.shape) == 2:
2767
+ return _slope_2d(copied.values, copied.res, degrees=degrees)
2768
+ else:
2769
+ raise ValueError("array must be 2 or 3 dimensional")
2770
+
2771
+
2772
+ def to_xarray(
2773
+ array: np.ndarray, transform: Affine, crs: Any, name: str | None = None
2774
+ ) -> DataArray:
2775
+ """Convert the raster to an xarray.DataArray."""
2776
+ if len(array.shape) == 2:
2777
+ height, width = array.shape
2778
+ dims = ["y", "x"]
2779
+ elif len(array.shape) == 3:
2780
+ height, width = array.shape[1:]
2781
+ dims = ["band", "y", "x"]
2782
+ else:
2783
+ raise ValueError(f"Array should be 2 or 3 dimensional. Got shape {array.shape}")
2784
+
2785
+ coords = _generate_spatial_coords(transform, width, height)
2786
+ return xr.DataArray(
2787
+ array,
2788
+ coords=coords,
2789
+ dims=dims,
2790
+ name=name,
2791
+ attrs={"crs": crs},
2792
+ )
2793
+
2794
+
2795
+ def _slope_2d(array: np.ndarray, res: int, degrees: int) -> np.ndarray:
2796
+ gradient_x, gradient_y = np.gradient(array, res, res)
2797
+
2798
+ gradient = abs(gradient_x) + abs(gradient_y)
2799
+
2800
+ if not degrees:
2801
+ return gradient
2802
+
2803
+ radians = np.arctan(gradient)
2804
+ degrees = np.degrees(radians)
2805
+
2806
+ assert np.max(degrees) <= 90
2807
+
2808
+ return degrees
2809
+
2810
+
2811
+ def _clip_loaded_array(
2812
+ arr: np.ndarray,
2813
+ bounds: tuple[int, int, int, int],
2814
+ transform: Affine,
2815
+ crs: Any,
2816
+ out_shape: tuple[int, int],
2817
+ **kwargs,
2818
+ ) -> np.ndarray:
2819
+ # xarray needs a numpy array of polygon(s)
2820
+ bounds_arr: np.ndarray = GeoSeries([to_shapely(bounds)]).values
2821
+ try:
2822
+
2823
+ while out_shape != arr.shape:
2824
+ arr = (
2825
+ to_xarray(
2826
+ arr,
2827
+ transform=transform,
2828
+ crs=crs,
2829
+ )
2830
+ .rio.clip(bounds_arr, crs=crs, **kwargs)
2831
+ .to_numpy()
2832
+ )
2833
+ # bounds_arr = bounds_arr.buffer(0.0000001)
2834
+ return arr
2835
+
2836
+ except NoDataInBounds:
2837
+ return np.array([])
2838
+
2839
+
2840
+ def _get_images(
2841
+ image_paths: list[str],
2842
+ *,
2843
+ all_file_paths: list[str],
2844
+ df: pd.DataFrame,
2845
+ processes: int,
2846
+ image_class: type,
2847
+ band_class: type,
2848
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None,
2849
+ masking: BandMasking | None,
2850
+ **kwargs,
2851
+ ) -> list[Image]:
2852
+
2853
+ with joblib.Parallel(n_jobs=processes, backend="loky") as parallel:
2854
+ images = parallel(
2855
+ joblib.delayed(image_class)(
2856
+ path,
2857
+ df=df,
2858
+ all_file_paths=all_file_paths,
2859
+ masking=masking,
2860
+ band_class=band_class,
2861
+ **kwargs,
2862
+ )
2863
+ for path in image_paths
2864
+ )
2865
+ if bbox is not None:
2866
+ intersects_list = GeoSeries([img.union_all() for img in images]).intersects(
2867
+ to_shapely(bbox)
2868
+ )
2869
+ return [
2870
+ img
2871
+ for img, intersects in zip(images, intersects_list, strict=False)
2872
+ if intersects
2873
+ ]
2874
+ return images
2875
+
2876
+
2877
+ def numpy_to_torch(array: np.ndarray) -> torch.Tensor:
2878
+ """Convert numpy array to a pytorch tensor."""
2879
+ # fix numpy dtypes which are not supported by pytorch tensors
2880
+ if array.dtype == np.uint16:
2881
+ array = array.astype(np.int32)
2882
+ elif array.dtype == np.uint32:
2883
+ array = array.astype(np.int64)
2884
+
2885
+ return torch.tensor(array)
2886
+
2887
+
2888
+ class _RegexError(ValueError):
2889
+ pass
2890
+
2891
+
2892
+ class ArrayNotLoadedError(ValueError):
2893
+ """Arrays are not loaded."""
2894
+
2895
+
2896
+ class PathlessImageError(ValueError):
2897
+ """'path' attribute is needed but instance has no path."""
2898
+
2899
+ def __init__(self, instance: _ImageBase) -> None:
2900
+ """Initialise error class."""
2901
+ self.instance = instance
2902
+
2903
+ def __str__(self) -> str:
2904
+ """String representation."""
2905
+ if self.instance._merged:
2906
+ what = "that have been merged"
2907
+ elif self.isinstance._from_array:
2908
+ what = "from arrays"
2909
+ elif self.isinstance._from_gdf:
2910
+ what = "from GeoDataFrames"
2911
+
2912
+ return (
2913
+ f"{self.instance.__class__.__name__} instances {what} "
2914
+ "have no 'path' until they are written to file."
2915
+ )
2916
+
2917
+
2918
+ def _get_regex_match_from_xml_in_local_dir(
2919
+ paths: list[str], regexes: str | tuple[str]
2920
+ ) -> str | dict[str, str]:
2921
+ for i, path in enumerate(paths):
2922
+ if ".xml" not in path:
2923
+ continue
2924
+ with _open_func(path, "rb") as file:
2925
+ filebytes: bytes = file.read()
2926
+ try:
2927
+ return _extract_regex_match_from_string(
2928
+ filebytes.decode("utf-8"), regexes
2929
+ )
2930
+ except _RegexError as e:
2931
+ if i == len(paths) - 1:
2932
+ raise e
2933
+
2934
+
2935
+ def _extract_regex_match_from_string(
2936
+ xml_file: str, regexes: tuple[str | re.Pattern]
2937
+ ) -> str | dict[str, str]:
2938
+ if all(isinstance(x, str) for x in regexes):
2939
+ for regex in regexes:
2940
+ try:
2941
+ return re.search(regex, xml_file).group(1)
2942
+ except (TypeError, AttributeError):
2943
+ continue
2944
+ raise _RegexError()
2945
+
2946
+ out = {}
2947
+ for regex in regexes:
2948
+ try:
2949
+ matches = re.search(regex, xml_file)
2950
+ out |= matches.groupdict()
2951
+ except (TypeError, AttributeError):
2952
+ continue
2953
+ if not out:
2954
+ raise _RegexError()
2955
+ return out
2956
+
2957
+
2958
+ def _fix_path(path: str) -> str:
2959
+ return (
2960
+ str(path).replace("\\", "/").replace(r"\"", "/").replace("//", "/").rstrip("/")
2961
+ )
2962
+
2963
+
2964
+ def _get_regexes_matches_for_df(
2965
+ df, match_col: str, patterns: Sequence[re.Pattern]
2966
+ ) -> pd.DataFrame:
2967
+ if not len(df):
2968
+ return df
2969
+
2970
+ non_optional_groups = list(
2971
+ set(
2972
+ itertools.chain.from_iterable(
2973
+ [_get_non_optional_groups(pat) for pat in patterns]
2974
+ )
2975
+ )
2976
+ )
2977
+
2978
+ if not non_optional_groups:
2979
+ return df
2980
+
2981
+ assert df.index.is_unique
2982
+ keep = []
2983
+ for pat in patterns:
2984
+ for i, row in df[match_col].items():
2985
+ matches = _get_first_group_match(pat, row)
2986
+ if all(group in matches for group in non_optional_groups):
2987
+ keep.append(i)
2988
+
2989
+ return df.loc[keep]
2990
+
2991
+
2992
+ def _get_non_optional_groups(pat: re.Pattern | str) -> list[str]:
2993
+ return [
2994
+ x
2995
+ for x in [
2996
+ _extract_group_name(group)
2997
+ for group in pat.pattern.split("\n")
2998
+ if group
2999
+ and not group.replace(" ", "").startswith("#")
3000
+ and not group.replace(" ", "").split("#")[0].endswith("?")
3001
+ ]
3002
+ if x is not None
3003
+ ]
3004
+
3005
+
3006
+ def _extract_group_name(txt: str) -> str | None:
3007
+ try:
3008
+ return re.search(r"\(\?P<(\w+)>", txt)[1]
3009
+ except TypeError:
3010
+ return None
3011
+
3012
+
3013
+ def _get_first_group_match(pat: re.Pattern, text: str) -> dict[str, str]:
3014
+ groups = pat.groupindex.keys()
3015
+ all_matches: dict[str, str] = {}
3016
+ for x in pat.findall(text):
3017
+ for group, value in zip(groups, x, strict=True):
3018
+ if value and group not in all_matches:
3019
+ all_matches[group] = value
3020
+ return all_matches
3021
+
3022
+
3023
+ def _date_is_within(
3024
+ path,
3025
+ date_ranges: (
3026
+ tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...] | None
3027
+ ),
3028
+ image_patterns: Sequence[re.Pattern],
3029
+ date_format: str,
3030
+ ) -> bool:
3031
+ for pat in image_patterns:
3032
+
3033
+ try:
3034
+ date = _get_first_group_match(pat, Path(path).name)["date"]
3035
+ break
3036
+ except KeyError:
3037
+ date = None
3038
+
3039
+ if date is None:
3040
+ return False
3041
+
3042
+ if date_ranges is None:
3043
+ return True
3044
+
3045
+ if all(x is None or isinstance(x, (str, float)) for x in date_ranges):
3046
+ date_ranges = (date_ranges,)
3047
+
3048
+ if all(isinstance(x, float) for date_range in date_ranges for x in date_range):
3049
+ date = disambiguate_timestamp(date, date_format)
3050
+ else:
3051
+ date = date[:8]
3052
+
3053
+ for date_range in date_ranges:
3054
+ date_min, date_max = date_range
3055
+
3056
+ if isinstance(date_min, float) and isinstance(date_max, float):
3057
+ if date[0] >= date_min + 0.0000001 and date[1] <= date_max - 0.0000001:
3058
+ return True
3059
+ continue
3060
+
3061
+ try:
3062
+ date_min = date_min or "00000000"
3063
+ date_max = date_max or "99999999"
3064
+ if not (
3065
+ isinstance(date_min, str)
3066
+ and len(date_min) == 8
3067
+ and isinstance(date_max, str)
3068
+ and len(date_max) == 8
3069
+ ):
3070
+ raise ValueError()
3071
+ except ValueError as err:
3072
+ raise TypeError(
3073
+ "date_ranges should be a tuple of two 8-charactered strings (start and end date)."
3074
+ f"Got {date_range} of type {[type(x) for x in date_range]}"
3075
+ ) from err
3076
+ if date >= date_min and date <= date_max:
3077
+ return True
3078
+
3079
+ return False
3080
+
3081
+
3082
+ def _get_dtype_min(dtype: str | type) -> int | float:
3083
+ try:
3084
+ return np.iinfo(dtype).min
3085
+ except ValueError:
3086
+ return np.finfo(dtype).min
3087
+
3088
+
3089
+ def _get_dtype_max(dtype: str | type) -> int | float:
3090
+ try:
3091
+ return np.iinfo(dtype).max
3092
+ except ValueError:
3093
+ return np.finfo(dtype).max
3094
+
3095
+
3096
+ def _img_ndvi(img, **kwargs):
3097
+ return Image([img.ndvi(**kwargs)])
3098
+
3099
+
3100
+ def _intesects(x, other) -> bool:
3101
+ return box(*x.bounds).intersects(other)
3102
+
3103
+
3104
+ def _copy_and_add_df_parallel(
3105
+ i: tuple[Any, ...], group: pd.DataFrame, self: ImageCollection
3106
+ ) -> tuple[tuple[Any], ImageCollection]:
3107
+ copied = self.copy()
3108
+ copied.images = [
3109
+ img.copy() for img in group.drop_duplicates("_image_idx")["_image_instance"]
3110
+ ]
3111
+ if "band_id" in group:
3112
+ band_ids = set(group["band_id"].values)
3113
+ for img in copied.images:
3114
+ img._bands = [band for band in img if band.band_id in band_ids]
3115
+
3116
+ return (i, copied)
3117
+
3118
+
3119
+ def _get_single_value(values: tuple):
3120
+ if len(set(values)) == 1:
3121
+ return next(iter(values))
3122
+ else:
3123
+ return None
3124
+
3125
+
3126
+ def _open_raster(path: str | Path) -> rasterio.io.DatasetReader:
3127
+ with opener(path) as file:
3128
+ return rasterio.open(file)
3129
+
3130
+
3131
+ def _load_band(band: Band, **kwargs) -> None:
3132
+ band.load(**kwargs)
3133
+
3134
+
3135
+ def _merge_by_band(collection: ImageCollection, **kwargs) -> Image:
3136
+ return collection.merge_by_band(**kwargs)
3137
+
3138
+
3139
+ def _merge(collection: ImageCollection, **kwargs) -> Band:
3140
+ return collection.merge(**kwargs)
3141
+
3142
+
3143
+ def _zonal_one_pair(i: int, poly: Polygon, band: Band, aggfunc, array_func, func_names):
3144
+ clipped = band.copy().load(bounds=poly)
3145
+ if not np.size(clipped.values):
3146
+ return _no_overlap_df(func_names, i, date=band.date)
3147
+ return _aggregate(clipped.values, array_func, aggfunc, func_names, band.date, i)
3148
+
3149
+
3150
+ def array_buffer(arr: np.ndarray, distance: int) -> np.ndarray:
3151
+ """Buffer array points with the value 1 in a binary array.
3152
+
3153
+ Args:
3154
+ arr: The array.
3155
+ distance: Number of array cells to buffer by.
3156
+
3157
+ Returns:
3158
+ Array with buffered values.
3159
+ """
3160
+ if not np.all(np.isin(arr, (1, 0, True, False))):
3161
+ raise ValueError("Array must be all 0s and 1s or boolean.")
3162
+
3163
+ dtype = arr.dtype
3164
+
3165
+ structure = np.ones((2 * abs(distance) + 1, 2 * abs(distance) + 1))
3166
+
3167
+ arr = np.where(arr, 1, 0)
3168
+
3169
+ if distance > 0:
3170
+ return binary_dilation(arr, structure=structure).astype(dtype)
3171
+ elif distance < 0:
3172
+
3173
+ return binary_erosion(arr, structure=structure).astype(dtype)
3174
+
3175
+
3176
+ class Sentinel2Config:
3177
+ """Holder of Sentinel 2 regexes, band_ids etc."""
3178
+
3179
+ image_regexes: ClassVar[str] = (config.SENTINEL2_IMAGE_REGEX,)
3180
+ filename_regexes: ClassVar[str] = (
3181
+ config.SENTINEL2_FILENAME_REGEX,
3182
+ config.SENTINEL2_CLOUD_FILENAME_REGEX,
3183
+ )
3184
+ all_bands: ClassVar[list[str]] = list(config.SENTINEL2_BANDS)
3185
+ rbg_bands: ClassVar[list[str]] = config.SENTINEL2_RBG_BANDS
3186
+ ndvi_bands: ClassVar[list[str]] = config.SENTINEL2_NDVI_BANDS
3187
+ l2a_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L2A_BANDS
3188
+ l1c_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L1C_BANDS
3189
+ date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
3190
+ masking: ClassVar[BandMasking] = BandMasking(
3191
+ band_id="SCL", values=(3, 8, 9, 10, 11)
3192
+ )
3193
+
3194
+
3195
+ class Sentinel2CloudlessConfig(Sentinel2Config):
3196
+ """Holder of regexes, band_ids etc. for Sentinel 2 cloudless mosaic."""
3197
+
3198
+ image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
3199
+ filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
3200
+ masking: ClassVar[None] = None
3201
+ date_format: ClassVar[str] = "%Y%m%d"
3202
+ all_bands: ClassVar[list[str]] = [
3203
+ x.replace("B0", "B") for x in Sentinel2Config.all_bands
3204
+ ]
3205
+ rbg_bands: ClassVar[list[str]] = [
3206
+ x.replace("B0", "B") for x in Sentinel2Config.rbg_bands
3207
+ ]
3208
+ ndvi_bands: ClassVar[list[str]] = [
3209
+ x.replace("B0", "B") for x in Sentinel2Config.ndvi_bands
3210
+ ]
3211
+
3212
+
3213
+ class Sentinel2Band(Sentinel2Config, Band):
3214
+ """Band with Sentinel2 specific name variables and regexes."""
3215
+
3216
+
3217
+ class Sentinel2Image(Sentinel2Config, Image):
3218
+ """Image with Sentinel2 specific name variables and regexes."""
3219
+
3220
+ cloud_cover_regexes: ClassVar[tuple[str]] = config.CLOUD_COVERAGE_REGEXES
3221
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3222
+
3223
+ def ndvi(
3224
+ self,
3225
+ red_band: str = Sentinel2Config.ndvi_bands[0],
3226
+ nir_band: str = Sentinel2Config.ndvi_bands[1],
3227
+ copy: bool = True,
3228
+ ) -> NDVIBand:
3229
+ """Calculate the NDVI for the Image."""
3230
+ return super().ndvi(red_band=red_band, nir_band=nir_band, copy=copy)
3231
+
3232
+
3233
+ class Sentinel2Collection(Sentinel2Config, ImageCollection):
3234
+ """ImageCollection with Sentinel2 specific name variables and regexes."""
3235
+
3236
+ image_class: ClassVar[Sentinel2Image] = Sentinel2Image
3237
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3238
+
3239
+
3240
+ class Sentinel2CloudlessBand(Sentinel2CloudlessConfig, Band):
3241
+ """Band for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3242
+
3243
+
3244
+ class Sentinel2CloudlessImage(Sentinel2CloudlessConfig, Sentinel2Image):
3245
+ """Image for cloudless mosaic with Sentinel2 specific name variables and regexes."""
3246
+
3247
+ cloud_cover_regexes: ClassVar[None] = None
3248
+ band_class: ClassVar[Sentinel2CloudlessBand] = Sentinel2CloudlessBand
3249
+
3250
+ ndvi = Sentinel2Image.ndvi
3251
+
3252
+
3253
+ class Sentinel2CloudlessCollection(Sentinel2CloudlessConfig, ImageCollection):
3254
+ """ImageCollection with Sentinel2 specific name variables and regexes."""
3255
+
3256
+ image_class: ClassVar[Sentinel2CloudlessImage] = Sentinel2CloudlessImage
3257
+ band_class: ClassVar[Sentinel2Band] = Sentinel2CloudlessBand