ssb-sgis 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sgis/__init__.py +10 -6
  2. sgis/exceptions.py +2 -2
  3. sgis/geopandas_tools/bounds.py +17 -15
  4. sgis/geopandas_tools/buffer_dissolve_explode.py +24 -5
  5. sgis/geopandas_tools/conversion.py +15 -6
  6. sgis/geopandas_tools/duplicates.py +2 -2
  7. sgis/geopandas_tools/general.py +9 -5
  8. sgis/geopandas_tools/geometry_types.py +3 -3
  9. sgis/geopandas_tools/neighbors.py +3 -3
  10. sgis/geopandas_tools/point_operations.py +2 -2
  11. sgis/geopandas_tools/polygon_operations.py +5 -5
  12. sgis/geopandas_tools/sfilter.py +3 -3
  13. sgis/helpers.py +3 -3
  14. sgis/io/read_parquet.py +1 -1
  15. sgis/maps/examine.py +16 -2
  16. sgis/maps/explore.py +370 -57
  17. sgis/maps/legend.py +164 -72
  18. sgis/maps/map.py +184 -90
  19. sgis/maps/maps.py +92 -90
  20. sgis/maps/thematicmap.py +236 -83
  21. sgis/networkanalysis/closing_network_holes.py +2 -2
  22. sgis/networkanalysis/cutting_lines.py +3 -3
  23. sgis/networkanalysis/directednetwork.py +1 -1
  24. sgis/networkanalysis/finding_isolated_networks.py +2 -2
  25. sgis/networkanalysis/networkanalysis.py +7 -7
  26. sgis/networkanalysis/networkanalysisrules.py +1 -1
  27. sgis/networkanalysis/traveling_salesman.py +1 -1
  28. sgis/parallel/parallel.py +39 -19
  29. sgis/raster/__init__.py +0 -6
  30. sgis/raster/cube.py +51 -5
  31. sgis/raster/image_collection.py +2560 -0
  32. sgis/raster/indices.py +14 -5
  33. sgis/raster/raster.py +131 -236
  34. sgis/raster/sentinel_config.py +104 -0
  35. sgis/raster/zonal.py +0 -1
  36. {ssb_sgis-1.0.2.dist-info → ssb_sgis-1.0.3.dist-info}/METADATA +1 -1
  37. ssb_sgis-1.0.3.dist-info/RECORD +61 -0
  38. sgis/raster/methods_as_functions.py +0 -0
  39. sgis/raster/torchgeo.py +0 -171
  40. ssb_sgis-1.0.2.dist-info/RECORD +0 -61
  41. {ssb_sgis-1.0.2.dist-info → ssb_sgis-1.0.3.dist-info}/LICENSE +0 -0
  42. {ssb_sgis-1.0.2.dist-info → ssb_sgis-1.0.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,2560 @@
1
+ import functools
2
+ import glob
3
+ import itertools
4
+ import numbers
5
+ import os
6
+ import random
7
+ import re
8
+ import warnings
9
+ from collections.abc import Callable
10
+ from collections.abc import Iterable
11
+ from collections.abc import Iterator
12
+ from collections.abc import Sequence
13
+ from copy import deepcopy
14
+ from json import loads
15
+ from pathlib import Path
16
+ from typing import Any
17
+ from typing import ClassVar
18
+
19
+ import joblib
20
+ import numpy as np
21
+ import pandas as pd
22
+ import pyproj
23
+ import rasterio
24
+ from affine import Affine
25
+ from geopandas import GeoDataFrame
26
+ from geopandas import GeoSeries
27
+ from rasterio import features
28
+ from rasterio.enums import MergeAlg
29
+ from rtree.index import Index
30
+ from rtree.index import Property
31
+ from shapely import Geometry
32
+ from shapely import box
33
+ from shapely import unary_union
34
+ from shapely.geometry import MultiPolygon
35
+ from shapely.geometry import Point
36
+ from shapely.geometry import Polygon
37
+ from shapely.geometry import shape
38
+
39
+ try:
40
+ import dapla as dp
41
+ from dapla.gcs import GCSFileSystem
42
+ except ImportError:
43
+
44
+ class GCSFileSystem:
45
+ """Placeholder."""
46
+
47
+
48
+ try:
49
+ from rioxarray.exceptions import NoDataInBounds
50
+ from rioxarray.merge import merge_arrays
51
+ from rioxarray.rioxarray import _generate_spatial_coords
52
+ except ImportError:
53
+ pass
54
+ try:
55
+ import xarray as xr
56
+ from xarray import DataArray
57
+ except ImportError:
58
+
59
+ class DataArray:
60
+ """Placeholder."""
61
+
62
+
63
+ try:
64
+ import torch
65
+ except ImportError:
66
+ pass
67
+
68
+ try:
69
+ from gcsfs.core import GCSFile
70
+ except ImportError:
71
+
72
+ class GCSFile:
73
+ """Placeholder."""
74
+
75
+
76
+ try:
77
+ from torchgeo.datasets.utils import disambiguate_timestamp
78
+ except ImportError:
79
+
80
+ class torch:
81
+ """Placeholder."""
82
+
83
+ class Tensor:
84
+ """Placeholder to reference torch.Tensor."""
85
+
86
+
87
+ try:
88
+ from torchgeo.datasets.utils import BoundingBox
89
+ except ImportError:
90
+
91
+ class BoundingBox:
92
+ """Placeholder."""
93
+
94
+ def __init__(self, *args, **kwargs) -> None:
95
+ """Placeholder."""
96
+ raise ImportError("missing optional dependency 'torchgeo'")
97
+
98
+
99
+ from ..geopandas_tools.bounds import get_total_bounds
100
+ from ..geopandas_tools.bounds import to_bbox
101
+ from ..geopandas_tools.conversion import to_gdf
102
+ from ..geopandas_tools.conversion import to_shapely
103
+ from ..geopandas_tools.general import get_common_crs
104
+ from ..helpers import get_all_files
105
+ from ..helpers import get_numpy_func
106
+ from ..io._is_dapla import is_dapla
107
+ from ..io.opener import opener
108
+ from . import sentinel_config as config
109
+ from .base import get_index_mapper
110
+ from .indices import ndvi
111
+ from .zonal import _aggregate
112
+ from .zonal import _make_geometry_iterrows
113
+ from .zonal import _no_overlap_df
114
+ from .zonal import _prepare_zonal
115
+ from .zonal import _zonal_post
116
+
117
+ if is_dapla():
118
+
119
+ def ls_func(*args, **kwargs) -> list[str]:
120
+ return dp.FileClient.get_gcs_file_system().ls(*args, **kwargs)
121
+
122
+ def glob_func(*args, **kwargs) -> list[str]:
123
+ return dp.FileClient.get_gcs_file_system().glob(*args, **kwargs)
124
+
125
+ def open_func(*args, **kwargs) -> GCSFile:
126
+ return dp.FileClient.get_gcs_file_system().open(*args, **kwargs)
127
+
128
+ else:
129
+ ls_func = functools.partial(get_all_files, recursive=False)
130
+ open_func = open
131
+ glob_func = glob.glob
132
+
133
+ TORCHGEO_RETURN_TYPE = dict[str, torch.Tensor | pyproj.CRS | BoundingBox]
134
+ FILENAME_COL_SUFFIX = "_filename"
135
+ DEFAULT_FILENAME_REGEX = r".*\.(?:tif|tiff|jp2)$"
136
+
137
+
138
+ class ImageCollectionGroupBy:
139
+ """Iterator and merger class returned from groupby.
140
+
141
+ Can be iterated through like pandas.DataFrameGroupBy.
142
+ Or use the methods merge_by_band or merge.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ data: Iterable[tuple[Any], "ImageCollection"],
148
+ by: list[str],
149
+ collection: "ImageCollection",
150
+ ) -> None:
151
+ """Initialiser.
152
+
153
+ Args:
154
+ data: Iterable of group values and ImageCollection groups.
155
+ by: list of group attributes.
156
+ collection: ImageCollection instance. Used to pass attributes.
157
+ """
158
+ self.data = list(data)
159
+ self.by = by
160
+ self.collection = collection
161
+
162
+ def merge_by_band(
163
+ self, bounds=None, method="median", as_int: bool = True, indexes=None, **kwargs
164
+ ) -> "ImageCollection":
165
+ """Merge each group into separate Bands per band_id, returned as an ImageCollection."""
166
+ images = self._run_func_for_collection_groups(
167
+ _merge_by_band,
168
+ method=method,
169
+ bounds=bounds,
170
+ as_int=as_int,
171
+ indexes=indexes,
172
+ **kwargs,
173
+ )
174
+ print("hihihih")
175
+ for img, (group_values, _) in zip(images, self.data, strict=True):
176
+ for attr, group_value in zip(self.by, group_values, strict=True):
177
+ print(attr, group_value)
178
+ try:
179
+ setattr(img, attr, group_value)
180
+ except AttributeError:
181
+ setattr(img, f"_{attr}", group_value)
182
+
183
+ collection = ImageCollection(
184
+ images,
185
+ level=self.collection.level,
186
+ **self.collection._common_init_kwargs,
187
+ )
188
+ collection._merged = True
189
+ return collection
190
+
191
+ def merge(
192
+ self, bounds=None, method="median", as_int: bool = True, indexes=None, **kwargs
193
+ ) -> "Image":
194
+ """Merge each group into a single Band, returned as combined Image."""
195
+ bands: list[Band] = self._run_func_for_collection_groups(
196
+ _merge,
197
+ method=method,
198
+ bounds=bounds,
199
+ as_int=as_int,
200
+ indexes=indexes,
201
+ **kwargs,
202
+ )
203
+ for band, (group_values, _) in zip(bands, self.data, strict=True):
204
+ for attr, group_value in zip(self.by, group_values, strict=True):
205
+ try:
206
+ setattr(band, attr, group_value)
207
+ except AttributeError:
208
+ if hasattr(band, f"_{attr}"):
209
+ setattr(band, f"_{attr}", group_value)
210
+
211
+ if "band_id" in self.by:
212
+ for band in bands:
213
+ assert band.band_id is not None
214
+
215
+ image = Image(
216
+ bands,
217
+ **self.collection._common_init_kwargs,
218
+ )
219
+ image._merged = True
220
+ return image
221
+
222
+ def _run_func_for_collection_groups(self, func: Callable, **kwargs) -> list[Any]:
223
+ if self.collection.processes == 1:
224
+ return [func(group, **kwargs) for _, group in self]
225
+ processes = min(self.collection.processes, len(self))
226
+
227
+ if processes == 0:
228
+ return []
229
+
230
+ with joblib.Parallel(n_jobs=processes, backend="threading") as parallel:
231
+ return parallel(joblib.delayed(func)(group, **kwargs) for _, group in self)
232
+
233
+ def __iter__(self) -> Iterator[tuple[tuple[Any, ...], "ImageCollection"]]:
234
+ """Iterate over the group values and the ImageCollection groups themselves."""
235
+ return iter(self.data)
236
+
237
+ def __len__(self) -> int:
238
+ """Number of ImageCollection groups."""
239
+ return len(self.data)
240
+
241
+ def __repr__(self) -> str:
242
+ """String representation."""
243
+ return f"{self.__class__.__name__}({len(self)})"
244
+
245
+
246
+ class _ImageBase:
247
+ image_regexes: ClassVar[str | None] = None
248
+ filename_regexes: ClassVar[str | tuple[str]] = (DEFAULT_FILENAME_REGEX,)
249
+ date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
250
+
251
+ def __init__(self) -> None:
252
+
253
+ self._merged = False
254
+ self._from_array = False
255
+ self._from_gdf = False
256
+
257
+ if self.filename_regexes:
258
+ if isinstance(self.filename_regexes, str):
259
+ self.filename_regexes = (self.filename_regexes,)
260
+ self.filename_patterns = [
261
+ re.compile(regexes, flags=re.VERBOSE)
262
+ for regexes in self.filename_regexes
263
+ ]
264
+ else:
265
+ self.filename_patterns = None
266
+
267
+ if self.image_regexes:
268
+ if isinstance(self.image_regexes, str):
269
+ self.image_regexes = (self.image_regexes,)
270
+ self.image_patterns = [
271
+ re.compile(regexes, flags=re.VERBOSE) for regexes in self.image_regexes
272
+ ]
273
+ else:
274
+ self.image_patterns = None
275
+
276
+ @property
277
+ def _common_init_kwargs(self) -> dict:
278
+ return {
279
+ "file_system": self.file_system,
280
+ "processes": self.processes,
281
+ "res": self.res,
282
+ "_mask": self._mask,
283
+ }
284
+
285
+ @property
286
+ def path(self) -> str:
287
+ try:
288
+ return self._path
289
+ except AttributeError as e:
290
+ raise PathlessImageError(self.__class__.__name__) from e
291
+
292
+ @property
293
+ def res(self) -> int:
294
+ """Pixel resolution."""
295
+ return self._res
296
+
297
+ @property
298
+ def centroid(self) -> Point:
299
+ return self.unary_union.centroid
300
+
301
+ def _name_regex_searcher(
302
+ self, group: str, patterns: tuple[re.Pattern]
303
+ ) -> str | None:
304
+ if not patterns or not any(pat.groups for pat in patterns):
305
+ return None
306
+ for pat in patterns:
307
+ try:
308
+ return re.match(pat, self.name).group(group)
309
+ except (AttributeError, TypeError, IndexError):
310
+ pass
311
+ raise ValueError(
312
+ f"Couldn't find group '{group}' in name {self.name} with regex patterns {patterns}"
313
+ )
314
+
315
+ def _create_metadata_df(self, file_paths: list[str]) -> None:
316
+ df = pd.DataFrame({"file_path": file_paths})
317
+
318
+ df["filename"] = df["file_path"].apply(lambda x: _fix_path(Path(x).name))
319
+ if not self.single_banded:
320
+ df["image_path"] = df["file_path"].apply(
321
+ lambda x: _fix_path(str(Path(x).parent))
322
+ )
323
+ else:
324
+ df["image_path"] = df["file_path"]
325
+
326
+ if not len(df):
327
+ return df
328
+
329
+ if self.filename_patterns:
330
+ df, match_cols_filename = _get_regexes_matches_for_df(
331
+ df, "filename", self.filename_patterns, suffix=FILENAME_COL_SUFFIX
332
+ )
333
+
334
+ if not len(df):
335
+ return df
336
+
337
+ self._match_cols_filename = match_cols_filename
338
+ grouped = (
339
+ df.drop(columns=match_cols_filename, errors="ignore")
340
+ .drop_duplicates("image_path")
341
+ .set_index("image_path")
342
+ )
343
+ for col in ["file_path", "filename", *match_cols_filename]:
344
+ if col in df:
345
+ grouped[col] = df.groupby("image_path")[col].apply(tuple)
346
+
347
+ grouped = grouped.reset_index()
348
+ else:
349
+ df["file_path"] = df.groupby("image_path")["file_path"].apply(tuple)
350
+ df["filename"] = df.groupby("image_path")["filename"].apply(tuple)
351
+ grouped = df.drop_duplicates("image_path")
352
+
353
+ grouped["imagename"] = grouped["image_path"].apply(
354
+ lambda x: _fix_path(Path(x).name)
355
+ )
356
+
357
+ if self.image_patterns and len(grouped):
358
+ grouped, _ = _get_regexes_matches_for_df(
359
+ grouped, "imagename", self.image_patterns, suffix=""
360
+ )
361
+
362
+ if "date" in grouped:
363
+ return grouped.sort_values("date")
364
+ else:
365
+ return grouped
366
+
367
+ def copy(self) -> "_ImageBase":
368
+ """Copy the instance and its attributes."""
369
+ copied = deepcopy(self)
370
+ for key, value in copied.__dict__.items():
371
+ try:
372
+ setattr(copied, key, value.copy())
373
+ except AttributeError:
374
+ setattr(copied, key, deepcopy(value))
375
+ except TypeError:
376
+ continue
377
+ return copied
378
+
379
+
380
+ class _ImageBandBase(_ImageBase):
381
+ def intersects(self, other: GeoDataFrame | GeoSeries | Geometry) -> bool:
382
+ if hasattr(other, "crs") and not pyproj.CRS(self.crs).equals(
383
+ pyproj.CRS(other.crs)
384
+ ):
385
+ raise ValueError(f"crs mismatch: {self.crs} and {other.crs}")
386
+ return self.unary_union.intersects(to_shapely(other))
387
+
388
+ @property
389
+ def year(self) -> str:
390
+ return self.date[:4]
391
+
392
+ @property
393
+ def month(self) -> str:
394
+ return "".join(self.date.split("-"))[:6]
395
+
396
+ @property
397
+ def yyyymmd(self) -> str:
398
+ return "".join(self.date.split("-"))[:7]
399
+
400
+ @property
401
+ def name(self) -> str | None:
402
+ if hasattr(self, "_name") and self._name is not None:
403
+ return self._name
404
+ try:
405
+ return Path(self.path).name
406
+ except (ValueError, AttributeError):
407
+ return None
408
+
409
+ @name.setter
410
+ def name(self, value) -> None:
411
+ self._name = value
412
+
413
+ @property
414
+ def stem(self) -> str | None:
415
+ try:
416
+ return Path(self.path).stem
417
+ except (AttributeError, ValueError):
418
+ return None
419
+
420
+ @property
421
+ def level(self) -> str:
422
+ return self._name_regex_searcher("level", self.image_patterns)
423
+
424
+ @property
425
+ def mint(self) -> float:
426
+ return disambiguate_timestamp(self.date, self.date_format)[0]
427
+
428
+ @property
429
+ def maxt(self) -> float:
430
+ return disambiguate_timestamp(self.date, self.date_format)[1]
431
+
432
+ @property
433
+ def unary_union(self) -> Polygon:
434
+ return box(*self.bounds)
435
+
436
+ @property
437
+ def bbox(self) -> BoundingBox:
438
+ bounds = GeoSeries([self.unary_union]).bounds
439
+ return BoundingBox(
440
+ minx=bounds.minx[0],
441
+ miny=bounds.miny[0],
442
+ maxx=bounds.maxx[0],
443
+ maxy=bounds.maxy[0],
444
+ mint=self.mint,
445
+ maxt=self.maxt,
446
+ )
447
+
448
+
449
+ class Band(_ImageBandBase):
450
+ """Band holding a single 2 dimensional array representing an image band."""
451
+
452
+ cmap: ClassVar[str | None] = None
453
+
454
+ def __init__(
455
+ self,
456
+ data: str | np.ndarray,
457
+ res: int | None,
458
+ crs: Any | None = None,
459
+ bounds: tuple[float, float, float, float] | None = None,
460
+ cmap: str | None = None,
461
+ name: str | None = None,
462
+ file_system: GCSFileSystem | None = None,
463
+ band_id: str | None = None,
464
+ processes: int = 1,
465
+ _mask: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
466
+ **kwargs,
467
+ ) -> None:
468
+ """Band initialiser."""
469
+ if isinstance(data, (GeoDataFrame | GeoSeries)):
470
+ if res is None:
471
+ raise ValueError("Must specify res when data is vector geometries.")
472
+ bounds = to_bbox(bounds) if bounds is not None else data.total_bounds
473
+ crs = crs if crs else data.crs
474
+ data: np.ndarray = _arr_from_gdf(data, res=res, **kwargs)
475
+ self._from_gdf = True
476
+
477
+ if isinstance(data, np.ndarray):
478
+ self._values = data
479
+ if bounds is None:
480
+ raise ValueError("Must specify bounds when data is an array.")
481
+ self._bounds = to_bbox(bounds)
482
+ self._crs = crs
483
+ self.transform = _get_transform_from_bounds(
484
+ self.bounds, shape=self.values.shape
485
+ )
486
+ self._from_array = True
487
+
488
+ elif not isinstance(data, (str | Path | os.PathLike)):
489
+ raise TypeError(
490
+ "'data' must be string, Path-like or numpy.ndarray. "
491
+ f"Got {type(data)}"
492
+ )
493
+ else:
494
+ self._path = str(data)
495
+
496
+ self._res = res
497
+ if cmap is not None:
498
+ self.cmap = cmap
499
+ self.file_system = file_system
500
+ self._mask = _mask
501
+ self._name = name
502
+ self._band_id = band_id
503
+ self.processes = processes
504
+
505
+ if self.filename_regexes:
506
+ if isinstance(self.filename_regexes, str):
507
+ self.filename_regexes = [self.filename_regexes]
508
+ self.filename_patterns = [
509
+ re.compile(pat, flags=re.VERBOSE) for pat in self.filename_regexes
510
+ ]
511
+ else:
512
+ self.filename_patterns = None
513
+
514
+ def __lt__(self, other: "Band") -> bool:
515
+ """Makes Bands sortable by band_id."""
516
+ return self.band_id < other.band_id
517
+
518
+ @property
519
+ def values(self) -> np.ndarray:
520
+ """The numpy array, if loaded."""
521
+ try:
522
+ return self._values
523
+ except AttributeError as e:
524
+ raise ValueError("array is not loaded.") from e
525
+
526
+ @values.setter
527
+ def values(self, new_val):
528
+ if isinstance(new_val, np.ndarray):
529
+ raise TypeError(f"{self.__class__.__name__} 'values' must be np.ndarray.")
530
+ self._values = new_val
531
+
532
+ @property
533
+ def band_id(self) -> str:
534
+ """Band id."""
535
+ if self._band_id is not None:
536
+ return self._band_id
537
+ return self._name_regex_searcher("band", self.filename_patterns)
538
+
539
+ @property
540
+ def height(self) -> int:
541
+ """Pixel heigth of the image band."""
542
+ i = 1 if len(self.values.shape) == 3 else 0
543
+ return self.values.shape[i]
544
+
545
+ @property
546
+ def width(self) -> int:
547
+ """Pixel width of the image band."""
548
+ i = 2 if len(self.values.shape) == 3 else 1
549
+ return self.values.shape[i]
550
+
551
+ @property
552
+ def tile(self) -> str:
553
+ """Tile name from filename_regex."""
554
+ if hasattr(self, "_tile") and self._tile:
555
+ return self._tile
556
+ return self._name_regex_searcher("tile", self.filename_patterns)
557
+
558
+ @property
559
+ def date(self) -> str:
560
+ """Tile name from filename_regex."""
561
+ if hasattr(self, "_date") and self._date:
562
+ return self._date
563
+
564
+ return self._name_regex_searcher("date", self.filename_patterns)
565
+
566
+ @property
567
+ def crs(self) -> str | None:
568
+ """Coordinate reference system."""
569
+ try:
570
+ return self._crs
571
+ except AttributeError:
572
+ with opener(self.path, file_system=self.file_system) as file:
573
+ with rasterio.open(file) as src:
574
+ self._bounds = to_bbox(src.bounds)
575
+ self._crs = src.crs
576
+ return self._crs
577
+
578
+ @property
579
+ def bounds(self) -> tuple[int, int, int, int] | None:
580
+ """Bounds as tuple (minx, miny, maxx, maxy)."""
581
+ try:
582
+ return self._bounds
583
+ except AttributeError:
584
+ with opener(self.path, file_system=self.file_system) as file:
585
+ with rasterio.open(file) as src:
586
+ self._bounds = to_bbox(src.bounds)
587
+ self._crs = src.crs
588
+ return self._bounds
589
+ except TypeError:
590
+ return None
591
+
592
+ def get_n_largest(
593
+ self, n: int, precision: float = 0.000001, column: str = "value"
594
+ ) -> GeoDataFrame:
595
+ """Get the largest values of the array as polygons in a GeoDataFrame."""
596
+ copied = self.copy()
597
+ value_must_be_at_least = np.sort(np.ravel(copied.values))[-n] - (precision or 0)
598
+ copied._values = np.where(copied.values >= value_must_be_at_least, 1, 0)
599
+ df = copied.to_gdf(column).loc[lambda x: x[column] == 1]
600
+ df[column] = f"largest_{n}"
601
+ return df
602
+
603
+ def get_n_smallest(
604
+ self, n: int, precision: float = 0.000001, column: str = "value"
605
+ ) -> GeoDataFrame:
606
+ """Get the lowest values of the array as polygons in a GeoDataFrame."""
607
+ copied = self.copy()
608
+ value_must_be_at_least = np.sort(np.ravel(copied.values))[n] - (precision or 0)
609
+ copied._values = np.where(copied.values <= value_must_be_at_least, 1, 0)
610
+ df = copied.to_gdf(column).loc[lambda x: x[column] == 1]
611
+ df[column] = f"smallest_{n}"
612
+ return df
613
+
614
+ def load(self, bounds=None, indexes=None, **kwargs) -> "Band":
615
+ """Load and potentially clip the array.
616
+
617
+ The array is stored in the 'values' property.
618
+ """
619
+ bounds = to_bbox(bounds) if bounds is not None else self._mask
620
+
621
+ try:
622
+ assert isinstance(self.values, np.ndarray)
623
+ has_array = True
624
+ except (ValueError, AssertionError):
625
+ has_array = False
626
+
627
+ if has_array:
628
+ if bounds is None:
629
+ return self
630
+ # bounds_shapely = to_shapely(bounds)
631
+ # if not bounds_shapely.intersects(self.)
632
+ # bounds_arr = GeoSeries([bounds_shapely]).values
633
+ bounds_arr = GeoSeries([to_shapely(bounds)]).values
634
+ try:
635
+ self._values = (
636
+ to_xarray(
637
+ self.values,
638
+ transform=self.transform,
639
+ crs=self.crs,
640
+ name=self.name,
641
+ )
642
+ .rio.clip(bounds_arr, crs=self.crs)
643
+ .to_numpy()
644
+ )
645
+ except NoDataInBounds:
646
+ self._values = np.array([])
647
+ self._bounds = bounds
648
+ return self
649
+
650
+ with opener(self.path, file_system=self.file_system) as f:
651
+ with rasterio.open(f) as src:
652
+ self._res = int(src.res[0]) if not self.res else self.res
653
+ # if bounds is None:
654
+ # out_shape = _get_shape_from_res(to_bbox(src.bounds), self.res, indexes)
655
+ # self.transform = src.transform
656
+ # arr = src.load(indexes=indexes, out_shape=out_shape, **kwargs)
657
+ # # if isinstance(indexes, int) and len(arr.shape) == 3:
658
+ # # return arr[0]
659
+ # return arr
660
+ # else:
661
+ # window = rasterio.windows.from_bounds(
662
+ # *bounds, transform=src.transform
663
+ # )
664
+ # out_shape = _get_shape_from_bounds(bounds, self.res)
665
+
666
+ # arr = src.read(
667
+ # indexes=indexes,
668
+ # out_shape=out_shape,
669
+ # window=window,
670
+ # boundless=boundless,
671
+ # **kwargs,
672
+ # )
673
+ # if isinstance(indexes, int):
674
+ # # arr = arr[0]
675
+ # height, width = arr.shape
676
+ # else:
677
+ # height, width = arr.shape[1:]
678
+
679
+ # self.transform = rasterio.transform.from_bounds(
680
+ # *bounds, width, height
681
+ # )
682
+ # if bounds is not None:
683
+ # self._bounds = bounds
684
+ # return arr
685
+
686
+ if indexes is None and len(src.indexes) == 1:
687
+ indexes = 1
688
+
689
+ if isinstance(indexes, int):
690
+ _indexes = (indexes,)
691
+ else:
692
+ _indexes = indexes
693
+
694
+ arr, transform = rasterio.merge.merge(
695
+ [src],
696
+ res=self.res,
697
+ indexes=_indexes,
698
+ bounds=bounds,
699
+ **kwargs,
700
+ )
701
+ self.transform = transform
702
+ if bounds is not None:
703
+ self._bounds = bounds
704
+
705
+ if isinstance(indexes, int):
706
+ arr = arr[0]
707
+
708
+ self._values = arr
709
+ return self
710
+
711
+ def write(self, path: str | Path, **kwargs) -> None:
712
+ """Write the array as an image file."""
713
+ if not hasattr(self, "_values"):
714
+ raise ValueError(
715
+ "Can only write image band from Band constructed from array."
716
+ )
717
+
718
+ if self.crs is None:
719
+ raise ValueError("Cannot write None crs to image.")
720
+
721
+ profile = {
722
+ # "driver": self.driver,
723
+ # "compress": self.compress,
724
+ # "dtype": self.dtype,
725
+ "crs": self.crs,
726
+ "transform": self.transform,
727
+ # "nodata": self.nodata,
728
+ # "count": self.count,
729
+ # "height": self.height,
730
+ # "width": self.width,
731
+ # "indexes": self.indexes,
732
+ } | kwargs
733
+
734
+ with opener(path, "w", file_system=self.file_system) as f:
735
+ with rasterio.open(f, **profile) as dst:
736
+ # bounds = to_bbox(self._mask) if self._mask is not None else dst.bounds
737
+
738
+ # res = dst.res if not self.res else self.res
739
+
740
+ if len(self.values.shape) == 2:
741
+ return dst.write(self.values, indexes=1)
742
+
743
+ for i in range(self.values.shape[0]):
744
+ dst.write(self.values[i], indexes=i + 1)
745
+
746
+ self._path = str(path)
747
+
748
+ def sample(self, size: int = 1000, mask: Any = None, **kwargs) -> "Image":
749
+ """Take a random spatial sample area of the Band."""
750
+ copied = self.copy()
751
+ if mask is not None:
752
+ point = GeoSeries([copied.unary_union]).clip(mask).sample_points(1)
753
+ else:
754
+ point = GeoSeries([copied.unary_union]).sample_points(1)
755
+ buffered = point.buffer(size / 2).clip(copied.unary_union)
756
+ copied = copied.load(bounds=buffered.total_bounds, **kwargs)
757
+ return copied
758
+
759
+ def get_gradient(self, degrees: bool = False, copy: bool = True) -> "Band":
760
+ """Get the slope of an elevation band.
761
+
762
+ Calculates the absolute slope between the grid cells
763
+ based on the image resolution.
764
+
765
+ For multi-band images, the calculation is done for each band.
766
+
767
+ Args:
768
+ band: band instance.
769
+ degrees: If False (default), the returned values will be in ratios,
770
+ where a value of 1 means 1 meter up per 1 meter forward. If True,
771
+ the values will be in degrees from 0 to 90.
772
+ copy: Whether to copy or overwrite the original Raster.
773
+ Defaults to True.
774
+
775
+ Returns:
776
+ The class instance with new array values, or a copy if copy is True.
777
+
778
+ Examples:
779
+ ---------
780
+ Making an array where the gradient to the center is always 10.
781
+
782
+ >>> import sgis as sg
783
+ >>> import numpy as np
784
+ >>> arr = np.array(
785
+ ... [
786
+ ... [100, 100, 100, 100, 100],
787
+ ... [100, 110, 110, 110, 100],
788
+ ... [100, 110, 120, 110, 100],
789
+ ... [100, 110, 110, 110, 100],
790
+ ... [100, 100, 100, 100, 100],
791
+ ... ]
792
+ ... )
793
+
794
+ Now let's create a Raster from this array with a resolution of 10.
795
+
796
+ >>> band = sg.Band(arr, crs=None, bounds=(0, 0, 50, 50), res=10)
797
+
798
+ The gradient will be 1 (1 meter up for every meter forward).
799
+ The calculation is by default done in place to save memory.
800
+
801
+ >>> band.gradient()
802
+ >>> band.values
803
+ array([[0., 1., 1., 1., 0.],
804
+ [1., 1., 1., 1., 1.],
805
+ [1., 1., 0., 1., 1.],
806
+ [1., 1., 1., 1., 1.],
807
+ [0., 1., 1., 1., 0.]])
808
+ """
809
+ copied = self.copy() if copy else self
810
+ copied._values = _get_gradient(copied, degrees=degrees, copy=copy)
811
+ return copied
812
+
813
+ def zonal(
814
+ self,
815
+ polygons: GeoDataFrame,
816
+ aggfunc: str | Callable | list[Callable | str],
817
+ array_func: Callable | None = None,
818
+ dropna: bool = True,
819
+ ) -> GeoDataFrame:
820
+ """Calculate zonal statistics in polygons.
821
+
822
+ Args:
823
+ polygons: A GeoDataFrame of polygon geometries.
824
+ aggfunc: Function(s) of which to aggregate the values
825
+ within each polygon.
826
+ array_func: Optional calculation of the raster
827
+ array before calculating the zonal statistics.
828
+ dropna: If True (default), polygons with all missing
829
+ values will be removed.
830
+
831
+ Returns:
832
+ A GeoDataFrame with aggregated values per polygon.
833
+ """
834
+ idx_mapper, idx_name = get_index_mapper(polygons)
835
+ polygons, aggfunc, func_names = _prepare_zonal(polygons, aggfunc)
836
+ poly_iter = _make_geometry_iterrows(polygons)
837
+
838
+ kwargs = {
839
+ "band": self,
840
+ "aggfunc": aggfunc,
841
+ "array_func": array_func,
842
+ "func_names": func_names,
843
+ }
844
+
845
+ if self.processes == 1:
846
+ aggregated = [_zonal_one_pair(i, poly, **kwargs) for i, poly in poly_iter]
847
+ else:
848
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
849
+ aggregated = parallel(
850
+ joblib.delayed(_zonal_one_pair)(i, poly, **kwargs)
851
+ for i, poly in poly_iter
852
+ )
853
+
854
+ return _zonal_post(
855
+ aggregated,
856
+ polygons=polygons,
857
+ idx_mapper=idx_mapper,
858
+ idx_name=idx_name,
859
+ dropna=dropna,
860
+ )
861
+
862
+ def to_gdf(self, column: str = "value") -> GeoDataFrame:
863
+ """Create a GeoDataFrame from the image Band.
864
+
865
+ Args:
866
+ column: Name of resulting column that holds the raster values.
867
+
868
+ Returns:
869
+ A GeoDataFrame with a geometry column and array values.
870
+ """
871
+ if not hasattr(self, "_values"):
872
+ raise ValueError("Array is not loaded.")
873
+
874
+ if self.values.shape[0] == 0:
875
+ return GeoDataFrame({"geometry": []}, crs=self.crs)
876
+
877
+ return GeoDataFrame(
878
+ pd.DataFrame(
879
+ _array_to_geojson(
880
+ self.values, self.transform, processes=self.processes
881
+ ),
882
+ columns=[column, "geometry"],
883
+ ),
884
+ geometry="geometry",
885
+ crs=self.crs,
886
+ )
887
+
888
+ def to_xarray(self) -> DataArray:
889
+ """Convert the raster to an xarray.DataArray."""
890
+ name = self.name or self.__class__.__name__.lower()
891
+ coords = _generate_spatial_coords(self.transform, self.width, self.height)
892
+ if len(self.values.shape) == 2:
893
+ dims = ["y", "x"]
894
+ elif len(self.values.shape) == 3:
895
+ dims = ["band", "y", "x"]
896
+ else:
897
+ raise ValueError("Array must be 2 or 3 dimensional.")
898
+ return xr.DataArray(
899
+ self.values,
900
+ coords=coords,
901
+ dims=dims,
902
+ name=name,
903
+ attrs={"crs": self.crs},
904
+ ) # .transpose("y", "x")
905
+
906
+ def __repr__(self) -> str:
907
+ """String representation."""
908
+ try:
909
+ band_id = f"'{self.band_id}'" if self.band_id else None
910
+ except (ValueError, AttributeError):
911
+ band_id = None
912
+ try:
913
+ path = f"'{self.path}'"
914
+ except (ValueError, AttributeError):
915
+ path = None
916
+ return (
917
+ f"{self.__class__.__name__}(band_id={band_id}, res={self.res}, path={path})"
918
+ )
919
+
920
+
921
+ class NDVIBand(Band):
922
+ """Band for NDVI values."""
923
+
924
+ cmap: str = "Greens"
925
+
926
+
927
+ class Image(_ImageBandBase):
928
+ """Image consisting of one or more Bands."""
929
+
930
+ cloud_cover_regexes: ClassVar[tuple[str] | None] = None
931
+ band_class: ClassVar[Band] = Band
932
+
933
+ def __init__(
934
+ self,
935
+ data: str | Path | Sequence[Band],
936
+ res: int | None = None,
937
+ # crs: Any | None = None,
938
+ single_banded: bool = False,
939
+ file_system: GCSFileSystem | None = None,
940
+ df: pd.DataFrame | None = None,
941
+ all_file_paths: list[str] | None = None,
942
+ _mask: GeoDataFrame | GeoSeries | Geometry | tuple | None = None,
943
+ processes: int = 1,
944
+ ) -> None:
945
+ """Image initialiser."""
946
+ super().__init__()
947
+
948
+ self._res = res
949
+ # self._crs = crs
950
+ self.file_system = file_system
951
+ self._mask = _mask
952
+ self.single_banded = single_banded
953
+ self.processes = processes
954
+
955
+ if hasattr(data, "__iter__") and all(isinstance(x, Band) for x in data):
956
+ self._bands = list(data)
957
+ self._bounds = get_total_bounds(self._bands)
958
+ self._crs = get_common_crs(self._bands)
959
+ res = list({band.res for band in self._bands})
960
+ if len(res) == 1:
961
+ self._res = res[0]
962
+ else:
963
+ raise ValueError(f"Different resolutions for the bands: {res}")
964
+ return
965
+
966
+ if not isinstance(data, (str | Path | os.PathLike)):
967
+ raise TypeError("'data' must be string, Path-like or a sequence of Band.")
968
+
969
+ self._path = str(data)
970
+
971
+ if df is None:
972
+ if is_dapla():
973
+ file_paths = list(sorted(set(glob_func(self.path + "/**"))))
974
+ else:
975
+ file_paths = list(
976
+ sorted(
977
+ set(
978
+ glob_func(self.path + "/**/**")
979
+ + glob_func(self.path + "/**/**/**")
980
+ + glob_func(self.path + "/**/**/**/**")
981
+ + glob_func(self.path + "/**/**/**/**/**")
982
+ )
983
+ )
984
+ )
985
+ df = self._create_metadata_df(file_paths)
986
+
987
+ df["image_path"] = df["image_path"].astype(str)
988
+
989
+ cols_to_explode = [
990
+ "file_path",
991
+ "filename",
992
+ *[x for x in df if FILENAME_COL_SUFFIX in x],
993
+ ]
994
+ try:
995
+ df = df.explode(cols_to_explode, ignore_index=True)
996
+ except ValueError:
997
+ for col in cols_to_explode:
998
+ df = df.explode(col)
999
+ df = df.loc[lambda x: ~x["filename"].duplicated()].reset_index(drop=True)
1000
+
1001
+ df = df.loc[lambda x: x["image_path"].str.contains(_fix_path(self.path))]
1002
+
1003
+ if self.filename_patterns and any(pat.groups for pat in self.filename_patterns):
1004
+ df = df.loc[
1005
+ lambda x: (x[f"band{FILENAME_COL_SUFFIX}"].notna())
1006
+ ].sort_values(f"band{FILENAME_COL_SUFFIX}")
1007
+
1008
+ if self.cloud_cover_regexes:
1009
+ if all_file_paths is None:
1010
+ file_paths = ls_func(self.path)
1011
+ else:
1012
+ file_paths = [path for path in all_file_paths if self.name in path]
1013
+ self.cloud_coverage_percentage = float(
1014
+ _get_regex_match_from_xml_in_local_dir(
1015
+ file_paths, regexes=self.cloud_cover_regexes
1016
+ )
1017
+ )
1018
+ else:
1019
+ self.cloud_coverage_percentage = None
1020
+
1021
+ self._bands = [
1022
+ self.band_class(
1023
+ path,
1024
+ **self._common_init_kwargs,
1025
+ )
1026
+ for path in (df["file_path"])
1027
+ ]
1028
+
1029
+ if self.filename_patterns and any(pat.groups for pat in self.filename_patterns):
1030
+ self._bands = list(sorted(self._bands))
1031
+
1032
+ def get_ndvi(self, red_band: str, nir_band: str) -> NDVIBand:
1033
+ """Calculate the NDVI for the Image."""
1034
+ red = self[red_band].load().values
1035
+ nir = self[nir_band].load().values
1036
+
1037
+ arr: np.ndarray = ndvi(red, nir)
1038
+
1039
+ return NDVIBand(
1040
+ arr,
1041
+ bounds=self.bounds,
1042
+ crs=self.crs,
1043
+ **self._common_init_kwargs,
1044
+ )
1045
+
1046
+ def get_brightness(self, bounds=None, rbg_bands: list[str] | None = None) -> Band:
1047
+ """Get a Band with a brightness score of the Image's RBG bands."""
1048
+ if rbg_bands is None:
1049
+ try:
1050
+ r, b, g = self.rbg_bands
1051
+ except AttributeError as err:
1052
+ raise AttributeError(
1053
+ "Must specify rbg_bands when there is no class variable 'rbd_bands'"
1054
+ ) from err
1055
+ else:
1056
+ r, b, g = rbg_bands
1057
+
1058
+ red = self[r].load(bounds=bounds)
1059
+ blue = self[b].load(bounds=bounds)
1060
+ green = self[g].load(bounds=bounds)
1061
+
1062
+ brightness = (
1063
+ 0.299 * red.values + 0.587 * green.values + 0.114 * blue.values
1064
+ ).astype(int)
1065
+
1066
+ return Band(
1067
+ brightness,
1068
+ bounds=red.bounds,
1069
+ crs=self.crs,
1070
+ **self._common_init_kwargs,
1071
+ )
1072
+
1073
+ @property
1074
+ def band_ids(self) -> list[str]:
1075
+ """The Band ids."""
1076
+ return [band.band_id for band in self]
1077
+
1078
+ @property
1079
+ def file_paths(self) -> list[str]:
1080
+ """The Band file paths."""
1081
+ return [band.path for band in self]
1082
+
1083
+ @property
1084
+ def bands(self) -> list[Band]:
1085
+ """The Image Bands."""
1086
+ return self._bands
1087
+
1088
+ @property
1089
+ def tile(self) -> str:
1090
+ """Tile name from filename_regex."""
1091
+ if hasattr(self, "_tile") and self._tile:
1092
+ return self._tile
1093
+ return self._name_regex_searcher("tile", self.image_patterns)
1094
+
1095
+ @property
1096
+ def date(self) -> str:
1097
+ """Tile name from filename_regex."""
1098
+ if hasattr(self, "_date") and self._date:
1099
+ return self._date
1100
+
1101
+ return self._name_regex_searcher("date", self.image_patterns)
1102
+
1103
+ @property
1104
+ def crs(self) -> str | None:
1105
+ """Coordinate reference system of the Image."""
1106
+ try:
1107
+ return self._crs
1108
+ except AttributeError:
1109
+ if not len(self):
1110
+ return None
1111
+ with opener(self.file_paths[0], file_system=self.file_system) as file:
1112
+ with rasterio.open(file) as src:
1113
+ self._bounds = to_bbox(src.bounds)
1114
+ self._crs = src.crs
1115
+ return self._crs
1116
+
1117
+ @property
1118
+ def bounds(self) -> tuple[int, int, int, int] | None:
1119
+ """Bounds of the Image (minx, miny, maxx, maxy)."""
1120
+ try:
1121
+ return self._bounds
1122
+ except AttributeError:
1123
+ if not len(self):
1124
+ return None
1125
+ with opener(self.file_paths[0], file_system=self.file_system) as file:
1126
+ with rasterio.open(file) as src:
1127
+ self._bounds = to_bbox(src.bounds)
1128
+ self._crs = src.crs
1129
+ return self._bounds
1130
+ except TypeError:
1131
+ return None
1132
+
1133
+ # @property
1134
+ # def year(self) -> str:
1135
+ # return self.date[:4]
1136
+
1137
+ # @property
1138
+ # def month(self) -> str:
1139
+ # return "".join(self.date.split("-"))[:6]
1140
+
1141
+ # def write(self, image_path: str | Path, file_type: str = "tif", **kwargs) -> None:
1142
+ # _test = kwargs.pop("_test")
1143
+ # suffix = "." + file_type.strip(".")
1144
+ # img_path = Path(image_path)
1145
+ # for band in self:
1146
+ # band_path = (img_path / band.name).with_suffix(suffix)
1147
+ # if _test:
1148
+ # print(f"{self.__class__.__name__}.write: {band_path}")
1149
+ # continue
1150
+
1151
+ # band.write(band_path, **kwargs)
1152
+
1153
+ # def read(self, bounds=None, **kwargs) -> np.ndarray:
1154
+ # """Return 3 dimensional numpy.ndarray of shape (n bands, width, height)."""
1155
+ # return np.array(
1156
+ # [(band.load(bounds=bounds, **kwargs).values) for band in self.bands]
1157
+ # )
1158
+
1159
+ def to_gdf(self, column: str = "value") -> GeoDataFrame:
1160
+ """Convert the array to a GeoDataFrame of grid polygons and values."""
1161
+ return pd.concat(
1162
+ [band.to_gdf(column=column) for band in self], ignore_index=True
1163
+ )
1164
+
1165
+ def sample(
1166
+ self, n: int = 1, size: int = 1000, mask: Any = None, **kwargs
1167
+ ) -> "Image":
1168
+ """Take a random spatial sample of the image."""
1169
+ copied = self.copy()
1170
+ if mask is not None:
1171
+ points = GeoSeries([self.unary_union]).clip(mask).sample_points(n)
1172
+ else:
1173
+ points = GeoSeries([self.unary_union]).sample_points(n)
1174
+ buffered = points.buffer(size / 2).clip(self.unary_union)
1175
+ boxes = to_gdf([box(*arr) for arr in buffered.bounds.values], crs=self.crs)
1176
+ copied._bands = [band.load(bounds=boxes, **kwargs) for band in copied]
1177
+ return copied
1178
+
1179
+ # def get_filepath(self, band: str) -> str:
1180
+ # simple_string_match = [path for path in self.file_paths if str(band) in path]
1181
+ # if len(simple_string_match) == 1:
1182
+ # return simple_string_match[0]
1183
+
1184
+ # regexes_matches = []
1185
+ # for path in self.file_paths:
1186
+ # for pat in self.filename_patterns:
1187
+ # match_ = re.search(pat, Path(path).name)
1188
+ # if match_ and str(band) == match_.group("band"):
1189
+ # regexes_matches.append(path)
1190
+
1191
+ # if len(regexes_matches) == 1:
1192
+ # return regexes_matches[0]
1193
+
1194
+ # if len(regexes_matches) > 1:
1195
+ # prefix = "Multiple"
1196
+ # elif not regexes_matches:
1197
+ # prefix = "No"
1198
+
1199
+ # raise KeyError(
1200
+ # f"{prefix} matches for band {band} among paths {[Path(x).name for x in self.file_paths]}"
1201
+ # )
1202
+
1203
+ def __getitem__(
1204
+ self, band: str | int | Sequence[str] | Sequence[int]
1205
+ ) -> "Band | Image":
1206
+ """Get bands by band_id or integer index.
1207
+
1208
+ Returns a Band if a string or int is passed,
1209
+ returns an Image if a sequence of strings or integers is passed.
1210
+ """
1211
+ if isinstance(band, str):
1212
+ return self._get_band(band)
1213
+ if isinstance(band, int):
1214
+ return self.bands[band] # .copy()
1215
+
1216
+ copied = self.copy()
1217
+ try:
1218
+ copied._bands = [copied._get_band(x) for x in band]
1219
+ except TypeError:
1220
+ try:
1221
+ copied._bands = [copied.bands[i] for i in band]
1222
+ except TypeError as e:
1223
+ raise TypeError(
1224
+ f"{self.__class__.__name__} indices should be string, int "
1225
+ f"or sequence of string or int. Got {band}."
1226
+ ) from e
1227
+ return copied
1228
+
1229
+ def __contains__(self, item: str | Sequence[str]) -> bool:
1230
+ """Check if the Image contains a band_id (str) or all band_ids in a sequence."""
1231
+ if isinstance(item, str):
1232
+ return item in self.band_ids
1233
+ return all(x in self.band_ids for x in item)
1234
+
1235
+ def __lt__(self, other: "Image") -> bool:
1236
+ """Makes Images sortable by date."""
1237
+ return self.date < other.date
1238
+
1239
+ def __iter__(self) -> Iterator[Band]:
1240
+ """Iterate over the Bands."""
1241
+ return iter(self.bands)
1242
+
1243
+ def __len__(self) -> int:
1244
+ """Number of bands in the Image."""
1245
+ return len(self.bands)
1246
+
1247
+ def __repr__(self) -> str:
1248
+ """String representation."""
1249
+ return f"{self.__class__.__name__}(bands={self.bands})"
1250
+
1251
+ def get_cloud_band(self) -> Band:
1252
+ """Get a Band where self.cloud_values have value 1 and the rest have value 0."""
1253
+ scl = self[self.cloud_band].load()
1254
+ scl._values = np.where(np.isin(scl.values, self.cloud_values), 1, 0)
1255
+ return scl
1256
+
1257
+ # @property
1258
+ # def transform(self) -> Affine | None:
1259
+ # """Get the Affine transform of the image."""
1260
+ # try:
1261
+ # return rasterio.transform.from_bounds(*self.bounds, self.width, self.height)
1262
+ # except (ZeroDivisionError, TypeError):
1263
+ # if not self.width or not self.height:
1264
+ # return None
1265
+
1266
+ # @property
1267
+ # def shape(self) -> tuple[int]:
1268
+ # return self._shape
1269
+
1270
+ # @property
1271
+ # def height(self) -> int:
1272
+ # i = 1 if len(self.shape) == 3 else 0
1273
+ # return self.shape[i]
1274
+
1275
+ # @property
1276
+ # def width(self) -> int:
1277
+ # i = 2 if len(self.shape) == 3 else 1
1278
+ # return self.shape[i]
1279
+
1280
+ # @transform.setter
1281
+ # def transform(self, value: Affine) -> None:
1282
+ # self._bounds = rasterio.transform.array_bounds(self.height, self.width, value)
1283
+
1284
+ def _get_band(self, band: str) -> Band:
1285
+ if not isinstance(band, str):
1286
+ raise TypeError(f"band must be string. Got {type(band)}")
1287
+
1288
+ bands = [x for x in self.bands if x.band_id == band]
1289
+ if len(bands) == 1:
1290
+ return bands[0]
1291
+ if len(bands) > 1:
1292
+ raise ValueError(
1293
+ f"Multiple matches for band_id {band} among {[x for x in self]}"
1294
+ )
1295
+
1296
+ try:
1297
+ more_bands = [x for x in self.bands if x.path == band]
1298
+ except PathlessImageError:
1299
+ more_bands = bands
1300
+
1301
+ if len(more_bands) == 1:
1302
+ return more_bands[0]
1303
+
1304
+ if len(bands) > 1:
1305
+ prefix = "Multiple"
1306
+ elif not bands:
1307
+ prefix = "No"
1308
+
1309
+ raise KeyError(
1310
+ f"{prefix} matches for band {band} among paths {[Path(band.path).name for band in self.bands]}"
1311
+ )
1312
+
1313
+
1314
+ class ImageCollection(_ImageBase):
1315
+ """Collection of Images.
1316
+
1317
+ Loops though Images.
1318
+ """
1319
+
1320
+ image_class: ClassVar[Image] = Image
1321
+ band_class: ClassVar[Band] = Band
1322
+
1323
+ def __init__(
1324
+ self,
1325
+ data: str | Path | Sequence[Image],
1326
+ res: int,
1327
+ level: str | None,
1328
+ single_banded: bool = False,
1329
+ processes: int = 1,
1330
+ file_system: GCSFileSystem | None = None,
1331
+ df: pd.DataFrame | None = None,
1332
+ _mask: Any | None = None,
1333
+ ) -> None:
1334
+ """Initialiser."""
1335
+ super().__init__()
1336
+
1337
+ self.level = level
1338
+ self.processes = processes
1339
+ self.file_system = file_system
1340
+ self._res = res
1341
+ self._mask = _mask
1342
+ self._band_ids = None
1343
+ self.single_banded = single_banded
1344
+
1345
+ if hasattr(data, "__iter__") and all(isinstance(x, Image) for x in data):
1346
+ self._path = None
1347
+ self.images = data
1348
+ return
1349
+
1350
+ if not isinstance(data, (str | Path | os.PathLike)):
1351
+ raise TypeError("'data' must be string, Path-like or a sequence of Image.")
1352
+
1353
+ self._path = str(data)
1354
+
1355
+ if is_dapla():
1356
+ self._all_filepaths = list(sorted(set(glob_func(self.path + "/**"))))
1357
+ else:
1358
+ self._all_filepaths = list(
1359
+ sorted(
1360
+ set(
1361
+ glob_func(self.path + "/**/**")
1362
+ + glob_func(self.path + "/**/**/**")
1363
+ + glob_func(self.path + "/**/**/**/**")
1364
+ + glob_func(self.path + "/**/**/**/**/**")
1365
+ )
1366
+ )
1367
+ )
1368
+
1369
+ if self.level:
1370
+ self._all_filepaths = [
1371
+ path for path in self._all_filepaths if self.level in path
1372
+ ]
1373
+
1374
+ if df is not None:
1375
+ self._df = df
1376
+ else:
1377
+ self._df = self._create_metadata_df(self._all_filepaths)
1378
+
1379
+ def groupby(self, by: str | list[str], **kwargs) -> ImageCollectionGroupBy:
1380
+ """Group the Collection by Image or Band attribute(s)."""
1381
+ df = pd.DataFrame(
1382
+ [(i, img) for i, img in enumerate(self) for _ in img],
1383
+ columns=["_image_idx", "_image_instance"],
1384
+ )
1385
+
1386
+ if isinstance(by, str):
1387
+ by = [by]
1388
+
1389
+ for attr in by:
1390
+ if attr == "bounds":
1391
+ # need integers to check equality when grouping
1392
+ df[attr] = [
1393
+ tuple(int(x) for x in band.bounds) for img in self for band in img
1394
+ ]
1395
+ continue
1396
+
1397
+ try:
1398
+ df[attr] = [getattr(band, attr) for img in self for band in img]
1399
+ except AttributeError:
1400
+ df[attr] = [getattr(img, attr) for img in self for _ in img]
1401
+
1402
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1403
+ return ImageCollectionGroupBy(
1404
+ sorted(
1405
+ parallel(
1406
+ joblib.delayed(_copy_and_add_df_parallel)(i, group, self)
1407
+ for i, group in df.groupby(by, **kwargs)
1408
+ )
1409
+ ),
1410
+ by=by,
1411
+ collection=self,
1412
+ )
1413
+
1414
+ def explode(self, copy: bool = True) -> "ImageCollection":
1415
+ """Make all Images single-banded."""
1416
+ copied = self.copy() if copy else self
1417
+ copied.images = [
1418
+ self.image_class(
1419
+ [band],
1420
+ single_banded=True,
1421
+ **self._common_init_kwargs,
1422
+ df=self._df,
1423
+ all_file_paths=self._all_filepaths,
1424
+ )
1425
+ for img in self
1426
+ for band in img
1427
+ ]
1428
+ return copied
1429
+
1430
+ def merge(
1431
+ self, bounds=None, method="median", as_int: bool = True, indexes=None, **kwargs
1432
+ ) -> Band:
1433
+ """Merge all areas and all bands to a single Band."""
1434
+ bounds = to_bbox(bounds) if bounds is not None else self._mask
1435
+ crs = self.crs
1436
+
1437
+ if indexes is None:
1438
+ indexes = 1
1439
+
1440
+ if isinstance(indexes, int):
1441
+ _indexes = (indexes,)
1442
+ else:
1443
+ _indexes = indexes
1444
+
1445
+ if method == "mean":
1446
+ _method = "sum"
1447
+ else:
1448
+ _method = method
1449
+
1450
+ if method not in list(rasterio.merge.MERGE_METHODS) + ["mean"]:
1451
+ arr = self._merge_with_numpy_func(
1452
+ method=method,
1453
+ bounds=bounds,
1454
+ as_int=as_int,
1455
+ )
1456
+ else:
1457
+ datasets = [_open_raster(path) for path in self.file_paths]
1458
+ arr, _ = rasterio.merge.merge(
1459
+ datasets,
1460
+ res=self.res,
1461
+ bounds=bounds,
1462
+ indexes=_indexes,
1463
+ method=_method,
1464
+ **kwargs,
1465
+ )
1466
+
1467
+ if isinstance(indexes, int) and len(arr.shape) == 3 and arr.shape[0] == 1:
1468
+ arr = arr[0]
1469
+
1470
+ if method == "mean":
1471
+ if as_int:
1472
+ arr = arr // len(datasets)
1473
+ else:
1474
+ arr = arr / len(datasets)
1475
+
1476
+ if bounds is None:
1477
+ bounds = self.bounds
1478
+
1479
+ # return self.band_class(
1480
+ band = Band(
1481
+ arr,
1482
+ bounds=bounds,
1483
+ crs=crs,
1484
+ **self._common_init_kwargs,
1485
+ )
1486
+
1487
+ band._merged = True
1488
+ return band
1489
+
1490
+ def merge_by_band(
1491
+ self,
1492
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1493
+ method: str = "median",
1494
+ as_int: bool = True,
1495
+ indexes: int | tuple[int] | None = None,
1496
+ **kwargs,
1497
+ ) -> Image:
1498
+ """Merge all areas to a single tile, one band per band_id."""
1499
+ bounds = to_bbox(bounds) if bounds is not None else self._mask
1500
+ bounds = self.bounds if bounds is None else bounds
1501
+ out_bounds = self.bounds if bounds is None else bounds
1502
+ crs = self.crs
1503
+
1504
+ if indexes is None:
1505
+ indexes = 1
1506
+
1507
+ if isinstance(indexes, int):
1508
+ _indexes = (indexes,)
1509
+ else:
1510
+ _indexes = indexes
1511
+
1512
+ if method == "mean":
1513
+ _method = "sum"
1514
+ else:
1515
+ _method = method
1516
+
1517
+ arrs = []
1518
+ bands: list[Band] = []
1519
+ for (band_id,), band_collection in self.groupby("band_id"):
1520
+ if method not in list(rasterio.merge.MERGE_METHODS) + ["mean"]:
1521
+ arr = band_collection._merge_with_numpy_func(
1522
+ method=method,
1523
+ bounds=bounds,
1524
+ as_int=as_int,
1525
+ )
1526
+ else:
1527
+ datasets = [_open_raster(path) for path in band_collection.file_paths]
1528
+ arr, _ = rasterio.merge.merge(
1529
+ datasets,
1530
+ res=self.res,
1531
+ bounds=bounds,
1532
+ indexes=_indexes,
1533
+ method=_method,
1534
+ **kwargs,
1535
+ )
1536
+ if isinstance(indexes, int):
1537
+ arr = arr[0]
1538
+ if method == "mean":
1539
+ if as_int:
1540
+ arr = arr // len(datasets)
1541
+ else:
1542
+ arr = arr / len(datasets)
1543
+
1544
+ arrs.append(arr)
1545
+ bands.append(
1546
+ self.band_class(
1547
+ arr,
1548
+ bounds=out_bounds,
1549
+ crs=crs,
1550
+ band_id=band_id,
1551
+ **self._common_init_kwargs,
1552
+ )
1553
+ )
1554
+
1555
+ # return self.image_class(
1556
+ image = Image(
1557
+ bands,
1558
+ **self._common_init_kwargs,
1559
+ )
1560
+
1561
+ image._merged = True
1562
+ return image
1563
+
1564
+ def _merge_with_numpy_func(
1565
+ self,
1566
+ method: str | Callable,
1567
+ bounds=None,
1568
+ as_int: bool = True,
1569
+ indexes=None,
1570
+ **kwargs,
1571
+ ) -> np.ndarray:
1572
+ arrs = []
1573
+ numpy_func = get_numpy_func(method) if not callable(method) else method
1574
+ for (_bounds,), collection in self.groupby("bounds"):
1575
+ arr = np.array(
1576
+ [
1577
+ band.load(bounds=bounds, indexes=indexes, **kwargs).values
1578
+ for img in collection
1579
+ for band in img
1580
+ ]
1581
+ )
1582
+ arr = numpy_func(arr, axis=0)
1583
+ if as_int:
1584
+ arr = arr.astype(int)
1585
+ min_dtype = rasterio.dtypes.get_minimum_dtype(arr)
1586
+ arr = arr.astype(min_dtype)
1587
+
1588
+ if len(arr.shape) == 2:
1589
+ height, width = arr.shape
1590
+ elif len(arr.shape) == 3:
1591
+ height, width = arr.shape[1:]
1592
+ else:
1593
+ raise ValueError(arr.shape)
1594
+
1595
+ transform = rasterio.transform.from_bounds(*_bounds, width, height)
1596
+ coords = _generate_spatial_coords(transform, width, height)
1597
+
1598
+ arrs.append(
1599
+ xr.DataArray(
1600
+ arr,
1601
+ coords=coords,
1602
+ dims=["y", "x"],
1603
+ name=str(_bounds),
1604
+ attrs={"crs": self.crs},
1605
+ )
1606
+ )
1607
+
1608
+ if bounds is None:
1609
+ bounds = self.bounds
1610
+
1611
+ merged = merge_arrays(arrs, bounds=bounds, res=self.res)
1612
+
1613
+ return merged.to_numpy()
1614
+
1615
+ # def write(self, root: str | Path, file_type: str = "tif", **kwargs) -> None:
1616
+ # _test = kwargs.pop("_test")
1617
+ # suffix = "." + file_type.strip(".")
1618
+ # for img in self:
1619
+ # img_path = Path(root) / img.name
1620
+ # for band in img:
1621
+ # band_path = (img_path / band.name).with_suffix(suffix)
1622
+ # if _test:
1623
+ # print(f"{self.__class__.__name__}.write: {band_path}")
1624
+ # continue
1625
+ # band.write(band_path, **kwargs)
1626
+
1627
+ def load_bands(self, bounds=None, indexes=None, **kwargs) -> "ImageCollection":
1628
+ """Load all image Bands with threading."""
1629
+ with joblib.Parallel(n_jobs=self.processes, backend="threading") as parallel:
1630
+ parallel(
1631
+ joblib.delayed(_load_band)(
1632
+ band, bounds=bounds, indexes=indexes, **kwargs
1633
+ )
1634
+ for img in self
1635
+ for band in img
1636
+ if bounds is None or img.intersects(bounds)
1637
+ )
1638
+
1639
+ return self
1640
+
1641
+ def set_mask(
1642
+ self, mask: GeoDataFrame | GeoSeries | Geometry | tuple[float]
1643
+ ) -> "ImageCollection":
1644
+ """Set the mask to be used to clip the images to."""
1645
+ self._mask = to_bbox(mask)
1646
+ # only update images when already instansiated
1647
+ if hasattr(self, "_images"):
1648
+ for img in self._images:
1649
+ img._mask = self._mask
1650
+ img._bounds = self._mask
1651
+ for band in img:
1652
+ band._mask = self._mask
1653
+ band._bounds = self._mask
1654
+ return self
1655
+
1656
+ def filter(
1657
+ self,
1658
+ bands: str | list[str] | None = None,
1659
+ date_ranges: (
1660
+ tuple[str | None, str | None]
1661
+ | tuple[tuple[str | None, str | None], ...]
1662
+ | None
1663
+ ) = None,
1664
+ bounds: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
1665
+ max_cloud_coverage: int | None = None,
1666
+ copy: bool = True,
1667
+ ) -> "ImageCollection":
1668
+ """Filter images and bands in the collection."""
1669
+ copied = self.copy() if copy else self
1670
+
1671
+ if isinstance(bounds, BoundingBox):
1672
+ date_ranges = (bounds.mint, bounds.maxt)
1673
+
1674
+ if date_ranges:
1675
+ copied = copied._filter_dates(date_ranges, copy=False)
1676
+
1677
+ if max_cloud_coverage is not None:
1678
+ copied.images = [
1679
+ image
1680
+ for image in copied.images
1681
+ if image.cloud_coverage_percentage < max_cloud_coverage
1682
+ ]
1683
+
1684
+ if bounds is not None:
1685
+ copied = copied._filter_bounds(bounds, copy=False)
1686
+
1687
+ if bands is not None:
1688
+ if isinstance(bands, str):
1689
+ bands = [bands]
1690
+ bands = set(bands)
1691
+ copied._band_ids = bands
1692
+ copied.images = [img[bands] for img in copied.images if bands in img]
1693
+
1694
+ return copied
1695
+
1696
+ def _filter_dates(
1697
+ self,
1698
+ date_ranges: (
1699
+ tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...]
1700
+ ),
1701
+ copy: bool = True,
1702
+ ) -> "ImageCollection":
1703
+ if not isinstance(date_ranges, (tuple, list)):
1704
+ raise TypeError(
1705
+ "date_ranges should be a 2-length tuple of strings or None, "
1706
+ "or a tuple of tuples for multiple date ranges"
1707
+ )
1708
+ if self.image_patterns is None:
1709
+ raise ValueError(
1710
+ "Cannot set date_ranges when the class's image_regexes attribute is None"
1711
+ )
1712
+
1713
+ copied = self.copy() if copy else self
1714
+
1715
+ copied.images = [
1716
+ img
1717
+ for img in self
1718
+ if _date_is_within(
1719
+ img.path, date_ranges, copied.image_patterns, copied.date_format
1720
+ )
1721
+ ]
1722
+ return copied
1723
+
1724
+ def _filter_bounds(
1725
+ self, other: GeoDataFrame | GeoSeries | Geometry | tuple, copy: bool = True
1726
+ ) -> "ImageCollection":
1727
+ copied = self.copy() if copy else self
1728
+
1729
+ other = to_shapely(other)
1730
+
1731
+ with joblib.Parallel(n_jobs=copied.processes, backend="threading") as parallel:
1732
+ intersects_list: list[bool] = parallel(
1733
+ joblib.delayed(_intesects)(image, other) for image in copied
1734
+ )
1735
+ copied.images = [
1736
+ image
1737
+ for image, intersects in zip(copied, intersects_list, strict=False)
1738
+ if intersects
1739
+ ]
1740
+ return copied
1741
+
1742
+ def to_gdfs(self, column: str = "value") -> dict[str, GeoDataFrame]:
1743
+ """Convert each band in each Image to a GeoDataFrame."""
1744
+ out = {}
1745
+ i = 0
1746
+ for img in self:
1747
+ for band in img:
1748
+ i += 1
1749
+ try:
1750
+ name = band.name
1751
+ except AttributeError:
1752
+ name = f"{self.__class__.__name__}({i})"
1753
+
1754
+ if name not in out:
1755
+ out[name] = band.to_gdf(column=column)
1756
+ else:
1757
+ out[name] = f"{self.__class__.__name__}({i})"
1758
+ return out
1759
+
1760
+ def sample(self, n: int = 1, size: int = 500) -> "ImageCollection":
1761
+ """Sample one or more areas of a given size and set this as mask for the images."""
1762
+ images = []
1763
+ bbox = to_gdf(self.unary_union).geometry.buffer(-size / 2)
1764
+ copied = self.copy()
1765
+ for _ in range(n):
1766
+ mask = to_bbox(bbox.sample_points(1).buffer(size))
1767
+ images += copied.filter(bounds=mask).set_mask(mask).images
1768
+ copied.images = images
1769
+ return copied
1770
+
1771
+ def sample_tiles(self, n: int) -> "ImageCollection":
1772
+ """Sample one or more tiles in a copy of the ImageCollection."""
1773
+ copied = self.copy()
1774
+ sampled_tiles = list({img.tile for img in self})
1775
+ random.shuffle(sampled_tiles)
1776
+ sampled_tiles = sampled_tiles[:n]
1777
+
1778
+ copied.images = [image for image in self if image.tile in sampled_tiles]
1779
+ return copied
1780
+
1781
+ def sample_images(self, n: int) -> "ImageCollection":
1782
+ """Sample one or more images in a copy of the ImageCollection."""
1783
+ copied = self.copy()
1784
+ images = copied.images
1785
+ if n > len(images):
1786
+ raise ValueError(
1787
+ f"n ({n}) is higher than number of images in collection ({len(images)})"
1788
+ )
1789
+ sample = []
1790
+ for _ in range(n):
1791
+ random.shuffle(images)
1792
+ img = images.pop()
1793
+ sample.append(img)
1794
+
1795
+ copied.images = sample
1796
+
1797
+ return copied
1798
+
1799
+ def __or__(self, collection: "ImageCollection") -> "ImageCollection":
1800
+ """Concatenate the collection with another collection."""
1801
+ return concat_image_collections([self, collection])
1802
+
1803
+ def __iter__(self) -> Iterator[Image]:
1804
+ """Iterate over the images."""
1805
+ return iter(self.images)
1806
+
1807
+ def __len__(self) -> int:
1808
+ """Number of images."""
1809
+ return len(self.images)
1810
+
1811
+ def __getitem__(
1812
+ self,
1813
+ item: int | slice | Sequence[int | bool] | BoundingBox | Sequence[BoundingBox],
1814
+ ) -> Image | TORCHGEO_RETURN_TYPE:
1815
+ """Select one Image by integer index, or multiple Images by slice, list of int or torchgeo.BoundingBox."""
1816
+ if isinstance(item, int):
1817
+ return self.images[item]
1818
+
1819
+ if isinstance(item, slice):
1820
+ copied = self.copy()
1821
+ copied.images = copied.images[item]
1822
+ return copied
1823
+
1824
+ if not isinstance(item, BoundingBox) and not (
1825
+ isinstance(item, Iterable) and all(isinstance(x, BoundingBox) for x in item)
1826
+ ):
1827
+ try:
1828
+ copied = self.copy()
1829
+ if all(isinstance(x, bool) for x in item):
1830
+ copied.images = [
1831
+ img for x, img in zip(item, copied, strict=True) if x
1832
+ ]
1833
+ else:
1834
+ copied.images = [copied.images[i] for i in item]
1835
+ return copied
1836
+ except Exception as e:
1837
+ if hasattr(item, "__iter__"):
1838
+ endnote = f" of length {len(item)} with types {set(type(x) for x in item)}"
1839
+ raise TypeError(
1840
+ "ImageCollection indices must be int or BoundingBox. "
1841
+ f"Got {type(item)}{endnote}"
1842
+ ) from e
1843
+
1844
+ elif isinstance(item, BoundingBox):
1845
+ date_ranges: tuple[str] = (item.mint, item.maxt)
1846
+ data: torch.Tensor = numpy_to_torch(
1847
+ np.array(
1848
+ [
1849
+ band.values
1850
+ for band in self.filter(
1851
+ bounds=item, date_ranges=date_ranges
1852
+ ).merge_by_band(bounds=item)
1853
+ ]
1854
+ )
1855
+ )
1856
+ else:
1857
+ bboxes: list[Polygon] = [to_bbox(x) for x in item]
1858
+ date_ranges: list[list[str, str]] = [(x.mint, x.maxt) for x in item]
1859
+ data: torch.Tensor = torch.cat(
1860
+ [
1861
+ numpy_to_torch(
1862
+ np.array(
1863
+ [
1864
+ band.values
1865
+ for band in self.filter(
1866
+ bounds=bbox, date_ranges=date_range
1867
+ ).merge_by_band(bounds=bbox)
1868
+ ]
1869
+ )
1870
+ )
1871
+ for bbox, date_range in zip(bboxes, date_ranges, strict=True)
1872
+ ]
1873
+ )
1874
+
1875
+ crs = get_common_crs(self.images)
1876
+
1877
+ key = "image" # if self.is_image else "mask"
1878
+ sample = {key: data, "crs": crs, "bbox": item}
1879
+
1880
+ return sample
1881
+
1882
+ # def dates_as_float(self) -> list[tuple[float, float]]:
1883
+ # return [disambiguate_timestamp(date, self.date_format) for date in self.dates]
1884
+
1885
+ @property
1886
+ def mint(self) -> float:
1887
+ """Min timestamp of the images combined."""
1888
+ return min(img.mint for img in self)
1889
+
1890
+ @property
1891
+ def maxt(self) -> float:
1892
+ """Max timestamp of the images combined."""
1893
+ return max(img.maxt for img in self)
1894
+
1895
+ @property
1896
+ def band_ids(self) -> list[str]:
1897
+ """Sorted list of unique band_ids."""
1898
+ return list(sorted({band.band_id for img in self for band in img}))
1899
+
1900
+ @property
1901
+ def file_paths(self) -> list[str]:
1902
+ """Sorted list of all file paths, meaning all band paths."""
1903
+ return list(sorted({band.path for img in self for band in img}))
1904
+
1905
+ @property
1906
+ def dates(self) -> list[str]:
1907
+ """List of image dates."""
1908
+ return [img.date for img in self]
1909
+
1910
+ def dates_as_int(self) -> list[int]:
1911
+ """List of image dates as 8-length integers."""
1912
+ return [int(img.date[:8]) for img in self]
1913
+
1914
+ @property
1915
+ def image_paths(self) -> list[str]:
1916
+ """List of image paths."""
1917
+ return [img.path for img in self]
1918
+
1919
+ @property
1920
+ def images(self) -> list["Image"]:
1921
+ """List of images in the Collection."""
1922
+ try:
1923
+ return self._images
1924
+ except AttributeError:
1925
+ # only fetch images when they are needed
1926
+ self._images = _get_images(
1927
+ list(self._df["image_path"]),
1928
+ all_file_paths=self._all_filepaths,
1929
+ df=self._df,
1930
+ res=self.res,
1931
+ processes=self.processes,
1932
+ image_class=self.image_class,
1933
+ _mask=self._mask,
1934
+ )
1935
+ if self.image_regexes:
1936
+ self._images = list(sorted(self._images))
1937
+ return self._images
1938
+
1939
+ @images.setter
1940
+ def images(self, new_value: list["Image"]) -> list["Image"]:
1941
+ if self.filename_patterns and any(pat.groups for pat in self.filename_patterns):
1942
+ self._images = list(sorted(new_value))
1943
+ else:
1944
+ self._images = list(new_value)
1945
+ if not all(isinstance(x, Image) for x in self._images):
1946
+ raise TypeError("images should be a sequence of Image.")
1947
+
1948
+ @property
1949
+ def index(self) -> Index:
1950
+ """Spatial index that makes torchgeo think this class is a RasterDataset."""
1951
+ try:
1952
+ if len(self) == len(self._index):
1953
+ return self._index
1954
+ except AttributeError:
1955
+ self._index = Index(interleaved=False, properties=Property(dimension=3))
1956
+
1957
+ for i, img in enumerate(self.images):
1958
+ if img.date:
1959
+ try:
1960
+ mint, maxt = disambiguate_timestamp(img.date, self.date_format)
1961
+ except (NameError, TypeError):
1962
+ mint, maxt = 0, 1
1963
+ else:
1964
+ mint, maxt = 0, 1
1965
+ # important: torchgeo has a different order of the bbox than shapely and geopandas
1966
+ minx, miny, maxx, maxy = img.bounds
1967
+ self._index.insert(i, (minx, maxx, miny, maxy, mint, maxt))
1968
+ return self._index
1969
+
1970
+ def __repr__(self) -> str:
1971
+ """String representation."""
1972
+ return f"{self.__class__.__name__}({len(self)})"
1973
+
1974
+ @property
1975
+ def unary_union(self) -> Polygon | MultiPolygon:
1976
+ """(Multi)Polygon representing the union of all image bounds."""
1977
+ return unary_union([img.unary_union for img in self])
1978
+
1979
+ @property
1980
+ def bounds(self) -> tuple[int, int, int, int]:
1981
+ """Total bounds for all Images combined."""
1982
+ return get_total_bounds([img.bounds for img in self])
1983
+
1984
+ @property
1985
+ def crs(self) -> Any:
1986
+ """Common coordinate reference system of the Images."""
1987
+ return get_common_crs([img.crs for img in self])
1988
+
1989
+
1990
+ def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCollection:
1991
+ """Union multiple ImageCollections together.
1992
+
1993
+ Same as using the union operator |.
1994
+ """
1995
+ resolutions = {x.res for x in collections}
1996
+ if len(resolutions) > 1:
1997
+ raise ValueError(f"resoultion mismatch. {resolutions}")
1998
+ images = list(itertools.chain.from_iterable([x.images for x in collections]))
1999
+ levels = {x.level for x in collections}
2000
+ level = next(iter(levels)) if len(levels) == 1 else None
2001
+ first_collection = collections[0]
2002
+
2003
+ out_collection = first_collection.__class__(
2004
+ images,
2005
+ level=level,
2006
+ **first_collection._common_init_kwargs,
2007
+ # res=list(resolutions)[0],
2008
+ )
2009
+ out_collection._all_filepaths = list(
2010
+ sorted(
2011
+ set(itertools.chain.from_iterable([x._all_filepaths for x in collections]))
2012
+ )
2013
+ )
2014
+ return out_collection
2015
+
2016
+
2017
+ def _get_gradient(band: Band, degrees: bool = False, copy: bool = True) -> Band:
2018
+ copied = band.copy() if copy else band
2019
+ if len(copied.values.shape) == 3:
2020
+ return np.array(
2021
+ [_slope_2d(arr, copied.res, degrees=degrees) for arr in copied.values]
2022
+ )
2023
+ elif len(copied.values.shape) == 2:
2024
+ return _slope_2d(copied.values, copied.res, degrees=degrees)
2025
+ else:
2026
+ raise ValueError("array must be 2 or 3 dimensional")
2027
+
2028
+
2029
+ def to_xarray(
2030
+ array: np.ndarray, transform: Affine, crs: Any, name: str | None = None
2031
+ ) -> DataArray:
2032
+ """Convert the raster to an xarray.DataArray."""
2033
+ if len(array.shape) == 2:
2034
+ height, width = array.shape
2035
+ dims = ["y", "x"]
2036
+ elif len(array.shape) == 3:
2037
+ height, width = array.shape[1:]
2038
+ dims = ["band", "y", "x"]
2039
+ else:
2040
+ raise ValueError(f"Array should be 2 or 3 dimensional. Got shape {array.shape}")
2041
+
2042
+ coords = _generate_spatial_coords(transform, width, height)
2043
+ return xr.DataArray(
2044
+ array,
2045
+ coords=coords,
2046
+ dims=dims,
2047
+ name=name,
2048
+ attrs={"crs": crs},
2049
+ ) # .transpose("y", "x")
2050
+
2051
+
2052
+ def _slope_2d(array: np.ndarray, res: int, degrees: int) -> np.ndarray:
2053
+ gradient_x, gradient_y = np.gradient(array, res, res)
2054
+
2055
+ gradient = abs(gradient_x) + abs(gradient_y)
2056
+
2057
+ if not degrees:
2058
+ return gradient
2059
+
2060
+ radians = np.arctan(gradient)
2061
+ degrees = np.degrees(radians)
2062
+
2063
+ assert np.max(degrees) <= 90
2064
+
2065
+ return degrees
2066
+
2067
+
2068
+ def _get_images(
2069
+ image_paths: list[str],
2070
+ *,
2071
+ res: int,
2072
+ all_file_paths: list[str],
2073
+ df: pd.DataFrame,
2074
+ processes: int,
2075
+ image_class: Image,
2076
+ _mask: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None,
2077
+ ) -> list[Image]:
2078
+ with joblib.Parallel(n_jobs=processes, backend="threading") as parallel:
2079
+ return parallel(
2080
+ joblib.delayed(image_class)(
2081
+ path,
2082
+ df=df,
2083
+ res=res,
2084
+ all_file_paths=all_file_paths,
2085
+ _mask=_mask,
2086
+ processes=processes,
2087
+ )
2088
+ for path in image_paths
2089
+ )
2090
+
2091
+
2092
+ def numpy_to_torch(array: np.ndarray) -> torch.Tensor:
2093
+ """Convert numpy array to a pytorch tensor."""
2094
+ # fix numpy dtypes which are not supported by pytorch tensors
2095
+ if array.dtype == np.uint16:
2096
+ array = array.astype(np.int32)
2097
+ elif array.dtype == np.uint32:
2098
+ array = array.astype(np.int64)
2099
+
2100
+ return torch.tensor(array)
2101
+
2102
+
2103
+ class _RegexError(ValueError):
2104
+ pass
2105
+
2106
+
2107
+ class PathlessImageError(ValueError):
2108
+ """Exception for when Images, Bands or ImageCollections have no path."""
2109
+
2110
+ def __init__(self, instance: _ImageBase) -> None:
2111
+ """Initialise error class."""
2112
+ self.instance = instance
2113
+
2114
+ def __str__(self) -> str:
2115
+ """String representation."""
2116
+ if self.instance._merged:
2117
+ what = "that have been merged"
2118
+ elif self.isinstance._from_array:
2119
+ what = "from arrays"
2120
+ elif self.isinstance._from_gdf:
2121
+ what = "from GeoDataFrames"
2122
+
2123
+ return (
2124
+ f"{self.instance.__class__.__name__} instances {what} "
2125
+ "have no 'path' until they are written to file."
2126
+ )
2127
+
2128
+
2129
+ def _get_regex_match_from_xml_in_local_dir(
2130
+ paths: list[str], regexes: str | tuple[str]
2131
+ ) -> str | dict[str, str]:
2132
+ for i, path in enumerate(paths):
2133
+ if ".xml" not in path:
2134
+ continue
2135
+ with open_func(path, "rb") as file:
2136
+ filebytes: bytes = file.read()
2137
+ try:
2138
+ return _extract_regex_match_from_string(
2139
+ filebytes.decode("utf-8"), regexes
2140
+ )
2141
+ except _RegexError as e:
2142
+ if i == len(paths) - 1:
2143
+ raise e
2144
+
2145
+
2146
+ def _extract_regex_match_from_string(
2147
+ xml_file: str, regexes: tuple[str]
2148
+ ) -> str | dict[str, str]:
2149
+ for regex in regexes:
2150
+ if isinstance(regex, dict):
2151
+ out = {}
2152
+ for key, value in regex.items():
2153
+ try:
2154
+ out[key] = re.search(value, xml_file).group(1)
2155
+ except (TypeError, AttributeError):
2156
+ continue
2157
+ if len(out) != len(regex):
2158
+ raise _RegexError()
2159
+ return out
2160
+ try:
2161
+ return re.search(regex, xml_file).group(1)
2162
+ except (TypeError, AttributeError):
2163
+ continue
2164
+ raise _RegexError()
2165
+
2166
+
2167
+ def _fix_path(path: str) -> str:
2168
+ return (
2169
+ str(path).replace("\\", "/").replace(r"\"", "/").replace("//", "/").rstrip("/")
2170
+ )
2171
+
2172
+
2173
+ def _get_regexes_matches_for_df(
2174
+ df, match_col: str, patterns: Sequence[re.Pattern], suffix: str = ""
2175
+ ) -> tuple[pd.DataFrame, list[str]]:
2176
+ if not len(df):
2177
+ return df, []
2178
+ assert df.index.is_unique
2179
+ matches: list[pd.DataFrame] = []
2180
+ for pat in patterns:
2181
+ if pat.groups:
2182
+ try:
2183
+ matches.append(df[match_col].str.extract(pat))
2184
+ except ValueError:
2185
+ continue
2186
+ else:
2187
+ match_ = df[match_col].loc[df[match_col].str.match(pat)]
2188
+ if len(match_):
2189
+ matches.append(match_)
2190
+
2191
+ matches = pd.concat(matches).groupby(level=0, dropna=True).first()
2192
+
2193
+ if isinstance(matches, pd.Series):
2194
+ matches = pd.DataFrame({matches.name: matches.values}, index=matches.index)
2195
+
2196
+ match_cols = [f"{col}{suffix}" for col in matches.columns]
2197
+ df[match_cols] = matches
2198
+ return (
2199
+ df.loc[~df[match_cols].isna().all(axis=1)].drop(
2200
+ columns=f"{match_col}{suffix}", errors="ignore"
2201
+ ),
2202
+ match_cols,
2203
+ )
2204
+
2205
+
2206
+ def _arr_from_gdf(
2207
+ gdf: GeoDataFrame,
2208
+ res: int,
2209
+ fill: int = 0,
2210
+ all_touched: bool = False,
2211
+ merge_alg: Callable = MergeAlg.replace,
2212
+ default_value: int = 1,
2213
+ dtype: Any | None = None,
2214
+ ) -> np.ndarray:
2215
+ """Construct Raster from a GeoDataFrame or GeoSeries.
2216
+
2217
+ The GeoDataFrame should have
2218
+
2219
+ Args:
2220
+ gdf: The GeoDataFrame to rasterize.
2221
+ res: Resolution of the raster in units of the GeoDataFrame's coordinate reference system.
2222
+ fill: Fill value for areas outside of input geometries (default is 0).
2223
+ all_touched: Whether to consider all pixels touched by geometries,
2224
+ not just those whose center is within the polygon (default is False).
2225
+ merge_alg: Merge algorithm to use when combining geometries
2226
+ (default is 'MergeAlg.replace').
2227
+ default_value: Default value to use for the rasterized pixels
2228
+ (default is 1).
2229
+ dtype: Data type of the output array. If None, it will be
2230
+ determined automatically.
2231
+
2232
+ Returns:
2233
+ A Raster instance based on the specified GeoDataFrame and parameters.
2234
+
2235
+ Raises:
2236
+ TypeError: If 'transform' is provided in kwargs, as this is
2237
+ computed based on the GeoDataFrame bounds and resolution.
2238
+ """
2239
+ if isinstance(gdf, GeoSeries):
2240
+ values = gdf.index
2241
+ gdf = gdf.to_frame("geometry")
2242
+ elif isinstance(gdf, GeoDataFrame):
2243
+ if len(gdf.columns) > 2:
2244
+ raise ValueError(
2245
+ "gdf should have only a geometry column and one numeric column to "
2246
+ "use as array values. "
2247
+ "Alternatively only a geometry column and a numeric index."
2248
+ )
2249
+ elif len(gdf.columns) == 1:
2250
+ values = gdf.index
2251
+ else:
2252
+ col: str = next(
2253
+ iter([col for col in gdf if col != gdf._geometry_column_name])
2254
+ )
2255
+ values = gdf[col]
2256
+
2257
+ if isinstance(values, pd.MultiIndex):
2258
+ raise ValueError("Index cannot be MultiIndex.")
2259
+
2260
+ shape = _get_shape_from_bounds(gdf.total_bounds, res=res)
2261
+ transform = _get_transform_from_bounds(gdf.total_bounds, shape)
2262
+
2263
+ return features.rasterize(
2264
+ _gdf_to_geojson_with_col(gdf, values),
2265
+ out_shape=shape,
2266
+ transform=transform,
2267
+ fill=fill,
2268
+ all_touched=all_touched,
2269
+ merge_alg=merge_alg,
2270
+ default_value=default_value,
2271
+ dtype=dtype,
2272
+ )
2273
+
2274
+
2275
+ def _gdf_to_geojson_with_col(gdf: GeoDataFrame, values: np.ndarray) -> list[dict]:
2276
+ with warnings.catch_warnings():
2277
+ warnings.filterwarnings("ignore", category=UserWarning)
2278
+ return [
2279
+ (feature["geometry"], val)
2280
+ for val, feature in zip(
2281
+ values, loads(gdf.to_json())["features"], strict=False
2282
+ )
2283
+ ]
2284
+
2285
+
2286
+ def _date_is_within(
2287
+ path,
2288
+ date_ranges: (
2289
+ tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...] | None
2290
+ ),
2291
+ image_patterns: Sequence[re.Pattern],
2292
+ date_format: str,
2293
+ ) -> bool:
2294
+ for pat in image_patterns:
2295
+ try:
2296
+ date = re.match(pat, Path(path).name).group("date")
2297
+ break
2298
+ except AttributeError:
2299
+ date = None
2300
+
2301
+ if date is None:
2302
+ return False
2303
+
2304
+ if date_ranges is None:
2305
+ return True
2306
+
2307
+ if all(x is None or isinstance(x, (str, float)) for x in date_ranges):
2308
+ date_ranges = (date_ranges,)
2309
+
2310
+ if all(isinstance(x, float) for date_range in date_ranges for x in date_range):
2311
+ date = disambiguate_timestamp(date, date_format)
2312
+ else:
2313
+ date = date[:8]
2314
+
2315
+ for date_range in date_ranges:
2316
+ date_min, date_max = date_range
2317
+
2318
+ if isinstance(date_min, float) and isinstance(date_max, float):
2319
+ if date[0] >= date_min + 0.0000001 and date[1] <= date_max - 0.0000001:
2320
+ return True
2321
+ continue
2322
+
2323
+ try:
2324
+ date_min = date_min or "00000000"
2325
+ date_max = date_max or "99999999"
2326
+ assert isinstance(date_min, str)
2327
+ assert len(date_min) == 8
2328
+ assert isinstance(date_max, str)
2329
+ assert len(date_max) == 8
2330
+ except AssertionError as err:
2331
+ raise TypeError(
2332
+ "date_ranges should be a tuple of two 8-charactered strings (start and end date)."
2333
+ f"Got {date_range} of type {[type(x) for x in date_range]}"
2334
+ ) from err
2335
+ if date >= date_min and date <= date_max:
2336
+ return True
2337
+
2338
+ return False
2339
+
2340
+
2341
+ def _get_shape_from_bounds(
2342
+ obj: GeoDataFrame | GeoSeries | Geometry | tuple, res: int
2343
+ ) -> tuple[int, int]:
2344
+ resx, resy = (res, res) if isinstance(res, numbers.Number) else res
2345
+
2346
+ minx, miny, maxx, maxy = to_bbox(obj)
2347
+ diffx = maxx - minx
2348
+ diffy = maxy - miny
2349
+ width = int(diffx / resx)
2350
+ heigth = int(diffy / resy)
2351
+ return heigth, width
2352
+
2353
+
2354
+ def _get_transform_from_bounds(
2355
+ obj: GeoDataFrame | GeoSeries | Geometry | tuple, shape: tuple[float, ...]
2356
+ ) -> Affine:
2357
+ minx, miny, maxx, maxy = to_bbox(obj)
2358
+ if len(shape) == 2:
2359
+ width, height = shape
2360
+ elif len(shape) == 3:
2361
+ _, width, height = shape
2362
+ else:
2363
+ raise ValueError
2364
+ return rasterio.transform.from_bounds(minx, miny, maxx, maxy, width, height)
2365
+
2366
+
2367
+ def _get_shape_from_res(
2368
+ bounds: tuple[float], res: int, indexes: int | tuple[int]
2369
+ ) -> tuple[int] | None:
2370
+ if res is None:
2371
+ return None
2372
+ if hasattr(res, "__iter__") and len(res) == 2:
2373
+ res = res[0]
2374
+ diffx = bounds[2] - bounds[0]
2375
+ diffy = bounds[3] - bounds[1]
2376
+ width = int(diffx / res)
2377
+ height = int(diffy / res)
2378
+ if not isinstance(indexes, int):
2379
+ return len(indexes), width, height
2380
+ return width, height
2381
+
2382
+
2383
+ def _array_to_geojson(
2384
+ array: np.ndarray, transform: Affine, processes: int
2385
+ ) -> list[tuple]:
2386
+ if np.ma.is_masked(array):
2387
+ array = array.data
2388
+ try:
2389
+ return _array_to_geojson_loop(array, transform, processes)
2390
+
2391
+ except ValueError:
2392
+ try:
2393
+ array = array.astype(np.float32)
2394
+ return _array_to_geojson_loop(array, transform, processes)
2395
+
2396
+ except Exception as err:
2397
+ raise err.__class__(array.shape, err) from err
2398
+
2399
+
2400
+ def _array_to_geojson_loop(array, transform, processes):
2401
+ if processes == 1:
2402
+ return [
2403
+ (value, shape(geom))
2404
+ for geom, value in features.shapes(array, transform=transform, mask=None)
2405
+ ]
2406
+ else:
2407
+ with joblib.Parallel(n_jobs=processes, backend="threading") as parallel:
2408
+ return parallel(
2409
+ joblib.delayed(_value_geom_pair)(value, geom)
2410
+ for geom, value in features.shapes(
2411
+ array, transform=transform, mask=None
2412
+ )
2413
+ )
2414
+
2415
+
2416
+ def _value_geom_pair(value, geom):
2417
+ return (value, shape(geom))
2418
+
2419
+
2420
+ def _intesects(x, other) -> bool:
2421
+ return box(*x.bounds).intersects(other)
2422
+
2423
+
2424
+ def _copy_and_add_df_parallel(
2425
+ i: tuple[Any, ...], group: pd.DataFrame, self: ImageCollection
2426
+ ) -> tuple[tuple[Any], ImageCollection]:
2427
+ copied = self.copy()
2428
+ copied.images = [
2429
+ img.copy() for img in group.drop_duplicates("_image_idx")["_image_instance"]
2430
+ ]
2431
+ if "band_id" in group:
2432
+ band_ids = set(group["band_id"].values)
2433
+ for img in copied.images:
2434
+ img._bands = [band for band in img if band.band_id in band_ids]
2435
+
2436
+ # for col in group.columns.difference({"_image_instance", "_image_idx"}):
2437
+ # if not all(
2438
+ # col in dir(band) or col in band.__dict__ for img in copied for band in img
2439
+ # ):
2440
+ # continue
2441
+ # values = set(group[col].values)
2442
+ # for img in copied.images:
2443
+ # img._bands = [band for band in img if getattr(band, col) in values]
2444
+
2445
+ return (i, copied)
2446
+
2447
+
2448
+ def _open_raster(path: str | Path) -> rasterio.io.DatasetReader:
2449
+ with opener(path) as file:
2450
+ return rasterio.open(file)
2451
+
2452
+
2453
+ def _load_band(band: Band, **kwargs) -> None:
2454
+ band.load(**kwargs)
2455
+
2456
+
2457
+ def _merge_by_band(collection: ImageCollection, **kwargs) -> Image:
2458
+ print("_merge_by_band", collection.dates)
2459
+ return collection.merge_by_band(**kwargs)
2460
+
2461
+
2462
+ def _merge(collection: ImageCollection, **kwargs) -> Band:
2463
+ return collection.merge(**kwargs)
2464
+
2465
+
2466
+ def _zonal_one_pair(i: int, poly: Polygon, band: Band, aggfunc, array_func, func_names):
2467
+ clipped = band.copy().load(bounds=poly)
2468
+ if not np.size(clipped.values):
2469
+ return _no_overlap_df(func_names, i, date=band.date)
2470
+ return _aggregate(clipped.values, array_func, aggfunc, func_names, band.date, i)
2471
+
2472
+
2473
+ class Sentinel2Config:
2474
+ """Holder of Sentinel 2 regexes, band_ids etc."""
2475
+
2476
+ image_regexes: ClassVar[str] = (
2477
+ config.SENTINEL2_IMAGE_REGEX,
2478
+ ) # config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2479
+ filename_regexes: ClassVar[str] = (
2480
+ config.SENTINEL2_FILENAME_REGEX,
2481
+ # config.SENTINEL2_MOSAIC_FILENAME_REGEX,
2482
+ config.SENTINEL2_CLOUD_FILENAME_REGEX,
2483
+ )
2484
+ all_bands: ClassVar[list[str]] = list(config.SENTINEL2_BANDS)
2485
+ rbg_bands: ClassVar[list[str]] = ["B02", "B03", "B04"]
2486
+ ndvi_bands: ClassVar[list[str]] = ["B04", "B08"]
2487
+ cloud_band: ClassVar[str] = "SCL"
2488
+ cloud_values: ClassVar[tuple[int]] = (3, 8, 9, 10, 11)
2489
+ l2a_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L2A_BANDS
2490
+ l1c_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L1C_BANDS
2491
+ date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
2492
+
2493
+
2494
+ class Sentinel2CloudlessConfig(Sentinel2Config):
2495
+ """Holder of regexes, band_ids etc. for Sentinel 2 cloudless mosaic."""
2496
+
2497
+ image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2498
+ filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2499
+ cloud_band: ClassVar[None] = None
2500
+ cloud_values: ClassVar[None] = None
2501
+ date_format: ClassVar[str] = "%Y%m%d"
2502
+
2503
+
2504
+ class Sentinel2Band(Sentinel2Config, Band):
2505
+ """Band with Sentinel2 specific name variables and regexes."""
2506
+
2507
+
2508
+ class Sentinel2Image(Sentinel2Config, Image):
2509
+ """Image with Sentinel2 specific name variables and regexes."""
2510
+
2511
+ cloud_cover_regexes: ClassVar[tuple[str]] = config.CLOUD_COVERAGE_REGEXES
2512
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
2513
+
2514
+ def get_ndvi(
2515
+ self,
2516
+ red_band: str = Sentinel2Config.ndvi_bands[0],
2517
+ nir_band: str = Sentinel2Config.ndvi_bands[1],
2518
+ ) -> NDVIBand:
2519
+ """Calculate the NDVI for the Image."""
2520
+ return super().get_ndvi(red_band=red_band, nir_band=nir_band)
2521
+
2522
+
2523
+ class Sentinel2Collection(Sentinel2Config, ImageCollection):
2524
+ """ImageCollection with Sentinel2 specific name variables and regexes."""
2525
+
2526
+ image_class: ClassVar[Sentinel2Image] = Sentinel2Image
2527
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band
2528
+
2529
+
2530
+ class Sentinel2CloudlessBand(Sentinel2CloudlessConfig, Band):
2531
+ """Band for cloudless mosaic with Sentinel2 specific name variables and regexes."""
2532
+
2533
+
2534
+ class Sentinel2CloudlessImage(Sentinel2CloudlessConfig, Sentinel2Image):
2535
+ """Image for cloudless mosaic with Sentinel2 specific name variables and regexes."""
2536
+
2537
+ # image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2538
+ # filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2539
+
2540
+ cloud_cover_regexes: ClassVar[None] = None
2541
+ band_class: ClassVar[Sentinel2CloudlessBand] = Sentinel2CloudlessBand
2542
+
2543
+ get_ndvi = Sentinel2Image.get_ndvi
2544
+ # def get_ndvi(
2545
+ # self,
2546
+ # red_band: str = Sentinel2Config.ndvi_bands[0],
2547
+ # nir_band: str = Sentinel2Config.ndvi_bands[1],
2548
+ # ) -> NDVIBand:
2549
+ # """Calculate the NDVI for the Image."""
2550
+ # return super().get_ndvi(red_band=red_band, nir_band=nir_band)
2551
+
2552
+
2553
+ class Sentinel2CloudlessCollection(Sentinel2CloudlessConfig, ImageCollection):
2554
+ """ImageCollection with Sentinel2 specific name variables and regexes."""
2555
+
2556
+ # image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2557
+ # filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2558
+
2559
+ image_class: ClassVar[Sentinel2CloudlessImage] = Sentinel2CloudlessImage
2560
+ band_class: ClassVar[Sentinel2Band] = Sentinel2Band