geoai-py 0.26.0__py2.py3-none-any.whl → 0.28.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/prithvi.py CHANGED
@@ -28,6 +28,16 @@ NO_DATA_FLOAT = 0.0001
28
28
  OFFSET = 0
29
29
  PERCENTILE = 99.9
30
30
 
31
+ # Available Prithvi models
32
+ AVAILABLE_MODELS = [
33
+ "Prithvi-EO-2.0-tiny-TL", # tiny transfer learning, embed_dim=192, depth=12, with coords
34
+ "Prithvi-EO-2.0-100M-TL", # 100M transfer learning, embed_dim=768, depth=12, with coords
35
+ "Prithvi-EO-2.0-300M", # 300M base model, embed_dim=1024, depth=24, no coords
36
+ "Prithvi-EO-2.0-300M-TL", # 300M transfer learning, embed_dim=768, depth=12, with coords
37
+ "Prithvi-EO-2.0-600M", # 600M base model, embed_dim=1280, depth=32, no coords
38
+ "Prithvi-EO-2.0-600M-TL", # 600M transfer learning, embed_dim=1280, depth=32, with coords
39
+ ]
40
+
31
41
 
32
42
  def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
43
  """Create 3D sin/cos positional embeddings.
@@ -622,8 +632,22 @@ class PrithviMAE(nn.Module):
622
632
  class PrithviProcessor:
623
633
  """Prithvi EO 2.0 processor with GeoTIFF input/output support.
624
634
 
625
- https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL
626
- https://github.com/NASA-IMPACT/Prithvi-EO-2.0
635
+ Supports multiple model variants:
636
+ - Prithvi-EO-2.0-tiny-TL (tiny transfer learning)
637
+ - Prithvi-EO-2.0-100M-TL (100M transfer learning)
638
+ - Prithvi-EO-2.0-300M (300M base model)
639
+ - Prithvi-EO-2.0-300M-TL (300M transfer learning)
640
+ - Prithvi-EO-2.0-600M (600M base model)
641
+ - Prithvi-EO-2.0-600M-TL (600M transfer learning)
642
+
643
+ References:
644
+ - tiny-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL
645
+ - 100M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL
646
+ - 300M: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M
647
+ - 300M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL
648
+ - 600M: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M
649
+ - 600M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL
650
+ - GitHub: https://github.com/NASA-IMPACT/Prithvi-EO-2.0
627
651
  """
628
652
 
629
653
  def __init__(
@@ -637,7 +661,14 @@ class PrithviProcessor:
637
661
  """Initialize Prithvi processor.
638
662
 
639
663
  Args:
640
- model_name: Name of the Prithvi model to download from HuggingFace Hub
664
+ model_name: Name of the Prithvi model to download from HuggingFace Hub.
665
+ Options:
666
+ - "Prithvi-EO-2.0-tiny-TL" (tiny, 192 dim, 12 layers)
667
+ - "Prithvi-EO-2.0-100M-TL" (100M, 768 dim, 12 layers)
668
+ - "Prithvi-EO-2.0-300M" (base, 1024 dim, 24 layers)
669
+ - "Prithvi-EO-2.0-300M-TL" (default, 768 dim, 12 layers)
670
+ - "Prithvi-EO-2.0-600M" (base, 1280 dim, 32 layers)
671
+ - "Prithvi-EO-2.0-600M-TL" (1280 dim, 32 layers)
641
672
  config_path: Path to config file (optional, downloads if not provided)
642
673
  checkpoint_path: Path to checkpoint file (optional, downloads if not provided)
643
674
  device: Torch device to use
@@ -679,7 +710,13 @@ class PrithviProcessor:
679
710
  """Download Prithvi model from HuggingFace Hub.
680
711
 
681
712
  Args:
682
- model_name: Name of the model
713
+ model_name: Name of the model. Options:
714
+ - "Prithvi-EO-2.0-tiny-TL"
715
+ - "Prithvi-EO-2.0-100M-TL"
716
+ - "Prithvi-EO-2.0-300M" (base model)
717
+ - "Prithvi-EO-2.0-300M-TL" (default)
718
+ - "Prithvi-EO-2.0-600M" (base model)
719
+ - "Prithvi-EO-2.0-600M-TL"
683
720
  cache_dir: Directory to cache files
684
721
 
685
722
  Returns:
@@ -773,7 +810,7 @@ class PrithviProcessor:
773
810
  meta = src.meta
774
811
  try:
775
812
  coords = src.tags()
776
- except:
813
+ except Exception:
777
814
  coords = None
778
815
 
779
816
  return img, meta, coords
@@ -1208,6 +1245,20 @@ class PrithviProcessor:
1208
1245
  dest.write(image[i], i + 1)
1209
1246
 
1210
1247
 
1248
+ def get_available_prithvi_models() -> List[str]:
1249
+ """Get list of available Prithvi model names.
1250
+
1251
+ Returns:
1252
+ List of available model names
1253
+
1254
+ Example:
1255
+ >>> models = get_available_prithvi_models()
1256
+ >>> print(models)
1257
+ ['Prithvi-EO-2.0-300M-TL', 'Prithvi-EO-2.0-600M-TL']
1258
+ """
1259
+ return AVAILABLE_MODELS.copy()
1260
+
1261
+
1211
1262
  def load_prithvi_model(
1212
1263
  model_name: str = "Prithvi-EO-2.0-300M-TL",
1213
1264
  device: Optional[str] = None,
@@ -1216,12 +1267,32 @@ def load_prithvi_model(
1216
1267
  """Load Prithvi model (convenience function).
1217
1268
 
1218
1269
  Args:
