openforis-whisp 2.0.0b3__py3-none-any.whl → 3.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@ from shapely.geometry import shape
4
4
  from pathlib import Path
5
5
 
6
6
  # Existing imports
7
- from typing import List, Any
7
+ from typing import List, Any, Union
8
8
  from geojson import Feature, FeatureCollection, Polygon, Point
9
9
  import json
10
10
  import os
@@ -13,32 +13,32 @@ import ee
13
13
 
14
14
 
15
15
  def convert_geojson_to_ee(
16
- geojson_filepath: Any, enforce_wgs84: bool = True, strip_z_coords: bool = True
16
+ geojson_filepath: Union[str, Path, dict],
17
+ enforce_wgs84: bool = True,
18
+ strip_z_coords: bool = True,
17
19
  ) -> ee.FeatureCollection:
18
20
  """
19
- Reads a GeoJSON file from the given path and converts it to an Earth Engine FeatureCollection.
21
+ Converts GeoJSON data to an Earth Engine FeatureCollection.
22
+ Accepts either a file path or a GeoJSON dictionary object.
20
23
  Optionally checks and converts the CRS to WGS 84 (EPSG:4326) if needed.
21
24
  Automatically handles 3D coordinates by stripping Z values when necessary.
22
25
 
23
26
  Args:
24
- geojson_filepath (Any): The filepath to the GeoJSON file.
27
+ geojson_filepath (Union[str, Path, dict]): The filepath to the GeoJSON file (str or Path)
28
+ or a GeoJSON dictionary object.
25
29
  enforce_wgs84 (bool): Whether to enforce WGS 84 projection (EPSG:4326). Defaults to True.
30
+ Only applies when input is a file path (dicts are assumed to be in WGS84).
26
31
  strip_z_coords (bool): Whether to automatically strip Z coordinates from 3D geometries. Defaults to True.
27
32
 
28
33
  Returns:
29
34
  ee.FeatureCollection: Earth Engine FeatureCollection created from the GeoJSON.
30
35
  """
31
- if isinstance(geojson_filepath, (str, Path)):
36
+ if isinstance(geojson_filepath, dict):
37
+ # Input is already a GeoJSON dictionary - skip file reading
38
+ geojson_data = geojson_filepath
39
+ elif isinstance(geojson_filepath, (str, Path)):
32
40
  file_path = os.path.abspath(geojson_filepath)
33
41
 
34
- # Apply print_once deduplication for file reading message
35
- if not hasattr(convert_geojson_to_ee, "_printed_file_messages"):
36
- convert_geojson_to_ee._printed_file_messages = set()
37
-
38
- if file_path not in convert_geojson_to_ee._printed_file_messages:
39
- print(f"Reading GeoJSON file from: {file_path}")
40
- convert_geojson_to_ee._printed_file_messages.add(file_path)
41
-
42
42
  # Use GeoPandas to read the file and handle CRS
43
43
  gdf = gpd.read_file(file_path)
44
44
 
