openforis-whisp 2.0.0b3__py3-none-any.whl → 3.0.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@ from shapely.geometry import shape
4
4
  from pathlib import Path
5
5
 
6
6
  # Existing imports
7
- from typing import List, Any
7
+ from typing import List, Any, Union
8
8
  from geojson import Feature, FeatureCollection, Polygon, Point
9
9
  import json
10
10
  import os
@@ -12,65 +12,81 @@ import geopandas as gpd
12
12
  import ee
13
13
 
14
14
 
15
- def convert_geojson_to_ee(
16
- geojson_filepath: Any, enforce_wgs84: bool = True, strip_z_coords: bool = True
17
- ) -> ee.FeatureCollection:
15
+ # ============================================================================
16
+ # HELPER FUNCTIONS FOR UNIFIED PROCESSING PATHWAY
17
+ # ============================================================================
18
+
19
+
20
+ def _sanitize_geodataframe(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
18
21
  """
19
- Reads a GeoJSON file from the given path and converts it to an Earth Engine FeatureCollection.
20
- Optionally checks and converts the CRS to WGS 84 (EPSG:4326) if needed.
21
- Automatically handles 3D coordinates by stripping Z values when necessary.
22
+ Sanitize GeoDataFrame data types for JSON serialization.
23
+
24
+ Converts problematic data types that cannot be directly serialized:
25
+ - DateTime/Timestamp columns → ISO format strings
26
+ - Object columns → strings
27
+ - Skips geometry column
22
28
 
23
29
  Args:
24
- geojson_filepath (Any): The filepath to the GeoJSON file.
25
- enforce_wgs84 (bool): Whether to enforce WGS 84 projection (EPSG:4326). Defaults to True.
26
- strip_z_coords (bool): Whether to automatically strip Z coordinates from 3D geometries. Defaults to True.
30
+ gdf (gpd.GeoDataFrame): Input GeoDataFrame
27
31
 
28
32
  Returns:
29
- ee.FeatureCollection: Earth Engine FeatureCollection created from the GeoJSON.
33
+ gpd.GeoDataFrame: GeoDataFrame with sanitized data types
34
+ """
35
+ gdf = gdf.copy()
36
+ for col in gdf.columns:
37
+ if col != gdf.geometry.name: # Skip geometry column
38
+ # Handle datetime/timestamp columns
39
+ if pd.api.types.is_datetime64_any_dtype(gdf[col]):
40
+ gdf[col] = gdf[col].dt.strftime("%Y-%m-%d %H:%M:%S").fillna("")
41
+ # Handle other problematic types
42
+ elif gdf[col].dtype == "object":
43
+ # Convert any remaining non-serializable objects to strings
44
+ gdf[col] = gdf[col].astype(str)
45
+ return gdf
46
+
47
+
48
+ def _ensure_wgs84_crs(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
30
49
  """
31
- if isinstance(geojson_filepath, (str, Path)):
32
- file_path = os.path.abspath(geojson_filepath)
50
+ Ensure GeoDataFrame uses WGS 84 (EPSG:4326) coordinate reference system.
33
51
 
34
- # Apply print_once deduplication for file reading message
35
- if not hasattr(convert_geojson_to_ee, "_printed_file_messages"):
36
- convert_geojson_to_ee._printed_file_messages = set()
52
+ - If CRS is None, assumes WGS 84
53
+ - If CRS is not WGS 84, converts to WGS 84
54
+ - If already WGS 84, returns unchanged
37
55
 
38
- if file_path not in convert_geojson_to_ee._printed_file_messages:
39
- print(f"Reading GeoJSON file from: {file_path}")
40
- convert_geojson_to_ee._printed_file_messages.add(file_path)
56
+ Args:
57
+ gdf (gpd.GeoDataFrame): Input GeoDataFrame
41
58
 
42
- # Use GeoPandas to read the file and handle CRS
43
- gdf = gpd.read_file(file_path)
59
+ Returns:
60
+ gpd.GeoDataFrame: GeoDataFrame in WGS 84
61
+ """
62
+ if gdf.crs is None:
63
+ # Assuming WGS 84 if no CRS defined
64
+ return gdf
65
+ elif gdf.crs != "EPSG:4326":
66
+ return gdf.to_crs("EPSG:4326")
67
+ return gdf
44
68
 
45
- # NEW: Handle problematic data types before JSON conversion
46
- for col in gdf.columns:
47
- if col != gdf.geometry.name: # Skip geometry column
48
- # Handle datetime/timestamp columns
49
- if pd.api.types.is_datetime64_any_dtype(gdf[col]):
50
- gdf[col] = gdf[col].dt.strftime("%Y-%m-%d %H:%M:%S").fillna("")
51
- # Handle other problematic types
52
- elif gdf[col].dtype == "object":
53
- # Convert any remaining non-serializable objects to strings
54
- gdf[col] = gdf[col].astype(str)
55
-
56
- # Check and convert CRS if needed
57
- if enforce_wgs84:
58
- if gdf.crs is None:
59
- print("Warning: Input GeoJSON has no CRS defined, assuming WGS 84")
60
- elif gdf.crs != "EPSG:4326":
61
- print(f"Converting CRS from {gdf.crs} to WGS 84 (EPSG:4326)")
62
- gdf = gdf.to_crs("EPSG:4326")
63
-
64
- # Convert to GeoJSON
65
- geojson_data = json.loads(gdf.to_json())
66
- else:
67
- raise ValueError("Input must be a file path (str or Path)")
68
69
 
69
- validation_errors = validate_geojson(geojson_data)
70
- if validation_errors:
71
- raise ValueError(f"GeoJSON validation errors: {validation_errors}")
70
+ def _create_ee_feature_collection(
71
+ geojson_data: dict, strip_z_coords: bool = True, input_source: str = "input"
72
+ ) -> ee.FeatureCollection:
73
+ """
74
+ Create Earth Engine FeatureCollection from GeoJSON dict with error recovery.
75
+
76
+ Attempts to create EE FeatureCollection. If it fails due to 3D coordinates
77
+ and strip_z_coords is True, automatically strips Z values and retries.
78
+
79
+ Args:
80
+ geojson_data (dict): GeoJSON data dictionary
81
+ strip_z_coords (bool): Whether to retry with 2D geometries on failure
82
+ input_source (str): Description of input source for logging
72
83
 
73
- # Try to create the feature collection, handle 3D coordinate issues automatically
84
+ Returns:
85
+ ee.FeatureCollection: Earth Engine FeatureCollection
86
+
87
+ Raises:
88
+ ee.EEException: If conversion fails even after retries
89
+ """
74
90
  try:
75
91
  feature_collection = ee.FeatureCollection(
76
92
  create_feature_collection(geojson_data)
@@ -79,16 +95,16 @@ def convert_geojson_to_ee(
79
95
  except ee.EEException as e:
80
96
  if "Invalid GeoJSON geometry" in str(e) and strip_z_coords:
81
97
  # Apply print_once deduplication for Z-coordinate stripping messages
82
- if not hasattr(convert_geojson_to_ee, "_printed_z_messages"):
83
- convert_geojson_to_ee._printed_z_messages = set()
98
+ if not hasattr(_create_ee_feature_collection, "_printed_z_messages"):
99
+ _create_ee_feature_collection._printed_z_messages = set()
84
100
 
85
- z_message_key = f"z_coords_{file_path}"
86
- if z_message_key not in convert_geojson_to_ee._printed_z_messages:
101
+ z_message_key = f"z_coords_{input_source}"
102
+ if z_message_key not in _create_ee_feature_collection._printed_z_messages:
87
103
  print(
88
104
  "Warning: Invalid GeoJSON geometry detected, likely due to 3D coordinates."
89
105
  )
90
106
  print("Attempting to fix by stripping Z coordinates...")
91
- convert_geojson_to_ee._printed_z_messages.add(z_message_key)
107
+ _create_ee_feature_collection._printed_z_messages.add(z_message_key)
92
108
 
93
109
  # Apply Z-coordinate stripping
94
110
  geojson_data_fixed = _strip_z_coordinates_from_geojson(geojson_data)
@@ -99,10 +115,15 @@ def convert_geojson_to_ee(
99
115
  create_feature_collection(geojson_data_fixed)
100
116
  )
101
117
 
102
- success_message_key = f"z_coords_success_{file_path}"
103
- if success_message_key not in convert_geojson_to_ee._printed_z_messages:
104
- print("✓ Successfully converted after stripping Z coordinates")
105
- convert_geojson_to_ee._printed_z_messages.add(success_message_key)
118
+ success_message_key = f"z_coords_success_{input_source}"
119
+ if (
120
+ success_message_key
121
+ not in _create_ee_feature_collection._printed_z_messages
122
+ ):
123
+ print("Successfully converted after stripping Z coordinates")
124
+ _create_ee_feature_collection._printed_z_messages.add(
125
+ success_message_key
126
+ )
106
127
 
107
128
  return feature_collection
108
129
  except Exception as retry_error:
@@ -113,6 +134,82 @@ def convert_geojson_to_ee(
113
134
  raise e
114
135
 
115
136
 
137
+ def convert_geojson_to_ee(
138
+ geojson_input: Union[str, Path, dict, gpd.GeoDataFrame],
139
+ enforce_wgs84: bool = True,
140
+ strip_z_coords: bool = True,
141
+ ) -> ee.FeatureCollection:
142
+ """
143
+ Converts GeoJSON data to an Earth Engine FeatureCollection.
144
+
145
+ Accepts flexible input types with a unified processing pathway:
146
+ - File path (str or Path) → loads with GeoPandas
147
+ - GeoJSON dict → uses directly
148
+ - GeoDataFrame → uses directly
149
+
150
+ Automatically handles:
151
+ - CRS conversion to WGS 84 (EPSG:4326) if needed
152
+ - DateTime/Timestamp columns → converts to ISO strings before JSON serialization
153
+ - Non-serializable objects → converts to strings
154
+ - 3D coordinates → strips Z values when necessary
155
+ - Z-coordinate errors → retries with 2D geometries if enabled
156
+
157
+ Args:
158
+ geojson_input (Union[str, Path, dict, gpd.GeoDataFrame]):
159
+ - File path (str or Path) to GeoJSON file
160
+ - GeoJSON dictionary object
161
+ - GeoPandas GeoDataFrame
162
+ enforce_wgs84 (bool): Whether to enforce WGS 84 projection (EPSG:4326).
163
+ Defaults to True. Only applies to file path and GeoDataFrame inputs.
164
+ strip_z_coords (bool): Whether to automatically strip Z coordinates from 3D geometries.
165
+ Defaults to True.
166
+
167
+ Returns:
168
+ ee.FeatureCollection: Earth Engine FeatureCollection created from the GeoJSON.
169
+
170
+ Raises:
171
+ ValueError: If input type is unsupported or GeoJSON validation fails.
172
+ ee.EEException: If GeoJSON cannot be converted even after retries.
173
+ """
174
+ # UNIFIED INPUT NORMALIZATION: Convert all inputs to GeoDataFrame first
175
+ if isinstance(geojson_input, gpd.GeoDataFrame):
176
+ gdf = geojson_input.copy()
177
+ input_source = "GeoDataFrame"
178
+ elif isinstance(geojson_input, dict):
179
+ # Convert dict to GeoDataFrame for unified processing
180
+ gdf = gpd.GeoDataFrame.from_features(geojson_input.get("features", []))
181
+ input_source = "dict"
182
+ elif isinstance(geojson_input, (str, Path)):
183
+ # Load file and convert to GeoDataFrame
184
+ file_path = os.path.abspath(geojson_input)
185
+ gdf = gpd.read_file(file_path)
186
+ input_source = f"file ({file_path})"
187
+ else:
188
+ raise ValueError(
189
+ f"Input must be a file path (str or Path), GeoJSON dict, or GeoDataFrame. "
190
+ f"Got {type(geojson_input).__name__}"
191
+ )
192
+
193
+ # UNIFIED DATA SANITIZATION PATHWAY
194
+ # Handle problematic data types before JSON conversion
195
+ gdf = _sanitize_geodataframe(gdf)
196
+
197
+ # UNIFIED CRS HANDLING
198
+ if enforce_wgs84:
199
+ gdf = _ensure_wgs84_crs(gdf)
200
+
201
+ # UNIFIED GEOJSON CONVERSION
202
+ geojson_data = json.loads(gdf.to_json())
203
+
204
+ # UNIFIED VALIDATION
205
+ validation_errors = validate_geojson(geojson_data)
206
+ if validation_errors:
207
+ raise ValueError(f"GeoJSON validation errors: {validation_errors}")
208
+
209
+ # UNIFIED EE CONVERSION with error recovery
210
+ return _create_ee_feature_collection(geojson_data, strip_z_coords, input_source)
211
+
212
+
116
213
  def _strip_z_coordinates_from_geojson(geojson_data: dict) -> dict:
117
214
  """
118
215
  Helper function to strip Z coordinates from GeoJSON data.
@@ -250,6 +347,58 @@ def convert_shapefile_to_ee(shapefile_path):
250
347
  return roi
251
348
 
252
349
 
350
+ # def convert_ee_to_df(
351
+ # ee_object,
352
+ # columns=None,
353
+ # remove_geom=False,
354
+ # **kwargs,
355
+ # ):
356
+ # """Converts an ee.FeatureCollection to pandas dataframe.
357
+
358
+ # Args:
359
+ # ee_object (ee.FeatureCollection): ee.FeatureCollection.
360
+ # columns (list): List of column names. Defaults to None.
361
+ # remove_geom (bool): Whether to remove the geometry column. Defaults to True.
362
+ # kwargs: Additional arguments passed to ee.data.computeFeature.
363
+
364
+ # Raises:
365
+ # TypeError: ee_object must be an ee.FeatureCollection
366
+
367
+ # Returns:
368
+ # pd.DataFrame: pandas DataFrame
369
+ # """
370
+ # if isinstance(ee_object, ee.Feature):
371
+ # ee_object = ee.FeatureCollection([ee_object])
372
+
373
+ # if not isinstance(ee_object, ee.FeatureCollection):
374
+ # raise TypeError("ee_object must be an ee.FeatureCollection")
375
+
376
+ # try:
377
+ # if remove_geom:
378
+ # data = ee_object.map(
379
+ # lambda f: ee.Feature(None, f.toDictionary(f.propertyNames().sort()))
380
+ # )
381
+ # else:
382
+ # data = ee_object
383
+
384
+ # kwargs["expression"] = data
385
+ # kwargs["fileFormat"] = "PANDAS_DATAFRAME"
386
+
387
+ # df = ee.data.computeFeatures(kwargs)
388
+
389
+ # if isinstance(columns, list):
390
+ # df = df[columns]
391
+
392
+ # if remove_geom and ("geometry" in df.columns):
393
+ # df = df.drop(columns=["geometry"], axis=1)
394
+
395
+ # # Sorting columns is not supported server-side and is removed from this function.
396
+
397
+ # return df
398
+ # except Exception as e:
399
+ # raise Exception(e)
400
+
401
+
253
402
  def convert_ee_to_df(
254
403
  ee_object,
255
404
  columns=None,
@@ -257,49 +406,37 @@ def convert_ee_to_df(
257
406
  sort_columns=False,
258
407
  **kwargs,
259
408
  ):
260
- """Converts an ee.FeatureCollection to pandas dataframe.
409
+ """
410
+ Converts an ee.FeatureCollection to pandas DataFrame, maximizing server-side operations.
261
411
 
262
412
  Args:
263
413
  ee_object (ee.FeatureCollection): ee.FeatureCollection.
264
- columns (list): List of column names. Defaults to None.
265
- remove_geom (bool): Whether to remove the geometry column. Defaults to True.
266
- sort_columns (bool): Whether to sort the column names. Defaults to False.
267
- kwargs: Additional arguments passed to ee.data.computeFeature.
268
-
269
- Raises:
270
- TypeError: ee_object must be an ee.FeatureCollection
414
+ columns (list): List of column names to select (server-side if possible).
415
+ remove_geom (bool): Remove geometry column server-side.
416
+ kwargs: Additional arguments for ee.data.computeFeatures.
271
417
 
272
418
  Returns:
273
419
  pd.DataFrame: pandas DataFrame
274
420
  """
421
+ import ee
422
+
275
423
  if isinstance(ee_object, ee.Feature):
276
424
  ee_object = ee.FeatureCollection([ee_object])
277
425
 
278
426
  if not isinstance(ee_object, ee.FeatureCollection):
279
427
  raise TypeError("ee_object must be an ee.FeatureCollection")
280
428
 
281
- try:
282
- if remove_geom:
283
- data = ee_object.map(
284
- lambda f: ee.Feature(None, f.toDictionary(f.propertyNames().sort()))
285
- )
286
- else:
287
- data = ee_object
429
+ # Server-side: select columns and remove geometry
430
+ if columns is not None:
431
+ ee_object = ee_object.select(columns)
432
+ if remove_geom:
433
+ ee_object = ee_object.map(lambda f: ee.Feature(None, f.toDictionary()))
288
434
 
289
- kwargs["expression"] = data
435
+ try:
436
+ kwargs["expression"] = ee_object
290
437
  kwargs["fileFormat"] = "PANDAS_DATAFRAME"
291
-
292
438
  df = ee.data.computeFeatures(kwargs)
293
439
 
294
- if isinstance(columns, list):
295
- df = df[columns]
296
-
297
- if remove_geom and ("geometry" in df.columns):
298
- df = df.drop(columns=["geometry"], axis=1)
299
-
300
- if sort_columns:
301
- df = df.reindex(sorted(df.columns), axis=1)
302
-
303
440
  return df
304
441
  except Exception as e:
305
442
  raise Exception(e)
@@ -443,7 +580,7 @@ def convert_csv_to_geojson(
443
580
  try:
444
581
  df = pd.read_csv(csv_filepath)
445
582
 
446
- df_to_geojson(df, geojson_filepath, geo_column)
583
+ convert_df_to_geojson(df, geojson_filepath, geo_column)
447
584
 
448
585
  except Exception as e:
449
586
  print(f"An error occurred while converting CSV to GeoJSON: {e}")
@@ -177,7 +177,7 @@ def g_jrc_tmf_plantation_prep():
177
177
  plantation_2020 = plantation.where(
178
178
  deforestation_year.gte(2021), 0
179
179
  ) # update from https://github.com/forestdatapartnership/whisp/issues/42
180
- return plantation_2020.rename("TMF_plant")
180
+ return plantation_2020.rename("TMF_plant").selfMask()
181
181
 
182
182
 
183
183
  # # Oil_palm_Descals
@@ -390,6 +390,7 @@ def g_radd_year_prep():
390
390
  .updateMask(radd_date.lte(end))
391
391
  .gt(0)
392
392
  .rename("RADD_year_" + "20" + str(year))
393
+ .selfMask()
393
394
  )
394
395
  return ee.Image(img_stack).addBands(radd_year)
395
396
 
@@ -403,6 +404,7 @@ def g_radd_year_prep():
403
404
  .updateMask(radd_date.lte(end))
404
405
  .gt(0)
405
406
  .rename(band_name)
407
+ .selfMask()
406
408
  )
407
409
 
408
410
  def make_band(year, img_stack):
@@ -415,6 +417,7 @@ def g_radd_year_prep():
415
417
  .updateMask(radd_date.lte(end))
416
418
  .gt(0)
417
419
  .rename(band_name)
420
+ .selfMask()
418
421
  )
419
422
  return ee.Image(img_stack).addBands(radd_year)
420
423
 
@@ -431,7 +434,7 @@ def g_tmf_def_per_year_prep():
431
434
  for i in range(0, 24 + 1):
432
435
  year_num = ee.Number(2000 + i)
433
436
  band_name = ee.String("TMF_def_").cat(year_num.format("%d"))
434
- tmf_def_year = tmf_def.eq(year_num).rename(band_name)
437
+ tmf_def_year = tmf_def.eq(year_num).rename(band_name).selfMask()
435
438
  if img_stack is None:
436
439
  img_stack = tmf_def_year
437
440
  else:
@@ -448,7 +451,7 @@ def g_tmf_deg_per_year_prep():
448
451
  for i in range(0, 24 + 1):
449
452
  year_num = ee.Number(2000 + i)
450
453
  band_name = ee.String("TMF_deg_").cat(year_num.format("%d"))
451
- tmf_def_year = tmf_def.eq(year_num).rename(band_name)
454
+ tmf_def_year = tmf_def.eq(year_num).rename(band_name).selfMask()
452
455
  if img_stack is None:
453
456
  img_stack = tmf_def_year
454
457
  else:
@@ -468,7 +471,7 @@ def g_glad_gfc_loss_per_year_prep():
468
471
  gfc_loss_year = (
469
472
  gfc.select(["lossyear"]).eq(i).And(gfc.select(["treecover2000"]).gt(10))
470
473
  )
471
- gfc_loss_year = gfc_loss_year.rename(band_name)
474
+ gfc_loss_year = gfc_loss_year.rename(band_name).selfMask()
472
475
  if img_stack is None:
473
476
  img_stack = gfc_loss_year
474
477
  else:
@@ -499,6 +502,7 @@ def g_modis_fire_prep():
499
502
  .select(["BurnDate"])
500
503
  .gte(0)
501
504
  .rename(band_name)
505
+ .selfMask()
502
506
  )
503
507
  img_stack = modis_year if img_stack is None else img_stack.addBands(modis_year)
504
508
 
@@ -528,6 +532,7 @@ def g_esa_fire_prep():
528
532
  .select(["BurnDate"])
529
533
  .gte(0)
530
534
  .rename(band_name)
535
+ .selfMask()
531
536
  )
532
537
  img_stack = esa_year if img_stack is None else img_stack.addBands(esa_year)
533
538
 
@@ -1155,10 +1160,55 @@ def nci_ocs2020_prep():
1155
1160
  ).selfMask() # cocoa from national land cover map for Côte d'Ivoire
1156
1161
 
1157
1162
 
1163
+ # ============================================================================
1164
+ # CONTEXT BANDS (Administrative boundaries and water mask)
1165
+ # ============================================================================
1166
+
1167
+
1168
+ def g_gaul_admin_code():
1169
+ """
1170
+ GAUL 2024 Level 1 administrative boundary codes (500m resolution).
1171
+ Used for spatial context and administrative aggregation.
1172
+
1173
+ Returns
1174
+ -------
1175
+ ee.Image
1176
+ Image with admin codes renamed to 'admin_code' (as int32)
1177
+ """
1178
+ admin_image = ee.Image(
1179
+ "projects/ee-andyarnellgee/assets/gaul_2024_level_1_code_500m"
1180
+ )
1181
+ # Cast to int32 to ensure integer GAUL codes, then rename
1182
+ return admin_image.rename("admin_code")
1183
+
1184
+
1185
+ def g_water_mask_prep():
1186
+ """
1187
+ Water mask from JRC/USGS combined dataset.
1188
+ Used to identify water bodies for downstream filtering and context.
1189
+
1190
+ Multiplied by pixel area to get water area in hectares.
1191
+
1192
+ Returns
1193
+ -------
1194
+ ee.Image
1195
+ Binary water mask image renamed to In_waterbody (will be multiplied by pixel area)
1196
+ """
1197
+ from openforis_whisp.parameters.config_runtime import water_flag
1198
+
1199
+ water_mask_image = ee.Image("projects/ee-andyarnellgee/assets/water_mask_jrc_usgs")
1200
+ return water_mask_image.selfMask().rename(water_flag)
1201
+
1202
+
1158
1203
  ###Combining datasets
1159
1204
 
1160
1205
 
1161
- def combine_datasets(national_codes=None, validate_bands=False):
1206
+ def combine_datasets(
1207
+ national_codes=None,
1208
+ validate_bands=False,
1209
+ include_context_bands=True,
1210
+ auto_recovery=False,
1211
+ ):
1162
1212
  """
1163
1213
  Combines datasets into a single multiband image, with fallback if assets are missing.
1164
1214
 
@@ -1169,48 +1219,76 @@ def combine_datasets(national_codes=None, validate_bands=False):
1169
1219
  validate_bands : bool, optional
1170
1220
  If True, validates band names with a slow .getInfo() call (default: False)
1171
1221
  Only enable for debugging. Normal operation relies on exception handling.
1222
+ include_context_bands : bool, optional
1223
+ If True (default), includes context bands (admin_code, water_flag) in the output.
1224
+ Set to False when using stats.py implementations that compile datasets differently.
1225
+ auto_recovery : bool, optional
1226
+ If True (default), automatically enables validate_bands when an error is detected
1227
+ during initial assembly. This allows graceful recovery from missing/broken datasets.
1172
1228
 
1173
1229
  Returns
1174
1230
  -------
1175
1231
  ee.Image
1176
- Combined multiband image with all datasets
1232
+ Combined multiband image with all datasets (and optionally context bands)
1177
1233
  """
1178
- img_combined = ee.Image(1).rename(geometry_area_column)
1179
-
1180
- # Combine images directly
1181
- for img in [func() for func in list_functions(national_codes=national_codes)]:
1234
+ # Step 1: Combine all main dataset images
1235
+ all_images = [ee.Image(1).rename(geometry_area_column)]
1236
+ for func in list_functions(national_codes=national_codes):
1182
1237
  try:
1183
- img_combined = img_combined.addBands(img)
1184
- # img_combined = img_combined.addBands(img)
1238
+ all_images.append(func())
1185
1239
  except ee.EEException as e:
1186
- # logger.error(f"Error adding image: {e}")
1187
- print(f"Error adding image: {e}")
1240
+ print(f"Error loading image: {e}")
1241
+
1242
+ img_combined = ee.Image.cat(all_images)
1188
1243
 
1189
- # OPTIMIZATION: Removed slow .getInfo() call for band validation
1190
- # The validation is now optional and disabled by default
1191
- # Image processing will fail downstream if there's an issue, which is handled by exception blocks
1192
- if validate_bands:
1244
+ # Step 2: Determine if validation needed
1245
+ should_validate = validate_bands
1246
+ if auto_recovery and not validate_bands:
1193
1247
  try:
1194
- # This is SLOW - only use for debugging
1195
- img_combined.bandNames().getInfo()
1248
+ # Fast error detection: batch check main + context bands in one call
1249
+ bands_to_check = [img_combined.bandNames().get(0)]
1250
+ if include_context_bands:
1251
+ admin_image = g_gaul_admin_code()
1252
+ water_mask = g_water_mask_prep()
1253
+ bands_to_check.extend(
1254
+ [admin_image.bandNames().get(0), water_mask.bandNames().get(0)]
1255
+ )
1256
+ ee.List(bands_to_check).getInfo() # trigger error if any band is invalid
1196
1257
  except ee.EEException as e:
1197
- # logger.error(f"Error validating band names: {e}")
1198
- # logger.info("Running code for filtering to only valid datasets due to error in input")
1199
- print("using valid datasets filter due to error in validation")
1200
- # Validate images
1201
- images_to_test = [
1202
- func() for func in list_functions(national_codes=national_codes)
1203
- ]
1204
- valid_imgs = keep_valid_images(images_to_test) # Validate images
1205
-
1206
- # Retry combining images after validation
1207
- img_combined = ee.Image(1).rename(geometry_area_column)
1208
- for img in valid_imgs:
1209
- img_combined = img_combined.addBands(img)
1258
+ print(f"Error detected, enabling recovery mode: {str(e)[:80]}...")
1259
+ should_validate = True
1210
1260
 
1261
+ # Step 3: Validate and recover if needed
1262
+ if should_validate:
1263
+ try:
1264
+ img_combined.bandNames().getInfo() # check all bands
1265
+ except ee.EEException as e:
1266
+ print("Using valid datasets filter due to error in validation")
1267
+ valid_imgs = keep_valid_images(
1268
+ [func() for func in list_functions(national_codes=national_codes)]
1269
+ )
1270
+ all_images_retry = [ee.Image(1).rename(geometry_area_column)]
1271
+ all_images_retry.extend(valid_imgs)
1272
+ img_combined = ee.Image.cat(all_images_retry)
1273
+
1274
+ # Step 4: Multiply main datasets by pixel area
1211
1275
  img_combined = img_combined.multiply(ee.Image.pixelArea())
1212
- print("Whisp multiband image compiled")
1213
1276
 
1277
+ # Step 5: Add context bands (admin_code only - water mask is now in prep functions)
1278
+ if include_context_bands:
1279
+ for band_func, band_name in [
1280
+ (g_gaul_admin_code, "admin_code"),
1281
+ (g_water_mask_prep, "In_waterbody"),
1282
+ ]:
1283
+ try:
1284
+ band_img = band_func()
1285
+ if should_validate:
1286
+ band_img.bandNames().getInfo()
1287
+ img_combined = img_combined.addBands(band_img)
1288
+ except ee.EEException as e:
1289
+ print(f"Warning: Could not add {band_name} band: {e}")
1290
+
1291
+ print("Whisp multiband image compiled")
1214
1292
  return img_combined
1215
1293
 
1216
1294
 
@@ -1230,9 +1308,12 @@ def combine_datasets(national_codes=None, validate_bands=False):
1230
1308
  def list_functions(national_codes=None):
1231
1309
  """
1232
1310
  Returns a list of functions that end with "_prep" and either:
1233
- - Start with "g_" (global/regional products)
1311
+ - Start with "g_" (global/regional products, excluding context bands)
1234
1312
  - Start with any provided national code prefix (nXX_)
1235
1313
 
1314
+ Context band functions (g_gaul_admin_code, g_water_mask_prep) are handled
1315
+ separately and excluded from this list to avoid duplication.
1316
+
1236
1317
  Args:
1237
1318
  national_codes: List of ISO2 country codes (without the 'n' prefix)
1238
1319
  """
@@ -1243,15 +1324,19 @@ def list_functions(national_codes=None):
1243
1324
  if national_codes is None:
1244
1325
  national_codes = []
1245
1326
 
1327
+ # Context band functions that are handled separately
1328
+ context_functions = {"g_gaul_admin_code", "g_water_mask_prep"}
1329
+
1246
1330
  # Create prefixes list with proper formatting ('n' + code + '_')
1247
1331
  allowed_prefixes = ["g_"] + [f"n{code.lower()}_" for code in national_codes]
1248
1332
 
1249
- # Filter functions in a single pass
1333
+ # Filter functions in a single pass, excluding context band functions
1250
1334
  functions = [
1251
1335
  func
1252
1336
  for name, func in inspect.getmembers(current_module, inspect.isfunction)
1253
1337
  if name.endswith("_prep")
1254
1338
  and any(name.startswith(prefix) for prefix in allowed_prefixes)
1339
+ and name not in context_functions
1255
1340
  ]
1256
1341
 
1257
1342
  return functions
@@ -1335,3 +1420,6 @@ def combine_custom_bands(custom_images, custom_bands_info):
1335
1420
  custom_ee_image = custom_ee_image.multiply(ee.Image.pixelArea())
1336
1421
 
1337
1422
  return custom_ee_image # Only return the image
1423
+
1424
+
1425
+ # %%