geo-explorer 0.9.10__tar.gz → 0.9.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/PKG-INFO +1 -1
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/pyproject.toml +1 -1
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/file_browser.py +40 -25
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/geo_explorer.py +45 -103
- geo_explorer-0.9.11/src/geo_explorer/img.py +297 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/utils.py +29 -0
- geo_explorer-0.9.10/src/geo_explorer/nc.py +0 -199
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/LICENSE +0 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/LICENSE.md +0 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/README.md +0 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/__init__.py +0 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/assets/chroma.min.js +0 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/assets/on_each_feature.js +0 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/assets/stylesheet.css +0 -0
- {geo_explorer-0.9.10 → geo_explorer-0.9.11}/src/geo_explorer/fs.py +0 -0
|
@@ -19,6 +19,9 @@ from .utils import _clicked_button_style
|
|
|
19
19
|
from .utils import _standardize_path
|
|
20
20
|
from .utils import _unclicked_button_style
|
|
21
21
|
from .utils import get_button_with_tooltip
|
|
22
|
+
from .utils import time_function_call
|
|
23
|
+
from .utils import time_method_call
|
|
24
|
+
from .utils import _PROFILE_DICT
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class FileBrowser:
|
|
@@ -29,15 +32,18 @@ class FileBrowser:
|
|
|
29
32
|
start_dir: str,
|
|
30
33
|
favorites: list[str] | None = None,
|
|
31
34
|
file_system: AbstractFileSystem | None = None,
|
|
35
|
+
sum_partition_sizes: bool = True,
|
|
32
36
|
) -> None:
|
|
33
37
|
self.start_dir = _standardize_path(start_dir)
|
|
34
38
|
self.file_system = file_system
|
|
35
39
|
self.favorites = (
|
|
36
40
|
[_standardize_path(x) for x in favorites] if favorites is not None else []
|
|
37
41
|
)
|
|
42
|
+
self.sum_partition_sizes = sum_partition_sizes
|
|
38
43
|
self._history = [self.start_dir]
|
|
39
44
|
self._register_callbacks()
|
|
40
45
|
|
|
46
|
+
@time_method_call(_PROFILE_DICT)
|
|
41
47
|
def get_file_browser_components(self, width: str = "140vh") -> list[Component]:
|
|
42
48
|
return [
|
|
43
49
|
dbc.Row(
|
|
@@ -295,6 +301,7 @@ class FileBrowser:
|
|
|
295
301
|
State("file-list", "children"),
|
|
296
302
|
State("file-data-dict", "data"),
|
|
297
303
|
)
|
|
304
|
+
@time_method_call(_PROFILE_DICT)
|
|
298
305
|
def update_file_list(
|
|
299
306
|
path,
|
|
300
307
|
search_word,
|
|
@@ -350,6 +357,7 @@ class FileBrowser:
|
|
|
350
357
|
|
|
351
358
|
return (file_data_dict, file_list, alert, sort_by_clicks, self._history[1:])
|
|
352
359
|
|
|
360
|
+
@time_method_call(_PROFILE_DICT)
|
|
353
361
|
def _list_dir(
|
|
354
362
|
self,
|
|
355
363
|
path: str,
|
|
@@ -386,11 +394,13 @@ class FileBrowser:
|
|
|
386
394
|
|
|
387
395
|
if (recursive or 0) % 2 == 0:
|
|
388
396
|
|
|
397
|
+
@time_function_call(_PROFILE_DICT)
|
|
389
398
|
def _ls(path):
|
|
390
399
|
return file_system.ls(path, detail=True)
|
|
391
400
|
|
|
392
401
|
else:
|
|
393
402
|
|
|
403
|
+
@time_function_call(_PROFILE_DICT)
|
|
394
404
|
def _ls(path):
|
|
395
405
|
path = str(Path(path) / "**")
|
|
396
406
|
return _try_glob(path, file_system)
|
|
@@ -417,7 +427,7 @@ class FileBrowser:
|
|
|
417
427
|
if isinstance(paths, dict):
|
|
418
428
|
paths = list(paths.values())
|
|
419
429
|
|
|
420
|
-
def
|
|
430
|
+
def is_dir_or_relevant_file_format(x) -> bool:
|
|
421
431
|
return x["type"] == "directory" or any(
|
|
422
432
|
x["name"].endswith(txt) for txt in self.file_formats
|
|
423
433
|
)
|
|
@@ -427,13 +437,35 @@ class FileBrowser:
|
|
|
427
437
|
for x in paths
|
|
428
438
|
if isinstance(x, dict)
|
|
429
439
|
and _contains(x["name"])
|
|
430
|
-
and
|
|
440
|
+
and is_dir_or_relevant_file_format(x)
|
|
431
441
|
and Path(path).parts != Path(x["name"]).parts
|
|
432
442
|
]
|
|
433
443
|
|
|
434
444
|
paths.sort(key=lambda x: x["name"])
|
|
435
445
|
isdir_list = [x["type"] == "directory" for x in paths]
|
|
436
446
|
|
|
447
|
+
if self.sum_partition_sizes:
|
|
448
|
+
self._sum_partition_sizes(paths, file_system)
|
|
449
|
+
|
|
450
|
+
return (
|
|
451
|
+
paths,
|
|
452
|
+
[
|
|
453
|
+
_get_file_list_row(
|
|
454
|
+
x["name"],
|
|
455
|
+
x.get("updated", None),
|
|
456
|
+
x["size"],
|
|
457
|
+
isdir,
|
|
458
|
+
path,
|
|
459
|
+
self.file_formats,
|
|
460
|
+
file_system,
|
|
461
|
+
)
|
|
462
|
+
for x, isdir in zip(paths, isdir_list, strict=True)
|
|
463
|
+
if isinstance(x, dict)
|
|
464
|
+
],
|
|
465
|
+
None,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def _sum_partition_sizes(self, paths, file_system):
|
|
437
469
|
partitioned = {
|
|
438
470
|
i: x
|
|
439
471
|
for i, x in enumerate(paths)
|
|
@@ -444,6 +476,7 @@ class FileBrowser:
|
|
|
444
476
|
)
|
|
445
477
|
}
|
|
446
478
|
|
|
479
|
+
@time_function_call(_PROFILE_DICT)
|
|
447
480
|
def get_summed_size_and_latest_timestamp_in_subdirs(
|
|
448
481
|
x,
|
|
449
482
|
) -> tuple[float, datetime.datetime]:
|
|
@@ -462,11 +495,9 @@ class FileBrowser:
|
|
|
462
495
|
)
|
|
463
496
|
|
|
464
497
|
with ThreadPoolExecutor() as executor:
|
|
465
|
-
summed_size_ant_time =
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
partitioned.values(),
|
|
469
|
-
)
|
|
498
|
+
summed_size_ant_time = executor.map(
|
|
499
|
+
get_summed_size_and_latest_timestamp_in_subdirs,
|
|
500
|
+
partitioned.values(),
|
|
470
501
|
)
|
|
471
502
|
for i, (size, timestamp) in zip(
|
|
472
503
|
partitioned, summed_size_ant_time, strict=True
|
|
@@ -474,25 +505,8 @@ class FileBrowser:
|
|
|
474
505
|
paths[i]["size"] = size
|
|
475
506
|
paths[i]["updated"] = timestamp
|
|
476
507
|
|
|
477
|
-
return (
|
|
478
|
-
paths,
|
|
479
|
-
[
|
|
480
|
-
_get_file_list_row(
|
|
481
|
-
x["name"],
|
|
482
|
-
x.get("updated", None),
|
|
483
|
-
x["size"],
|
|
484
|
-
isdir,
|
|
485
|
-
path,
|
|
486
|
-
self.file_formats,
|
|
487
|
-
file_system,
|
|
488
|
-
)
|
|
489
|
-
for x, isdir in zip(paths, isdir_list, strict=True)
|
|
490
|
-
if isinstance(x, dict)
|
|
491
|
-
],
|
|
492
|
-
None,
|
|
493
|
-
)
|
|
494
|
-
|
|
495
508
|
|
|
509
|
+
@time_function_call(_PROFILE_DICT)
|
|
496
510
|
def _get_file_list_row(
|
|
497
511
|
path, timestamp, size, isdir: bool, current_path, file_formats, file_system
|
|
498
512
|
):
|
|
@@ -564,6 +578,7 @@ def _get_file_list_row(
|
|
|
564
578
|
)
|
|
565
579
|
|
|
566
580
|
|
|
581
|
+
@time_function_call(_PROFILE_DICT)
|
|
567
582
|
def _try_glob(path, file_system):
|
|
568
583
|
try:
|
|
569
584
|
return file_system.glob(path, detail=True, recursive=True)
|
|
@@ -39,7 +39,6 @@ import pandas as pd
|
|
|
39
39
|
import polars as pl
|
|
40
40
|
import pyarrow
|
|
41
41
|
import pyarrow.parquet as pq
|
|
42
|
-
import rasterio
|
|
43
42
|
import sgis as sg
|
|
44
43
|
import shapely
|
|
45
44
|
from dash import Dash
|
|
@@ -82,9 +81,9 @@ except ImportError:
|
|
|
82
81
|
|
|
83
82
|
from .file_browser import FileBrowser
|
|
84
83
|
from .fs import LocalFileSystem
|
|
85
|
-
from .
|
|
86
|
-
from .
|
|
87
|
-
from .
|
|
84
|
+
from .img import GeoTIFFConfig
|
|
85
|
+
from .img import AbstractImageConfig
|
|
86
|
+
from .img import NetCDFConfig
|
|
88
87
|
from .utils import _PROFILE_DICT
|
|
89
88
|
from .utils import DEBUG
|
|
90
89
|
from .utils import _clicked_button_style
|
|
@@ -1257,7 +1256,9 @@ class GeoExplorer:
|
|
|
1257
1256
|
port: int = 8050,
|
|
1258
1257
|
file_system: AbstractFileSystem | None = None,
|
|
1259
1258
|
data: (
|
|
1260
|
-
dict[str, str | GeoDataFrame |
|
|
1259
|
+
dict[str, str | GeoDataFrame | AbstractImageConfig]
|
|
1260
|
+
| list[str | dict]
|
|
1261
|
+
| None
|
|
1261
1262
|
) = None,
|
|
1262
1263
|
column: str | None = None,
|
|
1263
1264
|
color_dict: dict | None = None,
|
|
@@ -1273,6 +1274,7 @@ class GeoExplorer:
|
|
|
1273
1274
|
nan_color: str = "#969696",
|
|
1274
1275
|
nan_label: str = "Missing",
|
|
1275
1276
|
max_read_size_per_callback: int = 1e9,
|
|
1277
|
+
sum_partition_sizes: bool = True,
|
|
1276
1278
|
**kwargs,
|
|
1277
1279
|
) -> None:
|
|
1278
1280
|
"""Initialiser."""
|
|
@@ -1313,7 +1315,10 @@ class GeoExplorer:
|
|
|
1313
1315
|
self._concatted_data: pl.DataFrame | None = None
|
|
1314
1316
|
self._selected_features = {}
|
|
1315
1317
|
self._file_browser = FileBrowser(
|
|
1316
|
-
start_dir,
|
|
1318
|
+
start_dir,
|
|
1319
|
+
file_system=file_system,
|
|
1320
|
+
favorites=favorites,
|
|
1321
|
+
sum_partition_sizes=sum_partition_sizes,
|
|
1317
1322
|
)
|
|
1318
1323
|
self._current_table_view = None
|
|
1319
1324
|
self.max_read_size_per_callback = max_read_size_per_callback
|
|
@@ -1718,17 +1723,25 @@ class GeoExplorer:
|
|
|
1718
1723
|
raise ValueError(error_mess)
|
|
1719
1724
|
for key, value in x.items():
|
|
1720
1725
|
key = _standardize_path(key)
|
|
1721
|
-
if isinstance(value,
|
|
1722
|
-
# setting nc files as unchecked because they might be very large
|
|
1723
|
-
self.selected_files[key] = False
|
|
1724
|
-
self._queries[key] = value.code_block
|
|
1726
|
+
if isinstance(value, AbstractImageConfig):
|
|
1725
1727
|
self._nc[key] = value
|
|
1728
|
+
# setting image files as unchecked because they might be very large
|
|
1729
|
+
self.selected_files[key] = False
|
|
1730
|
+
self.set_query(key, value.code_block)
|
|
1731
|
+
child_paths = self._get_child_paths([key]) or [key]
|
|
1732
|
+
try:
|
|
1733
|
+
bbox_series_dict |= {
|
|
1734
|
+
x: shapely.box(*self._nc[key].get_bounds(None, x))
|
|
1735
|
+
for x in child_paths
|
|
1736
|
+
}
|
|
1737
|
+
except Exception:
|
|
1738
|
+
pass
|
|
1726
1739
|
continue
|
|
1727
1740
|
if value is not None and not isinstance(value, (GeoDataFrame | str)):
|
|
1728
1741
|
raise ValueError(error_mess)
|
|
1729
1742
|
elif not isinstance(value, GeoDataFrame):
|
|
1730
1743
|
self.selected_files[key] = True
|
|
1731
|
-
self.
|
|
1744
|
+
self.set_query(key, value)
|
|
1732
1745
|
continue
|
|
1733
1746
|
value, dtypes = _geopandas_to_polars(value, key)
|
|
1734
1747
|
bbox_series_dict[key] = shapely.box(
|
|
@@ -1785,13 +1798,13 @@ class GeoExplorer:
|
|
|
1785
1798
|
continue
|
|
1786
1799
|
loaded_data_sorted[key] = df.with_columns(
|
|
1787
1800
|
_unique_id=_get_unique_id(self._max_unique_id_int)
|
|
1788
|
-
).drop("id",
|
|
1801
|
+
).drop("id", strict=False)
|
|
1789
1802
|
else:
|
|
1790
1803
|
x = _standardize_path(x)
|
|
1791
1804
|
df = self._loaded_data[x]
|
|
1792
1805
|
loaded_data_sorted[x] = df.with_columns(
|
|
1793
1806
|
_unique_id=_get_unique_id(self._max_unique_id_int)
|
|
1794
|
-
).drop("id",
|
|
1807
|
+
).drop("id", strict=False)
|
|
1795
1808
|
self._max_unique_id_int += 1
|
|
1796
1809
|
|
|
1797
1810
|
self._loaded_data = loaded_data_sorted
|
|
@@ -2157,7 +2170,7 @@ class GeoExplorer:
|
|
|
2157
2170
|
and query.startswith("pl.col")
|
|
2158
2171
|
):
|
|
2159
2172
|
query = f"{old_query}, {query}"
|
|
2160
|
-
self.
|
|
2173
|
+
self.set_query(path, query)
|
|
2161
2174
|
return query
|
|
2162
2175
|
|
|
2163
2176
|
@callback(
|
|
@@ -3070,7 +3083,7 @@ class GeoExplorer:
|
|
|
3070
3083
|
bins=bins,
|
|
3071
3084
|
opacity=opacity,
|
|
3072
3085
|
n_rows_per_path=n_rows_per_path,
|
|
3073
|
-
columns=self._columns,
|
|
3086
|
+
columns=self._columns(),
|
|
3074
3087
|
current_columns=current_columns,
|
|
3075
3088
|
)
|
|
3076
3089
|
results = [
|
|
@@ -3104,15 +3117,21 @@ class GeoExplorer:
|
|
|
3104
3117
|
):
|
|
3105
3118
|
_read_files(self, [img_path], mask=bbox)
|
|
3106
3119
|
img_bbox = self._bbox_series.loc[img_path]
|
|
3107
|
-
|
|
3108
|
-
if
|
|
3120
|
+
clipped_bbox = img_bbox.intersection(bbox)
|
|
3121
|
+
if clipped_bbox.is_empty:
|
|
3109
3122
|
continue
|
|
3110
3123
|
try:
|
|
3111
|
-
|
|
3112
|
-
|
|
3124
|
+
self._nc[selected_path].validate_code_block()
|
|
3125
|
+
ds = self._nc[selected_path].filter_ds(
|
|
3126
|
+
ds=self._loaded_data[img_path],
|
|
3127
|
+
bounds=clipped_bbox.bounds,
|
|
3128
|
+
path=img_path,
|
|
3113
3129
|
)
|
|
3114
3130
|
except Exception as e:
|
|
3115
3131
|
traceback.print_exc()
|
|
3132
|
+
print(
|
|
3133
|
+
"\nNote: the above was a print of the error traceback from invalid query of dataset"
|
|
3134
|
+
)
|
|
3116
3135
|
alerts.append(
|
|
3117
3136
|
dbc.Alert(
|
|
3118
3137
|
f"{type(e).__name__}: {e}. (Traceback printed in terminal)",
|
|
@@ -3128,7 +3147,7 @@ class GeoExplorer:
|
|
|
3128
3147
|
if np.isnan(arr).any() and not np.all(np.isnan(arr)):
|
|
3129
3148
|
arr[np.isnan(arr)] = np.min(arr[~np.isnan(arr)])
|
|
3130
3149
|
|
|
3131
|
-
images[img_path] = (arr,
|
|
3150
|
+
images[img_path] = (arr, clipped_bbox)
|
|
3132
3151
|
|
|
3133
3152
|
if images:
|
|
3134
3153
|
# make sure all single-band images are normalized by same extremities
|
|
@@ -3861,7 +3880,6 @@ class GeoExplorer:
|
|
|
3861
3880
|
deleted_files.add(path)
|
|
3862
3881
|
break
|
|
3863
3882
|
|
|
3864
|
-
assert len(deleted_files) == 1, deleted_files
|
|
3865
3883
|
deleted_files2 = set()
|
|
3866
3884
|
for i, path2 in enumerate(list(self._loaded_data)):
|
|
3867
3885
|
parts = Path(path2).parts
|
|
@@ -4000,7 +4018,7 @@ class GeoExplorer:
|
|
|
4000
4018
|
if query == self._queries.get(path):
|
|
4001
4019
|
continue
|
|
4002
4020
|
self._check_for_circular_queries(query, path)
|
|
4003
|
-
self.
|
|
4021
|
+
self.set_query(path, query)
|
|
4004
4022
|
except RecursionError as e:
|
|
4005
4023
|
out_alerts.append(
|
|
4006
4024
|
dbc.Alert(
|
|
@@ -4016,6 +4034,11 @@ class GeoExplorer:
|
|
|
4016
4034
|
return out_alerts
|
|
4017
4035
|
return None
|
|
4018
4036
|
|
|
4037
|
+
def set_query(self, key: str, value: str | None) -> None:
|
|
4038
|
+
self._queries[key] = value
|
|
4039
|
+
if key in self._nc:
|
|
4040
|
+
self._nc[key].code_block = value
|
|
4041
|
+
|
|
4019
4042
|
@time_method_call(_PROFILE_DICT)
|
|
4020
4043
|
def _get_unique_stem(self, path) -> str:
|
|
4021
4044
|
name = _get_stem(path)
|
|
@@ -4023,19 +4046,6 @@ class GeoExplorer:
|
|
|
4023
4046
|
name = _get_stem_from_parent(path)
|
|
4024
4047
|
return name
|
|
4025
4048
|
|
|
4026
|
-
def _open_img_path_as_xarray(self, img_path, selected_path, clipped_bounds):
|
|
4027
|
-
if is_netcdf(img_path):
|
|
4028
|
-
return self._nc[selected_path].filter_ds(
|
|
4029
|
-
ds=self._loaded_data[img_path],
|
|
4030
|
-
bounds=clipped_bounds.bounds,
|
|
4031
|
-
code_block=self._queries.get(selected_path),
|
|
4032
|
-
)
|
|
4033
|
-
else:
|
|
4034
|
-
return rasterio_to_xarray(
|
|
4035
|
-
img_path, clipped_bounds, code_block=self._queries.get(selected_path)
|
|
4036
|
-
)
|
|
4037
|
-
|
|
4038
|
-
@property
|
|
4039
4049
|
def _columns(self) -> dict[str, set[str]]:
|
|
4040
4050
|
return {path: set(dtypes) for path, dtypes in self._dtypes.items()} | {
|
|
4041
4051
|
path: {"value"} for path in self._nc
|
|
@@ -4759,74 +4769,6 @@ def get_numeric_colors(values_no_nans_unique, values_no_nans, cmap, k):
|
|
|
4759
4769
|
return color_dict, bins
|
|
4760
4770
|
|
|
4761
4771
|
|
|
4762
|
-
def rasterio_to_numpy(
|
|
4763
|
-
img_path, bbox, return_attrs: list[str] | None = None
|
|
4764
|
-
) -> np.ndarray | tuple[Any]:
|
|
4765
|
-
with rasterio.open(img_path) as src:
|
|
4766
|
-
bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(src.crs).total_bounds
|
|
4767
|
-
window = rasterio.windows.from_bounds(
|
|
4768
|
-
*bounds_in_img_crs, transform=src.transform
|
|
4769
|
-
)
|
|
4770
|
-
arr = src.read(window=window, boundless=False, masked=False)
|
|
4771
|
-
if not return_attrs:
|
|
4772
|
-
return arr
|
|
4773
|
-
return (arr, *[getattr(src, attr) for attr in return_attrs])
|
|
4774
|
-
|
|
4775
|
-
|
|
4776
|
-
def rasterio_to_xarray(img_path, bbox, code_block):
|
|
4777
|
-
import xarray as xr
|
|
4778
|
-
from rioxarray.rioxarray import _generate_spatial_coords
|
|
4779
|
-
|
|
4780
|
-
arr, crs, descriptions = rasterio_to_numpy(
|
|
4781
|
-
img_path, bbox, return_attrs=["crs", "descriptions"]
|
|
4782
|
-
)
|
|
4783
|
-
bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(crs).total_bounds
|
|
4784
|
-
|
|
4785
|
-
if not all(arr.shape):
|
|
4786
|
-
return xr.DataArray(
|
|
4787
|
-
arr,
|
|
4788
|
-
dims=["y", "x"],
|
|
4789
|
-
attrs={"crs": crs},
|
|
4790
|
-
)
|
|
4791
|
-
|
|
4792
|
-
if len(arr.shape) == 2:
|
|
4793
|
-
height, width = arr.shape
|
|
4794
|
-
elif len(arr.shape) == 3 and arr.shape[0] == 1:
|
|
4795
|
-
arr = arr[0]
|
|
4796
|
-
height, width = arr.shape
|
|
4797
|
-
elif len(arr.shape) == 3:
|
|
4798
|
-
height, width = arr.shape[1:]
|
|
4799
|
-
else:
|
|
4800
|
-
raise ValueError(arr.shape)
|
|
4801
|
-
|
|
4802
|
-
transform = rasterio.transform.from_bounds(*bounds_in_img_crs, width, height)
|
|
4803
|
-
coords = _generate_spatial_coords(transform, width, height)
|
|
4804
|
-
|
|
4805
|
-
if len(arr.shape) == 2:
|
|
4806
|
-
ds = xr.DataArray(
|
|
4807
|
-
arr,
|
|
4808
|
-
coords=coords,
|
|
4809
|
-
dims=["y", "x"],
|
|
4810
|
-
attrs={"crs": crs},
|
|
4811
|
-
)
|
|
4812
|
-
else:
|
|
4813
|
-
if len(descriptions) != arr.shape[0]:
|
|
4814
|
-
descriptions = range(arr.shape[0])
|
|
4815
|
-
ds = xr.Dataset(
|
|
4816
|
-
{
|
|
4817
|
-
desc: xr.DataArray(
|
|
4818
|
-
arr[i],
|
|
4819
|
-
coords=coords,
|
|
4820
|
-
dims=["y", "x"],
|
|
4821
|
-
attrs={"crs": crs},
|
|
4822
|
-
name=desc,
|
|
4823
|
-
)
|
|
4824
|
-
for i, desc in enumerate(descriptions)
|
|
4825
|
-
}
|
|
4826
|
-
)
|
|
4827
|
-
return _run_code_block(ds, code_block)
|
|
4828
|
-
|
|
4829
|
-
|
|
4830
4772
|
def as_sized_array(arr: np.ndarray) -> np.ndarray:
|
|
4831
4773
|
try:
|
|
4832
4774
|
len(arr)
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any
|
|
3
|
+
from typing import ClassVar
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pyproj
|
|
8
|
+
import rasterio
|
|
9
|
+
import shapely
|
|
10
|
+
from geopandas import GeoDataFrame
|
|
11
|
+
from geopandas import GeoSeries
|
|
12
|
+
from shapely.geometry import Polygon
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from xarray import DataArray
|
|
16
|
+
from xarray import Dataset
|
|
17
|
+
except ImportError:
|
|
18
|
+
|
|
19
|
+
class DataArray:
|
|
20
|
+
"""Placeholder."""
|
|
21
|
+
|
|
22
|
+
class Dataset:
|
|
23
|
+
"""Placeholder."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
from .utils import _PROFILE_DICT
|
|
27
|
+
from .utils import get_xarray_bounds
|
|
28
|
+
from .utils import time_method_call
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AbstractImageConfig(abc.ABC):
|
|
32
|
+
rgb_bands: ClassVar[list[str] | None] = None
|
|
33
|
+
|
|
34
|
+
def __init__(self, code_block: str | None = None) -> None:
|
|
35
|
+
self.code_block = code_block
|
|
36
|
+
self.validate_code_block()
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def get_bounds(self, ds: Dataset, path: str) -> tuple[float, float, float, float]:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@time_method_call(_PROFILE_DICT)
|
|
47
|
+
def filter_ds(
|
|
48
|
+
self,
|
|
49
|
+
ds: Dataset,
|
|
50
|
+
path: str,
|
|
51
|
+
bounds: tuple[float, float, float, float],
|
|
52
|
+
) -> GeoDataFrame | None:
|
|
53
|
+
crs = self.get_crs(ds, None)
|
|
54
|
+
ds_bounds = get_xarray_bounds(ds)
|
|
55
|
+
|
|
56
|
+
bbox_correct_crs = (
|
|
57
|
+
GeoSeries([shapely.box(*bounds)], crs=4326).to_crs(crs).union_all()
|
|
58
|
+
)
|
|
59
|
+
clipped_bbox = bbox_correct_crs.intersection(shapely.box(*ds_bounds))
|
|
60
|
+
minx, miny, maxx, maxy = clipped_bbox.bounds
|
|
61
|
+
|
|
62
|
+
ds = ds.sel(
|
|
63
|
+
x=slice(minx, maxx),
|
|
64
|
+
y=slice(maxy, miny),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return self._run_code_block(ds)
|
|
68
|
+
|
|
69
|
+
@time_method_call(_PROFILE_DICT)
|
|
70
|
+
def to_numpy(self, xarr: Dataset | DataArray) -> GeoDataFrame | None:
|
|
71
|
+
if isinstance(xarr, Dataset) and len(xarr.data_vars) == 1:
|
|
72
|
+
xarr = xarr[next(iter(xarr.data_vars))]
|
|
73
|
+
elif isinstance(xarr, Dataset):
|
|
74
|
+
try:
|
|
75
|
+
xarr = xarr[self.rgb_bands]
|
|
76
|
+
except Exception:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
if "time" in set(xarr.dims) and (
|
|
80
|
+
not hasattr(xarr["time"].values, "__len__") or len(xarr["time"].values) > 1
|
|
81
|
+
):
|
|
82
|
+
print("Note: selecting first DataArray in time dimension.")
|
|
83
|
+
xarr = xarr.isel(time=0)
|
|
84
|
+
|
|
85
|
+
if isinstance(xarr, Dataset):
|
|
86
|
+
try:
|
|
87
|
+
return np.array([xarr[band].values for band in self.rgb_bands])
|
|
88
|
+
except Exception:
|
|
89
|
+
pass
|
|
90
|
+
if len(xarr.data_vars) == 3:
|
|
91
|
+
try:
|
|
92
|
+
return np.array([xarr[band].values for band in xarr.data_vars])
|
|
93
|
+
except Exception:
|
|
94
|
+
pass
|
|
95
|
+
arrs = []
|
|
96
|
+
for var in xarr.data_vars:
|
|
97
|
+
arr = xarr[var].values
|
|
98
|
+
if len(arr.shape) == 2:
|
|
99
|
+
arrs.append(arr)
|
|
100
|
+
if len(arr.shape) == 3:
|
|
101
|
+
arrs.append(arr[0])
|
|
102
|
+
return arrs[np.argmax([arr.shape for arr in arrs])]
|
|
103
|
+
|
|
104
|
+
if isinstance(xarr, np.ndarray):
|
|
105
|
+
return xarr
|
|
106
|
+
|
|
107
|
+
return xarr.values
|
|
108
|
+
|
|
109
|
+
def validate_code_block(self) -> None:
|
|
110
|
+
if not self.code_block:
|
|
111
|
+
return
|
|
112
|
+
try:
|
|
113
|
+
assert "ds" in self.code_block
|
|
114
|
+
eval(self.code_block)
|
|
115
|
+
except NameError:
|
|
116
|
+
return
|
|
117
|
+
except SyntaxError:
|
|
118
|
+
pass
|
|
119
|
+
if "xarr=" not in self.code_block.replace(" ", "") or not any(
|
|
120
|
+
txt in self.code_block.replace(" ", "") for txt in ("=ds", "(ds")
|
|
121
|
+
):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"'code_block' must be a piece of code that takes the xarray object 'ds' and defines the object 'xarr'. "
|
|
124
|
+
f"Got '{self.code_block}'"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def _run_code_block(self, ds: DataArray | Dataset) -> Dataset | DataArray:
|
|
128
|
+
if not self.code_block:
|
|
129
|
+
return ds
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
xarr = eval(self.code_block)
|
|
133
|
+
if callable(xarr):
|
|
134
|
+
xarr = xarr(ds)
|
|
135
|
+
except SyntaxError:
|
|
136
|
+
loc = {}
|
|
137
|
+
exec(self.code_block, globals() | {"ds": ds}, loc)
|
|
138
|
+
xarr = loc["xarr"]
|
|
139
|
+
|
|
140
|
+
if isinstance(xarr, np.ndarray) and isinstance(ds, DataArray):
|
|
141
|
+
ds.values = xarr
|
|
142
|
+
return ds
|
|
143
|
+
elif isinstance(xarr, np.ndarray) and isinstance(ds, Dataset):
|
|
144
|
+
raise ValueError(
|
|
145
|
+
"Cannot return np.ndarray from 'code_block' if ds is xarray.Dataset."
|
|
146
|
+
)
|
|
147
|
+
return xarr
|
|
148
|
+
|
|
149
|
+
def __str__(self) -> str:
|
|
150
|
+
code_block = f"'{self.code_block}'" if self.code_block else None
|
|
151
|
+
return f"{self.__class__.__name__}({code_block})"
|
|
152
|
+
|
|
153
|
+
def __repr__(self) -> str:
|
|
154
|
+
return str(self)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class NetCDFConfig(AbstractImageConfig):
|
|
158
|
+
"""Sets the configuration for reading NetCDF files and getting crs and bounds.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
code_block: String of Python code that takes an xarray.Dataset (ds) and returns an xarray.DataArray (xarr).
|
|
162
|
+
Note that the input Dataset must be references as 'ds' and the output must be assigned to 'xarr'.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
|
|
166
|
+
|
|
167
|
+
def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
|
|
168
|
+
return get_xarray_bounds(ds)
|
|
169
|
+
|
|
170
|
+
def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
|
|
171
|
+
attrs = [
|
|
172
|
+
x
|
|
173
|
+
for x in set(ds.attrs).union(
|
|
174
|
+
set(ds.data_vars if isinstance(ds, Dataset) else set())
|
|
175
|
+
)
|
|
176
|
+
if "projection" in x.lower() or "crs" in x.lower()
|
|
177
|
+
]
|
|
178
|
+
if not attrs:
|
|
179
|
+
raise ValueError(f"Could not find CRS attribute/data_var in dataset: {ds}")
|
|
180
|
+
|
|
181
|
+
def getattr_xarray(ds, attr):
|
|
182
|
+
x = ds.attrs.get(attr, ds.get(attr))
|
|
183
|
+
if isinstance(x, DataArray):
|
|
184
|
+
return pyproj.CRS(str(x.values))
|
|
185
|
+
elif x is not None:
|
|
186
|
+
return pyproj.CRS(x)
|
|
187
|
+
raise ValueError(f"Could not find CRS attribute/data_var in dataset: {ds}")
|
|
188
|
+
|
|
189
|
+
for i, attr in enumerate(attrs):
|
|
190
|
+
try:
|
|
191
|
+
return getattr_xarray(ds, attr)
|
|
192
|
+
except Exception as e:
|
|
193
|
+
if i == len(attrs) - 1:
|
|
194
|
+
attrs_dict = {attr: ds.attrs.get(attr, ds[attr]) for attr in attrs}
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"No valid CRS attribute found among {attrs_dict}"
|
|
197
|
+
) from e
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class NBSNetCDFConfig(NetCDFConfig):
|
|
201
|
+
def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
|
|
202
|
+
return pyproj.CRS(ds.UTM_projection.epsg_code)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class Sentinel2NBSNetCDFConfig(NBSNetCDFConfig):
|
|
206
|
+
rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class GeoTIFFConfig(AbstractImageConfig):
|
|
210
|
+
def get_crs(self, ds, path):
|
|
211
|
+
with rasterio.open(path) as src:
|
|
212
|
+
return src.crs
|
|
213
|
+
|
|
214
|
+
def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
|
|
215
|
+
with rasterio.open(path) as src:
|
|
216
|
+
return tuple(
|
|
217
|
+
GeoSeries([shapely.box(*src.bounds)], crs=src.crs)
|
|
218
|
+
.to_crs(4326)
|
|
219
|
+
.total_bounds
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def filter_ds(
|
|
223
|
+
self,
|
|
224
|
+
ds: None,
|
|
225
|
+
path: str,
|
|
226
|
+
bounds: tuple[float, float, float, float],
|
|
227
|
+
) -> DataArray | Dataset:
|
|
228
|
+
from rioxarray.rioxarray import _generate_spatial_coords
|
|
229
|
+
|
|
230
|
+
bbox = shapely.box(*bounds)
|
|
231
|
+
arr, crs, descriptions = rasterio_to_numpy(
|
|
232
|
+
path, bbox, return_attrs=["crs", "descriptions"]
|
|
233
|
+
)
|
|
234
|
+
bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(crs).total_bounds
|
|
235
|
+
|
|
236
|
+
if not all(arr.shape):
|
|
237
|
+
return DataArray(
|
|
238
|
+
arr,
|
|
239
|
+
dims=["y", "x"],
|
|
240
|
+
attrs={"crs": crs},
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if len(arr.shape) == 2:
|
|
244
|
+
height, width = arr.shape
|
|
245
|
+
elif len(arr.shape) == 3 and arr.shape[0] == 1:
|
|
246
|
+
arr = arr[0]
|
|
247
|
+
height, width = arr.shape
|
|
248
|
+
elif len(arr.shape) == 3:
|
|
249
|
+
height, width = arr.shape[1:]
|
|
250
|
+
else:
|
|
251
|
+
raise ValueError(arr.shape)
|
|
252
|
+
|
|
253
|
+
transform = rasterio.transform.from_bounds(*bounds_in_img_crs, width, height)
|
|
254
|
+
coords = _generate_spatial_coords(transform, width, height)
|
|
255
|
+
|
|
256
|
+
if len(arr.shape) == 2:
|
|
257
|
+
ds = DataArray(
|
|
258
|
+
arr,
|
|
259
|
+
coords=coords,
|
|
260
|
+
dims=["y", "x"],
|
|
261
|
+
attrs={"crs": crs},
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
if len(descriptions) != arr.shape[0]:
|
|
265
|
+
descriptions = range(arr.shape[0])
|
|
266
|
+
ds = Dataset(
|
|
267
|
+
{
|
|
268
|
+
desc: DataArray(
|
|
269
|
+
arr[i],
|
|
270
|
+
coords=coords,
|
|
271
|
+
dims=["y", "x"],
|
|
272
|
+
attrs={"crs": crs},
|
|
273
|
+
name=desc,
|
|
274
|
+
)
|
|
275
|
+
for i, desc in enumerate(descriptions)
|
|
276
|
+
}
|
|
277
|
+
)
|
|
278
|
+
return self._run_code_block(ds)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def rasterio_to_numpy(
|
|
282
|
+
img_path: str, bbox: Polygon, return_attrs: list[str] | None = None
|
|
283
|
+
) -> np.ndarray | tuple[Any]:
|
|
284
|
+
with rasterio.open(img_path) as src:
|
|
285
|
+
bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(src.crs).total_bounds
|
|
286
|
+
window = rasterio.windows.from_bounds(
|
|
287
|
+
*bounds_in_img_crs, transform=src.transform
|
|
288
|
+
)
|
|
289
|
+
arr = src.read(window=window, boundless=False, masked=False)
|
|
290
|
+
if not return_attrs:
|
|
291
|
+
return arr
|
|
292
|
+
return (arr, *[getattr(src, attr) for attr in return_attrs])
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _pd():
|
|
296
|
+
"""Function that makes sure 'pd' is not removed by 'ruff' fixes. Because pd is useful in code_block."""
|
|
297
|
+
pd
|
|
@@ -87,7 +87,36 @@ if not DEBUG:
|
|
|
87
87
|
return decorator
|
|
88
88
|
|
|
89
89
|
|
|
90
|
+
def get_xarray_resolution(ds) -> int:
|
|
91
|
+
minx, miny, maxx, maxy = _get_raw_xarray_bounds(ds)
|
|
92
|
+
diffx = maxx - minx
|
|
93
|
+
diffy = maxy - miny
|
|
94
|
+
try:
|
|
95
|
+
resx = diffx / (ds.sizes["x"] - 1)
|
|
96
|
+
resy = diffy / (ds.sizes["y"] - 1)
|
|
97
|
+
except ZeroDivisionError:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Cannot calculate resolution for Dataset with {diffx=}, {diffy=}, {ds.sizes['x']=}, {ds.sizes['y']=}"
|
|
100
|
+
)
|
|
101
|
+
if resx != resy:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"x and y resolution differ: resx={resx}, resy={resy} for Dataset: {ds}"
|
|
104
|
+
)
|
|
105
|
+
return resx
|
|
106
|
+
|
|
107
|
+
|
|
90
108
|
def get_xarray_bounds(ds) -> tuple[float, float, float, float]:
|
|
109
|
+
res = get_xarray_resolution(ds)
|
|
110
|
+
minx, miny, maxx, maxy = _get_raw_xarray_bounds(ds)
|
|
111
|
+
return (
|
|
112
|
+
minx - res / 2,
|
|
113
|
+
miny - res / 2,
|
|
114
|
+
maxx + res / 2,
|
|
115
|
+
maxy + res / 2,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _get_raw_xarray_bounds(ds) -> tuple[float, float, float, float]:
|
|
91
120
|
return (
|
|
92
121
|
float(ds["x"].min().values),
|
|
93
122
|
float(ds["y"].min().values),
|
|
@@ -1,199 +0,0 @@
|
|
|
1
|
-
import abc
|
|
2
|
-
from typing import ClassVar
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pandas as pd
|
|
6
|
-
import pyproj
|
|
7
|
-
import rasterio
|
|
8
|
-
import shapely
|
|
9
|
-
from geopandas import GeoDataFrame
|
|
10
|
-
from geopandas import GeoSeries
|
|
11
|
-
from shapely.geometry import Polygon
|
|
12
|
-
|
|
13
|
-
try:
|
|
14
|
-
from xarray import DataArray
|
|
15
|
-
from xarray import Dataset
|
|
16
|
-
except ImportError:
|
|
17
|
-
|
|
18
|
-
class DataArray:
|
|
19
|
-
"""Placeholder."""
|
|
20
|
-
|
|
21
|
-
class Dataset:
|
|
22
|
-
"""Placeholder."""
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from .utils import _PROFILE_DICT
|
|
26
|
-
from .utils import get_xarray_bounds
|
|
27
|
-
from .utils import time_method_call
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class AbstractImageConfig(abc.ABC):
|
|
31
|
-
rgb_bands: ClassVar[list[str] | None] = None
|
|
32
|
-
reducer: ClassVar[str | None] = None
|
|
33
|
-
|
|
34
|
-
def __init__(self, code_block: str | None = None) -> None:
|
|
35
|
-
self._code_block = code_block
|
|
36
|
-
self.code_block = code_block # trigger setter
|
|
37
|
-
|
|
38
|
-
@property
|
|
39
|
-
def code_block(self) -> str | None:
|
|
40
|
-
return self._code_block
|
|
41
|
-
|
|
42
|
-
@code_block.setter
|
|
43
|
-
def code_block(self, value: str | None):
|
|
44
|
-
if value and (
|
|
45
|
-
"xarr=" not in value.replace(" ", "")
|
|
46
|
-
or not any(txt in value.replace(" ", "") for txt in ("=ds", "(ds"))
|
|
47
|
-
):
|
|
48
|
-
raise ValueError(
|
|
49
|
-
"'code_block' must be a piece of code that takes the xarray object 'ds' and defines the object 'xarr'. "
|
|
50
|
-
f"Got '{value}'"
|
|
51
|
-
)
|
|
52
|
-
self._code_block = value
|
|
53
|
-
|
|
54
|
-
@abc.abstractmethod
|
|
55
|
-
def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
|
|
56
|
-
pass
|
|
57
|
-
|
|
58
|
-
@abc.abstractmethod
|
|
59
|
-
def get_bounds(self, ds: Dataset, path: str) -> tuple[float, float, float, float]:
|
|
60
|
-
pass
|
|
61
|
-
|
|
62
|
-
@time_method_call(_PROFILE_DICT)
|
|
63
|
-
def filter_ds(
|
|
64
|
-
self,
|
|
65
|
-
ds: Dataset,
|
|
66
|
-
bounds: tuple[float, float, float, float],
|
|
67
|
-
code_block: str | None,
|
|
68
|
-
) -> GeoDataFrame | None:
|
|
69
|
-
crs = self.get_crs(ds, None)
|
|
70
|
-
ds_bounds = get_xarray_bounds(ds)
|
|
71
|
-
|
|
72
|
-
bbox_correct_crs = (
|
|
73
|
-
GeoSeries([shapely.box(*bounds)], crs=4326).to_crs(crs).union_all()
|
|
74
|
-
)
|
|
75
|
-
clipped_bbox = bbox_correct_crs.intersection(shapely.box(*ds_bounds))
|
|
76
|
-
minx, miny, maxx, maxy = clipped_bbox.bounds
|
|
77
|
-
|
|
78
|
-
ds = ds.sel(
|
|
79
|
-
x=slice(minx, maxx),
|
|
80
|
-
y=slice(maxy, miny),
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
return _run_code_block(ds, code_block)
|
|
84
|
-
|
|
85
|
-
@time_method_call(_PROFILE_DICT)
|
|
86
|
-
def to_numpy(self, xarr: Dataset | DataArray) -> GeoDataFrame | None:
|
|
87
|
-
if isinstance(xarr, Dataset) and len(xarr.data_vars) == 1:
|
|
88
|
-
xarr = xarr[next(iter(xarr.data_vars))]
|
|
89
|
-
elif isinstance(xarr, Dataset):
|
|
90
|
-
try:
|
|
91
|
-
xarr = xarr[self.rgb_bands]
|
|
92
|
-
except Exception:
|
|
93
|
-
pass
|
|
94
|
-
|
|
95
|
-
if "time" in set(xarr.dims) and (
|
|
96
|
-
not hasattr(xarr["time"].values, "__len__") or len(xarr["time"].values) > 1
|
|
97
|
-
):
|
|
98
|
-
if self.reducer is None:
|
|
99
|
-
xarr = xarr.isel(time=0)
|
|
100
|
-
else:
|
|
101
|
-
xarr = getattr(xarr, self.reducer)(dim="time")
|
|
102
|
-
|
|
103
|
-
if isinstance(xarr, Dataset) and self.rgb_bands:
|
|
104
|
-
return np.array([xarr[band].values for band in self.rgb_bands])
|
|
105
|
-
elif isinstance(xarr, Dataset):
|
|
106
|
-
return np.array([xarr[var].values for var in xarr.data_vars])
|
|
107
|
-
|
|
108
|
-
if isinstance(xarr, np.ndarray):
|
|
109
|
-
return xarr
|
|
110
|
-
|
|
111
|
-
return xarr.values
|
|
112
|
-
|
|
113
|
-
def __str__(self) -> str:
|
|
114
|
-
code_block = f"'{self.code_block}'" if self.code_block else None
|
|
115
|
-
return f"{self.__class__.__name__}({code_block})"
|
|
116
|
-
|
|
117
|
-
def __repr__(self) -> str:
|
|
118
|
-
return str(self)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def _run_code_block(
|
|
122
|
-
ds: DataArray | Dataset, code_block: str | None
|
|
123
|
-
) -> Dataset | DataArray:
|
|
124
|
-
if not code_block:
|
|
125
|
-
return ds
|
|
126
|
-
|
|
127
|
-
try:
|
|
128
|
-
xarr = eval(code_block)
|
|
129
|
-
if callable(xarr):
|
|
130
|
-
xarr = xarr(ds)
|
|
131
|
-
except SyntaxError:
|
|
132
|
-
loc = {}
|
|
133
|
-
exec(code_block, globals() | {"ds": ds}, loc)
|
|
134
|
-
xarr = loc["xarr"]
|
|
135
|
-
|
|
136
|
-
if isinstance(xarr, np.ndarray) and isinstance(ds, DataArray):
|
|
137
|
-
ds.values = xarr
|
|
138
|
-
return ds
|
|
139
|
-
elif isinstance(xarr, np.ndarray) and isinstance(ds, Dataset):
|
|
140
|
-
raise ValueError(
|
|
141
|
-
"Cannot return np.ndarray from 'code_block' if ds is xarray.Dataset."
|
|
142
|
-
)
|
|
143
|
-
return xarr
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
class GeoTIFFConfig(AbstractImageConfig):
|
|
147
|
-
def get_crs(self, ds, path):
|
|
148
|
-
with rasterio.open(path) as src:
|
|
149
|
-
return src.crs
|
|
150
|
-
|
|
151
|
-
def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
|
|
152
|
-
with rasterio.open(path) as src:
|
|
153
|
-
return tuple(
|
|
154
|
-
GeoSeries([shapely.box(*src.bounds)], crs=src.crs)
|
|
155
|
-
.to_crs(4326)
|
|
156
|
-
.total_bounds
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
class NetCDFConfig(AbstractImageConfig):
|
|
161
|
-
"""Sets the configuration for reading NetCDF files and getting crs and bounds.
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
code_block: String of Python code that takes an xarray.Dataset (ds) and returns an xarray.DataArray (xarr).
|
|
165
|
-
Note that the input Dataset must be references as 'ds' and the output must be assigned to 'xarr'.
|
|
166
|
-
"""
|
|
167
|
-
|
|
168
|
-
rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
|
|
169
|
-
|
|
170
|
-
def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
|
|
171
|
-
return get_xarray_bounds(ds)
|
|
172
|
-
|
|
173
|
-
def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
|
|
174
|
-
attrs = [x for x in ds.attrs if "projection" in x.lower() or "crs" in x.lower()]
|
|
175
|
-
if not attrs:
|
|
176
|
-
raise ValueError(f"Could not find CRS attribute in dataset: {ds}")
|
|
177
|
-
for i, attr in enumerate(attrs):
|
|
178
|
-
try:
|
|
179
|
-
return pyproj.CRS(ds.attrs[attr])
|
|
180
|
-
except Exception as e:
|
|
181
|
-
if i == len(attrs) - 1:
|
|
182
|
-
attrs_dict = {attr: ds.attrs[attr] for attr in attrs}
|
|
183
|
-
raise ValueError(
|
|
184
|
-
f"No valid CRS attribute found among {attrs_dict}"
|
|
185
|
-
) from e
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
class NBSNetCDFConfig(NetCDFConfig):
|
|
189
|
-
def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
|
|
190
|
-
return pyproj.CRS(ds.UTM_projection.epsg_code)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
class Sentinel2NBSNetCDFConfig(NBSNetCDFConfig):
|
|
194
|
-
rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def _pd():
|
|
198
|
-
"""Function that makes sure 'pd' is not removed by 'ruff' fixes. Because pd is useful in code_block."""
|
|
199
|
-
pd
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|