ssb-sgis 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sgis/__init__.py +10 -3
  2. sgis/debug_config.py +24 -0
  3. sgis/geopandas_tools/bounds.py +16 -21
  4. sgis/geopandas_tools/buffer_dissolve_explode.py +112 -30
  5. sgis/geopandas_tools/centerlines.py +4 -91
  6. sgis/geopandas_tools/cleaning.py +1576 -583
  7. sgis/geopandas_tools/conversion.py +24 -14
  8. sgis/geopandas_tools/duplicates.py +27 -6
  9. sgis/geopandas_tools/general.py +259 -100
  10. sgis/geopandas_tools/geometry_types.py +1 -1
  11. sgis/geopandas_tools/neighbors.py +16 -12
  12. sgis/geopandas_tools/overlay.py +7 -3
  13. sgis/geopandas_tools/point_operations.py +3 -3
  14. sgis/geopandas_tools/polygon_operations.py +505 -100
  15. sgis/geopandas_tools/polygons_as_rings.py +40 -8
  16. sgis/geopandas_tools/sfilter.py +26 -9
  17. sgis/io/dapla_functions.py +238 -19
  18. sgis/maps/examine.py +11 -10
  19. sgis/maps/explore.py +227 -155
  20. sgis/maps/legend.py +13 -4
  21. sgis/maps/map.py +22 -13
  22. sgis/maps/maps.py +100 -29
  23. sgis/maps/thematicmap.py +25 -18
  24. sgis/networkanalysis/_service_area.py +6 -1
  25. sgis/networkanalysis/cutting_lines.py +12 -5
  26. sgis/networkanalysis/finding_isolated_networks.py +13 -6
  27. sgis/networkanalysis/networkanalysis.py +10 -12
  28. sgis/parallel/parallel.py +27 -10
  29. sgis/raster/base.py +208 -0
  30. sgis/raster/cube.py +3 -3
  31. sgis/raster/image_collection.py +1421 -724
  32. sgis/raster/indices.py +10 -7
  33. sgis/raster/raster.py +7 -7
  34. sgis/raster/sentinel_config.py +33 -17
  35. {ssb_sgis-1.0.3.dist-info → ssb_sgis-1.0.5.dist-info}/METADATA +6 -7
  36. ssb_sgis-1.0.5.dist-info/RECORD +62 -0
  37. ssb_sgis-1.0.3.dist-info/RECORD +0 -61
  38. {ssb_sgis-1.0.3.dist-info → ssb_sgis-1.0.5.dist-info}/LICENSE +0 -0
  39. {ssb_sgis-1.0.3.dist-info → ssb_sgis-1.0.5.dist-info}/WHEEL +0 -0
@@ -1,22 +1,23 @@
1
+ import datetime
1
2
  import functools
2
3
  import glob
3
4
  import itertools
4
- import numbers
5
+ import math
5
6
  import os
6
7
  import random
7
8
  import re
8
- import warnings
9
9
  from collections.abc import Callable
10
10
  from collections.abc import Iterable
11
11
  from collections.abc import Iterator
12
12
  from collections.abc import Sequence
13
13
  from copy import deepcopy
14
- from json import loads
14
+ from dataclasses import dataclass
15
15
  from pathlib import Path
16
16
  from typing import Any
17
17
  from typing import ClassVar
18
18
 
19
19
  import joblib
20
+ import matplotlib.pyplot as plt
20
21
  import numpy as np
21
22
  import pandas as pd
22
23
  import pyproj
@@ -24,17 +25,19 @@ import rasterio
24
25
  from affine import Affine
25
26
  from geopandas import GeoDataFrame
26
27
  from geopandas import GeoSeries
27
- from rasterio import features
28
+ from matplotlib.colors import LinearSegmentedColormap
28
29
  from rasterio.enums import MergeAlg
29
30
  from rtree.index import Index
30
31
  from rtree.index import Property
32
+ from scipy import stats
33
+ from scipy.ndimage import binary_dilation
34
+ from scipy.ndimage import binary_erosion
31
35
  from shapely import Geometry
32
36
  from shapely import box
33
37
  from shapely import unary_union
34
38
  from shapely.geometry import MultiPolygon
35
39
  from shapely.geometry import Point
36
40
  from shapely.geometry import Polygon
37
- from shapely.geometry import shape
38
41
 
39
42
  try:
40
43
  import dapla as dp
@@ -97,7 +100,7 @@ except ImportError:
97
100
 
98
101
 
99
102
  from ..geopandas_tools.bounds import get_total_bounds
100
- from ..geopandas_tools.bounds import to_bbox
103
+ from ..geopandas_tools.conversion import to_bbox
101
104
  from ..geopandas_tools.conversion import to_gdf
102
105
  from ..geopandas_tools.conversion import to_shapely
103
106
  from ..geopandas_tools.general import get_common_crs
@@ -106,6 +109,10 @@ from ..helpers import get_numpy_func
106
109
  from ..io._is_dapla import is_dapla
107
110
  from ..io.opener import opener
108
111
  from . import sentinel_config as config
112
+ from .base import _array_to_geojson
113
+ from .base import _gdf_to_arr
114
+ from .base import _get_shape_from_bounds
115
+ from .base import _get_transform_from_bounds
109
116
  from .base import get_index_mapper
110
117
  from .indices import ndvi
111
118
  from .zonal import _aggregate
@@ -116,23 +123,56 @@ from .zonal import _zonal_post
116
123
 
117
124
  if is_dapla():
118
125
 
119
- def ls_func(*args, **kwargs) -> list[str]:
126
+ def _ls_func(*args, **kwargs) -> list[str]:
120
127
  return dp.FileClient.get_gcs_file_system().ls(*args, **kwargs)
121
128
 
122
- def glob_func(*args, **kwargs) -> list[str]:
129
+ def _glob_func(*args, **kwargs) -> list[str]:
123
130
  return dp.FileClient.get_gcs_file_system().glob(*args, **kwargs)
124
131
 
125
- def open_func(*args, **kwargs) -> GCSFile:
132
+ def _open_func(*args, **kwargs) -> GCSFile:
126
133
  return dp.FileClient.get_gcs_file_system().open(*args, **kwargs)
127
134
 
135
+ def _rm_file_func(*args, **kwargs) -> None:
136
+ return dp.FileClient.get_gcs_file_system().rm_file(*args, **kwargs)
137
+
138
+ def _read_parquet_func(*args, **kwargs) -> list[str]:
139
+ return dp.read_pandas(*args, **kwargs)
140
+
128
141
  else:
129
- ls_func = functools.partial(get_all_files, recursive=False)
130
- open_func = open
131
- glob_func = glob.glob
142
+ _ls_func = functools.partial(get_all_files, recursive=False)
143
+ _open_func = open
144
+ _glob_func = glob.glob
145
+ _rm_file_func = os.remove
146
+ _read_parquet_func = pd.read_parquet
132
147
 
133
148
  TORCHGEO_RETURN_TYPE = dict[str, torch.Tensor | pyproj.CRS | BoundingBox]
134
149
  FILENAME_COL_SUFFIX = "_filename"
135
- DEFAULT_FILENAME_REGEX = r".*\.(?:tif|tiff|jp2)$"
150
+ DEFAULT_FILENAME_REGEX = r"""
151
+ .*?
152
+ (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
153
+ .*?
154
+ (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
155
+ \.(?:tif|tiff|jp2)$ # End with .tif, .tiff, or .jp2
156
+ """
157
+ DEFAULT_IMAGE_REGEX = r"""
158
+ .*?
159
+ (?:_(?P<date>\d{8}(?:T\d{6})?))? # Optional date group
160
+ (?:_(?P<band>B\d{1,2}A|B\d{1,2}))? # Optional band group
161
+ """
162
+
163
+ ALLOWED_INIT_KWARGS = [
164
+ "image_class",
165
+ "band_class",
166
+ "image_regexes",
167
+ "filename_regexes",
168
+ "date_format",
169
+ "cloud_cover_regexes",
170
+ "bounds_regexes",
171
+ "all_bands",
172
+ "crs",
173
+ "masking",
174
+ "_merged",
175
+ ]
136
176
 
137
177
 
138
178
  class ImageCollectionGroupBy:
@@ -160,7 +200,12 @@ class ImageCollectionGroupBy:
160
200
  self.collection = collection
161
201
 
162
202
  def merge_by_band(
163
- self, bounds=None, method="median", as_int: bool = True, indexes=None, **kwargs
203
+ self,
204
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
205
+ method: str | Callable = "mean",
206
+ as_int: bool = True,
207
+ indexes: int | tuple[int] | None = None,
208
+ **kwargs,
164
209
  ) -> "ImageCollection":
165
210
  """Merge each group into separate Bands per band_id, returned as an ImageCollection."""
166
211
  images = self._run_func_for_collection_groups(
@@ -171,10 +216,8 @@ class ImageCollectionGroupBy:
171
216
  indexes=indexes,
172
217
  **kwargs,
173
218
  )
174
- print("hihihih")
175
219
  for img, (group_values, _) in zip(images, self.data, strict=True):
176
220
  for attr, group_value in zip(self.by, group_values, strict=True):
177
- print(attr, group_value)
178
221
  try:
179
222
  setattr(img, attr, group_value)
180
223
  except AttributeError:
@@ -182,6 +225,7 @@ class ImageCollectionGroupBy:
182
225
 
183
226
  collection = ImageCollection(
184
227
  images,
228
+ # TODO band_class?
185
229
  level=self.collection.level,
186
230
  **self.collection._common_init_kwargs,
187
231
  )
@@ -189,7 +233,12 @@ class ImageCollectionGroupBy:
189
233
  return collection
190
234
 
191
235
  def merge(
192
- self, bounds=None, method="median", as_int: bool = True, indexes=None, **kwargs
236
+ self,
237
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
238
+ method: str | Callable = "mean",
239
+ as_int: bool = True,
240
+ indexes: int | tuple[int] | None = None,
241
+ **kwargs,
193
242
  ) -> "Image":
194
243
  """Merge each group into a single Band, returned as combined Image."""
195
244
  bands: list[Band] = self._run_func_for_collection_groups(
@@ -214,6 +263,7 @@ class ImageCollectionGroupBy:
214
263
 
215
264
  image = Image(
216
265
  bands,
266
+ # TODO band_class?
217
267
  **self.collection._common_init_kwargs,
218
268
  )
219
269
  image._merged = True
@@ -243,13 +293,28 @@ class ImageCollectionGroupBy:
243
293
  return f"{self.__class__.__name__}({len(self)})"
244
294
 
245
295
 
296
+ @dataclass(frozen=True)
297
+ class BandMasking:
298
+ """Basically a frozen dict with forced keys."""
299
+
300
+ band_id: str
301
+ values: tuple[int]
302
+
303
+ def __getitem__(self, item: str) -> Any:
304
+ """Index into attributes to mimick dict."""
305
+ return getattr(self, item)
306
+
307
+
246
308
  class _ImageBase:
247
- image_regexes: ClassVar[str | None] = None
309
+ image_regexes: ClassVar[str | None] = (DEFAULT_IMAGE_REGEX,)
248
310
  filename_regexes: ClassVar[str | tuple[str]] = (DEFAULT_FILENAME_REGEX,)
249
311
  date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
312
+ masking: ClassVar[BandMasking | None] = None
250
313
 
251
- def __init__(self) -> None:
314
+ def __init__(self, **kwargs) -> None:
252
315
 
316
+ self._mask = None
317
+ self._bounds = None
253
318
  self._merged = False
254
319
  self._from_array = False
255
320
  self._from_gdf = False
@@ -262,7 +327,7 @@ class _ImageBase:
262
327
  for regexes in self.filename_regexes
263
328
  ]
264
329
  else:
265
- self.filename_patterns = None
330
+ self.filename_patterns = ()
266
331
 
267
332
  if self.image_regexes:
268
333
  if isinstance(self.image_regexes, str):
@@ -271,7 +336,15 @@ class _ImageBase:
271
336
  re.compile(regexes, flags=re.VERBOSE) for regexes in self.image_regexes
272
337
  ]
273
338
  else:
274
- self.image_patterns = None
339
+ self.image_patterns = ()
340
+
341
+ for key, value in kwargs.items():
342
+ if key in ALLOWED_INIT_KWARGS and key in dir(self):
343
+ setattr(self, key, value)
344
+ else:
345
+ raise ValueError(
346
+ f"{self.__class__.__name__} got an unexpected keyword argument '{key}'"
347
+ )
275
348
 
276
349
  @property
277
350
  def _common_init_kwargs(self) -> dict:
@@ -279,7 +352,8 @@ class _ImageBase:
279
352
  "file_system": self.file_system,
280
353
  "processes": self.processes,
281
354
  "res": self.res,
282
- "_mask": self._mask,
355
+ "bbox": self._bbox,
356
+ "nodata": self.nodata,
283
357
  }
284
358
 
285
359
  @property
@@ -287,7 +361,7 @@ class _ImageBase:
287
361
  try:
288
362
  return self._path
289
363
  except AttributeError as e:
290
- raise PathlessImageError(self.__class__.__name__) from e
364
+ raise PathlessImageError(self) from e
291
365
 
292
366
  @property
293
367
  def res(self) -> int:
@@ -296,7 +370,8 @@ class _ImageBase:
296
370
 
297
371
  @property
298
372
  def centroid(self) -> Point:
299
- return self.unary_union.centroid
373
+ """Centerpoint of the object."""
374
+ return self.union_all().centroid
300
375
 