1219
- model_name: Name of the model
1270
+ model_name: Name of the model. Options:
1271
+ - "Prithvi-EO-2.0-tiny-TL"
1272
+ - "Prithvi-EO-2.0-100M-TL"
1273
+ - "Prithvi-EO-2.0-300M" (base)
1274
+ - "Prithvi-EO-2.0-300M-TL" (default)
1275
+ - "Prithvi-EO-2.0-600M" (base)
1276
+ - "Prithvi-EO-2.0-600M-TL"
1220
1277
  device: Device to use ('cuda' or 'cpu')
1221
1278
  cache_dir: Cache directory
1222
1279
 
1223
1280
  Returns:
1224
1281
  PrithviProcessor instance
1282
+
1283
+ Example:
1284
+ >>> # Load tiny-TL model
1285
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-tiny-TL")
1286
+ >>> # Load 100M-TL model
1287
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-100M-TL")
1288
+ >>> # Load 300M base model
1289
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-300M")
1290
+ >>> # Load 300M-TL model
1291
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-300M-TL")
1292
+ >>> # Load 600M base model
1293
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-600M")
1294
+ >>> # Load 600M-TL model
1295
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-600M-TL")
1225
1296
  """
1226
1297
  if device is not None:
1227
1298
  device = torch.device(device)
@@ -1245,9 +1316,23 @@ def prithvi_inference(
1245
1316
  Args:
1246
1317
  file_paths: List of input GeoTIFF files
1247
1318
  output_dir: Output directory
1248
- model_name: Name of the model
1319
+ model_name: Name of the model. Options:
1320
+ - "Prithvi-EO-2.0-tiny-TL"
1321
+ - "Prithvi-EO-2.0-100M-TL"
1322
+ - "Prithvi-EO-2.0-300M" (base)
1323
+ - "Prithvi-EO-2.0-300M-TL" (default)
1324
+ - "Prithvi-EO-2.0-600M" (base)
1325
+ - "Prithvi-EO-2.0-600M-TL"
1249
1326
  mask_ratio: Optional mask ratio
1250
1327
  device: Device to use
1328
+
1329
+ Example:
1330
+ >>> # Use tiny-TL model
1331
+ >>> prithvi_inference(
1332
+ ... file_paths=["img1.tif", "img2.tif", "img3.tif", "img4.tif"],
1333
+ ... model_name="Prithvi-EO-2.0-tiny-TL",
1334
+ ... output_dir="output_tiny"
1335
+ ... )
1251
1336
  """
1252
1337
  processor = load_prithvi_model(model_name, device)
1253
1338
  processor.process_files(file_paths, output_dir, mask_ratio)
geoai/sam.py CHANGED
@@ -5,7 +5,6 @@ The SamGeo class provides an interface for segmenting geospatial data using the
5
5
  import os
6
6
  from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
- import cv2
9
8
  import numpy as np
10
9
  import torch
11
10
  from leafmap import array_to_image, blend_images
@@ -125,6 +124,7 @@ class SamGeo:
125
124
  Raises:
126
125
  ValueError: If the input source is not a valid path or numpy array.
127
126
  """
127
+ import cv2 # Lazy import to avoid QGIS opencv conflicts
128
128
 
129
129
  if isinstance(source, str):
130
130
  if source.startswith("http"):
@@ -399,6 +399,7 @@ class SamGeo:
399
399
  Raises:
400
400
  ValueError: If no masks are available and `save_masks()` cannot generate them.
401
401
  """
402
+ import cv2 # Lazy import to avoid QGIS opencv conflicts
402
403
  import matplotlib.pyplot as plt
403
404
 
404
405
  if self.batch:
geoai/segment.py CHANGED
@@ -4,7 +4,6 @@ import os
4
4
  from dataclasses import dataclass
5
5
  from typing import Any, Dict, List, Optional, Tuple, Union
6
6
 
7
- import cv2
8
7
  import geopandas as gpd
9
8
  import numpy as np
10
9
  import rasterio
@@ -174,6 +173,8 @@ class GroundedSAM:
174
173
  Returns:
175
174
  List[DetectionResult]: Filtered detection results.
176
175
  """
176
+ import cv2 # Lazy import to avoid QGIS opencv conflicts
177
+
177
178
  if not detections:
178
179
  return detections
179
180
 
@@ -235,6 +236,8 @@ class GroundedSAM:
235
236
 
236
237
  def _mask_to_polygon(self, mask: np.ndarray) -> List[List[int]]:
237
238
  """Convert mask to polygon coordinates."""
239
+ import cv2 # Lazy import to avoid QGIS opencv conflicts
240
+
238
241
  # Find contours in the binary mask
239
242
  contours, _ = cv2.findContours(
240
243
  mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
@@ -255,6 +258,8 @@ class GroundedSAM:
255
258
  self, polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]
256
259
  ) -> np.ndarray:
257
260
  """Convert polygon to mask."""
261
+ import cv2 # Lazy import to avoid QGIS opencv conflicts
262
+
258
263
  # Create an empty mask
259
264
  mask = np.zeros(image_shape, dtype=np.uint8)
260
265
 
@@ -279,6 +284,8 @@ class GroundedSAM:
279
284
  Returns:
280
285
  List[np.ndarray]: List of individual instance masks.
281
286
  """
287
+ import cv2 # Lazy import to avoid QGIS opencv conflicts
288
+
282
289
  # Find connected components
283
290
  num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
284
291
  mask.astype(np.uint8), connectivity=8
@@ -320,6 +327,8 @@ class GroundedSAM:
320
327
  Returns:
321
328
  List[Dict]: List of polygon dictionaries with geometry and properties.
322
329
  """
330
+ import cv2 # Lazy import to avoid QGIS opencv conflicts
331
+
323
332
  polygons = []
324
333
 
325
334
  # Get individual instances