rastr 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rastr might be problematic. Click here for more details.

rastr/raster.py CHANGED
@@ -1,30 +1,24 @@
1
1
  """Raster data structure."""
2
2
 
3
- from collections.abc import Callable, Generator
3
+ from __future__ import annotations
4
+
5
+ import importlib.util
6
+ import warnings
4
7
  from contextlib import contextmanager
5
8
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Literal
9
+ from typing import TYPE_CHECKING, Any, Literal
7
10
 
8
- import geopandas as gpd
9
- import matplotlib as mpl
10
11
  import numpy as np
11
- import pandas as pd
12
+ import numpy.ma
12
13
  import rasterio.plot
13
14
  import rasterio.sample
14
15
  import rasterio.transform
15
16
  import skimage.measure
16
- import xyzservices.providers as xyz
17
- from matplotlib import pyplot as plt
18
- from matplotlib.axes import Axes
19
- from mpl_toolkits.axes_grid1 import make_axes_locatable
20
- from numpy.typing import NDArray
21
17
  from pydantic import BaseModel, InstanceOf, field_validator
22
18
  from pyproj.crs.crs import CRS
23
19
  from rasterio.enums import Resampling
24
- from rasterio.io import BufferedDatasetWriter, DatasetReader, DatasetWriter, MemoryFile
25
- from scipy.ndimage import gaussian_filter
26
- from shapely.geometry import LineString, Polygon
27
- from typing_extensions import Self
20
+ from rasterio.io import MemoryFile
21
+ from shapely.geometry import LineString, Point, Polygon
28
22
 
29
23
  from rastr.arr.fill import fillna_nearest_neighbours
30
24
  from rastr.gis.fishnet import create_fishnet
@@ -32,16 +26,14 @@ from rastr.gis.smooth import catmull_rom_smooth
32
26
  from rastr.meta import RasterMeta
33
27
 
34
28
  if TYPE_CHECKING:
35
- from folium import Map
29
+ from collections.abc import Callable, Generator
36
30
 
37
- try:
38
- import folium
39
- import folium.raster_layers
31
+ import geopandas as gpd
40
32
  from folium import Map
41
- except ImportError:
42
- FOLIUM_INSTALLED = False
43
- else:
44
- FOLIUM_INSTALLED = True
33
+ from matplotlib.axes import Axes
34
+ from numpy.typing import ArrayLike, NDArray
35
+ from rasterio.io import BufferedDatasetWriter, DatasetReader, DatasetWriter
36
+ from typing_extensions import Self
45
37
 
46
38
  try:
47
39
  from rasterio._err import CPLE_BaseError
@@ -49,7 +41,9 @@ except ImportError:
49
41
  CPLE_BaseError = Exception # Fallback if private module import fails
50
42
 
51
43
 
52
- CTX_BASEMAP_SOURCE = xyz.Esri.WorldImagery # pyright: ignore[reportAttributeAccessIssue]
44
+ FOLIUM_INSTALLED = importlib.util.find_spec("folium") is not None
45
+ BRANCA_INSTALLED = importlib.util.find_spec("branca") is not None
46
+ MATPLOTLIB_INSTALLED = importlib.util.find_spec("matplotlib") is not None
53
47
 
54
48
 
55
49
  class RasterCellArrayShapeError(ValueError):
@@ -62,6 +56,42 @@ class RasterModel(BaseModel):
62
56
  arr: InstanceOf[np.ndarray]
63
57
  raster_meta: RasterMeta
64
58
 
