dea-tools 0.4.8.dev13__tar.gz → 0.4.9.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/PKG-INFO +1 -1
  2. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/wetlandsinsighttool.py +12 -4
  3. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/bandindices.py +11 -10
  4. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/classification.py +325 -224
  5. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/wetlands.py +3 -7
  6. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/.gitignore +0 -0
  7. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/LICENSE +0 -0
  8. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/README.md +0 -0
  9. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/__init__.py +0 -0
  10. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/__main__.py +0 -0
  11. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/__init__.py +0 -0
  12. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/animations.py +0 -0
  13. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/changefilmstrips.py +0 -0
  14. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/crophealth.py +0 -0
  15. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/deacoastlines.py +0 -0
  16. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/geomedian.py +0 -0
  17. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/imageexport.py +0 -0
  18. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/miningrehab.py +0 -0
  19. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/app/widgetconstructors.py +0 -0
  20. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/bom.py +0 -0
  21. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/coastal.py +0 -0
  22. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/dask.py +0 -0
  23. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/datahandling.py +0 -0
  24. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/landcover.py +0 -0
  25. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/maps.py +0 -0
  26. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/README.md +0 -0
  27. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/__init__.py +0 -0
  28. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/cog.py +0 -0
  29. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/styling.py +0 -0
  30. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/utils.py +0 -0
  31. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/mosaics/vrt.py +0 -0
  32. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/plotting.py +0 -0
  33. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/spatial.py +0 -0
  34. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/temporal.py +0 -0
  35. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/validation.py +0 -0
  36. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/Tools/dea_tools/waterbodies.py +0 -0
  37. {dea_tools-0.4.8.dev13 → dea_tools-0.4.9.dev2}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dea-tools
3
- Version: 0.4.8.dev13
3
+ Version: 0.4.9.dev2
4
4
  Summary: Open-source tools for geospatial analysis with Digital Earth Australia, Open Data Cube, and Xarray
5
5
  Project-URL: Homepage, https://knowledge.dea.ga.gov.au/notebooks/Tools/
6
6
  Project-URL: Repository, https://github.com/GeoscienceAustralia/dea-notebooks
@@ -488,8 +488,7 @@ class wit_app(HBox):
488
488
  # ---Plotting------------------------------
489
489
  if df is not None:
490
490
  with self.wit_plot:
491
- fontsize = 17
492
- plt.rcParams.update({"font.size": fontsize})
491
+
493
492
  # set up color palette
494
493
  pal = [
495
494
  sns.xkcd_rgb["cobalt blue"],
@@ -543,8 +542,17 @@ class wit_app(HBox):
543
542
  hatch="//",
544
543
  )
545
544
 
546
- ax.xaxis.set_major_locator(mdates.MonthLocator())
547
- ax.xaxis.set_major_formatter(mdates.DateFormatter("%b-%Y"))
545
+ # calculate how many years of data have been loaded
546
+ date_range_years = (df["date"].max() - df["date"].min()).days / 365.25
547
+
548
+ if date_range_years > 5:
549
+ # show only years on x-axis
550
+ ax.xaxis.set_major_locator(mdates.YearLocator())
551
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
552
+ else:
553
+ # show months and years on x-axis
554
+ ax.xaxis.set_major_locator(mdates.MonthLocator())
555
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%b-%Y"))
548
556
 
549
557
  # Rotates and right-aligns the x labels so they don't crowd each other.
550
558
  for label in ax.get_xticklabels(which="major"):
@@ -15,7 +15,7 @@ here: https://gis.stackexchange.com/questions/tagged/open-data-cube).
15
15
  If you would like to report an issue with this script, you can file one
