segment-geospatial 1.1.0__py2.py3-none-any.whl → 1.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- samgeo/__init__.py +1 -1
- samgeo/common.py +218 -1
- samgeo/samgeo3.py +436 -6
- {segment_geospatial-1.1.0.dist-info → segment_geospatial-1.2.0.dist-info}/METADATA +2 -1
- {segment_geospatial-1.1.0.dist-info → segment_geospatial-1.2.0.dist-info}/RECORD +8 -8
- {segment_geospatial-1.1.0.dist-info → segment_geospatial-1.2.0.dist-info}/WHEEL +0 -0
- {segment_geospatial-1.1.0.dist-info → segment_geospatial-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {segment_geospatial-1.1.0.dist-info → segment_geospatial-1.2.0.dist-info}/top_level.txt +0 -0
samgeo/__init__.py
CHANGED
samgeo/common.py
CHANGED
|
@@ -4,7 +4,7 @@ The source code is adapted from https://github.com/aliaksandr960/segment-anythin
|
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
6
|
import tempfile
|
|
7
|
-
from typing import Any, List, Optional, Tuple, Union
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
8
8
|
|
|
9
9
|
import cv2
|
|
10
10
|
import geopandas as gpd
|
|
@@ -4055,3 +4055,220 @@ def get_device() -> torch.device:
|
|
|
4055
4055
|
return torch.device("mps")
|
|
4056
4056
|
else:
|
|
4057
4057
|
return torch.device("cpu")
|
|
4058
|
+
|
|
4059
|
+
|
|
4060
|
+
def get_raster_info(raster_path: str) -> Dict[str, Any]:
|
|
4061
|
+
"""Display basic information about a raster dataset.
|
|
4062
|
+
|
|
4063
|
+
Args:
|
|
4064
|
+
raster_path (str): Path to the raster file
|
|
4065
|
+
|
|
4066
|
+
Returns:
|
|
4067
|
+
dict: Dictionary containing the basic information about the raster
|
|
4068
|
+
"""
|
|
4069
|
+
# Open the raster dataset
|
|
4070
|
+
with rasterio.open(raster_path) as src:
|
|
4071
|
+
# Get basic metadata
|
|
4072
|
+
info = {
|
|
4073
|
+
"driver": src.driver,
|
|
4074
|
+
"width": src.width,
|
|
4075
|
+
"height": src.height,
|
|
4076
|
+
"count": src.count,
|
|
4077
|
+
"dtype": src.dtypes[0],
|
|
4078
|
+
"crs": src.crs.to_string() if src.crs else "No CRS defined",
|
|
4079
|
+
"transform": src.transform,
|
|
4080
|
+
"bounds": src.bounds,
|
|
4081
|
+
"resolution": (src.transform[0], -src.transform[4]),
|
|
4082
|
+
"nodata": src.nodata,
|
|
4083
|
+
}
|
|
4084
|
+
|
|
4085
|
+
# Calculate statistics for each band
|
|
4086
|
+
stats = []
|
|
4087
|
+
for i in range(1, src.count + 1):
|
|
4088
|
+
band = src.read(i, masked=True)
|
|
4089
|
+
band_stats = {
|
|
4090
|
+
"band": i,
|
|
4091
|
+
"min": float(band.min()),
|
|
4092
|
+
"max": float(band.max()),
|
|
4093
|
+
"mean": float(band.mean()),
|
|
4094
|
+
"std": float(band.std()),
|
|
4095
|
+
}
|
|
4096
|
+
stats.append(band_stats)
|
|
4097
|
+
|
|
4098
|
+
info["band_stats"] = stats
|
|
4099
|
+
|
|
4100
|
+
return info
|
|
4101
|
+
|
|
4102
|
+
|
|
4103
|
+
def get_raster_stats(raster_path: str, divide_by: float = 1.0) -> Dict[str, Any]:
|
|
4104
|
+
"""Calculate statistics for each band in a raster dataset.
|
|
4105
|
+
|
|
4106
|
+
This function computes min, max, mean, and standard deviation values
|
|
4107
|
+
for each band in the provided raster, returning results in a dictionary
|
|
4108
|
+
with lists for each statistic type.
|
|
4109
|
+
|
|
4110
|
+
Args:
|
|
4111
|
+
raster_path (str): Path to the raster file
|
|
4112
|
+
divide_by (float, optional): Value to divide pixel values by.
|
|
4113
|
+
Defaults to 1.0, which keeps the original pixel values unchanged.
|
|
4114
|
+
|
|
4115
|
+
Returns:
|
|
4116
|
+
dict: Dictionary containing lists of statistics with keys:
|
|
4117
|
+
- 'min': List of minimum values for each band
|
|
4118
|
+
- 'max': List of maximum values for each band
|
|
4119
|
+
- 'mean': List of mean values for each band
|
|
4120
|
+
- 'std': List of standard deviation values for each band
|
|
4121
|
+
"""
|
|
4122
|
+
# Initialize the results dictionary with empty lists
|
|
4123
|
+
stats = {"min": [], "max": [], "mean": [], "std": []}
|
|
4124
|
+
|
|
4125
|
+
# Open the raster dataset
|
|
4126
|
+
with rasterio.open(raster_path) as src:
|
|
4127
|
+
# Calculate statistics for each band
|
|
4128
|
+
for i in range(1, src.count + 1):
|
|
4129
|
+
band = src.read(i, masked=True)
|
|
4130
|
+
|
|
4131
|
+
# Append statistics for this band to each list
|
|
4132
|
+
stats["min"].append(float(band.min()) / divide_by)
|
|
4133
|
+
stats["max"].append(float(band.max()) / divide_by)
|
|
4134
|
+
stats["mean"].append(float(band.mean()) / divide_by)
|
|
4135
|
+
stats["std"].append(float(band.std()) / divide_by)
|
|
4136
|
+
|
|
4137
|
+
return stats
|
|
4138
|
+
|
|
4139
|
+
|
|
4140
|
+
def print_raster_info(
|
|
4141
|
+
raster_path: str, show_preview: bool = True, figsize: Tuple[int, int] = (10, 8)
|
|
4142
|
+
) -> Optional[Dict[str, Any]]:
|
|
4143
|
+
"""Print formatted information about a raster dataset and optionally show a preview.
|
|
4144
|
+
|
|
4145
|
+
Args:
|
|
4146
|
+
raster_path (str): Path to the raster file
|
|
4147
|
+
show_preview (bool, optional): Whether to display a visual preview of the raster.
|
|
4148
|
+
Defaults to True.
|
|
4149
|
+
figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
|
|
4150
|
+
|
|
4151
|
+
Returns:
|
|
4152
|
+
dict: Dictionary containing raster information if successful, None otherwise
|
|
4153
|
+
"""
|
|
4154
|
+
import matplotlib.pyplot as plt
|
|
4155
|
+
from rasterio.plot import show
|
|
4156
|
+
|
|
4157
|
+
try:
|
|
4158
|
+
info = get_raster_info(raster_path)
|
|
4159
|
+
|
|
4160
|
+
# Print basic information
|
|
4161
|
+
print(f"===== RASTER INFORMATION: {raster_path} =====")
|
|
4162
|
+
print(f"Driver: {info['driver']}")
|
|
4163
|
+
print(f"Dimensions: {info['width']} x {info['height']} pixels")
|
|
4164
|
+
print(f"Number of bands: {info['count']}")
|
|
4165
|
+
print(f"Data type: {info['dtype']}")
|
|
4166
|
+
print(f"Coordinate Reference System: {info['crs']}")
|
|
4167
|
+
print(f"Georeferenced Bounds: {info['bounds']}")
|
|
4168
|
+
print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
|
|
4169
|
+
print(f"NoData Value: {info['nodata']}")
|
|
4170
|
+
|
|
4171
|
+
# Print band statistics
|
|
4172
|
+
print("\n----- Band Statistics -----")
|
|
4173
|
+
for band_stat in info["band_stats"]:
|
|
4174
|
+
print(f"Band {band_stat['band']}:")
|
|
4175
|
+
print(f" Min: {band_stat['min']:.2f}")
|
|
4176
|
+
print(f" Max: {band_stat['max']:.2f}")
|
|
4177
|
+
print(f" Mean: {band_stat['mean']:.2f}")
|
|
4178
|
+
print(f" Std Dev: {band_stat['std']:.2f}")
|
|
4179
|
+
|
|
4180
|
+
# Show a preview if requested
|
|
4181
|
+
if show_preview:
|
|
4182
|
+
with rasterio.open(raster_path) as src:
|
|
4183
|
+
# For multi-band images, show RGB composite or first band
|
|
4184
|
+
if src.count >= 3:
|
|
4185
|
+
# Try to show RGB composite
|
|
4186
|
+
rgb = np.dstack([src.read(i) for i in range(1, 4)])
|
|
4187
|
+
plt.figure(figsize=figsize)
|
|
4188
|
+
plt.imshow(rgb)
|
|
4189
|
+
plt.title(f"RGB Preview: {raster_path}")
|
|
4190
|
+
else:
|
|
4191
|
+
# Show first band for single-band images
|
|
4192
|
+
plt.figure(figsize=figsize)
|
|
4193
|
+
show(
|
|
4194
|
+
src.read(1),
|
|
4195
|
+
cmap="viridis",
|
|
4196
|
+
title=f"Band 1 Preview: {raster_path}",
|
|
4197
|
+
)
|
|
4198
|
+
plt.colorbar(label="Pixel Value")
|
|
4199
|
+
plt.show()
|
|
4200
|
+
|
|
4201
|
+
return info
|
|
4202
|
+
|
|
4203
|
+
except Exception as e:
|
|
4204
|
+
print(f"Error reading raster: {str(e)}")
|
|
4205
|
+
return None
|
|
4206
|
+
|
|
4207
|
+
|
|
4208
|
+
def smooth_vector(
|
|
4209
|
+
vector_data: Union[str, gpd.GeoDataFrame],
|
|
4210
|
+
output_path: str = None,
|
|
4211
|
+
segment_length: float = None,
|
|
4212
|
+
smooth_iterations: int = 3,
|
|
4213
|
+
num_cores: int = 0,
|
|
4214
|
+
merge_collection: bool = True,
|
|
4215
|
+
merge_field: str = None,
|
|
4216
|
+
merge_multipolygons: bool = True,
|
|
4217
|
+
preserve_area: bool = True,
|
|
4218
|
+
area_tolerance: float = 0.01,
|
|
4219
|
+
**kwargs: Any,
|
|
4220
|
+
) -> gpd.GeoDataFrame:
|
|
4221
|
+
"""Smooth a vector data using the smoothify library.
|
|
4222
|
+
See https://github.com/DPIRD-DMA/Smoothify for more details.
|
|
4223
|
+
|
|
4224
|
+
Args:
|
|
4225
|
+
vector_data: The vector data to smooth.
|
|
4226
|
+
output_path: The path to save the smoothed vector data. If None, returns the smoothed vector data.
|
|
4227
|
+
segment_length: Resolution of the original raster data in map units. If None (default), automatically
|
|
4228
|
+
detects by finding the minimum segment length (from a data sample). Recommended to specify explicitly when known.
|
|
4229
|
+
smooth_iterations: The number of iterations to smooth the vector data.
|
|
4230
|
+
num_cores: Number of cores to use for parallel processing. If 0 (default), uses all available cores.
|
|
4231
|
+
merge_collection: Whether to merge/dissolve adjacent geometries in collections before smoothing.
|
|
4232
|
+
merge_field: Column name to use for dissolving geometries. Only valid when merge_collection=True.
|
|
4233
|
+
If None, dissolves all geometries together. If specified, dissolves geometries grouped by the column values.
|
|
4234
|
+
merge_multipolygons: Whether to merge adjacent polygons within MultiPolygons before smoothing
|
|
4235
|
+
preserve_area: Whether to restore original area after smoothing via buffering (applies to Polygons only)
|
|
4236
|
+
area_tolerance: Percentage of original area allowed as error (e.g., 0.01 = 0.01% error = 99.99% preservation).
|
|
4237
|
+
Only affects Polygons when preserve_area=True
|
|
4238
|
+
|
|
4239
|
+
Returns:
|
|
4240
|
+
gpd.GeoDataFrame: The smoothed vector data.
|
|
4241
|
+
|
|
4242
|
+
Examples:
|
|
4243
|
+
>>> from samgeo import common
|
|
4244
|
+
>>> gdf = common.read_vector("path/to/vector.geojson")
|
|
4245
|
+
>>> smoothed_gdf = common.smooth_vector(gdf, smooth_iterations=3, output_path="path/to/smoothed_vector.geojson")
|
|
4246
|
+
>>> smoothed_gdf.head()
|
|
4247
|
+
>>> smoothed_gdf.explore()
|
|
4248
|
+
"""
|
|
4249
|
+
import leafmap
|
|
4250
|
+
|
|
4251
|
+
try:
|
|
4252
|
+
from smoothify import smoothify
|
|
4253
|
+
except ImportError:
|
|
4254
|
+
install_package("smoothify")
|
|
4255
|
+
from smoothify import smoothify
|
|
4256
|
+
|
|
4257
|
+
if isinstance(vector_data, str):
|
|
4258
|
+
vector_data = leafmap.read_vector(vector_data)
|
|
4259
|
+
|
|
4260
|
+
smoothed_vector_data = smoothify(
|
|
4261
|
+
geom=vector_data,
|
|
4262
|
+
segment_length=segment_length,
|
|
4263
|
+
smooth_iterations=smooth_iterations,
|
|
4264
|
+
num_cores=num_cores,
|
|
4265
|
+
merge_collection=merge_collection,
|
|
4266
|
+
merge_field=merge_field,
|
|
4267
|
+
merge_multipolygons=merge_multipolygons,
|
|
4268
|
+
preserve_area=preserve_area,
|
|
4269
|
+
area_tolerance=area_tolerance,
|
|
4270
|
+
**kwargs,
|
|
4271
|
+
)
|
|
4272
|
+
if output_path is not None:
|
|
4273
|
+
smoothed_vector_data.to_file(output_path)
|
|
4274
|
+
return smoothed_vector_data
|
samgeo/samgeo3.py
CHANGED
|
@@ -1075,6 +1075,7 @@ class SamGeo3:
|
|
|
1075
1075
|
prompt: str,
|
|
1076
1076
|
min_size: int = 0,
|
|
1077
1077
|
max_size: Optional[int] = None,
|
|
1078
|
+
quiet: bool = False,
|
|
1078
1079
|
**kwargs: Any,
|
|
1079
1080
|
) -> List[Dict[str, Any]]:
|
|
1080
1081
|
"""
|
|
@@ -1086,6 +1087,7 @@ class SamGeo3:
|
|
|
1086
1087
|
will be filtered out. Defaults to 0.
|
|
1087
1088
|
max_size (int, optional): Maximum mask size in pixels. Masks larger than
|
|
1088
1089
|
this will be filtered out. Defaults to None (no maximum).
|
|
1090
|
+
quiet (bool): If True, suppress progress messages. Defaults to False.
|
|
1089
1091
|
|
|
1090
1092
|
Returns:
|
|
1091
1093
|
List[Dict[str, Any]]: A list of dictionaries containing the generated masks.
|
|
@@ -1140,12 +1142,438 @@ class SamGeo3:
|
|
|
1140
1142
|
self._filter_masks_by_size(min_size, max_size)
|
|
1141
1143
|
|
|
1142
1144
|
num_objects = len(self.masks)
|
|
1143
|
-
if
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1145
|
+
if not quiet:
|
|
1146
|
+
if num_objects == 0:
|
|
1147
|
+
print("No objects found. Please try a different prompt.")
|
|
1148
|
+
elif num_objects == 1:
|
|
1149
|
+
print("Found one object.")
|
|
1150
|
+
else:
|
|
1151
|
+
print(f"Found {num_objects} objects.")
|
|
1152
|
+
|
|
1153
|
+
def generate_masks_tiled(
|
|
1154
|
+
self,
|
|
1155
|
+
source: str,
|
|
1156
|
+
prompt: str,
|
|
1157
|
+
output: str,
|
|
1158
|
+
tile_size: int = 1024,
|
|
1159
|
+
overlap: int = 128,
|
|
1160
|
+
min_size: int = 0,
|
|
1161
|
+
max_size: Optional[int] = None,
|
|
1162
|
+
unique: bool = True,
|
|
1163
|
+
dtype: str = "uint32",
|
|
1164
|
+
bands: Optional[List[int]] = None,
|
|
1165
|
+
batch_size: int = 1,
|
|
1166
|
+
verbose: bool = True,
|
|
1167
|
+
**kwargs: Any,
|
|
1168
|
+
) -> str:
|
|
1169
|
+
"""
|
|
1170
|
+
Generate masks for large GeoTIFF images using a sliding window approach.
|
|
1171
|
+
|
|
1172
|
+
This method processes large images tile by tile to avoid GPU memory issues.
|
|
1173
|
+
The tiles are processed with overlap to ensure seamless mask merging at
|
|
1174
|
+
boundaries. Each detected object gets a unique ID that is consistent
|
|
1175
|
+
across the entire image.
|
|
1176
|
+
|
|
1177
|
+
Args:
|
|
1178
|
+
source (str): Path to the input GeoTIFF image.
|
|
1179
|
+
prompt (str): The text prompt describing the objects to segment.
|
|
1180
|
+
output (str): Path to the output GeoTIFF file.
|
|
1181
|
+
tile_size (int): Size of each tile in pixels. Defaults to 1024.
|
|
1182
|
+
overlap (int): Overlap between adjacent tiles in pixels. Defaults to 128.
|
|
1183
|
+
Higher overlap helps with better boundary merging but increases
|
|
1184
|
+
processing time.
|
|
1185
|
+
min_size (int): Minimum mask size in pixels. Masks smaller than this
|
|
1186
|
+
will be filtered out. Defaults to 0.
|
|
1187
|
+
max_size (int, optional): Maximum mask size in pixels. Masks larger than
|
|
1188
|
+
this will be filtered out. Defaults to None (no maximum).
|
|
1189
|
+
unique (bool): If True, each mask gets a unique value. If False, binary
|
|
1190
|
+
mask (0 or 1). Defaults to True.
|
|
1191
|
+
dtype (str): Data type for the output array. Use 'uint32' for large
|
|
1192
|
+
numbers of objects, 'uint16' for up to 65535 objects, or 'uint8'
|
|
1193
|
+
for up to 255 objects. Defaults to 'uint32'.
|
|
1194
|
+
bands (List[int], optional): List of band indices (1-based) to use for RGB
|
|
1195
|
+
when the input has more than 3 bands. If None, uses first 3 bands.
|
|
1196
|
+
batch_size (int): Number of tiles to process at once (future use).
|
|
1197
|
+
Defaults to 1.
|
|
1198
|
+
verbose (bool): Whether to print progress information. Defaults to True.
|
|
1199
|
+
**kwargs: Additional keyword arguments.
|
|
1200
|
+
|
|
1201
|
+
Returns:
|
|
1202
|
+
str: Path to the output GeoTIFF file.
|
|
1203
|
+
|
|
1204
|
+
Example:
|
|
1205
|
+
>>> sam = SamGeo3(backend="meta")
|
|
1206
|
+
>>> sam.generate_masks_tiled(
|
|
1207
|
+
... source="large_satellite_image.tif",
|
|
1208
|
+
... prompt="building",
|
|
1209
|
+
... output="buildings_mask.tif",
|
|
1210
|
+
... tile_size=1024,
|
|
1211
|
+
... overlap=128,
|
|
1212
|
+
... )
|
|
1213
|
+
"""
|
|
1214
|
+
import rasterio
|
|
1215
|
+
from rasterio.windows import Window
|
|
1216
|
+
|
|
1217
|
+
if not source.lower().endswith((".tif", ".tiff")):
|
|
1218
|
+
raise ValueError("Source must be a GeoTIFF file for tiled processing.")
|
|
1219
|
+
|
|
1220
|
+
if not os.path.exists(source):
|
|
1221
|
+
raise ValueError(f"Source file not found: {source}")
|
|
1222
|
+
|
|
1223
|
+
if tile_size <= overlap:
|
|
1224
|
+
raise ValueError("tile_size must be greater than overlap")
|
|
1225
|
+
|
|
1226
|
+
# Open the source file to get metadata
|
|
1227
|
+
with rasterio.open(source) as src:
|
|
1228
|
+
img_height = src.height
|
|
1229
|
+
img_width = src.width
|
|
1230
|
+
profile = src.profile.copy()
|
|
1231
|
+
|
|
1232
|
+
if verbose:
|
|
1233
|
+
print(f"Processing image: {img_width} x {img_height} pixels")
|
|
1234
|
+
print(f"Tile size: {tile_size}, Overlap: {overlap}")
|
|
1235
|
+
|
|
1236
|
+
# Calculate the number of tiles
|
|
1237
|
+
step = tile_size - overlap
|
|
1238
|
+
n_tiles_x = max(1, (img_width - overlap + step - 1) // step)
|
|
1239
|
+
n_tiles_y = max(1, (img_height - overlap + step - 1) // step)
|
|
1240
|
+
total_tiles = n_tiles_x * n_tiles_y
|
|
1241
|
+
|
|
1242
|
+
if verbose:
|
|
1243
|
+
print(f"Total tiles to process: {total_tiles} ({n_tiles_x} x {n_tiles_y})")
|
|
1244
|
+
|
|
1245
|
+
# Determine output dtype
|
|
1246
|
+
if dtype == "uint8":
|
|
1247
|
+
np_dtype = np.uint8
|
|
1248
|
+
max_objects = 255
|
|
1249
|
+
elif dtype == "uint16":
|
|
1250
|
+
np_dtype = np.uint16
|
|
1251
|
+
max_objects = 65535
|
|
1252
|
+
elif dtype == "uint32":
|
|
1253
|
+
np_dtype = np.uint32
|
|
1254
|
+
max_objects = 4294967295
|
|
1147
1255
|
else:
|
|
1148
|
-
|
|
1256
|
+
np_dtype = np.uint32
|
|
1257
|
+
max_objects = 4294967295
|
|
1258
|
+
|
|
1259
|
+
# Create output array in memory (for smaller images) or use memory-mapped file
|
|
1260
|
+
# For very large images, you might want to use rasterio windowed writing
|
|
1261
|
+
output_mask = np.zeros((img_height, img_width), dtype=np_dtype)
|
|
1262
|
+
|
|
1263
|
+
# Track unique object IDs across all tiles
|
|
1264
|
+
current_max_id = 0
|
|
1265
|
+
total_objects = 0
|
|
1266
|
+
|
|
1267
|
+
# Process each tile
|
|
1268
|
+
tile_iterator = tqdm(
|
|
1269
|
+
range(total_tiles),
|
|
1270
|
+
desc="Processing tiles",
|
|
1271
|
+
disable=not verbose,
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
for tile_idx in tile_iterator:
|
|
1275
|
+
# Calculate tile position
|
|
1276
|
+
tile_y = tile_idx // n_tiles_x
|
|
1277
|
+
tile_x = tile_idx % n_tiles_x
|
|
1278
|
+
|
|
1279
|
+
# Calculate window coordinates
|
|
1280
|
+
x_start = tile_x * step
|
|
1281
|
+
y_start = tile_y * step
|
|
1282
|
+
|
|
1283
|
+
# Ensure we don't go beyond image bounds
|
|
1284
|
+
x_end = min(x_start + tile_size, img_width)
|
|
1285
|
+
y_end = min(y_start + tile_size, img_height)
|
|
1286
|
+
|
|
1287
|
+
# Adjust start if we're at the edge
|
|
1288
|
+
if x_end - x_start < tile_size and x_start > 0:
|
|
1289
|
+
x_start = max(0, x_end - tile_size)
|
|
1290
|
+
if y_end - y_start < tile_size and y_start > 0:
|
|
1291
|
+
y_start = max(0, y_end - tile_size)
|
|
1292
|
+
|
|
1293
|
+
window_width = x_end - x_start
|
|
1294
|
+
window_height = y_end - y_start
|
|
1295
|
+
|
|
1296
|
+
# Read tile from source
|
|
1297
|
+
with rasterio.open(source) as src:
|
|
1298
|
+
window = Window(x_start, y_start, window_width, window_height)
|
|
1299
|
+
if bands is not None:
|
|
1300
|
+
tile_data = np.stack(
|
|
1301
|
+
[src.read(b, window=window) for b in bands], axis=0
|
|
1302
|
+
)
|
|
1303
|
+
else:
|
|
1304
|
+
tile_data = src.read(window=window)
|
|
1305
|
+
if tile_data.shape[0] >= 3:
|
|
1306
|
+
tile_data = tile_data[:3, :, :]
|
|
1307
|
+
elif tile_data.shape[0] == 1:
|
|
1308
|
+
tile_data = np.repeat(tile_data, 3, axis=0)
|
|
1309
|
+
elif tile_data.shape[0] == 2:
|
|
1310
|
+
tile_data = np.concatenate(
|
|
1311
|
+
[tile_data, tile_data[0:1, :, :]], axis=0
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
# Transpose to (height, width, channels)
|
|
1315
|
+
tile_data = np.transpose(tile_data, (1, 2, 0))
|
|
1316
|
+
|
|
1317
|
+
# Normalize to 8-bit
|
|
1318
|
+
tile_data = tile_data.astype(np.float32)
|
|
1319
|
+
tile_data -= tile_data.min()
|
|
1320
|
+
if tile_data.max() > 0:
|
|
1321
|
+
tile_data /= tile_data.max()
|
|
1322
|
+
tile_data *= 255
|
|
1323
|
+
tile_image = tile_data.astype(np.uint8)
|
|
1324
|
+
|
|
1325
|
+
# Process the tile
|
|
1326
|
+
try:
|
|
1327
|
+
# Set image for the tile
|
|
1328
|
+
self.image = tile_image
|
|
1329
|
+
self.image_height, self.image_width = tile_image.shape[:2]
|
|
1330
|
+
self.source = None # Don't need georef for individual tiles
|
|
1331
|
+
|
|
1332
|
+
# Initialize inference state for this tile
|
|
1333
|
+
pil_image = Image.fromarray(tile_image)
|
|
1334
|
+
self.pil_image = pil_image
|
|
1335
|
+
|
|
1336
|
+
if self.backend == "meta":
|
|
1337
|
+
self.inference_state = self.processor.set_image(pil_image)
|
|
1338
|
+
else:
|
|
1339
|
+
# For transformers backend, process directly
|
|
1340
|
+
pass
|
|
1341
|
+
|
|
1342
|
+
# Generate masks for this tile (quiet=True to avoid per-tile messages)
|
|
1343
|
+
self.generate_masks(
|
|
1344
|
+
prompt, min_size=min_size, max_size=max_size, quiet=True
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1347
|
+
# Get masks for this tile
|
|
1348
|
+
tile_masks = self.masks
|
|
1349
|
+
|
|
1350
|
+
if tile_masks is not None and len(tile_masks) > 0:
|
|
1351
|
+
# Create a mask array for this tile
|
|
1352
|
+
tile_mask_array = np.zeros(
|
|
1353
|
+
(window_height, window_width), dtype=np_dtype
|
|
1354
|
+
)
|
|
1355
|
+
|
|
1356
|
+
for mask in tile_masks:
|
|
1357
|
+
# Convert mask to numpy
|
|
1358
|
+
if hasattr(mask, "cpu"):
|
|
1359
|
+
mask_np = mask.squeeze().cpu().numpy()
|
|
1360
|
+
elif hasattr(mask, "numpy"):
|
|
1361
|
+
mask_np = mask.squeeze().numpy()
|
|
1362
|
+
else:
|
|
1363
|
+
mask_np = (
|
|
1364
|
+
mask.squeeze() if hasattr(mask, "squeeze") else mask
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
if mask_np.ndim > 2:
|
|
1368
|
+
mask_np = mask_np[0]
|
|
1369
|
+
|
|
1370
|
+
# Resize mask to tile size if needed
|
|
1371
|
+
if mask_np.shape != (window_height, window_width):
|
|
1372
|
+
mask_np = cv2.resize(
|
|
1373
|
+
mask_np.astype(np.float32),
|
|
1374
|
+
(window_width, window_height),
|
|
1375
|
+
interpolation=cv2.INTER_NEAREST,
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
mask_bool = mask_np > 0
|
|
1379
|
+
mask_size = np.sum(mask_bool)
|
|
1380
|
+
|
|
1381
|
+
# Filter by size
|
|
1382
|
+
if mask_size < min_size:
|
|
1383
|
+
continue
|
|
1384
|
+
if max_size is not None and mask_size > max_size:
|
|
1385
|
+
continue
|
|
1386
|
+
|
|
1387
|
+
if unique:
|
|
1388
|
+
current_max_id += 1
|
|
1389
|
+
if current_max_id > max_objects:
|
|
1390
|
+
raise ValueError(
|
|
1391
|
+
f"Maximum number of objects ({max_objects}) exceeded. "
|
|
1392
|
+
"Consider using a larger dtype or reducing the number of objects."
|
|
1393
|
+
)
|
|
1394
|
+
tile_mask_array[mask_bool] = current_max_id
|
|
1395
|
+
else:
|
|
1396
|
+
tile_mask_array[mask_bool] = 1
|
|
1397
|
+
|
|
1398
|
+
total_objects += 1
|
|
1399
|
+
|
|
1400
|
+
# Merge tile mask into output mask
|
|
1401
|
+
# For overlapping regions, use the tile's values if they are non-zero
|
|
1402
|
+
# This simple approach works well for most cases
|
|
1403
|
+
self._merge_tile_mask(
|
|
1404
|
+
output_mask,
|
|
1405
|
+
tile_mask_array,
|
|
1406
|
+
x_start,
|
|
1407
|
+
y_start,
|
|
1408
|
+
x_end,
|
|
1409
|
+
y_end,
|
|
1410
|
+
overlap,
|
|
1411
|
+
tile_x,
|
|
1412
|
+
tile_y,
|
|
1413
|
+
n_tiles_x,
|
|
1414
|
+
n_tiles_y,
|
|
1415
|
+
)
|
|
1416
|
+
|
|
1417
|
+
except Exception as e:
|
|
1418
|
+
if verbose:
|
|
1419
|
+
print(f"Warning: Failed to process tile ({tile_x}, {tile_y}): {e}")
|
|
1420
|
+
continue
|
|
1421
|
+
|
|
1422
|
+
# Clear GPU memory
|
|
1423
|
+
self.masks = None
|
|
1424
|
+
self.boxes = None
|
|
1425
|
+
self.scores = None
|
|
1426
|
+
if hasattr(self, "inference_state"):
|
|
1427
|
+
self.inference_state = None
|
|
1428
|
+
# Additionally clear PyTorch CUDA cache, if available, to free GPU memory
|
|
1429
|
+
try:
|
|
1430
|
+
import torch
|
|
1431
|
+
|
|
1432
|
+
if torch.cuda.is_available():
|
|
1433
|
+
torch.cuda.empty_cache()
|
|
1434
|
+
except ImportError:
|
|
1435
|
+
# If torch is not installed, skip CUDA cache clearing
|
|
1436
|
+
pass
|
|
1437
|
+
# Update output profile
|
|
1438
|
+
profile.update(
|
|
1439
|
+
{
|
|
1440
|
+
"count": 1,
|
|
1441
|
+
"dtype": dtype,
|
|
1442
|
+
"compress": "deflate",
|
|
1443
|
+
}
|
|
1444
|
+
)
|
|
1445
|
+
|
|
1446
|
+
# Save the output
|
|
1447
|
+
with rasterio.open(output, "w", **profile) as dst:
|
|
1448
|
+
dst.write(output_mask, 1)
|
|
1449
|
+
|
|
1450
|
+
if verbose:
|
|
1451
|
+
print(f"Saved mask to {output}")
|
|
1452
|
+
print(f"Total objects found: {total_objects}")
|
|
1453
|
+
|
|
1454
|
+
# Store result for potential visualization
|
|
1455
|
+
self.objects = output_mask
|
|
1456
|
+
self.source = source
|
|
1457
|
+
|
|
1458
|
+
return output
|
|
1459
|
+
|
|
1460
|
+
def _merge_tile_mask(
|
|
1461
|
+
self,
|
|
1462
|
+
output_mask: np.ndarray,
|
|
1463
|
+
tile_mask: np.ndarray,
|
|
1464
|
+
x_start: int,
|
|
1465
|
+
y_start: int,
|
|
1466
|
+
x_end: int,
|
|
1467
|
+
y_end: int,
|
|
1468
|
+
overlap: int,
|
|
1469
|
+
tile_x: int,
|
|
1470
|
+
tile_y: int,
|
|
1471
|
+
n_tiles_x: int,
|
|
1472
|
+
n_tiles_y: int,
|
|
1473
|
+
) -> None:
|
|
1474
|
+
"""
|
|
1475
|
+
Merge a tile mask into the output mask, handling overlapping regions.
|
|
1476
|
+
|
|
1477
|
+
For overlapping regions, this uses a blending approach where we prioritize
|
|
1478
|
+
the current tile's mask in the non-overlapping core region, and for the
|
|
1479
|
+
overlap region, we keep existing values unless they are zero.
|
|
1480
|
+
|
|
1481
|
+
Args:
|
|
1482
|
+
output_mask: The full output mask array.
|
|
1483
|
+
tile_mask: The mask from the current tile.
|
|
1484
|
+
x_start, y_start: Start coordinates of the tile in the output.
|
|
1485
|
+
x_end, y_end: End coordinates of the tile in the output.
|
|
1486
|
+
overlap: The overlap size.
|
|
1487
|
+
tile_x, tile_y: Tile indices.
|
|
1488
|
+
n_tiles_x, n_tiles_y: Total number of tiles in each direction.
|
|
1489
|
+
"""
|
|
1490
|
+
tile_height = y_end - y_start
|
|
1491
|
+
tile_width = x_end - x_start
|
|
1492
|
+
|
|
1493
|
+
# Calculate the core region (non-overlapping part)
|
|
1494
|
+
# The overlap should be split between adjacent tiles
|
|
1495
|
+
left_overlap = overlap // 2 if tile_x > 0 else 0
|
|
1496
|
+
right_overlap = overlap // 2 if tile_x < n_tiles_x - 1 else 0
|
|
1497
|
+
top_overlap = overlap // 2 if tile_y > 0 else 0
|
|
1498
|
+
bottom_overlap = overlap // 2 if tile_y < n_tiles_y - 1 else 0
|
|
1499
|
+
|
|
1500
|
+
# Core region in tile coordinates
|
|
1501
|
+
core_x_start = left_overlap
|
|
1502
|
+
core_x_end = tile_width - right_overlap
|
|
1503
|
+
core_y_start = top_overlap
|
|
1504
|
+
core_y_end = tile_height - bottom_overlap
|
|
1505
|
+
|
|
1506
|
+
# Copy core region (always overwrite)
|
|
1507
|
+
out_y_start = y_start + core_y_start
|
|
1508
|
+
out_y_end = y_start + core_y_end
|
|
1509
|
+
out_x_start = x_start + core_x_start
|
|
1510
|
+
out_x_end = x_start + core_x_end
|
|
1511
|
+
|
|
1512
|
+
output_mask[out_y_start:out_y_end, out_x_start:out_x_end] = tile_mask[
|
|
1513
|
+
core_y_start:core_y_end, core_x_start:core_x_end
|
|
1514
|
+
]
|
|
1515
|
+
|
|
1516
|
+
# Handle overlap regions - only update if output is zero
|
|
1517
|
+
# Top overlap
|
|
1518
|
+
if top_overlap > 0:
|
|
1519
|
+
region = output_mask[y_start : y_start + top_overlap, out_x_start:out_x_end]
|
|
1520
|
+
tile_region = tile_mask[0:top_overlap, core_x_start:core_x_end]
|
|
1521
|
+
mask = region == 0
|
|
1522
|
+
region[mask] = tile_region[mask]
|
|
1523
|
+
|
|
1524
|
+
# Bottom overlap
|
|
1525
|
+
if bottom_overlap > 0:
|
|
1526
|
+
region = output_mask[out_y_end:y_end, out_x_start:out_x_end]
|
|
1527
|
+
tile_region = tile_mask[core_y_end:tile_height, core_x_start:core_x_end]
|
|
1528
|
+
mask = region == 0
|
|
1529
|
+
region[mask] = tile_region[mask]
|
|
1530
|
+
|
|
1531
|
+
# Left overlap
|
|
1532
|
+
if left_overlap > 0:
|
|
1533
|
+
region = output_mask[
|
|
1534
|
+
out_y_start:out_y_end, x_start : x_start + left_overlap
|
|
1535
|
+
]
|
|
1536
|
+
tile_region = tile_mask[core_y_start:core_y_end, 0:left_overlap]
|
|
1537
|
+
mask = region == 0
|
|
1538
|
+
region[mask] = tile_region[mask]
|
|
1539
|
+
|
|
1540
|
+
# Right overlap
|
|
1541
|
+
if right_overlap > 0:
|
|
1542
|
+
region = output_mask[out_y_start:out_y_end, out_x_end:x_end]
|
|
1543
|
+
tile_region = tile_mask[core_y_start:core_y_end, core_x_end:tile_width]
|
|
1544
|
+
mask = region == 0
|
|
1545
|
+
region[mask] = tile_region[mask]
|
|
1546
|
+
|
|
1547
|
+
# Corner overlaps
|
|
1548
|
+
# Top-left
|
|
1549
|
+
if top_overlap > 0 and left_overlap > 0:
|
|
1550
|
+
region = output_mask[
|
|
1551
|
+
y_start : y_start + top_overlap, x_start : x_start + left_overlap
|
|
1552
|
+
]
|
|
1553
|
+
tile_region = tile_mask[0:top_overlap, 0:left_overlap]
|
|
1554
|
+
mask = region == 0
|
|
1555
|
+
region[mask] = tile_region[mask]
|
|
1556
|
+
|
|
1557
|
+
# Top-right
|
|
1558
|
+
if top_overlap > 0 and right_overlap > 0:
|
|
1559
|
+
region = output_mask[y_start : y_start + top_overlap, out_x_end:x_end]
|
|
1560
|
+
tile_region = tile_mask[0:top_overlap, core_x_end:tile_width]
|
|
1561
|
+
mask = region == 0
|
|
1562
|
+
region[mask] = tile_region[mask]
|
|
1563
|
+
|
|
1564
|
+
# Bottom-left
|
|
1565
|
+
if bottom_overlap > 0 and left_overlap > 0:
|
|
1566
|
+
region = output_mask[out_y_end:y_end, x_start : x_start + left_overlap]
|
|
1567
|
+
tile_region = tile_mask[core_y_end:tile_height, 0:left_overlap]
|
|
1568
|
+
mask = region == 0
|
|
1569
|
+
region[mask] = tile_region[mask]
|
|
1570
|
+
|
|
1571
|
+
# Bottom-right
|
|
1572
|
+
if bottom_overlap > 0 and right_overlap > 0:
|
|
1573
|
+
region = output_mask[out_y_end:y_end, out_x_end:x_end]
|
|
1574
|
+
tile_region = tile_mask[core_y_end:tile_height, core_x_end:tile_width]
|
|
1575
|
+
mask = region == 0
|
|
1576
|
+
region[mask] = tile_region[mask]
|
|
1149
1577
|
|
|
1150
1578
|
def generate_masks_by_boxes(
|
|
1151
1579
|
self,
|
|
@@ -2294,7 +2722,9 @@ class SamGeo3:
|
|
|
2294
2722
|
# Use the same color generation as the original method for consistency
|
|
2295
2723
|
COLORS = generate_colors(n_colors=128, n_samples=5000)
|
|
2296
2724
|
# Convert from 0-1 float RGB to 0-255 int RGB for OpenCV
|
|
2297
|
-
colors_rgb = [
|
|
2725
|
+
colors_rgb = [
|
|
2726
|
+
(int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)) for c in COLORS
|
|
2727
|
+
]
|
|
2298
2728
|
|
|
2299
2729
|
# Create overlay for all masks
|
|
2300
2730
|
overlay = np.zeros((h, w, 3), dtype=np.float32)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: segment-geospatial
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: Meta AI' Segment Anything Model (SAM) for Geospatial Data.
|
|
5
5
|
Author-email: Qiusheng Wu <giswqs@gmail.com>
|
|
6
6
|
License: MIT license
|
|
@@ -29,6 +29,7 @@ Requires-Dist: pyproj
|
|
|
29
29
|
Requires-Dist: rasterio
|
|
30
30
|
Requires-Dist: segment_anything
|
|
31
31
|
Requires-Dist: shapely
|
|
32
|
+
Requires-Dist: smoothify
|
|
32
33
|
Requires-Dist: torch
|
|
33
34
|
Requires-Dist: torchvision
|
|
34
35
|
Requires-Dist: tqdm
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
samgeo/__init__.py,sha256=
|
|
1
|
+
samgeo/__init__.py,sha256=S1K0lIKUWsZb7mnBnLCV9qA6IpZjM9QjFQjRxO3ti6Y,245
|
|
2
2
|
samgeo/caption.py,sha256=Ttn9KwIEiz4n-jJU2d4o0HCkloo_6JSDev272yULWbo,20053
|
|
3
|
-
samgeo/common.py,sha256=
|
|
3
|
+
samgeo/common.py,sha256=E1dvoWP2B-qf6LULH0mVHcnyj5OV6G64tilZFtvGEqs,149721
|
|
4
4
|
samgeo/fast_sam.py,sha256=iFAaY4XXTtNCSSFCYFfljfWbLXZlMO4ZJQ5BsTmoeT8,10964
|
|
5
5
|
samgeo/fer.py,sha256=gW4kVhxDGZlU48Fl46O45svtVmRpg13fyIXhJBS3Mi8,34652
|
|
6
6
|
samgeo/hq_sam.py,sha256=qmwApENccnKXIqnHDmzA-_c_XvaUwCruVz4zWgF_sA4,34861
|
|
7
7
|
samgeo/samgeo.py,sha256=ZLDGlaCYHTLFod-2W8OoIXRFLRRP1lpsBjLlIv_hwgU,37417
|
|
8
8
|
samgeo/samgeo2.py,sha256=yMZwAh89pomAGNegtIVWJPCpHyJYLUOw3MFnaZZg_Ec,68896
|
|
9
|
-
samgeo/samgeo3.py,sha256=
|
|
9
|
+
samgeo/samgeo3.py,sha256=R3Vvf6cRjW6gD622crE2zVgzLdmYqlUmIYcji0qhXLo,203736
|
|
10
10
|
samgeo/text_sam.py,sha256=nFtwEoyr_jDUrjoz19aEgHuUU9occFrcwbiHDIrEeRs,27006
|
|
11
11
|
samgeo/utmconv.py,sha256=hFa7WYdZ83vJh-QX8teuu-xZEMJxZ9b795-Pwaq_rYs,5816
|
|
12
|
-
segment_geospatial-1.
|
|
13
|
-
segment_geospatial-1.
|
|
14
|
-
segment_geospatial-1.
|
|
15
|
-
segment_geospatial-1.
|
|
16
|
-
segment_geospatial-1.
|
|
12
|
+
segment_geospatial-1.2.0.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
|
13
|
+
segment_geospatial-1.2.0.dist-info/METADATA,sha256=_3CdjEGM41gkq6rncDAv4mY-TnU1hovjk__LI4pXOzs,15628
|
|
14
|
+
segment_geospatial-1.2.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
|
15
|
+
segment_geospatial-1.2.0.dist-info/top_level.txt,sha256=qmuxRrJ2s8MYWBdoBivlvEHHQkmhqLlXWqOyzEGP9XA,7
|
|
16
|
+
segment_geospatial-1.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|