59
+ @property
60
+ def meta(self) -> RasterMeta:
61
+ """Alias for raster_meta."""
62
+ return self.raster_meta
63
+
64
+ @meta.setter
65
+ def meta(self, value: RasterMeta) -> None:
66
+ self.raster_meta = value
67
+
68
+ def __init__(
69
+ self,
70
+ *,
71
+ arr: ArrayLike,
72
+ meta: RasterMeta | None = None,
73
+ raster_meta: RasterMeta | None = None,
74
+ ) -> None:
75
+ arr = np.asarray(arr)
76
+
77
+ # Set the meta
78
+ if meta is not None and raster_meta is not None:
79
+ msg = (
80
+ "Only one of 'meta' or 'raster_meta' should be provided, they are "
81
+ "aliases."
82
+ )
83
+ raise ValueError(msg)
84
+ elif meta is not None and raster_meta is None:
85
+ raster_meta = meta
86
+ elif meta is None and raster_meta is not None:
87
+ pass
88
+ else:
89
+ # Don't need to mention `'meta'` to simplify the messaging.
90
+ msg = "The attribute 'raster_meta' is required."
91
+ raise ValueError(msg)
92
+
93
+ super().__init__(arr=arr, raster_meta=raster_meta)
94
+
65
95
  def __eq__(self, other: object) -> bool:
66
96
  """Check equality of two RasterModel objects."""
67
97
  if not isinstance(other, RasterModel):
@@ -74,9 +104,10 @@ class RasterModel(BaseModel):
74
104
  __hash__ = BaseModel.__hash__
75
105
 
76
106
  def __add__(self, other: float | Self) -> Self:
107
+ cls = self.__class__
77
108
  if isinstance(other, float | int):
78
109
  new_arr = self.arr + other
79
- return RasterModel(arr=new_arr, raster_meta=self.raster_meta)
110
+ return cls(arr=new_arr, raster_meta=self.raster_meta)
80
111
  elif isinstance(other, RasterModel):
81
112
  if self.raster_meta != other.raster_meta:
82
113
  msg = (
@@ -91,7 +122,7 @@ class RasterModel(BaseModel):
91
122
  )
92
123
  raise ValueError(msg)
93
124
  new_arr = self.arr + other.arr
94
- return RasterModel(arr=new_arr, raster_meta=self.raster_meta)
125
+ return cls(arr=new_arr, raster_meta=self.raster_meta)
95
126
  else:
96
127
  return NotImplemented
97
128
 
@@ -99,9 +130,10 @@ class RasterModel(BaseModel):
99
130
  return self + other
100
131
 
101
132
  def __mul__(self, other: float | Self) -> Self:
133
+ cls = self.__class__
102
134
  if isinstance(other, float | int):
103
135
  new_arr = self.arr * other
104
- return RasterModel(arr=new_arr, raster_meta=self.raster_meta)
136
+ return cls(arr=new_arr, raster_meta=self.raster_meta)
105
137
  elif isinstance(other, RasterModel):
106
138
  if self.raster_meta != other.raster_meta:
