segment-geospatial 1.1.0__py2.py3-none-any.whl → 1.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
samgeo/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "1.1.0"
5
+ __version__ = "1.2.1"
6
6
 
7
7
 
8
8
  from .samgeo import * # noqa: F403
samgeo/common.py CHANGED
@@ -4,7 +4,7 @@ The source code is adapted from https://github.com/aliaksandr960/segment-anythin
4
4
 
5
5
  import os
6
6
  import tempfile
7
- from typing import Any, List, Optional, Tuple, Union
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cv2
10
10
  import geopandas as gpd
@@ -4055,3 +4055,220 @@ def get_device() -> torch.device:
4055
4055
  return torch.device("mps")
4056
4056
  else:
4057
4057
  return torch.device("cpu")
4058
+
4059
+
4060
+ def get_raster_info(raster_path: str) -> Dict[str, Any]:
4061
+ """Display basic information about a raster dataset.
4062
+
4063
+ Args:
4064
+ raster_path (str): Path to the raster file
4065
+
4066
+ Returns:
4067
+ dict: Dictionary containing the basic information about the raster
4068
+ """
4069
+ # Open the raster dataset
4070
+ with rasterio.open(raster_path) as src:
4071
+ # Get basic metadata
4072
+ info = {
4073
+ "driver": src.driver,
4074
+ "width": src.width,
4075
+ "height": src.height,
4076
+ "count": src.count,
4077
+ "dtype": src.dtypes[0],
4078
+ "crs": src.crs.to_string() if src.crs else "No CRS defined",
4079
+ "transform": src.transform,
4080
+ "bounds": src.bounds,
4081
+ "resolution": (src.transform[0], -src.transform[4]),
4082
+ "nodata": src.nodata,
4083
+ }
4084
+
4085
+ # Calculate statistics for each band
4086
+ stats = []
4087
+ for i in range(1, src.count + 1):
4088
+ band = src.read(i, masked=True)
4089
+ band_stats = {
4090
+ "band": i,
4091
+ "min": float(band.min()),
4092
+ "max": float(band.max()),
4093
+ "mean": float(band.mean()),
4094
+ "std": float(band.std()),
4095
+ }
4096
+ stats.append(band_stats)
4097
+
4098
+ info["band_stats"] = stats
4099
+
4100
+ return info
4101
+
4102
+
4103
+ def get_raster_stats(raster_path: str, divide_by: float = 1.0) -> Dict[str, Any]:
4104
+ """Calculate statistics for each band in a raster dataset.
4105
+
4106
+ This function computes min, max, mean, and standard deviation values
4107
+ for each band in the provided raster, returning results in a dictionary
4108
+ with lists for each statistic type.
4109
+
4110
+ Args:
4111
+ raster_path (str): Path to the raster file
4112
+ divide_by (float, optional): Value to divide pixel values by.
4113
+ Defaults to 1.0, which keeps the original pixel values unchanged.
4114
+
4115
+ Returns:
4116
+ dict: Dictionary containing lists of statistics with keys:
4117
+ - 'min': List of minimum values for each band
4118
+ - 'max': List of maximum values for each band
4119
+ - 'mean': List of mean values for each band
4120
+ - 'std': List of standard deviation values for each band
4121
+ """
4122
+ # Initialize the results dictionary with empty lists
4123
+ stats = {"min": [], "max": [], "mean": [], "std": []}
4124
+
4125
+ # Open the raster dataset
4126
+ with rasterio.open(raster_path) as src:
4127
+ # Calculate statistics for each band
4128
+ for i in range(1, src.count + 1):
4129
+ band = src.read(i, masked=True)
4130
+
4131
+ # Append statistics for this band to each list
4132
+ stats["min"].append(float(band.min()) / divide_by)
4133
+ stats["max"].append(float(band.max()) / divide_by)
4134
+ stats["mean"].append(float(band.mean()) / divide_by)
4135
+ stats["std"].append(float(band.std()) / divide_by)
4136
+
4137
+ return stats
4138
+
4139
+
4140
+ def print_raster_info(
4141
+ raster_path: str, show_preview: bool = True, figsize: Tuple[int, int] = (10, 8)
4142
+ ) -> Optional[Dict[str, Any]]:
4143
+ """Print formatted information about a raster dataset and optionally show a preview.
4144
+
4145
+ Args:
4146
+ raster_path (str): Path to the raster file
4147
+ show_preview (bool, optional): Whether to display a visual preview of the raster.
4148
+ Defaults to True.
4149
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
4150
+
4151
+ Returns:
4152
+ dict: Dictionary containing raster information if successful, None otherwise
4153
+ """
4154
+ import matplotlib.pyplot as plt
4155
+ from rasterio.plot import show
4156
+
4157
+ try:
4158
+ info = get_raster_info(raster_path)
4159
+
4160
+ # Print basic information
4161
+ print(f"===== RASTER INFORMATION: {raster_path} =====")
4162
+ print(f"Driver: {info['driver']}")
4163
+ print(f"Dimensions: {info['width']} x {info['height']} pixels")
4164
+ print(f"Number of bands: {info['count']}")
4165
+ print(f"Data type: {info['dtype']}")
4166
+ print(f"Coordinate Reference System: {info['crs']}")
4167
+ print(f"Georeferenced Bounds: {info['bounds']}")
4168
+ print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
4169
+ print(f"NoData Value: {info['nodata']}")
4170
+
4171
+ # Print band statistics
4172
+ print("\n----- Band Statistics -----")
4173
+ for band_stat in info["band_stats"]:
4174
+ print(f"Band {band_stat['band']}:")
4175
+ print(f" Min: {band_stat['min']:.2f}")
4176
+ print(f" Max: {band_stat['max']:.2f}")
4177
+ print(f" Mean: {band_stat['mean']:.2f}")
4178
+ print(f" Std Dev: {band_stat['std']:.2f}")
4179
+
4180
+ # Show a preview if requested
4181
+ if show_preview:
4182
+ with rasterio.open(raster_path) as src:
4183
+ # For multi-band images, show RGB composite or first band
4184
+ if src.count >= 3:
4185
+ # Try to show RGB composite
4186
+ rgb = np.dstack([src.read(i) for i in range(1, 4)])
4187
+ plt.figure(figsize=figsize)
4188
+ plt.imshow(rgb)
4189
+ plt.title(f"RGB Preview: {raster_path}")
4190
+ else:
4191
+ # Show first band for single-band images
4192
+ plt.figure(figsize=figsize)
4193
+ show(
4194
+ src.read(1),
4195
+ cmap="viridis",
4196
+ title=f"Band 1 Preview: {raster_path}",
4197
+ )
4198
+ plt.colorbar(label="Pixel Value")
4199
+ plt.show()
4200
+
4201
+ return info
4202
+
4203
+ except Exception as e:
4204
+ print(f"Error reading raster: {str(e)}")
4205
+ return None
4206
+
4207
+
4208
+ def smooth_vector(
4209
+ vector_data: Union[str, gpd.GeoDataFrame],
4210
+ output_path: str = None,
4211
+ segment_length: float = None,
4212
+ smooth_iterations: int = 3,
4213
+ num_cores: int = 0,
4214
+ merge_collection: bool = True,
4215
+ merge_field: str = None,
4216
+ merge_multipolygons: bool = True,
4217
+ preserve_area: bool = True,
4218
+ area_tolerance: float = 0.01,
4219
+ **kwargs: Any,
4220
+ ) -> gpd.GeoDataFrame:
4221
+ """Smooth a vector data using the smoothify library.
4222
+ See https://github.com/DPIRD-DMA/Smoothify for more details.
4223
+
4224
+ Args:
4225
+ vector_data: The vector data to smooth.
4226
+ output_path: The path to save the smoothed vector data. If None, returns the smoothed vector data.
4227
+ segment_length: Resolution of the original raster data in map units. If None (default), automatically
4228
+ detects by finding the minimum segment length (from a data sample). Recommended to specify explicitly when known.
4229
+ smooth_iterations: The number of iterations to smooth the vector data.
4230
+ num_cores: Number of cores to use for parallel processing. If 0 (default), uses all available cores.
4231
+ merge_collection: Whether to merge/dissolve adjacent geometries in collections before smoothing.
4232
+ merge_field: Column name to use for dissolving geometries. Only valid when merge_collection=True.
4233
+ If None, dissolves all geometries together. If specified, dissolves geometries grouped by the column values.
4234
+ merge_multipolygons: Whether to merge adjacent polygons within MultiPolygons before smoothing
4235
+ preserve_area: Whether to restore original area after smoothing via buffering (applies to Polygons only)
4236
+ area_tolerance: Percentage of original area allowed as error (e.g., 0.01 = 0.01% error = 99.99% preservation).
4237
+ Only affects Polygons when preserve_area=True
4238
+
4239
+ Returns:
4240
+ gpd.GeoDataFrame: The smoothed vector data.
4241
+
4242
+ Examples:
4243
+ >>> from samgeo import common
4244
+ >>> gdf = common.read_vector("path/to/vector.geojson")
4245
+ >>> smoothed_gdf = common.smooth_vector(gdf, smooth_iterations=3, output_path="path/to/smoothed_vector.geojson")
4246
+ >>> smoothed_gdf.head()
4247
+ >>> smoothed_gdf.explore()
4248
+ """
4249
+ import leafmap
4250
+
4251
+ try:
4252
+ from smoothify import smoothify
4253
+ except ImportError:
4254
+ install_package("smoothify")
4255
+ from smoothify import smoothify
4256
+
4257
+ if isinstance(vector_data, str):
4258
+ vector_data = leafmap.read_vector(vector_data)
4259
+
4260
+ smoothed_vector_data = smoothify(
4261
+ geom=vector_data,
4262
+ segment_length=segment_length,
4263
+ smooth_iterations=smooth_iterations,
4264
+ num_cores=num_cores,
4265
+ merge_collection=merge_collection,
4266
+ merge_field=merge_field,
4267
+ merge_multipolygons=merge_multipolygons,
4268
+ preserve_area=preserve_area,
4269
+ area_tolerance=area_tolerance,
4270
+ **kwargs,
4271
+ )
4272
+ if output_path is not None:
4273
+ smoothed_vector_data.to_file(output_path)
4274
+ return smoothed_vector_data
samgeo/samgeo3.py CHANGED
@@ -200,6 +200,15 @@ class SamGeo3:
200
200
  url = "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/bpe_simple_vocab_16e6.txt.gz"
