sgspy 1.0.2__cp314-cp314-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgspy/__init__.py +82 -0
- sgspy/_sgs.cpython-314-x86_64-linux-gnu.so +0 -0
- sgspy/calculate/__init__.py +18 -0
- sgspy/calculate/pca/__init__.py +2 -0
- sgspy/calculate/pca/pca.py +158 -0
- sgspy/calculate/representation/__init__.py +2 -0
- sgspy/calculate/representation/representation.py +3 -0
- sgspy/sample/__init__.py +30 -0
- sgspy/sample/ahels/__init__.py +2 -0
- sgspy/sample/ahels/ahels.py +3 -0
- sgspy/sample/clhs/__init__.py +2 -0
- sgspy/sample/clhs/clhs.py +202 -0
- sgspy/sample/nc/__init__.py +2 -0
- sgspy/sample/nc/nc.py +3 -0
- sgspy/sample/srs/__init__.py +2 -0
- sgspy/sample/srs/srs.py +228 -0
- sgspy/sample/strat/__init__.py +2 -0
- sgspy/sample/strat/strat.py +394 -0
- sgspy/sample/systematic/__init__.py +2 -0
- sgspy/sample/systematic/systematic.py +233 -0
- sgspy/stratify/__init__.py +27 -0
- sgspy/stratify/breaks/__init__.py +2 -0
- sgspy/stratify/breaks/breaks.py +222 -0
- sgspy/stratify/kmeans/__init__.py +2 -0
- sgspy/stratify/kmeans/kmeans.py +3 -0
- sgspy/stratify/map/__init__.py +2 -0
- sgspy/stratify/map/map_stratifications.py +244 -0
- sgspy/stratify/poly/__init__.py +2 -0
- sgspy/stratify/poly/poly.py +170 -0
- sgspy/stratify/quantiles/__init__.py +2 -0
- sgspy/stratify/quantiles/quantiles.py +276 -0
- sgspy/utils/__init__.py +18 -0
- sgspy/utils/plot.py +143 -0
- sgspy/utils/raster.py +605 -0
- sgspy/utils/vector.py +268 -0
- sgspy-1.0.2.data/data/sgspy/libonedal.so.3 +0 -0
- sgspy-1.0.2.data/data/sgspy/proj.db +0 -0
- sgspy-1.0.2.dist-info/METADATA +13 -0
- sgspy-1.0.2.dist-info/RECORD +40 -0
- sgspy-1.0.2.dist-info/WHEEL +5 -0
sgspy/utils/raster.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
# ******************************************************************************
|
|
2
|
+
#
|
|
3
|
+
# Project: sgs
|
|
4
|
+
# Purpose: GDALDataset wrapper for raster operations
|
|
5
|
+
# Author: Joseph Meyer
|
|
6
|
+
# Date: June, 2025
|
|
7
|
+
#
|
|
8
|
+
# ******************************************************************************
|
|
9
|
+
|
|
10
|
+
import importlib.util
|
|
11
|
+
import sys
|
|
12
|
+
import os
|
|
13
|
+
import site
|
|
14
|
+
import shutil
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import matplotlib.pyplot as plt
|
|
19
|
+
import matplotlib #for type checking matplotlib.axes.Axes
|
|
20
|
+
|
|
21
|
+
from .import plot
|
|
22
|
+
from .plot import plot_raster
|
|
23
|
+
|
|
24
|
+
#ensure _sgs binary can be found
|
|
25
|
+
site_packages = list(filter(lambda x : 'site-packages' in x, site.getsitepackages()))[0]
|
|
26
|
+
sys.path.append(os.path.join(site_packages, "sgspy"))
|
|
27
|
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|
28
|
+
from _sgs import GDALRasterWrapper
|
|
29
|
+
|
|
30
|
+
#rasterio optional import
|
|
31
|
+
try:
|
|
32
|
+
import rasterio
|
|
33
|
+
RASTERIO = True
|
|
34
|
+
except ImportError as e:
|
|
35
|
+
RASTERIO = False
|
|
36
|
+
|
|
37
|
+
#gdal optional import
|
|
38
|
+
try:
|
|
39
|
+
from osgeo import gdal
|
|
40
|
+
from osgeo import gdal_array
|
|
41
|
+
GDAL = True
|
|
42
|
+
except ImportError as e:
|
|
43
|
+
GDAL = False
|
|
44
|
+
|
|
45
|
+
PROJDB_PATH = os.path.join(sys.prefix, "sgspy")
|
|
46
|
+
|
|
47
|
+
##
|
|
48
|
+
# @ingroup user_utils
|
|
49
|
+
# This class represents a spatial raster, and is used as an input to many sgs functions.
|
|
50
|
+
#
|
|
51
|
+
# It has a number of additional uses, including accessing the raster data within as a numpy array,
|
|
52
|
+
# plotting with matplotlib, as well as converting to a GDAL or Rasterio dataset object. This class
|
|
53
|
+
# also has various attributes representing metadata of the raster which may be useful and can be
|
|
54
|
+
# seen in the 'Public Attributes' section.
|
|
55
|
+
#
|
|
56
|
+
# Accessing raster data
|
|
57
|
+
# --------------------
|
|
58
|
+
#
|
|
59
|
+
# raster data can be accessed in the form of a NumPy array per band. This can be done using
|
|
60
|
+
# the 'band' function. The band function takes a single parameter, which must be either
|
|
61
|
+
# an integer or a string. If it is an integer, it must refer to a valid zero-indexed band
|
|
62
|
+
# number. If it is a string, it must refer to a valid band name within the raster. This function
|
|
63
|
+
# may fail if the band is too large to fit in memory.
|
|
64
|
+
#
|
|
65
|
+
# rast = sgspy.SpatialRaster('test.tif') #raster with three layers
|
|
66
|
+
#
|
|
67
|
+
# b0 = rast.band(band=0) @n
|
|
68
|
+
# b1 = rast.band(band=1) @n
|
|
69
|
+
# b2 = rast.band(band=2) @n
|
|
70
|
+
#
|
|
71
|
+
# zq90 = rast.band(band='zq90') @n
|
|
72
|
+
# pzabove2 = rast.band(band='pzabove2') @n
|
|
73
|
+
# zstd = rast.band(band='zstd') @n
|
|
74
|
+
#
|
|
75
|
+
# Accessing raster information
|
|
76
|
+
# --------------------
|
|
77
|
+
#
|
|
78
|
+
# raster metadata can be displayed using the info() function. Info
|
|
79
|
+
# inclues: raster driver, band names, dimensions, pixel size, and bounds.
|
|
80
|
+
#
|
|
81
|
+
# rast = sgspy.SpatialRaster('test.tif') @n
|
|
82
|
+
# rast.info()
|
|
83
|
+
#
|
|
84
|
+
# Plotting raster
|
|
85
|
+
# --------------------
|
|
86
|
+
#
|
|
87
|
+
# the plot() function provides a wrapper around matplotlibs imshow
|
|
88
|
+
# functionality (matplotlib.pyplot.imshow). Only a single band can
|
|
89
|
+
# be plotted, and for multi-band rasters an indication must be given
|
|
90
|
+
# for which band to plot.
|
|
91
|
+
#
|
|
92
|
+
# Target width and heights can be given in the parameters
|
|
93
|
+
# target_width and target_height. Default parameters are 1000 pixels for both.
|
|
94
|
+
# Information on the actual downsampling can be found here:
|
|
95
|
+
# https://gdal.org/en/stable/api/gdaldataset_cpp.html#classGDALDataset_1ae66e21b09000133a0f4d99baabf7a0ec
|
|
96
|
+
#
|
|
97
|
+
# If no 'band' argument is given, the function will throw an error if the
|
|
98
|
+
# image does not contain a single band.
|
|
99
|
+
#
|
|
100
|
+
# The 'band' argument allows the end-user to specify either the band
|
|
101
|
+
# index or the band name. 'band' may be an int or str.
|
|
102
|
+
#
|
|
103
|
+
# Optionally, any of the arguments which may be passed to the matplotlib
|
|
104
|
+
# imshow function may also be passed to plot_image(), such as cmap
|
|
105
|
+
# for a specific color mapping.
|
|
106
|
+
#
|
|
107
|
+
# #plots the single band @n
|
|
108
|
+
# rast = sgspy.SpatialRaster('test_single_band_raster.tif') @n
|
|
109
|
+
# rast.plot_image()
|
|
110
|
+
#
|
|
111
|
+
# #plots the second band @n
|
|
112
|
+
# rast = sgspy.SpatialRaster('test_multi_band_raster.tif') @n
|
|
113
|
+
# rast.plot(band=1)
|
|
114
|
+
#
|
|
115
|
+
# #plots the 'zq90' band @n
|
|
116
|
+
# rast = sgspy.SpatialRaster('test_multi_band_raster.tif') @n
|
|
117
|
+
# rast.plot(band='zq90')
|
|
118
|
+
#
|
|
119
|
+
# Public Attributes
|
|
120
|
+
# --------------------
|
|
121
|
+
# driver : str @n
|
|
122
|
+
# gdal dataset driver, for info/display purposes @n @n
|
|
123
|
+
# width : int @n
|
|
124
|
+
# the pixel width of the raster image @n @n
|
|
125
|
+
# height : int @n
|
|
126
|
+
# the pixel height of the raster image @n @n
|
|
127
|
+
# band_count : int @n
|
|
128
|
+
# the number of bands in the raster image @n @n
|
|
129
|
+
# bands : list[str] @n
|
|
130
|
+
# the raster band names @n @n
|
|
131
|
+
# crs : str @n
|
|
132
|
+
# coordinate reference system @n @n
|
|
133
|
+
# projection : str @n
|
|
134
|
+
# full projection string as wkt @n @n
|
|
135
|
+
# xmin : double @n
|
|
136
|
+
# minimum x value as defined by the gdal geotransform @n @n
|
|
137
|
+
# xmax : double @n
|
|
138
|
+
# maximum x value as defined by the gdal geotransform @n @n
|
|
139
|
+
# ymin : double @n
|
|
140
|
+
# minimum y value as defined by the gdal geotransform @n @n
|
|
141
|
+
# ymax : double @n
|
|
142
|
+
# maximum y value as defined by the gdal geotransform @n @n
|
|
143
|
+
# pixel_height : double @n
|
|
144
|
+
# pixel height as defined by the gdal geotransform @n @n
|
|
145
|
+
# pixel_width : double @n
|
|
146
|
+
# pixel width as defined by the gdal geotransform @n @n
|
|
147
|
+
#
|
|
148
|
+
# Public Methods
|
|
149
|
+
# --------------------
|
|
150
|
+
# info() @n
|
|
151
|
+
# takes no arguments, prints raster information to the console @n @n
|
|
152
|
+
# plot() @n
|
|
153
|
+
# takes one optional 'band' argument of type int, or str @n @n
|
|
154
|
+
# band() @n
|
|
155
|
+
# returns the band data as a numpy array, may throw an error if the raster band is too large
|
|
156
|
+
#
|
|
157
|
+
# Optionally, any of the arguments that can be passed to matplotlib.pyplot.imshow
|
|
158
|
+
# can also be passed to plot_image().
|
|
159
|
+
class SpatialRaster:
|
|
160
|
+
|
|
161
|
+
have_temp_dir = False
|
|
162
|
+
temp_dataset = False
|
|
163
|
+
filename = ""
|
|
164
|
+
closed = False
|
|
165
|
+
|
|
166
|
+
def __init__(self,
|
|
167
|
+
image: str | GDALRasterWrapper):
|
|
168
|
+
"""
|
|
169
|
+
Constructing method for the SpatialRaster class.
|
|
170
|
+
|
|
171
|
+
Has one required parameter to specify a raster path. The following
|
|
172
|
+
attributes are populated:
|
|
173
|
+
self.cpp_raster
|
|
174
|
+
self.driver
|
|
175
|
+
self.width
|
|
176
|
+
self.height
|
|
177
|
+
self.band_count
|
|
178
|
+
self.crs
|
|
179
|
+
self.projection
|
|
180
|
+
self.xmin
|
|
181
|
+
self.xmax
|
|
182
|
+
self.ymin
|
|
183
|
+
self.ymax
|
|
184
|
+
self.pixel_height
|
|
185
|
+
self.pixel_width
|
|
186
|
+
self.bands
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
--------------------
|
|
190
|
+
image : str
|
|
191
|
+
specifies a raster file path
|
|
192
|
+
"""
|
|
193
|
+
if (type(image) is str):
|
|
194
|
+
self.cpp_raster = GDALRasterWrapper(image, PROJDB_PATH)
|
|
195
|
+
self.filename = image
|
|
196
|
+
elif type(image) is GDALRasterWrapper:
|
|
197
|
+
self.cpp_raster = image
|
|
198
|
+
else:
|
|
199
|
+
raise TypeError("'image' parameter of SpatialRaster constructor must be of type str or GDALRasterWrapper")
|
|
200
|
+
|
|
201
|
+
self.driver = self.cpp_raster.get_driver()
|
|
202
|
+
self.width = self.cpp_raster.get_width()
|
|
203
|
+
self.height = self.cpp_raster.get_height()
|
|
204
|
+
self.band_count = self.cpp_raster.get_band_count()
|
|
205
|
+
self.crs = self.cpp_raster.get_crs()
|
|
206
|
+
self.projection = self.cpp_raster.get_projection().encode('ascii', 'ignore').decode('unicode_escape')
|
|
207
|
+
self.xmin = self.cpp_raster.get_xmin()
|
|
208
|
+
self.xmax = self.cpp_raster.get_xmax()
|
|
209
|
+
self.ymin = self.cpp_raster.get_ymin()
|
|
210
|
+
self.ymax = self.cpp_raster.get_ymax()
|
|
211
|
+
self.pixel_width = self.cpp_raster.get_pixel_width()
|
|
212
|
+
self.pixel_height = self.cpp_raster.get_pixel_height()
|
|
213
|
+
self.band_name_dict = {}
|
|
214
|
+
self.band_data_dict = {}
|
|
215
|
+
self.bands = self.cpp_raster.get_bands()
|
|
216
|
+
for i in range(0, len(self.bands)):
|
|
217
|
+
self.band_name_dict[self.bands[i]] = i
|
|
218
|
+
|
|
219
|
+
def __del__(self):
|
|
220
|
+
if self.have_temp_dir:
|
|
221
|
+
shutil.rmtree(self.temp_dir)
|
|
222
|
+
|
|
223
|
+
def info(self):
|
|
224
|
+
"""
|
|
225
|
+
Displays driver, band, size, pixel size, and bound information of the raster.
|
|
226
|
+
"""
|
|
227
|
+
if self.closed:
|
|
228
|
+
raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
|
|
229
|
+
|
|
230
|
+
print("driver: {}".format(self.driver))
|
|
231
|
+
print("bands: {}".format(*self.bands))
|
|
232
|
+
print("size: {} x {} x {}".format(self.band_count, self.width, self.height))
|
|
233
|
+
print("pixel size: (x, y): ({}, {})".format(self.pixel_height, self.pixel_width))
|
|
234
|
+
print("bounds (xmin, xmax, ymin, ymax): ({}, {}, {}, {})".format(self.xmin, self.xmax, self.ymin, self.ymax))
|
|
235
|
+
print("crs: {}".format(self.crs))
|
|
236
|
+
|
|
237
|
+
def get_band_index(self, band: str | int):
|
|
238
|
+
"""
|
|
239
|
+
Utilizes the band_name_dict to convert a band name to an index if requried.
|
|
240
|
+
|
|
241
|
+
Parameters:
|
|
242
|
+
band : str or int
|
|
243
|
+
string representing a band or int representing a band
|
|
244
|
+
"""
|
|
245
|
+
if type(band) not in [str, int]:
|
|
246
|
+
raise TypeError("'band' parameter must be of type str or int.")
|
|
247
|
+
|
|
248
|
+
if self.closed:
|
|
249
|
+
raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
|
|
250
|
+
|
|
251
|
+
if type(band) == str:
|
|
252
|
+
band = self.band_name_dict[band]
|
|
253
|
+
|
|
254
|
+
return band
|
|
255
|
+
|
|
256
|
+
def load_arr(self, band_index: int):
|
|
257
|
+
"""
|
|
258
|
+
Loads the rasters gdal dataset into a numpy array.
|
|
259
|
+
|
|
260
|
+
Parameters:
|
|
261
|
+
band : int
|
|
262
|
+
integer representing band index
|
|
263
|
+
"""
|
|
264
|
+
if type(band_index) is not int:
|
|
265
|
+
raise TypeError("band_index' parameter must be of type int.")
|
|
266
|
+
|
|
267
|
+
if self.closed:
|
|
268
|
+
raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
|
|
269
|
+
|
|
270
|
+
self.band_data_dict[band_index] = np.asarray(
|
|
271
|
+
self.cpp_raster.get_raster_as_memoryview(self.width, self.height, band_index).toreadonly(),
|
|
272
|
+
copy=False
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def band(self, band: str | int):
|
|
276
|
+
"""
|
|
277
|
+
gets a numpy array with the specified bands data.
|
|
278
|
+
|
|
279
|
+
Parameters:
|
|
280
|
+
band : int | str
|
|
281
|
+
string or int representing band
|
|
282
|
+
"""
|
|
283
|
+
if type(band) not in [int, str]:
|
|
284
|
+
raise TypeError("'band' parameter must be of type int or str.")
|
|
285
|
+
|
|
286
|
+
if self.closed:
|
|
287
|
+
raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
|
|
288
|
+
|
|
289
|
+
index = self.get_band_index(band)
|
|
290
|
+
|
|
291
|
+
if index not in self.band_data_dict:
|
|
292
|
+
self.load_arr(index)
|
|
293
|
+
|
|
294
|
+
return self.band_data_dict[index]
|
|
295
|
+
|
|
296
|
+
def plot(self,
|
|
297
|
+
ax: Optional[matplotlib.axes.Axes] = None,
|
|
298
|
+
target_width: int = 1000,
|
|
299
|
+
target_height: int = 1000,
|
|
300
|
+
band: Optional[int | str] = None,
|
|
301
|
+
**kwargs):
|
|
302
|
+
"""
|
|
303
|
+
Calls plot_raster() on self.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
--------------------
|
|
307
|
+
ax : matplotlib.axes.Axes
|
|
308
|
+
axes to plot the raster on
|
|
309
|
+
target_width : int
|
|
310
|
+
maximum width in pixels for the image (after downsampling)
|
|
311
|
+
target_height : int
|
|
312
|
+
maximum height in pixels for the image (after downsampling)
|
|
313
|
+
band : int or str
|
|
314
|
+
specification of which bands to plot
|
|
315
|
+
**kwargs
|
|
316
|
+
any parameters which may be passed to matplotlib.pyplot.imshow
|
|
317
|
+
"""
|
|
318
|
+
if self.closed:
|
|
319
|
+
raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
|
|
320
|
+
|
|
321
|
+
if ax is not None:
|
|
322
|
+
plot_raster(self, ax, target_width, target_width, band, **kwargs)
|
|
323
|
+
else:
|
|
324
|
+
fig, ax = plt.subplots()
|
|
325
|
+
plot_raster(self, ax, target_width, target_width, band, **kwargs)
|
|
326
|
+
plt.show()
|
|
327
|
+
|
|
328
|
+
@classmethod
|
|
329
|
+
def from_rasterio(cls, ds, arr = None):
|
|
330
|
+
"""
|
|
331
|
+
This function is used to convert from a rasterio dataset object representing a raster into an sgspy.SpatialRaster
|
|
332
|
+
object. A np.ndarray may be passed as the 'arr' parameter, if so, the following must be true:
|
|
333
|
+
arr.shape == (ds.count, ds.height, ds.width)
|
|
334
|
+
|
|
335
|
+
Examples:
|
|
336
|
+
|
|
337
|
+
ds = rasterio.open("rast.tif")
|
|
338
|
+
rast = sgspy.SpatialRaster.from_rasterio(ds)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
ds = rasterio.open("rast.tif")
|
|
342
|
+
arr = ds.read()
|
|
343
|
+
arr[arr < 2] = np.nan
|
|
344
|
+
rast = sgspy.SpatialRaster.from_rasterio(ds, arr)
|
|
345
|
+
"""
|
|
346
|
+
if not RASTERIO:
|
|
347
|
+
raise RuntimeError("from_rasterio() can only be called if rasterio was successfully imported, but it wasn't.")
|
|
348
|
+
|
|
349
|
+
if type(ds) is not rasterio.io.DatasetReader and type(ds) is not rasterio.io.DatasetWriter:
|
|
350
|
+
raise TypeError("the ds parameter passed to from_raster() must be of type rasterio.io.DatasetReader or rasterio.io.DatasetWriter.")
|
|
351
|
+
|
|
352
|
+
if ds.driver == "MEM" and arr is None:
|
|
353
|
+
arr = ds.read()
|
|
354
|
+
|
|
355
|
+
if arr is not None:
|
|
356
|
+
if type(arr) is not np.ndarray:
|
|
357
|
+
raise TypeError("the 'arr' parameter, if passed, must be of type np.ndarray")
|
|
358
|
+
|
|
359
|
+
shape = arr.shape
|
|
360
|
+
if (len(shape)) == 2:
|
|
361
|
+
(height, width) = shape
|
|
362
|
+
if ds.count != 1:
|
|
363
|
+
raise RuntimeError("if the array parameter contains only a single band with shape (height, width), the raster must contain only a single band.")
|
|
364
|
+
else:
|
|
365
|
+
(band_count, height, width) = shape
|
|
366
|
+
if (band_count != ds.count):
|
|
367
|
+
raise RuntimeError("the array parameter must contains the same number of bands as the raster with shape (band_count, height, width).")
|
|
368
|
+
|
|
369
|
+
if height != ds.height:
|
|
370
|
+
raise RuntimeError("the height of the array passed must be equal to the height of the raster dataset.")
|
|
371
|
+
|
|
372
|
+
if width != ds.width:
|
|
373
|
+
raise RuntimeError("the width of the array passed must be equal to the width of the raster dataset.")
|
|
374
|
+
|
|
375
|
+
nan = ds.profile["nodata"]
|
|
376
|
+
if nan is None:
|
|
377
|
+
nan = np.nan
|
|
378
|
+
|
|
379
|
+
use_arr = True
|
|
380
|
+
else:
|
|
381
|
+
use_arr = False
|
|
382
|
+
|
|
383
|
+
if not use_arr:
|
|
384
|
+
#close the rasterio dataset, and open a GDALRasterWrapper of the file
|
|
385
|
+
filename = ds.name
|
|
386
|
+
ds.close()
|
|
387
|
+
return cls(filename)
|
|
388
|
+
else:
|
|
389
|
+
#create an in-memory dataset using the numpy array as the data, and the rasterio dataset to provide metadata
|
|
390
|
+
geotransform = ds.get_transform()
|
|
391
|
+
projection = ds.crs.wkt
|
|
392
|
+
arr = np.ascontiguousarray(arr)
|
|
393
|
+
buffer = memoryview(arr)
|
|
394
|
+
return cls(GDALRasterWrapper(buffer, geotransform, projection, [nan] * ds.count, ds.descriptions, PROJDB_PATH))
|
|
395
|
+
|
|
396
|
+
def to_rasterio(self, with_arr = False):
|
|
397
|
+
"""
|
|
398
|
+
This function is used to convert an sgspy.SpatialRaster into a rasterio dataset. If with_arr is set to True,
|
|
399
|
+
the function will return a numpy.ndarray as a tuple with the rasterio dataset object.
|
|
400
|
+
|
|
401
|
+
Examples:
|
|
402
|
+
|
|
403
|
+
rast = sgspy.SpatialRaster('rast.tif')
|
|
404
|
+
ds = rast.to_rasterio()
|
|
405
|
+
|
|
406
|
+
rast = sgspy.SpatialRaster('mraster.tif')
|
|
407
|
+
ds, arr = sgs.to_rasterio(with_arr=True)
|
|
408
|
+
"""
|
|
409
|
+
if type(with_arr) is not bool:
|
|
410
|
+
raise TypeError("'with_arr' parameter must be of type bool.")
|
|
411
|
+
|
|
412
|
+
if not RASTERIO:
|
|
413
|
+
raise RuntimeError("from_rasterio() can only be called if rasterio was successfully imported, but it wasn't.")
|
|
414
|
+
|
|
415
|
+
if self.closed:
|
|
416
|
+
raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
|
|
417
|
+
|
|
418
|
+
if (self.temp_dataset):
|
|
419
|
+
raise RuntimeError("the dataset has been saved as a temporary file which will be deleted when the C++ object containing it is deleted. the dataset must be either in-memory or have a filename.")
|
|
420
|
+
|
|
421
|
+
in_mem = self.driver.find("MEM") != -1
|
|
422
|
+
|
|
423
|
+
if with_arr or in_mem:
|
|
424
|
+
bands = []
|
|
425
|
+
for i in range(self.band_count):
|
|
426
|
+
bands.append(np.asarray(self.cpp_raster.get_raster_as_memoryview(self.width, self.height, i)))
|
|
427
|
+
|
|
428
|
+
#ensure numpy array doesn't accidentally get cleaned up by C++ object deletion
|
|
429
|
+
self.cpp_raster.release_band_buffers()
|
|
430
|
+
|
|
431
|
+
arr = np.stack(bands, axis=0)
|
|
432
|
+
|
|
433
|
+
if in_mem:
|
|
434
|
+
driver = "MEM"
|
|
435
|
+
width = self.width
|
|
436
|
+
height = self.height
|
|
437
|
+
count = self.band_count
|
|
438
|
+
crs = self.projection
|
|
439
|
+
gt = self.cpp_raster.get_geotransform()
|
|
440
|
+
transform = rasterio.transform.Affine(gt[1], gt[2], gt[0], gt[4], gt[5], gt[3]) #of course the rasterio transform has a different layout than a gdal geotransform... >:(
|
|
441
|
+
|
|
442
|
+
dtype = self.cpp_raster.get_data_type()
|
|
443
|
+
|
|
444
|
+
if dtype == "":
|
|
445
|
+
raise RuntimeError("sgs dataset has bands with different types, which is not supported by rasterio.")
|
|
446
|
+
|
|
447
|
+
nan = self.cpp_raster.get_band_nodata_value(0)
|
|
448
|
+
|
|
449
|
+
self.cpp_raster.close()
|
|
450
|
+
self.closed = True
|
|
451
|
+
|
|
452
|
+
ds = rasterio.MemoryFile().open(driver=driver, width=width, height=height, count=count, crs=crs, transform=transform, dtype=dtype, nodata=nan)
|
|
453
|
+
ds.write(arr)
|
|
454
|
+
|
|
455
|
+
for i in range(len(self.bands)):
|
|
456
|
+
ds.set_band_description(i + 1, self.bands[i])
|
|
457
|
+
else:
|
|
458
|
+
ds = rasterio.open(self.filename)
|
|
459
|
+
|
|
460
|
+
if with_arr:
|
|
461
|
+
return ds, arr
|
|
462
|
+
else:
|
|
463
|
+
return ds
|
|
464
|
+
|
|
465
|
+
@classmethod
|
|
466
|
+
def from_gdal(cls, ds, arr = None):
|
|
467
|
+
"""
|
|
468
|
+
This function is used to convert from a gdal.Dataset object representing a raster into an sgspy.SpatialRaster
|
|
469
|
+
object. A np.ndarray may be passed as the 'arr' parameter, if so, the following must be true:
|
|
470
|
+
arr.shape == (ds.RasterCount, ds.RasterYSize, ds.RasterXSize)
|
|
471
|
+
|
|
472
|
+
Examples:
|
|
473
|
+
|
|
474
|
+
ds = gdal.Open("rast.tif")
|
|
475
|
+
rast = sgspy.SpatialRaster.from_gdal(ds)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
ds = gdal.Open("rast.tif")
|
|
479
|
+
bands = []
|
|
480
|
+
for i in range(1, ds.RasterCount + 1):
|
|
481
|
+
bands.append(ds.GetRasterBand(1).ReadAsArray())
|
|
482
|
+
arr = np.stack(bands, axis=0)
|
|
483
|
+
arr[arr < 2] = np.nan
|
|
484
|
+
rast = sgspy.SpatialRaster.from_gdal(ds, arr)
|
|
485
|
+
"""
|
|
486
|
+
if not GDAL:
|
|
487
|
+
raise RuntimeError("from_gdal() can only be called if gdal was successfully imported, but it wasn't")
|
|
488
|
+
|
|
489
|
+
if type(ds) is not gdal.Dataset:
|
|
490
|
+
raise TypeError("the ds parameter passed to from_gdal() must be of type gdal.Dataset")
|
|
491
|
+
|
|
492
|
+
if ds.GetDriver().ShortName == "MEM" and arr is None:
|
|
493
|
+
bands = []
|
|
494
|
+
for i in range(1, ds.RasterCount + 1):
|
|
495
|
+
bands.append(ds.GetRasterBand(i).ReadAsArray())
|
|
496
|
+
arr = np.stack(bands, axis=0)
|
|
497
|
+
|
|
498
|
+
if arr is not None:
|
|
499
|
+
if type(arr) is not np.ndarray:
|
|
500
|
+
raise TypeError("'arr' parameter, if passed, must be of type np.ndarray")
|
|
501
|
+
|
|
502
|
+
shape = arr.shape
|
|
503
|
+
if (len(shape)) == 2:
|
|
504
|
+
(height, width) = shape
|
|
505
|
+
if ds.RasterCount != 1:
|
|
506
|
+
raise RuntimeError("if the array parameter contains only a single band with shape (height, width), the raster must contain only a single band.")
|
|
507
|
+
else:
|
|
508
|
+
(band_count, height, width) = shape
|
|
509
|
+
if (band_count != ds.RasterCount):
|
|
510
|
+
raise RuntimeError("the array parameter must contains the same number of bands as the raster with shape (band_count, height, width).")
|
|
511
|
+
|
|
512
|
+
if height != ds.RasterYSize:
|
|
513
|
+
raise RuntimeError("the height of the array passed must be equal to the height of the raster dataset.")
|
|
514
|
+
|
|
515
|
+
if width != ds.RasterXSize:
|
|
516
|
+
raise RuntimeError("the width of the array passed must be equal to the width of the raster dataset.")
|
|
517
|
+
|
|
518
|
+
nan_vals = []
|
|
519
|
+
band_names = []
|
|
520
|
+
for i in range(1, ds.RasterCount + 1):
|
|
521
|
+
band = ds.GetRasterBand(i)
|
|
522
|
+
nan_vals.append(band.GetNoDataValue())
|
|
523
|
+
band_names.append(band.GetDescription())
|
|
524
|
+
|
|
525
|
+
geotransform = ds.GetGeoTransform()
|
|
526
|
+
projection = ds.GetProjection()
|
|
527
|
+
arr = np.ascontiguousarray(arr)
|
|
528
|
+
buffer = memoryview(arr)
|
|
529
|
+
|
|
530
|
+
ds.Close()
|
|
531
|
+
return cls(GDALRasterWrapper(buffer, geotransform, projection, nan_vals, band_names, PROJDB_PATH))
|
|
532
|
+
else:
|
|
533
|
+
filename = ds.GetName()
|
|
534
|
+
|
|
535
|
+
ds.Close()
|
|
536
|
+
return cls(filename)
|
|
537
|
+
|
|
538
|
+
def to_gdal(self, with_arr = False):
|
|
539
|
+
"""
|
|
540
|
+
This function is used to convert an sgspy.SpatialRaster into a GDAL dataset. If with_arr is set to True,
|
|
541
|
+
the function will return a numpy.ndarray as a tuple with the GDAL dataset object.
|
|
542
|
+
|
|
543
|
+
Examples:
|
|
544
|
+
|
|
545
|
+
rast = sgspy.SpatialRaster('rast.tif')
|
|
546
|
+
ds = rast.to_gdal()
|
|
547
|
+
|
|
548
|
+
rast = sgspy.SpatialRaster('mraster.tif')
|
|
549
|
+
ds, arr = sgs.to_gdal(with_arr=True)
|
|
550
|
+
"""
|
|
551
|
+
if not GDAL:
|
|
552
|
+
raise RuntimeError("from_gdal() can only be called if gdal was successfully imported, but it wasn't")
|
|
553
|
+
|
|
554
|
+
if self.closed:
|
|
555
|
+
raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
|
|
556
|
+
|
|
557
|
+
if (self.temp_dataset):
|
|
558
|
+
raise RuntimeError("the dataset has been saved as a temporary file which will be deleted when the C++ object containing it is deleted. the dataset must be either in-memory or have a filename.")
|
|
559
|
+
|
|
560
|
+
in_mem = self.driver.find("MEM") != -1
|
|
561
|
+
|
|
562
|
+
if with_arr or in_mem:
|
|
563
|
+
bands = []
|
|
564
|
+
for i in range(self.band_count):
|
|
565
|
+
bands.append(np.asarray(self.cpp_raster.get_raster_as_memoryview(self.width, self.height, i)))
|
|
566
|
+
|
|
567
|
+
#ensure numpy array doesn't accidentally get cleaned up by C++ object deletion
|
|
568
|
+
self.cpp_raster.release_band_buffers()
|
|
569
|
+
|
|
570
|
+
arr = np.stack(bands, axis=0)
|
|
571
|
+
|
|
572
|
+
if in_mem:
|
|
573
|
+
geotransform = self.cpp_raster.get_geotransform()
|
|
574
|
+
projection = self.projection
|
|
575
|
+
nan_vals = []
|
|
576
|
+
for i in range(self.band_count):
|
|
577
|
+
nan_vals.append(self.cpp_raster.get_band_nodata_value(i))
|
|
578
|
+
band_names = self.bands
|
|
579
|
+
|
|
580
|
+
self.cpp_raster.close()
|
|
581
|
+
self.closed = True
|
|
582
|
+
|
|
583
|
+
ds = gdal.GetDriverByName("MEM").Create("", self.width, self.height, 0, gdal.GDT_Unknown)
|
|
584
|
+
ds.SetGeoTransform(geotransform)
|
|
585
|
+
ds.SetProjection(projection)
|
|
586
|
+
# NOTE: copying from an in-memory GDAL dataset then writing to a rasterio MemoryFile() may cause extra data copying.
|
|
587
|
+
#
|
|
588
|
+
# I'm dong it this way for now instead of somehow passing the data pointer directly, for fear of memory leaks/dangling pointers/accidentally deleting memory still in use.
|
|
589
|
+
for i in range(1, arr.shape[0] + 1):
|
|
590
|
+
band_arr = arr[i - 1]
|
|
591
|
+
ds.AddBand(gdal_array.NumericTypeCodeToGDALTypeCode(band_arr.dtype))
|
|
592
|
+
band = ds.GetRasterBand(i)
|
|
593
|
+
band.WriteArray(band_arr)
|
|
594
|
+
band.SetNoDataValue(nan_vals[i - 1])
|
|
595
|
+
band.SetDescription(band_names[i - 1])
|
|
596
|
+
else:
|
|
597
|
+
self.cpp_raster.close()
|
|
598
|
+
self.closed = True
|
|
599
|
+
|
|
600
|
+
ds = gdal.Open(self.filename)
|
|
601
|
+
|
|
602
|
+
if with_arr:
|
|
603
|
+
return ds, arr
|
|
604
|
+
else:
|
|
605
|
+
return ds
|