satdatakit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
satdatakit/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ """SatDataKit - Unified satellite data analysis toolkit.
2
+
3
+ Author: Rafael Cañete Vazquez
4
+ License: MIT
5
+ """
6
+ __version__ = "0.1.0"
7
+ __author__ = "Rafael Cañete Vazquez"
8
+
9
+ from satdatakit.core import SatelliteDataset
10
+ from satdatakit.io import read, read_collection
11
+ from satdatakit.indices import compute_index
12
+ from satdatakit.pipeline import Pipeline
13
+
14
+ __all__ = ["SatelliteDataset", "read", "read_collection", "compute_index", "Pipeline"]
satdatakit/core.py ADDED
@@ -0,0 +1,233 @@
1
+ """Core data model: SatelliteDataset.
2
+
3
+ Author: Rafael Cañete Vazquez
4
+ License: MIT
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import warnings
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import xarray as xr
17
+ from shapely.geometry import box
18
+
19
+
20
+ @dataclass
21
+ class SatelliteDataset:
22
+ """Universal container for Earth Observation data."""
23
+
24
+ data: xr.DataArray
25
+ bands: List[str]
26
+ crs: Optional[str] = None
27
+ resolution: Optional[Tuple[float, float]] = None
28
+ bounds: Optional[Tuple[float, float, float, float]] = None
29
+ datetime: Optional[Union[datetime, List[datetime]]] = None
30
+ sensor: Optional[str] = None
31
+ platform: Optional[str] = None
32
+ cloud_cover: Optional[float] = None
33
+ metadata: Dict[str, Any] = field(default_factory=dict)
34
+ source_format: Optional[str] = None
35
+ source_path: Optional[Path] = None
36
+
37
+ def __post_init__(self) -> None:
38
+ self._validate_data()
39
+ self._normalize_bands()
40
+
41
+ def _validate_data(self) -> None:
42
+ dims = list(self.data.dims)
43
+ if "band" not in dims:
44
+ raise ValueError(f"DataArray must have 'band' dimension. Got: {dims}")
45
+ if "y" not in dims or "x" not in dims:
46
+ raise ValueError(f"DataArray must have 'y' and 'x' dimensions. Got: {dims}")
47
+ n_bands = self.data.sizes["band"]
48
+ if n_bands != len(self.bands):
49
+ warnings.warn(f"Band count mismatch: {n_bands} vs {len(self.bands)}", UserWarning)
50
+ self.bands = [f"band_{i}" for i in range(n_bands)]
51
+
52
+ def _normalize_bands(self) -> None:
53
+ self.bands = [str(b) for b in self.bands]
54
+ seen = set()
55
+ for i, name in enumerate(self.bands):
56
+ if name in seen:
57
+ self.bands[i] = f"{name}_{i}"
58
+ seen.add(self.bands[i])
59
+
60
+ @property
61
+ def shape(self) -> Tuple[int, ...]:
62
+ return tuple(self.data.sizes[d] for d in self.data.dims)
63
+
64
+ @property
65
+ def n_bands(self) -> int:
66
+ return self.data.sizes["band"]
67
+
68
+ @property
69
+ def width(self) -> int:
70
+ return self.data.sizes["x"]
71
+
72
+ @property
73
+ def height(self) -> int:
74
+ return self.data.sizes["y"]
75
+
76
+ @property
77
+ def dtype(self):
78
+ """Return data type."""
79
+ return self.data.dtype
80
+
81
+ def __getitem__(self, key: Union[str, int]) -> xr.DataArray:
82
+ if isinstance(key, str):
83
+ if key not in self.bands:
84
+ raise KeyError(f"Band '{key}' not found. Available: {self.bands}")
85
+ idx = self.bands.index(key)
86
+ elif isinstance(key, int):
87
+ idx = key
88
+ else:
89
+ raise TypeError(f"Key must be str or int, got {type(key)}")
90
+ return self.data.isel(band=idx)
91
+
92
+ def get_bands(self, names: List[str]) -> "SatelliteDataset":
93
+ indices = [self.bands.index(n) for n in names if n in self.bands]
94
+ new_data = self.data.isel(band=indices)
95
+ return SatelliteDataset(
96
+ data=new_data, bands=[self.bands[i] for i in indices],
97
+ crs=self.crs, resolution=self.resolution, bounds=self.bounds,
98
+ datetime=self.datetime, sensor=self.sensor, platform=self.platform,
99
+ cloud_cover=self.cloud_cover, metadata=self.metadata.copy(),
100
+ source_format=self.source_format, source_path=self.source_path)
101
+
102
+ def to_numpy(self) -> np.ndarray:
103
+ return self.data.values
104
+
105
+ def to_xarray(self) -> xr.DataArray:
106
+ return self.data
107
+
108
+ def to_dataset(self) -> xr.Dataset:
109
+ datasets = {b: self.data.isel(band=i).drop_vars("band")
110
+ for i, b in enumerate(self.bands)}
111
+ ds = xr.Dataset(datasets)
112
+ if self.crs:
113
+ ds.attrs["crs"] = self.crs
114
+ return ds
115
+
116
+ def add_band(self, name: str, data: Union[np.ndarray, xr.DataArray]) -> "SatelliteDataset":
117
+ if name in self.bands:
118
+ raise ValueError(f"Band '{name}' already exists.")
119
+ if isinstance(data, np.ndarray):
120
+ data = xr.DataArray(data, dims=["y", "x"])
121
+ data = data.expand_dims(band=[name])
122
+ new_data = xr.concat([self.data, data], dim="band")
123
+ return SatelliteDataset(
124
+ data=new_data, bands=self.bands + [name], crs=self.crs,
125
+ resolution=self.resolution, bounds=self.bounds,
126
+ datetime=self.datetime, sensor=self.sensor, platform=self.platform,
127
+ cloud_cover=self.cloud_cover, metadata=self.metadata.copy(),
128
+ source_format=self.source_format, source_path=self.source_path)
129
+
130
+ def remove_band(self, name: str) -> "SatelliteDataset":
131
+ if name not in self.bands:
132
+ raise KeyError(f"Band '{name}' not found.")
133
+ idx = self.bands.index(name)
134
+ new_data = self.data.drop_isel(band=idx)
135
+ return SatelliteDataset(
136
+ data=new_data, bands=[b for b in self.bands if b != name],
137
+ crs=self.crs, resolution=self.resolution, bounds=self.bounds,
138
+ datetime=self.datetime, sensor=self.sensor, platform=self.platform,
139
+ cloud_cover=self.cloud_cover, metadata=self.metadata.copy(),
140
+ source_format=self.source_format, source_path=self.source_path)
141
+
142
+ def rename_bands(self, mapping: Dict[str, str]) -> "SatelliteDataset":
143
+ new_bands = [mapping.get(b, b) for b in self.bands]
144
+ new_data = self.data.copy()
145
+ new_data = new_data.assign_coords(band=new_bands)
146
+ return SatelliteDataset(
147
+ data=new_data, bands=new_bands, crs=self.crs,
148
+ resolution=self.resolution, bounds=self.bounds,
149
+ datetime=self.datetime, sensor=self.sensor, platform=self.platform,
150
+ cloud_cover=self.cloud_cover, metadata=self.metadata.copy(),
151
+ source_format=self.source_format, source_path=self.source_path)
152
+
153
+ def reproject(self, dst_crs: Union[str, int], **kwargs) -> "SatelliteDataset":
154
+ import rioxarray
155
+ if self.crs is None:
156
+ raise ValueError("Source CRS is not set. Cannot reproject.")
157
+ if self.data.rio.crs is None:
158
+ self.data = self.data.rio.write_crs(self.crs)
159
+ reprojected = self.data.rio.reproject(dst_crs, **kwargs)
160
+ new_bounds = reprojected.rio.bounds()
161
+ return SatelliteDataset(
162
+ data=reprojected, bands=self.bands.copy(), crs=str(dst_crs),
163
+ resolution=self.resolution, bounds=new_bounds,
164
+ datetime=self.datetime, sensor=self.sensor, platform=self.platform,
165
+ cloud_cover=self.cloud_cover, metadata=self.metadata.copy(),
166
+ source_format=self.source_format, source_path=self.source_path)
167
+
168
+ def resample(self, resolution: Union[float, Tuple[float, float]], **kwargs) -> "SatelliteDataset":
169
+ if self.crs is None:
170
+ raise ValueError("CRS must be set to resample.")
171
+ return self.reproject(dst_crs=self.crs, resolution=resolution, **kwargs)
172
+
173
+ def clip(self, geometry, crs=None, drop=True, **kwargs) -> "SatelliteDataset":
174
+ import rioxarray
175
+ if self.data.rio.crs is None and self.crs is not None:
176
+ self.data = self.data.rio.write_crs(self.crs)
177
+ clipped = self.data.rio.clip([geometry], crs=crs, drop=drop, all_touched=True, **kwargs)
178
+ new_bounds = clipped.rio.bounds()
179
+ return SatelliteDataset(
180
+ data=clipped, bands=self.bands.copy(), crs=self.crs,
181
+ resolution=self.resolution, bounds=new_bounds,
182
+ datetime=self.datetime, sensor=self.sensor, platform=self.platform,
183
+ cloud_cover=self.cloud_cover, metadata=self.metadata.copy(),
184
+ source_format=self.source_format, source_path=self.source_path)
185
+
186
+ def mask(self, mask_array: np.ndarray, fill_value: float = np.nan) -> "SatelliteDataset":
187
+ if mask_array.shape != self.data.shape[-2:]:
188
+ raise ValueError(f"Mask shape {mask_array.shape} does not match data spatial shape {self.data.shape[-2:]}")
189
+ masked_data = self.data.where(mask_array, fill_value)
190
+ return SatelliteDataset(
191
+ data=masked_data, bands=self.bands.copy(), crs=self.crs,
192
+ resolution=self.resolution, bounds=self.bounds,
193
+ datetime=self.datetime, sensor=self.sensor, platform=self.platform,
194
+ cloud_cover=self.cloud_cover, metadata=self.metadata.copy(),
195
+ source_format=self.source_format, source_path=self.source_path)
196
+
197
+ def to_geotiff(self, path: Union[str, Path], **kwargs) -> None:
198
+ import rioxarray
199
+ path = Path(path)
200
+ path.parent.mkdir(parents=True, exist_ok=True)
201
+ data = self.data
202
+ if data.rio.crs is None and self.crs is not None:
203
+ data = data.rio.write_crs(self.crs)
204
+ data.rio.to_raster(path, **kwargs)
205
+
206
+ def to_netcdf(self, path: Union[str, Path], **kwargs) -> None:
207
+ path = Path(path)
208
+ path.parent.mkdir(parents=True, exist_ok=True)
209
+ self.to_dataset().to_netcdf(path, **kwargs)
210
+
211
+ def __repr__(self) -> str:
212
+ return f"SatelliteDataset(shape={self.shape}, bands={self.bands}, crs={self.crs!r})"
213
+
214
+ def info(self) -> str:
215
+ lines = [
216
+ "=" * 50,
217
+ "SatelliteDataset Information",
218
+ "=" * 50,
219
+ f"Shape: {self.shape}",
220
+ f"Bands: {self.n_bands} ({self.bands})",
221
+ f"Width: {self.width} px",
222
+ f"Height: {self.height} px",
223
+ f"CRS: {self.crs}",
224
+ f"Resolution: {self.resolution}",
225
+ f"Bounds: {self.bounds}",
226
+ f"Sensor: {self.sensor}",
227
+ f"Platform: {self.platform}",
228
+ f"Datetime: {self.datetime}",
229
+ f"Cloud cover: {self.cloud_cover}%",
230
+ f"Dtype: {self.dtype}",
231
+ "=" * 50,
232
+ ]
233
+ return "\n".join(lines)
@@ -0,0 +1,26 @@
1
+ """SatDataKit extensions — optional add-ons for scalability.
2
+
3
+ Extensions load on demand and do not modify core code.
4
+ """
5
+
6
+ __version__ = "0.1.0"
7
+
8
+
9
+ def list_extensions():
10
+ """Return available extensions."""
11
+ return ["dask", "stac", "zarr"]
12
+
13
+
14
+ def enable(extension: str):
15
+ """Activate an extension by name."""
16
+ if extension == "dask":
17
+ from .dask_ext import enable_dask
18
+ enable_dask()
19
+ elif extension == "stac":
20
+ from .stac_ext import enable_stac
21
+ enable_stac()
22
+ elif extension == "zarr":
23
+ from .zarr_ext import enable_zarr
24
+ enable_zarr()
25
+ else:
26
+ raise ValueError(f"Unknown extension: {extension}. Available: {list_extensions()}")
@@ -0,0 +1,157 @@
1
+ """Dask extension for SatDataKit — parallel processing.
2
+
3
+ Usage:
4
+ from satdatakit.extensions.dask_ext import enable_dask, read_dask
5
+ enable_dask()
6
+
7
+ ds = read_dask(["file.tif"], chunks={"x": 1024})
8
+ ds = ds.to_dask(chunks={"x": 256})
9
+ ds = compute_index(ds, "NDVI")
10
+ ds = ds.compute()
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from pathlib import Path
16
+ from typing import Any, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import xarray as xr
20
+
21
+ from satdatakit.core import SatelliteDataset
22
+
23
+
24
+ def _check_dask():
25
+ """Verify Dask is installed."""
26
+ try:
27
+ import dask.array as da
28
+ from dask.distributed import Client
29
+ except ImportError:
30
+ raise ImportError(
31
+ "Dask not installed. Run: pip install satdatakit[dask]"
32
+ ) from None
33
+
34
+
35
+ def enable_dask():
36
+ """Activate Dask backend. Monkey-patches SatelliteDataset safely."""
37
+ _check_dask()
38
+
39
+ # --- Monkey patch: SatelliteDataset.to_dask ---
40
+ def to_dask(self: SatelliteDataset, chunks: Optional[dict] = None) -> SatelliteDataset:
41
+ """Convert data to Dask lazy arrays."""
42
+ import dask.array as da
43
+
44
+ if chunks is None:
45
+ chunks = {"x": 1024, "y": 1024}
46
+ if not isinstance(self.data.data, da.Array):
47
+ new_data = self.data.chunk(chunks)
48
+ return SatelliteDataset(
49
+ data=new_data,
50
+ bands=self.bands.copy(),
51
+ crs=self.crs,
52
+ resolution=self.resolution,
53
+ bounds=self.bounds,
54
+ datetime=self.datetime,
55
+ sensor=self.sensor,
56
+ platform=self.platform,
57
+ cloud_cover=self.cloud_cover,
58
+ metadata=self.metadata.copy(),
59
+ source_format=self.source_format,
60
+ source_path=self.source_path,
61
+ )
62
+ return self
63
+
64
+ SatelliteDataset.to_dask = to_dask
65
+
66
+ # --- Monkey patch: SatelliteDataset.compute ---
67
+ def compute(self: SatelliteDataset) -> SatelliteDataset:
68
+ """Trigger Dask computation and return in-memory dataset."""
69
+ import dask.array as da
70
+
71
+ if isinstance(self.data.data, da.Array):
72
+ new_data = self.data.compute()
73
+ return SatelliteDataset(
74
+ data=new_data,
75
+ bands=self.bands.copy(),
76
+ crs=self.crs,
77
+ resolution=self.resolution,
78
+ bounds=self.bounds,
79
+ datetime=self.datetime,
80
+ sensor=self.sensor,
81
+ platform=self.platform,
82
+ cloud_cover=self.cloud_cover,
83
+ metadata=self.metadata.copy(),
84
+ source_format=self.source_format,
85
+ source_path=self.source_path,
86
+ )
87
+ return self
88
+
89
+ SatelliteDataset.compute = compute
90
+
91
+ print("✅ Dask extension enabled:")
92
+ print(" - ds.to_dask(chunks={...})")
93
+ print(" - ds.compute()")
94
+
95
+
96
+ def disable_dask():
97
+ """Remove Dask patches."""
98
+ if hasattr(SatelliteDataset, "to_dask"):
99
+ delattr(SatelliteDataset, "to_dask")
100
+ if hasattr(SatelliteDataset, "compute"):
101
+ delattr(SatelliteDataset, "compute")
102
+ print("✅ Dask extension disabled")
103
+
104
+
105
+ def read_dask(
106
+ paths: List[Union[str, Path]],
107
+ bands: Optional[List[str]] = None,
108
+ chunks: Optional[dict] = None,
109
+ concat_dim: str = "time",
110
+ **kwargs: Any,
111
+ ) -> SatelliteDataset:
112
+ """
113
+ Read multiple files as lazy Dask dataset.
114
+
115
+ Usage:
116
+ from satdatakit.extensions.dask_ext import read_dask
117
+ ds = read_dask(["file1.tif", "file2.tif"], chunks={"x": 1024})
118
+ """
119
+ _check_dask()
120
+
121
+ import dask.array as da
122
+ import rioxarray
123
+
124
+ if chunks is None:
125
+ chunks = {"x": 1024, "y": 1024}
126
+
127
+ datasets = []
128
+ for p in paths:
129
+ ds = rioxarray.open_rasterio(p, chunks=chunks)
130
+ datasets.append(ds)
131
+
132
+ if len(datasets) == 1:
133
+ stacked = datasets[0]
134
+ else:
135
+ stacked = xr.concat(datasets, dim=concat_dim)
136
+
137
+ # Extraer nombres de bandas correctamente
138
+ band_names = []
139
+ if "band" in stacked.coords:
140
+ band_names = [str(b) for b in stacked.coords["band"].values]
141
+ else:
142
+ band_names = [f"band_{i}" for i in range(stacked.sizes["band"])]
143
+
144
+ return SatelliteDataset(
145
+ data=stacked,
146
+ bands=band_names,
147
+ crs=stacked.rio.crs.to_string() if hasattr(stacked, "rio") else None,
148
+ resolution=None,
149
+ bounds=None,
150
+ datetime=None,
151
+ sensor=None,
152
+ platform=None,
153
+ cloud_cover=None,
154
+ metadata={},
155
+ source_format="dask",
156
+ source_path=Path(paths[0]) if paths else None,
157
+ )
@@ -0,0 +1,74 @@
1
+ """STAC extension for SatDataKit — cloud catalog access.
2
+
3
+ Usage:
4
+ from satdatakit.extensions.stac_ext import enable_stac
5
+ enable_stac()
6
+
7
+ ds = satdatakit.read_stac(
8
+ catalog_url="https://earth-search.aws.element84.com/v1",
9
+ collection="sentinel-2-l2a",
10
+ bbox=(-100, 25, -99, 26),
11
+ datetime_range="2024-01-01/2024-06-01"
12
+ )
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any, List, Optional, Tuple, Union
18
+
19
+
20
+ def enable_stac():
21
+ """Activate STAC backend."""
22
+ try:
23
+ import pystac
24
+ import pystac_client
25
+ import stackstac
26
+ except ImportError:
27
+ raise ImportError(
28
+ "STAC dependencies not installed. Run: pip install satdatakit[stac]"
29
+ ) from None
30
+
31
+ from satdatakit.core import SatelliteDataset
32
+
33
+ def read_stac(
34
+ catalog_url: str,
35
+ collection: str,
36
+ bbox: Tuple[float, float, float, float],
37
+ datetime_range: str,
38
+ bands: Optional[List[str]] = None,
39
+ **kwargs: Any,
40
+ ) -> SatelliteDataset:
41
+ """Search and read from STAC catalog."""
42
+ catalog = pystac_client.Client.open(catalog_url)
43
+ search = catalog.search(
44
+ collections=[collection],
45
+ bbox=bbox,
46
+ datetime=datetime_range,
47
+ )
48
+ items = list(search.get_items())
49
+
50
+ if not items:
51
+ raise ValueError("No items found for query")
52
+
53
+ data = stackstac.stack(items, assets=bands, **kwargs)
54
+
55
+ return SatelliteDataset(
56
+ data=data,
57
+ bands=bands or list(data.coords["band"].values),
58
+ crs="EPSG:4326",
59
+ source_format="stac",
60
+ )
61
+
62
+ import satdatakit
63
+ satdatakit.read_stac = read_stac
64
+
65
+ print("✅ STAC extension enabled:")
66
+ print(" - satdatakit.read_stac(catalog_url=..., collection=...)")
67
+
68
+
69
+ def disable_stac():
70
+ """Remove STAC patches."""
71
+ import satdatakit
72
+ if hasattr(satdatakit, "read_stac"):
73
+ delattr(satdatakit, "read_stac")
74
+ print("✅ STAC extension disabled")
@@ -0,0 +1,89 @@
1
+ """Zarr extension for SatDataKit — cloud-native format.
2
+
3
+ Usage:
4
+ from satdatakit.extensions.zarr_ext import enable_zarr
5
+ enable_zarr()
6
+
7
+ ds = satdatakit.read_zarr("s3://bucket/dataset.zarr")
8
+ ds.to_zarr("/local/output.zarr")
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from pathlib import Path
14
+ from typing import Any, Optional, Union
15
+
16
+
17
+ def enable_zarr():
18
+ """Activate Zarr backend."""
19
+ try:
20
+ import zarr
21
+ except ImportError:
22
+ raise ImportError(
23
+ "Zarr not installed. Run: pip install satdatakit[zarr]"
24
+ ) from None
25
+
26
+ from satdatakit.core import SatelliteDataset
27
+
28
+ # --- Monkey patch: SatelliteDataset.to_zarr ---
29
+ def to_zarr(self: SatelliteDataset, path: Union[str, Path], **kwargs: Any) -> None:
30
+ """Export dataset to Zarr format."""
31
+ ds = self.to_dataset()
32
+ ds.to_zarr(path, **kwargs)
33
+
34
+ SatelliteDataset.to_zarr = to_zarr
35
+
36
+ # --- Global function: read_zarr ---
37
+ def read_zarr(path: Union[str, Path], **kwargs: Any) -> SatelliteDataset:
38
+ """Read Zarr store into SatelliteDataset."""
39
+ import xarray as xr
40
+
41
+ ds = xr.open_zarr(path, **kwargs)
42
+
43
+ # Detectar variables espaciales
44
+ spatial_vars = []
45
+ for name, var in ds.data_vars.items():
46
+ if any(d in var.dims for d in ["x", "y", "lon", "lat"]):
47
+ spatial_vars.append(name)
48
+
49
+ if not spatial_vars:
50
+ raise ValueError("No spatial variables found in Zarr store")
51
+
52
+ # Convertir a DataArray con dimensión band
53
+ data_arrays = []
54
+ for var in spatial_vars:
55
+ da = ds[var]
56
+ if "band" not in da.dims:
57
+ da = da.expand_dims(band=[var])
58
+ data_arrays.append(da)
59
+
60
+ if len(data_arrays) == 1:
61
+ data = data_arrays[0]
62
+ else:
63
+ data = xr.concat(data_arrays, dim="band")
64
+
65
+ return SatelliteDataset(
66
+ data=data,
67
+ bands=spatial_vars,
68
+ crs=ds.attrs.get("crs"),
69
+ source_format="zarr",
70
+ source_path=Path(path),
71
+ )
72
+
73
+ import satdatakit
74
+ satdatakit.read_zarr = read_zarr
75
+
76
+ print("✅ Zarr extension enabled:")
77
+ print(" - satdatakit.read_zarr(path)")
78
+ print(" - ds.to_zarr(path)")
79
+
80
+
81
+ def disable_zarr():
82
+ """Remove Zarr patches."""
83
+ from satdatakit.core import SatelliteDataset
84
+ if hasattr(SatelliteDataset, "to_zarr"):
85
+ delattr(SatelliteDataset, "to_zarr")
86
+ import satdatakit
87
+ if hasattr(satdatakit, "read_zarr"):
88
+ delattr(satdatakit, "read_zarr")
89
+ print("✅ Zarr extension disabled")
satdatakit/indices.py ADDED
@@ -0,0 +1,100 @@
1
+ """Spectral indices computation.
2
+
3
+ Author: Rafael Cañete Vazquez
4
+ License: MIT
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import numpy as np
11
+ import xarray as xr
12
+
13
+ from satdatakit.core import SatelliteDataset
14
+
15
+ INDEX_DEFINITIONS: Dict[str, Dict[str, Any]] = {
16
+ "NDVI": {
17
+ "name": "Normalized Difference Vegetation Index",
18
+ "formula": "(nir - red) / (nir + red)",
19
+ "bands": {"red": ["red", "B04"], "nir": ["nir", "B08"]},
20
+ "range": (-1, 1),
21
+ },
22
+ "NDWI": {
23
+ "name": "Normalized Difference Water Index",
24
+ "formula": "(green - nir) / (green + nir)",
25
+ "bands": {"green": ["green", "B03"], "nir": ["nir", "B08"]},
26
+ "range": (-1, 1),
27
+ },
28
+ "EVI": {
29
+ "name": "Enhanced Vegetation Index",
30
+ "formula": "2.5 * (nir - red) / (nir + 6*red - 7.5*blue + 1)",
31
+ "bands": {"red": ["red", "B04"], "nir": ["nir", "B08"], "blue": ["blue", "B02"]},
32
+ "range": (-1, 1),
33
+ },
34
+ "SAVI": {
35
+ "name": "Soil Adjusted Vegetation Index",
36
+ "formula": "(1 + L) * (nir - red) / (nir + red + L)",
37
+ "bands": {"red": ["red", "B04"], "nir": ["nir", "B08"]},
38
+ "params": {"L": 0.5},
39
+ "range": (-1, 1),
40
+ },
41
+ }
42
+
43
+ def _resolve_band_name(dataset: SatelliteDataset, candidates: List[str]) -> Optional[str]:
44
+ for candidate in candidates:
45
+ if candidate in dataset.bands:
46
+ return candidate
47
+ return None
48
+
49
+ def compute_index(dataset: SatelliteDataset, index: str,
50
+ band_mapping: Optional[Dict[str, str]] = None,
51
+ params: Optional[Dict[str, float]] = None,
52
+ add_to_dataset: bool = True, clip_range: bool = True
53
+ ) -> Union[SatelliteDataset, xr.DataArray]:
54
+ index = index.upper()
55
+ if index not in INDEX_DEFINITIONS:
56
+ raise ValueError(f"Index '{index}' not supported.")
57
+ definition = INDEX_DEFINITIONS[index]
58
+ resolved_bands = {}
59
+ for canonical_name, candidates in definition["bands"].items():
60
+ if band_mapping and canonical_name in band_mapping:
61
+ actual_name = band_mapping[canonical_name]
62
+ else:
63
+ actual_name = _resolve_band_name(dataset, candidates)
64
+ if actual_name is None:
65
+ raise KeyError(f"Cannot resolve band for '{canonical_name}'.")
66
+ resolved_bands[canonical_name] = actual_name
67
+ band_data = {name: dataset[actual] for name, actual in resolved_bands.items()}
68
+ formula_params = definition.get("params", {}).copy()
69
+ if params:
70
+ formula_params.update(params)
71
+ result = _compute_formula(index, band_data, formula_params)
72
+ if clip_range:
73
+ valid_min, valid_max = definition.get("range", (-float("inf"), float("inf")))
74
+ result = result.clip(min=valid_min, max=valid_max)
75
+ result = result.where(np.isfinite(result))
76
+ if add_to_dataset:
77
+ return dataset.add_band(index, result)
78
+ return result
79
+
80
+ def _compute_formula(index: str, bands: Dict[str, xr.DataArray], params: Dict[str, float]) -> xr.DataArray:
81
+ if index == "NDVI":
82
+ return (bands["nir"] - bands["red"]) / (bands["nir"] + bands["red"])
83
+ elif index == "NDWI":
84
+ return (bands["green"] - bands["nir"]) / (bands["green"] + bands["nir"])
85
+ elif index == "EVI":
86
+ return 2.5 * (bands["nir"] - bands["red"]) / (bands["nir"] + 6 * bands["red"] - 7.5 * bands["blue"] + 1)
87
+ elif index == "SAVI":
88
+ L = params.get("L", 0.5)
89
+ return (1 + L) * (bands["nir"] - bands["red"]) / (bands["nir"] + bands["red"] + L)
90
+ else:
91
+ raise ValueError(f"Formula not implemented for index: {index}")
92
+
93
+ def list_indices() -> List[str]:
94
+ return list(INDEX_DEFINITIONS.keys())
95
+
96
+ def index_info(index: str) -> Dict[str, Any]:
97
+ index = index.upper()
98
+ if index not in INDEX_DEFINITIONS:
99
+ raise ValueError(f"Index '{index}' not found.")
100
+ return INDEX_DEFINITIONS[index].copy()