sgspy 1.0.1__cp312-cp312-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sgspy/__init__.py +82 -0
  2. sgspy/_sgs.cpython-312-x86_64-linux-gnu.so +0 -0
  3. sgspy/calculate/__init__.py +18 -0
  4. sgspy/calculate/pca/__init__.py +2 -0
  5. sgspy/calculate/pca/pca.py +152 -0
  6. sgspy/calculate/representation/__init__.py +2 -0
  7. sgspy/calculate/representation/representation.py +3 -0
  8. sgspy/sample/__init__.py +30 -0
  9. sgspy/sample/ahels/__init__.py +2 -0
  10. sgspy/sample/ahels/ahels.py +3 -0
  11. sgspy/sample/clhs/__init__.py +2 -0
  12. sgspy/sample/clhs/clhs.py +198 -0
  13. sgspy/sample/nc/__init__.py +2 -0
  14. sgspy/sample/nc/nc.py +3 -0
  15. sgspy/sample/srs/__init__.py +2 -0
  16. sgspy/sample/srs/srs.py +224 -0
  17. sgspy/sample/strat/__init__.py +2 -0
  18. sgspy/sample/strat/strat.py +390 -0
  19. sgspy/sample/systematic/__init__.py +2 -0
  20. sgspy/sample/systematic/systematic.py +229 -0
  21. sgspy/stratify/__init__.py +27 -0
  22. sgspy/stratify/breaks/__init__.py +2 -0
  23. sgspy/stratify/breaks/breaks.py +218 -0
  24. sgspy/stratify/kmeans/__init__.py +2 -0
  25. sgspy/stratify/kmeans/kmeans.py +3 -0
  26. sgspy/stratify/map/__init__.py +2 -0
  27. sgspy/stratify/map/map_stratifications.py +240 -0
  28. sgspy/stratify/poly/__init__.py +2 -0
  29. sgspy/stratify/poly/poly.py +166 -0
  30. sgspy/stratify/quantiles/__init__.py +2 -0
  31. sgspy/stratify/quantiles/quantiles.py +272 -0
  32. sgspy/utils/__init__.py +18 -0
  33. sgspy/utils/plot.py +143 -0
  34. sgspy/utils/raster.py +602 -0
  35. sgspy/utils/vector.py +262 -0
  36. sgspy-1.0.1.data/data/sgspy/libonedal.so.3 +0 -0
  37. sgspy-1.0.1.data/data/sgspy/proj.db +0 -0
  38. sgspy-1.0.1.dist-info/METADATA +13 -0
  39. sgspy-1.0.1.dist-info/RECORD +40 -0
  40. sgspy-1.0.1.dist-info/WHEEL +5 -0