16
16
  on GitHub (https://github.com/GeoscienceAustralia/dea-notebooks/issues/new).
17
17
 
18
- Last modified: June 2023
18
+ Last modified: March 2026
19
19
  """
20
20
 
21
21
  # Import required packages
@@ -87,6 +87,8 @@ def calculate_indices(
87
87
  * ``'NIRv'`` (Near-Infrared Reflectance of Vegetation, Badgley et al. 2017)
88
88
  * ``'kNDVI'`` (Kernel Normalized Difference Vegetation Index, Camps-Valls et al. 2021)
89
89
  * ``'SAVI'`` (Soil Adjusted Vegetation Index, Huete 1988)
90
+ * ``'SRVI'`` (Symbolic Regression Vegetation Index, Chrysostomou 2026)
91
+ * ``'SRWI'`` (Symbolic Regression Water Index, Chrysostomou 2026)
90
92
  * ``'TCB'`` (Tasseled Cap Brightness, Crist 1985)
91
93
  * ``'TCG'`` (Tasseled Cap Greeness, Crist 1985)
92
94
  * ``'TCW'`` (Tasseled Cap Wetness, Crist 1985)
@@ -94,8 +96,6 @@ def calculate_indices(
94
96
  * ``'TCG_GSO'`` (Tasseled Cap Greeness, Nedkov 2017)
95
97
  * ``'TCW_GSO'`` (Tasseled Cap Wetness, Nedkov 2017)
96
98
  * ``'WI'`` (Water Index, Fisher 2016)
97
- * ``'kNDVI'`` (Non-linear Normalised Difference Vegation Index,
98
- Camps-Valls et al. 2021)
99
99
 
100
100
  collection : str
101
101
  An string that tells the function what data collection is
@@ -107,7 +107,6 @@ def calculate_indices(
107
107
 
108
108
  * ``'ga_ls_3'`` (for GA Landsat Collection 3)
109
109
  * ``'ga_s2_3'`` (for GA Sentinel 2 Collection 3)
110
- * ``'ga_gm_3'`` (for GA Geomedian Collection 3)
111
110
 
112
111
  custom_varname : str, optional
113
112
  By default, the original dataset will be returned with
@@ -295,6 +294,12 @@ def calculate_indices(
295
294
  "FMR": lambda ds: (ds.swir1 / ds.nir),
296
295
  # Iron Oxide Ratio, Segal 1982
297
296
  "IOR": lambda ds: (ds.red / ds.blue),
297
+ # Symbolic Regression Water Index, Chrysostomou 2026
298
+ "SRWI": lambda ds: ((ds.green + ds.blue) - (ds.nir + ds.swir1))
299
+ / ((ds.green + ds.blue) + (ds.nir + ds.swir1)),
300
+ # Symbolic Regression Vegetation Index, Chrysostomou 2026
301
+ "SRVI": lambda ds: ((2 * ds.nir) - (3 * ds.red))
302
+ / (ds.nir + ds.red + (0.5 * (ds.green + ds.swir1))),
298
303
  }
299
304
 
300
305
  # If index supplied is not a list, convert to list. This allows us to
@@ -343,7 +348,7 @@ def calculate_indices(
343
348
  if collection is None:
344
349
  raise ValueError(
345
350
  "'No `collection` was provided. Please specify "
346
- "either 'ga_ls_3', 'ga_s2_3' or 'ga_gm_3' "
351
+ "either 'ga_ls_3' or 'ga_s2_3' "
347
352
  "to ensure the function calculates indices "
348
353
  "using the correct spectral bands"
349
354
  )
@@ -396,16 +401,12 @@ def calculate_indices(
396
401
  a: b for a, b in bandnames_dict.items() if a in ds.variables
397
402
  }
398
403
 
399
- elif collection == "ga_gm_3":
400
- # Pass an empty dict as no bands need renaming
401
- bands_to_rename = {}
402
-
403
404
  # Raise error if no valid collection name is provided:
404
405
  else:
405
406
  raise ValueError(
406
407
  f"'{collection}' is not a valid option for "
407
408
  "`collection`. Please specify either \n"
408
- "'ga_ls_3', 'ga_s2_3' or 'ga_gm_3'"
409
+ "'ga_ls_3' or 'ga_s2_3'"
409
410
  )
410
411
 
411
412
  # Apply index function
@@ -17,35 +17,37 @@ here: https://gis.stackexchange.com/questions/tagged/open-data-cube).
17
17
  If you would like to report an issue with this script, you can file one
18
18
  on GitHub (https://github.com/GeoscienceAustralia/dea-notebooks/issues/new).
19
19
 
20
- Last modified: May 2021
20
+ Last modified: February 2026
21
21
  """
22
22
 
23
- import multiprocessing as mp
24
23
  import os
25
24
  import sys
26
25
  import time
27
- import warnings
28
- from abc import ABCMeta, abstractmethod
29
- from copy import deepcopy
30
- from datetime import timedelta
31
- from typing import Any, Callable, Dict, List, Optional, Tuple
32
-
33
- import dask.array as da
34
- import dask.distributed as dd
35
- import geopandas as gpd
26
+ import pyproj
36
27
  import joblib
28
+ import warnings
37
29
  import numpy as np
38
30
  import pandas as pd
39
31
  import xarray as xr
32
+ import dask.array as da
33
+ import geopandas as gpd
34
+ from tqdm.auto import tqdm
35
+ import multiprocessing as mp
36
+ import dask.distributed as dd
37
+ from functools import partial
38
+ from datetime import datetime, timedelta
39
+ from abc import ABCMeta, abstractmethod
40
40
  from dask_ml.wrappers import ParallelPostFit
41
+
41
42
  from odc.geo.geom import Geometry
42
43
  from odc.geo.xr import assign_crs
43
44
  from sklearn.base import ClusterMixin
44
- from sklearn.cluster import AgglomerativeClustering, KMeans
45
+ from sklearn.utils import check_random_state
45
46
  from sklearn.mixture import GaussianMixture
47
+ from sklearn.cluster import AgglomerativeClustering, KMeans
46
48
  from sklearn.model_selection import BaseCrossValidator, KFold, ShuffleSplit
47
- from sklearn.utils import check_random_state
48
- from tqdm.auto import tqdm
49
+
50
+ from typing import Any, Callable, Dict, List, Optional, Tuple
49
51
 
50
52
  from dea_tools.spatial import xr_rasterize
51
53
 
@@ -57,8 +59,7 @@ def sklearn_flatten(input_xr):
57
59
  dimensions flattened into one dimension.
58
60
 
59
61
  This flattening procedure enables DataArrays and Datasets to be used
60
- to train and predict
61
- with sklearn models.
62
+ to train and predict with sklearn models.
62
63
 
63
64
  Last modified: September 2019
64
65
 
@@ -83,7 +84,11 @@ def sklearn_flatten(input_xr):
83
84
  input_xr = input_xr.to_array()
84
85
 
85
86
  # stack across pixel dimensions, handling timeseries if necessary
86
- stacked = input_xr.stack(z=["x", "y", "time"]) if "time" in input_xr.dims else input_xr.stack(z=["x", "y"])
87
+ stacked = (
88
+ input_xr.stack(z=["x", "y", "time"])
89
+ if "time" in input_xr.dims
90
+ else input_xr.stack(z=["x", "y"])
91
+ )
87
92
 
88
93
  # finding 'bands' dimensions in each pixel - these will not be
89
94
  # flattened as their context is important for sklearn
@@ -146,7 +151,11 @@ def sklearn_unflatten(output_np, input_xr):
146
151
  input_xr = input_xr.to_array()
147
152
 
148
153
  # generate the same mask we used to create the input to the sklearn model
149
- stacked = input_xr.stack(z=["x", "y", "time"]) if "time" in input_xr.dims else input_xr.stack(z=["x", "y"])
154
+ stacked = (
155
+ input_xr.stack(z=["x", "y", "time"])
156
+ if "time" in input_xr.dims
157
+ else input_xr.stack(z=["x", "y"])
158
+ )
150
159
 
151
160
  pxdims = []
152
161
  for dim in stacked.dims:
@@ -287,7 +296,9 @@ def predict_xr(
287
296
  input_data_flattened = da.array(input_data_flattened).transpose()
288
297
 
289
298
  if clean:
290
- input_data_flattened = da.where(da.isfinite(input_data_flattened), input_data_flattened, 0)
299
+ input_data_flattened = da.where(
300
+ da.isfinite(input_data_flattened), input_data_flattened, 0
301
+ )
291
302
 
292
303
  if proba & persist:
293
304
  # persisting data so we don't require loading all the data twice
@@ -319,18 +330,24 @@ def predict_xr(
319
330
  out_proba = da.max(out_proba, axis=1) * 100.0
320
331
  out_proba = out_proba.reshape(len(y), len(x))
321
332
 
322
- out_proba = xr.DataArray(out_proba, coords={"x": x, "y": y}, dims=["y", "x"])
333
+ out_proba = xr.DataArray(
334
+ out_proba, coords={"x": x, "y": y}, dims=["y", "x"]
335
+ )
323
336
  output_xr["Probabilities"] = out_proba
324
337
  else:
325
338
  print(" returning class probability array")
326
339
  out_proba = out_proba * 100.0
327
- class_names = model.classes_ # Get the unique class names from the fitted classifier
340
+ class_names = (
341
+ model.classes_
342
+ ) # Get the unique class names from the fitted classifier
328
343
 
329
344
  # Loop through each class (band)
330
345
  probabilities_dataset = xr.Dataset()
331
346
  for i, class_name in enumerate(class_names):
332
347
  reshaped_band = out_proba[:, i].reshape(len(y), len(x))
333
- reshaped_da = xr.DataArray(reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"])
348
+ reshaped_da = xr.DataArray(
349
+ reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"]
350
+ )
334
351
  probabilities_dataset[f"prob_{class_name}"] = reshaped_da
335
352
 
336
353
  # merge in the probabilities
@@ -351,7 +368,9 @@ def predict_xr(
351
368
  if len(input_data_flattened.shape[1:]):
352
369
  output_px_shape = input_data_flattened.shape[1:]
353
370
 
354
- output_features = input_data_flattened.reshape((len(stacked.z), *output_px_shape))
371
+ output_features = input_data_flattened.reshape(
372
+ (len(stacked.z), *output_px_shape)
373
+ )
355
374
 
356
375
  # set the stacked coordinate to match the input
357
376
  output_features = xr.DataArray(
@@ -366,7 +385,9 @@ def predict_xr(
366
385
  # convert to dataset and rename arrays
367
386
  output_features = output_features.to_dataset(dim="output_dim_0")
368
387
  data_vars = list(input_xr.data_vars)
369
- output_features = output_features.rename({i: j for i, j in zip(output_features.data_vars, data_vars)})
388
+ output_features = output_features.rename(
389
+ {i: j for i, j in zip(output_features.data_vars, data_vars)}
390
+ )
370
391
 
371
392
  # merge with predictions
372
393
  output_xr = xr.merge([output_xr, output_features], compat="override")
@@ -377,10 +398,14 @@ def predict_xr(
377
398
  # convert model to dask predict
378
399
  model = ParallelPostFit(model)
379
400
  with joblib.parallel_backend("dask"):
380
- output_xr = _predict_func(model, input_xr, persist, proba, max_proba, clean, return_input)
401
+ output_xr = _predict_func(
402
+ model, input_xr, persist, proba, max_proba, clean, return_input
403
+ )
381
404
 
382
405
  else:
383
- output_xr = _predict_func(model, input_xr, persist, proba, max_proba, clean, return_input).compute()
406
+ output_xr = _predict_func(
407
+ model, input_xr, persist, proba, max_proba, clean, return_input
408
+ ).compute()
384
409
 
385
410
  return output_xr
386
411
 
@@ -400,42 +425,35 @@ class HiddenPrints:
400
425
 
401
426
 
402
427
  def _get_training_data_for_shp(
403
- gdf: gpd.GeoDataFrame,
404
- index: int,
405
428
  row: gpd.GeoSeries,
406
- out_arrs: List[np.ndarray],
407
- out_vars: List[List[str]],
429
+ crs: pyproj.CRS,
408
430
  dc_query: Dict,
409
431
  return_coords: bool,
410
- feature_func: Optional[callable] = None,
432
+ return_time_coords: bool,
433
+ feature_func: callable = None,
411
434
  field: Optional[str] = None,
412
435
  zonal_stats: Optional[str] = None,
413
436
  time_field: Optional[str] = None,
414
- time_delta: Optional[timedelta] = None,
415
- ):
437
+ ) -> pd.DataFrame:
416
438
  """
417
439
  This is the core function that is triggered by `collect_training_data`.
418
440
  The `collect_training_data` function loops through geometries in a geopandas
419
- geodataframe and runs the code within `_get_training_data_for_shp`.
420
- Parameters are inherited from `collect_training_data`.
421
- See that function for information on the other params not listed below.
441
+ geodataframe and runs this function. See `collect_training_data` for more
442
+ information on the parameters than is detailed below.
422
443
 
423
444
  Parameters
424
445
  ----------
425
- gdf : gpd.GeoDataFrame
426
- Geopandas GeoDataFrame containing geometries.
427
- index : int
428
- Index of the current geometry in the GeoDataFrame.
429
446
  row : gpd.GeoSeries
430
447
  GeoSeries representing the current row in the GeoDataFrame.
431
- out_arrs : List[np.ndarray]
432
- An empty list into which the training data arrays are stored.
433
- out_vars : List[List[str]]
434
- An empty list into which the data variable names are stored.
448
+ crs : pyrpoj.CRS
449
+ Coordinate reference system information extracted from a GeoDataFrame
450
+ e.g., crs=gdf.crs
435
451
  dc_query : Dict
436
- ODC query.
452
+ ODC query object.
437
453
  return_coords : bool
438
- Flag indicating whether to return coordinates in the dataset.
454
+ Flag indicating whether to return x,y coordinates in the dataset.
455
+ return_time_coords : bool
456
+ Flag indicating whether to return time coordinates in the dataset
439
457
  feature_func : callable, optional
440
458
  Optional function to extract data based on `dc_query`. Defaults to None.
441
459
  field : str, optional
@@ -443,65 +461,52 @@ def _get_training_data_for_shp(
443
461
  zonal_stats : str, optional
444
462
  Zonal statistics method. Defaults to None.
445
463
  time_field : str, optional
446
- Name of the column containing timestamp data in the input gdf. Defaults to None.
447
- time_delta : timedelta, optional
448
- Time delta used to match a data point with all the scenes falling between
449
- `time_stamp - time_delta` and `time_stamp + time_delta`. Defaults to None.
450
-
464
+ Name of the column containing time(range) data in the input gdf, for the case where each row
465
+ should load from a different time(range). If loading from the same time(range) for
466
+ all rows, then its preferable to pass time as a key:variable in the 'dc_query'.
467
+ Note the time values must be in a format that datacube.load() accepts. For example, as a
468
+ tuple with strings ('2017-01-01', '2017-01-31'). Defaults to None.
451
469
 
452
470
  Returns
453
471
  --------
454
- Two lists, a list of numpy.arrays containing classes and extracted data for
455
- each pixel or polygon, and another containing the data variable names.
472
+ pd.DataFrame
456
473
 
457
474
  """
458
475
 
459
476
  # prevent function altering dictionary kwargs
460
- dc_query = deepcopy(dc_query)
477
+ dc_query = {**dc_query} # shallow copy is faster
461
478
 
462
479
  # remove dask chunks if supplied as using
463
- # mulitprocessing for parallization
480
+ # mulitprocessing for parallelisation
464
481
  if "dask_chunks" in dc_query:
465
482
  dc_query.pop("dask_chunks", None)
466
483
 
467
- # set up query based on polygon
468
- geom = Geometry(geom=gdf.iloc[index].geometry, crs=gdf.crs)
469
- q = {"geopolygon": geom}
470
- # merge polygon query with user supplied query params
471
- dc_query.update(q)
472
-
473
- # Update time range if a time window is specified
474
- if time_delta is not None:
475
- timestamp = gdf.loc[index][time_field]
476
- start_time = timestamp - time_delta
477
- end_time = timestamp + time_delta
478
- timestamp = {"time": (start_time, end_time)}
479
- # merge time query with user supplied query params
480
- dc_query.update(timestamp)
481
-
482
- # Use input feature function
484
+ # set up query based on row geometry
485
+ geom = Geometry(geom=row["geometry"], crs=crs)
486
+ dc_query.update({"geopolygon": geom})
487
+
488
+ if time_field is not None:
489
+ timerange = getattr(row, time_field)
490
+ dc_query.update({"time": timerange})
491
+
492
+ # Use input feature function and run checks on output
483
493
  data = feature_func(dc_query)
484
494
 
485
- # if no data is present then return
486
- if len(data) == 0:
487
- return
495
+ if not isinstance(data, (xr.Dataset, xr.DataArray)):
496
+ raise TypeError("feature_func must return xarray Dataset or DataArray")
488
497
 
489
- if gdf.iloc[[index]].geometry.geom_type.values != "Point":
490
- # If the geometry type is a polygon extract all pixels
491
- # create polygon mask
492
- mask = xr_rasterize(gdf.iloc[[index]], data)
493
- data = data.where(mask)
498
+ if len(data.data_vars) == 0:
499
+ raise ValueError(
500
+ "feature_func returned an empty dataset, "
501
+ "this can happen if a geometry is not within data bounds"
502
+ )
494
503
 
495
- # Check that feature_func has removed time
496
- if "time" in data.dims:
497
- t = data.dims["time"]
498
- if t > 1 and time_delta is not None:
499
- raise ValueError(
500
- "After running the feature_func, the dataset still has "
501
- + str(t)
502
- + " time-steps, dataset must only have"
503
- + " x and y dimensions."
504
- )
504
+ # If the geometry type is a polygon extract all pixels
505
+ if row["geometry"].geom_type != "Point":
506
+ # create polygon mask (requires gdf)
507
+ dff = gpd.GeoDataFrame(row.to_frame().T, geometry="geometry", crs=crs)
508
+ mask = xr_rasterize(dff, data)
509
+ data = data.where(mask)
505
510
 
506
511
  if return_coords:
507
512
  # turn coords into a variable in the ds
@@ -510,46 +515,50 @@ def _get_training_data_for_shp(
510
515
 
511
516
  # append ID measurement to dataset for tracking failures
512
517
  band = list(data.data_vars)[0]
513
- _id = xr.zeros_like(data[band])
514
- data["id"] = _id
515
- data["id"] = data["id"] + gdf.iloc[index]["id"]
518
+ data["_training_id"] = xr.zeros_like(data[band])
519
+ data["_training_id"] = data["_training_id"] + row["_training_id"]
520
+
521
+ if "time" in data.sizes:
522
+ if return_time_coords:
523
+ data["time_coord"] = data.time
516
524
 
517
- # If no zonal stats were requested then extract all pixel values
525
+ # If no zonal stats were requested then extract all pixel values.
518
526
  if zonal_stats is None:
519
- flat_train = sklearn_flatten(data)
520
- flat_val = np.repeat(row[field], flat_train.shape[0])
521
- stacked = np.hstack((np.expand_dims(flat_val, axis=1), flat_train))
527
+ stacked = data.to_dataframe().reset_index(drop=True)
528
+ stacked[field] = row[field]
522
529
 
523
530
  elif zonal_stats in ["mean", "median", "max", "min"]:
524
531
  method_to_call = getattr(data, zonal_stats)
525
- flat_train = method_to_call()
526
- flat_train = flat_train.to_array()
527
- stacked = np.hstack((row[field], flat_train))
532
+ stacked = method_to_call(["x", "y"]) # will keep time as dim if present
533
+ stacked = stacked.to_dataframe().reset_index(drop=True)
534
+ stacked[field] = row[field]
528
535
 
529
536
  else:
530
537
  raise Exception(
531
- zonal_stats + " is not one of the supported" + " reduce functions ('mean','median','max','min')"
538
+ f"{zonal_stats} is not one of the supported reduce functions: 'mean','median','max','min'"
532
539
  )
533
540
 
534
- out_arrs.append(stacked)
535
- out_vars.append([field] + list(data.data_vars))
541
+ if "spatial_ref" in stacked.columns:
542
+ stacked = stacked.drop("spatial_ref", axis=1)
543
+
544
+ return stacked
536
545
 
537
546
 
538
547
  def _get_training_data_parallel(
539
548
  gdf: gpd.GeoDataFrame,
540
- dc_query: str,
549
+ dc_query: dict,
541
550
  ncpus: int,
542
- return_coords: bool,
543
- feature_func: Optional[Callable] = None,
551
+ return_coords: bool = False,
552
+ return_time_coords: bool = False,
553
+ feature_func: callable = None,
544
554
  field: Optional[str] = None,
545
555
  zonal_stats: Optional[str] = None,
546
556
  time_field: Optional[str] = None,
547
- time_delta: Optional[int] = None,
548
- ) -> Tuple[List[str], List[Any]]:
557
+ ) -> pd.DataFrame:
549
558
  """
550
559
  Function passing the '_get_training_data_for_shp' function
551
560
  to a mulitprocessing.Pool.
552
- Inherits variables from 'collect_training_data()'.
561
+ Inherits variables from 'collect_training_data'.
553
562
 
554
563
  """
555
564
  # Check if dask-client is running
@@ -561,18 +570,26 @@ def _get_training_data_parallel(
561
570
 
562
571
  if zx is not None:
563
572
  raise ValueError(
564
- "You have a Dask Client running, which prevents \nthis function from multiprocessing. Close the client."
573
+ "You have a Dask Client running, which prevents"
574
+ "this function from multiprocessing. Close the client."
565
575
  )
566
576
 
567
- # instantiate lists that can be shared across processes
568
- manager = mp.Manager()
569
- results = manager.list()
570
- column_names = manager.list()
577
+ crs = gdf.crs
578
+
579
+ # instantiate results list
580
+ results = []
571
581
 
572
582
  # progress bar
573
583
  pbar = tqdm(total=len(gdf))
574
584
 
575
- def update(*a):
585
+ # what to do with the results
586
+ def results_update(df):
587
+ results.append(df)
588
+ pbar.update()
589
+
590
+ # What to do with errors
591
+ def handle_error(index, e):
592
+ print(f"Worker failed on row {index}", str(e))
576
593
  pbar.update()
577
594
 
578
595
  with mp.Pool(ncpus) as pool:
@@ -580,131 +597,159 @@ def _get_training_data_parallel(
580
597
  pool.apply_async(
581
598
  _get_training_data_for_shp,
582
599
  [
583
- gdf,
584
- index,
585
600
  row,
586
- results,
587
- column_names,
601
+ crs,
588
602
  dc_query,
589
603
  return_coords,
604
+ return_time_coords,
590
605
  feature_func,
591
606
  field,
592
607
  zonal_stats,
593
608
  time_field,
594
- time_delta,
595
609
  ],
596
- callback=update,
610
+ callback=results_update,
611
+ error_callback=partial(handle_error, index),
597
612
  )
598
613
 
599
614
  pool.close()
600
615
  pool.join()
601
- pbar.close()
602
616
 
603
- return column_names, results
617
+ pbar.close()
618
+
619
+ return results
604
620
 
605
621
 
606
622
  def collect_training_data(
607
623
  gdf: gpd.GeoDataFrame,
608
- dc_query: dict,
624
+ dc_query: dict[str, Any],
609
625
  ncpus: int = 1,
610
626
  return_coords: bool = False,
627
+ return_time_coords: bool = False,
611
628
  feature_func: callable = None,
612
629
  field: str = None,
613
- zonal_stats: str = None,
630
+ zonal_stats: Optional[str] = None,
614
631
  clean: bool = True,
615
- fail_threshold: float = 0.02,
632
+ fail_threshold: float = 0.05,
616
633
  fail_ratio: float = 0.5,
617
- max_retries: int = 3,
618
- time_field: str = None,
619
- time_delta: timedelta = None,
620
- ) -> Tuple[List[np.ndarray], List[str]]:
634
+ max_retries: int = 2,
635
+ time_field: Optional[str] = None,
636
+ ) -> pd.DataFrame:
621
637
  """
622
- This function provides methods for gathering training data from the ODC over
638
+ This function provides methods for gathering training/validation data from the ODC over
623
639
  geometries stored within a geopandas geodataframe. The function will return a
624
- 'model_input' array containing stacked training data arrays with all NaNs & Infs removed.
625
- In the instance where ncpus > 1, a parallel version of the function will be run
626
- (functions are passed to a mp.Pool()). This function can conduct zonal statistics if
627
- the supplied shapefile contains polygons. The 'feature_func' parameter defines what
628
- features to produce.
640
+ pandas.DataFrame where the index contains class labels and the columns contain
641
+ feature values generated by a user-defined `feature_func`.
642
+
643
+ - In the instance where ncpus > 1, the function will automatically run in parallel.
644
+ - Zonal statistics are supported where the provided vector file contains polygons, otherwise all
645
+ pixel values are returned.
646
+ - Individual points/polygons can be loaded from different time ranges by passing the `time_field`
647
+ parameter.
648
+ - Implements a retry queue for samples that may fail due to i/o limitations or s3 read failures.
629
649
 
630
650
  Parameters
631
651
  ----------
632
652
  gdf : geopandas geodataframe
633
- geometry data in the form of a geopandas geodataframe
653
+ geometry data in the form of a geopandas geodataframe. Must contain a class labels column,
654
+ can optionally contain a column with time stamps, specified with the`time_field` param.
634
655
  dc_query : dictionary
635
- Datacube query object, should not contain lat and long (x or y)
636
- variables as these are supplied by the 'gdf' variable
656
+ Datacube query object, should not contain lat and long (x or y) variables as these
657
+ are supplied by the geopolygon column in the 'gdf'.
637
658
  ncpus : int
638
659
  The number of cpus/processes over which to parallelize the gathering
639
- of training data (only if ncpus is > 1). Use 'mp.cpu_count()' to determine the number of
640
- cpus available on a machine. Defaults to 1.
641
- return_coords : bool
642
- If True, then the training data will contain two extra columns 'x_coord' and
643
- 'y_coord' corresponding to the x,y coordinate of each sample. This variable can
644
- be useful for handling spatial autocorrelation between samples later in the ML workflow.
660
+ of training data (only if ncpus is > 1). Defaults to 1.
645
661
  feature_func : function
646
662
  A function for generating feature layers that is applied to the data within
647
663
  the bounds of the input geometry. The 'feature_func' must accept a 'dc_query'
648
- object, and return a single xarray.Dataset or xarray.DataArray containing
649
- 2D coordinates (i.e x, y - no time dimension).
650
- e.g.
664
+ object, and return a single xarray.Dataset or xarray.DataArray:
665
+
651
666
  def feature_function(query):
652
667
  dc = datacube.Datacube(app='feature_layers')
653
668
  ds = dc.load(**query)
654
669
  ds = ds.mean('time')
655
670
  return ds
671
+
656
672
  field : str
657
673
  Name of the column in the gdf that contains the class labels
674
+ return_coords : bool
675
+ If True, then the output data will contain two extra columns 'x_coord' and
676
+ 'y_coord' corresponding to the x,y coordinate of each sample.
677
+ return_time_coords : bool
678
+ If True, then the output data will contain an extra column 'time_coord',
679
+ corresponding to the time stamp of each sample.
658
680
  zonal_stats : string, optional
659
681
  An optional string giving the names of zonal statistics to calculate
660
682
  for each polygon. Default is None (all pixel values are returned). Supported
661
683
  values are 'mean', 'median', 'max', 'min'.
662
684
  clean : bool
663
- Whether or not to remove missing values in the training dataset. If True,
664
- training labels with any NaNs or Infs in the feature layers will be dropped
665
- from the dataset.
666
- fail_threshold : float, default 0.02
685
+ Whether or not to remove missing values in the returned dataset. If True (default),
686
+ rows with any NaNs or Infs in any numeric columns will be dropped from the dataset.
687
+ time_field : str, optional
688
+ Name of the column containing time(range) data in the input gdf, for the case where each row
689
+ should load from a different time(range). If loading from the same time(range) for
690
+ all rows, then its preferable to pass time as a key:variable in the 'dc_query'.
691
+ Note the time values must be in a format that datacube.load() accepts. For example, as a
692
+ tuple with strings ('2017-01-01', '2017-01-31'). Defaults to None.
693
+ fail_threshold : float, default 0.05
667
694
  Silent read fails on S3 can result in some rows of the returned data containing NaN values.
668
695
  The'fail_threshold' fraction specifies a % of acceptable fails.
669
696
  e.g. Setting 'fail_threshold' to 0.05 means if >5% of the samples in the training dataset
670
- fail then those samples will be reutnred to the multiprocessing queue. Below this fraction
697
+ fail then those samples will be returned to the multiprocessing queue. Below this fraction
671
698
  the function will accept the failures and return the results.
672
699
  fail_ratio: float
673
700
  A float between 0 and 1 that defines if a given training sample has failed.
674
701
  Default is 0.5, which means if 50 % of the measurements in a given sample return null
675
702
  values, and the number of total fails is more than the fail_threshold, the samplewill be
676
703
  passed to the retry queue.
677
- max_retries: int, default 3
704
+ max_retries: int, default 2
678
705
  Maximum number of times to retry collecting samples. This number is invoked
679
706
  if the 'fail_threshold' is not reached.
680
- time_field: str
681
- The name of the attribute in the input dataframe containing capture timestamp
682
- time_delta: time_delta
683
- The size of the window used as timestamp +/- time_delta.
684
- This is used to allow matching a single field data point with multiple scenes
685
707
 
686
708
  Returns
687
709
  --------
688
- Two lists, a list of numpy.arrays containing classes and extracted data for
689
- each pixel or polygon, and another containing the data variable names.
710
+ pandas.DataFrame
711
+ Where the index contains class labels and the columns contain feature values
690
712
 
691
713
  """
714
+ # --------Conduct various checks before running the function--------
715
+ if feature_func is None:
716
+ raise ValueError(
717
+ "Please supply a feature layer function through the "
718
+ + "parameter 'feature_func'"
719
+ )
720
+
721
+ if field is None:
722
+ raise ValueError("Parameter 'field' must be provided")
723
+
724
+ if field not in gdf.columns:
725
+ raise ValueError(f"Column '{field}' not found in GeoDataFrame")
692
726
 
693
727
  # check the dtype of the class field
694
728
  if not np.issubdtype(gdf[field].dtype, np.integer):
695
- raise ValueError(f'The "{field}" column of the input vector must contain integer dtypes')
729
+ raise ValueError(
730
+ f'The "{field}" column of the input vector must contain integer dtypes'
731
+ )
696
732
 
697
- # check for feature_func
698
- if feature_func is None:
699
- raise ValueError("Please supply a feature layer function through the " + "parameter 'feature_func'")
733
+ # check time-field params
734
+ if time_field is not None:
735
+
736
+ if "time" in dc_query:
737
+ raise ValueError(
738
+ f"You have passed both 'dc_query['time']' and 'time_field', "
739
+ "only pass one of these options"
740
+ )
700
741
 
701
- if zonal_stats is not None:
702
- print("Taking zonal statistic: " + zonal_stats)
742
+ if time_field not in gdf.columns:
743
+ raise ValueError(f"Column '{time_field}' not found in GeoDataFrame")
703
744
 
745
+ if zonal_stats:
746
+ print(f"Applying zonal statistic: {zonal_stats}")
747
+
748
+ # ----------------------------------------------------------------
704
749
  # add unique id to gdf to help with indexing failed rows
705
750
  # during multiprocessing
706
- # if zonal_stats is not None:
707
- gdf["id"] = range(0, len(gdf))
751
+ gdf = gdf.copy() # only modify copy
752
+ gdf["_training_id"] = np.arange(len(gdf))
708
753
 
709
754
  if ncpus == 1:
710
755
  # progress indicator
@@ -713,104 +758,106 @@ def collect_training_data(
713
758
 
714
759
  # list to store results
715
760
  results = []
716
- column_names = []
717
761
 
718
762
  # loop through polys and extract training data
719
763
  for index, row in gdf.iterrows():
720
764
  print(" Feature {:04}/{:04}\r".format(i + 1, len(gdf)), end="")
721
765
 
722
- _get_training_data_for_shp(
723
- gdf,
724
- index,
766
+ stacked = _get_training_data_for_shp(
725
767
  row,
726
- results,
727
- column_names,
768
+ gdf.crs,
728
769
  dc_query,
729
770
  return_coords,
771
+ return_time_coords,
730
772
  feature_func,
731
773
  field,
732
774
  zonal_stats,
733
775
  time_field,
734
- time_delta,
735
776
  )
777
+
778
+ results.append(stacked)
736
779
  i += 1
737
780
 
738
781
  else:
739
782
  print("Collecting training data in parallel mode")
740
- column_names, results = _get_training_data_parallel(
783
+ results = _get_training_data_parallel(
741
784
  gdf=gdf,
742
785
  dc_query=dc_query,
743
786
  ncpus=ncpus,
744
787
  return_coords=return_coords,
788
+ return_time_coords=return_time_coords,
745
789
  feature_func=feature_func,
746
790
  field=field,
747
791
  zonal_stats=zonal_stats,
748
792
  time_field=time_field,
749
- time_delta=time_delta,
750
793
  )
751
794
 
752
- # column names are appended during each iteration
753
- # but they are identical, grab only the first instance
754
- column_names = column_names[0]
795
+ if not results:
796
+ raise RuntimeError("No samples returned from feature extraction.")
755
797
 
756
- # Stack the extracted training data for each feature into a single array
757
- model_input = np.vstack(results)
798
+ # join all results into a single df
799
+ df = pd.concat(results)
758
800
 
759
801
  # this code block below iteratively retries failed rows
760
802
  # up to max_retries or until fail_threshold is
761
- # reached - whichever occurs first
803
+ # reached, whichever occurs first.
762
804
  if ncpus > 1:
763
805
  i = 1
764
806
  while i <= max_retries:
765
- # Find % of fails (null values) in data. Use Pandas for simplicity
766
- df = pd.DataFrame(data=model_input[:, 0:-1], index=model_input[:, -1])
807
+ # Find % of fails (null values) in data
808
+ dff = df.set_index("_training_id")
767
809
  # how many nan values per id?
768
- num_nans = df.isnull().sum(axis=1)
810
+ num_nans = dff.isnull().sum(axis=1)
769
811
  num_nans = num_nans.groupby(num_nans.index).sum()
770
812
  # how many valid values per id?
771
- num_valid = df.notnull().sum(axis=1)
813
+ num_valid = dff.notnull().sum(axis=1)
772
814
  num_valid = num_valid.groupby(num_valid.index).sum()
773
815
  # find fail rate
774
816
  perc_fail = num_nans / (num_nans + num_valid)
775
817
  fail_ids = perc_fail[perc_fail > fail_ratio]
818
+
776
819
  fail_rate = len(fail_ids) / len(gdf)
777
820
 
778
- print("Percentage of possible fails after run " + str(i) + " = " + str(round(fail_rate * 100, 2)) + " %")
821
+ print(
822
+ "Percentage of possible fails after run "
823
+ + str(i)
824
+ + " = "
825
+ + str(round(fail_rate * 100, 2))
826
+ + " %"
827
+ )
779
828
 
780
829
  if fail_rate > fail_threshold:
781
830
  print("Recollecting samples that failed")
782
831
 
783
832
  fail_ids = list(fail_ids.index)
784
- # keep only the ids in model_input object that didn't fail
785
- model_input = model_input[~np.isin(model_input[:, -1], fail_ids)]
833
+
834
+ # keep only the ids in df object that didn't fail
835
+ df = df.loc[~df["_training_id"].isin(fail_ids)]
786
836
 
787
837
  # index out the fail_ids from the original gdf
788
- gdf_rerun = gdf.loc[gdf["id"].isin(fail_ids)]
838
+ gdf_rerun = gdf.loc[gdf["_training_id"].isin(fail_ids)]
789
839
  gdf_rerun = gdf_rerun.reset_index(drop=True)
790
840
 
791
- time.sleep(5) # sleep for 5s to rest api
841
+ time.sleep(3) # sleep for 3s to rest api
792
842
 
793
843
  # recollect failed rows
794
- (
795
- column_names_again,
796
- results_again,
797
- ) = _get_training_data_parallel(
844
+ results_again = _get_training_data_parallel(
798
845
  gdf=gdf_rerun,
799
846
  dc_query=dc_query,
800
847
  ncpus=ncpus,
801
848
  return_coords=return_coords,
849
+ return_time_coords=return_time_coords,
802
850
  feature_func=feature_func,
803
851
  field=field,
804
852
  zonal_stats=zonal_stats,
805
853
  time_field=time_field,
806
- time_delta=time_delta,
807
854
  )
808
855
 
809
856
  # Stack the extracted training data for each feature into a single array
810
- model_input_again = np.vstack(results_again)
857
+ df_again = pd.concat(results_again)
811
858
 
812
859
  # merge results of the re-run with original run
813
- model_input = np.vstack((model_input, model_input_again))
860
+ df = pd.concat([df, df_again])
814
861
 
815
862
  i += 1
816
863
 
@@ -818,24 +865,30 @@ def collect_training_data(
818
865
  break
819
866
 
820
867
  # -----------------------------------------------
821
-
822
868
  # remove id column
823
- idx_var = column_names[0:-1]
824
- model_col_indices = [column_names.index(var_name) for var_name in idx_var]
825
- model_input = model_input[:, model_col_indices]
869
+ df = df.drop("_training_id", axis=1)
826
870
 
827
871
  if clean:
828
- num = np.count_nonzero(np.isnan(model_input).any(axis=1))
829
- model_input = model_input[~np.isnan(model_input).any(axis=1)]
830
- model_input = model_input[~np.isinf(model_input).any(axis=1)]
831
- print("Removed " + str(num) + " rows wth NaNs &/or Infs")
832
- print("Output shape: ", model_input.shape)
872
+ # Identify which columns have numeric data
873
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
874
+
875
+ # do we have any numeric columns to clean
876
+ if len(numeric_cols) == 0:
877
+ print("No numeric columns to clean; leaving DataFrame unchanged.")
878
+ else:
879
+ # Build invalid mask on numeric columns only, NaN or Inf
880
+ invalid_mask = ~np.isfinite(df[numeric_cols]).all(axis=1)
881
+ num_removed = invalid_mask.sum()
882
+ df = df[~invalid_mask]
883
+
884
+ print(f"Removed {num_removed} rows with NaNs or Infs in numeric columns")
885
+ print("Output shape:", df.shape)
833
886
 
834
887
  else:
835
888
  print("Returning data without cleaning")
836
- print("Output shape: ", model_input.shape)
889
+ print("Output shape: ", df.shape)
837
890
 
838
- return column_names[0:-1], model_input
891
+ return df.set_index(field)
839
892
 
840
893
 
841
894
  class KMeans_tree(ClusterMixin):
@@ -866,7 +919,8 @@ class KMeans_tree(ClusterMixin):
866
919
  # make child models
867
920
  if n_levels > 1:
868
921
  self.branches = [
869
- KMeans_tree(n_levels=n_levels - 1, n_clusters=n_clusters, **kwargs) for _ in range(n_clusters)
922
+ KMeans_tree(n_levels=n_levels - 1, n_clusters=n_clusters, **kwargs)
923
+ for _ in range(n_clusters)
870
924
  ]
871
925
 
872
926
  def fit(self, X, y=None, sample_weight=None):
@@ -898,7 +952,11 @@ class KMeans_tree(ClusterMixin):
898
952
  # fit child models on their corresponding partition of the training set
899
953
  self.branches[clu].fit(
900
954
  X[labels_old == clu],
901
- sample_weight=(sample_weight[labels_old == clu] if sample_weight is not None else None),
955
+ sample_weight=(
956
+ sample_weight[labels_old == clu]
957
+ if sample_weight is not None
958
+ else None
959
+ ),
902
960
  )
903
961
  self.labels_[labels_old == clu] += self.branches[clu].labels_
904
962
 
@@ -934,13 +992,24 @@ class KMeans_tree(ClusterMixin):
934
992
  for clu in range(self.n_clusters):
935
993
  result[rescpy == clu] += self.branches[clu].predict(
936
994
  X[rescpy == clu],
937
- sample_weight=(sample_weight[rescpy == clu] if sample_weight is not None else None),
995
+ sample_weight=(
996
+ sample_weight[rescpy == clu]
997
+ if sample_weight is not None
998
+ else None
999
+ ),
938
1000
  )
939
1001
 
940
1002
  return result
941
1003
 
942
1004
 
943
- def spatial_clusters(coordinates, method="Hierarchical", max_distance=None, n_groups=None, verbose=False, **kwargs):
1005
+ def spatial_clusters(
1006
+ coordinates,
1007
+ method="Hierarchical",
1008
+ max_distance=None,
1009
+ n_groups=None,
1010
+ verbose=False,
1011
+ **kwargs,
1012
+ ):
944
1013
  """
945
1014
  Create spatial groups on coorindate data using either KMeans clustering
946
1015
  or a Gaussian Mixture model
@@ -974,21 +1043,28 @@ def spatial_clusters(coordinates, method="Hierarchical", max_distance=None, n_gr
974
1043
  raise ValueError("Method must be one of: 'Hierarchical','KMeans' or 'GMM'")
975
1044
 
976
1045
  if (method in ["GMM", "KMeans"]) & (n_groups is None):
977
- raise ValueError("The 'GMM' and 'KMeans' methods requires explicitly setting 'n_groups'")
1046
+ raise ValueError(
1047
+ "The 'GMM' and 'KMeans' methods requires explicitly setting 'n_groups'"
1048
+ )
978
1049
 
979
1050
  if (method == "Hierarchical") & (max_distance is None):
980
1051
  raise ValueError("The 'Hierarchical' method requires setting max_distance")
981
1052
 
982
1053
  if method == "Hierarchical":
983
1054
  cluster_label = AgglomerativeClustering(
984
- n_clusters=None, linkage="complete", distance_threshold=max_distance, **kwargs
1055
+ n_clusters=None,
1056
+ linkage="complete",
1057
+ distance_threshold=max_distance,
1058
+ **kwargs,
985
1059
  ).fit_predict(coordinates)
986
1060
 
987
1061
  if method == "KMeans":
988
1062
  cluster_label = KMeans(n_clusters=n_groups, **kwargs).fit_predict(coordinates)
989
1063
 
990
1064
  if method == "GMM":
991
- cluster_label = GaussianMixture(n_components=n_groups, **kwargs).fit_predict(coordinates)
1065
+ cluster_label = GaussianMixture(n_components=n_groups, **kwargs).fit_predict(
1066
+ coordinates
1067
+ )
992
1068
  if verbose:
993
1069
  print("n clusters = " + str(len(np.unique(cluster_label))))
994
1070
 
@@ -1217,7 +1293,9 @@ def spatial_train_test_split(
1217
1293
 
1218
1294
  if kfold_method == "SpatialKFold":
1219
1295
  if n_splits is None:
1220
- raise ValueError("n_splits parameter requires an integer value, eg. 'n_splits=5'")
1296
+ raise ValueError(
1297
+ "n_splits parameter requires an integer value, eg. 'n_splits=5'"
1298
+ )
1221
1299
  if (test_size is not None) or (train_size is not None):
1222
1300
  warnings.warn(
1223
1301
  "With the 'SpatialKFold' method, controlling the test/train ratio "
@@ -1268,7 +1346,11 @@ def _partition_by_sum(array, parts):
1268
1346
  """
1269
1347
  array = np.atleast_1d(array).ravel()
1270
1348
  if parts > array.size:
1271
- raise ValueError("Cannot partition an array of size {} into {} parts of equal sum.".format(array.size, parts))
1349
+ raise ValueError(
1350
+ "Cannot partition an array of size {} into {} parts of equal sum.".format(
1351
+ array.size, parts
1352
+ )
1353
+ )
1272
1354
  cumulative_sum = array.cumsum()
1273
1355
  # Ideally, we want each part to have the same number of points (total /
1274
1356
  # parts).
@@ -1279,7 +1361,11 @@ def _partition_by_sum(array, parts):
1279
1361
  # Check for repeated split points, which indicates that there is no way to
1280
1362
  # split the array.
1281
1363
  if np.unique(indices).size != indices.size:
1282
- raise ValueError("Could not find partition points to split the array into {} parts of equal sum.".format(parts))
1364
+ raise ValueError(
1365
+ "Could not find partition points to split the array into {} parts of equal sum.".format(
1366
+ parts
1367
+ )
1368
+ )
1283
1369
  return indices
1284
1370
 
1285
1371
 
@@ -1337,7 +1423,11 @@ class _BaseSpatialCrossValidator(BaseCrossValidator, metaclass=ABCMeta):
1337
1423
  The testing set indices for that split.
1338
1424
  """
1339
1425
  if X.shape[1] != 2:
1340
- raise ValueError("X (the coordinate data) must have exactly 2 columns ({} given).".format(X.shape[1]))
1426
+ raise ValueError(
1427
+ "X (the coordinate data) must have exactly 2 columns ({} given).".format(
1428
+ X.shape[1]
1429
+ )
1430
+ )
1341
1431
  for train, test in super().split(X, y, groups):
1342
1432
  yield train, test
1343
1433
 
@@ -1471,7 +1561,9 @@ class _SpatialShuffleSplit(_BaseSpatialCrossValidator):
1471
1561
  **kwargs,
1472
1562
  )
1473
1563
  if balance < 1:
1474
- raise ValueError("The *balance* argument must be >= 1. To disable balance, use 1.")
1564
+ raise ValueError(
1565
+ "The *balance* argument must be >= 1. To disable balance, use 1."
1566
+ )
1475
1567
  self.test_size = test_size
1476
1568
  self.train_size = train_size
1477
1569
  self.random_state = random_state
@@ -1530,7 +1622,12 @@ class _SpatialShuffleSplit(_BaseSpatialCrossValidator):
1530
1622
  test_points = np.where(np.isin(labels, cluster_ids[test_clusters]))[0]
1531
1623
  # The proportion of data points assigned to each group should
1532
1624
  # be close the proportion of clusters assigned to each group.
1533
- balance.append(abs(train_points.size / test_points.size - train_clusters.size / test_clusters.size))
1625
+ balance.append(
1626
+ abs(
1627
+ train_points.size / test_points.size
1628
+ - train_clusters.size / test_clusters.size
1629
+ )
1630
+ )
1534
1631
  test_sets.append(test_points)
1535
1632
  best = np.argmin(balance)
1536
1633
  yield test_sets[best]
@@ -1612,7 +1709,11 @@ class _SpatialKFold(_BaseSpatialCrossValidator):
1612
1709
  )
1613
1710
 
1614
1711
  if n_splits < 2:
1615
- raise ValueError("Number of splits must be >=2 for clusterKFold. Given {}.".format(n_splits))
1712
+ raise ValueError(
1713
+ "Number of splits must be >=2 for clusterKFold. Given {}.".format(
1714
+ n_splits
1715
+ )
1716
+ )
1616
1717
  self.test_size = test_size
1617
1718
  self.shuffle = shuffle
1618
1719
  self.random_state = random_state
@@ -44,14 +44,10 @@ from dea_tools.dask import create_local_dask_cluster
44
44
  from dea_tools.datahandling import load_ard
45
45
  from dea_tools.spatial import xr_rasterize
46
46
 
47
- # Create local dask cluster to improve data load time
48
- client = create_local_dask_cluster(return_client=True)
49
-
50
47
  # disable DeprecationWarning for chained assignments in conversion to
51
48
  # datetime format
52
49
  pd.options.mode.chained_assignment = None # default='warn'
53
50
 
54
-
55
51
  def normalise_wit(polygon_base_df):
56
52
  """
57
53
  This function is to normalise the Fractional Cover vegetation
@@ -396,7 +392,7 @@ def WIT_drill(
396
392
  # Connect to the datacube
397
393
  dc = datacube.Datacube(app="WIT_drill")
398
394
 
399
- # load landsat 5,7,8 data
395
+ # load Landsat 5,7,8,9 data
400
396
  warnings.filterwarnings("ignore")
401
397
 
402
398
  # load wetland polygon and specify the coordinate reference system of the polygon
@@ -410,10 +406,10 @@ def WIT_drill(
410
406
  if verbose_progress:
411
407
  print("Loading Landsat data")
412
408
 
413
- # Load Landsat 5, 7 and 8 data. Not including Landsat 7 SLC off period (31-05-2003 to 06-04-2022)
409
+ # Load Landsat 5, 7, 8 and 9 data. Not including Landsat 7 SLC off period (31-05-2003 to 06-04-2022)
414
410
  ds_ls = load_ard(
415
411
  dc,
416
- products=["ga_ls8c_ard_3", "ga_ls7e_ard_3", "ga_ls5t_ard_3"],
412
+ products=["ga_ls9c_ard_3", "ga_ls8c_ard_3", "ga_ls7e_ard_3", "ga_ls5t_ard_3"],
417
413
  ls7_slc_off=False,
418
414
  measurements=bands,
419
415
  geopolygon=gpgon,
File without changes