geo-explorer 0.9.10__tar.gz → 0.9.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: geo-explorer
3
- Version: 0.9.10
3
+ Version: 0.9.12
4
4
  Summary: Explore geodata interactively.
5
5
  License: MIT
6
6
  Author: Morten Letnes
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "geo-explorer"
3
- version = "0.9.10"
3
+ version = "0.9.12"
4
4
  description = "Explore geodata interactively."
5
5
  authors = ["Morten Letnes <morten.letnes@ssb.no>"]
6
6
  license = "MIT"
@@ -19,6 +19,9 @@ from .utils import _clicked_button_style
19
19
  from .utils import _standardize_path
20
20
  from .utils import _unclicked_button_style
21
21
  from .utils import get_button_with_tooltip
22
+ from .utils import time_function_call
23
+ from .utils import time_method_call
24
+ from .utils import _PROFILE_DICT
22
25
 
23
26
 
24
27
  class FileBrowser:
@@ -29,15 +32,18 @@ class FileBrowser:
29
32
  start_dir: str,
30
33
  favorites: list[str] | None = None,
31
34
  file_system: AbstractFileSystem | None = None,
35
+ sum_partition_sizes: bool = True,
32
36
  ) -> None:
33
37
  self.start_dir = _standardize_path(start_dir)
34
38
  self.file_system = file_system
35
39
  self.favorites = (
36
40
  [_standardize_path(x) for x in favorites] if favorites is not None else []
37
41
  )
42
+ self.sum_partition_sizes = sum_partition_sizes
38
43
  self._history = [self.start_dir]
39
44
  self._register_callbacks()
40
45
 
46
+ @time_method_call(_PROFILE_DICT)
41
47
  def get_file_browser_components(self, width: str = "140vh") -> list[Component]:
42
48
  return [
43
49
  dbc.Row(
@@ -295,6 +301,7 @@ class FileBrowser:
295
301
  State("file-list", "children"),
296
302
  State("file-data-dict", "data"),
297
303
  )
304
+ @time_method_call(_PROFILE_DICT)
298
305
  def update_file_list(
299
306
  path,
300
307
  search_word,
@@ -350,6 +357,7 @@ class FileBrowser:
350
357
 
351
358
  return (file_data_dict, file_list, alert, sort_by_clicks, self._history[1:])
352
359
 
360
+ @time_method_call(_PROFILE_DICT)
353
361
  def _list_dir(
354
362
  self,
355
363
  path: str,
@@ -386,11 +394,13 @@ class FileBrowser:
386
394
 
387
395
  if (recursive or 0) % 2 == 0:
388
396
 
397
+ @time_function_call(_PROFILE_DICT)
389
398
  def _ls(path):
390
399
  return file_system.ls(path, detail=True)
391
400
 
392
401
  else:
393
402
 
403
+ @time_function_call(_PROFILE_DICT)
394
404
  def _ls(path):
395
405
  path = str(Path(path) / "**")
396
406
  return _try_glob(path, file_system)
@@ -417,7 +427,7 @@ class FileBrowser:
417
427
  if isinstance(paths, dict):
418
428
  paths = list(paths.values())
419
429
 
420
- def is_dir_or_is_partitioned_parquet(x) -> bool:
430
+ def is_dir_or_relevant_file_format(x) -> bool:
421
431
  return x["type"] == "directory" or any(
422
432
  x["name"].endswith(txt) for txt in self.file_formats
423
433
  )
@@ -427,13 +437,35 @@ class FileBrowser:
427
437
  for x in paths
428
438
  if isinstance(x, dict)
429
439
  and _contains(x["name"])
430
- and is_dir_or_is_partitioned_parquet(x)
440
+ and is_dir_or_relevant_file_format(x)
431
441
  and Path(path).parts != Path(x["name"]).parts
432
442
  ]
433
443
 
434
444
  paths.sort(key=lambda x: x["name"])
435
445
  isdir_list = [x["type"] == "directory" for x in paths]
436
446
 
447
+ if self.sum_partition_sizes:
448
+ self._sum_partition_sizes(paths, file_system)
449
+
450
+ return (
451
+ paths,
452
+ [
453
+ _get_file_list_row(
454
+ x["name"],
455
+ x.get("updated", None),
456
+ x["size"],
457
+ isdir,
458
+ path,
459
+ self.file_formats,
460
+ file_system,
461
+ )
462
+ for x, isdir in zip(paths, isdir_list, strict=True)
463
+ if isinstance(x, dict)
464
+ ],
465
+ None,
466
+ )
467
+
468
+ def _sum_partition_sizes(self, paths, file_system):
437
469
  partitioned = {
438
470
  i: x
439
471
  for i, x in enumerate(paths)
@@ -444,6 +476,7 @@ class FileBrowser:
444
476
  )
445
477
  }
446
478
 
479
+ @time_function_call(_PROFILE_DICT)
447
480
  def get_summed_size_and_latest_timestamp_in_subdirs(
448
481
  x,
449
482
  ) -> tuple[float, datetime.datetime]:
@@ -462,11 +495,9 @@ class FileBrowser:
462
495
  )
463
496
 
464
497
  with ThreadPoolExecutor() as executor:
