sgspy 1.0.2__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sgspy/__init__.py +82 -0
  2. sgspy/_sgs.cpython-310-x86_64-linux-gnu.so +0 -0
  3. sgspy/calculate/__init__.py +18 -0
  4. sgspy/calculate/pca/__init__.py +2 -0
  5. sgspy/calculate/pca/pca.py +158 -0
  6. sgspy/calculate/representation/__init__.py +2 -0
  7. sgspy/calculate/representation/representation.py +3 -0
  8. sgspy/sample/__init__.py +30 -0
  9. sgspy/sample/ahels/__init__.py +2 -0
  10. sgspy/sample/ahels/ahels.py +3 -0
  11. sgspy/sample/clhs/__init__.py +2 -0
  12. sgspy/sample/clhs/clhs.py +202 -0
  13. sgspy/sample/nc/__init__.py +2 -0
  14. sgspy/sample/nc/nc.py +3 -0
  15. sgspy/sample/srs/__init__.py +2 -0
  16. sgspy/sample/srs/srs.py +228 -0
  17. sgspy/sample/strat/__init__.py +2 -0
  18. sgspy/sample/strat/strat.py +394 -0
  19. sgspy/sample/systematic/__init__.py +2 -0
  20. sgspy/sample/systematic/systematic.py +233 -0
  21. sgspy/stratify/__init__.py +27 -0
  22. sgspy/stratify/breaks/__init__.py +2 -0
  23. sgspy/stratify/breaks/breaks.py +222 -0
  24. sgspy/stratify/kmeans/__init__.py +2 -0
  25. sgspy/stratify/kmeans/kmeans.py +3 -0
  26. sgspy/stratify/map/__init__.py +2 -0
  27. sgspy/stratify/map/map_stratifications.py +244 -0
  28. sgspy/stratify/poly/__init__.py +2 -0
  29. sgspy/stratify/poly/poly.py +170 -0
  30. sgspy/stratify/quantiles/__init__.py +2 -0
  31. sgspy/stratify/quantiles/quantiles.py +276 -0
  32. sgspy/utils/__init__.py +18 -0
  33. sgspy/utils/plot.py +143 -0
  34. sgspy/utils/raster.py +605 -0
  35. sgspy/utils/vector.py +268 -0
  36. sgspy-1.0.2.data/data/sgspy/libonedal.so.3 +0 -0
  37. sgspy-1.0.2.data/data/sgspy/proj.db +0 -0
  38. sgspy-1.0.2.dist-info/METADATA +13 -0
  39. sgspy-1.0.2.dist-info/RECORD +40 -0
  40. sgspy-1.0.2.dist-info/WHEEL +5 -0