sgspy/utils/raster.py ADDED
@@ -0,0 +1,602 @@
1
+ # ******************************************************************************
2
+ #
3
+ # Project: sgs
4
+ # Purpose: GDALDataset wrapper for raster operations
5
+ # Author: Joseph Meyer
6
+ # Date: June, 2025
7
+ #
8
+ # ******************************************************************************
9
+
10
+ import importlib.util
11
+ import sys
12
+ import os
13
+ import shutil
14
+ from typing import Optional
15
+
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib #for type checking matplotlib.axes.Axes
19
+
20
+ from .import plot
21
+ from .plot import plot_raster
22
+
23
+ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
24
+ print(os.path.join(os.path.dirname(__file__), ".."))
25
+ from _sgs import GDALRasterWrapper
26
+
27
+ #rasterio optional import
28
+ try:
29
+ import rasterio
30
+ RASTERIO = True
31
+ except ImportError as e:
32
+ RASTERIO = False
33
+
34
+ #gdal optional import
35
+ try:
36
+ from osgeo import gdal
37
+ from osgeo import gdal_array
38
+ GDAL = True
39
+ except ImportError as e:
40
+ GDAL = False
41
+
42
+ PROJDB_PATH = os.path.join(sys.prefix, "sgspy")
43
+
44
+ ##
45
+ # @ingroup user_utils
46
+ # This class represents a spatial raster, and is used as an input to many sgs functions.
47
+ #
48
+ # It has a number of additional uses, including accessing the raster data within as a numpy array,
49
+ # plotting with matplotlib, as well as converting to a GDAL or Rasterio dataset object. This class
50
+ # also has various attributes representing metadata of the raster which may be useful and can be
51
+ # seen in the 'Public Attributes' section.
52
+ #
53
+ # Accessing raster data
54
+ # --------------------
55
+ #
56
+ # raster data can be accessed in the form of a NumPy array per band. This can be done using
57
+ # the 'band' function. The band function takes a single parameter, which must be either
58
+ # an integer or a string. If it is an integer, it must refer to a valid zero-indexed band
59
+ # number. If it is a string, it must refer to a valid band name within the raster. This function
60
+ # may fail if the band is too large to fit in memory.
61
+ #
62
+ # rast = sgspy.SpatialRaster('test.tif') #raster with three layers
63
+ #
64
+ # b0 = rast.band(band=0) @n
65
+ # b1 = rast.band(band=1) @n
66
+ # b2 = rast.band(band=2) @n
67
+ #
68
+ # zq90 = rast.band(band='zq90') @n
69
+ # pzabove2 = rast.band(band='pzabove2') @n
70
+ # zstd = rast.band(band='zstd') @n
71
+ #
72
+ # Accessing raster information
73
+ # --------------------
74
+ #
75
+ # raster metadata can be displayed using the info() function. Info
76
+ # inclues: raster driver, band names, dimensions, pixel size, and bounds.
77
+ #
78
+ # rast = sgspy.SpatialRaster('test.tif') @n
79
+ # rast.info()
80
+ #
81
+ # Plotting raster
82
+ # --------------------
83
+ #
84
+ # the plot() function provides a wrapper around matplotlibs imshow
85
+ # functionality (matplotlib.pyplot.imshow). Only a single band can
86
+ # be plotted, and for multi-band rasters an indication must be given
87
+ # for which band to plot.
88
+ #
89
+ # Target width and heights can be given in the parameters
90
+ # target_width and target_height. Default parameters are 1000 pixels for both.
91
+ # Information on the actual downsampling can be found here:
92
+ # https://gdal.org/en/stable/api/gdaldataset_cpp.html#classGDALDataset_1ae66e21b09000133a0f4d99baabf7a0ec
93
+ #
94
+ # If no 'band' argument is given, the function will throw an error if the
95
+ # image does not contain a single band.
96
+ #
97
+ # The 'band' argument allows the end-user to specify either the band
98
+ # index or the band name. 'band' may be an int or str.
99
+ #
100
+ # Optionally, any of the arguments which may be passed to the matplotlib
101
+ # imshow function may also be passed to plot_image(), such as cmap
102
+ # for a specific color mapping.
103
+ #
104
+ # #plots the single band @n
105
+ # rast = sgspy.SpatialRaster('test_single_band_raster.tif') @n
106
+ # rast.plot_image()
107
+ #
108
+ # #plots the second band @n
109
+ # rast = sgspy.SpatialRaster('test_multi_band_raster.tif') @n
110
+ # rast.plot(band=1)
111
+ #
112
+ # #plots the 'zq90' band @n
113
+ # rast = sgspy.SpatialRaster('test_multi_band_raster.tif') @n
114
+ # rast.plot(band='zq90')
115
+ #
116
+ # Public Attributes
117
+ # --------------------
118
+ # driver : str @n
119
+ # gdal dataset driver, for info/display purposes @n @n
120
+ # width : int @n
121
+ # the pixel width of the raster image @n @n
122
+ # height : int @n
123
+ # the pixel height of the raster image @n @n
124
+ # band_count : int @n
125
+ # the number of bands in the raster image @n @n
126
+ # bands : list[str] @n
127
+ # the raster band names @n @n
128
+ # crs : str @n
129
+ # coordinate reference system @n @n
130
+ # projection : str @n
131
+ # full projection string as wkt @n @n
132
+ # xmin : double @n
133
+ # minimum x value as defined by the gdal geotransform @n @n
134
+ # xmax : double @n
135
+ # maximum x value as defined by the gdal geotransform @n @n
136
+ # ymin : double @n
137
+ # minimum y value as defined by the gdal geotransform @n @n
138
+ # ymax : double @n
139
+ # maximum y value as defined by the gdal geotransform @n @n
140
+ # pixel_height : double @n
141
+ # pixel height as defined by the gdal geotransform @n @n
142
+ # pixel_width : double @n
143
+ # pixel width as defined by the gdal geotransform @n @n
144
+ #
145
+ # Public Methods
146
+ # --------------------
147
+ # info() @n
148
+ # takes no arguments, prints raster information to the console @n @n
149
+ # plot() @n
150
+ # takes one optional 'band' argument of type int, or str @n @n
151
+ # band() @n
152
+ # returns the band data as a numpy array, may throw an error if the raster band is too large
153
+ #
154
+ # Optionally, any of the arguments that can be passed to matplotlib.pyplot.imshow
155
+ # can also be passed to plot_image().
156
+ class SpatialRaster:
157
+
158
+ have_temp_dir = False
159
+ temp_dataset = False
160
+ filename = ""
161
+ closed = False
162
+
163
+ def __init__(self,
164
+ image: str | GDALRasterWrapper):
165
+ """
166
+ Constructing method for the SpatialRaster class.
167
+
168
+ Has one required parameter to specify a raster path. The following
169
+ attributes are populated:
170
+ self.cpp_raster
171
+ self.driver
172
+ self.width
173
+ self.height
174
+ self.band_count
175
+ self.crs
176
+ self.projection
177
+ self.xmin
178
+ self.xmax
179
+ self.ymin
180
+ self.ymax
181
+ self.pixel_height
182
+ self.pixel_width
183
+ self.bands
184
+
185
+ Parameters
186
+ --------------------
187
+ image : str
188
+ specifies a raster file path
189
+ """
190
+ if (type(image) is str):
191
+ self.cpp_raster = GDALRasterWrapper(image, PROJDB_PATH)
192
+ self.filename = image
193
+ elif type(image) is GDALRasterWrapper:
194
+ self.cpp_raster = image
195
+ else:
196
+ raise TypeError("'image' parameter of SpatialRaster constructor must be of type str or GDALRasterWrapper")
197
+
198
+ self.driver = self.cpp_raster.get_driver()
199
+ self.width = self.cpp_raster.get_width()
200
+ self.height = self.cpp_raster.get_height()
201
+ self.band_count = self.cpp_raster.get_band_count()
202
+ self.crs = self.cpp_raster.get_crs()
203
+ self.projection = self.cpp_raster.get_projection().encode('ascii', 'ignore').decode('unicode_escape')
204
+ self.xmin = self.cpp_raster.get_xmin()
205
+ self.xmax = self.cpp_raster.get_xmax()
206
+ self.ymin = self.cpp_raster.get_ymin()
207
+ self.ymax = self.cpp_raster.get_ymax()
208
+ self.pixel_width = self.cpp_raster.get_pixel_width()
209
+ self.pixel_height = self.cpp_raster.get_pixel_height()
210
+ self.band_name_dict = {}
211
+ self.band_data_dict = {}
212
+ self.bands = self.cpp_raster.get_bands()
213
+ for i in range(0, len(self.bands)):
214
+ self.band_name_dict[self.bands[i]] = i
215
+
216
+ def __del__(self):
217
+ if self.have_temp_dir:
218
+ shutil.rmtree(self.temp_dir)
219
+
220
+ def info(self):
221
+ """
222
+ Displays driver, band, size, pixel size, and bound information of the raster.
223
+ """
224
+ if self.closed:
225
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
226
+
227
+ print("driver: {}".format(self.driver))
228
+ print("bands: {}".format(*self.bands))
229
+ print("size: {} x {} x {}".format(self.band_count, self.width, self.height))
230
+ print("pixel size: (x, y): ({}, {})".format(self.pixel_height, self.pixel_width))
231
+ print("bounds (xmin, xmax, ymin, ymax): ({}, {}, {}, {})".format(self.xmin, self.xmax, self.ymin, self.ymax))
232
+ print("crs: {}".format(self.crs))
233
+
234
+ def get_band_index(self, band: str | int):
235
+ """
236
+ Utilizes the band_name_dict to convert a band name to an index if requried.
237
+
238
+ Parameters:
239
+ band : str or int
240
+ string representing a band or int representing a band
241
+ """
242
+ if type(band) not in [str, int]:
243
+ raise TypeError("'band' parameter must be of type str or int.")
244
+
245
+ if self.closed:
246
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
247
+
248
+ if type(band) == str:
249
+ band = self.band_name_dict[band]
250
+
251
+ return band
252
+
253
+ def load_arr(self, band_index: int):
254
+ """
255
+ Loads the rasters gdal dataset into a numpy array.
256
+
257
+ Parameters:
258
+ band : int
259
+ integer representing band index
260
+ """
261
+ if type(band_index) is not int:
262
+ raise TypeError("band_index' parameter must be of type int.")
263
+
264
+ if self.closed:
265
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
266
+
267
+ self.band_data_dict[band_index] = np.asarray(
268
+ self.cpp_raster.get_raster_as_memoryview(self.width, self.height, band_index).toreadonly(),
269
+ copy=False
270
+ )
271
+
272
+ def band(self, band: str | int):
273
+ """
274
+ gets a numpy array with the specified bands data.
275
+
276
+ Parameters:
277
+ band : int | str
278
+ string or int representing band
279
+ """
280
+ if type(band) not in [int, str]:
281
+ raise TypeError("'band' parameter must be of type int or str.")
282
+
283
+ if self.closed:
284
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
285
+
286
+ index = self.get_band_index(band)
287
+
288
+ if index not in self.band_data_dict:
289
+ self.load_arr(index)
290
+
291
+ return self.band_data_dict[index]
292
+
293
+ def plot(self,
294
+ ax: Optional[matplotlib.axes.Axes] = None,
295
+ target_width: int = 1000,
296
+ target_height: int = 1000,
297
+ band: Optional[int | str] = None,
298
+ **kwargs):
299
+ """
300
+ Calls plot_raster() on self.
301
+
302
+ Parameters
303
+ --------------------
304
+ ax : matplotlib.axes.Axes
305
+ axes to plot the raster on
306
+ target_width : int
307
+ maximum width in pixels for the image (after downsampling)
308
+ target_height : int
309
+ maximum height in pixels for the image (after downsampling)
310
+ band : int or str
311
+ specification of which bands to plot
312
+ **kwargs
313
+ any parameters which may be passed to matplotlib.pyplot.imshow
314
+ """
315
+ if self.closed:
316
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
317
+
318
+ if ax is not None:
319
+ plot_raster(self, ax, target_width, target_width, band, **kwargs)
320
+ else:
321
+ fig, ax = plt.subplots()
322
+ plot_raster(self, ax, target_width, target_width, band, **kwargs)
323
+ plt.show()
324
+
325
+ @classmethod
326
+ def from_rasterio(cls, ds, arr = None):
327
+ """
328
+ This function is used to convert from a rasterio dataset object representing a raster into an sgspy.SpatialRaster
329
+ object. A np.ndarray may be passed as the 'arr' parameter, if so, the following must be true:
330
+ arr.shape == (ds.count, ds.height, ds.width)
331
+
332
+ Examples:
333
+
334
+ ds = rasterio.open("rast.tif")
335
+ rast = sgspy.SpatialRaster.from_rasterio(ds)
336
+
337
+
338
+ ds = rasterio.open("rast.tif")
339
+ arr = ds.read()
340
+ arr[arr < 2] = np.nan
341
+ rast = sgspy.SpatialRaster.from_rasterio(ds, arr)
342
+ """
343
+ if not RASTERIO:
344
+ raise RuntimeError("from_rasterio() can only be called if rasterio was successfully imported, but it wasn't.")
345
+
346
+ if type(ds) is not rasterio.io.DatasetReader and type(ds) is not rasterio.io.DatasetWriter:
347
+ raise TypeError("the ds parameter passed to from_raster() must be of type rasterio.io.DatasetReader or rasterio.io.DatasetWriter.")
348
+
349
+ if ds.driver == "MEM" and arr is None:
350
+ arr = ds.read()
351
+
352
+ if arr is not None:
353
+ if type(arr) is not np.ndarray:
354
+ raise TypeError("the 'arr' parameter, if passed, must be of type np.ndarray")
355
+
356
+ shape = arr.shape
357
+ if (len(shape)) == 2:
358
+ (height, width) = shape
359
+ if ds.count != 1:
360
+ raise RuntimeError("if the array parameter contains only a single band with shape (height, width), the raster must contain only a single band.")
361
+ else:
362
+ (band_count, height, width) = shape
363
+ if (band_count != ds.count):
364
+ raise RuntimeError("the array parameter must contains the same number of bands as the raster with shape (band_count, height, width).")
365
+
366
+ if height != ds.height:
367
+ raise RuntimeError("the height of the array passed must be equal to the height of the raster dataset.")
368
+
369
+ if width != ds.width:
370
+ raise RuntimeError("the width of the array passed must be equal to the width of the raster dataset.")
371
+
372
+ nan = ds.profile["nodata"]
373
+ if nan is None:
374
+ nan = np.nan
375
+
376
+ use_arr = True
377
+ else:
378
+ use_arr = False
379
+
380
+ if not use_arr:
381
+ #close the rasterio dataset, and open a GDALRasterWrapper of the file
382
+ filename = ds.name
383
+ ds.close()
384
+ return cls(filename)
385
+ else:
386
+ #create an in-memory dataset using the numpy array as the data, and the rasterio dataset to provide metadata
387
+ geotransform = ds.get_transform()
388
+ projection = ds.crs.wkt
389
+ arr = np.ascontiguousarray(arr)
390
+ buffer = memoryview(arr)
391
+ return cls(GDALRasterWrapper(buffer, geotransform, projection, [nan] * ds.count, ds.descriptions, PROJDB_PATH))
392
+
393
+ def to_rasterio(self, with_arr = False):
394
+ """
395
+ This function is used to convert an sgspy.SpatialRaster into a rasterio dataset. If with_arr is set to True,
396
+ the function will return a numpy.ndarray as a tuple with the rasterio dataset object.
397
+
398
+ Examples:
399
+
400
+ rast = sgspy.SpatialRaster('rast.tif')
401
+ ds = rast.to_rasterio()
402
+
403
+ rast = sgspy.SpatialRaster('mraster.tif')
404
+ ds, arr = sgs.to_rasterio(with_arr=True)
405
+ """
406
+ if type(with_arr) is not bool:
407
+ raise TypeError("'with_arr' parameter must be of type bool.")
408
+
409
+ if not RASTERIO:
410
+ raise RuntimeError("from_rasterio() can only be called if rasterio was successfully imported, but it wasn't.")
411
+
412
+ if self.closed:
413
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
414
+
415
+ if (self.temp_dataset):
416
+ raise RuntimeError("the dataset has been saved as a temporary file which will be deleted when the C++ object containing it is deleted. the dataset must be either in-memory or have a filename.")
417
+
418
+ in_mem = self.driver.find("MEM") != -1
419
+
420
+ if with_arr or in_mem:
421
+ bands = []
422
+ for i in range(self.band_count):
423
+ bands.append(np.asarray(self.cpp_raster.get_raster_as_memoryview(self.width, self.height, i)))
424
+
425
+ #ensure numpy array doesn't accidentally get cleaned up by C++ object deletion
426
+ self.cpp_raster.release_band_buffers()
427
+
428
+ arr = np.stack(bands, axis=0)
429
+
430
+ if in_mem:
431
+ driver = "MEM"
432
+ width = self.width
433
+ height = self.height
434
+ count = self.band_count
435
+ crs = self.projection
436
+ gt = self.cpp_raster.get_geotransform()
437
+ transform = rasterio.transform.Affine(gt[1], gt[2], gt[0], gt[4], gt[5], gt[3]) #of course the rasterio transform has a different layout than a gdal geotransform... >:(
438
+
439
+ dtype = self.cpp_raster.get_data_type()
440
+
441
+ if dtype == "":
442
+ raise RuntimeError("sgs dataset has bands with different types, which is not supported by rasterio.")
443
+
444
+ nan = self.cpp_raster.get_band_nodata_value(0)
445
+
446
+ self.cpp_raster.close()
447
+ self.closed = True
448
+
449
+ ds = rasterio.MemoryFile().open(driver=driver, width=width, height=height, count=count, crs=crs, transform=transform, dtype=dtype, nodata=nan)
450
+ ds.write(arr)
451
+
452
+ for i in range(len(self.bands)):
453
+ ds.set_band_description(i + 1, self.bands[i])
454
+ else:
455
+ ds = rasterio.open(self.filename)
456
+
457
+ if with_arr:
458
+ return ds, arr
459
+ else:
460
+ return ds
461
+
462
+ @classmethod
463
+ def from_gdal(cls, ds, arr = None):
464
+ """
465
+ This function is used to convert from a gdal.Dataset object representing a raster into an sgspy.SpatialRaster
466
+ object. A np.ndarray may be passed as the 'arr' parameter, if so, the following must be true:
467
+ arr.shape == (ds.RasterCount, ds.RasterYSize, ds.RasterXSize)
468
+
469
+ Examples:
470
+
471
+ ds = gdal.Open("rast.tif")
472
+ rast = sgspy.SpatialRaster.from_gdal(ds)
473
+
474
+
475
+ ds = gdal.Open("rast.tif")
476
+ bands = []
477
+ for i in range(1, ds.RasterCount + 1):
478
+ bands.append(ds.GetRasterBand(1).ReadAsArray())
479
+ arr = np.stack(bands, axis=0)
480
+ arr[arr < 2] = np.nan
481
+ rast = sgspy.SpatialRaster.from_gdal(ds, arr)
482
+ """
483
+ if not GDAL:
484
+ raise RuntimeError("from_gdal() can only be called if gdal was successfully imported, but it wasn't")
485
+
486
+ if type(ds) is not gdal.Dataset:
487
+ raise TypeError("the ds parameter passed to from_gdal() must be of type gdal.Dataset")
488
+
489
+ if ds.GetDriver().ShortName == "MEM" and arr is None:
490
+ bands = []
491
+ for i in range(1, ds.RasterCount + 1):
492
+ bands.append(ds.GetRasterBand(i).ReadAsArray())
493
+ arr = np.stack(bands, axis=0)
494
+
495
+ if arr is not None:
496
+ if type(arr) is not np.ndarray:
497
+ raise TypeError("'arr' parameter, if passed, must be of type np.ndarray")
498
+
499
+ shape = arr.shape
500
+ if (len(shape)) == 2:
501
+ (height, width) = shape
502
+ if ds.RasterCount != 1:
503
+ raise RuntimeError("if the array parameter contains only a single band with shape (height, width), the raster must contain only a single band.")
504
+ else:
505
+ (band_count, height, width) = shape
506
+ if (band_count != ds.RasterCount):
507
+ raise RuntimeError("the array parameter must contains the same number of bands as the raster with shape (band_count, height, width).")
508
+
509
+ if height != ds.RasterYSize:
510
+ raise RuntimeError("the height of the array passed must be equal to the height of the raster dataset.")
511
+
512
+ if width != ds.RasterXSize:
513
+ raise RuntimeError("the width of the array passed must be equal to the width of the raster dataset.")
514
+
515
+ nan_vals = []
516
+ band_names = []
517
+ for i in range(1, ds.RasterCount + 1):
518
+ band = ds.GetRasterBand(i)
519
+ nan_vals.append(band.GetNoDataValue())
520
+ band_names.append(band.GetDescription())
521
+
522
+ geotransform = ds.GetGeoTransform()
523
+ projection = ds.GetProjection()
524
+ arr = np.ascontiguousarray(arr)
525
+ buffer = memoryview(arr)
526
+
527
+ ds.Close()
528
+ return cls(GDALRasterWrapper(buffer, geotransform, projection, nan_vals, band_names, PROJDB_PATH))
529
+ else:
530
+ filename = ds.GetName()
531
+
532
+ ds.Close()
533
+ return cls(filename)
534
+
535
+ def to_gdal(self, with_arr = False):
536
+ """
537
+ This function is used to convert an sgspy.SpatialRaster into a GDAL dataset. If with_arr is set to True,
538
+ the function will return a numpy.ndarray as a tuple with the GDAL dataset object.
539
+
540
+ Examples:
541
+
542
+ rast = sgspy.SpatialRaster('rast.tif')
543
+ ds = rast.to_gdal()
544
+
545
+ rast = sgspy.SpatialRaster('mraster.tif')
546
+ ds, arr = sgs.to_gdal(with_arr=True)
547
+ """
548
+ if not GDAL:
549
+ raise RuntimeError("from_gdal() can only be called if gdal was successfully imported, but it wasn't")
550
+
551
+ if self.closed:
552
+ raise RuntimeError("the C++ object which this class wraps has been cleaned up and closed.")
553
+
554
+ if (self.temp_dataset):
555
+ raise RuntimeError("the dataset has been saved as a temporary file which will be deleted when the C++ object containing it is deleted. the dataset must be either in-memory or have a filename.")
556
+
557
+ in_mem = self.driver.find("MEM") != -1
558
+
559
+ if with_arr or in_mem:
560
+ bands = []
561
+ for i in range(self.band_count):
562
+ bands.append(np.asarray(self.cpp_raster.get_raster_as_memoryview(self.width, self.height, i)))
563
+
564
+ #ensure numpy array doesn't accidentally get cleaned up by C++ object deletion
565
+ self.cpp_raster.release_band_buffers()
566
+
567
+ arr = np.stack(bands, axis=0)
568
+
569
+ if in_mem:
570
+ geotransform = self.cpp_raster.get_geotransform()
571
+ projection = self.projection
572
+ nan_vals = []
573
+ for i in range(self.band_count):
574
+ nan_vals.append(self.cpp_raster.get_band_nodata_value(i))
575
+ band_names = self.bands
576
+
577
+ self.cpp_raster.close()
578
+ self.closed = True
579
+
580
+ ds = gdal.GetDriverByName("MEM").Create("", self.width, self.height, 0, gdal.GDT_Unknown)
581
+ ds.SetGeoTransform(geotransform)
582
+ ds.SetProjection(projection)
583
+ # NOTE: copying from an in-memory GDAL dataset then writing to a rasterio MemoryFile() may cause extra data copying.
584
+ #
585
+ # I'm dong it this way for now instead of somehow passing the data pointer directly, for fear of memory leaks/dangling pointers/accidentally deleting memory still in use.
586
+ for i in range(1, arr.shape[0] + 1):
587
+ band_arr = arr[i - 1]
588
+ ds.AddBand(gdal_array.NumericTypeCodeToGDALTypeCode(band_arr.dtype))
589
+ band = ds.GetRasterBand(i)
590
+ band.WriteArray(band_arr)
591
+ band.SetNoDataValue(nan_vals[i - 1])
592
+ band.SetDescription(band_names[i - 1])
593
+ else:
594
+ self.cpp_raster.close()
595
+ self.closed = True
596
+
597
+ ds = gdal.Open(self.filename)
598
+
599
+ if with_arr:
600
+ return ds, arr
601
+ else:
602
+ return ds