rastr 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rastr might be problematic. Click here for more details.

rastr/raster.py CHANGED
@@ -4,17 +4,20 @@ from __future__ import annotations
4
4
 
5
5
  import importlib.util
6
6
  import warnings
7
+ from collections.abc import Collection
7
8
  from contextlib import contextmanager
8
9
  from pathlib import Path
9
10
  from typing import TYPE_CHECKING, Any, Literal, overload
10
11
 
11
12
  import numpy as np
12
13
  import numpy.ma
14
+ import rasterio.features
13
15
  import rasterio.plot
14
16
  import rasterio.sample
15
17
  import rasterio.transform
16
18
  import skimage.measure
17
19
  from pydantic import BaseModel, InstanceOf, field_validator
20
+ from pyproj import Transformer
18
21
  from pyproj.crs.crs import CRS
19
22
  from rasterio.enums import Resampling
20
23
  from rasterio.io import MemoryFile
@@ -29,10 +32,14 @@ if TYPE_CHECKING:
29
32
  from collections.abc import Callable, Generator
30
33
 
31
34
  import geopandas as gpd
35
+ from affine import Affine
36
+ from branca.colormap import LinearColormap as BrancaLinearColormap
32
37
  from folium import Map
33
38
  from matplotlib.axes import Axes
39
+ from matplotlib.image import AxesImage
34
40
  from numpy.typing import ArrayLike, NDArray
35
41
  from rasterio.io import BufferedDatasetWriter, DatasetReader, DatasetWriter
42
+ from shapely import MultiPolygon
36
43
  from typing_extensions import Self
37
44
 
38
45
  try:
@@ -45,17 +52,28 @@ FOLIUM_INSTALLED = importlib.util.find_spec("folium") is not None
45
52
  BRANCA_INSTALLED = importlib.util.find_spec("branca") is not None
46
53
  MATPLOTLIB_INSTALLED = importlib.util.find_spec("matplotlib") is not None
47
54
 
55
+ CONTOUR_PERTURB_EPS = 1e-10
56
+
48
57
 
49
58
  class RasterCellArrayShapeError(ValueError):
50
59
  """Custom error for invalid raster cell array shapes."""
51
60
 
52
61
 
53
- class RasterModel(BaseModel):
62
+ class Raster(BaseModel):
54
63
  """2-dimensional raster and metadata."""
55
64
 
56
65
  arr: InstanceOf[np.ndarray]
57
66
  raster_meta: RasterMeta
58
67
 
68
+ @field_validator("arr")
69
+ @classmethod
70
+ def check_2d_array(cls, v: NDArray) -> NDArray:
71
+ """Validator to ensure the cell array is 2D."""
72
+ if v.ndim != 2:
73
+ msg = "Cell array must be 2D"
74
+ raise RasterCellArrayShapeError(msg)
75
+ return v
76
+
59
77
  @property
60
78
  def meta(self) -> RasterMeta:
61
79
  """Alias for raster_meta."""
@@ -80,6 +98,26 @@ class RasterModel(BaseModel):
80
98
  """Set the CRS via meta."""
81
99
  self.meta.crs = value
82
100
 