301
376
  def _name_regex_searcher(
302
377
  self, group: str, patterns: tuple[re.Pattern]
@@ -305,17 +380,22 @@ class _ImageBase:
305
380
  return None
306
381
  for pat in patterns:
307
382
  try:
383
+ return _get_first_group_match(pat, self.name)[group]
308
384
  return re.match(pat, self.name).group(group)
309
- except (AttributeError, TypeError, IndexError):
385
+ except (TypeError, KeyError):
310
386
  pass
387
+ if not any(group in _get_non_optional_groups(pat) for pat in patterns):
388
+ return None
311
389
  raise ValueError(
312
390
  f"Couldn't find group '{group}' in name {self.name} with regex patterns {patterns}"
313
391
  )
314
392
 
315
- def _create_metadata_df(self, file_paths: list[str]) -> None:
393
+ def _create_metadata_df(self, file_paths: list[str]) -> pd.DataFrame:
394
+ """Create a dataframe with file paths and image paths that match regexes."""
316
395
  df = pd.DataFrame({"file_path": file_paths})
317
396
 
318
397
  df["filename"] = df["file_path"].apply(lambda x: _fix_path(Path(x).name))
398
+
319
399
  if not self.single_banded:
320
400
  df["image_path"] = df["file_path"].apply(
321
401
  lambda x: _fix_path(str(Path(x).parent))
@@ -327,20 +407,13 @@ class _ImageBase:
327
407
  return df
328
408
 
329
409
  if self.filename_patterns:
330
- df, match_cols_filename = _get_regexes_matches_for_df(
331
- df, "filename", self.filename_patterns, suffix=FILENAME_COL_SUFFIX
332
- )
410
+ df = _get_regexes_matches_for_df(df, "filename", self.filename_patterns)
333
411
 
334
412
  if not len(df):
335
413
  return df
336
414
 
337
- self._match_cols_filename = match_cols_filename
338
- grouped = (
339
- df.drop(columns=match_cols_filename, errors="ignore")
340
- .drop_duplicates("image_path")
341
- .set_index("image_path")
342
- )
343
- for col in ["file_path", "filename", *match_cols_filename]:
415
+ grouped = df.drop_duplicates("image_path").set_index("image_path")
416
+ for col in ["file_path", "filename"]:
344
417
  if col in df:
345
418
  grouped[col] = df.groupby("image_path")[col].apply(tuple)
346
419
 
@@ -355,14 +428,11 @@ class _ImageBase:
355
428
  )
356
429
 
357
430
  if self.image_patterns and len(grouped):
358
- grouped, _ = _get_regexes_matches_for_df(
359
- grouped, "imagename", self.image_patterns, suffix=""
431
+ grouped = _get_regexes_matches_for_df(
432
+ grouped, "imagename", self.image_patterns
360
433
  )
361
434
 
362
- if "date" in grouped:
363
- return grouped.sort_values("date")
364
- else:
365
- return grouped
435
+ return grouped
366
436
 
367
437
  def copy(self) -> "_ImageBase":
368
438
  """Copy the instance and its attributes."""
@@ -383,19 +453,23 @@ class _ImageBandBase(_ImageBase):
383
453
  pyproj.CRS(other.crs)
384
454
  ):
385
455
  raise ValueError(f"crs mismatch: {self.crs} and {other.crs}")
386
- return self.unary_union.intersects(to_shapely(other))
456
+ return self.union_all().intersects(to_shapely(other))
457
+
458
+ @property
459
+ def mask_percentage(self) -> float:
460
+ return self.mask.values.sum() / (self.mask.width * self.mask.height) * 100
387
461
 
388
462
  @property
389
463
  def year(self) -> str:
464
+ if hasattr(self, "_year") and self._year:
465
+ return self._year
390
466
  return self.date[:4]
391
467
 
392
468
  @property
393
469
  def month(self) -> str:
394
- return "".join(self.date.split("-"))[:6]
395
-
396
- @property
397
- def yyyymmd(self) -> str:
398
- return "".join(self.date.split("-"))[:7]
470
+ if hasattr(self, "_month") and self._month:
471
+ return self._month
472
+ return "".join(self.date.split("-"))[4:6]
399
473
 
400
474
  @property
401
475
  def name(self) -> str | None:
@@ -429,13 +503,15 @@ class _ImageBandBase(_ImageBase):
429
503
  def maxt(self) -> float:
430
504
  return disambiguate_timestamp(self.date, self.date_format)[1]
431
505
 
432
- @property
433
- def unary_union(self) -> Polygon:
434
- return box(*self.bounds)
506
+ def union_all(self) -> Polygon:
507
+ try:
508
+ return box(*self.bounds)
509
+ except TypeError:
510
+ return Polygon()
435
511
 
436
512
  @property
437
- def bbox(self) -> BoundingBox:
438
- bounds = GeoSeries([self.unary_union]).bounds
513
+ def torch_bbox(self) -> BoundingBox:
514
+ bounds = GeoSeries([self.union_all()]).bounds
439
515
  return BoundingBox(
440
516
  minx=bounds.minx[0],
441
517
  miny=bounds.miny[0],
@@ -451,6 +527,34 @@ class Band(_ImageBandBase):
451
527
 
452
528
  cmap: ClassVar[str | None] = None
453
529
 
530
+ @classmethod
531
+ def from_gdf(
532
+ cls,
533
+ gdf: GeoDataFrame | GeoSeries,
534
+ res: int,
535
+ *,
536
+ fill: int = 0,
537
+ all_touched: bool = False,
538
+ merge_alg: Callable = MergeAlg.replace,
539
+ default_value: int = 1,
540
+ dtype: Any | None = None,
541
+ **kwargs,
542
+ ) -> None:
543
+ """Create Band from a GeoDataFrame."""
544
+ arr: np.ndarray = _gdf_to_arr(
545
+ gdf,
546
+ res=res,
547
+ fill=fill,
548
+ all_touched=all_touched,
549
+ merge_alg=merge_alg,
550
+ default_value=default_value,
551
+ dtype=dtype,
552
+ )
553
+
554
+ obj = cls(arr, res=res, crs=gdf.crs, bounds=gdf.total_bounds, **kwargs)
555
+ obj._from_gdf = True
556
+ return obj
557
+
454
558
  def __init__(
455
559
  self,
456
560
  data: str | np.ndarray,
@@ -462,26 +566,31 @@ class Band(_ImageBandBase):
462
566
  file_system: GCSFileSystem | None = None,
463
567
  band_id: str | None = None,
464
568
  processes: int = 1,
465
- _mask: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
569
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
570
+ mask: "Band | None" = None,
571
+ nodata: int | None = None,
466
572
  **kwargs,
467
573
  ) -> None:
468
574
  """Band initialiser."""
469
- if isinstance(data, (GeoDataFrame | GeoSeries)):
470
- if res is None:
471
- raise ValueError("Must specify res when data is vector geometries.")
472
- bounds = to_bbox(bounds) if bounds is not None else data.total_bounds
473
- crs = crs if crs else data.crs
474
- data: np.ndarray = _arr_from_gdf(data, res=res, **kwargs)
475
- self._from_gdf = True
575
+ super().__init__(**kwargs)
576
+
577
+ self._mask = mask
578
+ self._bbox = to_bbox(bbox) if bbox is not None else None
579
+ self._values = None
580
+ self._crs = None
581
+ self.nodata = nodata
582
+
583
+ bounds = to_bbox(bounds) if bounds is not None else None
584
+
585
+ self._bounds = bounds
476
586
 
477
587
  if isinstance(data, np.ndarray):
478
- self._values = data
479
- if bounds is None:
588
+ self.values = data
589
+ if self._bounds is None:
480
590
  raise ValueError("Must specify bounds when data is an array.")
481
- self._bounds = to_bbox(bounds)
482
591
  self._crs = crs
483
592
  self.transform = _get_transform_from_bounds(
484
- self.bounds, shape=self.values.shape
593
+ self._bounds, shape=self.values.shape
485
594
  )
486
595
  self._from_array = True
487
596
 
@@ -497,19 +606,18 @@ class Band(_ImageBandBase):
497
606
  if cmap is not None:
498
607
  self.cmap = cmap
499
608
  self.file_system = file_system
500
- self._mask = _mask
501
609
  self._name = name
502
610
  self._band_id = band_id
503
611
  self.processes = processes
504
612
 
505
- if self.filename_regexes:
506
- if isinstance(self.filename_regexes, str):
507
- self.filename_regexes = [self.filename_regexes]
508
- self.filename_patterns = [
509
- re.compile(pat, flags=re.VERBOSE) for pat in self.filename_regexes
510
- ]
511
- else:
512
- self.filename_patterns = None
613
+ # if self.filename_regexes:
614
+ # if isinstance(self.filename_regexes, str):
615
+ # self.filename_regexes = [self.filename_regexes]
616
+ # self.filename_patterns = [
617
+ # re.compile(pat, flags=re.VERBOSE) for pat in self.filename_regexes
618
+ # ]
619
+ # else:
620
+ # self.filename_patterns = None
513
621
 
514
622
  def __lt__(self, other: "Band") -> bool:
515
623
  """Makes Bands sortable by band_id."""
@@ -518,17 +626,29 @@ class Band(_ImageBandBase):
518
626
  @property
519
627
  def values(self) -> np.ndarray:
520
628
  """The numpy array, if loaded."""
521
- try:
522
- return self._values
523
- except AttributeError as e:
524
- raise ValueError("array is not loaded.") from e
629
+ if self._values is None:
630
+ raise ArrayNotLoadedError("array is not loaded.")
631
+ return self._values
525
632
 
526
633
  @values.setter
527
634
  def values(self, new_val):
528
- if isinstance(new_val, np.ndarray):
529
- raise TypeError(f"{self.__class__.__name__} 'values' must be np.ndarray.")
635
+ if not isinstance(new_val, np.ndarray):
636
+ raise TypeError(
637
+ f"{self.__class__.__name__} 'values' must be np.ndarray. Got {type(new_val)}"
638
+ )
530
639
  self._values = new_val
531
640
 
641
+ @property
642
+ def mask(self) -> "Band":
643
+ """Mask Band."""
644
+ return self._mask
645
+
646
+ @mask.setter
647
+ def mask(self, values: "Band") -> None:
648
+ if values is not None and not isinstance(values, Band):
649
+ raise TypeError(f"'mask' should be of type Band. Got {type(values)}")
650
+ self._mask = values
651
+
532
652
  @property
533
653
  def band_id(self) -> str:
534
654
  """Band id."""
@@ -539,21 +659,21 @@ class Band(_ImageBandBase):
539
659
  @property
540
660
  def height(self) -> int:
541
661
  """Pixel heigth of the image band."""
542
- i = 1 if len(self.values.shape) == 3 else 0
543
- return self.values.shape[i]
662
+ return self.values.shape[-2]
544
663
 
545
664
  @property
546
665
  def width(self) -> int:
547
666
  """Pixel width of the image band."""
548
- i = 2 if len(self.values.shape) == 3 else 1
549
- return self.values.shape[i]
667
+ return self.values.shape[-1]
550
668
 
551
669
  @property
552
670
  def tile(self) -> str:
553
671
  """Tile name from filename_regex."""
554
672
  if hasattr(self, "_tile") and self._tile:
555
673
  return self._tile
556
- return self._name_regex_searcher("tile", self.filename_patterns)
674
+ return self._name_regex_searcher(
675
+ "tile", self.filename_patterns + self.image_patterns
676
+ )
557
677
 
558
678
  @property
559
679
  def date(self) -> str:
@@ -561,33 +681,31 @@ class Band(_ImageBandBase):
561
681
  if hasattr(self, "_date") and self._date:
562
682
  return self._date
563
683
 
564
- return self._name_regex_searcher("date", self.filename_patterns)
684
+ return self._name_regex_searcher(
685
+ "date", self.filename_patterns + self.image_patterns
686
+ )
565
687
 
566
688
  @property
567
689
  def crs(self) -> str | None:
568
690
  """Coordinate reference system."""
569
- try:
691
+ if self._crs is not None:
570
692
  return self._crs
571
- except AttributeError:
572
- with opener(self.path, file_system=self.file_system) as file:
573
- with rasterio.open(file) as src:
574
- self._bounds = to_bbox(src.bounds)
575
- self._crs = src.crs
576
- return self._crs
693
+ with opener(self.path, file_system=self.file_system) as file:
694
+ with rasterio.open(file) as src:
695
+ # self._bounds = to_bbox(src.bounds)
696
+ self._crs = src.crs
697
+ return self._crs
577
698
 
578
699
  @property
579
700
  def bounds(self) -> tuple[int, int, int, int] | None:
580
701
  """Bounds as tuple (minx, miny, maxx, maxy)."""
581
- try:
702
+ if self._bounds is not None:
582
703
  return self._bounds
583
- except AttributeError:
584
- with opener(self.path, file_system=self.file_system) as file:
585
- with rasterio.open(file) as src:
586
- self._bounds = to_bbox(src.bounds)
587
- self._crs = src.crs
588
- return self._bounds
589
- except TypeError:
590
- return None
704
+ with opener(self.path, file_system=self.file_system) as file:
705
+ with rasterio.open(file) as src:
706
+ self._bounds = to_bbox(src.bounds)
707
+ self._crs = src.crs
708
+ return self._bounds
591
709
 
592
710
  def get_n_largest(
593
711
  self, n: int, precision: float = 0.000001, column: str = "value"
@@ -611,104 +729,178 @@ class Band(_ImageBandBase):
611
729
  df[column] = f"smallest_{n}"
612
730
  return df
613
731
 
614
- def load(self, bounds=None, indexes=None, **kwargs) -> "Band":
732
+ def load(
733
+ self,
734
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
735
+ indexes: int | tuple[int] | None = None,
736
+ masked: bool | None = None,
737
+ **kwargs,
738
+ ) -> "Band":
615
739
  """Load and potentially clip the array.
616
740
 
617
741
  The array is stored in the 'values' property.
618
742
  """
619
- bounds = to_bbox(bounds) if bounds is not None else self._mask
743
+ if masked is None:
744
+ masked = True if self.mask is None else False
745
+
746
+ bounds_was_none = bounds is None
620
747
 
621
748
  try:
622
- assert isinstance(self.values, np.ndarray)
749
+ if not isinstance(self.values, np.ndarray):
750
+ raise ValueError()
623
751
  has_array = True
624
- except (ValueError, AssertionError):
752
+ except ValueError: # also catches ArrayNotLoadedError
625
753
  has_array = False
626
754
 
755
+ # get common bounds of function argument 'bounds' and previously set bbox
756
+ if bounds is None and self._bbox is None:
757
+ bounds = None
758
+ elif bounds is not None and self._bbox is None:
759
+ bounds = to_shapely(bounds).intersection(self.union_all())
760
+ elif bounds is None and self._bbox is not None:
761
+ bounds = to_shapely(self._bbox).intersection(self.union_all())
762
+ else:
763
+ bounds = to_shapely(bounds).intersection(to_shapely(self._bbox))
764
+
765
+ should_return_empty: bool = bounds is not None and bounds.area == 0
766
+ if should_return_empty:
767
+ self._values = np.array([])
768
+ if self.mask is not None and not self.is_mask:
769
+ self._mask = self._mask.load()
770
+ # self._mask = np.ma.array([], [])
771
+ self._bounds = None
772
+ self.transform = None
773
+ return self
774
+
775
+ if has_array and bounds_was_none:
776
+ return self
777
+
778
+ # round down/up to integer to avoid precision trouble
779
+ if bounds is not None:
780
+ # bounds = to_bbox(bounds)
781
+ minx, miny, maxx, maxy = to_bbox(bounds)
782
+ bounds = (int(minx), int(miny), math.ceil(maxx), math.ceil(maxy))
783
+
784
+ boundless = False
785
+
786
+ if indexes is None:
787
+ indexes = 1
788
+
789
+ # as tuple to ensure we get 3d array
790
+ _indexes: tuple[int] = (indexes,) if isinstance(indexes, int) else indexes
791
+
792
+ # allow setting a fixed out_shape for the array, in order to make mask same shape as values
793
+ out_shape = kwargs.pop("out_shape", None)
794
+
627
795
  if has_array:
628
- if bounds is None:
629
- return self
630
- # bounds_shapely = to_shapely(bounds)
631
- # if not bounds_shapely.intersects(self.)
632
- # bounds_arr = GeoSeries([bounds_shapely]).values
633
- bounds_arr = GeoSeries([to_shapely(bounds)]).values
634
- try:
635
- self._values = (
636
- to_xarray(
637
- self.values,
638
- transform=self.transform,
639
- crs=self.crs,
640
- name=self.name,
641
- )
642
- .rio.clip(bounds_arr, crs=self.crs)
643
- .to_numpy()
644
- )
645
- except NoDataInBounds:
646
- self._values = np.array([])
796
+ self.values = _clip_loaded_array(
797
+ self.values, bounds, self.transform, self.crs, out_shape, **kwargs
798
+ )
647
799
  self._bounds = bounds
648
- return self
800
+ self.transform = _get_transform_from_bounds(self._bounds, self.values.shape)
649
801
 
650
- with opener(self.path, file_system=self.file_system) as f:
651
- with rasterio.open(f) as src:
652
- self._res = int(src.res[0]) if not self.res else self.res
653
- # if bounds is None:
654
- # out_shape = _get_shape_from_res(to_bbox(src.bounds), self.res, indexes)
655
- # self.transform = src.transform
656
- # arr = src.load(indexes=indexes, out_shape=out_shape, **kwargs)
657
- # # if isinstance(indexes, int) and len(arr.shape) == 3:
658
- # # return arr[0]
659
- # return arr
660
- # else:
661
- # window = rasterio.windows.from_bounds(
662
- # *bounds, transform=src.transform
663
- # )
664
- # out_shape = _get_shape_from_bounds(bounds, self.res)
665
-
666
- # arr = src.read(
667
- # indexes=indexes,
668
- # out_shape=out_shape,
669
- # window=window,
670
- # boundless=boundless,
671
- # **kwargs,
672
- # )
673
- # if isinstance(indexes, int):
674
- # # arr = arr[0]
675
- # height, width = arr.shape
676
- # else:
677
- # height, width = arr.shape[1:]
678
-
679
- # self.transform = rasterio.transform.from_bounds(
680
- # *bounds, width, height
681
- # )
682
- # if bounds is not None:
683
- # self._bounds = bounds
684
- # return arr
685
-
686
- if indexes is None and len(src.indexes) == 1:
687
- indexes = 1
802
+ else:
803
+ with opener(self.path, file_system=self.file_system) as f:
804
+ with rasterio.open(f, nodata=self.nodata) as src:
805
+ self._res = int(src.res[0]) if not self.res else self.res
688
806
 
689
- if isinstance(indexes, int):
690
- _indexes = (indexes,)
691
- else:
692
- _indexes = indexes
807
+ if self.nodata is None or np.isnan(self.nodata):
808
+ self.nodata = src.nodata
809
+ else:
810
+ dtype_min_value = _get_dtype_min(src.dtypes[0])
811
+ dtype_max_value = _get_dtype_max(src.dtypes[0])
812
+ if (
813
+ self.nodata > dtype_max_value
814
+ or self.nodata < dtype_min_value
815
+ ):
816
+ src._dtypes = tuple(
817
+ rasterio.dtypes.get_minimum_dtype(self.nodata)
818
+ for _ in range(len(_indexes))
819
+ )
820
+
821
+ if bounds is None:
822
+ if self._res != int(src.res[0]):
823
+ if out_shape is None:
824
+ out_shape = _get_shape_from_bounds(
825
+ to_bbox(src.bounds), self.res, indexes
826
+ )
827
+ self.transform = _get_transform_from_bounds(
828
+ to_bbox(src.bounds), shape=out_shape
829
+ )
830
+ else:
831
+ self.transform = src.transform
832
+
833
+ self._values = src.read(
834
+ indexes=indexes,
835
+ out_shape=out_shape,
836
+ masked=masked,
837
+ **kwargs,
838
+ )
839
+ else:
840
+ window = rasterio.windows.from_bounds(
841
+ *bounds, transform=src.transform
842
+ )
693
843
 
694
- arr, transform = rasterio.merge.merge(
695
- [src],
696
- res=self.res,
697
- indexes=_indexes,
698
- bounds=bounds,
699
- **kwargs,
700
- )
701
- self.transform = transform
702
- if bounds is not None:
703
- self._bounds = bounds
844
+ if out_shape is None:
845
+ out_shape = _get_shape_from_bounds(
846
+ bounds, self.res, indexes
847
+ )
848
+
849
+ self._values = src.read(
850
+ indexes=indexes,
851
+ window=window,
852
+ boundless=boundless,
853
+ out_shape=out_shape,
854
+ masked=masked,
855
+ **kwargs,
856
+ )
704
857
 
705
- if isinstance(indexes, int):
706
- arr = arr[0]
858
+ assert out_shape == self._values.shape, (
859
+ out_shape,
860
+ self._values.shape,
861
+ )
862
+
863
+ self.transform = rasterio.transform.from_bounds(
864
+ *bounds, self.width, self.height
865
+ )
866
+ self._bounds = bounds
867
+
868
+ if self.nodata is not None and not np.isnan(self.nodata):
869
+ if isinstance(self.values, np.ma.core.MaskedArray):
870
+ self.values.data[self.values.data == src.nodata] = (
871
+ self.nodata
872
+ )
873
+ else:
874
+ self.values[self.values == src.nodata] = self.nodata
875
+
876
+ if self.masking and self.is_mask:
877
+ self.values = np.isin(self.values, self.masking["values"])
878
+
879
+ elif self.mask is not None and not isinstance(
880
+ self.values, np.ma.core.MaskedArray
881
+ ):
882
+ self.mask = self.mask.copy().load(
883
+ bounds=bounds, indexes=indexes, out_shape=out_shape, **kwargs
884
+ )
885
+ mask_arr = self.mask.values
886
+
887
+ # if self.masking:
888
+ # mask_arr = np.isin(mask_arr, self.masking["values"])
889
+
890
+ self._values = np.ma.array(
891
+ self._values, mask=mask_arr, fill_value=self.nodata
892
+ )
707
893
 
708
- self._values = arr
709
894
  return self
710
895
 
711
- def write(self, path: str | Path, **kwargs) -> None:
896
+ @property
897
+ def is_mask(self) -> bool:
898
+ """True if the band_id is equal to the masking band_id."""
899
+ return self.band_id == self.masking["band_id"]
900
+
901
+ def write(
902
+ self, path: str | Path, driver: str = "GTiff", compress: str = "LZW", **kwargs
903
+ ) -> None:
712
904
  """Write the array as an image file."""
713
905
  if not hasattr(self, "_values"):
714
906
  raise ValueError(
@@ -719,44 +911,78 @@ class Band(_ImageBandBase):
719
911
  raise ValueError("Cannot write None crs to image.")
720
912
 
721
913
  profile = {
722
- # "driver": self.driver,
723
- # "compress": self.compress,
724
- # "dtype": self.dtype,
914
+ "driver": driver,
915
+ "compress": compress,
916
+ "dtype": rasterio.dtypes.get_minimum_dtype(self.values),
725
917
  "crs": self.crs,
726
918
  "transform": self.transform,
727
- # "nodata": self.nodata,
728
- # "count": self.count,
729
- # "height": self.height,
730
- # "width": self.width,
731
- # "indexes": self.indexes,
919
+ "nodata": self.nodata,
920
+ "count": 1 if len(self.values.shape) == 2 else self.values.shape[0],
921
+ "height": self.height,
922
+ "width": self.width,
732
923
  } | kwargs
733
924
 
734
- with opener(path, "w", file_system=self.file_system) as f:
735
- with rasterio.open(f, **profile) as dst:
736
- # bounds = to_bbox(self._mask) if self._mask is not None else dst.bounds
925
+ with opener(path, "wb", file_system=self.file_system) as f:
926
+ with rasterio.open(f, "w", **profile) as dst:
927
+
928
+ if dst.nodata is None:
929
+ dst.nodata = _get_dtype_min(dst.dtypes[0])
737
930
 
738
- # res = dst.res if not self.res else self.res
931
+ # if (
932
+ # isinstance(self.values, np.ma.core.MaskedArray)
933
+ # # and dst.nodata is not None
934
+ # ):
935
+ # self.values.data[np.isnan(self.values.data)] = dst.nodata
936
+ # self.values.data[self.values.mask] = dst.nodata
739
937
 
740
938
  if len(self.values.shape) == 2:
741
- return dst.write(self.values, indexes=1)
939
+ dst.write(self.values, indexes=1)
940
+ else:
941
+ for i in range(self.values.shape[0]):
942
+ dst.write(self.values[i], indexes=i + 1)
742
943
 
743
- for i in range(self.values.shape[0]):
744
- dst.write(self.values[i], indexes=i + 1)
944
+ if isinstance(self.values, np.ma.core.MaskedArray):
945
+ dst.write_mask(self.values.mask)
745
946
 
746
947
  self._path = str(path)
747
948
 
949
+ def apply(self, func: Callable, **kwargs) -> "Band":
950
+ """Apply a function to the array."""
951
+ self.values = func(self.values, **kwargs)
952
+ return self
953
+
954
+ def normalize(self) -> "Band":
955
+ """Normalize array values between 0 and 1."""
956
+ arr = self.values
957
+ self.values = (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
958
+ return self
959
+
748
960
  def sample(self, size: int = 1000, mask: Any = None, **kwargs) -> "Image":
749
961
  """Take a random spatial sample area of the Band."""
750
962
  copied = self.copy()
751
963
  if mask is not None:
752
- point = GeoSeries([copied.unary_union]).clip(mask).sample_points(1)
964
+ point = GeoSeries([copied.union_all()]).clip(mask).sample_points(1)
753
965
  else:
754
- point = GeoSeries([copied.unary_union]).sample_points(1)
755
- buffered = point.buffer(size / 2).clip(copied.unary_union)
966
+ point = GeoSeries([copied.union_all()]).sample_points(1)
967
+ buffered = point.buffer(size / 2).clip(copied.union_all())
756
968
  copied = copied.load(bounds=buffered.total_bounds, **kwargs)
757
969
  return copied
758
970
 
759
- def get_gradient(self, degrees: bool = False, copy: bool = True) -> "Band":
971
+ def buffer(self, distance: int, copy: bool = True) -> "Band":
972
+ """Buffer array points with the value 1 in a binary array.
973
+
974
+ Args:
975
+ distance: Number of array cells to buffer by.
976
+ copy: Whether to copy the Band.
977
+
978
+ Returns:
979
+ Band with buffered values.
980
+ """
981
+ copied = self.copy() if copy else self
982
+ copied.values = array_buffer(copied.values, distance)
983
+ return copied
984
+
985
+ def gradient(self, degrees: bool = False, copy: bool = True) -> "Band":
760
986
  """Get the slope of an elevation band.
761
987
 
762
988
  Calculates the absolute slope between the grid cells
@@ -901,7 +1127,7 @@ class Band(_ImageBandBase):
901
1127
  dims=dims,
902
1128
  name=name,
903
1129
  attrs={"crs": self.crs},
904
- ) # .transpose("y", "x")
1130
+ )
905
1131
 
906
1132
  def __repr__(self) -> str:
907
1133
  """String representation."""
@@ -923,6 +1149,141 @@ class NDVIBand(Band):
923
1149
 
924
1150
  cmap: str = "Greens"
925
1151
 
1152
+ # @staticmethod
1153
+ # def get_cmap(arr: np.ndarray):
1154
+ # return get_cmap(arr)
1155
+
1156
+
1157
+ def get_cmap(arr: np.ndarray) -> LinearSegmentedColormap:
1158
+
1159
+ # blue = [[i / 10 + 0.1, i / 10 + 0.1, 1 - (i / 10) + 0.1] for i in range(11)][1:]
1160
+ blue = [
1161
+ [0.1, 0.1, 1.0],
1162
+ [0.2, 0.2, 0.9],
1163
+ [0.3, 0.3, 0.8],
1164
+ [0.4, 0.4, 0.7],
1165
+ [0.6, 0.6, 0.6],
1166
+ [0.6, 0.6, 0.6],
1167
+ [0.7, 0.7, 0.7],
1168
+ [0.8, 0.8, 0.8],
1169
+ ]
1170
+ # gray = list(reversed([[i / 10 - 0.1, i / 10, i / 10 - 0.1] for i in range(11)][1:]))
1171
+ gray = [
1172
+ [0.6, 0.6, 0.6],
1173
+ [0.6, 0.6, 0.6],
1174
+ [0.6, 0.6, 0.6],
1175
+ [0.6, 0.6, 0.6],
1176
+ [0.6, 0.6, 0.6],
1177
+ [0.4, 0.7, 0.4],
1178
+ [0.3, 0.7, 0.3],
1179
+ [0.2, 0.8, 0.2],
1180
+ ]
1181
+ # gray = [[0.6, 0.6, 0.6] for i in range(10)]
1182
+ # green = [[0.2 + i/20, i / 10 - 0.1, + i/20] for i in range(11)][1:]
1183
+ green = [
1184
+ [0.25, 0.0, 0.05],
1185
+ [0.3, 0.1, 0.1],
1186
+ [0.35, 0.2, 0.15],
1187
+ [0.4, 0.3, 0.2],
1188
+ [0.45, 0.4, 0.25],
1189
+ [0.5, 0.5, 0.3],
1190
+ [0.55, 0.6, 0.35],
1191
+ [0.7, 0.9, 0.5],
1192
+ ]
1193
+ green = [
1194
+ [0.6, 0.6, 0.6],
1195
+ [0.4, 0.7, 0.4],
1196
+ [0.3, 0.8, 0.3],
1197
+ [0.25, 0.4, 0.25],
1198
+ [0.2, 0.5, 0.2],
1199
+ [0.10, 0.7, 0.10],
1200
+ [0, 0.9, 0],
1201
+ ]
1202
+
1203
+ def get_start(arr):
1204
+ min_value = np.min(arr)
1205
+ if min_value < -0.75:
1206
+ return 0
1207
+ if min_value < -0.5:
1208
+ return 1
1209
+ if min_value < -0.25:
1210
+ return 2
1211
+ if min_value < 0:
1212
+ return 3
1213
+ if min_value < 0.25:
1214
+ return 4
1215
+ if min_value < 0.5:
1216
+ return 5
1217
+ if min_value < 0.75:
1218
+ return 6
1219
+ return 7
1220
+
1221
+ def get_stop(arr):
1222
+ max_value = np.max(arr)
1223
+ if max_value <= 0.05:
1224
+ return 0
1225
+ if max_value < 0.175:
1226
+ return 1
1227
+ if max_value < 0.25:
1228
+ return 2
1229
+ if max_value < 0.375:
1230
+ return 3
1231
+ if max_value < 0.5:
1232
+ return 4
1233
+ if max_value < 0.75:
1234
+ return 5
1235
+ return 6
1236
+
1237
+ cmap_name = "blue_gray_green"
1238
+
1239
+ start = get_start(arr)
1240
+ stop = get_stop(arr)
1241
+ blue = blue[start]
1242
+ gray = gray[start]
1243
+ # green = green[start]
1244
+ green = green[stop]
1245
+
1246
+ # green[0] = np.arange(0, 1, 0.1)[::-1][stop]
1247
+ # green[1] = np.arange(0, 1, 0.1)[stop]
1248
+ # green[2] = np.arange(0, 1, 0.1)[::-1][stop]
1249
+
1250
+ print(green)
1251
+ print(start, stop)
1252
+ print("blue gray green")
1253
+ print(blue)
1254
+ print(gray)
1255
+ print(green)
1256
+
1257
+ # Define the segments of the colormap
1258
+ cdict = {
1259
+ "red": [
1260
+ (0.0, blue[0], blue[0]),
1261
+ (0.3, gray[0], gray[0]),
1262
+ (0.7, gray[0], gray[0]),
1263
+ (1.0, green[0], green[0]),
1264
+ ],
1265
+ "green": [
1266
+ (0.0, blue[1], blue[1]),
1267
+ (0.3, gray[1], gray[1]),
1268
+ (0.7, gray[1], gray[1]),
1269
+ (1.0, green[1], green[1]),
1270
+ ],
1271
+ "blue": [
1272
+ (0.0, blue[2], blue[2]),
1273
+ (0.3, gray[2], gray[2]),
1274
+ (0.7, gray[2], gray[2]),
1275
+ (1.0, green[2], green[2]),
1276
+ ],
1277
+ }
1278
+
1279
+ return LinearSegmentedColormap(cmap_name, segmentdata=cdict, N=50)
1280
+
1281
+
1282
+ def median_as_int_and_minimum_dtype(arr: np.ndarray) -> np.ndarray:
1283
+ arr = np.median(arr, axis=0).astype(int)
1284
+ min_dtype = rasterio.dtypes.get_minimum_dtype(arr)
1285
+ return arr.astype(min_dtype)
1286
+
926
1287
 
927
1288
  class Image(_ImageBandBase):
928
1289
  """Image consisting of one or more Bands."""
@@ -934,54 +1295,63 @@ class Image(_ImageBandBase):
934
1295
  self,
935
1296
  data: str | Path | Sequence[Band],
936
1297
  res: int | None = None,
937
- # crs: Any | None = None,
1298
+ crs: Any | None = None,
938
1299
  single_banded: bool = False,
939
1300
  file_system: GCSFileSystem | None = None,
940
1301
  df: pd.DataFrame | None = None,
941
1302
  all_file_paths: list[str] | None = None,
942
- _mask: GeoDataFrame | GeoSeries | Geometry | tuple | None = None,
943
1303
  processes: int = 1,
1304
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple | None = None,
1305
+ nodata: int | None = None,
1306
+ **kwargs,
944
1307
  ) -> None:
945
1308
  """Image initialiser."""
946
- super().__init__()
1309
+ super().__init__(**kwargs)
947
1310
 
1311
+ self.nodata = nodata
948
1312
  self._res = res
949
- # self._crs = crs
1313
+ self._crs = crs
950
1314
  self.file_system = file_system
951
- self._mask = _mask
1315
+ self._bbox = to_bbox(bbox) if bbox is not None else None
1316
+ # self._mask = _mask
952
1317
  self.single_banded = single_banded
953
1318
  self.processes = processes
1319
+ self._all_file_paths = all_file_paths
954
1320
 
955
1321
  if hasattr(data, "__iter__") and all(isinstance(x, Band) for x in data):
956
1322
  self._bands = list(data)
957
- self._bounds = get_total_bounds(self._bands)
958
- self._crs = get_common_crs(self._bands)
959
- res = list({band.res for band in self._bands})
960
- if len(res) == 1:
961
- self._res = res[0]
1323
+ if res is None:
1324
+ res = list({band.res for band in self._bands})
1325
+ if len(res) == 1:
1326
+ self._res = res[0]
1327
+ else:
1328
+ raise ValueError(f"Different resolutions for the bands: {res}")
962
1329
  else:
963
- raise ValueError(f"Different resolutions for the bands: {res}")
1330
+ self._res = res
964
1331
  return
965
1332
 
966
1333
  if not isinstance(data, (str | Path | os.PathLike)):
967
1334
  raise TypeError("'data' must be string, Path-like or a sequence of Band.")
968
1335
 
1336
+ self._bands = None
969
1337
  self._path = str(data)
970
1338
 
971
1339
  if df is None:
972
1340
  if is_dapla():
973
- file_paths = list(sorted(set(glob_func(self.path + "/**"))))
1341
+ file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
974
1342
  else:
975
1343
  file_paths = list(
976
1344
  sorted(
977
1345
  set(
978
- glob_func(self.path + "/**/**")
979
- + glob_func(self.path + "/**/**/**")
980
- + glob_func(self.path + "/**/**/**/**")
981
- + glob_func(self.path + "/**/**/**/**/**")
1346
+ _glob_func(self.path + "/**/**")
1347
+ + _glob_func(self.path + "/**/**/**")
1348
+ + _glob_func(self.path + "/**/**/**/**")
1349
+ + _glob_func(self.path + "/**/**/**/**/**")
982
1350
  )
983
1351
  )
984
1352
  )
1353
+ if not file_paths:
1354
+ file_paths = [self.path]
985
1355
  df = self._create_metadata_df(file_paths)
986
1356
 
987
1357
  df["image_path"] = df["image_path"].astype(str)
@@ -1000,14 +1370,9 @@ class Image(_ImageBandBase):
1000
1370
 
1001
1371
  df = df.loc[lambda x: x["image_path"].str.contains(_fix_path(self.path))]
1002
1372
 
1003
- if self.filename_patterns and any(pat.groups for pat in self.filename_patterns):
1004
- df = df.loc[
1005
- lambda x: (x[f"band{FILENAME_COL_SUFFIX}"].notna())
1006
- ].sort_values(f"band{FILENAME_COL_SUFFIX}")
1007
-
1008
1373
  if self.cloud_cover_regexes:
1009
1374
  if all_file_paths is None:
1010
- file_paths = ls_func(self.path)
1375
+ file_paths = _ls_func(self.path)
1011
1376
  else:
1012
1377
  file_paths = [path for path in all_file_paths if self.name in path]
1013
1378
  self.cloud_coverage_percentage = float(
@@ -1018,32 +1383,41 @@ class Image(_ImageBandBase):
1018
1383
  else:
1019
1384
  self.cloud_coverage_percentage = None
1020
1385
 
1021
- self._bands = [
1022
- self.band_class(
1023
- path,
1024
- **self._common_init_kwargs,
1025
- )
1026
- for path in (df["file_path"])
1027
- ]
1386
+ self._df = df
1028
1387
 
1029
- if self.filename_patterns and any(pat.groups for pat in self.filename_patterns):
1030
- self._bands = list(sorted(self._bands))
1388
+ @property
1389
+ def values(self) -> np.ndarray:
1390
+ """3 dimensional numpy array."""
1391
+ return np.array([band.values for band in self])
1031
1392
 
1032
- def get_ndvi(self, red_band: str, nir_band: str) -> NDVIBand:
1393
+ def ndvi(self, red_band: str, nir_band: str, copy: bool = True) -> NDVIBand:
1033
1394
  """Calculate the NDVI for the Image."""
1034
- red = self[red_band].load().values
1035
- nir = self[nir_band].load().values
1395
+ copied = self.copy() if copy else self
1396
+ red = copied[red_band].load()
1397
+ nir = copied[nir_band].load()
1036
1398
 
1037
- arr: np.ndarray = ndvi(red, nir)
1399
+ arr: np.ndarray | np.ma.core.MaskedArray = ndvi(red.values, nir.values)
1400
+
1401
+ # if self.nodata is not None and not np.isnan(self.nodata):
1402
+ # try:
1403
+ # arr.data[arr.mask] = self.nodata
1404
+ # arr = arr.copy()
1405
+ # except AttributeError:
1406
+ # pass
1038
1407
 
1039
1408
  return NDVIBand(
1040
1409
  arr,
1041
- bounds=self.bounds,
1042
- crs=self.crs,
1043
- **self._common_init_kwargs,
1410
+ bounds=red.bounds,
1411
+ crs=red.crs,
1412
+ mask=red.mask,
1413
+ **red._common_init_kwargs,
1044
1414
  )
1045
1415
 
1046
- def get_brightness(self, bounds=None, rbg_bands: list[str] | None = None) -> Band:
1416
+ def get_brightness(
1417
+ self,
1418
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1419
+ rbg_bands: list[str] | None = None,
1420
+ ) -> Band:
1047
1421
  """Get a Band with a brightness score of the Image's RBG bands."""
1048
1422
  if rbg_bands is None:
1049
1423
  try:
@@ -1067,9 +1441,55 @@ class Image(_ImageBandBase):
1067
1441
  brightness,
1068
1442
  bounds=red.bounds,
1069
1443
  crs=self.crs,
1444
+ mask=self.mask,
1445
+ **self._common_init_kwargs,
1446
+ )
1447
+
1448
+ @property
1449
+ def mask(self) -> Band | None:
1450
+ """Mask Band."""
1451
+ if self._mask is not None:
1452
+ return self._mask
1453
+ if self.masking is None:
1454
+ return None
1455
+
1456
+ mask_band_id = self.masking["band_id"]
1457
+ mask_paths = [path for path in self._df["file_path"] if mask_band_id in path]
1458
+ if len(mask_paths) > 1:
1459
+ raise ValueError(
1460
+ f"Multiple file_paths match mask band_id {mask_band_id} for {self.path}"
1461
+ )
1462
+ elif not mask_paths:
1463
+ raise ValueError(
1464
+ f"No file_paths match mask band_id {mask_band_id} for {self.path}"
1465
+ )
1466
+ self._mask = self.band_class(
1467
+ mask_paths[0],
1070
1468
  **self._common_init_kwargs,
1071
1469
  )
1072
1470
 
1471
+ return self._mask
1472
+
1473
+ @mask.setter
1474
+ def mask(self, values: Band) -> None:
1475
+ if values is None:
1476
+ self._mask = None
1477
+ for band in self:
1478
+ band.mask = None
1479
+ return
1480
+ if not isinstance(values, Band):
1481
+ raise TypeError(f"mask must be Band. Got {type(values)}")
1482
+ self._mask = values
1483
+ mask_arr = self._mask.values
1484
+ for band in self:
1485
+ band._mask = self._mask
1486
+ try:
1487
+ band.values = np.ma.array(
1488
+ band.values, mask=mask_arr, fill_value=band.nodata
1489
+ )
1490
+ except ArrayNotLoadedError:
1491
+ pass
1492
+
1073
1493
  @property
1074
1494
  def band_ids(self) -> list[str]:
1075
1495
  """The Band ids."""
@@ -1083,14 +1503,100 @@ class Image(_ImageBandBase):
1083
1503
  @property
1084
1504
  def bands(self) -> list[Band]:
1085
1505
  """The Image Bands."""
1506
+ if self._bands is not None:
1507
+ return self._bands
1508
+
1509
+ # if self.masking:
1510
+ # mask_band_id = self.masking["band_id"]
1511
+ # mask_paths = [
1512
+ # path for path in self._df["file_path"] if mask_band_id in path
1513
+ # ]
1514
+ # if len(mask_paths) > 1:
1515
+ # raise ValueError(
1516
+ # f"Multiple file_paths match mask band_id {mask_band_id}"
1517
+ # )
1518
+ # elif not mask_paths:
1519
+ # raise ValueError(f"No file_paths match mask band_id {mask_band_id}")
1520
+ # arr = (
1521
+ # self.band_class(
1522
+ # mask_paths[0],
1523
+ # # mask=self.mask,
1524
+ # **self._common_init_kwargs,
1525
+ # )
1526
+ # .load()
1527
+ # .values
1528
+ # )
1529
+ # self._mask = np.ma.array(
1530
+ # arr, mask=np.isin(arr, self.masking["values"]), fill_value=None
1531
+ # )
1532
+
1533
+ self._bands = [
1534
+ self.band_class(
1535
+ path,
1536
+ mask=self.mask,
1537
+ **self._common_init_kwargs,
1538
+ )
1539
+ for path in (self._df["file_path"])
1540
+ ]
1541
+
1542
+ if self.masking:
1543
+ mask_band_id = self.masking["band_id"]
1544
+ self._bands = [
1545
+ band for band in self._bands if mask_band_id not in band.path
1546
+ ]
1547
+
1548
+ if (
1549
+ self.filename_patterns
1550
+ and any(_get_non_optional_groups(pat) for pat in self.filename_patterns)
1551
+ or self.image_patterns
1552
+ and any(_get_non_optional_groups(pat) for pat in self.image_patterns)
1553
+ ):
1554
+ self._bands = [band for band in self._bands if band.band_id is not None]
1555
+
1556
+ if self.filename_patterns:
1557
+ self._bands = [
1558
+ band
1559
+ for band in self._bands
1560
+ if any(
1561
+ # _get_first_group_match(pat, band.name)
1562
+ re.search(pat, band.name)
1563
+ for pat in self.filename_patterns
1564
+ )
1565
+ ]
1566
+
1567
+ if self.image_patterns:
1568
+ self._bands = [
1569
+ band
1570
+ for band in self._bands
1571
+ if any(
1572
+ re.search(pat, Path(band.path).parent.name)
1573
+ # _get_first_group_match(pat, Path(band.path).parent.name)
1574
+ for pat in self.image_patterns
1575
+ )
1576
+ ]
1577
+
1578
+ if self._should_be_sorted:
1579
+ self._bands = list(sorted(self._bands))
1580
+
1086
1581
  return self._bands
1087
1582
 
1583
+ @property
1584
+ def _should_be_sorted(self) -> bool:
1585
+ sort_groups = ["band", "band_id"]
1586
+ return self.filename_patterns and any(
1587
+ group in _get_non_optional_groups(pat)
1588
+ for group in sort_groups
1589
+ for pat in self.filename_patterns
1590
+ )
1591
+
1088
1592
  @property
1089
1593
  def tile(self) -> str:
1090
1594
  """Tile name from filename_regex."""
1091
1595
  if hasattr(self, "_tile") and self._tile:
1092
1596
  return self._tile
1093
- return self._name_regex_searcher("tile", self.image_patterns)
1597
+ return self._name_regex_searcher(
1598
+ "tile", self.image_patterns + self.filename_patterns
1599
+ )
1094
1600
 
1095
1601
  @property
1096
1602
  def date(self) -> str:
@@ -1098,63 +1604,24 @@ class Image(_ImageBandBase):
1098
1604
  if hasattr(self, "_date") and self._date:
1099
1605
  return self._date
1100
1606
 
1101
- return self._name_regex_searcher("date", self.image_patterns)
1607
+ return self._name_regex_searcher(
1608
+ "date", self.image_patterns + self.filename_patterns
1609
+ )
1102
1610
 
1103
1611
  @property
1104
1612
  def crs(self) -> str | None:
1105
1613
  """Coordinate reference system of the Image."""
1106
- try:
1614
+ if self._crs is not None:
1107
1615
  return self._crs
1108
- except AttributeError:
1109
- if not len(self):
1110
- return None
1111
- with opener(self.file_paths[0], file_system=self.file_system) as file:
1112
- with rasterio.open(file) as src:
1113
- self._bounds = to_bbox(src.bounds)
1114
- self._crs = src.crs
1115
- return self._crs
1616
+ if not len(self):
1617
+ return None
1618
+ self._crs = get_common_crs(self)
1619
+ return self._crs
1116
1620
 
1117
1621
  @property
1118
1622
  def bounds(self) -> tuple[int, int, int, int] | None:
1119
1623
  """Bounds of the Image (minx, miny, maxx, maxy)."""
1120
- try:
1121
- return self._bounds
1122
- except AttributeError:
1123
- if not len(self):
1124
- return None
1125
- with opener(self.file_paths[0], file_system=self.file_system) as file:
1126
- with rasterio.open(file) as src:
1127
- self._bounds = to_bbox(src.bounds)
1128
- self._crs = src.crs
1129
- return self._bounds
1130
- except TypeError:
1131
- return None
1132
-
1133
- # @property
1134
- # def year(self) -> str:
1135
- # return self.date[:4]
1136
-
1137
- # @property
1138
- # def month(self) -> str:
1139
- # return "".join(self.date.split("-"))[:6]
1140
-
1141
- # def write(self, image_path: str | Path, file_type: str = "tif", **kwargs) -> None:
1142
- # _test = kwargs.pop("_test")
1143
- # suffix = "." + file_type.strip(".")
1144
- # img_path = Path(image_path)
1145
- # for band in self:
1146
- # band_path = (img_path / band.name).with_suffix(suffix)
1147
- # if _test:
1148
- # print(f"{self.__class__.__name__}.write: {band_path}")
1149
- # continue
1150
-
1151
- # band.write(band_path, **kwargs)
1152
-
1153
- # def read(self, bounds=None, **kwargs) -> np.ndarray:
1154
- # """Return 3 dimensional numpy.ndarray of shape (n bands, width, height)."""
1155
- # return np.array(
1156
- # [(band.load(bounds=bounds, **kwargs).values) for band in self.bands]
1157
- # )
1624
+ return get_total_bounds([band.bounds for band in self])
1158
1625
 
1159
1626
  def to_gdf(self, column: str = "value") -> GeoDataFrame:
1160
1627
  """Convert the array to a GeoDataFrame of grid polygons and values."""
@@ -1168,38 +1635,15 @@ class Image(_ImageBandBase):
1168
1635
  """Take a random spatial sample of the image."""
1169
1636
  copied = self.copy()
1170
1637
  if mask is not None:
1171
- points = GeoSeries([self.unary_union]).clip(mask).sample_points(n)
1638
+ points = GeoSeries([self.union_all()]).clip(mask).sample_points(n)
1172
1639
  else:
1173
- points = GeoSeries([self.unary_union]).sample_points(n)
1174
- buffered = points.buffer(size / 2).clip(self.unary_union)
1640
+ points = GeoSeries([self.union_all()]).sample_points(n)
1641
+ buffered = points.buffer(size / 2).clip(self.union_all())
1175
1642
  boxes = to_gdf([box(*arr) for arr in buffered.bounds.values], crs=self.crs)
1176
1643
  copied._bands = [band.load(bounds=boxes, **kwargs) for band in copied]
1644
+ copied._bounds = get_total_bounds([band.bounds for band in copied])
1177
1645
  return copied
1178
1646
 
1179
- # def get_filepath(self, band: str) -> str:
1180
- # simple_string_match = [path for path in self.file_paths if str(band) in path]
1181
- # if len(simple_string_match) == 1:
1182
- # return simple_string_match[0]
1183
-
1184
- # regexes_matches = []
1185
- # for path in self.file_paths:
1186
- # for pat in self.filename_patterns:
1187
- # match_ = re.search(pat, Path(path).name)
1188
- # if match_ and str(band) == match_.group("band"):
1189
- # regexes_matches.append(path)
1190
-
1191
- # if len(regexes_matches) == 1:
1192
- # return regexes_matches[0]
1193
-
1194
- # if len(regexes_matches) > 1:
1195
- # prefix = "Multiple"
1196
- # elif not regexes_matches:
1197
- # prefix = "No"
1198
-
1199
- # raise KeyError(
1200
- # f"{prefix} matches for band {band} among paths {[Path(x).name for x in self.file_paths]}"
1201
- # )
1202
-
1203
1647
  def __getitem__(
1204
1648
  self, band: str | int | Sequence[str] | Sequence[int]
1205
1649
  ) -> "Band | Image":
@@ -1234,7 +1678,14 @@ class Image(_ImageBandBase):
1234
1678
 
1235
1679
  def __lt__(self, other: "Image") -> bool:
1236
1680
  """Makes Images sortable by date."""
1237
- return self.date < other.date
1681
+ try:
1682
+ return self.date < other.date
1683
+ except Exception as e:
1684
+ print(self.path)
1685
+ print(self.date)
1686
+ print(other.path)
1687
+ print(other.date)
1688
+ raise e
1238
1689
 
1239
1690
  def __iter__(self) -> Iterator[Band]:
1240
1691
  """Iterate over the Bands."""
@@ -1248,39 +1699,6 @@ class Image(_ImageBandBase):
1248
1699
  """String representation."""
1249
1700
  return f"{self.__class__.__name__}(bands={self.bands})"
1250
1701
 
1251
- def get_cloud_band(self) -> Band:
1252
- """Get a Band where self.cloud_values have value 1 and the rest have value 0."""
1253
- scl = self[self.cloud_band].load()
1254
- scl._values = np.where(np.isin(scl.values, self.cloud_values), 1, 0)
1255
- return scl
1256
-
1257
- # @property
1258
- # def transform(self) -> Affine | None:
1259
- # """Get the Affine transform of the image."""
1260
- # try:
1261
- # return rasterio.transform.from_bounds(*self.bounds, self.width, self.height)
1262
- # except (ZeroDivisionError, TypeError):
1263
- # if not self.width or not self.height:
1264
- # return None
1265
-
1266
- # @property
1267
- # def shape(self) -> tuple[int]:
1268
- # return self._shape
1269
-
1270
- # @property
1271
- # def height(self) -> int:
1272
- # i = 1 if len(self.shape) == 3 else 0
1273
- # return self.shape[i]
1274
-
1275
- # @property
1276
- # def width(self) -> int:
1277
- # i = 2 if len(self.shape) == 3 else 1
1278
- # return self.shape[i]
1279
-
1280
- # @transform.setter
1281
- # def transform(self, value: Affine) -> None:
1282
- # self._bounds = rasterio.transform.array_bounds(self.height, self.width, value)
1283
-
1284
1702
  def _get_band(self, band: str) -> Band:
1285
1703
  if not isinstance(band, str):
1286
1704
  raise TypeError(f"band must be string. Got {type(band)}")
@@ -1289,9 +1707,15 @@ class Image(_ImageBandBase):
1289
1707
  if len(bands) == 1:
1290
1708
  return bands[0]
1291
1709
  if len(bands) > 1:
1292
- raise ValueError(
1293
- f"Multiple matches for band_id {band} among {[x for x in self]}"
1294
- )
1710
+ raise ValueError(f"Multiple matches for band_id {band} for {self}")
1711
+
1712
+ bands = [x for x in self.bands if x.band_id == band.replace("B0", "B")]
1713
+ if len(bands) == 1:
1714
+ return bands[0]
1715
+
1716
+ bands = [x for x in self.bands if x.band_id.replace("B0", "B") == band]
1717
+ if len(bands) == 1:
1718
+ return bands[0]
1295
1719
 
1296
1720
  try:
1297
1721
  more_bands = [x for x in self.bands if x.path == band]
@@ -1325,27 +1749,43 @@ class ImageCollection(_ImageBase):
1325
1749
  data: str | Path | Sequence[Image],
1326
1750
  res: int,
1327
1751
  level: str | None,
1752
+ crs: Any | None = None,
1328
1753
  single_banded: bool = False,
1329
1754
  processes: int = 1,
1330
1755
  file_system: GCSFileSystem | None = None,
1331
1756
  df: pd.DataFrame | None = None,
1332
- _mask: Any | None = None,
1757
+ bbox: Any | None = None,
1758
+ nodata: int | None = None,
1759
+ metadata: str | dict | pd.DataFrame | None = None,
1760
+ **kwargs,
1333
1761
  ) -> None:
1334
1762
  """Initialiser."""
1335
- super().__init__()
1763
+ super().__init__(**kwargs)
1336
1764
 
1765
+ self.nodata = nodata
1337
1766
  self.level = level
1767
+ self._crs = crs
1338
1768
  self.processes = processes
1339
1769
  self.file_system = file_system
1340
1770
  self._res = res
1341
- self._mask = _mask
1771
+ self._bbox = to_bbox(bbox) if bbox is not None else None
1342
1772
  self._band_ids = None
1343
1773
  self.single_banded = single_banded
1344
1774
 
1775
+ if metadata is not None:
1776
+ if isinstance(metadata, (str | Path | os.PathLike)):
1777
+ self.metadata = _read_parquet_func(metadata)
1778
+ else:
1779
+ self.metadata = metadata
1780
+ else:
1781
+ self.metadata = metadata
1782
+
1345
1783
  if hasattr(data, "__iter__") and all(isinstance(x, Image) for x in data):
1346
1784
  self._path = None
1347
- self.images = data
1785
+ self.images = [x.copy() for x in data]
1348
1786
  return
1787
+ else:
1788
+ self._images = None
1349
1789
 
1350
1790
  if not isinstance(data, (str | Path | os.PathLike)):
1351
1791
  raise TypeError("'data' must be string, Path-like or a sequence of Image.")
@@ -1353,28 +1793,53 @@ class ImageCollection(_ImageBase):
1353
1793
  self._path = str(data)
1354
1794
 
1355
1795
  if is_dapla():
1356
- self._all_filepaths = list(sorted(set(glob_func(self.path + "/**"))))
1796
+ self._all_file_paths = list(sorted(set(_glob_func(self.path + "/**"))))
1357
1797
  else:
1358
- self._all_filepaths = list(
1798
+ self._all_file_paths = list(
1359
1799
  sorted(
1360
1800
  set(
1361
- glob_func(self.path + "/**/**")
1362
- + glob_func(self.path + "/**/**/**")
1363
- + glob_func(self.path + "/**/**/**/**")
1364
- + glob_func(self.path + "/**/**/**/**/**")
1801
+ _glob_func(self.path + "/**/**")
1802
+ + _glob_func(self.path + "/**/**/**")
1803
+ + _glob_func(self.path + "/**/**/**/**")
1804
+ + _glob_func(self.path + "/**/**/**/**/**")
1365
1805
  )
1366
1806
  )
1367
1807
  )
1368
1808
 
1369
1809
  if self.level:
1370
- self._all_filepaths = [
1371
- path for path in self._all_filepaths if self.level in path
1810
+ self._all_file_paths = [
1811
+ path for path in self._all_file_paths if self.level in path
1372
1812
  ]
1373
1813
 
1374
1814
  if df is not None:
1375
1815
  self._df = df
1376
1816
  else:
1377
- self._df = self._create_metadata_df(self._all_filepaths)
1817
+ self._df = self._create_metadata_df(self._all_file_paths)
1818
+
1819
+ @property
1820
+ def values(self) -> np.ndarray:
1821
+ """4 dimensional numpy array."""
1822
+ return np.array([img.values for img in self])
1823
+
1824
+ @property
1825
+ def mask(self) -> np.ndarray:
1826
+ """4 dimensional numpy array."""
1827
+ return np.array([img.mask.values for img in self])
1828
+
1829
+ # def ndvi(
1830
+ # self, red_band: str, nir_band: str, copy: bool = True
1831
+ # ) -> "ImageCollection":
1832
+ # # copied = self.copy() if copy else self
1833
+
1834
+ # with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1835
+ # ndvi_images = parallel(
1836
+ # joblib.delayed(_img_ndvi)(
1837
+ # img, red_band=red_band, nir_band=nir_band, copy=False
1838
+ # )
1839
+ # for img in self
1840
+ # )
1841
+
1842
+ # return ImageCollection(ndvi_images, single_banded=True)
1378
1843
 
1379
1844
  def groupby(self, by: str | list[str], **kwargs) -> ImageCollectionGroupBy:
1380
1845
  """Group the Collection by Image or Band attribute(s)."""
@@ -1418,9 +1883,11 @@ class ImageCollection(_ImageBase):
1418
1883
  self.image_class(
1419
1884
  [band],
1420
1885
  single_banded=True,
1886
+ masking=self.masking,
1887
+ band_class=self.band_class,
1421
1888
  **self._common_init_kwargs,
1422
1889
  df=self._df,
1423
- all_file_paths=self._all_filepaths,
1890
+ all_file_paths=self._all_file_paths,
1424
1891
  )
1425
1892
  for img in self
1426
1893
  for band in img
@@ -1428,10 +1895,15 @@ class ImageCollection(_ImageBase):
1428
1895
  return copied
1429
1896
 
1430
1897
  def merge(
1431
- self, bounds=None, method="median", as_int: bool = True, indexes=None, **kwargs
1898
+ self,
1899
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1900
+ method: str | Callable = "mean",
1901
+ as_int: bool = True,
1902
+ indexes: int | tuple[int] | None = None,
1903
+ **kwargs,
1432
1904
  ) -> Band:
1433
1905
  """Merge all areas and all bands to a single Band."""
1434
- bounds = to_bbox(bounds) if bounds is not None else self._mask
1906
+ bounds = to_bbox(bounds) if bounds is not None else self._bbox
1435
1907
  crs = self.crs
1436
1908
 
1437
1909
  if indexes is None:
@@ -1447,20 +1919,22 @@ class ImageCollection(_ImageBase):
1447
1919
  else:
1448
1920
  _method = method
1449
1921
 
1450
- if method not in list(rasterio.merge.MERGE_METHODS) + ["mean"]:
1922
+ if self.masking or method not in list(rasterio.merge.MERGE_METHODS) + ["mean"]:
1451
1923
  arr = self._merge_with_numpy_func(
1452
1924
  method=method,
1453
1925
  bounds=bounds,
1454
1926
  as_int=as_int,
1927
+ **kwargs,
1455
1928
  )
1456
1929
  else:
1457
1930
  datasets = [_open_raster(path) for path in self.file_paths]
1458
1931
  arr, _ = rasterio.merge.merge(
1459
1932
  datasets,
1460
1933
  res=self.res,
1461
- bounds=bounds,
1934
+ bounds=(bounds if bounds is not None else self.bounds),
1462
1935
  indexes=_indexes,
1463
1936
  method=_method,
1937
+ nodata=self.nodata,
1464
1938
  **kwargs,
1465
1939
  )
1466
1940
 
@@ -1481,6 +1955,7 @@ class ImageCollection(_ImageBase):
1481
1955
  arr,
1482
1956
  bounds=bounds,
1483
1957
  crs=crs,
1958
+ mask=self.mask,
1484
1959
  **self._common_init_kwargs,
1485
1960
  )
1486
1961
 
@@ -1490,15 +1965,15 @@ class ImageCollection(_ImageBase):
1490
1965
  def merge_by_band(
1491
1966
  self,
1492
1967
  bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1493
- method: str = "median",
1968
+ method: str = "mean",
1494
1969
  as_int: bool = True,
1495
1970
  indexes: int | tuple[int] | None = None,
1496
1971
  **kwargs,
1497
1972
  ) -> Image:
1498
1973
  """Merge all areas to a single tile, one band per band_id."""
1499
- bounds = to_bbox(bounds) if bounds is not None else self._mask
1974
+ bounds = to_bbox(bounds) if bounds is not None else self._bbox
1500
1975
  bounds = self.bounds if bounds is None else bounds
1501
- out_bounds = self.bounds if bounds is None else bounds
1976
+ out_bounds = bounds
1502
1977
  crs = self.crs
1503
1978
 
1504
1979
  if indexes is None:
@@ -1517,20 +1992,24 @@ class ImageCollection(_ImageBase):
1517
1992
  arrs = []
1518
1993
  bands: list[Band] = []
1519
1994
  for (band_id,), band_collection in self.groupby("band_id"):
1520
- if method not in list(rasterio.merge.MERGE_METHODS) + ["mean"]:
1995
+ if self.masking or method not in list(rasterio.merge.MERGE_METHODS) + [
1996
+ "mean"
1997
+ ]:
1521
1998
  arr = band_collection._merge_with_numpy_func(
1522
1999
  method=method,
1523
2000
  bounds=bounds,
1524
2001
  as_int=as_int,
2002
+ **kwargs,
1525
2003
  )
1526
2004
  else:
1527
2005
  datasets = [_open_raster(path) for path in band_collection.file_paths]
1528
2006
  arr, _ = rasterio.merge.merge(
1529
2007
  datasets,
1530
2008
  res=self.res,
1531
- bounds=bounds,
2009
+ bounds=(bounds if bounds is not None else self.bounds),
1532
2010
  indexes=_indexes,
1533
2011
  method=_method,
2012
+ nodata=self.nodata,
1534
2013
  **kwargs,
1535
2014
  )
1536
2015
  if isinstance(indexes, int):
@@ -1555,6 +2034,7 @@ class ImageCollection(_ImageBase):
1555
2034
  # return self.image_class(
1556
2035
  image = Image(
1557
2036
  bands,
2037
+ band_class=self.band_class,
1558
2038
  **self._common_init_kwargs,
1559
2039
  )
1560
2040
 
@@ -1564,17 +2044,33 @@ class ImageCollection(_ImageBase):
1564
2044
  def _merge_with_numpy_func(
1565
2045
  self,
1566
2046
  method: str | Callable,
1567
- bounds=None,
2047
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
1568
2048
  as_int: bool = True,
1569
- indexes=None,
2049
+ indexes: int | tuple[int] | None = None,
1570
2050
  **kwargs,
1571
2051
  ) -> np.ndarray:
1572
2052
  arrs = []
2053
+ kwargs["indexes"] = indexes
2054
+ bounds = to_shapely(bounds) if bounds is not None else None
1573
2055
  numpy_func = get_numpy_func(method) if not callable(method) else method
1574
2056
  for (_bounds,), collection in self.groupby("bounds"):
2057
+ _bounds = (
2058
+ to_shapely(_bounds).intersection(bounds)
2059
+ if bounds is not None
2060
+ else to_shapely(_bounds)
2061
+ )
2062
+ if not _bounds.area:
2063
+ continue
2064
+
2065
+ _bounds = to_bbox(_bounds)
1575
2066
  arr = np.array(
1576
2067
  [
1577
- band.load(bounds=bounds, indexes=indexes, **kwargs).values
2068
+ (
2069
+ band.load(
2070
+ bounds=(_bounds if _bounds is not None else None),
2071
+ **kwargs,
2072
+ )
2073
+ ).values
1578
2074
  for img in collection
1579
2075
  for band in img
1580
2076
  ]
@@ -1600,79 +2096,102 @@ class ImageCollection(_ImageBase):
1600
2096
  arr,
1601
2097
  coords=coords,
1602
2098
  dims=["y", "x"],
1603
- name=str(_bounds),
1604
2099
  attrs={"crs": self.crs},
1605
2100
  )
1606
2101
  )
1607
2102
 
1608
- if bounds is None:
1609
- bounds = self.bounds
1610
-
1611
- merged = merge_arrays(arrs, bounds=bounds, res=self.res)
2103
+ merged = merge_arrays(
2104
+ arrs,
2105
+ res=self.res,
2106
+ nodata=self.nodata,
2107
+ )
1612
2108
 
1613
2109
  return merged.to_numpy()
1614
2110
 
1615
- # def write(self, root: str | Path, file_type: str = "tif", **kwargs) -> None:
1616
- # _test = kwargs.pop("_test")
1617
- # suffix = "." + file_type.strip(".")
1618
- # for img in self:
1619
- # img_path = Path(root) / img.name
1620
- # for band in img:
1621
- # band_path = (img_path / band.name).with_suffix(suffix)
1622
- # if _test:
1623
- # print(f"{self.__class__.__name__}.write: {band_path}")
1624
- # continue
1625
- # band.write(band_path, **kwargs)
1626
-
1627
- def load_bands(self, bounds=None, indexes=None, **kwargs) -> "ImageCollection":
2111
+ def sort_images(self, ascending: bool = True) -> "ImageCollection":
2112
+ """Sort Images by date."""
2113
+ self._images = (
2114
+ list(sorted([img for img in self if img.date is not None]))
2115
+ + sorted(
2116
+ [img for img in self if img.date is None and img.path is not None],
2117
+ key=lambda x: x.path,
2118
+ )
2119
+ + [img for img in self if img.date is None and img.path is None]
2120
+ )
2121
+ if not ascending:
2122
+ self._images = list(reversed(self.images))
2123
+ return self
2124
+
2125
+ def load(
2126
+ self,
2127
+ bounds: tuple | Geometry | GeoDataFrame | GeoSeries | None = None,
2128
+ indexes: int | tuple[int] | None = None,
2129
+ **kwargs,
2130
+ ) -> "ImageCollection":
1628
2131
  """Load all image Bands with threading."""
1629
2132
  with joblib.Parallel(n_jobs=self.processes, backend="threading") as parallel:
2133
+ if self.masking:
2134
+ parallel(
2135
+ joblib.delayed(_load_band)(
2136
+ img.mask, bounds=bounds, indexes=indexes, **kwargs
2137
+ )
2138
+ for img in self
2139
+ )
1630
2140
  parallel(
1631
2141
  joblib.delayed(_load_band)(
1632
2142
  band, bounds=bounds, indexes=indexes, **kwargs
1633
2143
  )
1634
2144
  for img in self
1635
2145
  for band in img
1636
- if bounds is None or img.intersects(bounds)
1637
2146
  )
1638
2147
 
1639
2148
  return self
1640
2149
 
1641
- def set_mask(
1642
- self, mask: GeoDataFrame | GeoSeries | Geometry | tuple[float]
2150
+ def set_bbox(
2151
+ self, bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float]
1643
2152
  ) -> "ImageCollection":
1644
2153
  """Set the mask to be used to clip the images to."""
1645
- self._mask = to_bbox(mask)
2154
+ self._bbox = to_bbox(bbox)
1646
2155
  # only update images when already instansiated
1647
- if hasattr(self, "_images"):
2156
+ if self._images is not None:
1648
2157
  for img in self._images:
1649
- img._mask = self._mask
1650
- img._bounds = self._mask
1651
- for band in img:
1652
- band._mask = self._mask
1653
- band._bounds = self._mask
2158
+ img._bbox = self._bbox
2159
+ if img._bands is not None:
2160
+ for band in img:
2161
+ band._bbox = self._bbox
2162
+ bounds = box(*band._bbox).intersection(box(*band.bounds))
2163
+ band._bounds = to_bbox(bounds) if not bounds.is_empty else None
2164
+
2165
+ return self
2166
+
2167
+ def apply(self, func: Callable, **kwargs) -> "ImageCollection":
2168
+ """Apply a function to all bands in each image of the collection."""
2169
+ for img in self:
2170
+ img.bands = [func(band, **kwargs) for band in img]
1654
2171
  return self
1655
2172
 
1656
2173
  def filter(
1657
2174
  self,
1658
2175
  bands: str | list[str] | None = None,
2176
+ exclude_bands: str | list[str] | None = None,
1659
2177
  date_ranges: (
1660
2178
  tuple[str | None, str | None]
1661
2179
  | tuple[tuple[str | None, str | None], ...]
1662
2180
  | None
1663
2181
  ) = None,
1664
- bounds: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2182
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
2183
+ intersects: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None = None,
1665
2184
  max_cloud_coverage: int | None = None,
1666
2185
  copy: bool = True,
1667
2186
  ) -> "ImageCollection":
1668
2187
  """Filter images and bands in the collection."""
1669
2188
  copied = self.copy() if copy else self
1670
2189
 
1671
- if isinstance(bounds, BoundingBox):
1672
- date_ranges = (bounds.mint, bounds.maxt)
2190
+ if isinstance(bbox, BoundingBox):
2191
+ date_ranges = (bbox.mint, bbox.maxt)
1673
2192
 
1674
2193
  if date_ranges:
1675
- copied = copied._filter_dates(date_ranges, copy=False)
2194
+ copied = copied._filter_dates(date_ranges)
1676
2195
 
1677
2196
  if max_cloud_coverage is not None:
1678
2197
  copied.images = [
@@ -1681,8 +2200,12 @@ class ImageCollection(_ImageBase):
1681
2200
  if image.cloud_coverage_percentage < max_cloud_coverage
1682
2201
  ]
1683
2202
 
1684
- if bounds is not None:
1685
- copied = copied._filter_bounds(bounds, copy=False)
2203
+ if bbox is not None:
2204
+ copied = copied._filter_bounds(bbox)
2205
+ copied.set_bbox(bbox)
2206
+
2207
+ if intersects is not None:
2208
+ copied = copied._filter_bounds(intersects)
1686
2209
 
1687
2210
  if bands is not None:
1688
2211
  if isinstance(bands, str):
@@ -1691,6 +2214,21 @@ class ImageCollection(_ImageBase):
1691
2214
  copied._band_ids = bands
1692
2215
  copied.images = [img[bands] for img in copied.images if bands in img]
1693
2216
 
2217
+ if exclude_bands is not None:
2218
+ if isinstance(exclude_bands, str):
2219
+ exclude_bands = {exclude_bands}
2220
+ else:
2221
+ exclude_bands = set(exclude_bands)
2222
+ include_bands: list[list[str]] = [
2223
+ [band_id for band_id in img.band_ids if band_id not in exclude_bands]
2224
+ for img in copied
2225
+ ]
2226
+ copied.images = [
2227
+ img[bands]
2228
+ for img, bands in zip(copied.images, include_bands, strict=False)
2229
+ if bands
2230
+ ]
2231
+
1694
2232
  return copied
1695
2233
 
1696
2234
  def _filter_dates(
@@ -1698,46 +2236,46 @@ class ImageCollection(_ImageBase):
1698
2236
  date_ranges: (
1699
2237
  tuple[str | None, str | None] | tuple[tuple[str | None, str | None], ...]
1700
2238
  ),
1701
- copy: bool = True,
1702
2239
  ) -> "ImageCollection":
1703
2240
  if not isinstance(date_ranges, (tuple, list)):
1704
2241
  raise TypeError(
1705
2242
  "date_ranges should be a 2-length tuple of strings or None, "
1706
2243
  "or a tuple of tuples for multiple date ranges"
1707
2244
  )
1708
- if self.image_patterns is None:
2245
+ if not self.image_patterns:
1709
2246
  raise ValueError(
1710
2247
  "Cannot set date_ranges when the class's image_regexes attribute is None"
1711
2248
  )
1712
2249
 
1713
- copied = self.copy() if copy else self
1714
-
1715
- copied.images = [
2250
+ self.images = [
1716
2251
  img
1717
2252
  for img in self
1718
2253
  if _date_is_within(
1719
- img.path, date_ranges, copied.image_patterns, copied.date_format
2254
+ img.path, date_ranges, self.image_patterns, self.date_format
1720
2255
  )
1721
2256
  ]
1722
- return copied
2257
+ return self
1723
2258
 
1724
2259
  def _filter_bounds(
1725
- self, other: GeoDataFrame | GeoSeries | Geometry | tuple, copy: bool = True
2260
+ self, other: GeoDataFrame | GeoSeries | Geometry | tuple
1726
2261
  ) -> "ImageCollection":
1727
- copied = self.copy() if copy else self
2262
+ if self._images is None:
2263
+ return self
1728
2264
 
1729
2265
  other = to_shapely(other)
1730
2266
 
1731
- with joblib.Parallel(n_jobs=copied.processes, backend="threading") as parallel:
2267
+ # intersects_list = GeoSeries([img.union_all() for img in self]).intersects(other)
2268
+ with joblib.Parallel(n_jobs=self.processes, backend="loky") as parallel:
1732
2269
  intersects_list: list[bool] = parallel(
1733
- joblib.delayed(_intesects)(image, other) for image in copied
2270
+ joblib.delayed(_intesects)(image, other) for image in self
1734
2271
  )
1735
- copied.images = [
2272
+
2273
+ self.images = [
1736
2274
  image
1737
- for image, intersects in zip(copied, intersects_list, strict=False)
2275
+ for image, intersects in zip(self, intersects_list, strict=False)
1738
2276
  if intersects
1739
2277
  ]
1740
- return copied
2278
+ return self
1741
2279
 
1742
2280
  def to_gdfs(self, column: str = "value") -> dict[str, GeoDataFrame]:
1743
2281
  """Convert each band in each Image to a GeoDataFrame."""
@@ -1751,21 +2289,38 @@ class ImageCollection(_ImageBase):
1751
2289
  except AttributeError:
1752
2290
  name = f"{self.__class__.__name__}({i})"
1753
2291
 
2292
+ band.load()
2293
+
1754
2294
  if name not in out:
1755
2295
  out[name] = band.to_gdf(column=column)
1756
- else:
1757
- out[name] = f"{self.__class__.__name__}({i})"
2296
+ # else:
2297
+ # out[name] = f"{self.__class__.__name__}({i})"
1758
2298
  return out
1759
2299
 
1760
2300
  def sample(self, n: int = 1, size: int = 500) -> "ImageCollection":
1761
2301
  """Sample one or more areas of a given size and set this as mask for the images."""
1762
- images = []
1763
- bbox = to_gdf(self.unary_union).geometry.buffer(-size / 2)
2302
+ unioned = self.union_all()
2303
+ buffered_in = unioned.buffer(-size / 2)
2304
+ if not buffered_in.is_empty:
2305
+ bbox = to_gdf(buffered_in)
2306
+ else:
2307
+ bbox = to_gdf(unioned)
2308
+
1764
2309
  copied = self.copy()
1765
- for _ in range(n):
2310
+ sampled_images = []
2311
+ while len(sampled_images) < n:
1766
2312
  mask = to_bbox(bbox.sample_points(1).buffer(size))
1767
- images += copied.filter(bounds=mask).set_mask(mask).images
1768
- copied.images = images
2313
+ images = copied.filter(bbox=mask).images
2314
+ random.shuffle(images)
2315
+ try:
2316
+ images = images[:n]
2317
+ except IndexError:
2318
+ pass
2319
+ sampled_images += images
2320
+ copied._images = sampled_images[:n]
2321
+ if copied._should_be_sorted:
2322
+ copied._images = list(sorted(copied._images))
2323
+
1769
2324
  return copied
1770
2325
 
1771
2326
  def sample_tiles(self, n: int) -> "ImageCollection":
@@ -1821,34 +2376,47 @@ class ImageCollection(_ImageBase):
1821
2376
  copied.images = copied.images[item]
1822
2377
  return copied
1823
2378
 
2379
+ if isinstance(item, ImageCollection):
2380
+
2381
+ def _get_from_single_element_list(lst: list[Any]) -> Any:
2382
+ if len(lst) != 1:
2383
+ raise ValueError(lst)
2384
+ return next(iter(lst))
2385
+
2386
+ copied = self.copy()
2387
+ copied._images = [
2388
+ _get_from_single_element_list(
2389
+ [img2 for img2 in copied if img2.stem in img.path]
2390
+ )
2391
+ for img in item
2392
+ ]
2393
+ return copied
2394
+
1824
2395
  if not isinstance(item, BoundingBox) and not (
1825
- isinstance(item, Iterable) and all(isinstance(x, BoundingBox) for x in item)
2396
+ isinstance(item, Iterable)
2397
+ and len(item)
2398
+ and all(isinstance(x, BoundingBox) for x in item)
1826
2399
  ):
1827
- try:
1828
- copied = self.copy()
1829
- if all(isinstance(x, bool) for x in item):
1830
- copied.images = [
1831
- img for x, img in zip(item, copied, strict=True) if x
1832
- ]
1833
- else:
1834
- copied.images = [copied.images[i] for i in item]
1835
- return copied
1836
- except Exception as e:
1837
- if hasattr(item, "__iter__"):
1838
- endnote = f" of length {len(item)} with types {set(type(x) for x in item)}"
1839
- raise TypeError(
1840
- "ImageCollection indices must be int or BoundingBox. "
1841
- f"Got {type(item)}{endnote}"
1842
- ) from e
2400
+ copied = self.copy()
2401
+ if callable(item):
2402
+ item = [item(img) for img in copied]
2403
+
2404
+ # check for base bool and numpy bool
2405
+ if all("bool" in str(type(x)) for x in item):
2406
+ copied.images = [img for x, img in zip(item, copied, strict=True) if x]
1843
2407
 
1844
- elif isinstance(item, BoundingBox):
2408
+ else:
2409
+ copied.images = [copied.images[i] for i in item]
2410
+ return copied
2411
+
2412
+ if isinstance(item, BoundingBox):
1845
2413
  date_ranges: tuple[str] = (item.mint, item.maxt)
1846
2414
  data: torch.Tensor = numpy_to_torch(
1847
2415
  np.array(
1848
2416
  [
1849
2417
  band.values
1850
2418
  for band in self.filter(
1851
- bounds=item, date_ranges=date_ranges
2419
+ bbox=item, date_ranges=date_ranges
1852
2420
  ).merge_by_band(bounds=item)
1853
2421
  ]
1854
2422
  )
@@ -1863,7 +2431,7 @@ class ImageCollection(_ImageBase):
1863
2431
  [
1864
2432
  band.values
1865
2433
  for band in self.filter(
1866
- bounds=bbox, date_ranges=date_range
2434
+ bbox=bbox, date_ranges=date_range
1867
2435
  ).merge_by_band(bounds=bbox)
1868
2436
  ]
1869
2437
  )
@@ -1879,9 +2447,6 @@ class ImageCollection(_ImageBase):
1879
2447
 
1880
2448
  return sample
1881
2449
 
1882
- # def dates_as_float(self) -> list[tuple[float, float]]:
1883
- # return [disambiguate_timestamp(date, self.date_format) for date in self.dates]
1884
-
1885
2450
  @property
1886
2451
  def mint(self) -> float:
1887
2452
  """Min timestamp of the images combined."""
@@ -1919,29 +2484,71 @@ class ImageCollection(_ImageBase):
1919
2484
  @property
1920
2485
  def images(self) -> list["Image"]:
1921
2486
  """List of images in the Collection."""
1922
- try:
2487
+ if self._images is not None:
1923
2488
  return self._images
1924
- except AttributeError:
1925
- # only fetch images when they are needed
1926
- self._images = _get_images(
1927
- list(self._df["image_path"]),
1928
- all_file_paths=self._all_filepaths,
1929
- df=self._df,
1930
- res=self.res,
1931
- processes=self.processes,
1932
- image_class=self.image_class,
1933
- _mask=self._mask,
2489
+ # only fetch images when they are needed
2490
+ self._images = _get_images(
2491
+ list(self._df["image_path"]),
2492
+ all_file_paths=self._all_file_paths,
2493
+ df=self._df,
2494
+ image_class=self.image_class,
2495
+ band_class=self.band_class,
2496
+ masking=self.masking,
2497
+ **self._common_init_kwargs,
2498
+ )
2499
+ if self.masking is not None:
2500
+ images = []
2501
+ for image in self._images:
2502
+ try:
2503
+ if not isinstance(image.mask, Band):
2504
+ raise ValueError()
2505
+ images.append(image)
2506
+ except ValueError:
2507
+ continue
2508
+ self._images = images
2509
+ for image in self._images:
2510
+ image._bands = [band for band in image if band.band_id is not None]
2511
+
2512
+ if self.metadata is not None:
2513
+ for img in self:
2514
+ for band in img:
2515
+ for key in ["crs", "bounds"]:
2516
+ try:
2517
+ value = self.metadata[band.path][key]
2518
+ except KeyError:
2519
+ value = self.metadata[key][band.path]
2520
+ setattr(band, f"_{key}", value)
2521
+
2522
+ self._images = [img for img in self if len(img)]
2523
+
2524
+ if self._should_be_sorted:
2525
+ self._images = list(sorted(self._images))
2526
+
2527
+ return self._images
2528
+
2529
+ @property
2530
+ def _should_be_sorted(self) -> bool:
2531
+ """True if the ImageCollection has regexes that should make it sortable by date."""
2532
+ sort_group = "date"
2533
+ return (
2534
+ self.filename_patterns
2535
+ and any(
2536
+ sort_group in pat.groupindex
2537
+ and sort_group in _get_non_optional_groups(pat)
2538
+ for pat in self.filename_patterns
1934
2539
  )
1935
- if self.image_regexes:
1936
- self._images = list(sorted(self._images))
1937
- return self._images
2540
+ or self.image_patterns
2541
+ and any(
2542
+ sort_group in pat.groupindex
2543
+ and sort_group in _get_non_optional_groups(pat)
2544
+ for pat in self.image_patterns
2545
+ )
2546
+ or all(img.date is not None for img in self)
2547
+ )
1938
2548
 
1939
2549
  @images.setter
1940
2550
  def images(self, new_value: list["Image"]) -> list["Image"]:
1941
- if self.filename_patterns and any(pat.groups for pat in self.filename_patterns):
1942
- self._images = list(sorted(new_value))
1943
- else:
1944
- self._images = list(new_value)
2551
+ self._images = list(new_value)
1945
2552
  if not all(isinstance(x, Image) for x in self._images):
1946
2553
  raise TypeError("images should be a sequence of Image.")
1947
2554
 
@@ -1969,12 +2576,11 @@ class ImageCollection(_ImageBase):
1969
2576
 
1970
2577
  def __repr__(self) -> str:
1971
2578
  """String representation."""
1972
- return f"{self.__class__.__name__}({len(self)})"
2579
+ return f"{self.__class__.__name__}({len(self)}, path='{self.path}')"
1973
2580
 
1974
- @property
1975
- def unary_union(self) -> Polygon | MultiPolygon:
2581
+ def union_all(self) -> Polygon | MultiPolygon:
1976
2582
  """(Multi)Polygon representing the union of all image bounds."""
1977
- return unary_union([img.unary_union for img in self])
2583
+ return unary_union([img.union_all() for img in self])
1978
2584
 
1979
2585
  @property
1980
2586
  def bounds(self) -> tuple[int, int, int, int]:
@@ -1984,7 +2590,143 @@ class ImageCollection(_ImageBase):
1984
2590
  @property
1985
2591
  def crs(self) -> Any:
1986
2592
  """Common coordinate reference system of the Images."""
1987
- return get_common_crs([img.crs for img in self])
2593
+ if self._crs is not None:
2594
+ return self._crs
2595
+ self._crs = get_common_crs([img.crs for img in self])
2596
+ return self._crs
2597
+
2598
+ def plot_pixels(
2599
+ self,
2600
+ by: str | list[str] | None = None,
2601
+ x_var: str = "date",
2602
+ y_label: str = "value",
2603
+ p: float = 0.95,
2604
+ ylim: tuple[float, float] | None = None,
2605
+ figsize: tuple[int] = (20, 8),
2606
+ ) -> None:
2607
+ """Plot each individual pixel in a dotplot for all dates.
2608
+
2609
+ Args:
2610
+ by: Band attributes to groupby. Defaults to "bounds" and "band_id"
2611
+ if all bands have no-None band_ids, otherwise defaults to "bounds".
2612
+ x_var: Attribute to use on the x-axis. Defaults to "date"
2613
+ if the ImageCollection is sortable by date, otherwise a range index.
2614
+ Can be set to "days_since_start".
2615
+ y_label: Label to use on the y-axis.
2616
+ p: p-value for the confidence interval.
2617
+ ylim: Limits of the y-axis.
2618
+ figsize: Figure size as tuple (width, height).
2619
+
2620
+ """
2621
+ if by is None and all(band.band_id is not None for img in self for band in img):
2622
+ by = ["bounds", "band_id"]
2623
+ elif by is None:
2624
+ by = ["bounds"]
2625
+
2626
+ alpha = 1 - p
2627
+
2628
+ for img in self:
2629
+ for band in img:
2630
+ band.load()
2631
+
2632
+ for group_values, subcollection in self.groupby(by):
2633
+ print("group_values:", *group_values)
2634
+
2635
+ y = np.array([band.values for img in subcollection for band in img])
2636
+ if "date" in x_var and subcollection._should_be_sorted:
2637
+ x = np.array(
2638
+ [
2639
+ datetime.datetime.strptime(band.date[:8], "%Y%m%d").date()
2640
+ for img in subcollection
2641
+ for band in img
2642
+ ]
2643
+ )
2644
+ x = (
2645
+ pd.to_datetime(
2646
+ [band.date[:8] for img in subcollection for band in img]
2647
+ )
2648
+ - pd.Timestamp(np.min(x))
2649
+ ).days
2650
+ else:
2651
+ x = np.arange(0, len(y))
2652
+
2653
+ mask = np.array(
2654
+ [
2655
+ (
2656
+ band.values.mask
2657
+ if hasattr(band.values, "mask")
2658
+ else np.full(band.values.shape, False)
2659
+ )
2660
+ for img in subcollection
2661
+ for band in img
2662
+ ]
2663
+ )
2664
+
2665
+ if x_var == "days_since_start":
2666
+ x = x - np.min(x)
2667
+
2668
+ for i in range(y.shape[1]):
2669
+ for j in range(y.shape[2]):
2670
+ this_y = y[:, i, j]
2671
+
2672
+ this_mask = mask[:, i, j]
2673
+ this_x = x[~this_mask]
2674
+ this_y = this_y[~this_mask]
2675
+
2676
+ if ylim:
2677
+ condition = (this_y >= ylim[0]) & (this_y <= ylim[1])
2678
+ this_y = this_y[condition]
2679
+ this_x = this_x[condition]
2680
+
2681
+ coef, intercept = np.linalg.lstsq(
2682
+ np.vstack([this_x, np.ones(this_x.shape[0])]).T,
2683
+ this_y,
2684
+ rcond=None,
2685
+ )[0]
2686
+ predicted = np.array([intercept + coef * x for x in this_x])
2687
+
2688
+ # Degrees of freedom
2689
+ dof = len(this_x) - 2
2690
+
2691
+ # 95% confidence interval
2692
+ t_val = stats.t.ppf(1 - alpha / 2, dof)
2693
+
2694
+ # Mean squared error of the residuals
2695
+ mse = np.sum((this_y - predicted) ** 2) / dof
2696
+
2697
+ # Calculate the standard error of predictions
2698
+ pred_stderr = np.sqrt(
2699
+ mse
2700
+ * (
2701
+ 1 / len(this_x)
2702
+ + (this_x - np.mean(this_x)) ** 2
2703
+ / np.sum((this_x - np.mean(this_x)) ** 2)
2704
+ )
2705
+ )
2706
+
2707
+ # Calculate the confidence interval for predictions
2708
+ ci_lower = predicted - t_val * pred_stderr
2709
+ ci_upper = predicted + t_val * pred_stderr
2710
+
2711
+ rounding = int(np.log(1 / abs(coef)))
2712
+
2713
+ fig = plt.figure(figsize=figsize)
2714
+ ax = fig.add_subplot(1, 1, 1)
2715
+
2716
+ ax.scatter(this_x, this_y, color="#2c93db")
2717
+ ax.plot(this_x, predicted, color="#e0436b")
2718
+ ax.fill_between(
2719
+ this_x,
2720
+ ci_lower,
2721
+ ci_upper,
2722
+ color="#e0436b",
2723
+ alpha=0.2,
2724
+ label=f"{int(alpha*100)}% CI",
2725
+ )
2726
+ plt.title(f"Coefficient: {round(coef, rounding)}")
2727
+ plt.xlabel(x_var)
2728
+ plt.ylabel(y_label)
2729
+ plt.show()
1988
2730
 
1989
2731
 
1990
2732
  def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCollection:
@@ -2003,12 +2745,13 @@ def concat_image_collections(collections: Sequence[ImageCollection]) -> ImageCol
2003
2745
  out_collection = first_collection.__class__(
2004
2746
  images,
2005
2747
  level=level,
2748
+ band_class=first_collection.band_class,
2749
+ image_class=first_collection.image_class,
2006
2750
  **first_collection._common_init_kwargs,
2007
- # res=list(resolutions)[0],
2008
2751
  )
2009
- out_collection._all_filepaths = list(
2752
+ out_collection._all_file_paths = list(
2010
2753
  sorted(
2011
- set(itertools.chain.from_iterable([x._all_filepaths for x in collections]))
2754
+ set(itertools.chain.from_iterable([x._all_file_paths for x in collections]))
2012
2755
  )
2013
2756
  )
2014
2757
  return out_collection
@@ -2046,7 +2789,7 @@ def to_xarray(
2046
2789
  dims=dims,
2047
2790
  name=name,
2048
2791
  attrs={"crs": crs},
2049
- ) # .transpose("y", "x")
2792
+ )
2050
2793
 
2051
2794
 
2052
2795
  def _slope_2d(array: np.ndarray, res: int, degrees: int) -> np.ndarray:
@@ -2065,28 +2808,70 @@ def _slope_2d(array: np.ndarray, res: int, degrees: int) -> np.ndarray:
2065
2808
  return degrees
2066
2809
 
2067
2810
 
2811
+ def _clip_loaded_array(
2812
+ arr: np.ndarray,
2813
+ bounds: tuple[int, int, int, int],
2814
+ transform: Affine,
2815
+ crs: Any,
2816
+ out_shape: tuple[int, int],
2817
+ **kwargs,
2818
+ ) -> np.ndarray:
2819
+ # xarray needs a numpy array of polygon(s)
2820
+ bounds_arr: np.ndarray = GeoSeries([to_shapely(bounds)]).values
2821
+ try:
2822
+
2823
+ while out_shape != arr.shape:
2824
+ arr = (
2825
+ to_xarray(
2826
+ arr,
2827
+ transform=transform,
2828
+ crs=crs,
2829
+ )
2830
+ .rio.clip(bounds_arr, crs=crs, **kwargs)
2831
+ .to_numpy()
2832
+ )
2833
+ # bounds_arr = bounds_arr.buffer(0.0000001)
2834
+ return arr
2835
+
2836
+ except NoDataInBounds:
2837
+ return np.array([])
2838
+
2839
+
2068
2840
  def _get_images(
2069
2841
  image_paths: list[str],
2070
2842
  *,
2071
- res: int,
2072
2843
  all_file_paths: list[str],
2073
2844
  df: pd.DataFrame,
2074
2845
  processes: int,
2075
- image_class: Image,
2076
- _mask: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None,
2846
+ image_class: type,
2847
+ band_class: type,
2848
+ bbox: GeoDataFrame | GeoSeries | Geometry | tuple[float] | None,
2849
+ masking: BandMasking | None,
2850
+ **kwargs,
2077
2851
  ) -> list[Image]:
2078
- with joblib.Parallel(n_jobs=processes, backend="threading") as parallel:
2079
- return parallel(
2852
+
2853
+ with joblib.Parallel(n_jobs=processes, backend="loky") as parallel:
2854
+ images = parallel(
2080
2855
  joblib.delayed(image_class)(
2081
2856
  path,
2082
2857
  df=df,
2083
- res=res,
2084
2858
  all_file_paths=all_file_paths,
2085
- _mask=_mask,
2086
- processes=processes,
2859
+ masking=masking,
2860
+ band_class=band_class,
2861
+ **kwargs,
2087
2862
  )
2088
2863
  for path in image_paths
2089
2864
  )
2865
+ if bbox is not None:
2866
+ intersects_list = GeoSeries([img.union_all() for img in images]).intersects(
2867
+ to_shapely(bbox)
2868
+ )
2869
+ return [
2870
+ img
2871
+ for img, intersects in zip(images, intersects_list, strict=False)
2872
+ if intersects
2873
+ ]
2874
+ return images
2090
2875
 
2091
2876
 
2092
2877
  def numpy_to_torch(array: np.ndarray) -> torch.Tensor:
@@ -2104,8 +2889,12 @@ class _RegexError(ValueError):
2104
2889
  pass
2105
2890
 
2106
2891
 
2892
+ class ArrayNotLoadedError(ValueError):
2893
+ """Arrays are not loaded."""
2894
+
2895
+
2107
2896
  class PathlessImageError(ValueError):
2108
- """Exception for when Images, Bands or ImageCollections have no path."""
2897
+ """'path' attribute is needed but instance has no path."""
2109
2898
 
2110
2899
  def __init__(self, instance: _ImageBase) -> None:
2111
2900
  """Initialise error class."""
@@ -2132,7 +2921,7 @@ def _get_regex_match_from_xml_in_local_dir(
2132
2921
  for i, path in enumerate(paths):
2133
2922
  if ".xml" not in path:
2134
2923
  continue
2135
- with open_func(path, "rb") as file:
2924
+ with _open_func(path, "rb") as file:
2136
2925
  filebytes: bytes = file.read()
2137
2926
  try:
2138
2927
  return _extract_regex_match_from_string(
@@ -2144,24 +2933,26 @@ def _get_regex_match_from_xml_in_local_dir(
2144
2933
 
2145
2934
 
2146
2935
  def _extract_regex_match_from_string(
2147
- xml_file: str, regexes: tuple[str]
2936
+ xml_file: str, regexes: tuple[str | re.Pattern]
2148
2937
  ) -> str | dict[str, str]:
2938
+ if all(isinstance(x, str) for x in regexes):
2939
+ for regex in regexes:
2940
+ try:
2941
+ return re.search(regex, xml_file).group(1)
2942
+ except (TypeError, AttributeError):
2943
+ continue
2944
+ raise _RegexError()
2945
+
2946
+ out = {}
2149
2947
  for regex in regexes:
2150
- if isinstance(regex, dict):
2151
- out = {}
2152
- for key, value in regex.items():
2153
- try:
2154
- out[key] = re.search(value, xml_file).group(1)
2155
- except (TypeError, AttributeError):
2156
- continue
2157
- if len(out) != len(regex):
2158
- raise _RegexError()
2159
- return out
2160
2948
  try:
2161
- return re.search(regex, xml_file).group(1)
2949
+ matches = re.search(regex, xml_file)
2950
+ out |= matches.groupdict()
2162
2951
  except (TypeError, AttributeError):
2163
2952
  continue
2164
- raise _RegexError()
2953
+ if not out:
2954
+ raise _RegexError()
2955
+ return out
2165
2956
 
2166
2957
 
2167
2958
  def _fix_path(path: str) -> str:
@@ -2171,116 +2962,62 @@ def _fix_path(path: str) -> str:
2171
2962
 
2172
2963
 
2173
2964
  def _get_regexes_matches_for_df(
2174
- df, match_col: str, patterns: Sequence[re.Pattern], suffix: str = ""
2175
- ) -> tuple[pd.DataFrame, list[str]]:
2965
+ df, match_col: str, patterns: Sequence[re.Pattern]
2966
+ ) -> pd.DataFrame:
2176
2967
  if not len(df):
2177
- return df, []
2178
- assert df.index.is_unique
2179
- matches: list[pd.DataFrame] = []
2180
- for pat in patterns:
2181
- if pat.groups:
2182
- try:
2183
- matches.append(df[match_col].str.extract(pat))
2184
- except ValueError:
2185
- continue
2186
- else:
2187
- match_ = df[match_col].loc[df[match_col].str.match(pat)]
2188
- if len(match_):
2189
- matches.append(match_)
2190
-
2191
- matches = pd.concat(matches).groupby(level=0, dropna=True).first()
2192
-
2193
- if isinstance(matches, pd.Series):
2194
- matches = pd.DataFrame({matches.name: matches.values}, index=matches.index)
2968
+ return df
2195
2969
 
2196
- match_cols = [f"{col}{suffix}" for col in matches.columns]
2197
- df[match_cols] = matches
2198
- return (
2199
- df.loc[~df[match_cols].isna().all(axis=1)].drop(
2200
- columns=f"{match_col}{suffix}", errors="ignore"
2201
- ),
2202
- match_cols,
2970
+ non_optional_groups = list(
2971
+ set(
2972
+ itertools.chain.from_iterable(
2973
+ [_get_non_optional_groups(pat) for pat in patterns]
2974
+ )
2975
+ )
2203
2976
  )
2204
2977
 
2978
+ if not non_optional_groups:
2979
+ return df
2205
2980
 
2206
- def _arr_from_gdf(
2207
- gdf: GeoDataFrame,
2208
- res: int,
2209
- fill: int = 0,
2210
- all_touched: bool = False,
2211
- merge_alg: Callable = MergeAlg.replace,
2212
- default_value: int = 1,
2213
- dtype: Any | None = None,
2214
- ) -> np.ndarray:
2215
- """Construct Raster from a GeoDataFrame or GeoSeries.
2216
-
2217
- The GeoDataFrame should have
2218
-
2219
- Args:
2220
- gdf: The GeoDataFrame to rasterize.
2221
- res: Resolution of the raster in units of the GeoDataFrame's coordinate reference system.
2222
- fill: Fill value for areas outside of input geometries (default is 0).
2223
- all_touched: Whether to consider all pixels touched by geometries,
2224
- not just those whose center is within the polygon (default is False).
2225
- merge_alg: Merge algorithm to use when combining geometries
2226
- (default is 'MergeAlg.replace').
2227
- default_value: Default value to use for the rasterized pixels
2228
- (default is 1).
2229
- dtype: Data type of the output array. If None, it will be
2230
- determined automatically.
2981
+ assert df.index.is_unique
2982
+ keep = []
2983
+ for pat in patterns:
2984
+ for i, row in df[match_col].items():
2985
+ matches = _get_first_group_match(pat, row)
2986
+ if all(group in matches for group in non_optional_groups):
2987
+ keep.append(i)
2988
+
2989
+ return df.loc[keep]
2990
+
2991
+
2992
+ def _get_non_optional_groups(pat: re.Pattern | str) -> list[str]:
2993
+ return [
2994
+ x
2995
+ for x in [
2996
+ _extract_group_name(group)
2997
+ for group in pat.pattern.split("\n")
2998
+ if group
2999
+ and not group.replace(" ", "").startswith("#")
3000
+ and not group.replace(" ", "").split("#")[0].endswith("?")
3001
+ ]
3002
+ if x is not None
3003
+ ]
2231
3004
 
2232
- Returns:
2233
- A Raster instance based on the specified GeoDataFrame and parameters.
2234
3005
 
2235
- Raises:
2236
- TypeError: If 'transform' is provided in kwargs, as this is
2237
- computed based on the GeoDataFrame bounds and resolution.
2238
- """
2239
- if isinstance(gdf, GeoSeries):
2240
- values = gdf.index
2241
- gdf = gdf.to_frame("geometry")
2242
- elif isinstance(gdf, GeoDataFrame):
2243
- if len(gdf.columns) > 2:
2244
- raise ValueError(
2245
- "gdf should have only a geometry column and one numeric column to "
2246
- "use as array values. "
2247
- "Alternatively only a geometry column and a numeric index."
2248
- )
2249
- elif len(gdf.columns) == 1:
2250
- values = gdf.index
2251
- else:
2252
- col: str = next(
2253
- iter([col for col in gdf if col != gdf._geometry_column_name])
2254
- )
2255
- values = gdf[col]
2256
-
2257
- if isinstance(values, pd.MultiIndex):
2258
- raise ValueError("Index cannot be MultiIndex.")
2259
-
2260
- shape = _get_shape_from_bounds(gdf.total_bounds, res=res)
2261
- transform = _get_transform_from_bounds(gdf.total_bounds, shape)
2262
-
2263
- return features.rasterize(
2264
- _gdf_to_geojson_with_col(gdf, values),
2265
- out_shape=shape,
2266
- transform=transform,
2267
- fill=fill,
2268
- all_touched=all_touched,
2269
- merge_alg=merge_alg,
2270
- default_value=default_value,
2271
- dtype=dtype,
2272
- )
3006
+ def _extract_group_name(txt: str) -> str | None:
3007
+ try:
3008
+ return re.search(r"\(\?P<(\w+)>", txt)[1]
3009
+ except TypeError:
3010
+ return None
2273
3011
 
2274
3012
 
2275
- def _gdf_to_geojson_with_col(gdf: GeoDataFrame, values: np.ndarray) -> list[dict]:
2276
- with warnings.catch_warnings():
2277
- warnings.filterwarnings("ignore", category=UserWarning)
2278
- return [
2279
- (feature["geometry"], val)
2280
- for val, feature in zip(
2281
- values, loads(gdf.to_json())["features"], strict=False
2282
- )
2283
- ]
3013
+ def _get_first_group_match(pat: re.Pattern, text: str) -> dict[str, str]:
3014
+ groups = pat.groupindex.keys()
3015
+ all_matches: dict[str, str] = {}
3016
+ for x in pat.findall(text):
3017
+ for group, value in zip(groups, x, strict=True):
3018
+ if value and group not in all_matches:
3019
+ all_matches[group] = value
3020
+ return all_matches
2284
3021
 
2285
3022
 
2286
3023
  def _date_is_within(
@@ -2292,10 +3029,11 @@ def _date_is_within(
2292
3029
  date_format: str,
2293
3030
  ) -> bool:
2294
3031
  for pat in image_patterns:
3032
+
2295
3033
  try:
2296
- date = re.match(pat, Path(path).name).group("date")
3034
+ date = _get_first_group_match(pat, Path(path).name)["date"]
2297
3035
  break
2298
- except AttributeError:
3036
+ except KeyError:
2299
3037
  date = None
2300
3038
 
2301
3039
  if date is None:
@@ -2323,11 +3061,14 @@ def _date_is_within(
2323
3061
  try:
2324
3062
  date_min = date_min or "00000000"
2325
3063
  date_max = date_max or "99999999"
2326
- assert isinstance(date_min, str)
2327
- assert len(date_min) == 8
2328
- assert isinstance(date_max, str)
2329
- assert len(date_max) == 8
2330
- except AssertionError as err:
3064
+ if not (
3065
+ isinstance(date_min, str)
3066
+ and len(date_min) == 8
3067
+ and isinstance(date_max, str)
3068
+ and len(date_max) == 8
3069
+ ):
3070
+ raise ValueError()
3071
+ except ValueError as err:
2331
3072
  raise TypeError(
2332
3073
  "date_ranges should be a tuple of two 8-charactered strings (start and end date)."
2333
3074
  f"Got {date_range} of type {[type(x) for x in date_range]}"
@@ -2338,83 +3079,22 @@ def _date_is_within(
2338
3079
  return False
2339
3080
 
2340
3081
 
2341
- def _get_shape_from_bounds(
2342
- obj: GeoDataFrame | GeoSeries | Geometry | tuple, res: int
2343
- ) -> tuple[int, int]:
2344
- resx, resy = (res, res) if isinstance(res, numbers.Number) else res
2345
-
2346
- minx, miny, maxx, maxy = to_bbox(obj)
2347
- diffx = maxx - minx
2348
- diffy = maxy - miny
2349
- width = int(diffx / resx)
2350
- heigth = int(diffy / resy)
2351
- return heigth, width
2352
-
2353
-
2354
- def _get_transform_from_bounds(
2355
- obj: GeoDataFrame | GeoSeries | Geometry | tuple, shape: tuple[float, ...]
2356
- ) -> Affine:
2357
- minx, miny, maxx, maxy = to_bbox(obj)
2358
- if len(shape) == 2:
2359
- width, height = shape
2360
- elif len(shape) == 3:
2361
- _, width, height = shape
2362
- else:
2363
- raise ValueError
2364
- return rasterio.transform.from_bounds(minx, miny, maxx, maxy, width, height)
2365
-
2366
-
2367
- def _get_shape_from_res(
2368
- bounds: tuple[float], res: int, indexes: int | tuple[int]
2369
- ) -> tuple[int] | None:
2370
- if res is None:
2371
- return None
2372
- if hasattr(res, "__iter__") and len(res) == 2:
2373
- res = res[0]
2374
- diffx = bounds[2] - bounds[0]
2375
- diffy = bounds[3] - bounds[1]
2376
- width = int(diffx / res)
2377
- height = int(diffy / res)
2378
- if not isinstance(indexes, int):
2379
- return len(indexes), width, height
2380
- return width, height
2381
-
2382
-
2383
- def _array_to_geojson(
2384
- array: np.ndarray, transform: Affine, processes: int
2385
- ) -> list[tuple]:
2386
- if np.ma.is_masked(array):
2387
- array = array.data
3082
+ def _get_dtype_min(dtype: str | type) -> int | float:
2388
3083
  try:
2389
- return _array_to_geojson_loop(array, transform, processes)
2390
-
3084
+ return np.iinfo(dtype).min
2391
3085
  except ValueError:
2392
- try:
2393
- array = array.astype(np.float32)
2394
- return _array_to_geojson_loop(array, transform, processes)
2395
-
2396
- except Exception as err:
2397
- raise err.__class__(array.shape, err) from err
3086
+ return np.finfo(dtype).min
2398
3087
 
2399
3088
 
2400
- def _array_to_geojson_loop(array, transform, processes):
2401
- if processes == 1:
2402
- return [
2403
- (value, shape(geom))
2404
- for geom, value in features.shapes(array, transform=transform, mask=None)
2405
- ]
2406
- else:
2407
- with joblib.Parallel(n_jobs=processes, backend="threading") as parallel:
2408
- return parallel(
2409
- joblib.delayed(_value_geom_pair)(value, geom)
2410
- for geom, value in features.shapes(
2411
- array, transform=transform, mask=None
2412
- )
2413
- )
3089
+ def _get_dtype_max(dtype: str | type) -> int | float:
3090
+ try:
3091
+ return np.iinfo(dtype).max
3092
+ except ValueError:
3093
+ return np.finfo(dtype).max
2414
3094
 
2415
3095
 
2416
- def _value_geom_pair(value, geom):
2417
- return (value, shape(geom))
3096
+ def _img_ndvi(img, **kwargs):
3097
+ return Image([img.ndvi(**kwargs)])
2418
3098
 
2419
3099
 
2420
3100
  def _intesects(x, other) -> bool:
@@ -2433,18 +3113,16 @@ def _copy_and_add_df_parallel(
2433
3113
  for img in copied.images:
2434
3114
  img._bands = [band for band in img if band.band_id in band_ids]
2435
3115
 
2436
- # for col in group.columns.difference({"_image_instance", "_image_idx"}):
2437
- # if not all(
2438
- # col in dir(band) or col in band.__dict__ for img in copied for band in img
2439
- # ):
2440
- # continue
2441
- # values = set(group[col].values)
2442
- # for img in copied.images:
2443
- # img._bands = [band for band in img if getattr(band, col) in values]
2444
-
2445
3116
  return (i, copied)
2446
3117
 
2447
3118
 
3119
+ def _get_single_value(values: tuple):
3120
+ if len(set(values)) == 1:
3121
+ return next(iter(values))
3122
+ else:
3123
+ return None
3124
+
3125
+
2448
3126
  def _open_raster(path: str | Path) -> rasterio.io.DatasetReader:
2449
3127
  with opener(path) as file:
2450
3128
  return rasterio.open(file)
@@ -2455,7 +3133,6 @@ def _load_band(band: Band, **kwargs) -> None:
2455
3133
 
2456
3134
 
2457
3135
  def _merge_by_band(collection: ImageCollection, **kwargs) -> Image:
2458
- print("_merge_by_band", collection.dates)
2459
3136
  return collection.merge_by_band(**kwargs)
2460
3137
 
2461
3138
 
@@ -2470,25 +3147,49 @@ def _zonal_one_pair(i: int, poly: Polygon, band: Band, aggfunc, array_func, func
2470
3147
  return _aggregate(clipped.values, array_func, aggfunc, func_names, band.date, i)
2471
3148
 
2472
3149
 
3150
+ def array_buffer(arr: np.ndarray, distance: int) -> np.ndarray:
3151
+ """Buffer array points with the value 1 in a binary array.
3152
+
3153
+ Args:
3154
+ arr: The array.
3155
+ distance: Number of array cells to buffer by.
3156
+
3157
+ Returns:
3158
+ Array with buffered values.
3159
+ """
3160
+ if not np.all(np.isin(arr, (1, 0, True, False))):
3161
+ raise ValueError("Array must be all 0s and 1s or boolean.")
3162
+
3163
+ dtype = arr.dtype
3164
+
3165
+ structure = np.ones((2 * abs(distance) + 1, 2 * abs(distance) + 1))
3166
+
3167
+ arr = np.where(arr, 1, 0)
3168
+
3169
+ if distance > 0:
3170
+ return binary_dilation(arr, structure=structure).astype(dtype)
3171
+ elif distance < 0:
3172
+
3173
+ return binary_erosion(arr, structure=structure).astype(dtype)
3174
+
3175
+
2473
3176
  class Sentinel2Config:
2474
3177
  """Holder of Sentinel 2 regexes, band_ids etc."""
2475
3178
 
2476
- image_regexes: ClassVar[str] = (
2477
- config.SENTINEL2_IMAGE_REGEX,
2478
- ) # config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
3179
+ image_regexes: ClassVar[str] = (config.SENTINEL2_IMAGE_REGEX,)
2479
3180
  filename_regexes: ClassVar[str] = (
2480
3181
  config.SENTINEL2_FILENAME_REGEX,
2481
- # config.SENTINEL2_MOSAIC_FILENAME_REGEX,
2482
3182
  config.SENTINEL2_CLOUD_FILENAME_REGEX,
2483
3183
  )
2484
3184
  all_bands: ClassVar[list[str]] = list(config.SENTINEL2_BANDS)
2485
- rbg_bands: ClassVar[list[str]] = ["B02", "B03", "B04"]
2486
- ndvi_bands: ClassVar[list[str]] = ["B04", "B08"]
2487
- cloud_band: ClassVar[str] = "SCL"
2488
- cloud_values: ClassVar[tuple[int]] = (3, 8, 9, 10, 11)
3185
+ rbg_bands: ClassVar[list[str]] = config.SENTINEL2_RBG_BANDS
3186
+ ndvi_bands: ClassVar[list[str]] = config.SENTINEL2_NDVI_BANDS
2489
3187
  l2a_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L2A_BANDS
2490
3188
  l1c_bands: ClassVar[dict[str, int]] = config.SENTINEL2_L1C_BANDS
2491
3189
  date_format: ClassVar[str] = "%Y%m%d" # T%H%M%S"
3190
+ masking: ClassVar[BandMasking] = BandMasking(
3191
+ band_id="SCL", values=(3, 8, 9, 10, 11)
3192
+ )
2492
3193
 
2493
3194
 
2494
3195
  class Sentinel2CloudlessConfig(Sentinel2Config):
@@ -2496,9 +3197,17 @@ class Sentinel2CloudlessConfig(Sentinel2Config):
2496
3197
 
2497
3198
  image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2498
3199
  filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2499
- cloud_band: ClassVar[None] = None
2500
- cloud_values: ClassVar[None] = None
3200
+ masking: ClassVar[None] = None
2501
3201
  date_format: ClassVar[str] = "%Y%m%d"
3202
+ all_bands: ClassVar[list[str]] = [
3203
+ x.replace("B0", "B") for x in Sentinel2Config.all_bands
3204
+ ]
3205
+ rbg_bands: ClassVar[list[str]] = [
3206
+ x.replace("B0", "B") for x in Sentinel2Config.rbg_bands
3207
+ ]
3208
+ ndvi_bands: ClassVar[list[str]] = [
3209
+ x.replace("B0", "B") for x in Sentinel2Config.ndvi_bands
3210
+ ]
2502
3211
 
2503
3212
 
2504
3213
  class Sentinel2Band(Sentinel2Config, Band):
@@ -2511,13 +3220,14 @@ class Sentinel2Image(Sentinel2Config, Image):
2511
3220
  cloud_cover_regexes: ClassVar[tuple[str]] = config.CLOUD_COVERAGE_REGEXES
2512
3221
  band_class: ClassVar[Sentinel2Band] = Sentinel2Band
2513
3222
 
2514
- def get_ndvi(
3223
+ def ndvi(
2515
3224
  self,
2516
3225
  red_band: str = Sentinel2Config.ndvi_bands[0],
2517
3226
  nir_band: str = Sentinel2Config.ndvi_bands[1],
3227
+ copy: bool = True,
2518
3228
  ) -> NDVIBand:
2519
3229
  """Calculate the NDVI for the Image."""
2520
- return super().get_ndvi(red_band=red_band, nir_band=nir_band)
3230
+ return super().ndvi(red_band=red_band, nir_band=nir_band, copy=copy)
2521
3231
 
2522
3232
 
2523
3233
  class Sentinel2Collection(Sentinel2Config, ImageCollection):
@@ -2534,27 +3244,14 @@ class Sentinel2CloudlessBand(Sentinel2CloudlessConfig, Band):
2534
3244
  class Sentinel2CloudlessImage(Sentinel2CloudlessConfig, Sentinel2Image):
2535
3245
  """Image for cloudless mosaic with Sentinel2 specific name variables and regexes."""
2536
3246
 
2537
- # image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2538
- # filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2539
-
2540
3247
  cloud_cover_regexes: ClassVar[None] = None
2541
3248
  band_class: ClassVar[Sentinel2CloudlessBand] = Sentinel2CloudlessBand
2542
3249
 
2543
- get_ndvi = Sentinel2Image.get_ndvi
2544
- # def get_ndvi(
2545
- # self,
2546
- # red_band: str = Sentinel2Config.ndvi_bands[0],
2547
- # nir_band: str = Sentinel2Config.ndvi_bands[1],
2548
- # ) -> NDVIBand:
2549
- # """Calculate the NDVI for the Image."""
2550
- # return super().get_ndvi(red_band=red_band, nir_band=nir_band)
3250
+ ndvi = Sentinel2Image.ndvi
2551
3251
 
2552
3252
 
2553
3253
  class Sentinel2CloudlessCollection(Sentinel2CloudlessConfig, ImageCollection):
2554
3254
  """ImageCollection with Sentinel2 specific name variables and regexes."""
2555
3255
 
2556
- # image_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_IMAGE_REGEX,)
2557
- # filename_regexes: ClassVar[str] = (config.SENTINEL2_MOSAIC_FILENAME_REGEX,)
2558
-
2559
3256
  image_class: ClassVar[Sentinel2CloudlessImage] = Sentinel2CloudlessImage
2560
- band_class: ClassVar[Sentinel2Band] = Sentinel2Band
3257
+ band_class: ClassVar[Sentinel2Band] = Sentinel2CloudlessBand