geoai-py 0.10.0__py2.py3-none-any.whl → 0.11.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +2 -1
- geoai/dinov3.py +1146 -0
- geoai/geoai.py +11 -0
- geoai/map_widgets.py +174 -0
- geoai/train.py +8 -2
- geoai/utils.py +13 -1
- {geoai_py-0.10.0.dist-info → geoai_py-0.11.1.dist-info}/METADATA +1 -1
- geoai_py-0.11.1.dist-info/RECORD +21 -0
- geoai_py-0.10.0.dist-info/RECORD +0 -19
- {geoai_py-0.10.0.dist-info → geoai_py-0.11.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.10.0.dist-info → geoai_py-0.11.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.10.0.dist-info → geoai_py-0.11.1.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.10.0.dist-info → geoai_py-0.11.1.dist-info}/top_level.txt +0 -0
geoai/dinov3.py
ADDED
@@ -0,0 +1,1146 @@
|
|
1
|
+
"""DINOv3 module for patch similarity analysis with GeoTIFF support.
|
2
|
+
|
3
|
+
This module provides tools for computing patch similarity using DINOv3 features
|
4
|
+
on geospatial imagery stored in GeoTIFF format.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import json
|
8
|
+
import math
|
9
|
+
import os
|
10
|
+
import sys
|
11
|
+
from typing import Tuple, Optional, Dict, List, Union
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
from PIL import Image
|
15
|
+
import torch
|
16
|
+
import torch.nn.functional as F
|
17
|
+
import torchvision.transforms as transforms
|
18
|
+
import rasterio
|
19
|
+
from rasterio.windows import Window
|
20
|
+
from rasterio.io import DatasetReader
|
21
|
+
import matplotlib.pyplot as plt
|
22
|
+
import matplotlib.patches as patches
|
23
|
+
|
24
|
+
from huggingface_hub import hf_hub_download
|
25
|
+
|
26
|
+
from .utils import get_device, coords_to_xy, dict_to_image, dict_to_rioxarray
|
27
|
+
|
28
|
+
|
29
|
+
class DINOv3GeoProcessor:
|
30
|
+
"""DINOv3 processor with GeoTIFF input/output support.
|
31
|
+
https://github.com/facebookresearch/dinov3
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
model_name: str = "dinov3_vitl16",
|
37
|
+
weights_path: Optional[str] = None,
|
38
|
+
device: Optional[torch.device] = None,
|
39
|
+
):
|
40
|
+
"""Initialize DINOv3 processor.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
model_name: Name of the DINOv3 model. Can be "dinov3_vits16", "dinov3_vits16plus",
|
44
|
+
"dinov3_vitb16", "dinov3_vitl16", "dinov3_vith16plus", "dinov3_vit7b16", "dinov3_convnext_tiny",
|
45
|
+
"dinov3_convnext_small", "dinov3_convnext_base", "dinov3_convnext_large",
|
46
|
+
"dinov3dinov3_vitl16", and "dinov3_vit7b16".
|
47
|
+
See https://github.com/facebookresearch/dinov3 for more details.
|
48
|
+
weights_path: Path to model weights (optional)
|
49
|
+
device: Torch device to use
|
50
|
+
dinov3_location: Path to DINOv3 repository
|
51
|
+
"""
|
52
|
+
|
53
|
+
dinov3_github_location = "facebookresearch/dinov3"
|
54
|
+
|
55
|
+
if os.getenv("DINOV3_LOCATION") is not None:
|
56
|
+
dinov3_location = os.getenv("DINOV3_LOCATION")
|
57
|
+
else:
|
58
|
+
dinov3_location = dinov3_github_location
|
59
|
+
|
60
|
+
self.dinov3_location = dinov3_location
|
61
|
+
self.dinov3_source = (
|
62
|
+
"local" if dinov3_location != dinov3_github_location else "github"
|
63
|
+
)
|
64
|
+
|
65
|
+
self.device = device or get_device()
|
66
|
+
self.model_name = model_name
|
67
|
+
|
68
|
+
# Add DINOv3 to path if needed
|
69
|
+
if dinov3_location != "facebookresearch/dinov3" and (
|
70
|
+
dinov3_location not in sys.path
|
71
|
+
):
|
72
|
+
sys.path.append(dinov3_location)
|
73
|
+
|
74
|
+
# Load model
|
75
|
+
self.model = self._load_model(weights_path)
|
76
|
+
self.patch_size = self.model.patch_size
|
77
|
+
self.embed_dim = self.model.embed_dim
|
78
|
+
|
79
|
+
# Image transforms - satellite imagery normalization
|
80
|
+
self.transform = transforms.Compose(
|
81
|
+
[
|
82
|
+
transforms.ToTensor(),
|
83
|
+
transforms.Normalize(
|
84
|
+
mean=(0.430, 0.411, 0.296), # SAT-493M normalization
|
85
|
+
std=(0.213, 0.156, 0.143),
|
86
|
+
),
|
87
|
+
]
|
88
|
+
)
|
89
|
+
|
90
|
+
def _download_model_from_hf(
|
91
|
+
self, model_path: Optional[str] = None, repo_id: Optional[str] = None
|
92
|
+
) -> str:
|
93
|
+
"""
|
94
|
+
Download the object detection model from Hugging Face.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
model_path: Path to the model file.
|
98
|
+
repo_id: Hugging Face repository ID.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
Path to the downloaded model file
|
102
|
+
"""
|
103
|
+
try:
|
104
|
+
|
105
|
+
# Define the repository ID and model filename
|
106
|
+
if repo_id is None:
|
107
|
+
repo_id = "giswqs/geoai"
|
108
|
+
|
109
|
+
if model_path is None:
|
110
|
+
model_path = "dinov3_vitl16_sat493m.pth"
|
111
|
+
|
112
|
+
# Download the model
|
113
|
+
model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
|
114
|
+
|
115
|
+
return model_path
|
116
|
+
|
117
|
+
except Exception as e:
|
118
|
+
print(f"Error downloading model from Hugging Face: {e}")
|
119
|
+
print("Please specify a local model path or ensure internet connectivity.")
|
120
|
+
raise
|
121
|
+
|
122
|
+
def _load_model(self, weights_path: Optional[str] = None) -> torch.nn.Module:
|
123
|
+
"""Load DINOv3 model."""
|
124
|
+
try:
|
125
|
+
if weights_path and os.path.exists(weights_path):
|
126
|
+
# Load with custom weights
|
127
|
+
model = torch.hub.load(
|
128
|
+
repo_or_dir=self.dinov3_location,
|
129
|
+
model=self.model_name,
|
130
|
+
source=self.dinov3_source,
|
131
|
+
)
|
132
|
+
# Load state dict manually
|
133
|
+
state_dict = torch.load(weights_path, map_location=self.device)
|
134
|
+
model.load_state_dict(state_dict, strict=False)
|
135
|
+
else:
|
136
|
+
# Download weights and load manually
|
137
|
+
weights_path = self._download_model_from_hf()
|
138
|
+
model = torch.hub.load(
|
139
|
+
repo_or_dir=self.dinov3_location,
|
140
|
+
model=self.model_name,
|
141
|
+
source=self.dinov3_source,
|
142
|
+
)
|
143
|
+
# Load state dict manually
|
144
|
+
state_dict = torch.load(weights_path, map_location=self.device)
|
145
|
+
model.load_state_dict(state_dict, strict=False)
|
146
|
+
|
147
|
+
model = model.to(self.device)
|
148
|
+
model.eval()
|
149
|
+
return model
|
150
|
+
except Exception as e:
|
151
|
+
raise RuntimeError(f"Failed to load DINOv3 model: {e}") from e
|
152
|
+
|
153
|
+
def load_regular_image(
|
154
|
+
self,
|
155
|
+
image_path: str,
|
156
|
+
) -> Tuple[np.ndarray, dict]:
|
157
|
+
"""Load regular image file (PNG, JPG, etc.).
|
158
|
+
|
159
|
+
Args:
|
160
|
+
image_path: Path to image file
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
Tuple of (image array, metadata)
|
164
|
+
"""
|
165
|
+
try:
|
166
|
+
# Load image using PIL
|
167
|
+
image = Image.open(image_path).convert("RGB")
|
168
|
+
|
169
|
+
# Convert to numpy array (H, W, C)
|
170
|
+
img_array = np.array(image)
|
171
|
+
|
172
|
+
# Convert to (C, H, W) format to match GeoTIFF format
|
173
|
+
data = np.transpose(img_array, (2, 0, 1)).astype(np.uint8)
|
174
|
+
|
175
|
+
# Create basic metadata
|
176
|
+
height, width = img_array.shape[:2]
|
177
|
+
metadata = {
|
178
|
+
"profile": {
|
179
|
+
"driver": "PNG",
|
180
|
+
"dtype": "uint8",
|
181
|
+
"nodata": None,
|
182
|
+
"width": width,
|
183
|
+
"height": height,
|
184
|
+
"count": 3,
|
185
|
+
"crs": None,
|
186
|
+
"transform": None,
|
187
|
+
},
|
188
|
+
"crs": None,
|
189
|
+
"transform": None,
|
190
|
+
"bounds": (0, 0, width, height),
|
191
|
+
}
|
192
|
+
|
193
|
+
return data, metadata
|
194
|
+
|
195
|
+
except Exception as e:
|
196
|
+
raise RuntimeError(f"Failed to load image {image_path}: {e}")
|
197
|
+
|
198
|
+
def load_geotiff(
|
199
|
+
self,
|
200
|
+
source: Union[str, DatasetReader],
|
201
|
+
window: Optional[Window] = None,
|
202
|
+
bands: Optional[List[int]] = None,
|
203
|
+
) -> Tuple[np.ndarray, dict]:
|
204
|
+
"""Load GeoTIFF file.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
source: Path to GeoTIFF file (str) or an open rasterio.DatasetReader
|
208
|
+
window: Rasterio window for reading subset
|
209
|
+
bands: List of bands to read (1-indexed)
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Tuple of (image array, metadata)
|
213
|
+
"""
|
214
|
+
# Flag to determine if we need to close the dataset afterwards
|
215
|
+
should_close = False
|
216
|
+
if isinstance(source, str):
|
217
|
+
src = rasterio.open(source)
|
218
|
+
should_close = True
|
219
|
+
elif isinstance(source, DatasetReader):
|
220
|
+
src = source
|
221
|
+
else:
|
222
|
+
raise TypeError("source must be a str path or a rasterio.DatasetReader")
|
223
|
+
|
224
|
+
try:
|
225
|
+
# Read specified bands or all bands
|
226
|
+
if bands:
|
227
|
+
data = src.read(bands, window=window)
|
228
|
+
else:
|
229
|
+
data = src.read(window=window)
|
230
|
+
|
231
|
+
# Get metadata
|
232
|
+
profile = src.profile.copy()
|
233
|
+
if window:
|
234
|
+
profile.update(
|
235
|
+
{
|
236
|
+
"height": window.height,
|
237
|
+
"width": window.width,
|
238
|
+
"transform": src.window_transform(window),
|
239
|
+
}
|
240
|
+
)
|
241
|
+
|
242
|
+
metadata = {
|
243
|
+
"profile": profile,
|
244
|
+
"crs": src.crs,
|
245
|
+
"transform": profile["transform"],
|
246
|
+
"bounds": (
|
247
|
+
src.bounds
|
248
|
+
if not window
|
249
|
+
else rasterio.windows.bounds(window, src.transform)
|
250
|
+
),
|
251
|
+
}
|
252
|
+
finally:
|
253
|
+
if should_close:
|
254
|
+
src.close()
|
255
|
+
|
256
|
+
return data, metadata
|
257
|
+
|
258
|
+
def load_image(
|
259
|
+
self,
|
260
|
+
source: Union[str, DatasetReader],
|
261
|
+
window: Optional[Window] = None,
|
262
|
+
bands: Optional[List[int]] = None,
|
263
|
+
) -> Tuple[np.ndarray, dict]:
|
264
|
+
"""Load image file (GeoTIFF or regular image).
|
265
|
+
|
266
|
+
Args:
|
267
|
+
source: Path to image file (str) or an open rasterio.DatasetReader
|
268
|
+
window: Rasterio window for reading subset (only applies to GeoTIFF)
|
269
|
+
bands: List of bands to read (only applies to GeoTIFF)
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
Tuple of (image array, metadata)
|
273
|
+
"""
|
274
|
+
if isinstance(source, str):
|
275
|
+
# Check if it's a GeoTIFF file
|
276
|
+
try:
|
277
|
+
# Try to open with rasterio first
|
278
|
+
with rasterio.open(source) as src:
|
279
|
+
# If successful and has CRS, treat as GeoTIFF
|
280
|
+
if src.crs is not None:
|
281
|
+
return self.load_geotiff(source, window, bands)
|
282
|
+
# If no CRS, it might be a regular image opened by rasterio
|
283
|
+
else:
|
284
|
+
# Check file extension
|
285
|
+
file_ext = source.lower().split(".")[-1]
|
286
|
+
if file_ext in ["tif", "tiff"]:
|
287
|
+
return self.load_geotiff(source, window, bands)
|
288
|
+
else:
|
289
|
+
return self.load_regular_image(source)
|
290
|
+
except (rasterio.RasterioIOError, rasterio.errors.RasterioIOError):
|
291
|
+
# If rasterio fails, try as regular image
|
292
|
+
return self.load_regular_image(source)
|
293
|
+
elif isinstance(source, DatasetReader):
|
294
|
+
# Already opened rasterio dataset
|
295
|
+
return self.load_geotiff(source, window, bands)
|
296
|
+
else:
|
297
|
+
raise TypeError("source must be a str path or a rasterio.DatasetReader")
|
298
|
+
|
299
|
+
def save_geotiff(
|
300
|
+
self, data: np.ndarray, output_path: str, metadata: dict, dtype: str = "float32"
|
301
|
+
) -> None:
|
302
|
+
"""Save array as GeoTIFF.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
data: Array to save
|
306
|
+
output_path: Output file path
|
307
|
+
metadata: Metadata from original file
|
308
|
+
dtype: Output data type
|
309
|
+
"""
|
310
|
+
profile = metadata["profile"].copy()
|
311
|
+
profile.update(
|
312
|
+
{
|
313
|
+
"dtype": dtype,
|
314
|
+
"count": data.shape[0] if data.ndim == 3 else 1,
|
315
|
+
"height": data.shape[-2] if data.ndim >= 2 else data.shape[0],
|
316
|
+
"width": data.shape[-1] if data.ndim >= 2 else 1,
|
317
|
+
}
|
318
|
+
)
|
319
|
+
|
320
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
321
|
+
if data.ndim == 2:
|
322
|
+
dst.write(data, 1)
|
323
|
+
else:
|
324
|
+
dst.write(data)
|
325
|
+
|
326
|
+
def save_similarity_as_image(
|
327
|
+
self, similarity_data: np.ndarray, output_path: str, colormap: str = "turbo"
|
328
|
+
) -> None:
|
329
|
+
"""Save similarity array as PNG image with colormap.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
similarity_data: 2D similarity array
|
333
|
+
output_path: Output file path
|
334
|
+
colormap: Matplotlib colormap name
|
335
|
+
"""
|
336
|
+
import matplotlib.pyplot as plt
|
337
|
+
|
338
|
+
# Apply colormap
|
339
|
+
cmap = plt.get_cmap(colormap)
|
340
|
+
colored_data = cmap(similarity_data)
|
341
|
+
|
342
|
+
# Convert to uint8 image (remove alpha channel)
|
343
|
+
img_data = (colored_data[..., :3] * 255).astype(np.uint8)
|
344
|
+
|
345
|
+
# Save as PNG
|
346
|
+
img = Image.fromarray(img_data)
|
347
|
+
img.save(output_path)
|
348
|
+
|
349
|
+
def preprocess_image_for_dinov3(
|
350
|
+
self,
|
351
|
+
data: np.ndarray,
|
352
|
+
target_size: int = 896,
|
353
|
+
normalize_percentile: bool = True,
|
354
|
+
) -> Image.Image:
|
355
|
+
"""Preprocess image data for DINOv3.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
data: Input array (C, H, W) or (H, W)
|
359
|
+
target_size: Target size for resizing
|
360
|
+
normalize_percentile: Whether to normalize using percentiles
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
PIL Image ready for DINOv3
|
364
|
+
"""
|
365
|
+
# Handle different input shapes
|
366
|
+
if data.ndim == 2:
|
367
|
+
data = data[np.newaxis, :, :] # Add channel dimension
|
368
|
+
elif data.ndim == 3 and data.shape[0] > 3:
|
369
|
+
# Take first 3 bands if more than 3 channels
|
370
|
+
data = data[:3, :, :]
|
371
|
+
|
372
|
+
# Normalize data
|
373
|
+
if normalize_percentile:
|
374
|
+
# Normalize each band using percentiles
|
375
|
+
normalized_data = np.zeros_like(data, dtype=np.float32)
|
376
|
+
for i in range(data.shape[0]):
|
377
|
+
band = data[i]
|
378
|
+
p2, p98 = np.percentile(band, [2, 98])
|
379
|
+
normalized_data[i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
380
|
+
else:
|
381
|
+
# Simple min-max normalization
|
382
|
+
normalized_data = (data - data.min()) / (data.max() - data.min())
|
383
|
+
|
384
|
+
# Convert to PIL Image
|
385
|
+
if normalized_data.shape[0] == 1:
|
386
|
+
# Grayscale - repeat to 3 channels
|
387
|
+
img_array = np.repeat(normalized_data[0], 3, axis=0)
|
388
|
+
else:
|
389
|
+
img_array = normalized_data
|
390
|
+
|
391
|
+
# Transpose to HWC format and convert to uint8
|
392
|
+
img_array = np.transpose(img_array, (1, 2, 0))
|
393
|
+
img_array = (img_array * 255).astype(np.uint8)
|
394
|
+
|
395
|
+
# Create PIL Image
|
396
|
+
image = Image.fromarray(img_array)
|
397
|
+
|
398
|
+
# Resize to patch-aligned dimensions
|
399
|
+
return self.resize_to_patch_aligned(image, target_size)
|
400
|
+
|
401
|
+
def resize_to_patch_aligned(
|
402
|
+
self, image: Image.Image, target_size: int = 896
|
403
|
+
) -> Image.Image:
|
404
|
+
"""Resize image to be aligned with patch grid."""
|
405
|
+
w, h = image.size
|
406
|
+
|
407
|
+
# Calculate new dimensions that are multiples of patch_size
|
408
|
+
if w > h:
|
409
|
+
new_h = target_size
|
410
|
+
new_w = int((w * target_size) / h)
|
411
|
+
else:
|
412
|
+
new_w = target_size
|
413
|
+
new_h = int((h * target_size) / w)
|
414
|
+
|
415
|
+
# Round to nearest multiple of patch_size
|
416
|
+
new_h = ((new_h + self.patch_size - 1) // self.patch_size) * self.patch_size
|
417
|
+
new_w = ((new_w + self.patch_size - 1) // self.patch_size) * self.patch_size
|
418
|
+
|
419
|
+
return image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
420
|
+
|
421
|
+
def extract_features(self, image: Image.Image) -> Tuple[torch.Tensor, int, int]:
|
422
|
+
"""Extract patch features from image."""
|
423
|
+
|
424
|
+
if isinstance(image, str):
|
425
|
+
image = Image.open(image)
|
426
|
+
|
427
|
+
if isinstance(image, np.ndarray):
|
428
|
+
image = Image.fromarray(image)
|
429
|
+
|
430
|
+
# Transform image
|
431
|
+
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
432
|
+
|
433
|
+
with torch.no_grad():
|
434
|
+
# Extract features from last layer
|
435
|
+
features = self.model.get_intermediate_layers(
|
436
|
+
img_tensor, n=1, reshape=True, norm=True
|
437
|
+
)[
|
438
|
+
0
|
439
|
+
] # Shape: [1, embed_dim, h_patches, w_patches]
|
440
|
+
|
441
|
+
# Rearrange to [h_patches, w_patches, embed_dim]
|
442
|
+
features = features.squeeze(0).permute(1, 2, 0)
|
443
|
+
h_patches, w_patches = features.shape[:2]
|
444
|
+
|
445
|
+
return features, h_patches, w_patches
|
446
|
+
|
447
|
+
def compute_patch_similarity(
|
448
|
+
self, features: torch.Tensor, patch_x: int, patch_y: int
|
449
|
+
) -> torch.Tensor:
|
450
|
+
"""Compute cosine similarity between selected patch and all patches."""
|
451
|
+
h_patches, w_patches, embed_dim = features.shape
|
452
|
+
|
453
|
+
# Get query patch feature
|
454
|
+
query_feature = features[patch_y, patch_x] # Shape: [embed_dim]
|
455
|
+
|
456
|
+
# Reshape features for batch computation
|
457
|
+
all_features = features.view(
|
458
|
+
-1, embed_dim
|
459
|
+
) # Shape: [h_patches * w_patches, embed_dim]
|
460
|
+
|
461
|
+
# Compute cosine similarity
|
462
|
+
similarities = F.cosine_similarity(
|
463
|
+
query_feature.unsqueeze(0), # Shape: [1, embed_dim]
|
464
|
+
all_features, # Shape: [h_patches * w_patches, embed_dim]
|
465
|
+
dim=1,
|
466
|
+
)
|
467
|
+
|
468
|
+
# Reshape back to patch grid
|
469
|
+
similarities = similarities.view(h_patches, w_patches)
|
470
|
+
|
471
|
+
# Normalize to 0-1 range
|
472
|
+
similarities = (similarities + 1) / 2
|
473
|
+
|
474
|
+
return similarities
|
475
|
+
|
476
|
+
def compute_similarity(
|
477
|
+
self,
|
478
|
+
source: str = None,
|
479
|
+
features: torch.Tensor = None,
|
480
|
+
query_coords: Tuple[float, float] = None,
|
481
|
+
output_dir: str = None,
|
482
|
+
window: Optional[Window] = None,
|
483
|
+
bands: Optional[List[int]] = None,
|
484
|
+
target_size: int = 896,
|
485
|
+
save_features: bool = False,
|
486
|
+
coord_crs: str = None,
|
487
|
+
use_interpolation: bool = True,
|
488
|
+
) -> Dict[str, np.ndarray]:
|
489
|
+
"""Process GeoTIFF for patch similarity analysis.
|
490
|
+
|
491
|
+
Args:
|
492
|
+
source: Path to input GeoTIFF or rasterio dataset
|
493
|
+
features: Pre-extracted features (h_patches, w_patches, embed_dim)
|
494
|
+
query_coords: (x, y) coordinates in image pixel space or (lon, lat) in geographic space
|
495
|
+
output_dir: Output directory for results
|
496
|
+
window: Optional window for reading subset
|
497
|
+
bands: Optional list of bands to use
|
498
|
+
target_size: Target size for processing
|
499
|
+
save_features: Whether to save extracted features
|
500
|
+
coord_crs: Coordinate CRS of the query coordinates
|
501
|
+
use_interpolation: Whether to use interpolation when resizing similarity map
|
502
|
+
|
503
|
+
Returns:
|
504
|
+
Dictionary containing similarity results and metadata
|
505
|
+
"""
|
506
|
+
os.makedirs(output_dir, exist_ok=True)
|
507
|
+
|
508
|
+
# Load image (GeoTIFF or regular image)
|
509
|
+
data, metadata = self.load_image(source, window, bands)
|
510
|
+
raw_img_w, raw_img_h = data.shape[-1], data.shape[-2]
|
511
|
+
|
512
|
+
# Preprocess for DINOv3
|
513
|
+
image = self.preprocess_image_for_dinov3(data, target_size)
|
514
|
+
|
515
|
+
# Extract features
|
516
|
+
if features is None:
|
517
|
+
features, h_patches, w_patches = self.extract_features(image)
|
518
|
+
else:
|
519
|
+
h_patches, w_patches = features.shape[:2]
|
520
|
+
|
521
|
+
# Convert coordinates to patch space
|
522
|
+
img_w, img_h = image.size
|
523
|
+
if len(query_coords) == 2:
|
524
|
+
# Assume pixel coordinates for now
|
525
|
+
if coord_crs is not None:
|
526
|
+
[query_coords] = coords_to_xy(source, [query_coords], coord_crs)
|
527
|
+
|
528
|
+
new_x = math.floor(query_coords[0] / raw_img_w * img_w)
|
529
|
+
new_y = math.floor(query_coords[1] / raw_img_h * img_h)
|
530
|
+
query_coords = [new_x, new_y]
|
531
|
+
|
532
|
+
x_pixel, y_pixel = query_coords
|
533
|
+
patch_x = math.floor((x_pixel / img_w) * w_patches)
|
534
|
+
patch_y = math.floor((y_pixel / img_h) * h_patches)
|
535
|
+
|
536
|
+
# Clamp to valid range
|
537
|
+
patch_x = max(0, min(w_patches - 1, patch_x))
|
538
|
+
patch_y = max(0, min(h_patches - 1, patch_y))
|
539
|
+
|
540
|
+
# Compute similarity
|
541
|
+
similarities = self.compute_patch_similarity(features, patch_x, patch_y)
|
542
|
+
|
543
|
+
# Prepare results
|
544
|
+
results = {
|
545
|
+
"similarities": similarities.cpu().numpy(),
|
546
|
+
"patch_coords": (patch_x, patch_y),
|
547
|
+
"patch_grid_size": (h_patches, w_patches),
|
548
|
+
"image_size": (img_w, img_h),
|
549
|
+
"metadata": metadata,
|
550
|
+
}
|
551
|
+
|
552
|
+
# Save similarity as GeoTIFF
|
553
|
+
sim_array = similarities.cpu().numpy()
|
554
|
+
|
555
|
+
# Resize similarity to original data dimensions
|
556
|
+
if use_interpolation:
|
557
|
+
try:
|
558
|
+
from skimage.transform import resize
|
559
|
+
|
560
|
+
sim_resized = resize(
|
561
|
+
sim_array,
|
562
|
+
(data.shape[-2], data.shape[-1]),
|
563
|
+
preserve_range=True,
|
564
|
+
anti_aliasing=True,
|
565
|
+
)
|
566
|
+
except ImportError:
|
567
|
+
# Fallback to PIL if scikit-image not available
|
568
|
+
from PIL import Image as PILImage
|
569
|
+
|
570
|
+
sim_pil = PILImage.fromarray((sim_array * 255).astype(np.uint8))
|
571
|
+
sim_pil = sim_pil.resize(
|
572
|
+
(data.shape[-1], data.shape[-2]), PILImage.LANCZOS
|
573
|
+
)
|
574
|
+
sim_resized = np.array(sim_pil, dtype=np.float32) / 255.0
|
575
|
+
else:
|
576
|
+
# Resize without interpolation (nearest neighbor)
|
577
|
+
try:
|
578
|
+
from skimage.transform import resize
|
579
|
+
|
580
|
+
sim_resized = resize(
|
581
|
+
sim_array,
|
582
|
+
(data.shape[-2], data.shape[-1]),
|
583
|
+
preserve_range=True,
|
584
|
+
anti_aliasing=False,
|
585
|
+
order=0, # Nearest neighbor interpolation
|
586
|
+
)
|
587
|
+
except ImportError:
|
588
|
+
# Fallback to PIL with nearest neighbor
|
589
|
+
from PIL import Image as PILImage
|
590
|
+
|
591
|
+
sim_pil = PILImage.fromarray((sim_array * 255).astype(np.uint8))
|
592
|
+
sim_pil = sim_pil.resize(
|
593
|
+
(data.shape[-1], data.shape[-2]), PILImage.NEAREST
|
594
|
+
)
|
595
|
+
sim_resized = np.array(sim_pil, dtype=np.float32) / 255.0
|
596
|
+
|
597
|
+
# Save similarity map
|
598
|
+
if metadata["crs"] is not None:
|
599
|
+
# Save as GeoTIFF for georeferenced data
|
600
|
+
similarity_path = os.path.join(
|
601
|
+
output_dir, f"similarity_patch_{patch_x}_{patch_y}.tif"
|
602
|
+
)
|
603
|
+
self.save_geotiff(
|
604
|
+
sim_resized[np.newaxis, :, :],
|
605
|
+
similarity_path,
|
606
|
+
metadata,
|
607
|
+
dtype="float32",
|
608
|
+
)
|
609
|
+
else:
|
610
|
+
# Save as PNG for regular images
|
611
|
+
similarity_path = os.path.join(
|
612
|
+
output_dir, f"similarity_patch_{patch_x}_{patch_y}.png"
|
613
|
+
)
|
614
|
+
self.save_similarity_as_image(sim_resized, similarity_path)
|
615
|
+
|
616
|
+
image_dict = {
|
617
|
+
"crs": metadata["crs"],
|
618
|
+
"bounds": metadata["bounds"],
|
619
|
+
"image": sim_resized[np.newaxis, :, :],
|
620
|
+
}
|
621
|
+
results["image_dict"] = image_dict
|
622
|
+
|
623
|
+
# Save features if requested
|
624
|
+
if save_features:
|
625
|
+
features_np = features.cpu().numpy()
|
626
|
+
features_path = os.path.join(
|
627
|
+
output_dir, f"features_patch_{patch_x}_{patch_y}.npy"
|
628
|
+
)
|
629
|
+
np.save(features_path, features_np)
|
630
|
+
|
631
|
+
# Save metadata
|
632
|
+
metadata_dict = {
|
633
|
+
"input_path": source,
|
634
|
+
"query_coords": query_coords,
|
635
|
+
"patch_coords": (patch_x, patch_y),
|
636
|
+
"patch_grid_size": (h_patches, w_patches),
|
637
|
+
"image_size": (img_w, img_h),
|
638
|
+
"similarity_stats": {
|
639
|
+
"max": float(sim_array.max()),
|
640
|
+
"min": float(sim_array.min()),
|
641
|
+
"mean": float(sim_array.mean()),
|
642
|
+
"std": float(sim_array.std()),
|
643
|
+
},
|
644
|
+
}
|
645
|
+
|
646
|
+
if save_features:
|
647
|
+
metadata_path = os.path.join(
|
648
|
+
output_dir, f"metadata_patch_{patch_x}_{patch_y}.json"
|
649
|
+
)
|
650
|
+
with open(metadata_path, "w", encoding="utf-8") as f:
|
651
|
+
json.dump(metadata_dict, f, indent=2)
|
652
|
+
|
653
|
+
results["output_paths"] = {
|
654
|
+
"similarity": similarity_path,
|
655
|
+
"metadata": metadata_path,
|
656
|
+
"features": features_path if save_features else None,
|
657
|
+
}
|
658
|
+
|
659
|
+
return results
|
660
|
+
|
661
|
+
def visualize_similarity(
|
662
|
+
self,
|
663
|
+
source: str,
|
664
|
+
similarity_data: np.ndarray,
|
665
|
+
query_coords: Tuple[float, float] = None,
|
666
|
+
patch_coords: Tuple[int, int] = None,
|
667
|
+
figsize: Tuple[int, int] = (15, 6),
|
668
|
+
colormap: str = "turbo",
|
669
|
+
alpha: float = 0.7,
|
670
|
+
save_path: str = None,
|
671
|
+
show_query_point: bool = True,
|
672
|
+
overlay: bool = False,
|
673
|
+
) -> plt.Figure:
|
674
|
+
"""Visualize original image and similarity map side by side or as overlay.
|
675
|
+
|
676
|
+
Args:
|
677
|
+
source: Path to original image
|
678
|
+
similarity_data: 2D similarity array
|
679
|
+
query_coords: Query coordinates in pixel space (x, y)
|
680
|
+
patch_coords: Patch coordinates (patch_x, patch_y) for marking query patch
|
681
|
+
figsize: Figure size for visualization
|
682
|
+
colormap: Colormap for similarity visualization
|
683
|
+
alpha: Transparency for overlay mode
|
684
|
+
save_path: Optional path to save the visualization
|
685
|
+
show_query_point: Whether to show the query point marker
|
686
|
+
overlay: If True, overlay similarity on original image; if False, show side by side
|
687
|
+
|
688
|
+
Returns:
|
689
|
+
Matplotlib figure object
|
690
|
+
"""
|
691
|
+
# Load original image
|
692
|
+
data, metadata = self.load_image(source)
|
693
|
+
|
694
|
+
# Convert image data to displayable format
|
695
|
+
if data.ndim == 3:
|
696
|
+
if data.shape[0] <= 3:
|
697
|
+
# Standard RGB/grayscale image (C, H, W)
|
698
|
+
display_img = np.transpose(data, (1, 2, 0))
|
699
|
+
else:
|
700
|
+
# Multi-band image, take first 3 bands
|
701
|
+
display_img = np.transpose(data[:3], (1, 2, 0))
|
702
|
+
else:
|
703
|
+
# Single band image
|
704
|
+
display_img = data
|
705
|
+
|
706
|
+
# Normalize image for display
|
707
|
+
if display_img.dtype != np.uint8:
|
708
|
+
# Normalize using percentiles
|
709
|
+
if display_img.ndim == 3:
|
710
|
+
normalized_img = np.zeros_like(display_img, dtype=np.float32)
|
711
|
+
for i in range(display_img.shape[2]):
|
712
|
+
band = display_img[:, :, i]
|
713
|
+
p2, p98 = np.percentile(band, [2, 98])
|
714
|
+
normalized_img[:, :, i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
715
|
+
else:
|
716
|
+
p2, p98 = np.percentile(display_img, [2, 98])
|
717
|
+
normalized_img = np.clip((display_img - p2) / (p98 - p2), 0, 1)
|
718
|
+
display_img = normalized_img
|
719
|
+
else:
|
720
|
+
display_img = display_img / 255.0
|
721
|
+
|
722
|
+
# Ensure similarity data matches image dimensions
|
723
|
+
if similarity_data.shape != display_img.shape[:2]:
|
724
|
+
from PIL import Image as PILImage
|
725
|
+
|
726
|
+
sim_pil = PILImage.fromarray((similarity_data * 255).astype(np.uint8))
|
727
|
+
sim_pil = sim_pil.resize(
|
728
|
+
(display_img.shape[1], display_img.shape[0]), PILImage.LANCZOS
|
729
|
+
)
|
730
|
+
similarity_data = np.array(sim_pil, dtype=np.float32) / 255.0
|
731
|
+
|
732
|
+
if overlay:
|
733
|
+
# Single plot with overlay
|
734
|
+
fig, ax = plt.subplots(1, 1, figsize=(figsize[1], figsize[1]))
|
735
|
+
|
736
|
+
# Show original image
|
737
|
+
if display_img.ndim == 2:
|
738
|
+
ax.imshow(display_img, cmap="gray")
|
739
|
+
else:
|
740
|
+
ax.imshow(display_img)
|
741
|
+
|
742
|
+
# Overlay similarity map
|
743
|
+
im_sim = ax.imshow(
|
744
|
+
similarity_data, cmap=colormap, alpha=alpha, vmin=0, vmax=1
|
745
|
+
)
|
746
|
+
|
747
|
+
# Add colorbar for similarity
|
748
|
+
cbar = plt.colorbar(im_sim, ax=ax, fraction=0.046, pad=0.04)
|
749
|
+
cbar.set_label("Similarity", rotation=270, labelpad=20)
|
750
|
+
|
751
|
+
ax.set_title("Image with Similarity Overlay")
|
752
|
+
|
753
|
+
else:
|
754
|
+
# Side-by-side visualization
|
755
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
756
|
+
|
757
|
+
# Original image
|
758
|
+
if display_img.ndim == 2:
|
759
|
+
ax1.imshow(display_img, cmap="gray")
|
760
|
+
else:
|
761
|
+
ax1.imshow(display_img)
|
762
|
+
ax1.set_title("Original Image")
|
763
|
+
ax1.axis("off")
|
764
|
+
|
765
|
+
# Similarity map
|
766
|
+
im_sim = ax2.imshow(similarity_data, cmap=colormap, vmin=0, vmax=1)
|
767
|
+
ax2.set_title("Similarity Map")
|
768
|
+
ax2.axis("off")
|
769
|
+
|
770
|
+
# Add colorbar
|
771
|
+
cbar = plt.colorbar(im_sim, ax=ax2, fraction=0.046, pad=0.04)
|
772
|
+
cbar.set_label("Similarity", rotation=270, labelpad=20)
|
773
|
+
|
774
|
+
# Mark query point if provided
|
775
|
+
if show_query_point and query_coords is not None:
|
776
|
+
x, y = query_coords
|
777
|
+
if overlay:
|
778
|
+
ax.plot(
|
779
|
+
x,
|
780
|
+
y,
|
781
|
+
"r*",
|
782
|
+
markersize=15,
|
783
|
+
markeredgecolor="white",
|
784
|
+
markeredgewidth=2,
|
785
|
+
)
|
786
|
+
ax.plot(x, y, "r*", markersize=12)
|
787
|
+
else:
|
788
|
+
ax1.plot(
|
789
|
+
x,
|
790
|
+
y,
|
791
|
+
"r*",
|
792
|
+
markersize=15,
|
793
|
+
markeredgecolor="white",
|
794
|
+
markeredgewidth=2,
|
795
|
+
)
|
796
|
+
ax1.plot(x, y, "r*", markersize=12)
|
797
|
+
ax2.plot(
|
798
|
+
x,
|
799
|
+
y,
|
800
|
+
"r*",
|
801
|
+
markersize=15,
|
802
|
+
markeredgecolor="white",
|
803
|
+
markeredgewidth=2,
|
804
|
+
)
|
805
|
+
ax2.plot(x, y, "r*", markersize=12)
|
806
|
+
|
807
|
+
plt.tight_layout()
|
808
|
+
|
809
|
+
# Save if path provided
|
810
|
+
if save_path:
|
811
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
812
|
+
|
813
|
+
return fig
|
814
|
+
|
815
|
+
def visualize_patches(
|
816
|
+
self,
|
817
|
+
image: Image.Image,
|
818
|
+
features: torch.Tensor,
|
819
|
+
patch_coords: Tuple[int, int],
|
820
|
+
add_text: bool = False,
|
821
|
+
figsize: Tuple[int, int] = (12, 8),
|
822
|
+
save_path: str = None,
|
823
|
+
) -> plt.Figure:
|
824
|
+
"""Visualize image with patch grid and highlight selected patch.
|
825
|
+
|
826
|
+
Args:
|
827
|
+
image: PIL Image
|
828
|
+
features: Feature tensor (h_patches, w_patches, embed_dim)
|
829
|
+
patch_coords: Selected patch coordinates (patch_x, patch_y)
|
830
|
+
add_text: Whether to add text to the patch
|
831
|
+
figsize: Figure size
|
832
|
+
save_path: Optional path to save visualization
|
833
|
+
|
834
|
+
Returns:
|
835
|
+
Matplotlib figure object
|
836
|
+
"""
|
837
|
+
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
838
|
+
|
839
|
+
# Display image
|
840
|
+
ax.imshow(image)
|
841
|
+
ax.set_title("Image with Patch Grid")
|
842
|
+
ax.axis("off")
|
843
|
+
|
844
|
+
# Get dimensions
|
845
|
+
img_w, img_h = image.size
|
846
|
+
h_patches, w_patches = features.shape[:2]
|
847
|
+
patch_x, patch_y = patch_coords
|
848
|
+
|
849
|
+
# Calculate patch size in pixels
|
850
|
+
patch_w = img_w / w_patches
|
851
|
+
patch_h = img_h / h_patches
|
852
|
+
|
853
|
+
# Draw patch grid
|
854
|
+
for i in range(w_patches + 1):
|
855
|
+
x = i * patch_w
|
856
|
+
ax.axvline(x=x, color="white", alpha=0.3, linewidth=0.5)
|
857
|
+
|
858
|
+
for i in range(h_patches + 1):
|
859
|
+
y = i * patch_h
|
860
|
+
ax.axhline(y=y, color="white", alpha=0.3, linewidth=0.5)
|
861
|
+
|
862
|
+
# Highlight selected patch
|
863
|
+
rect_x = patch_x * patch_w
|
864
|
+
rect_y = patch_y * patch_h
|
865
|
+
rect = patches.Rectangle(
|
866
|
+
(rect_x, rect_y),
|
867
|
+
patch_w,
|
868
|
+
patch_h,
|
869
|
+
linewidth=3,
|
870
|
+
edgecolor="red",
|
871
|
+
facecolor="none",
|
872
|
+
)
|
873
|
+
ax.add_patch(rect)
|
874
|
+
|
875
|
+
# Add patch coordinate text
|
876
|
+
if add_text:
|
877
|
+
ax.text(
|
878
|
+
rect_x + patch_w / 2,
|
879
|
+
rect_y + patch_h / 2,
|
880
|
+
f"({patch_x}, {patch_y})",
|
881
|
+
color="red",
|
882
|
+
fontsize=12,
|
883
|
+
ha="center",
|
884
|
+
va="center",
|
885
|
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
|
886
|
+
)
|
887
|
+
|
888
|
+
plt.tight_layout()
|
889
|
+
|
890
|
+
if save_path:
|
891
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
892
|
+
|
893
|
+
return fig
|
894
|
+
|
895
|
+
def create_similarity_overlay(
|
896
|
+
self,
|
897
|
+
source: str,
|
898
|
+
similarity_data: np.ndarray,
|
899
|
+
colormap: str = "turbo",
|
900
|
+
alpha: float = 0.7,
|
901
|
+
) -> np.ndarray:
|
902
|
+
"""Create an overlay of similarity map on original image.
|
903
|
+
|
904
|
+
Args:
|
905
|
+
source: Path to original image
|
906
|
+
similarity_data: 2D similarity array
|
907
|
+
colormap: Colormap for similarity visualization
|
908
|
+
alpha: Transparency for overlay
|
909
|
+
|
910
|
+
Returns:
|
911
|
+
RGB overlay image as numpy array
|
912
|
+
"""
|
913
|
+
# Load original image
|
914
|
+
data, _ = self.load_image(source)
|
915
|
+
|
916
|
+
# Convert to display format
|
917
|
+
if data.ndim == 3:
|
918
|
+
if data.shape[0] <= 3:
|
919
|
+
display_img = np.transpose(data, (1, 2, 0))
|
920
|
+
else:
|
921
|
+
display_img = np.transpose(data[:3], (1, 2, 0))
|
922
|
+
else:
|
923
|
+
display_img = data
|
924
|
+
|
925
|
+
# Normalize image
|
926
|
+
if display_img.dtype != np.uint8:
|
927
|
+
if display_img.ndim == 3:
|
928
|
+
normalized_img = np.zeros_like(display_img, dtype=np.float32)
|
929
|
+
for i in range(display_img.shape[2]):
|
930
|
+
band = display_img[:, :, i]
|
931
|
+
p2, p98 = np.percentile(band, [2, 98])
|
932
|
+
normalized_img[:, :, i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
933
|
+
else:
|
934
|
+
p2, p98 = np.percentile(display_img, [2, 98])
|
935
|
+
normalized_img = np.clip((display_img - p2) / (p98 - p2), 0, 1)
|
936
|
+
base_img = normalized_img
|
937
|
+
else:
|
938
|
+
base_img = display_img / 255.0
|
939
|
+
|
940
|
+
# Convert grayscale to RGB if needed
|
941
|
+
if base_img.ndim == 2:
|
942
|
+
base_img = np.stack([base_img] * 3, axis=2)
|
943
|
+
|
944
|
+
# Resize similarity data to match image
|
945
|
+
if similarity_data.shape != base_img.shape[:2]:
|
946
|
+
from PIL import Image as PILImage
|
947
|
+
|
948
|
+
sim_pil = PILImage.fromarray((similarity_data * 255).astype(np.uint8))
|
949
|
+
sim_pil = sim_pil.resize(
|
950
|
+
(base_img.shape[1], base_img.shape[0]), PILImage.LANCZOS
|
951
|
+
)
|
952
|
+
similarity_data = np.array(sim_pil, dtype=np.float32) / 255.0
|
953
|
+
|
954
|
+
# Apply colormap to similarity data
|
955
|
+
cmap = plt.get_cmap(colormap)
|
956
|
+
colored_similarity = cmap(similarity_data)[:, :, :3] # Remove alpha channel
|
957
|
+
|
958
|
+
# Blend images
|
959
|
+
overlay_img = (1 - alpha) * base_img + alpha * colored_similarity
|
960
|
+
|
961
|
+
return np.clip(overlay_img, 0, 1)
|
962
|
+
|
963
|
+
def batch_similarity_analysis(
|
964
|
+
self,
|
965
|
+
input_path: str,
|
966
|
+
query_points: List[Tuple[float, float]],
|
967
|
+
output_dir: str,
|
968
|
+
window: Optional[Window] = None,
|
969
|
+
bands: Optional[List[int]] = None,
|
970
|
+
target_size: int = 896,
|
971
|
+
) -> List[Dict[str, np.ndarray]]:
|
972
|
+
"""Process multiple query points for similarity analysis.
|
973
|
+
|
974
|
+
Args:
|
975
|
+
input_path: Path to input GeoTIFF
|
976
|
+
query_points: List of (x, y) coordinates
|
977
|
+
output_dir: Output directory for results
|
978
|
+
window: Optional window for reading subset
|
979
|
+
bands: Optional list of bands to use
|
980
|
+
target_size: Target size for processing
|
981
|
+
|
982
|
+
Returns:
|
983
|
+
List of result dictionaries
|
984
|
+
"""
|
985
|
+
results = []
|
986
|
+
for i, coords in enumerate(query_points):
|
987
|
+
point_output_dir = os.path.join(output_dir, f"point_{i}")
|
988
|
+
result = self.compute_similarity(
|
989
|
+
source=input_path,
|
990
|
+
query_coords=coords,
|
991
|
+
output_dir=point_output_dir,
|
992
|
+
window=window,
|
993
|
+
bands=bands,
|
994
|
+
target_size=target_size,
|
995
|
+
)
|
996
|
+
results.append(result)
|
997
|
+
|
998
|
+
return results
|
999
|
+
|
1000
|
+
|
1001
|
+
def create_similarity_map(
|
1002
|
+
input_image: str,
|
1003
|
+
query_coords: Tuple[float, float],
|
1004
|
+
output_dir: str,
|
1005
|
+
model_name: str = "dinov3_vitl16",
|
1006
|
+
weights_path: Optional[str] = None,
|
1007
|
+
window: Optional[Window] = None,
|
1008
|
+
bands: Optional[List[int]] = None,
|
1009
|
+
target_size: int = 896,
|
1010
|
+
save_features: bool = False,
|
1011
|
+
coord_crs: str = None,
|
1012
|
+
use_interpolation: bool = True,
|
1013
|
+
) -> Dict[str, np.ndarray]:
|
1014
|
+
"""Convenience function to create similarity map from image file.
|
1015
|
+
|
1016
|
+
Args:
|
1017
|
+
input_image: Path to input image file (GeoTIFF, PNG, JPG, etc.)
|
1018
|
+
query_coords: Query coordinates (x, y) in pixel space
|
1019
|
+
output_dir: Output directory
|
1020
|
+
model_name: DINOv3 model name
|
1021
|
+
weights_path: Optional path to model weights
|
1022
|
+
window: Optional rasterio window (only applies to GeoTIFF)
|
1023
|
+
bands: Optional list of bands to use (only applies to GeoTIFF)
|
1024
|
+
target_size: Target size for processing
|
1025
|
+
save_features: Whether to save extracted features
|
1026
|
+
coord_crs: Coordinate CRS of the query coordinates (only applies to GeoTIFF)
|
1027
|
+
use_interpolation: Whether to use interpolation when resizing similarity map
|
1028
|
+
|
1029
|
+
Returns:
|
1030
|
+
Dictionary containing results
|
1031
|
+
"""
|
1032
|
+
processor = DINOv3GeoProcessor(model_name=model_name, weights_path=weights_path)
|
1033
|
+
|
1034
|
+
return processor.compute_similarity(
|
1035
|
+
source=input_image,
|
1036
|
+
query_coords=query_coords,
|
1037
|
+
output_dir=output_dir,
|
1038
|
+
window=window,
|
1039
|
+
bands=bands,
|
1040
|
+
target_size=target_size,
|
1041
|
+
save_features=save_features,
|
1042
|
+
coord_crs=coord_crs,
|
1043
|
+
use_interpolation=use_interpolation,
|
1044
|
+
)
|
1045
|
+
|
1046
|
+
|
1047
|
+
def analyze_image_patches(
|
1048
|
+
input_image: str,
|
1049
|
+
query_points: List[Tuple[float, float]],
|
1050
|
+
output_dir: str,
|
1051
|
+
model_name: str = "dinov3_vitl16",
|
1052
|
+
weights_path: Optional[str] = None,
|
1053
|
+
) -> List[Dict[str, np.ndarray]]:
|
1054
|
+
"""Analyze multiple patches in an image file.
|
1055
|
+
|
1056
|
+
Args:
|
1057
|
+
input_image: Path to input image file (GeoTIFF, PNG, JPG, etc.)
|
1058
|
+
query_points: List of query coordinates
|
1059
|
+
output_dir: Output directory
|
1060
|
+
model_name: DINOv3 model name
|
1061
|
+
weights_path: Optional path to model weights
|
1062
|
+
|
1063
|
+
Returns:
|
1064
|
+
List of result dictionaries
|
1065
|
+
"""
|
1066
|
+
processor = DINOv3GeoProcessor(model_name=model_name, weights_path=weights_path)
|
1067
|
+
|
1068
|
+
return processor.batch_similarity_analysis(input_image, query_points, output_dir)
|
1069
|
+
|
1070
|
+
|
1071
|
+
def visualize_similarity_results(
|
1072
|
+
input_image: str,
|
1073
|
+
query_coords: Tuple[float, float],
|
1074
|
+
output_dir: str = None,
|
1075
|
+
model_name: str = "dinov3_vitl16",
|
1076
|
+
weights_path: Optional[str] = None,
|
1077
|
+
figsize: Tuple[int, int] = (15, 6),
|
1078
|
+
colormap: str = "turbo",
|
1079
|
+
alpha: float = 0.7,
|
1080
|
+
save_path: str = None,
|
1081
|
+
show_query_point: bool = True,
|
1082
|
+
overlay: bool = False,
|
1083
|
+
target_size: int = 896,
|
1084
|
+
coord_crs: str = None,
|
1085
|
+
use_interpolation: bool = True,
|
1086
|
+
) -> Dict:
|
1087
|
+
"""Create similarity map and visualize results in one function.
|
1088
|
+
|
1089
|
+
Args:
|
1090
|
+
input_image: Path to input image file (GeoTIFF, PNG, JPG, etc.)
|
1091
|
+
query_coords: Query coordinates (x, y) in pixel space
|
1092
|
+
output_dir: Output directory for similarity map files (optional)
|
1093
|
+
model_name: DINOv3 model name
|
1094
|
+
weights_path: Optional path to model weights
|
1095
|
+
figsize: Figure size for visualization
|
1096
|
+
colormap: Colormap for similarity visualization
|
1097
|
+
alpha: Transparency for overlay mode
|
1098
|
+
save_path: Optional path to save the visualization
|
1099
|
+
show_query_point: Whether to show the query point marker
|
1100
|
+
overlay: If True, overlay similarity on original image; if False, show side by side
|
1101
|
+
target_size: Target size for processing
|
1102
|
+
coord_crs: Coordinate CRS of the query coordinates
|
1103
|
+
use_interpolation: Whether to use interpolation when resizing similarity map
|
1104
|
+
|
1105
|
+
Returns:
|
1106
|
+
Dictionary containing similarity results, metadata, and matplotlib figure
|
1107
|
+
"""
|
1108
|
+
processor = DINOv3GeoProcessor(model_name=model_name, weights_path=weights_path)
|
1109
|
+
|
1110
|
+
# Create temporary output directory if not provided
|
1111
|
+
if output_dir is None:
|
1112
|
+
import tempfile
|
1113
|
+
|
1114
|
+
output_dir = tempfile.mkdtemp(prefix="dinov3_similarity_")
|
1115
|
+
|
1116
|
+
# Compute similarity
|
1117
|
+
results = processor.compute_similarity(
|
1118
|
+
source=input_image,
|
1119
|
+
query_coords=query_coords,
|
1120
|
+
output_dir=output_dir,
|
1121
|
+
target_size=target_size,
|
1122
|
+
coord_crs=coord_crs,
|
1123
|
+
use_interpolation=use_interpolation,
|
1124
|
+
)
|
1125
|
+
|
1126
|
+
# Get similarity data from results
|
1127
|
+
similarity_data = results["image_dict"]["image"][0] # Remove channel dimension
|
1128
|
+
|
1129
|
+
# Create visualization
|
1130
|
+
fig = processor.visualize_similarity(
|
1131
|
+
source=input_image,
|
1132
|
+
similarity_data=similarity_data,
|
1133
|
+
query_coords=query_coords,
|
1134
|
+
patch_coords=results["patch_coords"],
|
1135
|
+
figsize=figsize,
|
1136
|
+
colormap=colormap,
|
1137
|
+
alpha=alpha,
|
1138
|
+
save_path=save_path,
|
1139
|
+
show_query_point=show_query_point,
|
1140
|
+
overlay=overlay,
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
# Add figure to results
|
1144
|
+
results["visualization"] = fig
|
1145
|
+
|
1146
|
+
return results
|