201
201
  bpe_path = common.download_file(url, bpe_path, quiet=True)
202
202
 
203
+ if os.environ.get("SAM3_CHECKPOINT_PATH") is not None:
204
+ checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH")
205
+ load_from_HF = False
206
+
207
+ if checkpoint_path is not None:
208
+ if not os.path.exists(checkpoint_path):
209
+ raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.")
210
+ load_from_HF = False
211
+
203
212
  model = build_sam3_image_model(
204
213
  bpe_path=bpe_path,
205
214
  device=device,
@@ -1075,6 +1084,7 @@ class SamGeo3:
1075
1084
  prompt: str,
1076
1085
  min_size: int = 0,
1077
1086
  max_size: Optional[int] = None,
1087
+ quiet: bool = False,
1078
1088
  **kwargs: Any,
1079
1089
  ) -> List[Dict[str, Any]]:
1080
1090
  """
@@ -1086,6 +1096,7 @@ class SamGeo3:
1086
1096
  will be filtered out. Defaults to 0.
1087
1097
  max_size (int, optional): Maximum mask size in pixels. Masks larger than
1088
1098
  this will be filtered out. Defaults to None (no maximum).
1099
+ quiet (bool): If True, suppress progress messages. Defaults to False.
1089
1100
 
1090
1101
  Returns:
1091
1102
  List[Dict[str, Any]]: A list of dictionaries containing the generated masks.
@@ -1140,12 +1151,438 @@ class SamGeo3:
1140
1151
  self._filter_masks_by_size(min_size, max_size)
1141
1152
 
1142
1153
  num_objects = len(self.masks)
1143
- if num_objects == 0:
1144
- print("No objects found. Please try a different prompt.")
1145
- elif num_objects == 1:
1146
- print("Found one object.")
1154
+ if not quiet:
1155
+ if num_objects == 0:
1156
+ print("No objects found. Please try a different prompt.")
1157
+ elif num_objects == 1:
1158
+ print("Found one object.")
1159
+ else:
1160
+ print(f"Found {num_objects} objects.")
1161
+
1162
+ def generate_masks_tiled(
1163
+ self,
1164
+ source: str,
1165
+ prompt: str,
1166
+ output: str,
1167
+ tile_size: int = 1024,
1168
+ overlap: int = 128,
1169
+ min_size: int = 0,
1170
+ max_size: Optional[int] = None,
1171
+ unique: bool = True,
1172
+ dtype: str = "uint32",
1173
+ bands: Optional[List[int]] = None,
1174
+ batch_size: int = 1,
1175
+ verbose: bool = True,
1176
+ **kwargs: Any,
1177
+ ) -> str:
1178
+ """
1179
+ Generate masks for large GeoTIFF images using a sliding window approach.
1180
+
1181
+ This method processes large images tile by tile to avoid GPU memory issues.
1182
+ The tiles are processed with overlap to ensure seamless mask merging at
1183
+ boundaries. Each detected object gets a unique ID that is consistent
1184
+ across the entire image.
1185
+
1186
+ Args:
1187
+ source (str): Path to the input GeoTIFF image.
1188
+ prompt (str): The text prompt describing the objects to segment.
1189
+ output (str): Path to the output GeoTIFF file.
1190
+ tile_size (int): Size of each tile in pixels. Defaults to 1024.
1191
+ overlap (int): Overlap between adjacent tiles in pixels. Defaults to 128.
1192
+ Higher overlap helps with better boundary merging but increases
1193
+ processing time.
1194
+ min_size (int): Minimum mask size in pixels. Masks smaller than this
1195
+ will be filtered out. Defaults to 0.
1196
+ max_size (int, optional): Maximum mask size in pixels. Masks larger than
1197
+ this will be filtered out. Defaults to None (no maximum).
1198
+ unique (bool): If True, each mask gets a unique value. If False, binary
1199
+ mask (0 or 1). Defaults to True.
1200
+ dtype (str): Data type for the output array. Use 'uint32' for large
1201
+ numbers of objects, 'uint16' for up to 65535 objects, or 'uint8'
1202
+ for up to 255 objects. Defaults to 'uint32'.
1203
+ bands (List[int], optional): List of band indices (1-based) to use for RGB
1204
+ when the input has more than 3 bands. If None, uses first 3 bands.
1205
+ batch_size (int): Number of tiles to process at once (future use).
1206
+ Defaults to 1.
1207
+ verbose (bool): Whether to print progress information. Defaults to True.
1208
+ **kwargs: Additional keyword arguments.
1209
+
1210
+ Returns:
1211
+ str: Path to the output GeoTIFF file.
1212
+
1213
+ Example:
1214
+ >>> sam = SamGeo3(backend="meta")
1215
+ >>> sam.generate_masks_tiled(
1216
+ ... source="large_satellite_image.tif",
1217
+ ... prompt="building",
1218
+ ... output="buildings_mask.tif",
1219
+ ... tile_size=1024,
1220
+ ... overlap=128,
1221
+ ... )
1222
+ """
1223
+ import rasterio
1224
+ from rasterio.windows import Window
1225
+
1226
+ if not source.lower().endswith((".tif", ".tiff")):
1227
+ raise ValueError("Source must be a GeoTIFF file for tiled processing.")
1228
+
1229
+ if not os.path.exists(source):
1230
+ raise ValueError(f"Source file not found: {source}")
1231
+
1232
+ if tile_size <= overlap:
1233
+ raise ValueError("tile_size must be greater than overlap")
1234
+
1235
+ # Open the source file to get metadata
1236
+ with rasterio.open(source) as src:
1237
+ img_height = src.height
1238
+ img_width = src.width
1239
+ profile = src.profile.copy()
1240
+
1241
+ if verbose:
1242
+ print(f"Processing image: {img_width} x {img_height} pixels")
1243
+ print(f"Tile size: {tile_size}, Overlap: {overlap}")
1244
+
1245
+ # Calculate the number of tiles
1246
+ step = tile_size - overlap
1247
+ n_tiles_x = max(1, (img_width - overlap + step - 1) // step)
1248
+ n_tiles_y = max(1, (img_height - overlap + step - 1) // step)
1249
+ total_tiles = n_tiles_x * n_tiles_y
1250
+
1251
+ if verbose:
1252
+ print(f"Total tiles to process: {total_tiles} ({n_tiles_x} x {n_tiles_y})")
1253
+
1254
+ # Determine output dtype
1255
+ if dtype == "uint8":
1256
+ np_dtype = np.uint8
1257
+ max_objects = 255
1258
+ elif dtype == "uint16":
1259
+ np_dtype = np.uint16
1260
+ max_objects = 65535
1261
+ elif dtype == "uint32":
1262
+ np_dtype = np.uint32
1263
+ max_objects = 4294967295
1147
1264
  else:
1148
- print(f"Found {num_objects} objects.")
1265
+ np_dtype = np.uint32
1266
+ max_objects = 4294967295
1267
+
1268
+ # Create output array in memory (for smaller images) or use memory-mapped file
1269
+ # For very large images, you might want to use rasterio windowed writing
1270
+ output_mask = np.zeros((img_height, img_width), dtype=np_dtype)
1271
+
1272
+ # Track unique object IDs across all tiles
1273
+ current_max_id = 0
1274
+ total_objects = 0
1275
+
1276
+ # Process each tile
1277
+ tile_iterator = tqdm(
1278
+ range(total_tiles),
1279
+ desc="Processing tiles",
1280
+ disable=not verbose,
1281
+ )
1282
+
1283
+ for tile_idx in tile_iterator:
1284
+ # Calculate tile position
1285
+ tile_y = tile_idx // n_tiles_x
1286
+ tile_x = tile_idx % n_tiles_x
1287
+
1288
+ # Calculate window coordinates
1289
+ x_start = tile_x * step
1290
+ y_start = tile_y * step
1291
+
1292
+ # Ensure we don't go beyond image bounds
1293
+ x_end = min(x_start + tile_size, img_width)
1294
+ y_end = min(y_start + tile_size, img_height)
1295
+
1296
+ # Adjust start if we're at the edge
1297
+ if x_end - x_start < tile_size and x_start > 0:
1298
+ x_start = max(0, x_end - tile_size)
1299
+ if y_end - y_start < tile_size and y_start > 0:
1300
+ y_start = max(0, y_end - tile_size)
1301
+
1302
+ window_width = x_end - x_start
1303
+ window_height = y_end - y_start
1304
+
1305
+ # Read tile from source
1306
+ with rasterio.open(source) as src:
1307
+ window = Window(x_start, y_start, window_width, window_height)
1308
+ if bands is not None:
1309
+ tile_data = np.stack(
1310
+ [src.read(b, window=window) for b in bands], axis=0
1311
+ )
1312
+ else:
1313
+ tile_data = src.read(window=window)
1314
+ if tile_data.shape[0] >= 3:
1315
+ tile_data = tile_data[:3, :, :]
1316
+ elif tile_data.shape[0] == 1:
1317
+ tile_data = np.repeat(tile_data, 3, axis=0)
1318
+ elif tile_data.shape[0] == 2:
1319
+ tile_data = np.concatenate(
1320
+ [tile_data, tile_data[0:1, :, :]], axis=0
1321
+ )
1322
+
1323
+ # Transpose to (height, width, channels)
1324
+ tile_data = np.transpose(tile_data, (1, 2, 0))
1325
+
1326
+ # Normalize to 8-bit
1327
+ tile_data = tile_data.astype(np.float32)
1328
+ tile_data -= tile_data.min()
1329
+ if tile_data.max() > 0:
1330
+ tile_data /= tile_data.max()
1331
+ tile_data *= 255
1332
+ tile_image = tile_data.astype(np.uint8)
1333
+
1334
+ # Process the tile
1335
+ try:
1336
+ # Set image for the tile
1337
+ self.image = tile_image
1338
+ self.image_height, self.image_width = tile_image.shape[:2]
1339
+ self.source = None # Don't need georef for individual tiles
1340
+
1341
+ # Initialize inference state for this tile
1342
+ pil_image = Image.fromarray(tile_image)
1343
+ self.pil_image = pil_image
1344
+
1345
+ if self.backend == "meta":
1346
+ self.inference_state = self.processor.set_image(pil_image)
1347
+ else:
1348
+ # For transformers backend, process directly
1349
+ pass
1350
+
1351
+ # Generate masks for this tile (quiet=True to avoid per-tile messages)
1352
+ self.generate_masks(
1353
+ prompt, min_size=min_size, max_size=max_size, quiet=True
1354
+ )
1355
+
1356
+ # Get masks for this tile
1357
+ tile_masks = self.masks
1358
+
1359
+ if tile_masks is not None and len(tile_masks) > 0:
1360
+ # Create a mask array for this tile
1361
+ tile_mask_array = np.zeros(
1362
+ (window_height, window_width), dtype=np_dtype
1363
+ )
1364
+
1365
+ for mask in tile_masks:
1366
+ # Convert mask to numpy
1367
+ if hasattr(mask, "cpu"):
1368
+ mask_np = mask.squeeze().cpu().numpy()
1369
+ elif hasattr(mask, "numpy"):
1370
+ mask_np = mask.squeeze().numpy()
1371
+ else:
1372
+ mask_np = (
1373
+ mask.squeeze() if hasattr(mask, "squeeze") else mask
1374
+ )
1375
+
1376
+ if mask_np.ndim > 2:
1377
+ mask_np = mask_np[0]
1378
+
1379
+ # Resize mask to tile size if needed
1380
+ if mask_np.shape != (window_height, window_width):
1381
+ mask_np = cv2.resize(
1382
+ mask_np.astype(np.float32),
1383
+ (window_width, window_height),
1384
+ interpolation=cv2.INTER_NEAREST,
1385
+ )
1386
+
1387
+ mask_bool = mask_np > 0
1388
+ mask_size = np.sum(mask_bool)
1389
+
1390
+ # Filter by size
1391
+ if mask_size < min_size:
1392
+ continue
1393
+ if max_size is not None and mask_size > max_size:
1394
+ continue
1395
+
1396
+ if unique:
1397
+ current_max_id += 1
1398
+ if current_max_id > max_objects:
1399
+ raise ValueError(
1400
+ f"Maximum number of objects ({max_objects}) exceeded. "
1401
+ "Consider using a larger dtype or reducing the number of objects."
1402
+ )
1403
+ tile_mask_array[mask_bool] = current_max_id
1404
+ else:
1405
+ tile_mask_array[mask_bool] = 1
1406
+
1407
+ total_objects += 1
1408
+
1409
+ # Merge tile mask into output mask
1410
+ # For overlapping regions, use the tile's values if they are non-zero
1411
+ # This simple approach works well for most cases
1412
+ self._merge_tile_mask(
1413
+ output_mask,
1414
+ tile_mask_array,
1415
+ x_start,
1416
+ y_start,
1417
+ x_end,
1418
+ y_end,
1419
+ overlap,
1420
+ tile_x,
1421
+ tile_y,
1422
+ n_tiles_x,
1423
+ n_tiles_y,
1424
+ )
1425
+
1426
+ except Exception as e:
1427
+ if verbose:
1428
+ print(f"Warning: Failed to process tile ({tile_x}, {tile_y}): {e}")
1429
+ continue
1430
+
1431
+ # Clear GPU memory
1432
+ self.masks = None
1433
+ self.boxes = None
1434
+ self.scores = None
1435
+ if hasattr(self, "inference_state"):
1436
+ self.inference_state = None
1437
+ # Additionally clear PyTorch CUDA cache, if available, to free GPU memory
1438
+ try:
1439
+ import torch
1440
+
1441
+ if torch.cuda.is_available():
1442
+ torch.cuda.empty_cache()
1443
+ except ImportError:
1444
+ # If torch is not installed, skip CUDA cache clearing
1445
+ pass
1446
+ # Update output profile
1447
+ profile.update(
1448
+ {
1449
+ "count": 1,
1450
+ "dtype": dtype,
1451
+ "compress": "deflate",
1452
+ }
1453
+ )
1454
+
1455
+ # Save the output
1456
+ with rasterio.open(output, "w", **profile) as dst:
1457
+ dst.write(output_mask, 1)
1458
+
1459
+ if verbose:
1460
+ print(f"Saved mask to {output}")
1461
+ print(f"Total objects found: {total_objects}")
1462
+
1463
+ # Store result for potential visualization
1464
+ self.objects = output_mask
1465
+ self.source = source
1466
+
1467
+ return output
1468
+
1469
+ def _merge_tile_mask(
1470
+ self,
1471
+ output_mask: np.ndarray,
1472
+ tile_mask: np.ndarray,
1473
+ x_start: int,
1474
+ y_start: int,
1475
+ x_end: int,
1476
+ y_end: int,
1477
+ overlap: int,
1478
+ tile_x: int,
1479
+ tile_y: int,
1480
+ n_tiles_x: int,
1481
+ n_tiles_y: int,
1482
+ ) -> None:
1483
+ """
1484
+ Merge a tile mask into the output mask, handling overlapping regions.
1485
+
1486
+ For overlapping regions, this uses a blending approach where we prioritize
1487
+ the current tile's mask in the non-overlapping core region, and for the
1488
+ overlap region, we keep existing values unless they are zero.
1489
+
1490
+ Args:
1491
+ output_mask: The full output mask array.
1492
+ tile_mask: The mask from the current tile.
1493
+ x_start, y_start: Start coordinates of the tile in the output.
1494
+ x_end, y_end: End coordinates of the tile in the output.
1495
+ overlap: The overlap size.
1496
+ tile_x, tile_y: Tile indices.
1497
+ n_tiles_x, n_tiles_y: Total number of tiles in each direction.
1498
+ """
1499
+ tile_height = y_end - y_start
1500
+ tile_width = x_end - x_start
1501
+
1502
+ # Calculate the core region (non-overlapping part)
1503
+ # The overlap should be split between adjacent tiles
1504
+ left_overlap = overlap // 2 if tile_x > 0 else 0
1505
+ right_overlap = overlap // 2 if tile_x < n_tiles_x - 1 else 0
1506
+ top_overlap = overlap // 2 if tile_y > 0 else 0
1507
+ bottom_overlap = overlap // 2 if tile_y < n_tiles_y - 1 else 0
1508
+
1509
+ # Core region in tile coordinates
1510
+ core_x_start = left_overlap
1511
+ core_x_end = tile_width - right_overlap
1512
+ core_y_start = top_overlap
1513
+ core_y_end = tile_height - bottom_overlap
1514
+
1515
+ # Copy core region (always overwrite)
1516
+ out_y_start = y_start + core_y_start
1517
+ out_y_end = y_start + core_y_end
1518
+ out_x_start = x_start + core_x_start
1519
+ out_x_end = x_start + core_x_end
1520
+
1521
+ output_mask[out_y_start:out_y_end, out_x_start:out_x_end] = tile_mask[
1522
+ core_y_start:core_y_end, core_x_start:core_x_end
1523
+ ]
1524
+
1525
+ # Handle overlap regions - only update if output is zero
1526
+ # Top overlap
1527
+ if top_overlap > 0:
1528
+ region = output_mask[y_start : y_start + top_overlap, out_x_start:out_x_end]
1529
+ tile_region = tile_mask[0:top_overlap, core_x_start:core_x_end]
1530
+ mask = region == 0
1531
+ region[mask] = tile_region[mask]
1532
+
1533
+ # Bottom overlap
1534
+ if bottom_overlap > 0:
1535
+ region = output_mask[out_y_end:y_end, out_x_start:out_x_end]
1536
+ tile_region = tile_mask[core_y_end:tile_height, core_x_start:core_x_end]
1537
+ mask = region == 0
1538
+ region[mask] = tile_region[mask]
1539
+
1540
+ # Left overlap
1541
+ if left_overlap > 0:
1542
+ region = output_mask[
1543
+ out_y_start:out_y_end, x_start : x_start + left_overlap
1544
+ ]
1545
+ tile_region = tile_mask[core_y_start:core_y_end, 0:left_overlap]
1546
+ mask = region == 0
1547
+ region[mask] = tile_region[mask]
1548
+
1549
+ # Right overlap
1550
+ if right_overlap > 0:
1551
+ region = output_mask[out_y_start:out_y_end, out_x_end:x_end]
1552
+ tile_region = tile_mask[core_y_start:core_y_end, core_x_end:tile_width]
1553
+ mask = region == 0
1554
+ region[mask] = tile_region[mask]
1555
+
1556
+ # Corner overlaps
1557
+ # Top-left
1558
+ if top_overlap > 0 and left_overlap > 0:
1559
+ region = output_mask[
1560
+ y_start : y_start + top_overlap, x_start : x_start + left_overlap
1561
+ ]
1562
+ tile_region = tile_mask[0:top_overlap, 0:left_overlap]
1563
+ mask = region == 0
1564
+ region[mask] = tile_region[mask]
1565
+
1566
+ # Top-right
1567
+ if top_overlap > 0 and right_overlap > 0:
1568
+ region = output_mask[y_start : y_start + top_overlap, out_x_end:x_end]
1569
+ tile_region = tile_mask[0:top_overlap, core_x_end:tile_width]
1570
+ mask = region == 0
1571
+ region[mask] = tile_region[mask]
1572
+
1573
+ # Bottom-left
1574
+ if bottom_overlap > 0 and left_overlap > 0:
1575
+ region = output_mask[out_y_end:y_end, x_start : x_start + left_overlap]
1576
+ tile_region = tile_mask[core_y_end:tile_height, 0:left_overlap]
1577
+ mask = region == 0
1578
+ region[mask] = tile_region[mask]
1579
+
1580
+ # Bottom-right
1581
+ if bottom_overlap > 0 and right_overlap > 0:
1582
+ region = output_mask[out_y_end:y_end, out_x_end:x_end]
1583
+ tile_region = tile_mask[core_y_end:tile_height, core_x_end:tile_width]
1584
+ mask = region == 0
1585
+ region[mask] = tile_region[mask]
1149
1586
 
1150
1587
  def generate_masks_by_boxes(
1151
1588
  self,
@@ -2294,7 +2731,9 @@ class SamGeo3:
2294
2731
  # Use the same color generation as the original method for consistency
2295
2732
  COLORS = generate_colors(n_colors=128, n_samples=5000)
2296
2733
  # Convert from 0-1 float RGB to 0-255 int RGB for OpenCV
2297
- colors_rgb = [(int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)) for c in COLORS]
2734
+ colors_rgb = [
2735
+ (int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)) for c in COLORS
2736
+ ]
2298
2737
 
2299
2738
  # Create overlay for all masks
2300
2739
  overlay = np.zeros((h, w, 3), dtype=np.float32)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: segment-geospatial
3
- Version: 1.1.0
3
+ Version: 1.2.1
4
4
  Summary: Meta AI' Segment Anything Model (SAM) for Geospatial Data.
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT license
@@ -29,6 +29,7 @@ Requires-Dist: pyproj
29
29
  Requires-Dist: rasterio
30
30
  Requires-Dist: segment_anything
31
31
  Requires-Dist: shapely
32
+ Requires-Dist: smoothify
32
33
  Requires-Dist: torch
33
34
  Requires-Dist: torchvision
34
35
  Requires-Dist: tqdm
@@ -1,16 +1,16 @@
1
- samgeo/__init__.py,sha256=7AUmKzGEYDeMUe2yWuusYTKSJrjcofmEn8AdmeA96IE,245
1
+ samgeo/__init__.py,sha256=qqGHrggLmdcOnXr0yhNmBTp0F_0qkwmZT2MOS1p19zQ,245
2
2
  samgeo/caption.py,sha256=Ttn9KwIEiz4n-jJU2d4o0HCkloo_6JSDev272yULWbo,20053
3
- samgeo/common.py,sha256=Scam3v-nKxLa5FIObLOsKSzYwnQ9lxQw23arG-lXB24,141207
3
+ samgeo/common.py,sha256=E1dvoWP2B-qf6LULH0mVHcnyj5OV6G64tilZFtvGEqs,149721
4
4
  samgeo/fast_sam.py,sha256=iFAaY4XXTtNCSSFCYFfljfWbLXZlMO4ZJQ5BsTmoeT8,10964
5
5
  samgeo/fer.py,sha256=gW4kVhxDGZlU48Fl46O45svtVmRpg13fyIXhJBS3Mi8,34652
6
6
  samgeo/hq_sam.py,sha256=qmwApENccnKXIqnHDmzA-_c_XvaUwCruVz4zWgF_sA4,34861
7
7
  samgeo/samgeo.py,sha256=ZLDGlaCYHTLFod-2W8OoIXRFLRRP1lpsBjLlIv_hwgU,37417
8
8
  samgeo/samgeo2.py,sha256=yMZwAh89pomAGNegtIVWJPCpHyJYLUOw3MFnaZZg_Ec,68896
9
- samgeo/samgeo3.py,sha256=uPFmMzYVVLS-2JV-mc6a38U-bW1fJL_U0Oy3Du2SlMo,186812
9
+ samgeo/samgeo3.py,sha256=d2IaeCHHXJHbWnD8HpktSDOb7mGqksxNCBMn4FDo3jA,204115
10
10
  samgeo/text_sam.py,sha256=nFtwEoyr_jDUrjoz19aEgHuUU9occFrcwbiHDIrEeRs,27006
11
11
  samgeo/utmconv.py,sha256=hFa7WYdZ83vJh-QX8teuu-xZEMJxZ9b795-Pwaq_rYs,5816
12
- segment_geospatial-1.1.0.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
- segment_geospatial-1.1.0.dist-info/METADATA,sha256=YMBLvKZolHXV44VtFG_29bcsJ6rTJEJkouOzPppqtFI,15603
14
- segment_geospatial-1.1.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
- segment_geospatial-1.1.0.dist-info/top_level.txt,sha256=qmuxRrJ2s8MYWBdoBivlvEHHQkmhqLlXWqOyzEGP9XA,7
16
- segment_geospatial-1.1.0.dist-info/RECORD,,
12
+ segment_geospatial-1.2.1.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
+ segment_geospatial-1.2.1.dist-info/METADATA,sha256=icPJ2wWNU_T7_11cvM8v3FrNofspW-nJdxpHCY2Y5w4,15628
14
+ segment_geospatial-1.2.1.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
+ segment_geospatial-1.2.1.dist-info/top_level.txt,sha256=qmuxRrJ2s8MYWBdoBivlvEHHQkmhqLlXWqOyzEGP9XA,7
16
+ segment_geospatial-1.2.1.dist-info/RECORD,,