101
+ @property
102
+ def transform(self) -> Affine:
103
+ """Convenience property to access the transform via meta."""
104
+ return self.meta.transform
105
+
106
+ @transform.setter
107
+ def transform(self, value: Affine) -> None:
108
+ """Set the transform via meta."""
109
+ self.meta.transform = value
110
+
111
+ @property
112
+ def cell_size(self) -> float:
113
+ """Convenience property to access the cell size via meta."""
114
+ return self.meta.cell_size
115
+
116
+ @cell_size.setter
117
+ def cell_size(self, value: float) -> None:
118
+ """Set the cell size via meta."""
119
+ self.meta.cell_size = value
120
+
83
121
  def __init__(
84
122
  self,
85
123
  *,
@@ -108,14 +146,25 @@ class RasterModel(BaseModel):
108
146
  super().__init__(arr=arr, raster_meta=raster_meta)
109
147
 
110
148
  def __eq__(self, other: object) -> bool:
111
- """Check equality of two RasterModel objects."""
112
- if not isinstance(other, RasterModel):
149
+ """Check equality of two Raster objects."""
150
+ if not isinstance(other, Raster):
113
151
  return NotImplemented
114
152
  return (
115
153
  np.array_equal(self.arr, other.arr)
116
154
  and self.raster_meta == other.raster_meta
117
155
  )
118
156
 
157
+ def is_like(self, other: Raster) -> bool:
158
+ """Check if two Raster objects have the same metadata and shape.
159
+
160
+ Args:
161
+ other: Another Raster to compare with.
162
+
163
+ Returns:
164
+ True if both rasters have the same meta and shape attributes.
165
+ """
166
+ return self.meta == other.meta and self.shape == other.shape
167
+
119
168
  __hash__ = BaseModel.__hash__
120
169
 
121
170
  def __add__(self, other: float | Self) -> Self:
@@ -123,7 +172,7 @@ class RasterModel(BaseModel):
123
172
  if isinstance(other, float | int):
124
173
  new_arr = self.arr + other
125
174
  return cls(arr=new_arr, raster_meta=self.raster_meta)
126
- elif isinstance(other, RasterModel):
175
+ elif isinstance(other, Raster):
127
176
  if self.raster_meta != other.raster_meta:
128
177
  msg = (
129
178
  "Rasters must have the same metadata (e.g. CRS, cell size, etc.) "
@@ -149,7 +198,7 @@ class RasterModel(BaseModel):
149
198
  if isinstance(other, float | int):
150
199
  new_arr = self.arr * other
151
200
  return cls(arr=new_arr, raster_meta=self.raster_meta)
152
- elif isinstance(other, RasterModel):
201
+ elif isinstance(other, Raster):
153
202
  if self.raster_meta != other.raster_meta:
154
203
  msg = (
155
204
  "Rasters must have the same metadata (e.g. CRS, cell size, etc.) "
@@ -172,7 +221,7 @@ class RasterModel(BaseModel):
172
221
  if isinstance(other, float | int):
173
222
  new_arr = self.arr / other
174
223
  return cls(arr=new_arr, raster_meta=self.raster_meta)
175
- elif isinstance(other, RasterModel):
224
+ elif isinstance(other, Raster):
176
225
  if self.raster_meta != other.raster_meta:
177
226
  msg = (
178
227
  "Rasters must have the same metadata (e.g. CRS, cell size, etc.) "
@@ -222,7 +271,7 @@ class RasterModel(BaseModel):
222
271
  """Create a rasterio in-memory dataset from the Raster object.
223
272
 
224
273
  Example:
225
- >>> raster = RasterModel.example()
274
+ >>> raster = Raster.example()
226
275
  >>> with raster.to_rasterio_dataset() as dataset:
227
276
  >>> ...
228
277
  """
@@ -248,12 +297,30 @@ class RasterModel(BaseModel):
248
297
  finally:
249
298
  memfile.close()
250
299
 
300
+ @overload
301
+ def sample(
302
+ self,
303
+ xy: Collection[tuple[float, float]] | Collection[Point] | ArrayLike,
304
+ *,
305
+ na_action: Literal["raise", "ignore"] = "raise",
306
+ ) -> NDArray: ...
307
+ @overload
251
308
  def sample(
252
309
  self,
253
- xy: list[tuple[float, float]] | list[Point] | ArrayLike,
310
+ xy: tuple[float, float] | Point,
254
311
  *,
255
312
  na_action: Literal["raise", "ignore"] = "raise",
256
- ) -> NDArray[np.float64]:
313
+ ) -> float: ...
314
+ def sample(
315
+ self,
316
+ xy: Collection[tuple[float, float]]
317
+ | Collection[Point]
318
+ | ArrayLike
319
+ | tuple[float, float]
320
+ | Point,
321
+ *,
322
+ na_action: Literal["raise", "ignore"] = "raise",
323
+ ) -> NDArray | float:
257
324
  """Sample raster values at GeoSeries locations and return sampled values.
258
325
 
259
326
  Args:
@@ -270,13 +337,30 @@ class RasterModel(BaseModel):
270
337
  # https://rdrn.me/optimising-sampling/
271
338
 
272
339
  # Convert shapely Points to coordinate tuples if needed
273
- if isinstance(xy, (list, tuple)):
274
- xy = [_get_xy_tuple(point) for point in xy]
340
+ if isinstance(xy, Point):
341
+ xy = [(xy.x, xy.y)]
342
+ singleton = True
343
+ elif (
344
+ isinstance(xy, Collection)
345
+ and len(xy) > 0
346
+ and isinstance(next(iter(xy)), Point)
347
+ ):
348
+ xy = [(point.x, point.y) for point in xy] # pyright: ignore[reportAttributeAccessIssue]
349
+ singleton = False
350
+ elif (
351
+ isinstance(xy, tuple)
352
+ and len(xy) == 2
353
+ and isinstance(next(iter(xy)), (float, int))
354
+ ):
355
+ xy = [xy] # pyright: ignore[reportAssignmentType]
356
+ singleton = True
357
+ else:
358
+ singleton = False
275
359
 
276
360
  xy = np.asarray(xy, dtype=float)
277
361
 
278
- # Short-circuit
279
362
  if len(xy) == 0:
363
+ # Short-circuit
280
364
  return np.array([], dtype=float)
281
365
 
282
366
  # Create in-memory rasterio dataset from the incumbent Raster object
@@ -326,6 +410,10 @@ class RasterModel(BaseModel):
326
410
  axis=0,
327
411
  )
328
412
 
413
+ if singleton:
414
+ (raster_value,) = raster_values
415
+ return raster_value
416
+
329
417
  return raster_values
330
418
 
331
419
  @property
@@ -354,13 +442,16 @@ class RasterModel(BaseModel):
354
442
  ]
355
443
  )
356
444
 
357
- def explore(
445
+ def explore( # noqa: PLR0913 c.f. geopandas.explore which also has many input args
358
446
  self,
359
447
  *,
360
448
  m: Map | None = None,
361
449
  opacity: float = 1.0,
362
- colormap: str = "viridis",
450
+ colormap: str
451
+ | Callable[[float], tuple[float, float, float, float]] = "viridis",
363
452
  cbar_label: str | None = None,
453
+ vmin: float | None = None,
454
+ vmax: float | None = None,
364
455
  ) -> Map:
365
456
  """Display the raster on a folium map."""
366
457
  if not FOLIUM_INSTALLED or not MATPLOTLIB_INSTALLED:
@@ -368,39 +459,44 @@ class RasterModel(BaseModel):
368
459
  raise ImportError(msg)
369
460
 
370
461
  import folium.raster_layers
371
- import geopandas as gpd
372
462
  import matplotlib as mpl
373
463
 
374
464
  if m is None:
375
465
  m = folium.Map()
376
466
 
377
- rgba_map: Callable[[float], tuple[float, float, float, float]] = mpl.colormaps[
378
- colormap
379
- ]
467
+ if vmin is not None and vmax is not None and vmax <= vmin:
468
+ msg = "'vmin' must be less than 'vmax'."
469
+ raise ValueError(msg)
470
+
471
+ if isinstance(colormap, str):
472
+ colormap = mpl.colormaps[colormap]
380
473
 
381
- # Cast to GDF to facilitate converting bounds to WGS84
474
+ # Transform bounds to WGS84 using pyproj directly
382
475
  wgs84_crs = CRS.from_epsg(4326)
383
- gdf = gpd.GeoDataFrame(geometry=[self.bbox], crs=self.raster_meta.crs).to_crs(
384
- wgs84_crs
476
+ transformer = Transformer.from_crs(
477
+ self.raster_meta.crs, wgs84_crs, always_xy=True
385
478
  )
386
- xmin, ymin, xmax, ymax = gdf.total_bounds
387
479
 
388
- arr = np.array(self.arr)
480
+ # Get the corner points of the bounding box
481
+ raster_xmin, raster_ymin, raster_xmax, raster_ymax = self.bounds
482
+ corner_points = [
483
+ (raster_xmin, raster_ymin),
484
+ (raster_xmin, raster_ymax),
485
+ (raster_xmax, raster_ymax),
486
+ (raster_xmax, raster_ymin),
487
+ ]
389
488
 
390
- # Normalize the data to the range [0, 1] as this is the cmap range
391
- with warnings.catch_warnings():
392
- warnings.filterwarnings(
393
- "ignore",
394
- message="All-NaN slice encountered",
395
- category=RuntimeWarning,
396
- )
397
- min_val = np.nanmin(arr)
398
- max_val = np.nanmax(arr)
489
+ # Transform all corner points to WGS84
490
+ transformed_points = [transformer.transform(x, y) for x, y in corner_points]
399
491
 
400
- if max_val > min_val: # Prevent division by zero
401
- arr = (arr - min_val) / (max_val - min_val)
402
- else:
403
- arr = np.zeros_like(arr) # In case all values are the same
492
+ # Find the bounding box of the transformed points
493
+ transformed_xs, transformed_ys = zip(*transformed_points, strict=True)
494
+ xmin, xmax = min(transformed_xs), max(transformed_xs)
495
+ ymin, ymax = min(transformed_ys), max(transformed_ys)
496
+
497
+ # Normalize the array to [0, 1] for colormap mapping
498
+ _vmin, _vmax = _get_vmin_vmax(self, vmin=vmin, vmax=vmax)
499
+ arr = self.normalize(vmin=_vmin, vmax=_vmax).arr
404
500
 
405
501
  # Finally, need to determine whether to flip the image based on negative Affine
406
502
  # coefficients
@@ -411,11 +507,12 @@ class RasterModel(BaseModel):
411
507
  if flip_y:
412
508
  arr = np.flip(arr, axis=0)
413
509
 
510
+ bounds = [[ymin, xmin], [ymax, xmax]]
414
511
  img = folium.raster_layers.ImageOverlay(
415
512
  image=arr,
416
- bounds=[[ymin, xmin], [ymax, xmax]],
513
+ bounds=bounds,
417
514
  opacity=opacity,
418
- colormap=rgba_map,
515
+ colormap=colormap,
419
516
  mercator_project=True,
420
517
  )
421
518
 
@@ -423,26 +520,39 @@ class RasterModel(BaseModel):
423
520
 
424
521
  # Add a colorbar legend
425
522
  if BRANCA_INSTALLED:
426
- from branca.colormap import LinearColormap as BrancaLinearColormap
427
- from matplotlib.colors import to_hex
428
-
429
- # Determine legend data range in original units
430
- vmin = float(min_val) if np.isfinite(min_val) else 0.0
431
- vmax = float(max_val) if np.isfinite(max_val) else 1.0
432
- if vmax <= vmin:
433
- vmax = vmin + 1.0
434
-
435
- sample_points = np.linspace(0, 1, rgba_map.N)
436
- colors = [to_hex(rgba_map(x)) for x in sample_points]
437
- legend = BrancaLinearColormap(colors=colors, vmin=vmin, vmax=vmax)
523
+ cbar = _map_colorbar(colormap=colormap, vmin=_vmin, vmax=_vmax)
438
524
  if cbar_label:
439
- legend.caption = cbar_label
440
- legend.add_to(m)
525
+ cbar.caption = cbar_label
526
+ cbar.add_to(m)
441
527
 
442
- m.fit_bounds([[ymin, xmin], [ymax, xmax]])
528
+ m.fit_bounds(bounds)
443
529
 
444
530
  return m
445
531
 
532
+ def normalize(
533
+ self, *, vmin: float | None = None, vmax: float | None = None
534
+ ) -> Self:
535
+ """Normalize the raster values to the range [0, 1].
536
+
537
+ If custom vmin and vmax values are provided, values below vmin will be set to 0,
538
+ and values above vmax will be set to 1.
539
+
540
+ Args:
541
+ vmin: Minimum value for normalization. Values below this will be set to 0.
542
+ If None, the minimum value in the array is used.
543
+ vmax: Maximum value for normalization. Values above this will be set to 1.
544
+ If None, the maximum value in the array is used.
545
+ """
546
+ _vmin, _vmax = _get_vmin_vmax(self, vmin=vmin, vmax=vmax)
547
+
548
+ arr = self.arr.copy()
549
+ if _vmax > _vmin:
550
+ arr = (arr - _vmin) / (_vmax - _vmin)
551
+ arr = np.clip(arr, 0, 1)
552
+ else:
553
+ arr = np.zeros_like(arr)
554
+ return self.__class__(arr=arr, raster_meta=self.raster_meta)
555
+
446
556
  def to_clipboard(self) -> None:
447
557
  """Copy the raster cell array to the clipboard."""
448
558
  import pandas as pd
@@ -456,8 +566,22 @@ class RasterModel(BaseModel):
456
566
  cbar_label: str | None = None,
457
567
  basemap: bool = False,
458
568
  cmap: str = "viridis",
569
+ suppressed: Collection[float] | float = tuple(),
570
+ **kwargs: Any,
459
571
  ) -> Axes:
460
- """Plot the raster on a matplotlib axis."""
572
+ """Plot the raster on a matplotlib axis.
573
+
574
+ Args:
575
+ ax: A matplotlib axes object to plot on. If None, a new figure will be
576
+ created.
577
+ cbar_label: Label for the colorbar. If None, no label is added.
578
+ basemap: Whether to add a basemap. Currently not implemented.
579
+ cmap: Colormap to use for the plot.
580
+ suppressed: Values to suppress from the plot (i.e. not display). This can be
581
+ useful for zeroes especially.
582
+ **kwargs: Additional keyword arguments to pass to `rasterio.plot.show()`.
583
+ This includes parameters like `alpha` for transparency.
584
+ """
461
585
  if not MATPLOTLIB_INSTALLED:
462
586
  msg = "The 'matplotlib' package is required for 'plot()'."
463
587
  raise ImportError(msg)
@@ -465,6 +589,8 @@ class RasterModel(BaseModel):
465
589
  from matplotlib import pyplot as plt
466
590
  from mpl_toolkits.axes_grid1 import make_axes_locatable
467
591
 
592
+ suppressed = np.array(suppressed)
593
+
468
594
  if ax is None:
469
595
  _, _ax = plt.subplots()
470
596
  _ax: Axes
@@ -474,33 +600,34 @@ class RasterModel(BaseModel):
474
600
  msg = "Basemap plotting is not yet implemented."
475
601
  raise NotImplementedError(msg)
476
602
 
477
- arr = self.arr.copy()
603
+ model = self.model_copy()
604
+ model.arr = model.arr.copy()
478
605
 
479
- # Get extent of the non-zero values in array index coordinates
480
- (x_nonzero,) = np.nonzero(arr.any(axis=0))
481
- (y_nonzero,) = np.nonzero(arr.any(axis=1))
606
+ # Get extent of the unsuppressed values in array index coordinates
607
+ suppressed_mask = np.isin(model.arr, suppressed)
608
+ (x_unsuppressed,) = np.nonzero((~suppressed_mask).any(axis=0))
609
+ (y_unsuppressed,) = np.nonzero((~suppressed_mask).any(axis=1))
482
610
 
483
- if len(x_nonzero) == 0 or len(y_nonzero) == 0:
484
- msg = "Raster contains no non-zero values; cannot plot."
611
+ if len(x_unsuppressed) == 0 or len(y_unsuppressed) == 0:
612
+ msg = "Raster contains no unsuppressed values; cannot plot."
485
613
  raise ValueError(msg)
486
614
 
487
- min_x_nonzero = np.min(x_nonzero)
488
- max_x_nonzero = np.max(x_nonzero)
489
- min_y_nonzero = np.min(y_nonzero)
490
- max_y_nonzero = np.max(y_nonzero)
615
+ # N.B. these are array index coordinates, so np.min and np.max are safe since
616
+ # they cannot encounter NaN values.
617
+ min_x_unsuppressed = np.min(x_unsuppressed)
618
+ max_x_unsuppressed = np.max(x_unsuppressed)
619
+ min_y_unsuppressed = np.min(y_unsuppressed)
620
+ max_y_unsuppressed = np.max(y_unsuppressed)
491
621
 
492
622
  # Transform to raster CRS
493
- x1, y1 = self.raster_meta.transform * (min_x_nonzero, min_y_nonzero) # type: ignore[reportAssignmentType] overloaded tuple size in affine
494
- x2, y2 = self.raster_meta.transform * (max_x_nonzero, max_y_nonzero) # type: ignore[reportAssignmentType]
623
+ x1, y1 = self.raster_meta.transform * (min_x_unsuppressed, min_y_unsuppressed) # type: ignore[reportAssignmentType] overloaded tuple size in affine
624
+ x2, y2 = self.raster_meta.transform * (max_x_unsuppressed, max_y_unsuppressed) # type: ignore[reportAssignmentType]
495
625
  xmin, xmax = sorted([x1, x2])
496
626
  ymin, ymax = sorted([y1, y2])
497
627
 
498
- arr[arr == 0] = np.nan
628
+ model.arr[suppressed_mask] = np.nan
499
629
 
500
- with self.to_rasterio_dataset() as dataset:
501
- img, *_ = rasterio.plot.show(
502
- dataset, with_bounds=True, ax=ax, cmap=cmap
503
- ).get_images()
630
+ img, *_ = model.rio_show(ax=ax, cmap=cmap, with_bounds=True, **kwargs)
504
631
 
505
632
  ax.set_xlim(xmin, xmax)
506
633
  ax.set_ylim(ymin, ymax)
@@ -516,6 +643,20 @@ class RasterModel(BaseModel):
516
643
  fig.colorbar(img, label=cbar_label, cax=cax)
517
644
  return ax
518
645
 
646
+ def rio_show(self, **kwargs: Any) -> list[AxesImage]:
647
+ """Plot the raster using rasterio's built-in plotting function.
648
+
649
+ This is useful for lower-level access to rasterio's plotting capabilities.
650
+ Generally, the `plot()` method is preferred for most use cases.
651
+
652
+ Args:
653
+ **kwargs: Keyword arguments to pass to `rasterio.plot.show()`. This includes
654
+ parameters like `alpha` for transparency, and `with_bounds` to control
655
+ whether to plot in spatial coordinates or array index coordinates.
656
+ """
657
+ with self.to_rasterio_dataset() as dataset:
658
+ return rasterio.plot.show(dataset, **kwargs).get_images()
659
+
519
660
  def as_geodataframe(self, name: str = "value") -> gpd.GeoDataFrame:
520
661
  """Create a GeoDataFrame representation of the raster."""
521
662
  import geopandas as gpd
@@ -532,8 +673,15 @@ class RasterModel(BaseModel):
532
673
 
533
674
  return raster_gdf
534
675
 
535
- def to_file(self, path: Path | str) -> None:
536
- """Write the raster to a GeoTIFF file."""
676
+ def to_file(self, path: Path | str, **kwargs: Any) -> None:
677
+ """Write the raster to a GeoTIFF file.
678
+
679
+ Args:
680
+ path: Path to output file.
681
+ **kwargs: Additional keyword arguments to pass to `rasterio.open()`. If
682
+ `nodata` is provided, NaN values in the raster will be replaced
683
+ with the nodata value.
684
+ """
537
685
 
538
686
  path = Path(path)
539
687
 
@@ -548,6 +696,15 @@ class RasterModel(BaseModel):
548
696
  msg = f"Unsupported file extension: {suffix}"
549
697
  raise ValueError(msg)
550
698
 
699
+ # Handle nodata: use provided value or default to np.nan
700
+ if "nodata" in kwargs:
701
+ # Replace NaN values with the nodata value
702
+ nodata_value = kwargs.pop("nodata")
703
+ arr_to_write = np.where(np.isnan(self.arr), nodata_value, self.arr)
704
+ else:
705
+ nodata_value = np.nan
706
+ arr_to_write = self.arr
707
+
551
708
  with rasterio.open(
552
709
  path,
553
710
  "w",
@@ -558,10 +715,11 @@ class RasterModel(BaseModel):
558
715
  dtype=self.arr.dtype,
559
716
  crs=self.raster_meta.crs,
560
717
  transform=self.raster_meta.transform,
561
- nodata=np.nan,
718
+ nodata=nodata_value,
719
+ **kwargs,
562
720
  ) as dst:
563
721
  try:
564
- dst.write(self.arr, 1)
722
+ dst.write(arr_to_write, 1)
565
723
  except CPLE_BaseError as err:
566
724
  msg = f"Failed to write raster to file: {err}"
567
725
  raise OSError(msg) from err
@@ -576,7 +734,7 @@ class RasterModel(BaseModel):
576
734
 
577
735
  @classmethod
578
736
  def example(cls) -> Self:
579
- """Create an example RasterModel."""
737
+ """Create an example Raster."""
580
738
  # Peaks dataset style example
581
739
  n = 256
582
740
  x = np.linspace(-3, 3, n)
@@ -588,6 +746,34 @@ class RasterModel(BaseModel):
588
746
  raster_meta = RasterMeta.example()
589
747
  return cls(arr=arr, raster_meta=raster_meta)
590
748
 
749
+ @classmethod
750
+ def full_like(cls, other: Raster, *, fill_value: float) -> Self:
751
+ """Create a raster with the same metadata as another but filled with a constant.
752
+
753
+ Args:
754
+ other: The raster to copy metadata from.
755
+ fill_value: The constant value to fill all cells with.
756
+
757
+ Returns:
758
+ A new raster with the same shape and metadata as `other`, but with all cells
759
+ set to `fill_value`.
760
+ """
761
+ arr = np.full(other.shape, fill_value, dtype=np.float32)
762
+ return cls(arr=arr, raster_meta=other.raster_meta)
763
+
764
+ @classmethod
765
+ def read_file(cls, filename: Path | str, crs: CRS | str | None = None) -> Self:
766
+ """Read raster data from a file and return an in-memory Raster object.
767
+
768
+ Args:
769
+ filename: Path to the raster file.
770
+ crs: Optional coordinate reference system to override the file's CRS.
771
+ """
772
+ # Import here to avoid circular import (rastr.io imports Raster)
773
+ from rastr.io import read_raster_inmem # noqa: PLC0415
774
+
775
+ return read_raster_inmem(filename, crs=crs, cls=cls)
776
+
591
777
  @overload
592
778
  def apply(
593
779
  self,
@@ -625,6 +811,59 @@ class RasterModel(BaseModel):
625
811
  new_raster.arr = np.asarray(new_arr)
626
812
  return new_raster
627
813
 
814
+ def max(self) -> float:
815
+ """Get the maximum value in the raster, ignoring NaN values.
816
+
817
+ Returns:
818
+ The maximum value in the raster. Returns NaN if all values are NaN.
819
+ """
820
+ return float(np.nanmax(self.arr))
821
+
822
+ def min(self) -> float:
823
+ """Get the minimum value in the raster, ignoring NaN values.
824
+
825
+ Returns:
826
+ The minimum value in the raster. Returns NaN if all values are NaN.
827
+ """
828
+ return float(np.nanmin(self.arr))
829
+
830
+ def mean(self) -> float:
831
+ """Get the mean value in the raster, ignoring NaN values.
832
+
833
+ Returns:
834
+ The mean value in the raster. Returns NaN if all values are NaN.
835
+ """
836
+ return float(np.nanmean(self.arr))
837
+
838
+ def std(self) -> float:
839
+ """Get the standard deviation of values in the raster, ignoring NaN values.
840
+
841
+ Returns:
842
+ The standard deviation of the raster. Returns NaN if all values are NaN.
843
+ """
844
+ return float(np.nanstd(self.arr))
845
+
846
+ def quantile(self, q: float) -> float:
847
+ """Get the specified quantile value in the raster, ignoring NaN values.
848
+
849
+ Args:
850
+ q: Quantile to compute, must be between 0 and 1 inclusive.
851
+
852
+ Returns:
853
+ The quantile value. Returns NaN if all values are NaN.
854
+ """
855
+ return float(np.nanquantile(self.arr, q))
856
+
857
+ def median(self) -> float:
858
+ """Get the median value in the raster, ignoring NaN values.
859
+
860
+ This is equivalent to quantile(0.5).
861
+
862
+ Returns:
863
+ The median value in the raster. Returns NaN if all values are NaN.
864
+ """
865
+ return float(np.nanmedian(self.arr))
866
+
628
867
  def fillna(self, value: float) -> Self:
629
868
  """Fill NaN values in the raster with a specified value.
630
869
 
@@ -635,6 +874,16 @@ class RasterModel(BaseModel):
635
874
  new_raster.arr = filled_arr
636
875
  return new_raster
637
876
 
877
+ def copy(self) -> Self: # type: ignore[override]
878
+ """Create a copy of the raster.
879
+
880
+ This method wraps `model_copy()` for convenience.
881
+
882
+ Returns:
883
+ A new Raster instance.
884
+ """
885
+ return self.model_copy(deep=True)
886
+
638
887
  def get_xy(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
639
888
  """Get the x and y coordinates of the raster cell centres in meshgrid format.
640
889
 
@@ -651,7 +900,7 @@ class RasterModel(BaseModel):
651
900
  return coords[:, :, 0], coords[:, :, 1]
652
901
 
653
902
  def contour(
654
- self, levels: list[float] | NDArray, *, smoothing: bool = True
903
+ self, levels: Collection[float] | NDArray, *, smoothing: bool = True
655
904
  ) -> gpd.GeoDataFrame:
656
905
  """Create contour lines from the raster data, optionally with smoothing.
657
906
 
@@ -664,8 +913,8 @@ class RasterModel(BaseModel):
664
913
  contouring, to denoise the contours.
665
914
 
666
915
  Args:
667
- levels: A list or array of contour levels to generate. The contour lines
668
- will be generated for each level in this sequence.
916
+ levels: A collection or array of contour levels to generate. The contour
917
+ lines will be generated for each level in this sequence.
669
918
  smoothing: Defaults to true, which corresponds to applying a smoothing
670
919
  algorithm to the contour lines. At the moment, this is the
671
920
  Catmull-Rom spline algorithm. If set to False, the raw
@@ -676,9 +925,17 @@ class RasterModel(BaseModel):
676
925
  all_levels = []
677
926
  all_geoms = []
678
927
  for level in levels:
928
+ # If this is the maximum or minimum level, perturb it ever-so-slightly to
929
+ # ensure we get contours at the edges of the raster
930
+ perturbed_level = level
931
+ if level == self.max():
932
+ perturbed_level -= CONTOUR_PERTURB_EPS
933
+ elif level == self.min():
934
+ perturbed_level += CONTOUR_PERTURB_EPS
935
+
679
936
  contours = skimage.measure.find_contours(
680
937
  self.arr,
681
- level=level,
938
+ level=perturbed_level,
682
939
  )
683
940
 
684
941
  # Construct shapely LineString objects
@@ -713,19 +970,40 @@ class RasterModel(BaseModel):
713
970
  # Dissolve contours by level to merge all contour lines of the same level
714
971
  return contour_gdf.dissolve(by="level", as_index=False)
715
972
 
716
- def blur(self, sigma: float) -> Self:
973
+ def blur(self, sigma: float, *, preserve_nan: bool = True) -> Self:
717
974
  """Apply a Gaussian blur to the raster data.
718
975
 
719
976
  Args:
720
977
  sigma: Standard deviation for Gaussian kernel, in units of geographic
721
978
  coordinate distance (e.g. meters). A larger sigma results in a more
722
979
  blurred image.
980
+ preserve_nan: If True, applies NaN-safe blurring by extrapolating NaN values
981
+ before blurring and restoring them afterwards. This prevents
982
+ NaNs from spreading into valid data during the blur operation.
723
983
  """
724
984
  from scipy.ndimage import gaussian_filter
725
985
 
726
986
  cell_sigma = sigma / self.raster_meta.cell_size
727
987
 
728
- blurred_array = gaussian_filter(self.arr, sigma=cell_sigma)
988
+ if preserve_nan:
989
+ # Save the original NaN mask
990
+ nan_mask = np.isnan(self.arr)
991
+
992
+ # If there are no NaNs, just apply regular blur
993
+ if not np.any(nan_mask):
994
+ blurred_array = gaussian_filter(self.arr, sigma=cell_sigma)
995
+ else:
996
+ # Extrapolate to fill NaN values temporarily
997
+ extrapolated_arr = fillna_nearest_neighbours(arr=self.arr)
998
+
999
+ # Apply blur to the extrapolated array
1000
+ blurred_array = gaussian_filter(extrapolated_arr, sigma=cell_sigma)
1001
+
1002
+ # Restore original NaN values
1003
+ blurred_array = np.where(nan_mask, np.nan, blurred_array)
1004
+ else:
1005
+ blurred_array = gaussian_filter(self.arr, sigma=cell_sigma)
1006
+
729
1007
  new_raster = self.model_copy()
730
1008
  new_raster.arr = blurred_array
731
1009
  return new_raster
@@ -751,9 +1029,75 @@ class RasterModel(BaseModel):
751
1029
 
752
1030
  return raster
753
1031
 
1032
+ def pad(self, width: float, *, value: float = np.nan) -> Self:
1033
+ """Extend the raster by adding a constant fill value around the edges.
1034
+
1035
+ By default, the padding value is NaN, but this can be changed via the
1036
+ `value` parameter.
1037
+
1038
+ This grows the raster by adding padding around all edges. New cells are
1039
+ filled with the constant `value`.
1040
+
1041
+ If the width is not an exact multiple of the cell size, the padding may be
1042
+ slightly larger than the specified width, i.e. the value is rounded up to
1043
+ the nearest whole number of cells.
1044
+
1045
+ Args:
1046
+ width: The width of the padding, in the same units as the raster CRS
1047
+ (e.g. meters). This defines how far from the edge the padding
1048
+ extends.
1049
+ value: The constant value to use for padding. Default is NaN.
1050
+ """
1051
+ cell_size = self.raster_meta.cell_size
1052
+
1053
+ # Calculate number of cells to pad in each direction
1054
+ pad_cells = int(np.ceil(width / cell_size))
1055
+
1056
+ # Get current bounds
1057
+ xmin, ymin, xmax, ymax = self.bounds
1058
+
1059
+ # Calculate new bounds with padding
1060
+ new_xmin = xmin - (pad_cells * cell_size)
1061
+ new_ymin = ymin - (pad_cells * cell_size)
1062
+ new_xmax = xmax + (pad_cells * cell_size)
1063
+ new_ymax = ymax + (pad_cells * cell_size)
1064
+
1065
+ # Create padded array
1066
+ new_height = self.arr.shape[0] + 2 * pad_cells
1067
+ new_width = self.arr.shape[1] + 2 * pad_cells
1068
+
1069
+ # Create new array filled with the padding value
1070
+ padded_arr = np.full((new_height, new_width), value, dtype=self.arr.dtype)
1071
+
1072
+ # Copy original array into the center of the padded array
1073
+ padded_arr[
1074
+ pad_cells : pad_cells + self.arr.shape[0],
1075
+ pad_cells : pad_cells + self.arr.shape[1],
1076
+ ] = self.arr
1077
+
1078
+ # Create new transform for the padded raster
1079
+ new_transform = rasterio.transform.from_bounds(
1080
+ west=new_xmin,
1081
+ south=new_ymin,
1082
+ east=new_xmax,
1083
+ north=new_ymax,
1084
+ width=new_width,
1085
+ height=new_height,
1086
+ )
1087
+
1088
+ # Create new raster metadata
1089
+ new_meta = RasterMeta(
1090
+ cell_size=cell_size,
1091
+ crs=self.raster_meta.crs,
1092
+ transform=new_transform,
1093
+ )
1094
+
1095
+ return self.__class__(arr=padded_arr, raster_meta=new_meta)
1096
+
754
1097
  def crop(
755
1098
  self,
756
1099
  bounds: tuple[float, float, float, float],
1100
+ *,
757
1101
  strategy: Literal["underflow", "overflow"] = "underflow",
758
1102
  ) -> Self:
759
1103
  """Crop the raster to the specified bounds as (minx, miny, maxx, maxy).
@@ -767,7 +1111,7 @@ class RasterModel(BaseModel):
767
1111
  remains covered with cells.
768
1112
 
769
1113
  Returns:
770
- A new RasterModel instance cropped to the specified bounds.
1114
+ A new Raster instance cropped to the specified bounds.
771
1115
  """
772
1116
 
773
1117
  minx, miny, maxx, maxy = bounds
@@ -827,6 +1171,159 @@ class RasterModel(BaseModel):
827
1171
  )
828
1172
  return cls(arr=cropped_arr, raster_meta=new_meta)
829
1173
 
1174
+ def taper_border(self, width: float, *, limit: float = 0.0) -> Self:
1175
+ """Taper values to a limiting value around the border of the raster.
1176
+
1177
+ By default, the borders are tapered to zero, but this can be changed via the
1178
+ `limit` parameter.
1179
+
1180
+ This keeps the raster size the same, overwriting values in the border area.
1181
+ To instead grow the raster, consider using `pad()` followed by `taper_border()`.
1182
+
1183
+ The tapering is linear from the cell centres around the border of the raster,
1184
+ so the value at the edge of the raster will be equal to `limit`.
1185
+
1186
+ Args:
1187
+ width: The width of the taper, in the same units as the raster CRS
1188
+ (e.g. meters). This defines how far from the edge the tapering
1189
+ starts.
1190
+ limit: The limiting value to taper to at the edges. Default is zero.
1191
+ """
1192
+
1193
+ # Determine the width in cell units (possibly fractional)
1194
+ cell_size = self.raster_meta.cell_size
1195
+ width_in_cells = width / cell_size
1196
+
1197
+ # Calculate the distance from the edge in cell units
1198
+ arr_height, arr_width = self.arr.shape
1199
+ y_indices, x_indices = np.indices((int(arr_height), int(arr_width)))
1200
+ dist_from_left = x_indices
1201
+ dist_from_right = arr_width - 1 - x_indices
1202
+ dist_from_top = y_indices
1203
+ dist_from_bottom = arr_height - 1 - y_indices
1204
+ dist_from_edge = np.minimum.reduce(
1205
+ [dist_from_left, dist_from_right, dist_from_top, dist_from_bottom]
1206
+ )
1207
+
1208
+ # Mask the arrays to only the area within the width from the edge, rounding up
1209
+ mask = dist_from_edge < np.ceil(width_in_cells)
1210
+ masked_dist_arr = np.where(mask, dist_from_edge, np.nan)
1211
+ masked_arr = np.where(mask, self.arr, np.nan)
1212
+
1213
+ # Calculate the tapering factor based on the distance from the edge
1214
+ taper_factor = np.clip(masked_dist_arr / width_in_cells, 0.0, 1.0)
1215
+ tapered_values = limit + (masked_arr - limit) * taper_factor
1216
+
1217
+ # Create the new raster array
1218
+ new_arr = self.arr.copy()
1219
+ new_arr[mask] = tapered_values[mask]
1220
+ new_raster = self.model_copy()
1221
+ new_raster.arr = new_arr
1222
+
1223
+ return new_raster
1224
+
1225
+ def clip(
1226
+ self,
1227
+ polygon: Polygon | MultiPolygon,
1228
+ *,
1229
+ strategy: Literal["centres"] = "centres",
1230
+ ) -> Self:
1231
+ """Clip the raster to the specified polygon, replacing cells outside with NaN.
1232
+
1233
+ The clipping strategy determines how to handle cells that are partially
1234
+ within the polygon. Currently, only the 'centres' strategy is supported, which
1235
+ retains cells whose centres fall within the polygon.
1236
+
1237
+ Args:
1238
+ polygon: A shapely Polygon or MultiPolygon defining the area to clip to.
1239
+ strategy: The clipping strategy to use. Currently only 'centres' is
1240
+ supported, which retains cells whose centres fall within the
1241
+ polygon.
1242
+
1243
+ Returns:
1244
+ A new Raster with cells outside the polygon set to NaN.
1245
+ """
1246
+ if strategy != "centres":
1247
+ msg = f"Unsupported clipping strategy: {strategy}"
1248
+ raise NotImplementedError(msg)
1249
+
1250
+ raster = self.model_copy()
1251
+
1252
+ mask = rasterio.features.rasterize(
1253
+ [(polygon, 1)],
1254
+ fill=0,
1255
+ out_shape=self.shape,
1256
+ transform=self.meta.transform,
1257
+ dtype=np.uint8,
1258
+ )
1259
+
1260
+ raster.arr = np.where(mask, raster.arr, np.nan)
1261
+
1262
+ return raster
1263
+
1264
+ def _trim_value(self, *, value_mask: NDArray[np.bool_], value_name: str) -> Self:
1265
+ """Crop the raster by trimming away slices matching the mask at the edges.
1266
+
1267
+ Args:
1268
+ value_mask: Boolean mask where True indicates values to trim
1269
+ value_name: Name of the value type for error messages (e.g., 'NaN', 'zero')
1270
+ """
1271
+ arr = self.arr
1272
+
1273
+ # Check if the entire array matches the mask
1274
+ if np.all(value_mask):
1275
+ msg = f"Cannot crop raster: all values are {value_name}"
1276
+ raise ValueError(msg)
1277
+
1278
+ # Find rows and columns that are not all matching the mask
1279
+ row_mask = np.all(value_mask, axis=1)
1280
+ col_mask = np.all(value_mask, axis=0)
1281
+
1282
+ # Find the bounding indices
1283
+ (row_indices,) = np.where(~row_mask)
1284
+ (col_indices,) = np.where(~col_mask)
1285
+
1286
+ min_row, max_row = row_indices[0], row_indices[-1]
1287
+ min_col, max_col = col_indices[0], col_indices[-1]
1288
+
1289
+ # Crop the array
1290
+ cropped_arr = arr[min_row : max_row + 1, min_col : max_col + 1]
1291
+
1292
+ # Shift the transform by the number of pixels cropped (min_col, min_row)
1293
+ new_transform = (
1294
+ self.raster_meta.transform
1295
+ * rasterio.transform.Affine.translation(min_col, min_row)
1296
+ )
1297
+
1298
+ # Create new metadata
1299
+ new_meta = RasterMeta(
1300
+ cell_size=self.raster_meta.cell_size,
1301
+ crs=self.raster_meta.crs,
1302
+ transform=new_transform,
1303
+ )
1304
+
1305
+ return self.__class__(arr=cropped_arr, raster_meta=new_meta)
1306
+
1307
+ def trim_nan(self) -> Self:
1308
+ """Crop the raster by trimming away all-NaN slices at the edges.
1309
+
1310
+ This effectively trims the raster to the smallest bounding box that contains all
1311
+ of the non-NaN values. Note that this does not guarantee no NaN values at all
1312
+ around the edges, only that there won't be entire edges which are all-NaN.
1313
+
1314
+ Consider using `.extrapolate()` for further cleanup of NaN values.
1315
+ """
1316
+ return self._trim_value(value_mask=np.isnan(self.arr), value_name="NaN")
1317
+
1318
+ def trim_zeros(self) -> Self:
1319
+ """Crop the raster by trimming away all-zero slices at the edges.
1320
+
1321
+ This effectively trims the raster to the smallest bounding box that contains all
1322
+ of the non-zero values. Note that this does not guarantee no zero values at all
1323
+ around the edges, only that there won't be entire edges which are all-zero.
1324
+ """
1325
+ return self._trim_value(value_mask=(self.arr == 0), value_name="zero")
1326
+
830
1327
  def resample(
831
1328
  self, new_cell_size: float, *, method: Literal["bilinear"] = "bilinear"
832
1329
  ) -> Self:
@@ -874,26 +1371,55 @@ class RasterModel(BaseModel):
874
1371
 
875
1372
  return cls(arr=new_arr, raster_meta=new_raster_meta)
876
1373
 
877
- @field_validator("arr")
878
- @classmethod
879
- def check_2d_array(cls, v: NDArray) -> NDArray:
880
- """Validator to ensure the cell array is 2D."""
881
- if v.ndim != 2:
882
- msg = "Cell array must be 2D"
883
- raise RasterCellArrayShapeError(msg)
884
- return v
885
1374
 
1375
+ def _map_colorbar(
1376
+ *,
1377
+ colormap: Callable[[float], tuple[float, float, float, float]],
1378
+ vmin: float,
1379
+ vmax: float,
1380
+ ) -> BrancaLinearColormap:
1381
+ from branca.colormap import LinearColormap as BrancaLinearColormap
1382
+ from matplotlib.colors import ListedColormap, to_hex
1383
+
1384
+ # Determine legend data range in original units
1385
+ vmin = float(vmin) if np.isfinite(vmin) else 0.0
1386
+ vmax = float(vmax) if np.isfinite(vmax) else 1.0
1387
+ if vmax <= vmin:
1388
+ vmax = vmin + 1.0
1389
+
1390
+ if isinstance(colormap, ListedColormap):
1391
+ n = colormap.N
1392
+ else:
1393
+ n = 256
1394
+
1395
+ sample_points = np.linspace(0, 1, n)
1396
+ colors = [to_hex(colormap(x)) for x in sample_points]
1397
+ return BrancaLinearColormap(colors=colors, vmin=vmin, vmax=vmax)
886
1398
 
887
- def _get_xy_tuple(xy: Any) -> tuple[float, float]:
888
- """Convert Point or coordinate tuple to coordinate tuple.
889
1399
 
890
- Args:
891
- xy: Either a coordinate tuple or a shapely Point object.
1400
+ def _get_vmin_vmax(
1401
+ raster: Raster, *, vmin: float | None = None, vmax: float | None = None
1402
+ ) -> tuple[float, float]:
1403
+ """Get maximum and minimum values from a raster array, ignoring NaNs.
892
1404
 
893
- Returns:
894
- A coordinate tuple (x, y).
1405
+ Allows for custom over-ride vmin and vmax values to be provided.
895
1406
  """
896
- if isinstance(xy, Point):
897
- return (xy.x, xy.y)
898
- x, y = xy
899
- return (float(x), float(y))
1407
+ with warnings.catch_warnings():
1408
+ warnings.filterwarnings(
1409
+ "ignore",
1410
+ message="All-NaN slice encountered",
1411
+ category=RuntimeWarning,
1412
+ )
1413
+ if vmin is None:
1414
+ _vmin = float(raster.min())
1415
+ else:
1416
+ _vmin = vmin
1417
+ if vmax is None:
1418
+ _vmax = float(raster.max())
1419
+ else:
1420
+ _vmax = vmax
1421
+
1422
+ return _vmin, _vmax
1423
+
1424
+
1425
+ RasterModel = Raster