@@ -56,15 +56,17 @@ def convert_geojson_to_ee(
56
56
  # Check and convert CRS if needed
57
57
  if enforce_wgs84:
58
58
  if gdf.crs is None:
59
- print("Warning: Input GeoJSON has no CRS defined, assuming WGS 84")
59
+ # Assuming WGS 84 if no CRS defined
60
+ pass
60
61
  elif gdf.crs != "EPSG:4326":
61
- print(f"Converting CRS from {gdf.crs} to WGS 84 (EPSG:4326)")
62
62
  gdf = gdf.to_crs("EPSG:4326")
63
63
 
64
64
  # Convert to GeoJSON
65
65
  geojson_data = json.loads(gdf.to_json())
66
66
  else:
67
- raise ValueError("Input must be a file path (str or Path)")
67
+ raise ValueError(
68
+ "Input must be a file path (str or Path) or a GeoJSON dictionary object (dict)"
69
+ )
68
70
 
69
71
  validation_errors = validate_geojson(geojson_data)
70
72
  if validation_errors:
@@ -101,7 +103,7 @@ def convert_geojson_to_ee(
101
103
 
102
104
  success_message_key = f"z_coords_success_{file_path}"
103
105
  if success_message_key not in convert_geojson_to_ee._printed_z_messages:
104
- print("Successfully converted after stripping Z coordinates")
106
+ print("Successfully converted after stripping Z coordinates")
105
107
  convert_geojson_to_ee._printed_z_messages.add(success_message_key)
106
108
 
107
109
  return feature_collection
@@ -250,6 +252,58 @@ def convert_shapefile_to_ee(shapefile_path):
250
252
  return roi
251
253
 
252
254
 
255
+ # def convert_ee_to_df(
256
+ # ee_object,
257
+ # columns=None,
258
+ # remove_geom=False,
259
+ # **kwargs,
260
+ # ):
261
+ # """Converts an ee.FeatureCollection to pandas dataframe.
262
+
263
+ # Args:
264
+ # ee_object (ee.FeatureCollection): ee.FeatureCollection.
265
+ # columns (list): List of column names. Defaults to None.
266
+ # remove_geom (bool): Whether to remove the geometry column. Defaults to True.
267
+ # kwargs: Additional arguments passed to ee.data.computeFeature.
268
+
269
+ # Raises:
270
+ # TypeError: ee_object must be an ee.FeatureCollection
271
+
272
+ # Returns:
273
+ # pd.DataFrame: pandas DataFrame
274
+ # """
275
+ # if isinstance(ee_object, ee.Feature):
276
+ # ee_object = ee.FeatureCollection([ee_object])
277
+
278
+ # if not isinstance(ee_object, ee.FeatureCollection):
279
+ # raise TypeError("ee_object must be an ee.FeatureCollection")
280
+
281
+ # try:
282
+ # if remove_geom:
283
+ # data = ee_object.map(
284
+ # lambda f: ee.Feature(None, f.toDictionary(f.propertyNames().sort()))
285
+ # )
286
+ # else:
287
+ # data = ee_object
288
+
289
+ # kwargs["expression"] = data
290
+ # kwargs["fileFormat"] = "PANDAS_DATAFRAME"
291
+
292
+ # df = ee.data.computeFeatures(kwargs)
293
+
294
+ # if isinstance(columns, list):
295
+ # df = df[columns]
296
+
297
+ # if remove_geom and ("geometry" in df.columns):
298
+ # df = df.drop(columns=["geometry"], axis=1)
299
+
300
+ # # Sorting columns is not supported server-side and is removed from this function.
301
+
302
+ # return df
303
+ # except Exception as e:
304
+ # raise Exception(e)
305
+
306
+
253
307
  def convert_ee_to_df(
254
308
  ee_object,
255
309
  columns=None,
@@ -257,49 +311,37 @@ def convert_ee_to_df(
257
311
  sort_columns=False,
258
312
  **kwargs,
259
313
  ):
260
- """Converts an ee.FeatureCollection to pandas dataframe.
314
+ """
315
+ Converts an ee.FeatureCollection to pandas DataFrame, maximizing server-side operations.
261
316
 
262
317
  Args:
263
318
  ee_object (ee.FeatureCollection): ee.FeatureCollection.
264
- columns (list): List of column names. Defaults to None.
265
- remove_geom (bool): Whether to remove the geometry column. Defaults to True.
266
- sort_columns (bool): Whether to sort the column names. Defaults to False.
267
- kwargs: Additional arguments passed to ee.data.computeFeature.
268
-
269
- Raises:
270
- TypeError: ee_object must be an ee.FeatureCollection
319
+ columns (list): List of column names to select (server-side if possible).
320
+ remove_geom (bool): Remove geometry column server-side.
321
+ kwargs: Additional arguments for ee.data.computeFeatures.
271
322
 
272
323
  Returns:
273
324
  pd.DataFrame: pandas DataFrame
274
325
  """
326
+ import ee
327
+
275
328
  if isinstance(ee_object, ee.Feature):
276
329
  ee_object = ee.FeatureCollection([ee_object])
277
330
 
278
331
  if not isinstance(ee_object, ee.FeatureCollection):
279
332
  raise TypeError("ee_object must be an ee.FeatureCollection")
280
333
 
281
- try:
282
- if remove_geom:
283
- data = ee_object.map(
284
- lambda f: ee.Feature(None, f.toDictionary(f.propertyNames().sort()))
285
- )
286
- else:
287
- data = ee_object
334
+ # Server-side: select columns and remove geometry
335
+ if columns is not None:
336
+ ee_object = ee_object.select(columns)
337
+ if remove_geom:
338
+ ee_object = ee_object.map(lambda f: ee.Feature(None, f.toDictionary()))
288
339
 
289
- kwargs["expression"] = data
340
+ try:
341
+ kwargs["expression"] = ee_object
290
342
  kwargs["fileFormat"] = "PANDAS_DATAFRAME"
291
-
292
343
  df = ee.data.computeFeatures(kwargs)
293
344
 
294
- if isinstance(columns, list):
295
- df = df[columns]
296
-
297
- if remove_geom and ("geometry" in df.columns):
298
- df = df.drop(columns=["geometry"], axis=1)
299
-
300
- if sort_columns:
301
- df = df.reindex(sorted(df.columns), axis=1)
302
-
303
345
  return df
304
346
  except Exception as e:
305
347
  raise Exception(e)
@@ -443,7 +485,7 @@ def convert_csv_to_geojson(
443
485
  try:
444
486
  df = pd.read_csv(csv_filepath)
445
487
 
446
- df_to_geojson(df, geojson_filepath, geo_column)
488
+ convert_df_to_geojson(df, geojson_filepath, geo_column)
447
489
 
448
490
  except Exception as e:
449
491
  print(f"An error occurred while converting CSV to GeoJSON: {e}")
@@ -177,7 +177,7 @@ def g_jrc_tmf_plantation_prep():
177
177
  plantation_2020 = plantation.where(
178
178
  deforestation_year.gte(2021), 0
179
179
  ) # update from https://github.com/forestdatapartnership/whisp/issues/42
180
- return plantation_2020.rename("TMF_plant")
180
+ return plantation_2020.rename("TMF_plant").selfMask()
181
181
 
182
182
 
183
183
  # # Oil_palm_Descals
@@ -390,6 +390,7 @@ def g_radd_year_prep():
390
390
  .updateMask(radd_date.lte(end))
391
391
  .gt(0)
392
392
  .rename("RADD_year_" + "20" + str(year))
393
+ .selfMask()
393
394
  )
394
395
  return ee.Image(img_stack).addBands(radd_year)
395
396
 
@@ -403,6 +404,7 @@ def g_radd_year_prep():
403
404
  .updateMask(radd_date.lte(end))
404
405
  .gt(0)
405
406
  .rename(band_name)
407
+ .selfMask()
406
408
  )
407
409
 
408
410
  def make_band(year, img_stack):
@@ -415,6 +417,7 @@ def g_radd_year_prep():
415
417
  .updateMask(radd_date.lte(end))
416
418
  .gt(0)
417
419
  .rename(band_name)
420
+ .selfMask()
418
421
  )
419
422
  return ee.Image(img_stack).addBands(radd_year)
420
423
 
@@ -431,7 +434,7 @@ def g_tmf_def_per_year_prep():
431
434
  for i in range(0, 24 + 1):
432
435
  year_num = ee.Number(2000 + i)
433
436
  band_name = ee.String("TMF_def_").cat(year_num.format("%d"))
434
- tmf_def_year = tmf_def.eq(year_num).rename(band_name)
437
+ tmf_def_year = tmf_def.eq(year_num).rename(band_name).selfMask()
435
438
  if img_stack is None:
436
439
  img_stack = tmf_def_year
437
440
  else:
@@ -448,7 +451,7 @@ def g_tmf_deg_per_year_prep():
448
451
  for i in range(0, 24 + 1):
449
452
  year_num = ee.Number(2000 + i)
450
453
  band_name = ee.String("TMF_deg_").cat(year_num.format("%d"))
451
- tmf_def_year = tmf_def.eq(year_num).rename(band_name)
454
+ tmf_def_year = tmf_def.eq(year_num).rename(band_name).selfMask()
452
455
  if img_stack is None:
453
456
  img_stack = tmf_def_year
454
457
  else:
@@ -468,7 +471,7 @@ def g_glad_gfc_loss_per_year_prep():
468
471
  gfc_loss_year = (
469
472
  gfc.select(["lossyear"]).eq(i).And(gfc.select(["treecover2000"]).gt(10))
470
473
  )
471
- gfc_loss_year = gfc_loss_year.rename(band_name)
474
+ gfc_loss_year = gfc_loss_year.rename(band_name).selfMask()
472
475
  if img_stack is None:
473
476
  img_stack = gfc_loss_year
474
477
  else:
@@ -499,6 +502,7 @@ def g_modis_fire_prep():
499
502
  .select(["BurnDate"])
500
503
  .gte(0)
501
504
  .rename(band_name)
505
+ .selfMask()
502
506
  )
503
507
  img_stack = modis_year if img_stack is None else img_stack.addBands(modis_year)
504
508
 
@@ -528,6 +532,7 @@ def g_esa_fire_prep():
528
532
  .select(["BurnDate"])
529
533
  .gte(0)
530
534
  .rename(band_name)
535
+ .selfMask()
531
536
  )
532
537
  img_stack = esa_year if img_stack is None else img_stack.addBands(esa_year)
533
538
 
@@ -1155,10 +1160,55 @@ def nci_ocs2020_prep():
1155
1160
  ).selfMask() # cocoa from national land cover map for Côte d'Ivoire
1156
1161
 
1157
1162
 
1163
+ # ============================================================================
1164
+ # CONTEXT BANDS (Administrative boundaries and water mask)
1165
+ # ============================================================================
1166
+
1167
+
1168
+ def g_gaul_admin_code():
1169
+ """
1170
+ GAUL 2024 Level 1 administrative boundary codes (500m resolution).
1171
+ Used for spatial context and administrative aggregation.
1172
+
1173
+ Returns
1174
+ -------
1175
+ ee.Image
1176
+ Image with admin codes renamed to 'admin_code' (as int32)
1177
+ """
1178
+ admin_image = ee.Image(
1179
+ "projects/ee-andyarnellgee/assets/gaul_2024_level_1_code_500m"
1180
+ )
1181
+ # Cast to int32 to ensure integer GAUL codes, then rename
1182
+ return admin_image.rename("admin_code")
1183
+
1184
+
1185
+ def g_water_mask_prep():
1186
+ """
1187
+ Water mask from JRC/USGS combined dataset.
1188
+ Used to identify water bodies for downstream filtering and context.
1189
+
1190
+ Multiplied by pixel area to get water area in hectares.
1191
+
1192
+ Returns
1193
+ -------
1194
+ ee.Image
1195
+ Binary water mask image renamed to In_waterbody (will be multiplied by pixel area)
1196
+ """
1197
+ from openforis_whisp.parameters.config_runtime import water_flag
1198
+
1199
+ water_mask_image = ee.Image("projects/ee-andyarnellgee/assets/water_mask_jrc_usgs")
1200
+ return water_mask_image.selfMask().rename(water_flag)
1201
+
1202
+
1158
1203
  ###Combining datasets
1159
1204
 
1160
1205
 
1161
- def combine_datasets(national_codes=None, validate_bands=False):
1206
+ def combine_datasets(
1207
+ national_codes=None,
1208
+ validate_bands=False,
1209
+ include_context_bands=True,
1210
+ auto_recovery=False,
1211
+ ):
1162
1212
  """
1163
1213
  Combines datasets into a single multiband image, with fallback if assets are missing.
1164
1214
 
@@ -1169,48 +1219,76 @@ def combine_datasets(national_codes=None, validate_bands=False):
1169
1219
  validate_bands : bool, optional
1170
1220
  If True, validates band names with a slow .getInfo() call (default: False)
1171
1221
  Only enable for debugging. Normal operation relies on exception handling.
1222
+ include_context_bands : bool, optional
1223
+ If True (default), includes context bands (admin_code, water_flag) in the output.
1224
+ Set to False when using stats.py implementations that compile datasets differently.
1225
+ auto_recovery : bool, optional
1226
+ If True (default), automatically enables validate_bands when an error is detected
1227
+ during initial assembly. This allows graceful recovery from missing/broken datasets.
1172
1228
 
1173
1229
  Returns
1174
1230
  -------
1175
1231
  ee.Image
1176
- Combined multiband image with all datasets
1232
+ Combined multiband image with all datasets (and optionally context bands)
1177
1233
  """
1178
- img_combined = ee.Image(1).rename(geometry_area_column)
1179
-
1180
- # Combine images directly
1181
- for img in [func() for func in list_functions(national_codes=national_codes)]:
1234
+ # Step 1: Combine all main dataset images
1235
+ all_images = [ee.Image(1).rename(geometry_area_column)]
1236
+ for func in list_functions(national_codes=national_codes):
1182
1237
  try:
1183
- img_combined = img_combined.addBands(img)
1184
- # img_combined = img_combined.addBands(img)
1238
+ all_images.append(func())
1185
1239
  except ee.EEException as e:
1186
- # logger.error(f"Error adding image: {e}")
1187
- print(f"Error adding image: {e}")
1240
+ print(f"Error loading image: {e}")
1241
+
1242
+ img_combined = ee.Image.cat(all_images)
1188
1243
 
1189
- # OPTIMIZATION: Removed slow .getInfo() call for band validation
1190
- # The validation is now optional and disabled by default
1191
- # Image processing will fail downstream if there's an issue, which is handled by exception blocks
1192
- if validate_bands:
1244
+ # Step 2: Determine if validation needed
1245
+ should_validate = validate_bands
1246
+ if auto_recovery and not validate_bands:
1193
1247
  try:
1194
- # This is SLOW - only use for debugging
1195
- img_combined.bandNames().getInfo()
1248
+ # Fast error detection: batch check main + context bands in one call
1249
+ bands_to_check = [img_combined.bandNames().get(0)]
1250
+ if include_context_bands:
1251
+ admin_image = g_gaul_admin_code()
1252
+ water_mask = g_water_mask_prep()
1253
+ bands_to_check.extend(
1254
+ [admin_image.bandNames().get(0), water_mask.bandNames().get(0)]
1255
+ )
1256
+ ee.List(bands_to_check).getInfo() # trigger error if any band is invalid
1196
1257
  except ee.EEException as e:
1197
- # logger.error(f"Error validating band names: {e}")
1198
- # logger.info("Running code for filtering to only valid datasets due to error in input")
1199
- print("using valid datasets filter due to error in validation")
1200
- # Validate images
1201
- images_to_test = [
1202
- func() for func in list_functions(national_codes=national_codes)
1203
- ]
1204
- valid_imgs = keep_valid_images(images_to_test) # Validate images
1205
-
1206
- # Retry combining images after validation
1207
- img_combined = ee.Image(1).rename(geometry_area_column)
1208
- for img in valid_imgs:
1209
- img_combined = img_combined.addBands(img)
1258
+ print(f"Error detected, enabling recovery mode: {str(e)[:80]}...")
1259
+ should_validate = True
1210
1260
 
1261
+ # Step 3: Validate and recover if needed
1262
+ if should_validate:
1263
+ try:
1264
+ img_combined.bandNames().getInfo() # check all bands
1265
+ except ee.EEException as e:
1266
+ print("Using valid datasets filter due to error in validation")
1267
+ valid_imgs = keep_valid_images(
1268
+ [func() for func in list_functions(national_codes=national_codes)]
1269
+ )
1270
+ all_images_retry = [ee.Image(1).rename(geometry_area_column)]
1271
+ all_images_retry.extend(valid_imgs)
1272
+ img_combined = ee.Image.cat(all_images_retry)
1273
+
1274
+ # Step 4: Multiply main datasets by pixel area
1211
1275
  img_combined = img_combined.multiply(ee.Image.pixelArea())
1212
- print("Whisp multiband image compiled")
1213
1276
 
1277
+ # Step 5: Add context bands (admin_code only - water mask is now in prep functions)
1278
+ if include_context_bands:
1279
+ for band_func, band_name in [
1280
+ (g_gaul_admin_code, "admin_code"),
1281
+ (g_water_mask_prep, "In_waterbody"),
1282
+ ]:
1283
+ try:
1284
+ band_img = band_func()
1285
+ if should_validate:
1286
+ band_img.bandNames().getInfo()
1287
+ img_combined = img_combined.addBands(band_img)
1288
+ except ee.EEException as e:
1289
+ print(f"Warning: Could not add {band_name} band: {e}")
1290
+
1291
+ print("Whisp multiband image compiled")
1214
1292
  return img_combined
1215
1293
 
1216
1294
 
@@ -1230,9 +1308,12 @@ def combine_datasets(national_codes=None, validate_bands=False):
1230
1308
  def list_functions(national_codes=None):
1231
1309
  """
1232
1310
  Returns a list of functions that end with "_prep" and either:
1233
- - Start with "g_" (global/regional products)
1311
+ - Start with "g_" (global/regional products, excluding context bands)
1234
1312
  - Start with any provided national code prefix (nXX_)
1235
1313
 
1314
+ Context band functions (g_gaul_admin_code, g_water_mask_prep) are handled
1315
+ separately and excluded from this list to avoid duplication.
1316
+
1236
1317
  Args:
1237
1318
  national_codes: List of ISO2 country codes (without the 'n' prefix)
1238
1319
  """
@@ -1243,15 +1324,19 @@ def list_functions(national_codes=None):
1243
1324
  if national_codes is None:
1244
1325
  national_codes = []
1245
1326
 
1327
+ # Context band functions that are handled separately
1328
+ context_functions = {"g_gaul_admin_code", "g_water_mask_prep"}
1329
+
1246
1330
  # Create prefixes list with proper formatting ('n' + code + '_')
1247
1331
  allowed_prefixes = ["g_"] + [f"n{code.lower()}_" for code in national_codes]
1248
1332
 
1249
- # Filter functions in a single pass
1333
+ # Filter functions in a single pass, excluding context band functions
1250
1334
  functions = [
1251
1335
  func
1252
1336
  for name, func in inspect.getmembers(current_module, inspect.isfunction)
1253
1337
  if name.endswith("_prep")
1254
1338
  and any(name.startswith(prefix) for prefix in allowed_prefixes)
1339
+ and name not in context_functions
1255
1340
  ]
1256
1341
 
1257
1342
  return functions
@@ -1335,3 +1420,6 @@ def combine_custom_bands(custom_images, custom_bands_info):
1335
1420
  custom_ee_image = custom_ee_image.multiply(ee.Image.pixelArea())
1336
1421
 
1337
1422
  return custom_ee_image # Only return the image
1423
+
1424
+
1425
+ # %%
openforis_whisp/logger.py CHANGED
@@ -34,6 +34,19 @@ class StdoutLogger:
34
34
  def setLevel(self, level):
35
35
  self.logger.setLevel(level)
36
36
 
37
+ @property
38
+ def level(self):
39
+ """Return the logger's effective level."""
40
+ return self.logger.level
41
+
42
+ def hasHandlers(self):
43
+ """Check if the logger has any handlers."""
44
+ return self.logger.hasHandlers()
45
+
46
+ def addHandler(self, handler):
47
+ """Add a handler to the logger."""
48
+ self.logger.addHandler(handler)
49
+
37
50
 
38
51
  class FileLogger:
39
52
  def __init__(
@@ -73,3 +86,16 @@ class FileLogger:
73
86
 
74
87
  def setLevel(self, level):
75
88
  self.logger.setLevel(level)
89
+
90
+ @property
91
+ def level(self):
92
+ """Return the logger's effective level."""
93
+ return self.logger.level
94
+
95
+ def hasHandlers(self):
96
+ """Check if the logger has any handlers."""
97
+ return self.logger.hasHandlers()
98
+
99
+ def addHandler(self, handler):
100
+ """Add a handler to the logger."""
101
+ self.logger.addHandler(handler)
File without changes