dea-tools 0.4.8.dev13__tar.gz → 0.4.9.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/PKG-INFO +1 -1
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/wetlandsinsighttool.py +12 -4
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/bandindices.py +11 -10
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/classification.py +325 -224
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/wetlands.py +3 -7
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/.gitignore +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/LICENSE +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/README.md +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/__init__.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/__main__.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/__init__.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/animations.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/changefilmstrips.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/crophealth.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/deacoastlines.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/geomedian.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/imageexport.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/miningrehab.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/widgetconstructors.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/bom.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/coastal.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/dask.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/datahandling.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/landcover.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/maps.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/README.md +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/__init__.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/cog.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/styling.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/utils.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/vrt.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/plotting.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/spatial.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/temporal.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/validation.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/waterbodies.py +0 -0
- {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dea-tools
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.9.dev2
|
|
4
4
|
Summary: Open-source tools for geospatial analysis with Digital Earth Australia, Open Data Cube, and Xarray
|
|
5
5
|
Project-URL: Homepage, https://knowledge.dea.ga.gov.au/notebooks/Tools/
|
|
6
6
|
Project-URL: Repository, https://github.com/GeoscienceAustralia/dea-notebooks
|
|
@@ -488,8 +488,7 @@ class wit_app(HBox):
|
|
|
488
488
|
# ---Plotting------------------------------
|
|
489
489
|
if df is not None:
|
|
490
490
|
with self.wit_plot:
|
|
491
|
-
|
|
492
|
-
plt.rcParams.update({"font.size": fontsize})
|
|
491
|
+
|
|
493
492
|
# set up color palette
|
|
494
493
|
pal = [
|
|
495
494
|
sns.xkcd_rgb["cobalt blue"],
|
|
@@ -543,8 +542,17 @@ class wit_app(HBox):
|
|
|
543
542
|
hatch="//",
|
|
544
543
|
)
|
|
545
544
|
|
|
546
|
-
|
|
547
|
-
|
|
545
|
+
# calculate how many years of data have been loaded
|
|
546
|
+
date_range_years = (df["date"].max() - df["date"].min()).days / 365.25
|
|
547
|
+
|
|
548
|
+
if date_range_years > 5:
|
|
549
|
+
# show only years on x-axis
|
|
550
|
+
ax.xaxis.set_major_locator(mdates.YearLocator())
|
|
551
|
+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
|
|
552
|
+
else:
|
|
553
|
+
# show months and years on x-axis
|
|
554
|
+
ax.xaxis.set_major_locator(mdates.MonthLocator())
|
|
555
|
+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%b-%Y"))
|
|
548
556
|
|
|
549
557
|
# Rotates and right-aligns the x labels so they don't crowd each other.
|
|
550
558
|
for label in ax.get_xticklabels(which="major"):
|
|
@@ -15,7 +15,7 @@ here: https://gis.stackexchange.com/questions/tagged/open-data-cube).
|
|
|
15
15
|
If you would like to report an issue with this script, you can file one
|
|
16
16
|
on GitHub (https://github.com/GeoscienceAustralia/dea-notebooks/issues/new).
|
|
17
17
|
|
|
18
|
-
Last modified:
|
|
18
|
+
Last modified: March 2026
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
21
|
# Import required packages
|
|
@@ -87,6 +87,8 @@ def calculate_indices(
|
|
|
87
87
|
* ``'NIRv'`` (Near-Infrared Reflectance of Vegetation, Badgley et al. 2017)
|
|
88
88
|
* ``'kNDVI'`` (Kernel Normalized Difference Vegetation Index, Camps-Valls et al. 2021)
|
|
89
89
|
* ``'SAVI'`` (Soil Adjusted Vegetation Index, Huete 1988)
|
|
90
|
+
* ``'SRVI'`` (Symbolic Regression Vegetation Index, Chrysostomou 2026)
|
|
91
|
+
* ``'SRWI'`` (Symbolic Regression Water Index, Chrysostomou 2026)
|
|
90
92
|
* ``'TCB'`` (Tasseled Cap Brightness, Crist 1985)
|
|
91
93
|
* ``'TCG'`` (Tasseled Cap Greeness, Crist 1985)
|
|
92
94
|
* ``'TCW'`` (Tasseled Cap Wetness, Crist 1985)
|
|
@@ -94,8 +96,6 @@ def calculate_indices(
|
|
|
94
96
|
* ``'TCG_GSO'`` (Tasseled Cap Greeness, Nedkov 2017)
|
|
95
97
|
* ``'TCW_GSO'`` (Tasseled Cap Wetness, Nedkov 2017)
|
|
96
98
|
* ``'WI'`` (Water Index, Fisher 2016)
|
|
97
|
-
* ``'kNDVI'`` (Non-linear Normalised Difference Vegation Index,
|
|
98
|
-
Camps-Valls et al. 2021)
|
|
99
99
|
|
|
100
100
|
collection : str
|
|
101
101
|
An string that tells the function what data collection is
|
|
@@ -107,7 +107,6 @@ def calculate_indices(
|
|
|
107
107
|
|
|
108
108
|
* ``'ga_ls_3'`` (for GA Landsat Collection 3)
|
|
109
109
|
* ``'ga_s2_3'`` (for GA Sentinel 2 Collection 3)
|
|
110
|
-
* ``'ga_gm_3'`` (for GA Geomedian Collection 3)
|
|
111
110
|
|
|
112
111
|
custom_varname : str, optional
|
|
113
112
|
By default, the original dataset will be returned with
|
|
@@ -295,6 +294,12 @@ def calculate_indices(
|
|
|
295
294
|
"FMR": lambda ds: (ds.swir1 / ds.nir),
|
|
296
295
|
# Iron Oxide Ratio, Segal 1982
|
|
297
296
|
"IOR": lambda ds: (ds.red / ds.blue),
|
|
297
|
+
# Symbolic Regression Water Index, Chrysostomou 2026
|
|
298
|
+
"SRWI": lambda ds: ((ds.green + ds.blue) - (ds.nir + ds.swir1))
|
|
299
|
+
/ ((ds.green + ds.blue) + (ds.nir + ds.swir1)),
|
|
300
|
+
# Symbolic Regression Vegetation Index, Chrysostomou 2026
|
|
301
|
+
"SRVI": lambda ds: ((2 * ds.nir) - (3 * ds.red))
|
|
302
|
+
/ (ds.nir + ds.red + (0.5 * (ds.green + ds.swir1))),
|
|
298
303
|
}
|
|
299
304
|
|
|
300
305
|
# If index supplied is not a list, convert to list. This allows us to
|
|
@@ -343,7 +348,7 @@ def calculate_indices(
|
|
|
343
348
|
if collection is None:
|
|
344
349
|
raise ValueError(
|
|
345
350
|
"'No `collection` was provided. Please specify "
|
|
346
|
-
"either 'ga_ls_3'
|
|
351
|
+
"either 'ga_ls_3' or 'ga_s2_3' "
|
|
347
352
|
"to ensure the function calculates indices "
|
|
348
353
|
"using the correct spectral bands"
|
|
349
354
|
)
|
|
@@ -396,16 +401,12 @@ def calculate_indices(
|
|
|
396
401
|
a: b for a, b in bandnames_dict.items() if a in ds.variables
|
|
397
402
|
}
|
|
398
403
|
|
|
399
|
-
elif collection == "ga_gm_3":
|
|
400
|
-
# Pass an empty dict as no bands need renaming
|
|
401
|
-
bands_to_rename = {}
|
|
402
|
-
|
|
403
404
|
# Raise error if no valid collection name is provided:
|
|
404
405
|
else:
|
|
405
406
|
raise ValueError(
|
|
406
407
|
f"'{collection}' is not a valid option for "
|
|
407
408
|
"`collection`. Please specify either \n"
|
|
408
|
-
"'ga_ls_3'
|
|
409
|
+
"'ga_ls_3' or 'ga_s2_3'"
|
|
409
410
|
)
|
|
410
411
|
|
|
411
412
|
# Apply index function
|
|
@@ -17,35 +17,37 @@ here: https://gis.stackexchange.com/questions/tagged/open-data-cube).
|
|
|
17
17
|
If you would like to report an issue with this script, you can file one
|
|
18
18
|
on GitHub (https://github.com/GeoscienceAustralia/dea-notebooks/issues/new).
|
|
19
19
|
|
|
20
|
-
Last modified:
|
|
20
|
+
Last modified: February 2026
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
|
-
import multiprocessing as mp
|
|
24
23
|
import os
|
|
25
24
|
import sys
|
|
26
25
|
import time
|
|
27
|
-
import
|
|
28
|
-
from abc import ABCMeta, abstractmethod
|
|
29
|
-
from copy import deepcopy
|
|
30
|
-
from datetime import timedelta
|
|
31
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
32
|
-
|
|
33
|
-
import dask.array as da
|
|
34
|
-
import dask.distributed as dd
|
|
35
|
-
import geopandas as gpd
|
|
26
|
+
import pyproj
|
|
36
27
|
import joblib
|
|
28
|
+
import warnings
|
|
37
29
|
import numpy as np
|
|
38
30
|
import pandas as pd
|
|
39
31
|
import xarray as xr
|
|
32
|
+
import dask.array as da
|
|
33
|
+
import geopandas as gpd
|
|
34
|
+
from tqdm.auto import tqdm
|
|
35
|
+
import multiprocessing as mp
|
|
36
|
+
import dask.distributed as dd
|
|
37
|
+
from functools import partial
|
|
38
|
+
from datetime import datetime, timedelta
|
|
39
|
+
from abc import ABCMeta, abstractmethod
|
|
40
40
|
from dask_ml.wrappers import ParallelPostFit
|
|
41
|
+
|
|
41
42
|
from odc.geo.geom import Geometry
|
|
42
43
|
from odc.geo.xr import assign_crs
|
|
43
44
|
from sklearn.base import ClusterMixin
|
|
44
|
-
from sklearn.
|
|
45
|
+
from sklearn.utils import check_random_state
|
|
45
46
|
from sklearn.mixture import GaussianMixture
|
|
47
|
+
from sklearn.cluster import AgglomerativeClustering, KMeans
|
|
46
48
|
from sklearn.model_selection import BaseCrossValidator, KFold, ShuffleSplit
|
|
47
|
-
|
|
48
|
-
from
|
|
49
|
+
|
|
50
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
49
51
|
|
|
50
52
|
from dea_tools.spatial import xr_rasterize
|
|
51
53
|
|
|
@@ -57,8 +59,7 @@ def sklearn_flatten(input_xr):
|
|
|
57
59
|
dimensions flattened into one dimension.
|
|
58
60
|
|
|
59
61
|
This flattening procedure enables DataArrays and Datasets to be used
|
|
60
|
-
to train and predict
|
|
61
|
-
with sklearn models.
|
|
62
|
+
to train and predict with sklearn models.
|
|
62
63
|
|
|
63
64
|
Last modified: September 2019
|
|
64
65
|
|
|
@@ -83,7 +84,11 @@ def sklearn_flatten(input_xr):
|
|
|
83
84
|
input_xr = input_xr.to_array()
|
|
84
85
|
|
|
85
86
|
# stack across pixel dimensions, handling timeseries if necessary
|
|
86
|
-
stacked =
|
|
87
|
+
stacked = (
|
|
88
|
+
input_xr.stack(z=["x", "y", "time"])
|
|
89
|
+
if "time" in input_xr.dims
|
|
90
|
+
else input_xr.stack(z=["x", "y"])
|
|
91
|
+
)
|
|
87
92
|
|
|
88
93
|
# finding 'bands' dimensions in each pixel - these will not be
|
|
89
94
|
# flattened as their context is important for sklearn
|
|
@@ -146,7 +151,11 @@ def sklearn_unflatten(output_np, input_xr):
|
|
|
146
151
|
input_xr = input_xr.to_array()
|
|
147
152
|
|
|
148
153
|
# generate the same mask we used to create the input to the sklearn model
|
|
149
|
-
stacked =
|
|
154
|
+
stacked = (
|
|
155
|
+
input_xr.stack(z=["x", "y", "time"])
|
|
156
|
+
if "time" in input_xr.dims
|
|
157
|
+
else input_xr.stack(z=["x", "y"])
|
|
158
|
+
)
|
|
150
159
|
|
|
151
160
|
pxdims = []
|
|
152
161
|
for dim in stacked.dims:
|
|
@@ -287,7 +296,9 @@ def predict_xr(
|
|
|
287
296
|
input_data_flattened = da.array(input_data_flattened).transpose()
|
|
288
297
|
|
|
289
298
|
if clean:
|
|
290
|
-
input_data_flattened = da.where(
|
|
299
|
+
input_data_flattened = da.where(
|
|
300
|
+
da.isfinite(input_data_flattened), input_data_flattened, 0
|
|
301
|
+
)
|
|
291
302
|
|
|
292
303
|
if proba & persist:
|
|
293
304
|
# persisting data so we don't require loading all the data twice
|
|
@@ -319,18 +330,24 @@ def predict_xr(
|
|
|
319
330
|
out_proba = da.max(out_proba, axis=1) * 100.0
|
|
320
331
|
out_proba = out_proba.reshape(len(y), len(x))
|
|
321
332
|
|
|
322
|
-
out_proba = xr.DataArray(
|
|
333
|
+
out_proba = xr.DataArray(
|
|
334
|
+
out_proba, coords={"x": x, "y": y}, dims=["y", "x"]
|
|
335
|
+
)
|
|
323
336
|
output_xr["Probabilities"] = out_proba
|
|
324
337
|
else:
|
|
325
338
|
print(" returning class probability array")
|
|
326
339
|
out_proba = out_proba * 100.0
|
|
327
|
-
class_names =
|
|
340
|
+
class_names = (
|
|
341
|
+
model.classes_
|
|
342
|
+
) # Get the unique class names from the fitted classifier
|
|
328
343
|
|
|
329
344
|
# Loop through each class (band)
|
|
330
345
|
probabilities_dataset = xr.Dataset()
|
|
331
346
|
for i, class_name in enumerate(class_names):
|
|
332
347
|
reshaped_band = out_proba[:, i].reshape(len(y), len(x))
|
|
333
|
-
reshaped_da = xr.DataArray(
|
|
348
|
+
reshaped_da = xr.DataArray(
|
|
349
|
+
reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"]
|
|
350
|
+
)
|
|
334
351
|
probabilities_dataset[f"prob_{class_name}"] = reshaped_da
|
|
335
352
|
|
|
336
353
|
# merge in the probabilities
|
|
@@ -351,7 +368,9 @@ def predict_xr(
|
|
|
351
368
|
if len(input_data_flattened.shape[1:]):
|
|
352
369
|
output_px_shape = input_data_flattened.shape[1:]
|
|
353
370
|
|
|
354
|
-
output_features = input_data_flattened.reshape(
|
|
371
|
+
output_features = input_data_flattened.reshape(
|
|
372
|
+
(len(stacked.z), *output_px_shape)
|
|
373
|
+
)
|
|
355
374
|
|
|
356
375
|
# set the stacked coordinate to match the input
|
|
357
376
|
output_features = xr.DataArray(
|
|
@@ -366,7 +385,9 @@ def predict_xr(
|
|
|
366
385
|
# convert to dataset and rename arrays
|
|
367
386
|
output_features = output_features.to_dataset(dim="output_dim_0")
|
|
368
387
|
data_vars = list(input_xr.data_vars)
|
|
369
|
-
output_features = output_features.rename(
|
|
388
|
+
output_features = output_features.rename(
|
|
389
|
+
{i: j for i, j in zip(output_features.data_vars, data_vars)}
|
|
390
|
+
)
|
|
370
391
|
|
|
371
392
|
# merge with predictions
|
|
372
393
|
output_xr = xr.merge([output_xr, output_features], compat="override")
|
|
@@ -377,10 +398,14 @@ def predict_xr(
|
|
|
377
398
|
# convert model to dask predict
|
|
378
399
|
model = ParallelPostFit(model)
|
|
379
400
|
with joblib.parallel_backend("dask"):
|
|
380
|
-
output_xr = _predict_func(
|
|
401
|
+
output_xr = _predict_func(
|
|
402
|
+
model, input_xr, persist, proba, max_proba, clean, return_input
|
|
403
|
+
)
|
|
381
404
|
|
|
382
405
|
else:
|
|
383
|
-
output_xr = _predict_func(
|
|
406
|
+
output_xr = _predict_func(
|
|
407
|
+
model, input_xr, persist, proba, max_proba, clean, return_input
|
|
408
|
+
).compute()
|
|
384
409
|
|
|
385
410
|
return output_xr
|
|
386
411
|
|
|
@@ -400,42 +425,35 @@ class HiddenPrints:
|
|
|
400
425
|
|
|
401
426
|
|
|
402
427
|
def _get_training_data_for_shp(
|
|
403
|
-
gdf: gpd.GeoDataFrame,
|
|
404
|
-
index: int,
|
|
405
428
|
row: gpd.GeoSeries,
|
|
406
|
-
|
|
407
|
-
out_vars: List[List[str]],
|
|
429
|
+
crs: pyproj.CRS,
|
|
408
430
|
dc_query: Dict,
|
|
409
431
|
return_coords: bool,
|
|
410
|
-
|
|
432
|
+
return_time_coords: bool,
|
|
433
|
+
feature_func: callable = None,
|
|
411
434
|
field: Optional[str] = None,
|
|
412
435
|
zonal_stats: Optional[str] = None,
|
|
413
436
|
time_field: Optional[str] = None,
|
|
414
|
-
|
|
415
|
-
):
|
|
437
|
+
) -> pd.DataFrame:
|
|
416
438
|
"""
|
|
417
439
|
This is the core function that is triggered by `collect_training_data`.
|
|
418
440
|
The `collect_training_data` function loops through geometries in a geopandas
|
|
419
|
-
geodataframe and runs
|
|
420
|
-
|
|
421
|
-
See that function for information on the other params not listed below.
|
|
441
|
+
geodataframe and runs this function. See `collect_training_data` for more
|
|
442
|
+
information on the parameters than is detailed below.
|
|
422
443
|
|
|
423
444
|
Parameters
|
|
424
445
|
----------
|
|
425
|
-
gdf : gpd.GeoDataFrame
|
|
426
|
-
Geopandas GeoDataFrame containing geometries.
|
|
427
|
-
index : int
|
|
428
|
-
Index of the current geometry in the GeoDataFrame.
|
|
429
446
|
row : gpd.GeoSeries
|
|
430
447
|
GeoSeries representing the current row in the GeoDataFrame.
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
An empty list into which the data variable names are stored.
|
|
448
|
+
crs : pyrpoj.CRS
|
|
449
|
+
Coordinate reference system information extracted from a GeoDataFrame
|
|
450
|
+
e.g., crs=gdf.crs
|
|
435
451
|
dc_query : Dict
|
|
436
|
-
ODC query.
|
|
452
|
+
ODC query object.
|
|
437
453
|
return_coords : bool
|
|
438
|
-
Flag indicating whether to return coordinates in the dataset.
|
|
454
|
+
Flag indicating whether to return x,y coordinates in the dataset.
|
|
455
|
+
return_time_coords : bool
|
|
456
|
+
Flag indicating whether to return time coordinates in the dataset
|
|
439
457
|
feature_func : callable, optional
|
|
440
458
|
Optional function to extract data based on `dc_query`. Defaults to None.
|
|
441
459
|
field : str, optional
|
|
@@ -443,65 +461,52 @@ def _get_training_data_for_shp(
|
|
|
443
461
|
zonal_stats : str, optional
|
|
444
462
|
Zonal statistics method. Defaults to None.
|
|
445
463
|
time_field : str, optional
|
|
446
|
-
Name of the column containing
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
464
|
+
Name of the column containing time(range) data in the input gdf, for the case where each row
|
|
465
|
+
should load from a different time(range). If loading from the same time(range) for
|
|
466
|
+
all rows, then its preferable to pass time as a key:variable in the 'dc_query'.
|
|
467
|
+
Note the time values must be in a format that datacube.load() accepts. For example, as a
|
|
468
|
+
tuple with strings ('2017-01-01', '2017-01-31'). Defaults to None.
|
|
451
469
|
|
|
452
470
|
Returns
|
|
453
471
|
--------
|
|
454
|
-
|
|
455
|
-
each pixel or polygon, and another containing the data variable names.
|
|
472
|
+
pd.DataFrame
|
|
456
473
|
|
|
457
474
|
"""
|
|
458
475
|
|
|
459
476
|
# prevent function altering dictionary kwargs
|
|
460
|
-
dc_query =
|
|
477
|
+
dc_query = {**dc_query} # shallow copy is faster
|
|
461
478
|
|
|
462
479
|
# remove dask chunks if supplied as using
|
|
463
|
-
# mulitprocessing for
|
|
480
|
+
# mulitprocessing for parallelisation
|
|
464
481
|
if "dask_chunks" in dc_query:
|
|
465
482
|
dc_query.pop("dask_chunks", None)
|
|
466
483
|
|
|
467
|
-
# set up query based on
|
|
468
|
-
geom = Geometry(geom=
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
start_time = timestamp - time_delta
|
|
477
|
-
end_time = timestamp + time_delta
|
|
478
|
-
timestamp = {"time": (start_time, end_time)}
|
|
479
|
-
# merge time query with user supplied query params
|
|
480
|
-
dc_query.update(timestamp)
|
|
481
|
-
|
|
482
|
-
# Use input feature function
|
|
484
|
+
# set up query based on row geometry
|
|
485
|
+
geom = Geometry(geom=row["geometry"], crs=crs)
|
|
486
|
+
dc_query.update({"geopolygon": geom})
|
|
487
|
+
|
|
488
|
+
if time_field is not None:
|
|
489
|
+
timerange = getattr(row, time_field)
|
|
490
|
+
dc_query.update({"time": timerange})
|
|
491
|
+
|
|
492
|
+
# Use input feature function and run checks on output
|
|
483
493
|
data = feature_func(dc_query)
|
|
484
494
|
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
return
|
|
495
|
+
if not isinstance(data, (xr.Dataset, xr.DataArray)):
|
|
496
|
+
raise TypeError("feature_func must return xarray Dataset or DataArray")
|
|
488
497
|
|
|
489
|
-
if
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
498
|
+
if len(data.data_vars) == 0:
|
|
499
|
+
raise ValueError(
|
|
500
|
+
"feature_func returned an empty dataset, "
|
|
501
|
+
"this can happen if a geometry is not within data bounds"
|
|
502
|
+
)
|
|
494
503
|
|
|
495
|
-
#
|
|
496
|
-
if "
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
+ str(t)
|
|
502
|
-
+ " time-steps, dataset must only have"
|
|
503
|
-
+ " x and y dimensions."
|
|
504
|
-
)
|
|
504
|
+
# If the geometry type is a polygon extract all pixels
|
|
505
|
+
if row["geometry"].geom_type != "Point":
|
|
506
|
+
# create polygon mask (requires gdf)
|
|
507
|
+
dff = gpd.GeoDataFrame(row.to_frame().T, geometry="geometry", crs=crs)
|
|
508
|
+
mask = xr_rasterize(dff, data)
|
|
509
|
+
data = data.where(mask)
|
|
505
510
|
|
|
506
511
|
if return_coords:
|
|
507
512
|
# turn coords into a variable in the ds
|
|
@@ -510,46 +515,50 @@ def _get_training_data_for_shp(
|
|
|
510
515
|
|
|
511
516
|
# append ID measurement to dataset for tracking failures
|
|
512
517
|
band = list(data.data_vars)[0]
|
|
513
|
-
|
|
514
|
-
data["
|
|
515
|
-
|
|
518
|
+
data["_training_id"] = xr.zeros_like(data[band])
|
|
519
|
+
data["_training_id"] = data["_training_id"] + row["_training_id"]
|
|
520
|
+
|
|
521
|
+
if "time" in data.sizes:
|
|
522
|
+
if return_time_coords:
|
|
523
|
+
data["time_coord"] = data.time
|
|
516
524
|
|
|
517
|
-
# If no zonal stats were requested then extract all pixel values
|
|
525
|
+
# If no zonal stats were requested then extract all pixel values.
|
|
518
526
|
if zonal_stats is None:
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
stacked = np.hstack((np.expand_dims(flat_val, axis=1), flat_train))
|
|
527
|
+
stacked = data.to_dataframe().reset_index(drop=True)
|
|
528
|
+
stacked[field] = row[field]
|
|
522
529
|
|
|
523
530
|
elif zonal_stats in ["mean", "median", "max", "min"]:
|
|
524
531
|
method_to_call = getattr(data, zonal_stats)
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
stacked =
|
|
532
|
+
stacked = method_to_call(["x", "y"]) # will keep time as dim if present
|
|
533
|
+
stacked = stacked.to_dataframe().reset_index(drop=True)
|
|
534
|
+
stacked[field] = row[field]
|
|
528
535
|
|
|
529
536
|
else:
|
|
530
537
|
raise Exception(
|
|
531
|
-
zonal_stats
|
|
538
|
+
f"{zonal_stats} is not one of the supported reduce functions: 'mean','median','max','min'"
|
|
532
539
|
)
|
|
533
540
|
|
|
534
|
-
|
|
535
|
-
|
|
541
|
+
if "spatial_ref" in stacked.columns:
|
|
542
|
+
stacked = stacked.drop("spatial_ref", axis=1)
|
|
543
|
+
|
|
544
|
+
return stacked
|
|
536
545
|
|
|
537
546
|
|
|
538
547
|
def _get_training_data_parallel(
|
|
539
548
|
gdf: gpd.GeoDataFrame,
|
|
540
|
-
dc_query:
|
|
549
|
+
dc_query: dict,
|
|
541
550
|
ncpus: int,
|
|
542
|
-
return_coords: bool,
|
|
543
|
-
|
|
551
|
+
return_coords: bool = False,
|
|
552
|
+
return_time_coords: bool = False,
|
|
553
|
+
feature_func: callable = None,
|
|
544
554
|
field: Optional[str] = None,
|
|
545
555
|
zonal_stats: Optional[str] = None,
|
|
546
556
|
time_field: Optional[str] = None,
|
|
547
|
-
|
|
548
|
-
) -> Tuple[List[str], List[Any]]:
|
|
557
|
+
) -> pd.DataFrame:
|
|
549
558
|
"""
|
|
550
559
|
Function passing the '_get_training_data_for_shp' function
|
|
551
560
|
to a mulitprocessing.Pool.
|
|
552
|
-
Inherits variables from 'collect_training_data
|
|
561
|
+
Inherits variables from 'collect_training_data'.
|
|
553
562
|
|
|
554
563
|
"""
|
|
555
564
|
# Check if dask-client is running
|
|
@@ -561,18 +570,26 @@ def _get_training_data_parallel(
|
|
|
561
570
|
|
|
562
571
|
if zx is not None:
|
|
563
572
|
raise ValueError(
|
|
564
|
-
"You have a Dask Client running, which prevents
|
|
573
|
+
"You have a Dask Client running, which prevents"
|
|
574
|
+
"this function from multiprocessing. Close the client."
|
|
565
575
|
)
|
|
566
576
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
results
|
|
570
|
-
|
|
577
|
+
crs = gdf.crs
|
|
578
|
+
|
|
579
|
+
# instantiate results list
|
|
580
|
+
results = []
|
|
571
581
|
|
|
572
582
|
# progress bar
|
|
573
583
|
pbar = tqdm(total=len(gdf))
|
|
574
584
|
|
|
575
|
-
|
|
585
|
+
# what to do with the results
|
|
586
|
+
def results_update(df):
|
|
587
|
+
results.append(df)
|
|
588
|
+
pbar.update()
|
|
589
|
+
|
|
590
|
+
# What to do with errors
|
|
591
|
+
def handle_error(index, e):
|
|
592
|
+
print(f"Worker failed on row {index}", str(e))
|
|
576
593
|
pbar.update()
|
|
577
594
|
|
|
578
595
|
with mp.Pool(ncpus) as pool:
|
|
@@ -580,131 +597,159 @@ def _get_training_data_parallel(
|
|
|
580
597
|
pool.apply_async(
|
|
581
598
|
_get_training_data_for_shp,
|
|
582
599
|
[
|
|
583
|
-
gdf,
|
|
584
|
-
index,
|
|
585
600
|
row,
|
|
586
|
-
|
|
587
|
-
column_names,
|
|
601
|
+
crs,
|
|
588
602
|
dc_query,
|
|
589
603
|
return_coords,
|
|
604
|
+
return_time_coords,
|
|
590
605
|
feature_func,
|
|
591
606
|
field,
|
|
592
607
|
zonal_stats,
|
|
593
608
|
time_field,
|
|
594
|
-
time_delta,
|
|
595
609
|
],
|
|
596
|
-
callback=
|
|
610
|
+
callback=results_update,
|
|
611
|
+
error_callback=partial(handle_error, index),
|
|
597
612
|
)
|
|
598
613
|
|
|
599
614
|
pool.close()
|
|
600
615
|
pool.join()
|
|
601
|
-
pbar.close()
|
|
602
616
|
|
|
603
|
-
|
|
617
|
+
pbar.close()
|
|
618
|
+
|
|
619
|
+
return results
|
|
604
620
|
|
|
605
621
|
|
|
606
622
|
def collect_training_data(
|
|
607
623
|
gdf: gpd.GeoDataFrame,
|
|
608
|
-
dc_query: dict,
|
|
624
|
+
dc_query: dict[str, Any],
|
|
609
625
|
ncpus: int = 1,
|
|
610
626
|
return_coords: bool = False,
|
|
627
|
+
return_time_coords: bool = False,
|
|
611
628
|
feature_func: callable = None,
|
|
612
629
|
field: str = None,
|
|
613
|
-
zonal_stats: str = None,
|
|
630
|
+
zonal_stats: Optional[str] = None,
|
|
614
631
|
clean: bool = True,
|
|
615
|
-
fail_threshold: float = 0.
|
|
632
|
+
fail_threshold: float = 0.05,
|
|
616
633
|
fail_ratio: float = 0.5,
|
|
617
|
-
max_retries: int =
|
|
618
|
-
time_field: str = None,
|
|
619
|
-
|
|
620
|
-
) -> Tuple[List[np.ndarray], List[str]]:
|
|
634
|
+
max_retries: int = 2,
|
|
635
|
+
time_field: Optional[str] = None,
|
|
636
|
+
) -> pd.DataFrame:
|
|
621
637
|
"""
|
|
622
|
-
This function provides methods for gathering training data from the ODC over
|
|
638
|
+
This function provides methods for gathering training/validation data from the ODC over
|
|
623
639
|
geometries stored within a geopandas geodataframe. The function will return a
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
the
|
|
628
|
-
|
|
640
|
+
pandas.DataFrame where the index contains class labels and the columns contain
|
|
641
|
+
feature values generated by a user-defined `feature_func`.
|
|
642
|
+
|
|
643
|
+
- In the instance where ncpus > 1, the function will automatically run in parallel.
|
|
644
|
+
- Zonal statistics are supported where the provided vector file contains polygons, otherwise all
|
|
645
|
+
pixel values are returned.
|
|
646
|
+
- Individual points/polygons can be loaded from different time ranges by passing the `time_field`
|
|
647
|
+
parameter.
|
|
648
|
+
- Implements a retry queue for samples that may fail due to i/o limitations or s3 read failures.
|
|
629
649
|
|
|
630
650
|
Parameters
|
|
631
651
|
----------
|
|
632
652
|
gdf : geopandas geodataframe
|
|
633
|
-
geometry data in the form of a geopandas geodataframe
|
|
653
|
+
geometry data in the form of a geopandas geodataframe. Must contain a class labels column,
|
|
654
|
+
can optionally contain a column with time stamps, specified with the`time_field` param.
|
|
634
655
|
dc_query : dictionary
|
|
635
|
-
Datacube query object, should not contain lat and long (x or y)
|
|
636
|
-
|
|
656
|
+
Datacube query object, should not contain lat and long (x or y) variables as these
|
|
657
|
+
are supplied by the geopolygon column in the 'gdf'.
|
|
637
658
|
ncpus : int
|
|
638
659
|
The number of cpus/processes over which to parallelize the gathering
|
|
639
|
-
of training data (only if ncpus is > 1).
|
|
640
|
-
cpus available on a machine. Defaults to 1.
|
|
641
|
-
return_coords : bool
|
|
642
|
-
If True, then the training data will contain two extra columns 'x_coord' and
|
|
643
|
-
'y_coord' corresponding to the x,y coordinate of each sample. This variable can
|
|
644
|
-
be useful for handling spatial autocorrelation between samples later in the ML workflow.
|
|
660
|
+
of training data (only if ncpus is > 1). Defaults to 1.
|
|
645
661
|
feature_func : function
|
|
646
662
|
A function for generating feature layers that is applied to the data within
|
|
647
663
|
the bounds of the input geometry. The 'feature_func' must accept a 'dc_query'
|
|
648
|
-
object, and return a single xarray.Dataset or xarray.DataArray
|
|
649
|
-
|
|
650
|
-
e.g.
|
|
664
|
+
object, and return a single xarray.Dataset or xarray.DataArray:
|
|
665
|
+
|
|
651
666
|
def feature_function(query):
|
|
652
667
|
dc = datacube.Datacube(app='feature_layers')
|
|
653
668
|
ds = dc.load(**query)
|
|
654
669
|
ds = ds.mean('time')
|
|
655
670
|
return ds
|
|
671
|
+
|
|
656
672
|
field : str
|
|
657
673
|
Name of the column in the gdf that contains the class labels
|
|
674
|
+
return_coords : bool
|
|
675
|
+
If True, then the output data will contain two extra columns 'x_coord' and
|
|
676
|
+
'y_coord' corresponding to the x,y coordinate of each sample.
|
|
677
|
+
return_time_coords : bool
|
|
678
|
+
If True, then the output data will contain an extra column 'time_coord',
|
|
679
|
+
corresponding to the time stamp of each sample.
|
|
658
680
|
zonal_stats : string, optional
|
|
659
681
|
An optional string giving the names of zonal statistics to calculate
|
|
660
682
|
for each polygon. Default is None (all pixel values are returned). Supported
|
|
661
683
|
values are 'mean', 'median', 'max', 'min'.
|
|
662
684
|
clean : bool
|
|
663
|
-
Whether or not to remove missing values in the
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
685
|
+
Whether or not to remove missing values in the returned dataset. If True (default),
|
|
686
|
+
rows with any NaNs or Infs in any numeric columns will be dropped from the dataset.
|
|
687
|
+
time_field : str, optional
|
|
688
|
+
Name of the column containing time(range) data in the input gdf, for the case where each row
|
|
689
|
+
should load from a different time(range). If loading from the same time(range) for
|
|
690
|
+
all rows, then its preferable to pass time as a key:variable in the 'dc_query'.
|
|
691
|
+
Note the time values must be in a format that datacube.load() accepts. For example, as a
|
|
692
|
+
tuple with strings ('2017-01-01', '2017-01-31'). Defaults to None.
|
|
693
|
+
fail_threshold : float, default 0.05
|
|
667
694
|
Silent read fails on S3 can result in some rows of the returned data containing NaN values.
|
|
668
695
|
The'fail_threshold' fraction specifies a % of acceptable fails.
|
|
669
696
|
e.g. Setting 'fail_threshold' to 0.05 means if >5% of the samples in the training dataset
|
|
670
|
-
fail then those samples will be
|
|
697
|
+
fail then those samples will be returned to the multiprocessing queue. Below this fraction
|
|
671
698
|
the function will accept the failures and return the results.
|
|
672
699
|
fail_ratio: float
|
|
673
700
|
A float between 0 and 1 that defines if a given training sample has failed.
|
|
674
701
|
Default is 0.5, which means if 50 % of the measurements in a given sample return null
|
|
675
702
|
values, and the number of total fails is more than the fail_threshold, the samplewill be
|
|
676
703
|
passed to the retry queue.
|
|
677
|
-
max_retries: int, default
|
|
704
|
+
max_retries: int, default 2
|
|
678
705
|
Maximum number of times to retry collecting samples. This number is invoked
|
|
679
706
|
if the 'fail_threshold' is not reached.
|
|
680
|
-
time_field: str
|
|
681
|
-
The name of the attribute in the input dataframe containing capture timestamp
|
|
682
|
-
time_delta: time_delta
|
|
683
|
-
The size of the window used as timestamp +/- time_delta.
|
|
684
|
-
This is used to allow matching a single field data point with multiple scenes
|
|
685
707
|
|
|
686
708
|
Returns
|
|
687
709
|
--------
|
|
688
|
-
|
|
689
|
-
|
|
710
|
+
pandas.DataFrame
|
|
711
|
+
Where the index contains class labels and the columns contain feature values
|
|
690
712
|
|
|
691
713
|
"""
|
|
714
|
+
# --------Conduct various checks before running the function--------
|
|
715
|
+
if feature_func is None:
|
|
716
|
+
raise ValueError(
|
|
717
|
+
"Please supply a feature layer function through the "
|
|
718
|
+
+ "parameter 'feature_func'"
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
if field is None:
|
|
722
|
+
raise ValueError("Parameter 'field' must be provided")
|
|
723
|
+
|
|
724
|
+
if field not in gdf.columns:
|
|
725
|
+
raise ValueError(f"Column '{field}' not found in GeoDataFrame")
|
|
692
726
|
|
|
693
727
|
# check the dtype of the class field
|
|
694
728
|
if not np.issubdtype(gdf[field].dtype, np.integer):
|
|
695
|
-
raise ValueError(
|
|
729
|
+
raise ValueError(
|
|
730
|
+
f'The "{field}" column of the input vector must contain integer dtypes'
|
|
731
|
+
)
|
|
696
732
|
|
|
697
|
-
# check
|
|
698
|
-
if
|
|
699
|
-
|
|
733
|
+
# check time-field params
|
|
734
|
+
if time_field is not None:
|
|
735
|
+
|
|
736
|
+
if "time" in dc_query:
|
|
737
|
+
raise ValueError(
|
|
738
|
+
f"You have passed both 'dc_query['time']' and 'time_field', "
|
|
739
|
+
"only pass one of these options"
|
|
740
|
+
)
|
|
700
741
|
|
|
701
|
-
|
|
702
|
-
|
|
742
|
+
if time_field not in gdf.columns:
|
|
743
|
+
raise ValueError(f"Column '{time_field}' not found in GeoDataFrame")
|
|
703
744
|
|
|
745
|
+
if zonal_stats:
|
|
746
|
+
print(f"Applying zonal statistic: {zonal_stats}")
|
|
747
|
+
|
|
748
|
+
# ----------------------------------------------------------------
|
|
704
749
|
# add unique id to gdf to help with indexing failed rows
|
|
705
750
|
# during multiprocessing
|
|
706
|
-
|
|
707
|
-
gdf["
|
|
751
|
+
gdf = gdf.copy() # only modify copy
|
|
752
|
+
gdf["_training_id"] = np.arange(len(gdf))
|
|
708
753
|
|
|
709
754
|
if ncpus == 1:
|
|
710
755
|
# progress indicator
|
|
@@ -713,104 +758,106 @@ def collect_training_data(
|
|
|
713
758
|
|
|
714
759
|
# list to store results
|
|
715
760
|
results = []
|
|
716
|
-
column_names = []
|
|
717
761
|
|
|
718
762
|
# loop through polys and extract training data
|
|
719
763
|
for index, row in gdf.iterrows():
|
|
720
764
|
print(" Feature {:04}/{:04}\r".format(i + 1, len(gdf)), end="")
|
|
721
765
|
|
|
722
|
-
_get_training_data_for_shp(
|
|
723
|
-
gdf,
|
|
724
|
-
index,
|
|
766
|
+
stacked = _get_training_data_for_shp(
|
|
725
767
|
row,
|
|
726
|
-
|
|
727
|
-
column_names,
|
|
768
|
+
gdf.crs,
|
|
728
769
|
dc_query,
|
|
729
770
|
return_coords,
|
|
771
|
+
return_time_coords,
|
|
730
772
|
feature_func,
|
|
731
773
|
field,
|
|
732
774
|
zonal_stats,
|
|
733
775
|
time_field,
|
|
734
|
-
time_delta,
|
|
735
776
|
)
|
|
777
|
+
|
|
778
|
+
results.append(stacked)
|
|
736
779
|
i += 1
|
|
737
780
|
|
|
738
781
|
else:
|
|
739
782
|
print("Collecting training data in parallel mode")
|
|
740
|
-
|
|
783
|
+
results = _get_training_data_parallel(
|
|
741
784
|
gdf=gdf,
|
|
742
785
|
dc_query=dc_query,
|
|
743
786
|
ncpus=ncpus,
|
|
744
787
|
return_coords=return_coords,
|
|
788
|
+
return_time_coords=return_time_coords,
|
|
745
789
|
feature_func=feature_func,
|
|
746
790
|
field=field,
|
|
747
791
|
zonal_stats=zonal_stats,
|
|
748
792
|
time_field=time_field,
|
|
749
|
-
time_delta=time_delta,
|
|
750
793
|
)
|
|
751
794
|
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
column_names = column_names[0]
|
|
795
|
+
if not results:
|
|
796
|
+
raise RuntimeError("No samples returned from feature extraction.")
|
|
755
797
|
|
|
756
|
-
#
|
|
757
|
-
|
|
798
|
+
# join all results into a single df
|
|
799
|
+
df = pd.concat(results)
|
|
758
800
|
|
|
759
801
|
# this code block below iteratively retries failed rows
|
|
760
802
|
# up to max_retries or until fail_threshold is
|
|
761
|
-
# reached
|
|
803
|
+
# reached, whichever occurs first.
|
|
762
804
|
if ncpus > 1:
|
|
763
805
|
i = 1
|
|
764
806
|
while i <= max_retries:
|
|
765
|
-
# Find % of fails (null values) in data
|
|
766
|
-
|
|
807
|
+
# Find % of fails (null values) in data
|
|
808
|
+
dff = df.set_index("_training_id")
|
|
767
809
|
# how many nan values per id?
|
|
768
|
-
num_nans =
|
|
810
|
+
num_nans = dff.isnull().sum(axis=1)
|
|
769
811
|
num_nans = num_nans.groupby(num_nans.index).sum()
|
|
770
812
|
# how many valid values per id?
|
|
771
|
-
num_valid =
|
|
813
|
+
num_valid = dff.notnull().sum(axis=1)
|
|
772
814
|
num_valid = num_valid.groupby(num_valid.index).sum()
|
|
773
815
|
# find fail rate
|
|
774
816
|
perc_fail = num_nans / (num_nans + num_valid)
|
|
775
817
|
fail_ids = perc_fail[perc_fail > fail_ratio]
|
|
818
|
+
|
|
776
819
|
fail_rate = len(fail_ids) / len(gdf)
|
|
777
820
|
|
|
778
|
-
print(
|
|
821
|
+
print(
|
|
822
|
+
"Percentage of possible fails after run "
|
|
823
|
+
+ str(i)
|
|
824
|
+
+ " = "
|
|
825
|
+
+ str(round(fail_rate * 100, 2))
|
|
826
|
+
+ " %"
|
|
827
|
+
)
|
|
779
828
|
|
|
780
829
|
if fail_rate > fail_threshold:
|
|
781
830
|
print("Recollecting samples that failed")
|
|
782
831
|
|
|
783
832
|
fail_ids = list(fail_ids.index)
|
|
784
|
-
|
|
785
|
-
|
|
833
|
+
|
|
834
|
+
# keep only the ids in df object that didn't fail
|
|
835
|
+
df = df.loc[~df["_training_id"].isin(fail_ids)]
|
|
786
836
|
|
|
787
837
|
# index out the fail_ids from the original gdf
|
|
788
|
-
gdf_rerun = gdf.loc[gdf["
|
|
838
|
+
gdf_rerun = gdf.loc[gdf["_training_id"].isin(fail_ids)]
|
|
789
839
|
gdf_rerun = gdf_rerun.reset_index(drop=True)
|
|
790
840
|
|
|
791
|
-
time.sleep(
|
|
841
|
+
time.sleep(3) # sleep for 3s to rest api
|
|
792
842
|
|
|
793
843
|
# recollect failed rows
|
|
794
|
-
(
|
|
795
|
-
column_names_again,
|
|
796
|
-
results_again,
|
|
797
|
-
) = _get_training_data_parallel(
|
|
844
|
+
results_again = _get_training_data_parallel(
|
|
798
845
|
gdf=gdf_rerun,
|
|
799
846
|
dc_query=dc_query,
|
|
800
847
|
ncpus=ncpus,
|
|
801
848
|
return_coords=return_coords,
|
|
849
|
+
return_time_coords=return_time_coords,
|
|
802
850
|
feature_func=feature_func,
|
|
803
851
|
field=field,
|
|
804
852
|
zonal_stats=zonal_stats,
|
|
805
853
|
time_field=time_field,
|
|
806
|
-
time_delta=time_delta,
|
|
807
854
|
)
|
|
808
855
|
|
|
809
856
|
# Stack the extracted training data for each feature into a single array
|
|
810
|
-
|
|
857
|
+
df_again = pd.concat(results_again)
|
|
811
858
|
|
|
812
859
|
# merge results of the re-run with original run
|
|
813
|
-
|
|
860
|
+
df = pd.concat([df, df_again])
|
|
814
861
|
|
|
815
862
|
i += 1
|
|
816
863
|
|
|
@@ -818,24 +865,30 @@ def collect_training_data(
|
|
|
818
865
|
break
|
|
819
866
|
|
|
820
867
|
# -----------------------------------------------
|
|
821
|
-
|
|
822
868
|
# remove id column
|
|
823
|
-
|
|
824
|
-
model_col_indices = [column_names.index(var_name) for var_name in idx_var]
|
|
825
|
-
model_input = model_input[:, model_col_indices]
|
|
869
|
+
df = df.drop("_training_id", axis=1)
|
|
826
870
|
|
|
827
871
|
if clean:
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
872
|
+
# Identify which columns have numeric data
|
|
873
|
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|
874
|
+
|
|
875
|
+
# do we have any numeric columns to clean
|
|
876
|
+
if len(numeric_cols) == 0:
|
|
877
|
+
print("No numeric columns to clean; leaving DataFrame unchanged.")
|
|
878
|
+
else:
|
|
879
|
+
# Build invalid mask on numeric columns only, NaN or Inf
|
|
880
|
+
invalid_mask = ~np.isfinite(df[numeric_cols]).all(axis=1)
|
|
881
|
+
num_removed = invalid_mask.sum()
|
|
882
|
+
df = df[~invalid_mask]
|
|
883
|
+
|
|
884
|
+
print(f"Removed {num_removed} rows with NaNs or Infs in numeric columns")
|
|
885
|
+
print("Output shape:", df.shape)
|
|
833
886
|
|
|
834
887
|
else:
|
|
835
888
|
print("Returning data without cleaning")
|
|
836
|
-
print("Output shape: ",
|
|
889
|
+
print("Output shape: ", df.shape)
|
|
837
890
|
|
|
838
|
-
return
|
|
891
|
+
return df.set_index(field)
|
|
839
892
|
|
|
840
893
|
|
|
841
894
|
class KMeans_tree(ClusterMixin):
|
|
@@ -866,7 +919,8 @@ class KMeans_tree(ClusterMixin):
|
|
|
866
919
|
# make child models
|
|
867
920
|
if n_levels > 1:
|
|
868
921
|
self.branches = [
|
|
869
|
-
KMeans_tree(n_levels=n_levels - 1, n_clusters=n_clusters, **kwargs)
|
|
922
|
+
KMeans_tree(n_levels=n_levels - 1, n_clusters=n_clusters, **kwargs)
|
|
923
|
+
for _ in range(n_clusters)
|
|
870
924
|
]
|
|
871
925
|
|
|
872
926
|
def fit(self, X, y=None, sample_weight=None):
|
|
@@ -898,7 +952,11 @@ class KMeans_tree(ClusterMixin):
|
|
|
898
952
|
# fit child models on their corresponding partition of the training set
|
|
899
953
|
self.branches[clu].fit(
|
|
900
954
|
X[labels_old == clu],
|
|
901
|
-
sample_weight=(
|
|
955
|
+
sample_weight=(
|
|
956
|
+
sample_weight[labels_old == clu]
|
|
957
|
+
if sample_weight is not None
|
|
958
|
+
else None
|
|
959
|
+
),
|
|
902
960
|
)
|
|
903
961
|
self.labels_[labels_old == clu] += self.branches[clu].labels_
|
|
904
962
|
|
|
@@ -934,13 +992,24 @@ class KMeans_tree(ClusterMixin):
|
|
|
934
992
|
for clu in range(self.n_clusters):
|
|
935
993
|
result[rescpy == clu] += self.branches[clu].predict(
|
|
936
994
|
X[rescpy == clu],
|
|
937
|
-
sample_weight=(
|
|
995
|
+
sample_weight=(
|
|
996
|
+
sample_weight[rescpy == clu]
|
|
997
|
+
if sample_weight is not None
|
|
998
|
+
else None
|
|
999
|
+
),
|
|
938
1000
|
)
|
|
939
1001
|
|
|
940
1002
|
return result
|
|
941
1003
|
|
|
942
1004
|
|
|
943
|
-
def spatial_clusters(
|
|
1005
|
+
def spatial_clusters(
|
|
1006
|
+
coordinates,
|
|
1007
|
+
method="Hierarchical",
|
|
1008
|
+
max_distance=None,
|
|
1009
|
+
n_groups=None,
|
|
1010
|
+
verbose=False,
|
|
1011
|
+
**kwargs,
|
|
1012
|
+
):
|
|
944
1013
|
"""
|
|
945
1014
|
Create spatial groups on coorindate data using either KMeans clustering
|
|
946
1015
|
or a Gaussian Mixture model
|
|
@@ -974,21 +1043,28 @@ def spatial_clusters(coordinates, method="Hierarchical", max_distance=None, n_gr
|
|
|
974
1043
|
raise ValueError("Method must be one of: 'Hierarchical','KMeans' or 'GMM'")
|
|
975
1044
|
|
|
976
1045
|
if (method in ["GMM", "KMeans"]) & (n_groups is None):
|
|
977
|
-
raise ValueError(
|
|
1046
|
+
raise ValueError(
|
|
1047
|
+
"The 'GMM' and 'KMeans' methods requires explicitly setting 'n_groups'"
|
|
1048
|
+
)
|
|
978
1049
|
|
|
979
1050
|
if (method == "Hierarchical") & (max_distance is None):
|
|
980
1051
|
raise ValueError("The 'Hierarchical' method requires setting max_distance")
|
|
981
1052
|
|
|
982
1053
|
if method == "Hierarchical":
|
|
983
1054
|
cluster_label = AgglomerativeClustering(
|
|
984
|
-
n_clusters=None,
|
|
1055
|
+
n_clusters=None,
|
|
1056
|
+
linkage="complete",
|
|
1057
|
+
distance_threshold=max_distance,
|
|
1058
|
+
**kwargs,
|
|
985
1059
|
).fit_predict(coordinates)
|
|
986
1060
|
|
|
987
1061
|
if method == "KMeans":
|
|
988
1062
|
cluster_label = KMeans(n_clusters=n_groups, **kwargs).fit_predict(coordinates)
|
|
989
1063
|
|
|
990
1064
|
if method == "GMM":
|
|
991
|
-
cluster_label = GaussianMixture(n_components=n_groups, **kwargs).fit_predict(
|
|
1065
|
+
cluster_label = GaussianMixture(n_components=n_groups, **kwargs).fit_predict(
|
|
1066
|
+
coordinates
|
|
1067
|
+
)
|
|
992
1068
|
if verbose:
|
|
993
1069
|
print("n clusters = " + str(len(np.unique(cluster_label))))
|
|
994
1070
|
|
|
@@ -1217,7 +1293,9 @@ def spatial_train_test_split(
|
|
|
1217
1293
|
|
|
1218
1294
|
if kfold_method == "SpatialKFold":
|
|
1219
1295
|
if n_splits is None:
|
|
1220
|
-
raise ValueError(
|
|
1296
|
+
raise ValueError(
|
|
1297
|
+
"n_splits parameter requires an integer value, eg. 'n_splits=5'"
|
|
1298
|
+
)
|
|
1221
1299
|
if (test_size is not None) or (train_size is not None):
|
|
1222
1300
|
warnings.warn(
|
|
1223
1301
|
"With the 'SpatialKFold' method, controlling the test/train ratio "
|
|
@@ -1268,7 +1346,11 @@ def _partition_by_sum(array, parts):
|
|
|
1268
1346
|
"""
|
|
1269
1347
|
array = np.atleast_1d(array).ravel()
|
|
1270
1348
|
if parts > array.size:
|
|
1271
|
-
raise ValueError(
|
|
1349
|
+
raise ValueError(
|
|
1350
|
+
"Cannot partition an array of size {} into {} parts of equal sum.".format(
|
|
1351
|
+
array.size, parts
|
|
1352
|
+
)
|
|
1353
|
+
)
|
|
1272
1354
|
cumulative_sum = array.cumsum()
|
|
1273
1355
|
# Ideally, we want each part to have the same number of points (total /
|
|
1274
1356
|
# parts).
|
|
@@ -1279,7 +1361,11 @@ def _partition_by_sum(array, parts):
|
|
|
1279
1361
|
# Check for repeated split points, which indicates that there is no way to
|
|
1280
1362
|
# split the array.
|
|
1281
1363
|
if np.unique(indices).size != indices.size:
|
|
1282
|
-
raise ValueError(
|
|
1364
|
+
raise ValueError(
|
|
1365
|
+
"Could not find partition points to split the array into {} parts of equal sum.".format(
|
|
1366
|
+
parts
|
|
1367
|
+
)
|
|
1368
|
+
)
|
|
1283
1369
|
return indices
|
|
1284
1370
|
|
|
1285
1371
|
|
|
@@ -1337,7 +1423,11 @@ class _BaseSpatialCrossValidator(BaseCrossValidator, metaclass=ABCMeta):
|
|
|
1337
1423
|
The testing set indices for that split.
|
|
1338
1424
|
"""
|
|
1339
1425
|
if X.shape[1] != 2:
|
|
1340
|
-
raise ValueError(
|
|
1426
|
+
raise ValueError(
|
|
1427
|
+
"X (the coordinate data) must have exactly 2 columns ({} given).".format(
|
|
1428
|
+
X.shape[1]
|
|
1429
|
+
)
|
|
1430
|
+
)
|
|
1341
1431
|
for train, test in super().split(X, y, groups):
|
|
1342
1432
|
yield train, test
|
|
1343
1433
|
|
|
@@ -1471,7 +1561,9 @@ class _SpatialShuffleSplit(_BaseSpatialCrossValidator):
|
|
|
1471
1561
|
**kwargs,
|
|
1472
1562
|
)
|
|
1473
1563
|
if balance < 1:
|
|
1474
|
-
raise ValueError(
|
|
1564
|
+
raise ValueError(
|
|
1565
|
+
"The *balance* argument must be >= 1. To disable balance, use 1."
|
|
1566
|
+
)
|
|
1475
1567
|
self.test_size = test_size
|
|
1476
1568
|
self.train_size = train_size
|
|
1477
1569
|
self.random_state = random_state
|
|
@@ -1530,7 +1622,12 @@ class _SpatialShuffleSplit(_BaseSpatialCrossValidator):
|
|
|
1530
1622
|
test_points = np.where(np.isin(labels, cluster_ids[test_clusters]))[0]
|
|
1531
1623
|
# The proportion of data points assigned to each group should
|
|
1532
1624
|
# be close the proportion of clusters assigned to each group.
|
|
1533
|
-
balance.append(
|
|
1625
|
+
balance.append(
|
|
1626
|
+
abs(
|
|
1627
|
+
train_points.size / test_points.size
|
|
1628
|
+
- train_clusters.size / test_clusters.size
|
|
1629
|
+
)
|
|
1630
|
+
)
|
|
1534
1631
|
test_sets.append(test_points)
|
|
1535
1632
|
best = np.argmin(balance)
|
|
1536
1633
|
yield test_sets[best]
|
|
@@ -1612,7 +1709,11 @@ class _SpatialKFold(_BaseSpatialCrossValidator):
|
|
|
1612
1709
|
)
|
|
1613
1710
|
|
|
1614
1711
|
if n_splits < 2:
|
|
1615
|
-
raise ValueError(
|
|
1712
|
+
raise ValueError(
|
|
1713
|
+
"Number of splits must be >=2 for clusterKFold. Given {}.".format(
|
|
1714
|
+
n_splits
|
|
1715
|
+
)
|
|
1716
|
+
)
|
|
1616
1717
|
self.test_size = test_size
|
|
1617
1718
|
self.shuffle = shuffle
|
|
1618
1719
|
self.random_state = random_state
|
|
@@ -44,14 +44,10 @@ from dea_tools.dask import create_local_dask_cluster
|
|
|
44
44
|
from dea_tools.datahandling import load_ard
|
|
45
45
|
from dea_tools.spatial import xr_rasterize
|
|
46
46
|
|
|
47
|
-
# Create local dask cluster to improve data load time
|
|
48
|
-
client = create_local_dask_cluster(return_client=True)
|
|
49
|
-
|
|
50
47
|
# disable DeprecationWarning for chained assignments in conversion to
|
|
51
48
|
# datetime format
|
|
52
49
|
pd.options.mode.chained_assignment = None # default='warn'
|
|
53
50
|
|
|
54
|
-
|
|
55
51
|
def normalise_wit(polygon_base_df):
|
|
56
52
|
"""
|
|
57
53
|
This function is to normalise the Fractional Cover vegetation
|
|
@@ -396,7 +392,7 @@ def WIT_drill(
|
|
|
396
392
|
# Connect to the datacube
|
|
397
393
|
dc = datacube.Datacube(app="WIT_drill")
|
|
398
394
|
|
|
399
|
-
# load
|
|
395
|
+
# load Landsat 5,7,8,9 data
|
|
400
396
|
warnings.filterwarnings("ignore")
|
|
401
397
|
|
|
402
398
|
# load wetland polygon and specify the coordinate reference system of the polygon
|
|
@@ -410,10 +406,10 @@ def WIT_drill(
|
|
|
410
406
|
if verbose_progress:
|
|
411
407
|
print("Loading Landsat data")
|
|
412
408
|
|
|
413
|
-
# Load Landsat 5, 7 and
|
|
409
|
+
# Load Landsat 5, 7, 8 and 9 data. Not including Landsat 7 SLC off period (31-05-2003 to 06-04-2022)
|
|
414
410
|
ds_ls = load_ard(
|
|
415
411
|
dc,
|
|
416
|
-
products=["ga_ls8c_ard_3", "ga_ls7e_ard_3", "ga_ls5t_ard_3"],
|
|
412
|
+
products=["ga_ls9c_ard_3", "ga_ls8c_ard_3", "ga_ls7e_ard_3", "ga_ls5t_ard_3"],
|
|
417
413
|
ls7_slc_off=False,
|
|
418
414
|
measurements=bands,
|
|
419
415
|
geopolygon=gpgon,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|