segment-geospatial 1.1.0__py2.py3-none-any.whl → 1.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
samgeo/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "1.1.0"
5
+ __version__ = "1.2.0"
6
6
 
7
7
 
8
8
  from .samgeo import * # noqa: F403
samgeo/common.py CHANGED
@@ -4,7 +4,7 @@ The source code is adapted from https://github.com/aliaksandr960/segment-anythin
4
4
 
5
5
  import os
6
6
  import tempfile
7
- from typing import Any, List, Optional, Tuple, Union
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cv2
10
10
  import geopandas as gpd
@@ -4055,3 +4055,220 @@ def get_device() -> torch.device:
4055
4055
  return torch.device("mps")
4056
4056
  else:
4057
4057
  return torch.device("cpu")
4058
+
4059
+
4060
+ def get_raster_info(raster_path: str) -> Dict[str, Any]:
4061
+ """Display basic information about a raster dataset.
4062
+
4063
+ Args:
4064
+ raster_path (str): Path to the raster file
4065
+
4066
+ Returns:
4067
+ dict: Dictionary containing the basic information about the raster
4068
+ """
4069
+ # Open the raster dataset
4070
+ with rasterio.open(raster_path) as src:
4071
+ # Get basic metadata
4072
+ info = {
4073
+ "driver": src.driver,
4074
+ "width": src.width,
4075
+ "height": src.height,
4076
+ "count": src.count,
4077
+ "dtype": src.dtypes[0],
4078
+ "crs": src.crs.to_string() if src.crs else "No CRS defined",
4079
+ "transform": src.transform,
4080
+ "bounds": src.bounds,
4081
+ "resolution": (src.transform[0], -src.transform[4]),
4082
+ "nodata": src.nodata,
4083
+ }
4084
+
4085
+ # Calculate statistics for each band
4086
+ stats = []
4087
+ for i in range(1, src.count + 1):
4088
+ band = src.read(i, masked=True)
4089
+ band_stats = {
4090
+ "band": i,
4091
+ "min": float(band.min()),
4092
+ "max": float(band.max()),
4093
+ "mean": float(band.mean()),
4094
+ "std": float(band.std()),
4095
+ }
4096
+ stats.append(band_stats)
4097
+
4098
+ info["band_stats"] = stats
4099
+
4100
+ return info
4101
+
4102
+
4103
+ def get_raster_stats(raster_path: str, divide_by: float = 1.0) -> Dict[str, Any]:
4104
+ """Calculate statistics for each band in a raster dataset.
4105
+
4106
+ This function computes min, max, mean, and standard deviation values
4107
+ for each band in the provided raster, returning results in a dictionary
4108
+ with lists for each statistic type.
4109
+
4110
+ Args:
4111
+ raster_path (str): Path to the raster file
4112
+ divide_by (float, optional): Value to divide pixel values by.
4113
+ Defaults to 1.0, which keeps the original pixel values unchanged.
4114
+
4115
+ Returns:
4116
+ dict: Dictionary containing lists of statistics with keys:
4117
+ - 'min': List of minimum values for each band
4118
+ - 'max': List of maximum values for each band
4119
+ - 'mean': List of mean values for each band
4120
+ - 'std': List of standard deviation values for each band
4121
+ """
4122
+ # Initialize the results dictionary with empty lists
4123
+ stats = {"min": [], "max": [], "mean": [], "std": []}
4124
+
4125
+ # Open the raster dataset
4126
+ with rasterio.open(raster_path) as src:
4127
+ # Calculate statistics for each band
4128
+ for i in range(1, src.count + 1):
4129
+ band = src.read(i, masked=True)
4130
+
4131
+ # Append statistics for this band to each list
4132
+ stats["min"].append(float(band.min()) / divide_by)
4133
+ stats["max"].append(float(band.max()) / divide_by)
4134
+ stats["mean"].append(float(band.mean()) / divide_by)
4135
+ stats["std"].append(float(band.std()) / divide_by)
4136
+
4137
+ return stats
4138
+
4139
+
4140
+ def print_raster_info(
4141
+ raster_path: str, show_preview: bool = True, figsize: Tuple[int, int] = (10, 8)
4142
+ ) -> Optional[Dict[str, Any]]:
4143
+ """Print formatted information about a raster dataset and optionally show a preview.
4144
+
4145
+ Args:
4146
+ raster_path (str): Path to the raster file
4147
+ show_preview (bool, optional): Whether to display a visual preview of the raster.
4148
+ Defaults to True.
4149
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
4150
+
4151
+ Returns:
4152
+ dict: Dictionary containing raster information if successful, None otherwise
4153
+ """
4154
+ import matplotlib.pyplot as plt
4155
+ from rasterio.plot import show
4156
+
4157
+ try:
4158
+ info = get_raster_info(raster_path)
4159
+
4160
+ # Print basic information
4161
+ print(f"===== RASTER INFORMATION: {raster_path} =====")
4162
+ print(f"Driver: {info['driver']}")
4163
+ print(f"Dimensions: {info['width']} x {info['height']} pixels")
4164
+ print(f"Number of bands: {info['count']}")
4165
+ print(f"Data type: {info['dtype']}")
4166
+ print(f"Coordinate Reference System: {info['crs']}")
4167
+ print(f"Georeferenced Bounds: {info['bounds']}")
4168
+ print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
4169
+ print(f"NoData Value: {info['nodata']}")
4170
+
4171
+ # Print band statistics
4172
+ print("\n----- Band Statistics -----")
4173
+ for band_stat in info["band_stats"]:
4174
+ print(f"Band {band_stat['band']}:")
4175
+ print(f" Min: {band_stat['min']:.2f}")
4176
+ print(f" Max: {band_stat['max']:.2f}")
4177
+ print(f" Mean: {band_stat['mean']:.2f}")
4178
+ print(f" Std Dev: {band_stat['std']:.2f}")
4179
+
4180
+ # Show a preview if requested
4181
+ if show_preview:
4182
+ with rasterio.open(raster_path) as src:
4183
+ # For multi-band images, show RGB composite or first band
4184
+ if src.count >= 3:
4185
+ # Try to show RGB composite
4186
+ rgb = np.dstack([src.read(i) for i in range(1, 4)])
4187
+ plt.figure(figsize=figsize)
4188
+ plt.imshow(rgb)
4189
+ plt.title(f"RGB Preview: {raster_path}")
4190
+ else:
4191
+ # Show first band for single-band images
4192
+ plt.figure(figsize=figsize)
4193
+ show(
4194
+ src.read(1),
4195
+ cmap="viridis",
4196
+ title=f"Band 1 Preview: {raster_path}",
4197
+ )
4198
+ plt.colorbar(label="Pixel Value")
4199
+ plt.show()
4200
+
4201
+ return info
4202
+
4203
+ except Exception as e:
4204
+ print(f"Error reading raster: {str(e)}")
4205
+ return None
4206
+
4207
+
4208
+ def smooth_vector(
4209
+ vector_data: Union[str, gpd.GeoDataFrame],
4210
+ output_path: str = None,
4211
+ segment_length: float = None,
4212
+ smooth_iterations: int = 3,
4213
+ num_cores: int = 0,
4214
+ merge_collection: bool = True,
4215
+ merge_field: str = None,
4216
+ merge_multipolygons: bool = True,
4217
+ preserve_area: bool = True,
4218
+ area_tolerance: float = 0.01,
4219
+ **kwargs: Any,
4220
+ ) -> gpd.GeoDataFrame:
4221
+ """Smooth a vector data using the smoothify library.
4222
+ See https://github.com/DPIRD-DMA/Smoothify for more details.
4223
+
4224
+ Args:
4225
+ vector_data: The vector data to smooth.
4226
+ output_path: The path to save the smoothed vector data. If None, returns the smoothed vector data.
4227
+ segment_length: Resolution of the original raster data in map units. If None (default), automatically
4228
+ detects by finding the minimum segment length (from a data sample). Recommended to specify explicitly when known.
4229
+ smooth_iterations: The number of iterations to smooth the vector data.
4230
+ num_cores: Number of cores to use for parallel processing. If 0 (default), uses all available cores.
4231
+ merge_collection: Whether to merge/dissolve adjacent geometries in collections before smoothing.
4232
+ merge_field: Column name to use for dissolving geometries. Only valid when merge_collection=True.
4233
+ If None, dissolves all geometries together. If specified, dissolves geometries grouped by the column values.
4234
+ merge_multipolygons: Whether to merge adjacent polygons within MultiPolygons before smoothing
4235
+ preserve_area: Whether to restore original area after smoothing via buffering (applies to Polygons only)
4236
+ area_tolerance: Percentage of original area allowed as error (e.g., 0.01 = 0.01% error = 99.99% preservation).
4237
+ Only affects Polygons when preserve_area=True
4238
+
4239
+ Returns:
4240
+ gpd.GeoDataFrame: The smoothed vector data.
4241
+
4242
+ Examples:
4243
+ >>> from samgeo import common
4244
+ >>> gdf = common.read_vector("path/to/vector.geojson")
4245
+ >>> smoothed_gdf = common.smooth_vector(gdf, smooth_iterations=3, output_path="path/to/smoothed_vector.geojson")
4246
+ >>> smoothed_gdf.head()
4247
+ >>> smoothed_gdf.explore()
4248
+ """
4249
+ import leafmap
4250
+
4251
+ try:
4252
+ from smoothify import smoothify
4253
+ except ImportError:
4254
+ install_package("smoothify")
4255
+ from smoothify import smoothify
4256
+
4257
+ if isinstance(vector_data, str):
4258
+ vector_data = leafmap.read_vector(vector_data)
4259
+
4260
+ smoothed_vector_data = smoothify(
4261
+ geom=vector_data,
4262
+ segment_length=segment_length,
4263
+ smooth_iterations=smooth_iterations,
4264
+ num_cores=num_cores,
4265
+ merge_collection=merge_collection,
4266
+ merge_field=merge_field,
4267
+ merge_multipolygons=merge_multipolygons,
4268
+ preserve_area=preserve_area,
4269
+ area_tolerance=area_tolerance,
4270
+ **kwargs,
4271
+ )
4272
+ if output_path is not None:
4273
+ smoothed_vector_data.to_file(output_path)
4274
+ return smoothed_vector_data
samgeo/samgeo3.py CHANGED
@@ -1075,6 +1075,7 @@ class SamGeo3:
1075
1075
  prompt: str,