107
139
  msg = (
@@ -113,7 +145,7 @@ class RasterModel(BaseModel):
113
145
  msg = "Rasters must have the same shape to be multiplied"
114
146
  raise ValueError(msg)
115
147
  new_arr = self.arr * other.arr
116
- return RasterModel(arr=new_arr, raster_meta=self.raster_meta)
148
+ return cls(arr=new_arr, raster_meta=self.raster_meta)
117
149
  else:
118
150
  return NotImplemented
119
151
 
@@ -121,9 +153,10 @@ class RasterModel(BaseModel):
121
153
  return self * other
122
154
 
123
155
  def __truediv__(self, other: float | Self) -> Self:
156
+ cls = self.__class__
124
157
  if isinstance(other, float | int):
125
158
  new_arr = self.arr / other
126
- return RasterModel(arr=new_arr, raster_meta=self.raster_meta)
159
+ return cls(arr=new_arr, raster_meta=self.raster_meta)
127
160
  elif isinstance(other, RasterModel):
128
161
  if self.raster_meta != other.raster_meta:
129
162
  msg = (
@@ -135,7 +168,7 @@ class RasterModel(BaseModel):
135
168
  msg = "Rasters must have the same shape to be divided"
136
169
  raise ValueError(msg)
137
170
  new_arr = self.arr / other.arr
138
- return RasterModel(arr=new_arr, raster_meta=self.raster_meta)
171
+ return cls(arr=new_arr, raster_meta=self.raster_meta)
139
172
  else:
140
173
  return NotImplemented
141
174
 
@@ -149,13 +182,24 @@ class RasterModel(BaseModel):
149
182
  return -self + other
150
183
 
151
184
  def __neg__(self) -> Self:
152
- return RasterModel(arr=-self.arr, raster_meta=self.raster_meta)
185
+ cls = self.__class__
186
+ return cls(arr=-self.arr, raster_meta=self.raster_meta)
153
187
 
154
188
  @property
155
189
  def cell_centre_coords(self) -> NDArray[np.float64]:
156
190
  """Get the coordinates of the cell centres in the raster."""
157
191
  return self.raster_meta.get_cell_centre_coords(self.arr.shape)
158
192
 
193
+ @property
194
+ def cell_x_coords(self) -> NDArray[np.float64]:
195
+ """Get the x coordinates of the cell centres in the raster."""
196
+ return self.raster_meta.get_cell_x_coords(self.arr.shape[0])
197
+
198
+ @property
199
+ def cell_y_coords(self) -> NDArray[np.float64]:
200
+ """Get the y coordinates of the cell centres in the raster."""
201
+ return self.raster_meta.get_cell_y_coords(self.arr.shape[1])
202
+
159
203
  @contextmanager
160
204
  def to_rasterio_dataset(
161
205
  self,
@@ -191,14 +235,15 @@ class RasterModel(BaseModel):
191
235
 
192
236
  def sample(
193
237
  self,
194
- xy: list[tuple[float, float]],
238
+ xy: list[tuple[float, float]] | list[Point] | ArrayLike,
195
239
  *,
196
240
  na_action: Literal["raise", "ignore"] = "raise",
197
241
  ) -> NDArray[np.float64]:
198
242
  """Sample raster values at GeoSeries locations and return sampled values.
199
243
 
200
244
  Args:
201
- xy: A list of (x, y) coordinates to sample the raster at.
245
+ xy: A list of (x, y) coordinates or shapely Point objects to sample the
246
+ raster at.
202
247
  na_action: Action to take when a NaN value is encountered in the input xy.
203
248
  Options are "raise" (raise an error) or "ignore" (replace with
204
249
  NaN).
@@ -209,6 +254,12 @@ class RasterModel(BaseModel):
209
254
  # If this function is too slow, consider the optimizations detailed here:
210
255
  # https://rdrn.me/optimising-sampling/
211
256
 
257
+ # Convert shapely Points to coordinate tuples if needed
258
+ if isinstance(xy, (list, tuple)):
259
+ xy = [_get_xy_tuple(point) for point in xy]
260
+
261
+ xy = np.asarray(xy, dtype=float)
262
+
212
263
  # Short-circuit
213
264
  if len(xy) == 0:
214
265
  return np.array([], dtype=float)
@@ -245,7 +296,7 @@ class RasterModel(BaseModel):
245
296
 
246
297
  # Convert the sampled values to a NumPy array and set masked values to NaN
247
298
  raster_values = np.array(
248
- [s.data[0] if not s.mask else np.nan for s in samples]
299
+ [s.data[0] if not numpy.ma.getmask(s) else np.nan for s in samples]
249
300
  ).astype(float)
250
301
 
251
302
  if len(xy_nan_idxs) > 0:
@@ -294,19 +345,25 @@ class RasterModel(BaseModel):
294
345
  m: Map | None = None,
295
346
  opacity: float = 1.0,
296
347
  colormap: str = "viridis",
348
+ cbar_label: str | None = None,
297
349
  ) -> Map:
298
350
  """Display the raster on a folium map."""
299
- if not FOLIUM_INSTALLED:
300
- msg = "The 'folium' package is required for 'explore()'."
351
+ if not FOLIUM_INSTALLED or not MATPLOTLIB_INSTALLED:
352
+ msg = "The 'folium' and 'matplotlib' packages are required for 'explore()'."
301
353
  raise ImportError(msg)
302
354
 
355
+ import folium.raster_layers
356
+ import geopandas as gpd
357
+ import matplotlib as mpl
358
+
303
359
  if m is None:
304
360
  m = folium.Map()
305
361
 
306
- rbga_map: Callable[[float], tuple[float, float, float, float]] = mpl.colormaps[
362
+ rgba_map: Callable[[float], tuple[float, float, float, float]] = mpl.colormaps[
307
363
  colormap
308
364
  ]
309
365
 
366
+ # Cast to GDF to facilitate converting bounds to WGS84
310
367
  wgs84_crs = CRS.from_epsg(4326)
311
368
  gdf = gpd.GeoDataFrame(geometry=[self.bbox], crs=self.raster_meta.crs).to_crs(
312
369
  wgs84_crs
@@ -316,8 +373,15 @@ class RasterModel(BaseModel):
316
373
  arr = np.array(self.arr)
317
374
 
318
375
  # Normalize the data to the range [0, 1] as this is the cmap range
319
- min_val = np.nanmin(arr)
320
- max_val = np.nanmax(arr)
376
+ with warnings.catch_warnings():
377
+ warnings.filterwarnings(
378
+ "ignore",
379
+ message="All-NaN slice encountered",
380
+ category=RuntimeWarning,
381
+ )
382
+ min_val = np.nanmin(arr)
383
+ max_val = np.nanmax(arr)
384
+
321
385
  if max_val > min_val: # Prevent division by zero
322
386
  arr = (arr - min_val) / (max_val - min_val)
323
387
  else:
@@ -328,26 +392,46 @@ class RasterModel(BaseModel):
328
392
  flip_x = self.raster_meta.transform.a < 0
329
393
  flip_y = self.raster_meta.transform.e > 0
330
394
  if flip_x:
331
- arr = np.flip(self.arr, axis=1)
395
+ arr = np.flip(arr, axis=1)
332
396
  if flip_y:
333
- arr = np.flip(self.arr, axis=0)
397
+ arr = np.flip(arr, axis=0)
334
398
 
335
399
  img = folium.raster_layers.ImageOverlay(
336
400
  image=arr,
337
401
  bounds=[[ymin, xmin], [ymax, xmax]],
338
402
  opacity=opacity,
339
- colormap=rbga_map,
403
+ colormap=rgba_map,
340
404
  mercator_project=True,
341
405
  )
342
406
 
343
407
  img.add_to(m)
344
408
 
409
+ # Add a colorbar legend
410
+ if BRANCA_INSTALLED:
411
+ from branca.colormap import LinearColormap as BrancaLinearColormap
412
+ from matplotlib.colors import to_hex
413
+
414
+ # Determine legend data range in original units
415
+ vmin = float(min_val) if np.isfinite(min_val) else 0.0
416
+ vmax = float(max_val) if np.isfinite(max_val) else 1.0
417
+ if vmax <= vmin:
418
+ vmax = vmin + 1.0
419
+
420
+ sample_points = np.linspace(0, 1, rgba_map.N)
421
+ colors = [to_hex(rgba_map(x)) for x in sample_points]
422
+ legend = BrancaLinearColormap(colors=colors, vmin=vmin, vmax=vmax)
423
+ if cbar_label:
424
+ legend.caption = cbar_label
425
+ legend.add_to(m)
426
+
345
427
  m.fit_bounds([[ymin, xmin], [ymax, xmax]])
346
428
 
347
429
  return m
348
430
 
349
431
  def to_clipboard(self) -> None:
350
432
  """Copy the raster cell array to the clipboard."""
433
+ import pandas as pd
434
+
351
435
  pd.DataFrame(self.arr).to_clipboard(index=False, header=False)
352
436
 
353
437
  def plot(
@@ -359,9 +443,17 @@ class RasterModel(BaseModel):
359
443
  cmap: str = "viridis",
360
444
  ) -> Axes:
361
445
  """Plot the raster on a matplotlib axis."""
446
+ if not MATPLOTLIB_INSTALLED:
447
+ msg = "The 'matplotlib' package is required for 'plot()'."
448
+ raise ImportError(msg)
449
+
450
+ from matplotlib import pyplot as plt
451
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
452
+
362
453
  if ax is None:
363
- _, ax = plt.subplots()
364
- ax: Axes
454
+ _, _ax = plt.subplots()
455
+ _ax: Axes
456
+ ax = _ax
365
457
 
366
458
  if basemap:
367
459
  msg = "Basemap plotting is not yet implemented."
@@ -383,8 +475,8 @@ class RasterModel(BaseModel):
383
475
  max_y_nonzero = np.max(y_nonzero)
384
476
 
385
477
  # Transform to raster CRS
386
- x1, y1 = self.raster_meta.transform * (min_x_nonzero, min_y_nonzero)
387
- x2, y2 = self.raster_meta.transform * (max_x_nonzero, max_y_nonzero)
478
+ x1, y1 = self.raster_meta.transform * (min_x_nonzero, min_y_nonzero) # type: ignore[reportAssignmentType] overloaded tuple size in affine
479
+ x2, y2 = self.raster_meta.transform * (max_x_nonzero, max_y_nonzero) # type: ignore[reportAssignmentType]
388
480
  xmin, xmax = sorted([x1, x2])
389
481
  ymin, ymax = sorted([y1, y2])
390
482
 
@@ -405,11 +497,14 @@ class RasterModel(BaseModel):
405
497
  divider = make_axes_locatable(ax)
406
498
  cax = divider.append_axes("right", size="5%", pad=0.05)
407
499
  fig = ax.get_figure()
408
- fig.colorbar(img, label=cbar_label, cax=cax)
500
+ if fig is not None:
501
+ fig.colorbar(img, label=cbar_label, cax=cax)
409
502
  return ax
410
503
 
411
504
  def as_geodataframe(self, name: str = "value") -> gpd.GeoDataFrame:
412
505
  """Create a GeoDataFrame representation of the raster."""
506
+ import geopandas as gpd
507
+
413
508
  polygons = create_fishnet(bounds=self.bounds, res=self.raster_meta.cell_size)
414
509
  point_tuples = [polygon.centroid.coords[0] for polygon in polygons]
415
510
  raster_gdf = gpd.GeoDataFrame(
@@ -422,9 +517,11 @@ class RasterModel(BaseModel):
422
517
 
423
518
  return raster_gdf
424
519
 
425
- def to_file(self, path: Path) -> None:
520
+ def to_file(self, path: Path | str) -> None:
426
521
  """Write the raster to a GeoTIFF file."""
427
522
 
523
+ path = Path(path)
524
+
428
525
  suffix = path.suffix.lower()
429
526
  if suffix in (".tif", ".tiff"):
430
527
  driver = "GTiff"
@@ -455,8 +552,9 @@ class RasterModel(BaseModel):
455
552
  raise OSError(msg) from err
456
553
 
457
554
  def __str__(self) -> str:
555
+ cls = self.__class__
458
556
  mean = np.nanmean(self.arr)
459
- return f"RasterModel(shape={self.arr.shape}, {mean=})"
557
+ return f"{cls.__name__}(shape={self.arr.shape}, {mean=})"
460
558
 
461
559
  def __repr__(self) -> str:
462
560
  return str(self)
@@ -486,24 +584,22 @@ class RasterModel(BaseModel):
486
584
  return new_raster
487
585
 
488
586
  def get_xy(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
489
- """Get the x and y coordinates of the raster in meshgrid format."""
490
- col_idx, row_idx = np.meshgrid(
491
- np.arange(self.arr.shape[1]),
492
- np.arange(self.arr.shape[0]),
493
- )
494
-
495
- col_idx = col_idx.flatten()
496
- row_idx = row_idx.flatten()
587
+ """Get the x and y coordinates of the raster cell centres in meshgrid format.
497
588
 
498
- coords = np.vstack((row_idx, col_idx)).T
589
+ Returns the coordinates of the cell centres as two separate 2D arrays in
590
+ meshgrid format, where each array has the same shape as the raster data array.
499
591
 
500
- x, y = rasterio.transform.xy(self.raster_meta.transform, *coords.T)
501
- x = np.array(x).reshape(self.arr.shape)
502
- y = np.array(y).reshape(self.arr.shape)
503
- return x, y
592
+ Returns:
593
+ A tuple of (x, y) coordinate arrays where:
594
+ - x: 2D array of x-coordinates of cell centres
595
+ - y: 2D array of y-coordinates of cell centres
596
+ Both arrays have the same shape as the raster data array.
597
+ """
598
+ coords = self.raster_meta.get_cell_centre_coords(self.arr.shape)
599
+ return coords[:, :, 0], coords[:, :, 1]
504
600
 
505
601
  def contour(
506
- self, *, levels: list[float], smoothing: bool = True
602
+ self, *, levels: list[float] | NDArray, smoothing: bool = True
507
603
  ) -> gpd.GeoDataFrame:
508
604
  """Create contour lines from the raster data, optionally with smoothing.
509
605
 
@@ -514,13 +610,14 @@ class RasterModel(BaseModel):
514
610
  contouring, to denoise the contours.
515
611
 
516
612
  Args:
517
- levels: A list of contour levels to generate. The contour lines will be
518
- generated for each level in this list.
613
+ levels: A list or array of contour levels to generate. The contour lines
614
+ will be generated for each level in this sequence.
519
615
  smoothing: Defaults to true, which corresponds to applying a smoothing
520
616
  algorithm to the contour lines. At the moment, this is the
521
617
  Catmull-Rom spline algorithm. If set to False, the raw
522
618
  contours will be returned without any smoothing.
523
619
  """
620
+ import geopandas as gpd
524
621
 
525
622
  all_levels = []
526
623
  all_geoms = []
@@ -569,6 +666,7 @@ class RasterModel(BaseModel):
569
666
  coordinate distance (e.g. meters). A larger sigma results in a more
570
667
  blurred image.
571
668
  """
669
+ from scipy.ndimage import gaussian_filter
572
670
 
573
671
  cell_sigma = sigma / self.raster_meta.cell_size
574
672
 
@@ -598,6 +696,82 @@ class RasterModel(BaseModel):
598
696
 
599
697
  return raster
600
698
 
699
+ def crop(
700
+ self,
701
+ bounds: tuple[float, float, float, float],
702
+ strategy: Literal["underflow", "overflow"] = "underflow",
703
+ ) -> Self:
704
+ """Crop the raster to the specified bounds.
705
+
706
+ Args:
707
+ bounds: A tuple of (minx, miny, maxx, maxy) defining the bounds to crop to.
708
+ strategy: The cropping strategy to use. 'underflow' will crop the raster
709
+ to be fully within the bounds, ignoring any cells that are
710
+ partially outside the bounds. 'overflow' will instead include
711
+ cells that intersect the bounds, ensuring the bounds area
712
+ remains covered with cells.
713
+
714
+ Returns:
715
+ A new RasterModel instance cropped to the specified bounds.
716
+ """
717
+
718
+ minx, miny, maxx, maxy = bounds
719
+ arr = self.arr
720
+
721
+ # Get the half cell size for cropping
722
+ cell_size = self.raster_meta.cell_size
723
+ half_cell_size = cell_size / 2
724
+
725
+ # Get the cell centre coordinates as 1D arrays
726
+ x_coords = self.cell_x_coords
727
+ y_coords = self.cell_y_coords
728
+
729
+ # Get the indices to crop the array
730
+ if strategy == "underflow":
731
+ x_idx = (x_coords >= minx + half_cell_size) & (
732
+ x_coords <= maxx - half_cell_size
733
+ )
734
+ y_idx = (y_coords >= miny + half_cell_size) & (
735
+ y_coords <= maxy - half_cell_size
736
+ )
737
+ elif strategy == "overflow":
738
+ x_idx = (x_coords > minx - half_cell_size) & (
739
+ x_coords < maxx + half_cell_size
740
+ )
741
+ y_idx = (y_coords > miny - half_cell_size) & (
742
+ y_coords < maxy + half_cell_size
743
+ )
744
+ else:
745
+ msg = f"Unsupported cropping strategy: {strategy}"
746
+ raise NotImplementedError(msg)
747
+
748
+ # Crop the array
749
+ cropped_arr = arr[np.ix_(x_idx, y_idx)]
750
+
751
+ # Check the shape of the cropped array
752
+ if cropped_arr.size == 0:
753
+ msg = "Cropped array is empty; no cells within the specified bounds."
754
+ raise ValueError(msg)
755
+
756
+ # Recalculate the transform for the cropped raster
757
+ x_coords = x_coords[x_idx]
758
+ y_coords = y_coords[y_idx]
759
+ transform = rasterio.transform.from_bounds(
760
+ west=x_coords.min() - half_cell_size,
761
+ south=y_coords.min() - half_cell_size,
762
+ east=x_coords.max() + half_cell_size,
763
+ north=y_coords.max() + half_cell_size,
764
+ width=cropped_arr.shape[1],
765
+ height=cropped_arr.shape[0],
766
+ )
767
+
768
+ # Update the raster
769
+ cls = self.__class__
770
+ new_meta = RasterMeta(
771
+ cell_size=cell_size, crs=self.raster_meta.crs, transform=transform
772
+ )
773
+ return cls(arr=cropped_arr, raster_meta=new_meta)
774
+
601
775
  def resample(
602
776
  self, new_cell_size: float, *, method: Literal["bilinear"] = "bilinear"
603
777
  ) -> Self:
@@ -619,6 +793,7 @@ class RasterModel(BaseModel):
619
793
 
620
794
  factor = self.raster_meta.cell_size / new_cell_size
621
795
 
796
+ cls = self.__class__
622
797
  # Use the rasterio dataset with proper context management
623
798
  with self.to_rasterio_dataset() as dataset:
624
799
  # N.B. the new height and width may increase slightly.
@@ -642,13 +817,28 @@ class RasterModel(BaseModel):
642
817
  cell_size=new_cell_size,
643
818
  )
644
819
 
645
- return RasterModel(arr=new_arr, raster_meta=new_raster_meta)
820
+ return cls(arr=new_arr, raster_meta=new_raster_meta)
646
821
 
647
822
  @field_validator("arr")
648
823
  @classmethod
649
- def check_2d_array(cls, v: np.ndarray) -> np.ndarray:
824
+ def check_2d_array(cls, v: NDArray) -> NDArray:
650
825
  """Validator to ensure the cell array is 2D."""
651
826
  if v.ndim != 2:
652
827
  msg = "Cell array must be 2D"
653
828
  raise RasterCellArrayShapeError(msg)
654
829
  return v
830
+
831
+
832
+ def _get_xy_tuple(xy: Any) -> tuple[float, float]:
833
+ """Convert Point or coordinate tuple to coordinate tuple.
834
+
835
+ Args:
836
+ xy: Either a coordinate tuple or a shapely Point object.
837
+
838
+ Returns:
839
+ A coordinate tuple (x, y).
840
+ """
841
+ if isinstance(xy, Point):
842
+ return (xy.x, xy.y)
843
+ x, y = xy
844
+ return (float(x), float(y))