465
- summed_size_ant_time = list(
466
- executor.map(
467
- get_summed_size_and_latest_timestamp_in_subdirs,
468
- partitioned.values(),
469
- )
498
+ summed_size_ant_time = executor.map(
499
+ get_summed_size_and_latest_timestamp_in_subdirs,
500
+ partitioned.values(),
470
501
  )
471
502
  for i, (size, timestamp) in zip(
472
503
  partitioned, summed_size_ant_time, strict=True
@@ -474,25 +505,8 @@ class FileBrowser:
474
505
  paths[i]["size"] = size
475
506
  paths[i]["updated"] = timestamp
476
507
 
477
- return (
478
- paths,
479
- [
480
- _get_file_list_row(
481
- x["name"],
482
- x.get("updated", None),
483
- x["size"],
484
- isdir,
485
- path,
486
- self.file_formats,
487
- file_system,
488
- )
489
- for x, isdir in zip(paths, isdir_list, strict=True)
490
- if isinstance(x, dict)
491
- ],
492
- None,
493
- )
494
-
495
508
 
509
+ @time_function_call(_PROFILE_DICT)
496
510
  def _get_file_list_row(
497
511
  path, timestamp, size, isdir: bool, current_path, file_formats, file_system
498
512
  ):
@@ -564,6 +578,7 @@ def _get_file_list_row(
564
578
  )
565
579
 
566
580
 
581
+ @time_function_call(_PROFILE_DICT)
567
582
  def _try_glob(path, file_system):
568
583
  try:
569
584
  return file_system.glob(path, detail=True, recursive=True)
@@ -25,6 +25,7 @@ from typing import Any
25
25
  from typing import ClassVar
26
26
 
27
27
  import dash
28
+ import pyproj
28
29
  import dash_bootstrap_components as dbc
29
30
  import dash_leaflet as dl
30
31
  import folium
@@ -39,7 +40,6 @@ import pandas as pd
39
40
  import polars as pl
40
41
  import pyarrow
41
42
  import pyarrow.parquet as pq
42
- import rasterio
43
43
  import sgis as sg
44
44
  import shapely
45
45
  from dash import Dash
@@ -82,9 +82,9 @@ except ImportError:
82
82
 
83
83
  from .file_browser import FileBrowser
84
84
  from .fs import LocalFileSystem
85
- from .nc import GeoTIFFConfig
86
- from .nc import NetCDFConfig
87
- from .nc import _run_code_block
85
+ from .img import GeoTIFFConfig
86
+ from .img import AbstractImageConfig
87
+ from .img import NetCDFConfig
88
88
  from .utils import _PROFILE_DICT
89
89
  from .utils import DEBUG
90
90
  from .utils import _clicked_button_style
@@ -687,12 +687,15 @@ def _read_files(explorer, paths: list[str], mask=None, **kwargs) -> None:
687
687
  paths = [
688
688
  path
689
689
  for path in paths
690
- if mask is None
691
- or (
692
- path in bbox_set
693
- and (
694
- pd.isna(explorer._bbox_series[path])
695
- or shapely.intersects(mask, explorer._bbox_series[path])
690
+ if explorer._loaded_data_sizes[path] > 0
691
+ and (
692
+ mask is None
693
+ or (
694
+ path in bbox_set
695
+ and (
696
+ pd.isna(explorer._bbox_series[path])
697
+ or shapely.intersects(mask, explorer._bbox_series[path])
698
+ )
696
699
  )
697
700
  )
698
701
  ]
@@ -769,30 +772,55 @@ def _try_to_get_bbox_else_none(
769
772
 
770
773
 
771
774
  def _get_bbox_series_as_4326(paths, file_system):
772
- # bbox_series = sg.get_bbox_series(paths, file_system=file_system)
773
- # return bbox_series.to_crs(4326)
774
-
775
775
  func = partial(_try_to_get_bbox_else_none, file_system=file_system)
776
776
  with ThreadPoolExecutor() as executor:
777
777
  bbox_and_crs = list(executor.map(func, paths))
778
778
 
779
- crss = {json.dumps(x[1]) for x in bbox_and_crs}
780
- crss = {
781
- crs
782
- for crs in crss
783
- if not any(str(crs).lower() == txt for txt in ["none", "null"])
784
- }
779
+ bbox_and_crs = [(bbox, json.dumps(crs)) for bbox, crs in bbox_and_crs]
780
+ bbox_and_crs = [
781
+ (
782
+ bbox,
783
+ (
784
+ pyproj.CRS(crs).to_string()
785
+ if not any(crs.lower() == txt for txt in ["none", "null"])
786
+ else None
787
+ ),
788
+ )
789
+ for bbox, crs in bbox_and_crs
790
+ ]
791
+ crss = {crs for (_, crs) in bbox_and_crs}
785
792
  if not crss:
786
793
  return GeoSeries([None for _ in range(len(paths))], index=paths)
787
- crs = get_common_crs(crss)
788
- return GeoSeries(
789
- [
790
- shapely.box(*bbox[0]) if bbox[0] is not None else None
791
- for bbox in bbox_and_crs
792
- ],
793
- index=paths,
794
- crs=crs,
795
- ).to_crs(4326)
794
+ missing = GeoSeries(
795
+ {
796
+ path: Polygon()
797
+ for path, (bbox, _) in zip(paths, bbox_and_crs, strict=True)
798
+ if bbox is None
799
+ }
800
+ )
801
+ paths_bbox_and_crs = {
802
+ path: (bbox, crs)
803
+ for path, (bbox, crs) in zip(paths, bbox_and_crs, strict=True)
804
+ if bbox is not None
805
+ }
806
+ crs_with_paths = {}
807
+ geoms = shapely.box(
808
+ [bbox[0] for (bbox, _) in paths_bbox_and_crs.values()],
809
+ [bbox[1] for (bbox, _) in paths_bbox_and_crs.values()],
810
+ [bbox[2] for (bbox, _) in paths_bbox_and_crs.values()],
811
+ [bbox[3] for (bbox, _) in paths_bbox_and_crs.values()],
812
+ )
813
+ for crs in crss:
814
+ crs_with_paths[crs] = {
815
+ path: geoms[i]
816
+ for i, (path, (bbox, this_crs)) in enumerate(paths_bbox_and_crs.items())
817
+ if this_crs == crs and bbox is not None
818
+ }
819
+ no_crs = GeoSeries(crs_with_paths.pop(None, {}), crs=4326)
820
+ return pd.concat(
821
+ [GeoSeries(data, crs=crs).to_crs(4326) for crs, data in crs_with_paths.items()]
822
+ + [missing, no_crs]
823
+ )
796
824
 
797
825
 
798
826
  def get_index(values: list[Any], ids: list[Any], index: Any):
@@ -1257,7 +1285,9 @@ class GeoExplorer:
1257
1285
  port: int = 8050,
1258
1286
  file_system: AbstractFileSystem | None = None,
1259
1287
  data: (
1260
- dict[str, str | GeoDataFrame | NetCDFConfig] | list[str | dict] | None
1288
+ dict[str, str | GeoDataFrame | AbstractImageConfig]
1289
+ | list[str | dict]
1290
+ | None
1261
1291
  ) = None,
1262
1292
  column: str | None = None,
1263
1293
  color_dict: dict | None = None,
@@ -1273,6 +1303,7 @@ class GeoExplorer:
1273
1303
  nan_color: str = "#969696",
1274
1304
  nan_label: str = "Missing",
1275
1305
  max_read_size_per_callback: int = 1e9,
1306
+ sum_partition_sizes: bool = True,
1276
1307
  **kwargs,
1277
1308
  ) -> None:
1278
1309
  """Initialiser."""
@@ -1313,7 +1344,10 @@ class GeoExplorer:
1313
1344
  self._concatted_data: pl.DataFrame | None = None
1314
1345
  self._selected_features = {}
1315
1346
  self._file_browser = FileBrowser(
1316
- start_dir, file_system=file_system, favorites=favorites
1347
+ start_dir,
1348
+ file_system=file_system,
1349
+ favorites=favorites,
1350
+ sum_partition_sizes=sum_partition_sizes,
1317
1351
  )
1318
1352
  self._current_table_view = None
1319
1353
  self.max_read_size_per_callback = max_read_size_per_callback
@@ -1718,17 +1752,25 @@ class GeoExplorer:
1718
1752
  raise ValueError(error_mess)
1719
1753
  for key, value in x.items():
1720
1754
  key = _standardize_path(key)
1721
- if isinstance(value, NetCDFConfig):
1722
- # setting nc files as unchecked because they might be very large
1723
- self.selected_files[key] = False
1724
- self._queries[key] = value.code_block
1755
+ if isinstance(value, AbstractImageConfig):
1725
1756
  self._nc[key] = value
1757
+ # setting image files as unchecked because they might be very large
1758
+ self.selected_files[key] = False
1759
+ self.set_query(key, value.code_block)
1760
+ child_paths = self._get_child_paths([key]) or [key]
1761
+ try:
1762
+ bbox_series_dict |= {
1763
+ x: shapely.box(*self._nc[key].get_bounds(None, x))
1764
+ for x in child_paths
1765
+ }
1766
+ except Exception:
1767
+ pass
1726
1768
  continue
1727
1769
  if value is not None and not isinstance(value, (GeoDataFrame | str)):
1728
1770
  raise ValueError(error_mess)
1729
1771
  elif not isinstance(value, GeoDataFrame):
1730
1772
  self.selected_files[key] = True
1731
- self._queries[key] = value
1773
+ self.set_query(key, value)
1732
1774
  continue
1733
1775
  value, dtypes = _geopandas_to_polars(value, key)
1734
1776
  bbox_series_dict[key] = shapely.box(
@@ -1785,13 +1827,13 @@ class GeoExplorer:
1785
1827
  continue
1786
1828
  loaded_data_sorted[key] = df.with_columns(
1787
1829
  _unique_id=_get_unique_id(self._max_unique_id_int)
1788
- ).drop("id", errors="ignore")
1830
+ ).drop("id", strict=False)
1789
1831
  else:
1790
1832
  x = _standardize_path(x)
1791
1833
  df = self._loaded_data[x]
1792
1834
  loaded_data_sorted[x] = df.with_columns(
1793
1835
  _unique_id=_get_unique_id(self._max_unique_id_int)
1794
- ).drop("id", errors="ignore")
1836
+ ).drop("id", strict=False)
1795
1837
  self._max_unique_id_int += 1
1796
1838
 
1797
1839
  self._loaded_data = loaded_data_sorted
@@ -2157,7 +2199,7 @@ class GeoExplorer:
2157
2199
  and query.startswith("pl.col")
2158
2200
  ):
2159
2201
  query = f"{old_query}, {query}"
2160
- self._queries[path] = query
2202
+ self.set_query(path, query)
2161
2203
  return query
2162
2204
 
2163
2205
  @callback(
@@ -2831,7 +2873,7 @@ class GeoExplorer:
2831
2873
  self.color_dict = {}
2832
2874
  elif not column and triggered is None:
2833
2875
  column = self.column
2834
- elif self._concatted_data is None:
2876
+ if self._concatted_data is None:
2835
2877
  return (
2836
2878
  [],
2837
2879
  None,
@@ -3070,7 +3112,7 @@ class GeoExplorer:
3070
3112
  bins=bins,
3071
3113
  opacity=opacity,
3072
3114
  n_rows_per_path=n_rows_per_path,
3073
- columns=self._columns,
3115
+ columns=self._columns(),
3074
3116
  current_columns=current_columns,
3075
3117
  )
3076
3118
  results = [
@@ -3104,15 +3146,21 @@ class GeoExplorer:
3104
3146
  ):
3105
3147
  _read_files(self, [img_path], mask=bbox)
3106
3148
  img_bbox = self._bbox_series.loc[img_path]
3107
- clipped_bounds = img_bbox.intersection(bbox)
3108
- if clipped_bounds.is_empty:
3149
+ clipped_bbox = img_bbox.intersection(bbox)
3150
+ if clipped_bbox.is_empty:
3109
3151
  continue
3110
3152
  try:
3111
- ds = self._open_img_path_as_xarray(
3112
- img_path, selected_path, clipped_bounds
3153
+ self._nc[selected_path].validate_code_block()
3154
+ ds = self._nc[selected_path].filter_ds(
3155
+ ds=self._loaded_data[img_path],
3156
+ bounds=clipped_bbox.bounds,
3157
+ path=img_path,
3113
3158
  )
3114
3159
  except Exception as e:
3115
3160
  traceback.print_exc()
3161
+ print(
3162
+ "\nNote: the above was a print of the error traceback from invalid query of dataset"
3163
+ )
3116
3164
  alerts.append(
3117
3165
  dbc.Alert(
3118
3166
  f"{type(e).__name__}: {e}. (Traceback printed in terminal)",
@@ -3128,10 +3176,11 @@ class GeoExplorer:
3128
3176
  if np.isnan(arr).any() and not np.all(np.isnan(arr)):
3129
3177
  arr[np.isnan(arr)] = np.min(arr[~np.isnan(arr)])
3130
3178
 
3131
- images[img_path] = (arr, clipped_bounds)
3179
+ images[img_path] = (arr, clipped_bbox)
3132
3180
 
3133
3181
  if images:
3134
3182
  # make sure all single-band images are normalized by same extremities
3183
+ debug_print(images)
3135
3184
  vmin = np.min([np.min(x[0]) for x in images.values()])
3136
3185
  vmax = np.min([np.max(x[0]) for x in images.values()])
3137
3186
 
@@ -3861,7 +3910,6 @@ class GeoExplorer:
3861
3910
  deleted_files.add(path)
3862
3911
  break
3863
3912
 
3864
- assert len(deleted_files) == 1, deleted_files
3865
3913
  deleted_files2 = set()
3866
3914
  for i, path2 in enumerate(list(self._loaded_data)):
3867
3915
  parts = Path(path2).parts
@@ -4000,7 +4048,7 @@ class GeoExplorer:
4000
4048
  if query == self._queries.get(path):
4001
4049
  continue
4002
4050
  self._check_for_circular_queries(query, path)
4003
- self._queries[path] = query
4051
+ self.set_query(path, query)
4004
4052
  except RecursionError as e:
4005
4053
  out_alerts.append(
4006
4054
  dbc.Alert(
@@ -4016,6 +4064,11 @@ class GeoExplorer:
4016
4064
  return out_alerts
4017
4065
  return None
4018
4066
 
4067
+ def set_query(self, key: str, value: str | None) -> None:
4068
+ self._queries[key] = value
4069
+ if key in self._nc:
4070
+ self._nc[key].code_block = value
4071
+
4019
4072
  @time_method_call(_PROFILE_DICT)
4020
4073
  def _get_unique_stem(self, path) -> str:
4021
4074
  name = _get_stem(path)
@@ -4023,19 +4076,6 @@ class GeoExplorer:
4023
4076
  name = _get_stem_from_parent(path)
4024
4077
  return name
4025
4078
 
4026
- def _open_img_path_as_xarray(self, img_path, selected_path, clipped_bounds):
4027
- if is_netcdf(img_path):
4028
- return self._nc[selected_path].filter_ds(
4029
- ds=self._loaded_data[img_path],
4030
- bounds=clipped_bounds.bounds,
4031
- code_block=self._queries.get(selected_path),
4032
- )
4033
- else:
4034
- return rasterio_to_xarray(
4035
- img_path, clipped_bounds, code_block=self._queries.get(selected_path)
4036
- )
4037
-
4038
- @property
4039
4079
  def _columns(self) -> dict[str, set[str]]:
4040
4080
  return {path: set(dtypes) for path, dtypes in self._dtypes.items()} | {
4041
4081
  path: {"value"} for path in self._nc
@@ -4759,74 +4799,6 @@ def get_numeric_colors(values_no_nans_unique, values_no_nans, cmap, k):
4759
4799
  return color_dict, bins
4760
4800
 
4761
4801
 
4762
- def rasterio_to_numpy(
4763
- img_path, bbox, return_attrs: list[str] | None = None
4764
- ) -> np.ndarray | tuple[Any]:
4765
- with rasterio.open(img_path) as src:
4766
- bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(src.crs).total_bounds
4767
- window = rasterio.windows.from_bounds(
4768
- *bounds_in_img_crs, transform=src.transform
4769
- )
4770
- arr = src.read(window=window, boundless=False, masked=False)
4771
- if not return_attrs:
4772
- return arr
4773
- return (arr, *[getattr(src, attr) for attr in return_attrs])
4774
-
4775
-
4776
- def rasterio_to_xarray(img_path, bbox, code_block):
4777
- import xarray as xr
4778
- from rioxarray.rioxarray import _generate_spatial_coords
4779
-
4780
- arr, crs, descriptions = rasterio_to_numpy(
4781
- img_path, bbox, return_attrs=["crs", "descriptions"]
4782
- )
4783
- bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(crs).total_bounds
4784
-
4785
- if not all(arr.shape):
4786
- return xr.DataArray(
4787
- arr,
4788
- dims=["y", "x"],
4789
- attrs={"crs": crs},
4790
- )
4791
-
4792
- if len(arr.shape) == 2:
4793
- height, width = arr.shape
4794
- elif len(arr.shape) == 3 and arr.shape[0] == 1:
4795
- arr = arr[0]
4796
- height, width = arr.shape
4797
- elif len(arr.shape) == 3:
4798
- height, width = arr.shape[1:]
4799
- else:
4800
- raise ValueError(arr.shape)
4801
-
4802
- transform = rasterio.transform.from_bounds(*bounds_in_img_crs, width, height)
4803
- coords = _generate_spatial_coords(transform, width, height)
4804
-
4805
- if len(arr.shape) == 2:
4806
- ds = xr.DataArray(
4807
- arr,
4808
- coords=coords,
4809
- dims=["y", "x"],
4810
- attrs={"crs": crs},
4811
- )
4812
- else:
4813
- if len(descriptions) != arr.shape[0]:
4814
- descriptions = range(arr.shape[0])
4815
- ds = xr.Dataset(
4816
- {
4817
- desc: xr.DataArray(
4818
- arr[i],
4819
- coords=coords,
4820
- dims=["y", "x"],
4821
- attrs={"crs": crs},
4822
- name=desc,
4823
- )
4824
- for i, desc in enumerate(descriptions)
4825
- }
4826
- )
4827
- return _run_code_block(ds, code_block)
4828
-
4829
-
4830
4802
  def as_sized_array(arr: np.ndarray) -> np.ndarray:
4831
4803
  try:
4832
4804
  len(arr)
@@ -0,0 +1,298 @@
1
+ import abc
2
+ from typing import Any
3
+ from typing import ClassVar
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pyproj
8
+ import rasterio
9
+ import shapely
10
+ from geopandas import GeoDataFrame
11
+ from geopandas import GeoSeries
12
+ from shapely.geometry import Polygon
13
+
14
+ try:
15
+ from xarray import DataArray
16
+ from xarray import Dataset
17
+ except ImportError:
18
+
19
+ class DataArray:
20
+ """Placeholder."""
21
+
22
+ class Dataset:
23
+ """Placeholder."""
24
+
25
+
26
+ from .utils import _PROFILE_DICT
27
+ from .utils import get_xarray_bounds
28
+ from .utils import time_method_call
29
+
30
+
31
+ class AbstractImageConfig(abc.ABC):
32
+ rgb_bands: ClassVar[list[str] | None] = None
33
+
34
+ def __init__(self, code_block: str | None = None) -> None:
35
+ self.code_block = code_block
36
+ self.validate_code_block()
37
+
38
+ @abc.abstractmethod
39
+ def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ def get_bounds(self, ds: Dataset, path: str) -> tuple[float, float, float, float]:
44
+ pass
45
+
46
+ @time_method_call(_PROFILE_DICT)
47
+ def filter_ds(
48
+ self,
49
+ ds: Dataset,
50
+ path: str,
51
+ bounds: tuple[float, float, float, float],
52
+ ) -> GeoDataFrame | None:
53
+ crs = self.get_crs(ds, None)
54
+ ds_bounds = get_xarray_bounds(ds)
55
+
56
+ bbox_correct_crs = (
57
+ GeoSeries([shapely.box(*bounds)], crs=4326).to_crs(crs).union_all()
58
+ )
59
+ clipped_bbox = bbox_correct_crs.intersection(shapely.box(*ds_bounds))
60
+ minx, miny, maxx, maxy = clipped_bbox.bounds
61
+
62
+ x = "x" if "x" in list(ds.coords) else "X"
63
+ y = "y" if "y" in list(ds.coords) else "Y"
64
+ ds = ds.sel(**{x: slice(minx, maxx), y: slice(maxy, miny)})
65
+ return self._run_code_block(ds)
66
+
67
+ @time_method_call(_PROFILE_DICT)
68
+ def to_numpy(self, xarr: Dataset | DataArray) -> GeoDataFrame | None:
69
+ if isinstance(xarr, Dataset) and len(xarr.data_vars) == 1:
70
+ xarr = xarr[next(iter(xarr.data_vars))]
71
+ elif isinstance(xarr, Dataset):
72
+ try:
73
+ xarr = xarr[self.rgb_bands]
74
+ except Exception:
75
+ pass
76
+
77
+ if "time" in set(xarr.dims) and (
78
+ not hasattr(xarr["time"].values, "__len__") or len(xarr["time"].values) > 1
79
+ ):
80
+ print("Note: selecting first DataArray in time dimension.")
81
+ xarr = xarr.isel(time=0)
82
+
83
+ if isinstance(xarr, Dataset):
84
+ try:
85
+ return np.array([xarr[band].values for band in self.rgb_bands])
86
+ except Exception:
87
+ pass
88
+ if len(xarr.data_vars) == 3:
89
+ try:
90
+ return np.array([xarr[band].values for band in xarr.data_vars])
91
+ except Exception:
92
+ pass
93
+ arrs = []
94
+ for var in xarr.data_vars:
95
+ arr = xarr[var].values
96
+ if len(arr.shape) == 2:
97
+ arrs.append(arr)
98
+ if len(arr.shape) == 3:
99
+ arrs.append(arr[0])
100
+ return arrs[np.argmax([arr.shape for arr in arrs])]
101
+
102
+ if isinstance(xarr, np.ndarray):
103
+ return xarr
104
+
105
+ return xarr.values
106
+
107
+ def validate_code_block(self) -> None:
108
+ if not self.code_block:
109
+ return
110
+ try:
111
+ assert "ds" in self.code_block
112
+ eval(self.code_block)
113
+ except NameError:
114
+ return
115
+ except SyntaxError:
116
+ pass
117
+ if "xarr=" not in self.code_block.replace(" ", "") or not any(
118
+ txt in self.code_block.replace(" ", "") for txt in ("=ds", "(ds")
119
+ ):
120
+ raise ValueError(
121
+ "'code_block' must be a piece of code that takes the xarray object 'ds' and defines the object 'xarr'. "
122
+ f"Got '{self.code_block}'"
123
+ )
124
+
125
+ def _run_code_block(self, ds: DataArray | Dataset) -> Dataset | DataArray:
126
+ if not self.code_block:
127
+ return ds
128
+
129
+ try:
130
+ xarr = eval(self.code_block)
131
+ if callable(xarr):
132
+ xarr = xarr(ds)
133
+ except SyntaxError:
134
+ loc = {}
135
+ exec(self.code_block, globals() | {"ds": ds}, loc)
136
+ xarr = loc["xarr"]
137
+
138
+ if isinstance(xarr, np.ndarray) and isinstance(ds, DataArray):
139
+ ds.values = xarr
140
+ return ds
141
+ elif isinstance(xarr, np.ndarray) and isinstance(ds, Dataset):
142
+ raise ValueError(
143
+ "Cannot return np.ndarray from 'code_block' if ds is xarray.Dataset."
144
+ )
145
+ return xarr
146
+
147
+ def __str__(self) -> str:
148
+ code_block = f"'{self.code_block}'" if self.code_block else None
149
+ return f"{self.__class__.__name__}({code_block})"
150
+
151
+ def __repr__(self) -> str:
152
+ return str(self)
153
+
154
+
155
+ class NetCDFConfig(AbstractImageConfig):
156
+ """Sets the configuration for reading NetCDF files and getting crs and bounds.
157
+
158
+ Args:
159
+ code_block: String of Python code that takes an xarray.Dataset (ds) and returns an xarray.DataArray (xarr).
160
+ Note that the input Dataset must be references as 'ds' and the output must be assigned to 'xarr'.
161
+ """
162
+
163
+ rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
164
+
165
+ def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
166
+ return get_xarray_bounds(ds)
167
+
168
+ def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
169
+ attrs = [
170
+ x
171
+ for x in set(ds.attrs).union(
172
+ set(ds.data_vars if isinstance(ds, Dataset) else set())
173
+ )
174
+ if "projection" in x.lower() or "crs" in x.lower() or "utm" in x.lower()
175
+ ]
176
+ if not attrs:
177
+ raise ValueError(f"Could not find CRS attribute/data_var in dataset: {ds}")
178
+
179
+ def getattr_xarray(ds, attr):
180
+ x = ds.attrs.get(attr, ds.get(attr))
181
+ if isinstance(x, DataArray):
182
+ try:
183
+ return pyproj.CRS(str(x.attrs["proj4"]))
184
+ except Exception:
185
+ return pyproj.CRS(str(x.values))
186
+ elif x is not None:
187
+ return pyproj.CRS(x)
188
+ raise ValueError(f"Could not find CRS attribute/data_var in dataset: {ds}")
189
+
190
+ for i, attr in enumerate(attrs):
191
+ try:
192
+ return getattr_xarray(ds, attr)
193
+ except Exception as e:
194
+ if i == len(attrs) - 1:
195
+ attrs_dict = {attr: ds.attrs.get(attr, ds[attr]) for attr in attrs}
196
+ raise ValueError(
197
+ f"No valid CRS attribute found among {attrs_dict}"
198
+ ) from e
199
+
200
+
201
+ class NBSNetCDFConfig(NetCDFConfig):
202
+ def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
203
+ return pyproj.CRS(ds.UTM_projection.epsg_code)
204
+
205
+
206
+ class Sentinel2NBSNetCDFConfig(NBSNetCDFConfig):
207
+ rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
208
+
209
+
210
+ class GeoTIFFConfig(AbstractImageConfig):
211
+ def get_crs(self, ds, path):
212
+ with rasterio.open(path) as src:
213
+ return src.crs
214
+
215
+ def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
216
+ with rasterio.open(path) as src:
217
+ return tuple(
218
+ GeoSeries([shapely.box(*src.bounds)], crs=src.crs)
219
+ .to_crs(4326)
220
+ .total_bounds
221
+ )
222
+
223
+ def filter_ds(
224
+ self,
225
+ ds: None,
226
+ path: str,
227
+ bounds: tuple[float, float, float, float],
228
+ ) -> DataArray | Dataset:
229
+ from rioxarray.rioxarray import _generate_spatial_coords
230
+
231
+ bbox = shapely.box(*bounds)
232
+ arr, crs, descriptions = rasterio_to_numpy(
233
+ path, bbox, return_attrs=["crs", "descriptions"]
234
+ )
235
+ bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(crs).total_bounds
236
+
237
+ if not all(arr.shape):
238
+ return DataArray(
239
+ arr,
240
+ dims=["y", "x"],
241
+ attrs={"crs": crs},
242
+ )
243
+
244
+ if len(arr.shape) == 2:
245
+ height, width = arr.shape
246
+ elif len(arr.shape) == 3 and arr.shape[0] == 1:
247
+ arr = arr[0]
248
+ height, width = arr.shape
249
+ elif len(arr.shape) == 3:
250
+ height, width = arr.shape[1:]
251
+ else:
252
+ raise ValueError(arr.shape)
253
+
254
+ transform = rasterio.transform.from_bounds(*bounds_in_img_crs, width, height)
255
+ coords = _generate_spatial_coords(transform, width, height)
256
+
257
+ if len(arr.shape) == 2:
258
+ ds = DataArray(
259
+ arr,
260
+ coords=coords,
261
+ dims=["y", "x"],
262
+ attrs={"crs": crs},
263
+ )
264
+ else:
265
+ if len(descriptions) != arr.shape[0]:
266
+ descriptions = range(arr.shape[0])
267
+ ds = Dataset(
268
+ {
269
+ desc: DataArray(
270
+ arr[i],
271
+ coords=coords,
272
+ dims=["y", "x"],
273
+ attrs={"crs": crs},
274
+ name=desc,
275
+ )
276
+ for i, desc in enumerate(descriptions)
277
+ }
278
+ )
279
+ return self._run_code_block(ds)
280
+
281
+
282
+ def rasterio_to_numpy(
283
+ img_path: str, bbox: Polygon, return_attrs: list[str] | None = None
284
+ ) -> np.ndarray | tuple[Any]:
285
+ with rasterio.open(img_path) as src:
286
+ bounds_in_img_crs = GeoSeries([bbox], crs=4326).to_crs(src.crs).total_bounds
287
+ window = rasterio.windows.from_bounds(
288
+ *bounds_in_img_crs, transform=src.transform
289
+ )
290
+ arr = src.read(window=window, boundless=False, masked=False)
291
+ if not return_attrs:
292
+ return arr
293
+ return (arr, *[getattr(src, attr) for attr in return_attrs])
294
+
295
+
296
+ def _pd():
297
+ """Function that makes sure 'pd' is not removed by 'ruff' fixes. Because pd is useful in code_block."""
298
+ pd
@@ -87,12 +87,45 @@ if not DEBUG:
87
87
  return decorator
88
88
 
89
89
 
90
+ def get_xarray_resolution(ds) -> int:
91
+ minx, miny, maxx, maxy = _get_raw_xarray_bounds(ds)
92
+ diffx = maxx - minx
93
+ diffy = maxy - miny
94
+ x = "x" if "x" in list(ds.coords) else "X"
95
+ y = "y" if "y" in list(ds.coords) else "Y"
96
+ try:
97
+ resx = diffx / (ds.sizes[x] - 1)
98
+ resy = diffy / (ds.sizes[y] - 1)
99
+ except ZeroDivisionError:
100
+ raise ValueError(
101
+ f"Cannot calculate resolution for Dataset with {diffx=}, {diffy=}, {ds.sizes['x']=}, {ds.sizes['y']=}"
102
+ )
103
+ if resx != resy:
104
+ raise ValueError(
105
+ f"x and y resolution differ: resx={resx}, resy={resy} for Dataset: {ds}"
106
+ )
107
+ return resx
108
+
109
+
90
110
  def get_xarray_bounds(ds) -> tuple[float, float, float, float]:
111
+ res = get_xarray_resolution(ds)
112
+ minx, miny, maxx, maxy = _get_raw_xarray_bounds(ds)
113
+ return (
114
+ minx - res / 2,
115
+ miny - res / 2,
116
+ maxx + res / 2,
117
+ maxy + res / 2,
118
+ )
119
+
120
+
121
+ def _get_raw_xarray_bounds(ds) -> tuple[float, float, float, float]:
122
+ x = "x" if "x" in list(ds.coords) else "X"
123
+ y = "y" if "y" in list(ds.coords) else "Y"
91
124
  return (
92
- float(ds["x"].min().values),
93
- float(ds["y"].min().values),
94
- float(ds["x"].max().values),
95
- float(ds["y"].max().values),
125
+ float(ds[x].min().values),
126
+ float(ds[y].min().values),
127
+ float(ds[x].max().values),
128
+ float(ds[y].max().values),
96
129
  )
97
130
 
98
131
 
@@ -1,199 +0,0 @@
1
- import abc
2
- from typing import ClassVar
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import pyproj
7
- import rasterio
8
- import shapely
9
- from geopandas import GeoDataFrame
10
- from geopandas import GeoSeries
11
- from shapely.geometry import Polygon
12
-
13
- try:
14
- from xarray import DataArray
15
- from xarray import Dataset
16
- except ImportError:
17
-
18
- class DataArray:
19
- """Placeholder."""
20
-
21
- class Dataset:
22
- """Placeholder."""
23
-
24
-
25
- from .utils import _PROFILE_DICT
26
- from .utils import get_xarray_bounds
27
- from .utils import time_method_call
28
-
29
-
30
- class AbstractImageConfig(abc.ABC):
31
- rgb_bands: ClassVar[list[str] | None] = None
32
- reducer: ClassVar[str | None] = None
33
-
34
- def __init__(self, code_block: str | None = None) -> None:
35
- self._code_block = code_block
36
- self.code_block = code_block # trigger setter
37
-
38
- @property
39
- def code_block(self) -> str | None:
40
- return self._code_block
41
-
42
- @code_block.setter
43
- def code_block(self, value: str | None):
44
- if value and (
45
- "xarr=" not in value.replace(" ", "")
46
- or not any(txt in value.replace(" ", "") for txt in ("=ds", "(ds"))
47
- ):
48
- raise ValueError(
49
- "'code_block' must be a piece of code that takes the xarray object 'ds' and defines the object 'xarr'. "
50
- f"Got '{value}'"
51
- )
52
- self._code_block = value
53
-
54
- @abc.abstractmethod
55
- def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
56
- pass
57
-
58
- @abc.abstractmethod
59
- def get_bounds(self, ds: Dataset, path: str) -> tuple[float, float, float, float]:
60
- pass
61
-
62
- @time_method_call(_PROFILE_DICT)
63
- def filter_ds(
64
- self,
65
- ds: Dataset,
66
- bounds: tuple[float, float, float, float],
67
- code_block: str | None,
68
- ) -> GeoDataFrame | None:
69
- crs = self.get_crs(ds, None)
70
- ds_bounds = get_xarray_bounds(ds)
71
-
72
- bbox_correct_crs = (
73
- GeoSeries([shapely.box(*bounds)], crs=4326).to_crs(crs).union_all()
74
- )
75
- clipped_bbox = bbox_correct_crs.intersection(shapely.box(*ds_bounds))
76
- minx, miny, maxx, maxy = clipped_bbox.bounds
77
-
78
- ds = ds.sel(
79
- x=slice(minx, maxx),
80
- y=slice(maxy, miny),
81
- )
82
-
83
- return _run_code_block(ds, code_block)
84
-
85
- @time_method_call(_PROFILE_DICT)
86
- def to_numpy(self, xarr: Dataset | DataArray) -> GeoDataFrame | None:
87
- if isinstance(xarr, Dataset) and len(xarr.data_vars) == 1:
88
- xarr = xarr[next(iter(xarr.data_vars))]
89
- elif isinstance(xarr, Dataset):
90
- try:
91
- xarr = xarr[self.rgb_bands]
92
- except Exception:
93
- pass
94
-
95
- if "time" in set(xarr.dims) and (
96
- not hasattr(xarr["time"].values, "__len__") or len(xarr["time"].values) > 1
97
- ):
98
- if self.reducer is None:
99
- xarr = xarr.isel(time=0)
100
- else:
101
- xarr = getattr(xarr, self.reducer)(dim="time")
102
-
103
- if isinstance(xarr, Dataset) and self.rgb_bands:
104
- return np.array([xarr[band].values for band in self.rgb_bands])
105
- elif isinstance(xarr, Dataset):
106
- return np.array([xarr[var].values for var in xarr.data_vars])
107
-
108
- if isinstance(xarr, np.ndarray):
109
- return xarr
110
-
111
- return xarr.values
112
-
113
- def __str__(self) -> str:
114
- code_block = f"'{self.code_block}'" if self.code_block else None
115
- return f"{self.__class__.__name__}({code_block})"
116
-
117
- def __repr__(self) -> str:
118
- return str(self)
119
-
120
-
121
- def _run_code_block(
122
- ds: DataArray | Dataset, code_block: str | None
123
- ) -> Dataset | DataArray:
124
- if not code_block:
125
- return ds
126
-
127
- try:
128
- xarr = eval(code_block)
129
- if callable(xarr):
130
- xarr = xarr(ds)
131
- except SyntaxError:
132
- loc = {}
133
- exec(code_block, globals() | {"ds": ds}, loc)
134
- xarr = loc["xarr"]
135
-
136
- if isinstance(xarr, np.ndarray) and isinstance(ds, DataArray):
137
- ds.values = xarr
138
- return ds
139
- elif isinstance(xarr, np.ndarray) and isinstance(ds, Dataset):
140
- raise ValueError(
141
- "Cannot return np.ndarray from 'code_block' if ds is xarray.Dataset."
142
- )
143
- return xarr
144
-
145
-
146
- class GeoTIFFConfig(AbstractImageConfig):
147
- def get_crs(self, ds, path):
148
- with rasterio.open(path) as src:
149
- return src.crs
150
-
151
- def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
152
- with rasterio.open(path) as src:
153
- return tuple(
154
- GeoSeries([shapely.box(*src.bounds)], crs=src.crs)
155
- .to_crs(4326)
156
- .total_bounds
157
- )
158
-
159
-
160
- class NetCDFConfig(AbstractImageConfig):
161
- """Sets the configuration for reading NetCDF files and getting crs and bounds.
162
-
163
- Args:
164
- code_block: String of Python code that takes an xarray.Dataset (ds) and returns an xarray.DataArray (xarr).
165
- Note that the input Dataset must be references as 'ds' and the output must be assigned to 'xarr'.
166
- """
167
-
168
- rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
169
-
170
- def get_bounds(self, ds, path) -> tuple[float, float, float, float]:
171
- return get_xarray_bounds(ds)
172
-
173
- def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
174
- attrs = [x for x in ds.attrs if "projection" in x.lower() or "crs" in x.lower()]
175
- if not attrs:
176
- raise ValueError(f"Could not find CRS attribute in dataset: {ds}")
177
- for i, attr in enumerate(attrs):
178
- try:
179
- return pyproj.CRS(ds.attrs[attr])
180
- except Exception as e:
181
- if i == len(attrs) - 1:
182
- attrs_dict = {attr: ds.attrs[attr] for attr in attrs}
183
- raise ValueError(
184
- f"No valid CRS attribute found among {attrs_dict}"
185
- ) from e
186
-
187
-
188
- class NBSNetCDFConfig(NetCDFConfig):
189
- def get_crs(self, ds: Dataset, path: str) -> pyproj.CRS:
190
- return pyproj.CRS(ds.UTM_projection.epsg_code)
191
-
192
-
193
- class Sentinel2NBSNetCDFConfig(NBSNetCDFConfig):
194
- rgb_bands: ClassVar[list[str]] = ["B4", "B3", "B2"]
195
-
196
-
197
- def _pd():
198
- """Function that makes sure 'pd' is not removed by 'ruff' fixes. Because pd is useful in code_block."""
199
- pd
File without changes
File without changes
File without changes