1076
1076
  min_size: int = 0,
1077
1077
  max_size: Optional[int] = None,
1078
+ quiet: bool = False,
1078
1079
  **kwargs: Any,
1079
1080
  ) -> List[Dict[str, Any]]:
1080
1081
  """
@@ -1086,6 +1087,7 @@ class SamGeo3:
1086
1087
  will be filtered out. Defaults to 0.
1087
1088
  max_size (int, optional): Maximum mask size in pixels. Masks larger than
1088
1089
  this will be filtered out. Defaults to None (no maximum).
1090
+ quiet (bool): If True, suppress progress messages. Defaults to False.
1089
1091
 
1090
1092
  Returns:
1091
1093
  List[Dict[str, Any]]: A list of dictionaries containing the generated masks.
@@ -1140,12 +1142,438 @@ class SamGeo3:
1140
1142
  self._filter_masks_by_size(min_size, max_size)
1141
1143
 
1142
1144
  num_objects = len(self.masks)
1143
- if num_objects == 0:
1144
- print("No objects found. Please try a different prompt.")
1145
- elif num_objects == 1:
1146
- print("Found one object.")
1145
+ if not quiet:
1146
+ if num_objects == 0:
1147
+ print("No objects found. Please try a different prompt.")
1148
+ elif num_objects == 1:
1149
+ print("Found one object.")
1150
+ else:
1151
+ print(f"Found {num_objects} objects.")
1152
+
1153
+ def generate_masks_tiled(
1154
+ self,
1155
+ source: str,
1156
+ prompt: str,
1157
+ output: str,
1158
+ tile_size: int = 1024,
1159
+ overlap: int = 128,
1160
+ min_size: int = 0,
1161
+ max_size: Optional[int] = None,
1162
+ unique: bool = True,
1163
+ dtype: str = "uint32",
1164
+ bands: Optional[List[int]] = None,
1165
+ batch_size: int = 1,
1166
+ verbose: bool = True,
1167
+ **kwargs: Any,
1168
+ ) -> str:
1169
+ """
1170
+ Generate masks for large GeoTIFF images using a sliding window approach.
1171
+
1172
+ This method processes large images tile by tile to avoid GPU memory issues.
1173
+ The tiles are processed with overlap to ensure seamless mask merging at
1174
+ boundaries. Each detected object gets a unique ID that is consistent
1175
+ across the entire image.
1176
+
1177
+ Args:
1178
+ source (str): Path to the input GeoTIFF image.
1179
+ prompt (str): The text prompt describing the objects to segment.
1180
+ output (str): Path to the output GeoTIFF file.
1181
+ tile_size (int): Size of each tile in pixels. Defaults to 1024.
1182
+ overlap (int): Overlap between adjacent tiles in pixels. Defaults to 128.
1183
+ Higher overlap helps with better boundary merging but increases
1184
+ processing time.
1185
+ min_size (int): Minimum mask size in pixels. Masks smaller than this
1186
+ will be filtered out. Defaults to 0.
1187
+ max_size (int, optional): Maximum mask size in pixels. Masks larger than
1188
+ this will be filtered out. Defaults to None (no maximum).
1189
+ unique (bool): If True, each mask gets a unique value. If False, binary
1190
+ mask (0 or 1). Defaults to True.
1191
+ dtype (str): Data type for the output array. Use 'uint32' for large
1192
+ numbers of objects, 'uint16' for up to 65535 objects, or 'uint8'
1193
+ for up to 255 objects. Defaults to 'uint32'.
1194
+ bands (List[int], optional): List of band indices (1-based) to use for RGB
1195
+ when the input has more than 3 bands. If None, uses first 3 bands.
1196
+ batch_size (int): Number of tiles to process at once (future use).
1197
+ Defaults to 1.
1198
+ verbose (bool): Whether to print progress information. Defaults to True.
1199
+ **kwargs: Additional keyword arguments.
1200
+
1201
+ Returns:
1202
+ str: Path to the output GeoTIFF file.
1203
+
1204
+ Example:
1205
+ >>> sam = SamGeo3(backend="meta")
1206
+ >>> sam.generate_masks_tiled(
1207
+ ... source="large_satellite_image.tif",
1208
+ ... prompt="building",
1209
+ ... output="buildings_mask.tif",
1210
+ ... tile_size=1024,
1211
+ ... overlap=128,
1212
+ ... )
1213
+ """
1214
+ import rasterio
1215
+ from rasterio.windows import Window
1216
+
1217
+ if not source.lower().endswith((".tif", ".tiff")):
1218
+ raise ValueError("Source must be a GeoTIFF file for tiled processing.")
1219
+
1220
+ if not os.path.exists(source):
1221
+ raise ValueError(f"Source file not found: {source}")
1222
+
1223
+ if tile_size <= overlap:
1224
+ raise ValueError("tile_size must be greater than overlap")
1225
+
1226
+ # Open the source file to get metadata
1227
+ with rasterio.open(source) as src:
1228
+ img_height = src.height
1229
+ img_width = src.width
1230
+ profile = src.profile.copy()
1231
+
1232
+ if verbose:
1233
+ print(f"Processing image: {img_width} x {img_height} pixels")
1234
+ print(f"Tile size: {tile_size}, Overlap: {overlap}")
1235
+
1236
+ # Calculate the number of tiles
1237
+ step = tile_size - overlap
1238
+ n_tiles_x = max(1, (img_width - overlap + step - 1) // step)
1239
+ n_tiles_y = max(1, (img_height - overlap + step - 1) // step)
1240
+ total_tiles = n_tiles_x * n_tiles_y
1241
+
1242
+ if verbose:
1243
+ print(f"Total tiles to process: {total_tiles} ({n_tiles_x} x {n_tiles_y})")
1244
+
1245
+ # Determine output dtype
1246
+ if dtype == "uint8":
1247
+ np_dtype = np.uint8
1248
+ max_objects = 255
1249
+ elif dtype == "uint16":
1250
+ np_dtype = np.uint16
1251
+ max_objects = 65535
1252
+ elif dtype == "uint32":
1253
+ np_dtype = np.uint32
1254
+ max_objects = 4294967295
1147
1255
  else:
1148
- print(f"Found {num_objects} objects.")
1256
+ np_dtype = np.uint32
1257
+ max_objects = 4294967295
1258
+
1259
+ # Create output array in memory (for smaller images) or use memory-mapped file
1260
+ # For very large images, you might want to use rasterio windowed writing
1261
+ output_mask = np.zeros((img_height, img_width), dtype=np_dtype)
1262
+
1263
+ # Track unique object IDs across all tiles
1264
+ current_max_id = 0
1265
+ total_objects = 0
1266
+
1267
+ # Process each tile
1268
+ tile_iterator = tqdm(
1269
+ range(total_tiles),
1270
+ desc="Processing tiles",
1271
+ disable=not verbose,
1272
+ )
1273
+
1274
+ for tile_idx in tile_iterator:
1275
+ # Calculate tile position
1276
+ tile_y = tile_idx // n_tiles_x
1277
+ tile_x = tile_idx % n_tiles_x
1278
+
1279
+ # Calculate window coordinates
1280
+ x_start = tile_x * step
1281
+ y_start = tile_y * step
1282
+
1283
+ # Ensure we don't go beyond image bounds
1284
+ x_end = min(x_start + tile_size, img_width)
1285
+ y_end = min(y_start + tile_size, img_height)
1286
+
1287
+ # Adjust start if we're at the edge
1288
+ if x_end - x_start < tile_size and x_start > 0:
1289
+ x_start = max(0, x_end - tile_size)
1290
+ if y_end - y_start < tile_size and y_start > 0:
1291
+ y_start = max(0, y_end - tile_size)
1292
+
1293
+ window_width = x_end - x_start
1294
+ window_height = y_end - y_start
1295
+
1296
+ # Read tile from source
1297
+ with rasterio.open(source) as src:
1298
+ window = Window(x_start, y_start, window_width, window_height)
1299
+ if bands is not None:
1300
+ tile_data = np.stack(
1301
+ [src.read(b, window=window) for b in bands], axis=0
1302
+ )
1303
+ else:
1304
+ tile_data = src.read(window=window)
1305
+ if tile_data.shape[0] >= 3:
1306
+ tile_data = tile_data[:3, :, :]
1307
+ elif tile_data.shape[0] == 1:
1308
+ tile_data = np.repeat(tile_data, 3, axis=0)
1309
+ elif tile_data.shape[0] == 2:
1310
+ tile_data = np.concatenate(
1311
+ [tile_data, tile_data[0:1, :, :]], axis=0
1312
+ )
1313
+
1314
+ # Transpose to (height, width, channels)
1315
+ tile_data = np.transpose(tile_data, (1, 2, 0))
1316
+
1317
+ # Normalize to 8-bit
1318
+ tile_data = tile_data.astype(np.float32)
1319
+ tile_data -= tile_data.min()
1320
+ if tile_data.max() > 0:
1321
+ tile_data /= tile_data.max()
1322
+ tile_data *= 255
1323
+ tile_image = tile_data.astype(np.uint8)
1324
+
1325
+ # Process the tile
1326
+ try:
1327
+ # Set image for the tile
1328
+ self.image = tile_image
1329
+ self.image_height, self.image_width = tile_image.shape[:2]
1330
+ self.source = None # Don't need georef for individual tiles
1331
+
1332
+ # Initialize inference state for this tile
1333
+ pil_image = Image.fromarray(tile_image)
1334
+ self.pil_image = pil_image
1335
+
1336
+ if self.backend == "meta":
1337
+ self.inference_state = self.processor.set_image(pil_image)
1338
+ else:
1339
+ # For transformers backend, process directly
1340
+ pass
1341
+
1342
+ # Generate masks for this tile (quiet=True to avoid per-tile messages)
1343
+ self.generate_masks(
1344
+ prompt, min_size=min_size, max_size=max_size, quiet=True
1345
+ )
1346
+
1347
+ # Get masks for this tile
1348
+ tile_masks = self.masks
1349
+
1350
+ if tile_masks is not None and len(tile_masks) > 0:
1351
+ # Create a mask array for this tile
1352
+ tile_mask_array = np.zeros(
1353
+ (window_height, window_width), dtype=np_dtype
1354
+ )
1355
+
1356
+ for mask in tile_masks:
1357
+ # Convert mask to numpy
1358
+ if hasattr(mask, "cpu"):
1359
+ mask_np = mask.squeeze().cpu().numpy()
1360
+ elif hasattr(mask, "numpy"):
1361
+ mask_np = mask.squeeze().numpy()
1362
+ else:
1363
+ mask_np = (
1364
+ mask.squeeze() if hasattr(mask, "squeeze") else mask
1365
+ )
1366
+
1367
+ if mask_np.ndim > 2:
1368
+ mask_np = mask_np[0]
1369
+
1370
+ # Resize mask to tile size if needed
1371
+ if mask_np.shape != (window_height, window_width):
1372
+ mask_np = cv2.resize(
1373
+ mask_np.astype(np.float32),
1374
+ (window_width, window_height),
1375
+ interpolation=cv2.INTER_NEAREST,
1376
+ )
1377
+
1378
+ mask_bool = mask_np > 0
1379
+ mask_size = np.sum(mask_bool)
1380
+
1381
+ # Filter by size
1382
+ if mask_size < min_size:
1383
+ continue
1384
+ if max_size is not None and mask_size > max_size:
1385
+ continue
1386
+
1387
+ if unique:
1388
+ current_max_id += 1
1389
+ if current_max_id > max_objects:
1390
+ raise ValueError(
1391
+ f"Maximum number of objects ({max_objects}) exceeded. "
1392
+ "Consider using a larger dtype or reducing the number of objects."
1393
+ )
1394
+ tile_mask_array[mask_bool] = current_max_id
1395
+ else:
1396
+ tile_mask_array[mask_bool] = 1
1397
+
1398
+ total_objects += 1
1399
+
1400
+ # Merge tile mask into output mask
1401
+ # For overlapping regions, use the tile's values if they are non-zero
1402
+ # This simple approach works well for most cases
1403
+ self._merge_tile_mask(
1404
+ output_mask,
1405
+ tile_mask_array,
1406
+ x_start,
1407
+ y_start,
1408
+ x_end,
1409
+ y_end,
1410
+ overlap,
1411
+ tile_x,
1412
+ tile_y,
1413
+ n_tiles_x,
1414
+ n_tiles_y,
1415
+ )
1416
+
1417
+ except Exception as e:
1418
+ if verbose:
1419
+ print(f"Warning: Failed to process tile ({tile_x}, {tile_y}): {e}")
1420
+ continue
1421
+
1422
+ # Clear GPU memory
1423
+ self.masks = None
1424
+ self.boxes = None
1425
+ self.scores = None
1426
+ if hasattr(self, "inference_state"):
1427
+ self.inference_state = None
1428
+ # Additionally clear PyTorch CUDA cache, if available, to free GPU memory
1429
+ try:
1430
+ import torch
1431
+
1432
+ if torch.cuda.is_available():
1433
+ torch.cuda.empty_cache()
1434
+ except ImportError:
1435
+ # If torch is not installed, skip CUDA cache clearing
1436
+ pass
1437
+ # Update output profile
1438
+ profile.update(
1439
+ {
1440
+ "count": 1,
1441
+ "dtype": dtype,
1442
+ "compress": "deflate",
1443
+ }
1444
+ )
1445
+
1446
+ # Save the output
1447
+ with rasterio.open(output, "w", **profile) as dst:
1448
+ dst.write(output_mask, 1)
1449
+
1450
+ if verbose:
1451
+ print(f"Saved mask to {output}")
1452
+ print(f"Total objects found: {total_objects}")
1453
+
1454
+ # Store result for potential visualization
1455
+ self.objects = output_mask
1456
+ self.source = source
1457
+
1458
+ return output
1459
+
1460
+ def _merge_tile_mask(
1461
+ self,
1462
+ output_mask: np.ndarray,
1463
+ tile_mask: np.ndarray,
1464
+ x_start: int,
1465
+ y_start: int,
1466
+ x_end: int,
1467
+ y_end: int,
1468
+ overlap: int,
1469
+ tile_x: int,
1470
+ tile_y: int,
1471
+ n_tiles_x: int,
1472
+ n_tiles_y: int,
1473
+ ) -> None:
1474
+ """
1475
+ Merge a tile mask into the output mask, handling overlapping regions.
1476
+
1477
+ For overlapping regions, this uses a blending approach where we prioritize
1478
+ the current tile's mask in the non-overlapping core region, and for the
1479
+ overlap region, we keep existing values unless they are zero.
1480
+
1481
+ Args:
1482
+ output_mask: The full output mask array.
1483
+ tile_mask: The mask from the current tile.
1484
+ x_start, y_start: Start coordinates of the tile in the output.
1485
+ x_end, y_end: End coordinates of the tile in the output.
1486
+ overlap: The overlap size.
1487
+ tile_x, tile_y: Tile indices.
1488
+ n_tiles_x, n_tiles_y: Total number of tiles in each direction.
1489
+ """
1490
+ tile_height = y_end - y_start
1491
+ tile_width = x_end - x_start
1492
+
1493
+ # Calculate the core region (non-overlapping part)
1494
+ # The overlap should be split between adjacent tiles
1495
+ left_overlap = overlap // 2 if tile_x > 0 else 0
1496
+ right_overlap = overlap // 2 if tile_x < n_tiles_x - 1 else 0
1497
+ top_overlap = overlap // 2 if tile_y > 0 else 0
1498
+ bottom_overlap = overlap // 2 if tile_y < n_tiles_y - 1 else 0
1499
+
1500
+ # Core region in tile coordinates
1501
+ core_x_start = left_overlap
1502
+ core_x_end = tile_width - right_overlap
1503
+ core_y_start = top_overlap
1504
+ core_y_end = tile_height - bottom_overlap
1505
+
1506
+ # Copy core region (always overwrite)
1507
+ out_y_start = y_start + core_y_start
1508
+ out_y_end = y_start + core_y_end
1509
+ out_x_start = x_start + core_x_start
1510
+ out_x_end = x_start + core_x_end
1511
+
1512
+ output_mask[out_y_start:out_y_end, out_x_start:out_x_end] = tile_mask[
1513
+ core_y_start:core_y_end, core_x_start:core_x_end
1514
+ ]
1515
+
1516
+ # Handle overlap regions - only update if output is zero
1517
+ # Top overlap
1518
+ if top_overlap > 0:
1519
+ region = output_mask[y_start : y_start + top_overlap, out_x_start:out_x_end]
1520
+ tile_region = tile_mask[0:top_overlap, core_x_start:core_x_end]
1521
+ mask = region == 0
1522
+ region[mask] = tile_region[mask]
1523
+
1524
+ # Bottom overlap
1525
+ if bottom_overlap > 0:
1526
+ region = output_mask[out_y_end:y_end, out_x_start:out_x_end]
1527
+ tile_region = tile_mask[core_y_end:tile_height, core_x_start:core_x_end]
1528
+ mask = region == 0
1529
+ region[mask] = tile_region[mask]
1530
+
1531
+ # Left overlap
1532
+ if left_overlap > 0:
1533
+ region = output_mask[
1534
+ out_y_start:out_y_end, x_start : x_start + left_overlap
1535
+ ]
1536
+ tile_region = tile_mask[core_y_start:core_y_end, 0:left_overlap]
1537
+ mask = region == 0
1538
+ region[mask] = tile_region[mask]
1539
+
1540
+ # Right overlap
1541
+ if right_overlap > 0:
1542
+ region = output_mask[out_y_start:out_y_end, out_x_end:x_end]
1543
+ tile_region = tile_mask[core_y_start:core_y_end, core_x_end:tile_width]
1544
+ mask = region == 0
1545
+ region[mask] = tile_region[mask]
1546
+
1547
+ # Corner overlaps
1548
+ # Top-left
1549
+ if top_overlap > 0 and left_overlap > 0:
1550
+ region = output_mask[
1551
+ y_start : y_start + top_overlap, x_start : x_start + left_overlap
1552
+ ]
1553
+ tile_region = tile_mask[0:top_overlap, 0:left_overlap]
1554
+ mask = region == 0
1555
+ region[mask] = tile_region[mask]
1556
+
1557
+ # Top-right
1558
+ if top_overlap > 0 and right_overlap > 0:
1559
+ region = output_mask[y_start : y_start + top_overlap, out_x_end:x_end]
1560
+ tile_region = tile_mask[0:top_overlap, core_x_end:tile_width]
1561
+ mask = region == 0
1562
+ region[mask] = tile_region[mask]
1563
+
1564
+ # Bottom-left
1565
+ if bottom_overlap > 0 and left_overlap > 0:
1566
+ region = output_mask[out_y_end:y_end, x_start : x_start + left_overlap]
1567
+ tile_region = tile_mask[core_y_end:tile_height, 0:left_overlap]
1568
+ mask = region == 0
1569
+ region[mask] = tile_region[mask]
1570
+
1571
+ # Bottom-right
1572
+ if bottom_overlap > 0 and right_overlap > 0:
1573
+ region = output_mask[out_y_end:y_end, out_x_end:x_end]
1574
+ tile_region = tile_mask[core_y_end:tile_height, core_x_end:tile_width]
1575
+ mask = region == 0
1576
+ region[mask] = tile_region[mask]
1149
1577
 
1150
1578
  def generate_masks_by_boxes(
1151
1579
  self,
@@ -2294,7 +2722,9 @@ class SamGeo3:
2294
2722
  # Use the same color generation as the original method for consistency
2295
2723
  COLORS = generate_colors(n_colors=128, n_samples=5000)
2296
2724
  # Convert from 0-1 float RGB to 0-255 int RGB for OpenCV
2297
- colors_rgb = [(int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)) for c in COLORS]
2725
+ colors_rgb = [
2726
+ (int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)) for c in COLORS
2727
+ ]
2298
2728
 
2299
2729
  # Create overlay for all masks
2300
2730
  overlay = np.zeros((h, w, 3), dtype=np.float32)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: segment-geospatial
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Meta AI' Segment Anything Model (SAM) for Geospatial Data.
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT license
@@ -29,6 +29,7 @@ Requires-Dist: pyproj
29
29
  Requires-Dist: rasterio
30
30
  Requires-Dist: segment_anything
31
31
  Requires-Dist: shapely
32
+ Requires-Dist: smoothify
32
33
  Requires-Dist: torch
33
34
  Requires-Dist: torchvision
34
35
  Requires-Dist: tqdm
@@ -1,16 +1,16 @@
1
- samgeo/__init__.py,sha256=7AUmKzGEYDeMUe2yWuusYTKSJrjcofmEn8AdmeA96IE,245
1
+ samgeo/__init__.py,sha256=S1K0lIKUWsZb7mnBnLCV9qA6IpZjM9QjFQjRxO3ti6Y,245
2
2
  samgeo/caption.py,sha256=Ttn9KwIEiz4n-jJU2d4o0HCkloo_6JSDev272yULWbo,20053
3
- samgeo/common.py,sha256=Scam3v-nKxLa5FIObLOsKSzYwnQ9lxQw23arG-lXB24,141207
3
+ samgeo/common.py,sha256=E1dvoWP2B-qf6LULH0mVHcnyj5OV6G64tilZFtvGEqs,149721
4
4
  samgeo/fast_sam.py,sha256=iFAaY4XXTtNCSSFCYFfljfWbLXZlMO4ZJQ5BsTmoeT8,10964
5
5
  samgeo/fer.py,sha256=gW4kVhxDGZlU48Fl46O45svtVmRpg13fyIXhJBS3Mi8,34652
6
6
  samgeo/hq_sam.py,sha256=qmwApENccnKXIqnHDmzA-_c_XvaUwCruVz4zWgF_sA4,34861
7
7
  samgeo/samgeo.py,sha256=ZLDGlaCYHTLFod-2W8OoIXRFLRRP1lpsBjLlIv_hwgU,37417
8
8
  samgeo/samgeo2.py,sha256=yMZwAh89pomAGNegtIVWJPCpHyJYLUOw3MFnaZZg_Ec,68896
9
- samgeo/samgeo3.py,sha256=uPFmMzYVVLS-2JV-mc6a38U-bW1fJL_U0Oy3Du2SlMo,186812
9
+ samgeo/samgeo3.py,sha256=R3Vvf6cRjW6gD622crE2zVgzLdmYqlUmIYcji0qhXLo,203736
10
10
  samgeo/text_sam.py,sha256=nFtwEoyr_jDUrjoz19aEgHuUU9occFrcwbiHDIrEeRs,27006
11
11
  samgeo/utmconv.py,sha256=hFa7WYdZ83vJh-QX8teuu-xZEMJxZ9b795-Pwaq_rYs,5816
12
- segment_geospatial-1.1.0.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
- segment_geospatial-1.1.0.dist-info/METADATA,sha256=YMBLvKZolHXV44VtFG_29bcsJ6rTJEJkouOzPppqtFI,15603
14
- segment_geospatial-1.1.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
- segment_geospatial-1.1.0.dist-info/top_level.txt,sha256=qmuxRrJ2s8MYWBdoBivlvEHHQkmhqLlXWqOyzEGP9XA,7
16
- segment_geospatial-1.1.0.dist-info/RECORD,,
12
+ segment_geospatial-1.2.0.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
13
+ segment_geospatial-1.2.0.dist-info/METADATA,sha256=_3CdjEGM41gkq6rncDAv4mY-TnU1hovjk__LI4pXOzs,15628
14
+ segment_geospatial-1.2.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
15
+ segment_geospatial-1.2.0.dist-info/top_level.txt,sha256=qmuxRrJ2s8MYWBdoBivlvEHHQkmhqLlXWqOyzEGP9XA,7
16
+ segment_geospatial-1.2.0.dist-info/RECORD,,