sgspy/utils/raster.py ADDED
@@ -0,0 +1,605 @@
1
+ # ******************************************************************************
2
+ #
3
+ # Project: sgs
4
+ # Purpose: GDALDataset wrapper for raster operations
5
+ # Author: Joseph Meyer
6
+ # Date: June, 2025
7
+ #
8
+ # ******************************************************************************
9
+
10
+ import importlib.util
11
+ import sys
12
+ import os
13
+ import site
14
+ import shutil
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ import matplotlib #for type checking matplotlib.axes.Axes
20
+
21
+ from .import plot
22
+ from .plot import plot_raster
23
+
24
+ #ensure _sgs binary can be found
25
+ site_packages = list(filter(lambda x : 'site-packages' in x, site.getsitepackages()))[0]
26
+ sys.path.append(os.path.join(site_packages, "sgspy"))
27
+ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
28
+ from _sgs import GDALRasterWrapper
29
+
30
+ #rasterio optional import
31
+ try:
32
+ import rasterio
33
+ RASTERIO = True
34
+ except ImportError as e:
35
+ RASTERIO = False
36
+
37
+ #gdal optional import
38
+ try:
39
+ from osgeo import gdal
40
+ from osgeo import gdal_array
41
+ GDAL = True
42
+ except ImportError as e:
43
+ GDAL = False
44
+
45
+ PROJDB_PATH = os.path.join(sys.prefix, "sgspy")
46
+
47
+ ##
48
+ # @ingroup user_utils
49
+ # This class represents a spatial raster, and is used as an input to many sgs functions.
50
+ #
51
+ # It has a number of additional uses, including accessing the raster data within as a numpy array,
52
+ # plotting with matplotlib, as well as converting to a GDAL or Rasterio dataset object. This class
53
+ # also has various attributes representing metadata of the raster which may be useful and can be
54
+ # seen in the 'Public Attributes' section.
55
+ #
56
+ # Accessing raster data
57
+ # --------------------
58
+ #
59
+ # raster data can be accessed in the form of a NumPy array per band. This can be done using
60
+ # the 'band' function. The band function takes a single parameter, which must be either
61
+ # an integer or a string. If it is an integer, it must refer to a valid zero-indexed band
62
+ # number. If it is a string, it must refer to a valid band name within the raster. This function
63
+ # may fail if the band is too large to fit in memory.
64
+ #
65
+ # rast = sgspy.SpatialRaster('test.tif') #raster with three layers
66
+ #
67
+ # b0 = rast.band(band=0) @n
68
+ # b1 = rast.band(band=1) @n
69
+ # b2 = rast.band(band=2) @n
70
+ #
71
+ # zq90 = rast.band(band='zq90') @n
72
+ # pzabove2 = rast.band(band='pzabove2') @n
73
+ # zstd = rast.band(band='zstd') @n
74
+ #
75
+ # Accessing raster information
76
+ # --------------------
77
+ #
78
+ # raster metadata can be displayed using the info() function. Info
79
+ # inclues: raster driver, band names, dimensions, pixel size, and bounds.
80
+ #
81
+ # rast = sgspy.SpatialRaster('test.tif') @n
82
+ # rast.info()
83
+ #
84
+ # Plotting raster
85
+ # --------------------
86
+ #
87
+ # the plot() function provides a wrapper around matplotlibs imshow
88
+ # functionality (matplotlib.pyplot.imshow). Only a single band can
89
+ # be plotted, and for multi-band rasters an indication must be given
90
+ # for which band to plot.
91
+ #
92
+ # Target width and heights can be given in the parameters
93
+ # target_width and target_height. Default parameters are 1000 pixels for both.
94
+ # Information on the actual downsampling can be found here:
95
+ # https://gdal.org/en/stable/api/gdaldataset_cpp.html#classGDALDataset_1ae66e21b09000133a0f4d99baabf7a0ec
96
+ #
97
+ # If no 'band' argument is given, the function will throw an error if the
98
+ # image does not contain a single band.
99
+ #
100
+ # The 'band' argument allows the end-user to specify either the band
101
+ # index or the band name. 'band' may be an int or str.
102
+ #
103
+ # Optionally, any of the arguments which may be passed to the matplotlib
104
+ # imshow function may also be passed to plot_image(), such as cmap
105
+ # for a specific color mapping.
106
+ #
107
+ # #plots the single band @n
108
+ # rast = sgspy.SpatialRaster('test_single_band_raster.tif') @n
109
+ # rast.plot_image()
110
+ #
111
+ # #plots the second band @n
112
+ # rast = sgspy.SpatialRaster('test_multi_band_raster.tif') @n
113
+ # rast.plot(band=1)
114
+ #
115
+ # #plots the 'zq90' band @n
116
+ # rast = sgspy.SpatialRaster('test_multi_band_raster.tif') @n
117
+ # rast.plot(band='zq90')
118
+ #
119
+ # Public Attributes
120
+ # --------------------
121
+ # driver : str @n
122
+ # gdal dataset driver, for info/display purposes @n @n
123
+ # width : int @n
124
+ # the pixel width of the raster image @n @n
125
+ # height : int @n
126
+ # the pixel height of the raster image @n @n
127
+ # band_count : int @n
128
+ # the number of bands in the raster image @n @n
129
+ # bands : list[str] @n
130
+ # the raster band names @n @n
131
+ # crs : str @n
132
+ # coordinate reference system @n @n
133
+ # projection : str @n
134
+ # full projection string as wkt @n @n
135
+ # xmin : double @n
136
+ # minimum x value as defined by the gdal geotransform @n @n
137
+ # xmax : double @n
138
+ # maximum x value as defined by the gdal geotransform @n @n
139
+ # ymin : double @n
140
+ # minimum y value as defined by the gdal geotransform @n @n
141
+ # ymax : double @n
142
+ # maximum y value as defined by the gdal geotransform @n @n
143
+ # pixel_height : double @n
144
+ # pixel height as defined by the gdal geotransform @n @n
145
+ # pixel_width : double @n
146
+ # pixel width as defined by the gdal geotransform @n @n
147
+ #
148
+ # Public Methods
149
+ # --------------------
150
+ # info() @n
151
+ # takes no arguments, prints raster information to the console @n @n
152
+ # plot() @n
153
+ # takes one optional 'band' argument of type int, or str @n @n
154
+ # band() @n
155
+ # returns the band data as a numpy array, may throw an error if the raster band is too large
156
+ #
157
+ # Optionally, any of the arguments that can be passed to matplotlib.pyplot.imshow
158
+ # can also be passed to plot_image().
159
+ class SpatialRaster:
160
+
161
+ have_temp_dir = False
162
+ temp_dataset = False
163
+ filename = ""
164
+ closed = False
165
+
166
+ def __init__(self,
167
+ image: str | GDALRasterWrapper):
168
+ """
169
+ Constructing method for the SpatialRaster class.
170
+
171
+ Has one required parameter to specify a raster path. The following
172
+ attributes are populated:
173
+ self.cpp_raster
174
+ self.driver
175
+ self.width
176
+ self.height
177
+ self.band_count
178
+ self.crs
179
+ self.projection
180
+ self.xmin
181
+ self.xmax
182
+ self.ymin
183
+ self.ymax
184
+ self.pixel_height
185
+ self.pixel_width
186
+ self.bands
187
+
188
+ Parameters
189
+ --------------------
190
+ image : str
191
+ specifies a raster file path
192
+ """
193
+ if (type(image) is str):
194
+ self.cpp_raster = GDALRasterWrapper(image, PROJDB_PATH)
195
+ self.filename = image
196
+ elif type(image) is GDALRasterWrapper:
197
+ self.cpp_raster = image
198
+ else:
199
+ raise TypeError("'image' parameter of SpatialRaster constructor must be of type str or GDALRasterWrapper")
200
+
201
+ self.driver = self.cpp_raster.get_driver()
202
+ self.width = self.cpp_raster.get_width()
203
+ self.height = self.cpp_raster.get_height()
204
+ self.band_count = self.cpp_raster.get_band_count()
205
+ self.crs = self.cpp_raster.get_crs()
206
+ self.projection = self.cpp_raster.get_projection().encode('ascii', 'ignore').decode('unicode_escape')
207
+ self.xmin = self.cpp_raster.get_xmin()
208
+ self.xmax = self.cpp_raster.get_xmax()
209
+ self.ymin = self.cpp_raster.get_ymin()
210
+ self.ymax = self.cpp_raster.get_ymax()
211
+ self.pixel_width = self.cpp_raster.get_pixel_width()
212
+ self.pixel_height = self.cpp_raster.get_pixel_height()
213
+ self.band_name_dict = {}
214
+ self.band_data_dict = {}
215
+ self.bands = self.cpp_raster.get_bands()
216
+ for i in range(0, len(self.bands)):
217
+ self.band_name_dict[self.bands[i]] = i
218
+
219
+ def __del__(self):
220
+ if self.have_temp_dir:
221
+ shutil.rmtree(self.temp_dir)
222
+
223
+ def info(self):
224
+ """
225
+ Displays driver, band, size, pixel size, and bound information of the raster.
226
+ """
227
+ if self.closed:
228
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
229
+
230
+ print("driver: {}".format(self.driver))
231
+ print("bands: {}".format(*self.bands))
232
+ print("size: {} x {} x {}".format(self.band_count, self.width, self.height))
233
+ print("pixel size: (x, y): ({}, {})".format(self.pixel_height, self.pixel_width))
234
+ print("bounds (xmin, xmax, ymin, ymax): ({}, {}, {}, {})".format(self.xmin, self.xmax, self.ymin, self.ymax))
235
+ print("crs: {}".format(self.crs))
236
+
237
+ def get_band_index(self, band: str | int):
238
+ """
239
+ Utilizes the band_name_dict to convert a band name to an index if requried.
240
+
241
+ Parameters:
242
+ band : str or int
243
+ string representing a band or int representing a band
244
+ """
245
+ if type(band) not in [str, int]:
246
+ raise TypeError("'band' parameter must be of type str or int.")
247
+
248
+ if self.closed:
249
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
250
+
251
+ if type(band) == str:
252
+ band = self.band_name_dict[band]
253
+
254
+ return band
255
+
256
+ def load_arr(self, band_index: int):
257
+ """
258
+ Loads the rasters gdal dataset into a numpy array.
259
+
260
+ Parameters:
261
+ band : int
262
+ integer representing band index
263
+ """
264
+ if type(band_index) is not int:
265
+ raise TypeError("band_index' parameter must be of type int.")
266
+
267
+ if self.closed:
268
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
269
+
270
+ self.band_data_dict[band_index] = np.asarray(
271
+ self.cpp_raster.get_raster_as_memoryview(self.width, self.height, band_index).toreadonly(),
272
+ copy=False
273
+ )
274
+
275
+ def band(self, band: str | int):
276
+ """
277
+ gets a numpy array with the specified bands data.
278
+
279
+ Parameters:
280
+ band : int | str
281
+ string or int representing band
282
+ """
283
+ if type(band) not in [int, str]:
284
+ raise TypeError("'band' parameter must be of type int or str.")
285
+
286
+ if self.closed:
287
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
288
+
289
+ index = self.get_band_index(band)
290
+
291
+ if index not in self.band_data_dict:
292
+ self.load_arr(index)
293
+
294
+ return self.band_data_dict[index]
295
+
296
+ def plot(self,
297
+ ax: Optional[matplotlib.axes.Axes] = None,
298
+ target_width: int = 1000,
299
+ target_height: int = 1000,
300
+ band: Optional[int | str] = None,
301
+ **kwargs):
302
+ """
303
+ Calls plot_raster() on self.
304
+
305
+ Parameters
306
+ --------------------
307
+ ax : matplotlib.axes.Axes
308
+ axes to plot the raster on
309
+ target_width : int
310
+ maximum width in pixels for the image (after downsampling)
311
+ target_height : int
312
+ maximum height in pixels for the image (after downsampling)
313
+ band : int or str
314
+ specification of which bands to plot
315
+ **kwargs
316
+ any parameters which may be passed to matplotlib.pyplot.imshow
317
+ """
318
+ if self.closed:
319
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
320
+
321
+ if ax is not None:
322
+ plot_raster(self, ax, target_width, target_width, band, **kwargs)
323
+ else:
324
+ fig, ax = plt.subplots()
325
+ plot_raster(self, ax, target_width, target_width, band, **kwargs)
326
+ plt.show()
327
+
328
+ @classmethod
329
+ def from_rasterio(cls, ds, arr = None):
330
+ """
331
+ This function is used to convert from a rasterio dataset object representing a raster into an sgspy.SpatialRaster
332
+ object. A np.ndarray may be passed as the 'arr' parameter, if so, the following must be true:
333
+ arr.shape == (ds.count, ds.height, ds.width)
334
+
335
+ Examples:
336
+
337
+ ds = rasterio.open("rast.tif")
338
+ rast = sgspy.SpatialRaster.from_rasterio(ds)
339
+
340
+
341
+ ds = rasterio.open("rast.tif")
342
+ arr = ds.read()
343
+ arr[arr < 2] = np.nan
344
+ rast = sgspy.SpatialRaster.from_rasterio(ds, arr)
345
+ """
346
+ if not RASTERIO:
347
+ raise RuntimeError("from_rasterio() can only be called if rasterio was successfully imported, but it wasn't.")
348
+
349
+ if type(ds) is not rasterio.io.DatasetReader and type(ds) is not rasterio.io.DatasetWriter:
350
+ raise TypeError("the ds parameter passed to from_raster() must be of type rasterio.io.DatasetReader or rasterio.io.DatasetWriter.")
351
+
352
+ if ds.driver == "MEM" and arr is None:
353
+ arr = ds.read()
354
+
355
+ if arr is not None:
356
+ if type(arr) is not np.ndarray:
357
+ raise TypeError("the 'arr' parameter, if passed, must be of type np.ndarray")
358
+
359
+ shape = arr.shape
360
+ if (len(shape)) == 2:
361
+ (height, width) = shape
362
+ if ds.count != 1:
363
+ raise RuntimeError("if the array parameter contains only a single band with shape (height, width), the raster must contain only a single band.")
364
+ else:
365
+ (band_count, height, width) = shape
366
+ if (band_count != ds.count):
367
+ raise RuntimeError("the array parameter must contains the same number of bands as the raster with shape (band_count, height, width).")
368
+
369
+ if height != ds.height:
370
+ raise RuntimeError("the height of the array passed must be equal to the height of the raster dataset.")
371
+
372
+ if width != ds.width:
373
+ raise RuntimeError("the width of the array passed must be equal to the width of the raster dataset.")
374
+
375
+ nan = ds.profile["nodata"]
376
+ if nan is None:
377
+ nan = np.nan
378
+
379
+ use_arr = True
380
+ else:
381
+ use_arr = False
382
+
383
+ if not use_arr:
384
+ #close the rasterio dataset, and open a GDALRasterWrapper of the file
385
+ filename = ds.name
386
+ ds.close()
387
+ return cls(filename)
388
+ else:
389
+ #create an in-memory dataset using the numpy array as the data, and the rasterio dataset to provide metadata
390
+ geotransform = ds.get_transform()
391
+ projection = ds.crs.wkt
392
+ arr = np.ascontiguousarray(arr)
393
+ buffer = memoryview(arr)
394
+ return cls(GDALRasterWrapper(buffer, geotransform, projection, [nan] * ds.count, ds.descriptions, PROJDB_PATH))
395
+
396
+ def to_rasterio(self, with_arr = False):
397
+ """
398
+ This function is used to convert an sgspy.SpatialRaster into a rasterio dataset. If with_arr is set to True,
399
+ the function will return a numpy.ndarray as a tuple with the rasterio dataset object.
400
+
401
+ Examples:
402
+
403
+ rast = sgspy.SpatialRaster('rast.tif')
404
+ ds = rast.to_rasterio()
405
+
406
+ rast = sgspy.SpatialRaster('mraster.tif')
407
+ ds, arr = sgs.to_rasterio(with_arr=True)
408
+ """
409
+ if type(with_arr) is not bool:
410
+ raise TypeError("'with_arr' parameter must be of type bool.")
411
+
412
+ if not RASTERIO:
413
+ raise RuntimeError("from_rasterio() can only be called if rasterio was successfully imported, but it wasn't.")
414
+
415
+ if self.closed:
416
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
417
+
418
+ if (self.temp_dataset):
419
+ raise RuntimeError("the dataset has been saved as a temporary file which will be deleted when the C++ object containing it is deleted. the dataset must be either in-memory or have a filename.")
420
+
421
+ in_mem = self.driver.find("MEM") != -1
422
+
423
+ if with_arr or in_mem:
424
+ bands = []
425
+ for i in range(self.band_count):
426
+ bands.append(np.asarray(self.cpp_raster.get_raster_as_memoryview(self.width, self.height, i)))
427
+
428
+ #ensure numpy array doesn't accidentally get cleaned up by C++ object deletion
429
+ self.cpp_raster.release_band_buffers()
430
+
431
+ arr = np.stack(bands, axis=0)
432
+
433
+ if in_mem:
434
+ driver = "MEM"
435
+ width = self.width
436
+ height = self.height
437
+ count = self.band_count
438
+ crs = self.projection
439
+ gt = self.cpp_raster.get_geotransform()
440
+ transform = rasterio.transform.Affine(gt[1], gt[2], gt[0], gt[4], gt[5], gt[3]) #of course the rasterio transform has a different layout than a gdal geotransform... >:(
441
+
442
+ dtype = self.cpp_raster.get_data_type()
443
+
444
+ if dtype == "":
445
+ raise RuntimeError("sgs dataset has bands with different types, which is not supported by rasterio.")
446
+
447
+ nan = self.cpp_raster.get_band_nodata_value(0)
448
+
449
+ self.cpp_raster.close()
450
+ self.closed = True
451
+
452
+ ds = rasterio.MemoryFile().open(driver=driver, width=width, height=height, count=count, crs=crs, transform=transform, dtype=dtype, nodata=nan)
453
+ ds.write(arr)
454
+
455
+ for i in range(len(self.bands)):
456
+ ds.set_band_description(i + 1, self.bands[i])
457
+ else:
458
+ ds = rasterio.open(self.filename)
459
+
460
+ if with_arr:
461
+ return ds, arr
462
+ else:
463
+ return ds
464
+
465
+ @classmethod
466
+ def from_gdal(cls, ds, arr = None):
467
+ """
468
+ This function is used to convert from a gdal.Dataset object representing a raster into an sgspy.SpatialRaster
469
+ object. A np.ndarray may be passed as the 'arr' parameter, if so, the following must be true:
470
+ arr.shape == (ds.RasterCount, ds.RasterYSize, ds.RasterXSize)
471
+
472
+ Examples:
473
+
474
+ ds = gdal.Open("rast.tif")
475
+ rast = sgspy.SpatialRaster.from_gdal(ds)
476
+
477
+
478
+ ds = gdal.Open("rast.tif")
479
+ bands = []
480
+ for i in range(1, ds.RasterCount + 1):
481
+ bands.append(ds.GetRasterBand(1).ReadAsArray())
482
+ arr = np.stack(bands, axis=0)
483
+ arr[arr < 2] = np.nan
484
+ rast = sgspy.SpatialRaster.from_gdal(ds, arr)
485
+ """
486
+ if not GDAL:
487
+ raise RuntimeError("from_gdal() can only be called if gdal was successfully imported, but it wasn't")
488
+
489
+ if type(ds) is not gdal.Dataset:
490
+ raise TypeError("the ds parameter passed to from_gdal() must be of type gdal.Dataset")
491
+
492
+ if ds.GetDriver().ShortName == "MEM" and arr is None:
493
+ bands = []
494
+ for i in range(1, ds.RasterCount + 1):
495
+ bands.append(ds.GetRasterBand(i).ReadAsArray())
496
+ arr = np.stack(bands, axis=0)
497
+
498
+ if arr is not None:
499
+ if type(arr) is not np.ndarray:
500
+ raise TypeError("'arr' parameter, if passed, must be of type np.ndarray")
501
+
502
+ shape = arr.shape
503
+ if (len(shape)) == 2:
504
+ (height, width) = shape
505
+ if ds.RasterCount != 1:
506
+ raise RuntimeError("if the array parameter contains only a single band with shape (height, width), the raster must contain only a single band.")
507
+ else:
508
+ (band_count, height, width) = shape
509
+ if (band_count != ds.RasterCount):
510
+ raise RuntimeError("the array parameter must contains the same number of bands as the raster with shape (band_count, height, width).")
511
+
512
+ if height != ds.RasterYSize:
513
+ raise RuntimeError("the height of the array passed must be equal to the height of the raster dataset.")
514
+
515
+ if width != ds.RasterXSize:
516
+ raise RuntimeError("the width of the array passed must be equal to the width of the raster dataset.")
517
+
518
+ nan_vals = []
519
+ band_names = []
520
+ for i in range(1, ds.RasterCount + 1):
521
+ band = ds.GetRasterBand(i)
522
+ nan_vals.append(band.GetNoDataValue())
523
+ band_names.append(band.GetDescription())
524
+
525
+ geotransform = ds.GetGeoTransform()
526
+ projection = ds.GetProjection()
527
+ arr = np.ascontiguousarray(arr)
528
+ buffer = memoryview(arr)
529
+
530
+ ds.Close()
531
+ return cls(GDALRasterWrapper(buffer, geotransform, projection, nan_vals, band_names, PROJDB_PATH))
532
+ else:
533
+ filename = ds.GetName()
534
+
535
+ ds.Close()
536
+ return cls(filename)
537
+
538
+ def to_gdal(self, with_arr = False):
539
+ """
540
+ This function is used to convert an sgspy.SpatialRaster into a GDAL dataset. If with_arr is set to True,
541
+ the function will return a numpy.ndarray as a tuple with the GDAL dataset object.
542
+
543
+ Examples:
544
+
545
+ rast = sgspy.SpatialRaster('rast.tif')
546
+ ds = rast.to_gdal()
547
+
548
+ rast = sgspy.SpatialRaster('mraster.tif')
549
+ ds, arr = sgs.to_gdal(with_arr=True)
550
+ """
551
+ if not GDAL:
552
+ raise RuntimeError("from_gdal() can only be called if gdal was successfully imported, but it wasn't")
553
+
554
+ if self.closed:
555
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
556
+
557
+ if (self.temp_dataset):
558
+ raise RuntimeError("the dataset has been saved as a temporary file which will be deleted when the C++ object containing it is deleted. the dataset must be either in-memory or have a filename.")
559
+
560
+ in_mem = self.driver.find("MEM") != -1
561
+
562
+ if with_arr or in_mem:
563
+ bands = []
564
+ for i in range(self.band_count):
565
+ bands.append(np.asarray(self.cpp_raster.get_raster_as_memoryview(self.width, self.height, i)))
566
+
567
+ #ensure numpy array doesn't accidentally get cleaned up by C++ object deletion
568
+ self.cpp_raster.release_band_buffers()
569
+
570
+ arr = np.stack(bands, axis=0)
571
+
572
+ if in_mem:
573
+ geotransform = self.cpp_raster.get_geotransform()
574
+ projection = self.projection
575
+ nan_vals = []
576
+ for i in range(self.band_count):
577
+ nan_vals.append(self.cpp_raster.get_band_nodata_value(i))
578
+ band_names = self.bands
579
+
580
+ self.cpp_raster.close()
581
+ self.closed = True
582
+
583
+ ds = gdal.GetDriverByName("MEM").Create("", self.width, self.height, 0, gdal.GDT_Unknown)
584
+ ds.SetGeoTransform(geotransform)
585
+ ds.SetProjection(projection)
586
+ # NOTE: copying from an in-memory GDAL dataset then writing to a rasterio MemoryFile() may cause extra data copying.
587
+ #
588
+ # I'm dong it this way for now instead of somehow passing the data pointer directly, for fear of memory leaks/dangling pointers/accidentally deleting memory still in use.
589
+ for i in range(1, arr.shape[0] + 1):
590
+ band_arr = arr[i - 1]
591
+ ds.AddBand(gdal_array.NumericTypeCodeToGDALTypeCode(band_arr.dtype))
592
+ band = ds.GetRasterBand(i)
593
+ band.WriteArray(band_arr)
594
+ band.SetNoDataValue(nan_vals[i - 1])
595
+ band.SetDescription(band_names[i - 1])
596
+ else:
597
+ self.cpp_raster.close()
598
+ self.closed = True
599
+
600
+ ds = gdal.Open(self.filename)
601
+
602
+ if with_arr:
603
+ return ds, arr
604
+ else:
605
+ return ds