geoai-py 0.10.0__py2.py3-none-any.whl → 0.11.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/dinov3.py ADDED
@@ -0,0 +1,1146 @@
1
+ """DINOv3 module for patch similarity analysis with GeoTIFF support.
2
+
3
+ This module provides tools for computing patch similarity using DINOv3 features
4
+ on geospatial imagery stored in GeoTIFF format.
5
+ """
6
+
7
+ import json
8
+ import math
9
+ import os
10
+ import sys
11
+ from typing import Tuple, Optional, Dict, List, Union
12
+
13
+ import numpy as np
14
+ from PIL import Image
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+ import rasterio
19
+ from rasterio.windows import Window
20
+ from rasterio.io import DatasetReader
21
+ import matplotlib.pyplot as plt
22
+ import matplotlib.patches as patches
23
+
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ from .utils import get_device, coords_to_xy, dict_to_image, dict_to_rioxarray
27
+
28
+
29
+ class DINOv3GeoProcessor:
30
+ """DINOv3 processor with GeoTIFF input/output support.
31
+ https://github.com/facebookresearch/dinov3
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model_name: str = "dinov3_vitl16",
37
+ weights_path: Optional[str] = None,
38
+ device: Optional[torch.device] = None,
39
+ ):
40
+ """Initialize DINOv3 processor.
41
+
42
+ Args:
43
+ model_name: Name of the DINOv3 model. Can be "dinov3_vits16", "dinov3_vits16plus",
44
+ "dinov3_vitb16", "dinov3_vitl16", "dinov3_vith16plus", "dinov3_vit7b16", "dinov3_convnext_tiny",
45
+ "dinov3_convnext_small", "dinov3_convnext_base", "dinov3_convnext_large",
46
+ "dinov3dinov3_vitl16", and "dinov3_vit7b16".
47
+ See https://github.com/facebookresearch/dinov3 for more details.
48
+ weights_path: Path to model weights (optional)
49
+ device: Torch device to use
50
+ dinov3_location: Path to DINOv3 repository
51
+ """
52
+
53
+ dinov3_github_location = "facebookresearch/dinov3"
54
+
55
+ if os.getenv("DINOV3_LOCATION") is not None:
56
+ dinov3_location = os.getenv("DINOV3_LOCATION")
57
+ else:
58
+ dinov3_location = dinov3_github_location
59
+
60
+ self.dinov3_location = dinov3_location
61
+ self.dinov3_source = (
62
+ "local" if dinov3_location != dinov3_github_location else "github"
63
+ )
64
+
65
+ self.device = device or get_device()
66
+ self.model_name = model_name
67
+
68
+ # Add DINOv3 to path if needed
69
+ if dinov3_location != "facebookresearch/dinov3" and (
70
+ dinov3_location not in sys.path
71
+ ):
72
+ sys.path.append(dinov3_location)
73
+
74
+ # Load model
75
+ self.model = self._load_model(weights_path)
76
+ self.patch_size = self.model.patch_size
77
+ self.embed_dim = self.model.embed_dim
78
+
79
+ # Image transforms - satellite imagery normalization
80
+ self.transform = transforms.Compose(
81
+ [
82
+ transforms.ToTensor(),
83
+ transforms.Normalize(
84
+ mean=(0.430, 0.411, 0.296), # SAT-493M normalization
85
+ std=(0.213, 0.156, 0.143),
86
+ ),
87
+ ]
88
+ )
89
+
90
+ def _download_model_from_hf(
91
+ self, model_path: Optional[str] = None, repo_id: Optional[str] = None
92
+ ) -> str:
93
+ """
94
+ Download the object detection model from Hugging Face.
95
+
96
+ Args:
97
+ model_path: Path to the model file.
98
+ repo_id: Hugging Face repository ID.
99
+
100
+ Returns:
101
+ Path to the downloaded model file
102
+ """
103
+ try:
104
+
105
+ # Define the repository ID and model filename
106
+ if repo_id is None:
107
+ repo_id = "giswqs/geoai"
108
+
109
+ if model_path is None:
110
+ model_path = "dinov3_vitl16_sat493m.pth"
111
+
112
+ # Download the model
113
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
114
+
115
+ return model_path
116
+
117
+ except Exception as e:
118
+ print(f"Error downloading model from Hugging Face: {e}")
119
+ print("Please specify a local model path or ensure internet connectivity.")
120
+ raise
121
+
122
+ def _load_model(self, weights_path: Optional[str] = None) -> torch.nn.Module:
123
+ """Load DINOv3 model."""
124
+ try:
125
+ if weights_path and os.path.exists(weights_path):
126
+ # Load with custom weights
127
+ model = torch.hub.load(
128
+ repo_or_dir=self.dinov3_location,
129
+ model=self.model_name,
130
+ source=self.dinov3_source,
131
+ )
132
+ # Load state dict manually
133
+ state_dict = torch.load(weights_path, map_location=self.device)
134
+ model.load_state_dict(state_dict, strict=False)
135
+ else:
136
+ # Download weights and load manually
137
+ weights_path = self._download_model_from_hf()
138
+ model = torch.hub.load(
139
+ repo_or_dir=self.dinov3_location,
140
+ model=self.model_name,
141
+ source=self.dinov3_source,
142
+ )
143
+ # Load state dict manually
144
+ state_dict = torch.load(weights_path, map_location=self.device)
145
+ model.load_state_dict(state_dict, strict=False)
146
+
147
+ model = model.to(self.device)
148
+ model.eval()
149
+ return model
150
+ except Exception as e:
151
+ raise RuntimeError(f"Failed to load DINOv3 model: {e}") from e
152
+
153
+ def load_regular_image(
154
+ self,
155
+ image_path: str,
156
+ ) -> Tuple[np.ndarray, dict]:
157
+ """Load regular image file (PNG, JPG, etc.).
158
+
159
+ Args:
160
+ image_path: Path to image file
161
+
162
+ Returns:
163
+ Tuple of (image array, metadata)
164
+ """
165
+ try:
166
+ # Load image using PIL
167
+ image = Image.open(image_path).convert("RGB")
168
+
169
+ # Convert to numpy array (H, W, C)
170
+ img_array = np.array(image)
171
+
172
+ # Convert to (C, H, W) format to match GeoTIFF format
173
+ data = np.transpose(img_array, (2, 0, 1)).astype(np.uint8)
174
+
175
+ # Create basic metadata
176
+ height, width = img_array.shape[:2]
177
+ metadata = {
178
+ "profile": {
179
+ "driver": "PNG",
180
+ "dtype": "uint8",
181
+ "nodata": None,
182
+ "width": width,
183
+ "height": height,
184
+ "count": 3,
185
+ "crs": None,
186
+ "transform": None,
187
+ },
188
+ "crs": None,
189
+ "transform": None,
190
+ "bounds": (0, 0, width, height),
191
+ }
192
+
193
+ return data, metadata
194
+
195
+ except Exception as e:
196
+ raise RuntimeError(f"Failed to load image {image_path}: {e}")
197
+
198
+ def load_geotiff(
199
+ self,
200
+ source: Union[str, DatasetReader],
201
+ window: Optional[Window] = None,
202
+ bands: Optional[List[int]] = None,
203
+ ) -> Tuple[np.ndarray, dict]:
204
+ """Load GeoTIFF file.
205
+
206
+ Args:
207
+ source: Path to GeoTIFF file (str) or an open rasterio.DatasetReader
208
+ window: Rasterio window for reading subset
209
+ bands: List of bands to read (1-indexed)
210
+
211
+ Returns:
212
+ Tuple of (image array, metadata)
213
+ """
214
+ # Flag to determine if we need to close the dataset afterwards
215
+ should_close = False
216
+ if isinstance(source, str):
217
+ src = rasterio.open(source)
218
+ should_close = True
219
+ elif isinstance(source, DatasetReader):
220
+ src = source
221
+ else:
222
+ raise TypeError("source must be a str path or a rasterio.DatasetReader")
223
+
224
+ try:
225
+ # Read specified bands or all bands
226
+ if bands:
227
+ data = src.read(bands, window=window)
228
+ else:
229
+ data = src.read(window=window)
230
+
231
+ # Get metadata
232
+ profile = src.profile.copy()
233
+ if window:
234
+ profile.update(
235
+ {
236
+ "height": window.height,
237
+ "width": window.width,
238
+ "transform": src.window_transform(window),
239
+ }
240
+ )
241
+
242
+ metadata = {
243
+ "profile": profile,
244
+ "crs": src.crs,
245
+ "transform": profile["transform"],
246
+ "bounds": (
247
+ src.bounds
248
+ if not window
249
+ else rasterio.windows.bounds(window, src.transform)
250
+ ),
251
+ }
252
+ finally:
253
+ if should_close:
254
+ src.close()
255
+
256
+ return data, metadata
257
+
258
+ def load_image(
259
+ self,
260
+ source: Union[str, DatasetReader],
261
+ window: Optional[Window] = None,
262
+ bands: Optional[List[int]] = None,
263
+ ) -> Tuple[np.ndarray, dict]:
264
+ """Load image file (GeoTIFF or regular image).
265
+
266
+ Args:
267
+ source: Path to image file (str) or an open rasterio.DatasetReader
268
+ window: Rasterio window for reading subset (only applies to GeoTIFF)
269
+ bands: List of bands to read (only applies to GeoTIFF)
270
+
271
+ Returns:
272
+ Tuple of (image array, metadata)
273
+ """
274
+ if isinstance(source, str):
275
+ # Check if it's a GeoTIFF file
276
+ try:
277
+ # Try to open with rasterio first
278
+ with rasterio.open(source) as src:
279
+ # If successful and has CRS, treat as GeoTIFF
280
+ if src.crs is not None:
281
+ return self.load_geotiff(source, window, bands)
282
+ # If no CRS, it might be a regular image opened by rasterio
283
+ else:
284
+ # Check file extension
285
+ file_ext = source.lower().split(".")[-1]
286
+ if file_ext in ["tif", "tiff"]:
287
+ return self.load_geotiff(source, window, bands)
288
+ else:
289
+ return self.load_regular_image(source)
290
+ except (rasterio.RasterioIOError, rasterio.errors.RasterioIOError):
291
+ # If rasterio fails, try as regular image
292
+ return self.load_regular_image(source)
293
+ elif isinstance(source, DatasetReader):
294
+ # Already opened rasterio dataset
295
+ return self.load_geotiff(source, window, bands)
296
+ else:
297
+ raise TypeError("source must be a str path or a rasterio.DatasetReader")
298
+
299
+ def save_geotiff(
300
+ self, data: np.ndarray, output_path: str, metadata: dict, dtype: str = "float32"
301
+ ) -> None:
302
+ """Save array as GeoTIFF.
303
+
304
+ Args:
305
+ data: Array to save
306
+ output_path: Output file path
307
+ metadata: Metadata from original file
308
+ dtype: Output data type
309
+ """
310
+ profile = metadata["profile"].copy()
311
+ profile.update(
312
+ {
313
+ "dtype": dtype,
314
+ "count": data.shape[0] if data.ndim == 3 else 1,
315
+ "height": data.shape[-2] if data.ndim >= 2 else data.shape[0],
316
+ "width": data.shape[-1] if data.ndim >= 2 else 1,
317
+ }
318
+ )
319
+
320
+ with rasterio.open(output_path, "w", **profile) as dst:
321
+ if data.ndim == 2:
322
+ dst.write(data, 1)
323
+ else:
324
+ dst.write(data)
325
+
326
+ def save_similarity_as_image(
327
+ self, similarity_data: np.ndarray, output_path: str, colormap: str = "turbo"
328
+ ) -> None:
329
+ """Save similarity array as PNG image with colormap.
330
+
331
+ Args:
332
+ similarity_data: 2D similarity array
333
+ output_path: Output file path
334
+ colormap: Matplotlib colormap name
335
+ """
336
+ import matplotlib.pyplot as plt
337
+
338
+ # Apply colormap
339
+ cmap = plt.get_cmap(colormap)
340
+ colored_data = cmap(similarity_data)
341
+
342
+ # Convert to uint8 image (remove alpha channel)
343
+ img_data = (colored_data[..., :3] * 255).astype(np.uint8)
344
+
345
+ # Save as PNG
346
+ img = Image.fromarray(img_data)
347
+ img.save(output_path)
348
+
349
+ def preprocess_image_for_dinov3(
350
+ self,
351
+ data: np.ndarray,
352
+ target_size: int = 896,
353
+ normalize_percentile: bool = True,
354
+ ) -> Image.Image:
355
+ """Preprocess image data for DINOv3.
356
+
357
+ Args:
358
+ data: Input array (C, H, W) or (H, W)
359
+ target_size: Target size for resizing
360
+ normalize_percentile: Whether to normalize using percentiles
361
+
362
+ Returns:
363
+ PIL Image ready for DINOv3
364
+ """
365
+ # Handle different input shapes
366
+ if data.ndim == 2:
367
+ data = data[np.newaxis, :, :] # Add channel dimension
368
+ elif data.ndim == 3 and data.shape[0] > 3:
369
+ # Take first 3 bands if more than 3 channels
370
+ data = data[:3, :, :]
371
+
372
+ # Normalize data
373
+ if normalize_percentile:
374
+ # Normalize each band using percentiles
375
+ normalized_data = np.zeros_like(data, dtype=np.float32)
376
+ for i in range(data.shape[0]):
377
+ band = data[i]
378
+ p2, p98 = np.percentile(band, [2, 98])
379
+ normalized_data[i] = np.clip((band - p2) / (p98 - p2), 0, 1)
380
+ else:
381
+ # Simple min-max normalization
382
+ normalized_data = (data - data.min()) / (data.max() - data.min())
383
+
384
+ # Convert to PIL Image
385
+ if normalized_data.shape[0] == 1:
386
+ # Grayscale - repeat to 3 channels
387
+ img_array = np.repeat(normalized_data[0], 3, axis=0)
388
+ else:
389
+ img_array = normalized_data
390
+
391
+ # Transpose to HWC format and convert to uint8
392
+ img_array = np.transpose(img_array, (1, 2, 0))
393
+ img_array = (img_array * 255).astype(np.uint8)
394
+
395
+ # Create PIL Image
396
+ image = Image.fromarray(img_array)
397
+
398
+ # Resize to patch-aligned dimensions
399
+ return self.resize_to_patch_aligned(image, target_size)
400
+
401
+ def resize_to_patch_aligned(
402
+ self, image: Image.Image, target_size: int = 896
403
+ ) -> Image.Image:
404
+ """Resize image to be aligned with patch grid."""
405
+ w, h = image.size
406
+
407
+ # Calculate new dimensions that are multiples of patch_size
408
+ if w > h:
409
+ new_h = target_size
410
+ new_w = int((w * target_size) / h)
411
+ else:
412
+ new_w = target_size
413
+ new_h = int((h * target_size) / w)
414
+
415
+ # Round to nearest multiple of patch_size
416
+ new_h = ((new_h + self.patch_size - 1) // self.patch_size) * self.patch_size
417
+ new_w = ((new_w + self.patch_size - 1) // self.patch_size) * self.patch_size
418
+
419
+ return image.resize((new_w, new_h), Image.Resampling.LANCZOS)
420
+
421
+ def extract_features(self, image: Image.Image) -> Tuple[torch.Tensor, int, int]:
422
+ """Extract patch features from image."""
423
+
424
+ if isinstance(image, str):
425
+ image = Image.open(image)
426
+
427
+ if isinstance(image, np.ndarray):
428
+ image = Image.fromarray(image)
429
+
430
+ # Transform image
431
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
432
+
433
+ with torch.no_grad():
434
+ # Extract features from last layer
435
+ features = self.model.get_intermediate_layers(
436
+ img_tensor, n=1, reshape=True, norm=True
437
+ )[
438
+ 0
439
+ ] # Shape: [1, embed_dim, h_patches, w_patches]
440
+
441
+ # Rearrange to [h_patches, w_patches, embed_dim]
442
+ features = features.squeeze(0).permute(1, 2, 0)
443
+ h_patches, w_patches = features.shape[:2]
444
+
445
+ return features, h_patches, w_patches
446
+
447
+ def compute_patch_similarity(
448
+ self, features: torch.Tensor, patch_x: int, patch_y: int
449
+ ) -> torch.Tensor:
450
+ """Compute cosine similarity between selected patch and all patches."""
451
+ h_patches, w_patches, embed_dim = features.shape
452
+
453
+ # Get query patch feature
454
+ query_feature = features[patch_y, patch_x] # Shape: [embed_dim]
455
+
456
+ # Reshape features for batch computation
457
+ all_features = features.view(
458
+ -1, embed_dim
459
+ ) # Shape: [h_patches * w_patches, embed_dim]
460
+
461
+ # Compute cosine similarity
462
+ similarities = F.cosine_similarity(
463
+ query_feature.unsqueeze(0), # Shape: [1, embed_dim]
464
+ all_features, # Shape: [h_patches * w_patches, embed_dim]
465
+ dim=1,
466
+ )
467
+
468
+ # Reshape back to patch grid
469
+ similarities = similarities.view(h_patches, w_patches)
470
+
471
+ # Normalize to 0-1 range
472
+ similarities = (similarities + 1) / 2
473
+
474
+ return similarities
475
+
476
+ def compute_similarity(
477
+ self,
478
+ source: str = None,
479
+ features: torch.Tensor = None,
480
+ query_coords: Tuple[float, float] = None,
481
+ output_dir: str = None,
482
+ window: Optional[Window] = None,
483
+ bands: Optional[List[int]] = None,
484
+ target_size: int = 896,
485
+ save_features: bool = False,
486
+ coord_crs: str = None,
487
+ use_interpolation: bool = True,
488
+ ) -> Dict[str, np.ndarray]:
489
+ """Process GeoTIFF for patch similarity analysis.
490
+
491
+ Args:
492
+ source: Path to input GeoTIFF or rasterio dataset
493
+ features: Pre-extracted features (h_patches, w_patches, embed_dim)
494
+ query_coords: (x, y) coordinates in image pixel space or (lon, lat) in geographic space
495
+ output_dir: Output directory for results
496
+ window: Optional window for reading subset
497
+ bands: Optional list of bands to use
498
+ target_size: Target size for processing
499
+ save_features: Whether to save extracted features
500
+ coord_crs: Coordinate CRS of the query coordinates
501
+ use_interpolation: Whether to use interpolation when resizing similarity map
502
+
503
+ Returns:
504
+ Dictionary containing similarity results and metadata
505
+ """
506
+ os.makedirs(output_dir, exist_ok=True)
507
+
508
+ # Load image (GeoTIFF or regular image)
509
+ data, metadata = self.load_image(source, window, bands)
510
+ raw_img_w, raw_img_h = data.shape[-1], data.shape[-2]
511
+
512
+ # Preprocess for DINOv3
513
+ image = self.preprocess_image_for_dinov3(data, target_size)
514
+
515
+ # Extract features
516
+ if features is None:
517
+ features, h_patches, w_patches = self.extract_features(image)
518
+ else:
519
+ h_patches, w_patches = features.shape[:2]
520
+
521
+ # Convert coordinates to patch space
522
+ img_w, img_h = image.size
523
+ if len(query_coords) == 2:
524
+ # Assume pixel coordinates for now
525
+ if coord_crs is not None:
526
+ [query_coords] = coords_to_xy(source, [query_coords], coord_crs)
527
+
528
+ new_x = math.floor(query_coords[0] / raw_img_w * img_w)
529
+ new_y = math.floor(query_coords[1] / raw_img_h * img_h)
530
+ query_coords = [new_x, new_y]
531
+
532
+ x_pixel, y_pixel = query_coords
533
+ patch_x = math.floor((x_pixel / img_w) * w_patches)
534
+ patch_y = math.floor((y_pixel / img_h) * h_patches)
535
+
536
+ # Clamp to valid range
537
+ patch_x = max(0, min(w_patches - 1, patch_x))
538
+ patch_y = max(0, min(h_patches - 1, patch_y))
539
+
540
+ # Compute similarity
541
+ similarities = self.compute_patch_similarity(features, patch_x, patch_y)
542
+
543
+ # Prepare results
544
+ results = {
545
+ "similarities": similarities.cpu().numpy(),
546
+ "patch_coords": (patch_x, patch_y),
547
+ "patch_grid_size": (h_patches, w_patches),
548
+ "image_size": (img_w, img_h),
549
+ "metadata": metadata,
550
+ }
551
+
552
+ # Save similarity as GeoTIFF
553
+ sim_array = similarities.cpu().numpy()
554
+
555
+ # Resize similarity to original data dimensions
556
+ if use_interpolation:
557
+ try:
558
+ from skimage.transform import resize
559
+
560
+ sim_resized = resize(
561
+ sim_array,
562
+ (data.shape[-2], data.shape[-1]),
563
+ preserve_range=True,
564
+ anti_aliasing=True,
565
+ )
566
+ except ImportError:
567
+ # Fallback to PIL if scikit-image not available
568
+ from PIL import Image as PILImage
569
+
570
+ sim_pil = PILImage.fromarray((sim_array * 255).astype(np.uint8))
571
+ sim_pil = sim_pil.resize(
572
+ (data.shape[-1], data.shape[-2]), PILImage.LANCZOS
573
+ )
574
+ sim_resized = np.array(sim_pil, dtype=np.float32) / 255.0
575
+ else:
576
+ # Resize without interpolation (nearest neighbor)
577
+ try:
578
+ from skimage.transform import resize
579
+
580
+ sim_resized = resize(
581
+ sim_array,
582
+ (data.shape[-2], data.shape[-1]),
583
+ preserve_range=True,
584
+ anti_aliasing=False,
585
+ order=0, # Nearest neighbor interpolation
586
+ )
587
+ except ImportError:
588
+ # Fallback to PIL with nearest neighbor
589
+ from PIL import Image as PILImage
590
+
591
+ sim_pil = PILImage.fromarray((sim_array * 255).astype(np.uint8))
592
+ sim_pil = sim_pil.resize(
593
+ (data.shape[-1], data.shape[-2]), PILImage.NEAREST
594
+ )
595
+ sim_resized = np.array(sim_pil, dtype=np.float32) / 255.0
596
+
597
+ # Save similarity map
598
+ if metadata["crs"] is not None:
599
+ # Save as GeoTIFF for georeferenced data
600
+ similarity_path = os.path.join(
601
+ output_dir, f"similarity_patch_{patch_x}_{patch_y}.tif"
602
+ )
603
+ self.save_geotiff(
604
+ sim_resized[np.newaxis, :, :],
605
+ similarity_path,
606
+ metadata,
607
+ dtype="float32",
608
+ )
609
+ else:
610
+ # Save as PNG for regular images
611
+ similarity_path = os.path.join(
612
+ output_dir, f"similarity_patch_{patch_x}_{patch_y}.png"
613
+ )
614
+ self.save_similarity_as_image(sim_resized, similarity_path)
615
+
616
+ image_dict = {
617
+ "crs": metadata["crs"],
618
+ "bounds": metadata["bounds"],
619
+ "image": sim_resized[np.newaxis, :, :],
620
+ }
621
+ results["image_dict"] = image_dict
622
+
623
+ # Save features if requested
624
+ if save_features:
625
+ features_np = features.cpu().numpy()
626
+ features_path = os.path.join(
627
+ output_dir, f"features_patch_{patch_x}_{patch_y}.npy"
628
+ )
629
+ np.save(features_path, features_np)
630
+
631
+ # Save metadata
632
+ metadata_dict = {
633
+ "input_path": source,
634
+ "query_coords": query_coords,
635
+ "patch_coords": (patch_x, patch_y),
636
+ "patch_grid_size": (h_patches, w_patches),
637
+ "image_size": (img_w, img_h),
638
+ "similarity_stats": {
639
+ "max": float(sim_array.max()),
640
+ "min": float(sim_array.min()),
641
+ "mean": float(sim_array.mean()),
642
+ "std": float(sim_array.std()),
643
+ },
644
+ }
645
+
646
+ if save_features:
647
+ metadata_path = os.path.join(
648
+ output_dir, f"metadata_patch_{patch_x}_{patch_y}.json"
649
+ )
650
+ with open(metadata_path, "w", encoding="utf-8") as f:
651
+ json.dump(metadata_dict, f, indent=2)
652
+
653
+ results["output_paths"] = {
654
+ "similarity": similarity_path,
655
+ "metadata": metadata_path,
656
+ "features": features_path if save_features else None,
657
+ }
658
+
659
+ return results
660
+
661
+ def visualize_similarity(
662
+ self,
663
+ source: str,
664
+ similarity_data: np.ndarray,
665
+ query_coords: Tuple[float, float] = None,
666
+ patch_coords: Tuple[int, int] = None,
667
+ figsize: Tuple[int, int] = (15, 6),
668
+ colormap: str = "turbo",
669
+ alpha: float = 0.7,
670
+ save_path: str = None,
671
+ show_query_point: bool = True,
672
+ overlay: bool = False,
673
+ ) -> plt.Figure:
674
+ """Visualize original image and similarity map side by side or as overlay.
675
+
676
+ Args:
677
+ source: Path to original image
678
+ similarity_data: 2D similarity array
679
+ query_coords: Query coordinates in pixel space (x, y)
680
+ patch_coords: Patch coordinates (patch_x, patch_y) for marking query patch
681
+ figsize: Figure size for visualization
682
+ colormap: Colormap for similarity visualization
683
+ alpha: Transparency for overlay mode
684
+ save_path: Optional path to save the visualization
685
+ show_query_point: Whether to show the query point marker
686
+ overlay: If True, overlay similarity on original image; if False, show side by side
687
+
688
+ Returns:
689
+ Matplotlib figure object
690
+ """
691
+ # Load original image
692
+ data, metadata = self.load_image(source)
693
+
694
+ # Convert image data to displayable format
695
+ if data.ndim == 3:
696
+ if data.shape[0] <= 3:
697
+ # Standard RGB/grayscale image (C, H, W)
698
+ display_img = np.transpose(data, (1, 2, 0))
699
+ else:
700
+ # Multi-band image, take first 3 bands
701
+ display_img = np.transpose(data[:3], (1, 2, 0))
702
+ else:
703
+ # Single band image
704
+ display_img = data
705
+
706
+ # Normalize image for display
707
+ if display_img.dtype != np.uint8:
708
+ # Normalize using percentiles
709
+ if display_img.ndim == 3:
710
+ normalized_img = np.zeros_like(display_img, dtype=np.float32)
711
+ for i in range(display_img.shape[2]):
712
+ band = display_img[:, :, i]
713
+ p2, p98 = np.percentile(band, [2, 98])
714
+ normalized_img[:, :, i] = np.clip((band - p2) / (p98 - p2), 0, 1)
715
+ else:
716
+ p2, p98 = np.percentile(display_img, [2, 98])
717
+ normalized_img = np.clip((display_img - p2) / (p98 - p2), 0, 1)
718
+ display_img = normalized_img
719
+ else:
720
+ display_img = display_img / 255.0
721
+
722
+ # Ensure similarity data matches image dimensions
723
+ if similarity_data.shape != display_img.shape[:2]:
724
+ from PIL import Image as PILImage
725
+
726
+ sim_pil = PILImage.fromarray((similarity_data * 255).astype(np.uint8))
727
+ sim_pil = sim_pil.resize(
728
+ (display_img.shape[1], display_img.shape[0]), PILImage.LANCZOS
729
+ )
730
+ similarity_data = np.array(sim_pil, dtype=np.float32) / 255.0
731
+
732
+ if overlay:
733
+ # Single plot with overlay
734
+ fig, ax = plt.subplots(1, 1, figsize=(figsize[1], figsize[1]))
735
+
736
+ # Show original image
737
+ if display_img.ndim == 2:
738
+ ax.imshow(display_img, cmap="gray")
739
+ else:
740
+ ax.imshow(display_img)
741
+
742
+ # Overlay similarity map
743
+ im_sim = ax.imshow(
744
+ similarity_data, cmap=colormap, alpha=alpha, vmin=0, vmax=1
745
+ )
746
+
747
+ # Add colorbar for similarity
748
+ cbar = plt.colorbar(im_sim, ax=ax, fraction=0.046, pad=0.04)
749
+ cbar.set_label("Similarity", rotation=270, labelpad=20)
750
+
751
+ ax.set_title("Image with Similarity Overlay")
752
+
753
+ else:
754
+ # Side-by-side visualization
755
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
756
+
757
+ # Original image
758
+ if display_img.ndim == 2:
759
+ ax1.imshow(display_img, cmap="gray")
760
+ else:
761
+ ax1.imshow(display_img)
762
+ ax1.set_title("Original Image")
763
+ ax1.axis("off")
764
+
765
+ # Similarity map
766
+ im_sim = ax2.imshow(similarity_data, cmap=colormap, vmin=0, vmax=1)
767
+ ax2.set_title("Similarity Map")
768
+ ax2.axis("off")
769
+
770
+ # Add colorbar
771
+ cbar = plt.colorbar(im_sim, ax=ax2, fraction=0.046, pad=0.04)
772
+ cbar.set_label("Similarity", rotation=270, labelpad=20)
773
+
774
+ # Mark query point if provided
775
+ if show_query_point and query_coords is not None:
776
+ x, y = query_coords
777
+ if overlay:
778
+ ax.plot(
779
+ x,
780
+ y,
781
+ "r*",
782
+ markersize=15,
783
+ markeredgecolor="white",
784
+ markeredgewidth=2,
785
+ )
786
+ ax.plot(x, y, "r*", markersize=12)
787
+ else:
788
+ ax1.plot(
789
+ x,
790
+ y,
791
+ "r*",
792
+ markersize=15,
793
+ markeredgecolor="white",
794
+ markeredgewidth=2,
795
+ )
796
+ ax1.plot(x, y, "r*", markersize=12)
797
+ ax2.plot(
798
+ x,
799
+ y,
800
+ "r*",
801
+ markersize=15,
802
+ markeredgecolor="white",
803
+ markeredgewidth=2,
804
+ )
805
+ ax2.plot(x, y, "r*", markersize=12)
806
+
807
+ plt.tight_layout()
808
+
809
+ # Save if path provided
810
+ if save_path:
811
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
812
+
813
+ return fig
814
+
815
+ def visualize_patches(
816
+ self,
817
+ image: Image.Image,
818
+ features: torch.Tensor,
819
+ patch_coords: Tuple[int, int],
820
+ add_text: bool = False,
821
+ figsize: Tuple[int, int] = (12, 8),
822
+ save_path: str = None,
823
+ ) -> plt.Figure:
824
+ """Visualize image with patch grid and highlight selected patch.
825
+
826
+ Args:
827
+ image: PIL Image
828
+ features: Feature tensor (h_patches, w_patches, embed_dim)
829
+ patch_coords: Selected patch coordinates (patch_x, patch_y)
830
+ add_text: Whether to add text to the patch
831
+ figsize: Figure size
832
+ save_path: Optional path to save visualization
833
+
834
+ Returns:
835
+ Matplotlib figure object
836
+ """
837
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
838
+
839
+ # Display image
840
+ ax.imshow(image)
841
+ ax.set_title("Image with Patch Grid")
842
+ ax.axis("off")
843
+
844
+ # Get dimensions
845
+ img_w, img_h = image.size
846
+ h_patches, w_patches = features.shape[:2]
847
+ patch_x, patch_y = patch_coords
848
+
849
+ # Calculate patch size in pixels
850
+ patch_w = img_w / w_patches
851
+ patch_h = img_h / h_patches
852
+
853
+ # Draw patch grid
854
+ for i in range(w_patches + 1):
855
+ x = i * patch_w
856
+ ax.axvline(x=x, color="white", alpha=0.3, linewidth=0.5)
857
+
858
+ for i in range(h_patches + 1):
859
+ y = i * patch_h
860
+ ax.axhline(y=y, color="white", alpha=0.3, linewidth=0.5)
861
+
862
+ # Highlight selected patch
863
+ rect_x = patch_x * patch_w
864
+ rect_y = patch_y * patch_h
865
+ rect = patches.Rectangle(
866
+ (rect_x, rect_y),
867
+ patch_w,
868
+ patch_h,
869
+ linewidth=3,
870
+ edgecolor="red",
871
+ facecolor="none",
872
+ )
873
+ ax.add_patch(rect)
874
+
875
+ # Add patch coordinate text
876
+ if add_text:
877
+ ax.text(
878
+ rect_x + patch_w / 2,
879
+ rect_y + patch_h / 2,
880
+ f"({patch_x}, {patch_y})",
881
+ color="red",
882
+ fontsize=12,
883
+ ha="center",
884
+ va="center",
885
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
886
+ )
887
+
888
+ plt.tight_layout()
889
+
890
+ if save_path:
891
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
892
+
893
+ return fig
894
+
895
+ def create_similarity_overlay(
896
+ self,
897
+ source: str,
898
+ similarity_data: np.ndarray,
899
+ colormap: str = "turbo",
900
+ alpha: float = 0.7,
901
+ ) -> np.ndarray:
902
+ """Create an overlay of similarity map on original image.
903
+
904
+ Args:
905
+ source: Path to original image
906
+ similarity_data: 2D similarity array
907
+ colormap: Colormap for similarity visualization
908
+ alpha: Transparency for overlay
909
+
910
+ Returns:
911
+ RGB overlay image as numpy array
912
+ """
913
+ # Load original image
914
+ data, _ = self.load_image(source)
915
+
916
+ # Convert to display format
917
+ if data.ndim == 3:
918
+ if data.shape[0] <= 3:
919
+ display_img = np.transpose(data, (1, 2, 0))
920
+ else:
921
+ display_img = np.transpose(data[:3], (1, 2, 0))
922
+ else:
923
+ display_img = data
924
+
925
+ # Normalize image
926
+ if display_img.dtype != np.uint8:
927
+ if display_img.ndim == 3:
928
+ normalized_img = np.zeros_like(display_img, dtype=np.float32)
929
+ for i in range(display_img.shape[2]):
930
+ band = display_img[:, :, i]
931
+ p2, p98 = np.percentile(band, [2, 98])
932
+ normalized_img[:, :, i] = np.clip((band - p2) / (p98 - p2), 0, 1)
933
+ else:
934
+ p2, p98 = np.percentile(display_img, [2, 98])
935
+ normalized_img = np.clip((display_img - p2) / (p98 - p2), 0, 1)
936
+ base_img = normalized_img
937
+ else:
938
+ base_img = display_img / 255.0
939
+
940
+ # Convert grayscale to RGB if needed
941
+ if base_img.ndim == 2:
942
+ base_img = np.stack([base_img] * 3, axis=2)
943
+
944
+ # Resize similarity data to match image
945
+ if similarity_data.shape != base_img.shape[:2]:
946
+ from PIL import Image as PILImage
947
+
948
+ sim_pil = PILImage.fromarray((similarity_data * 255).astype(np.uint8))
949
+ sim_pil = sim_pil.resize(
950
+ (base_img.shape[1], base_img.shape[0]), PILImage.LANCZOS
951
+ )
952
+ similarity_data = np.array(sim_pil, dtype=np.float32) / 255.0
953
+
954
+ # Apply colormap to similarity data
955
+ cmap = plt.get_cmap(colormap)
956
+ colored_similarity = cmap(similarity_data)[:, :, :3] # Remove alpha channel
957
+
958
+ # Blend images
959
+ overlay_img = (1 - alpha) * base_img + alpha * colored_similarity
960
+
961
+ return np.clip(overlay_img, 0, 1)
962
+
963
+ def batch_similarity_analysis(
964
+ self,
965
+ input_path: str,
966
+ query_points: List[Tuple[float, float]],
967
+ output_dir: str,
968
+ window: Optional[Window] = None,
969
+ bands: Optional[List[int]] = None,
970
+ target_size: int = 896,
971
+ ) -> List[Dict[str, np.ndarray]]:
972
+ """Process multiple query points for similarity analysis.
973
+
974
+ Args:
975
+ input_path: Path to input GeoTIFF
976
+ query_points: List of (x, y) coordinates
977
+ output_dir: Output directory for results
978
+ window: Optional window for reading subset
979
+ bands: Optional list of bands to use
980
+ target_size: Target size for processing
981
+
982
+ Returns:
983
+ List of result dictionaries
984
+ """
985
+ results = []
986
+ for i, coords in enumerate(query_points):
987
+ point_output_dir = os.path.join(output_dir, f"point_{i}")
988
+ result = self.compute_similarity(
989
+ source=input_path,
990
+ query_coords=coords,
991
+ output_dir=point_output_dir,
992
+ window=window,
993
+ bands=bands,
994
+ target_size=target_size,
995
+ )
996
+ results.append(result)
997
+
998
+ return results
999
+
1000
+
1001
+ def create_similarity_map(
1002
+ input_image: str,
1003
+ query_coords: Tuple[float, float],
1004
+ output_dir: str,
1005
+ model_name: str = "dinov3_vitl16",
1006
+ weights_path: Optional[str] = None,
1007
+ window: Optional[Window] = None,
1008
+ bands: Optional[List[int]] = None,
1009
+ target_size: int = 896,
1010
+ save_features: bool = False,
1011
+ coord_crs: str = None,
1012
+ use_interpolation: bool = True,
1013
+ ) -> Dict[str, np.ndarray]:
1014
+ """Convenience function to create similarity map from image file.
1015
+
1016
+ Args:
1017
+ input_image: Path to input image file (GeoTIFF, PNG, JPG, etc.)
1018
+ query_coords: Query coordinates (x, y) in pixel space
1019
+ output_dir: Output directory
1020
+ model_name: DINOv3 model name
1021
+ weights_path: Optional path to model weights
1022
+ window: Optional rasterio window (only applies to GeoTIFF)
1023
+ bands: Optional list of bands to use (only applies to GeoTIFF)
1024
+ target_size: Target size for processing
1025
+ save_features: Whether to save extracted features
1026
+ coord_crs: Coordinate CRS of the query coordinates (only applies to GeoTIFF)
1027
+ use_interpolation: Whether to use interpolation when resizing similarity map
1028
+
1029
+ Returns:
1030
+ Dictionary containing results
1031
+ """
1032
+ processor = DINOv3GeoProcessor(model_name=model_name, weights_path=weights_path)
1033
+
1034
+ return processor.compute_similarity(
1035
+ source=input_image,
1036
+ query_coords=query_coords,
1037
+ output_dir=output_dir,
1038
+ window=window,
1039
+ bands=bands,
1040
+ target_size=target_size,
1041
+ save_features=save_features,
1042
+ coord_crs=coord_crs,
1043
+ use_interpolation=use_interpolation,
1044
+ )
1045
+
1046
+
1047
+ def analyze_image_patches(
1048
+ input_image: str,
1049
+ query_points: List[Tuple[float, float]],
1050
+ output_dir: str,
1051
+ model_name: str = "dinov3_vitl16",
1052
+ weights_path: Optional[str] = None,
1053
+ ) -> List[Dict[str, np.ndarray]]:
1054
+ """Analyze multiple patches in an image file.
1055
+
1056
+ Args:
1057
+ input_image: Path to input image file (GeoTIFF, PNG, JPG, etc.)
1058
+ query_points: List of query coordinates
1059
+ output_dir: Output directory
1060
+ model_name: DINOv3 model name
1061
+ weights_path: Optional path to model weights
1062
+
1063
+ Returns:
1064
+ List of result dictionaries
1065
+ """
1066
+ processor = DINOv3GeoProcessor(model_name=model_name, weights_path=weights_path)
1067
+
1068
+ return processor.batch_similarity_analysis(input_image, query_points, output_dir)
1069
+
1070
+
1071
+ def visualize_similarity_results(
1072
+ input_image: str,
1073
+ query_coords: Tuple[float, float],
1074
+ output_dir: str = None,
1075
+ model_name: str = "dinov3_vitl16",
1076
+ weights_path: Optional[str] = None,
1077
+ figsize: Tuple[int, int] = (15, 6),
1078
+ colormap: str = "turbo",
1079
+ alpha: float = 0.7,
1080
+ save_path: str = None,
1081
+ show_query_point: bool = True,
1082
+ overlay: bool = False,
1083
+ target_size: int = 896,
1084
+ coord_crs: str = None,
1085
+ use_interpolation: bool = True,
1086
+ ) -> Dict:
1087
+ """Create similarity map and visualize results in one function.
1088
+
1089
+ Args:
1090
+ input_image: Path to input image file (GeoTIFF, PNG, JPG, etc.)
1091
+ query_coords: Query coordinates (x, y) in pixel space
1092
+ output_dir: Output directory for similarity map files (optional)
1093
+ model_name: DINOv3 model name
1094
+ weights_path: Optional path to model weights
1095
+ figsize: Figure size for visualization
1096
+ colormap: Colormap for similarity visualization
1097
+ alpha: Transparency for overlay mode
1098
+ save_path: Optional path to save the visualization
1099
+ show_query_point: Whether to show the query point marker
1100
+ overlay: If True, overlay similarity on original image; if False, show side by side
1101
+ target_size: Target size for processing
1102
+ coord_crs: Coordinate CRS of the query coordinates
1103
+ use_interpolation: Whether to use interpolation when resizing similarity map
1104
+
1105
+ Returns:
1106
+ Dictionary containing similarity results, metadata, and matplotlib figure
1107
+ """
1108
+ processor = DINOv3GeoProcessor(model_name=model_name, weights_path=weights_path)
1109
+
1110
+ # Create temporary output directory if not provided
1111
+ if output_dir is None:
1112
+ import tempfile
1113
+
1114
+ output_dir = tempfile.mkdtemp(prefix="dinov3_similarity_")
1115
+
1116
+ # Compute similarity
1117
+ results = processor.compute_similarity(
1118
+ source=input_image,
1119
+ query_coords=query_coords,
1120
+ output_dir=output_dir,
1121
+ target_size=target_size,
1122
+ coord_crs=coord_crs,
1123
+ use_interpolation=use_interpolation,
1124
+ )
1125
+
1126
+ # Get similarity data from results
1127
+ similarity_data = results["image_dict"]["image"][0] # Remove channel dimension
1128
+
1129
+ # Create visualization
1130
+ fig = processor.visualize_similarity(
1131
+ source=input_image,
1132
+ similarity_data=similarity_data,
1133
+ query_coords=query_coords,
1134
+ patch_coords=results["patch_coords"],
1135
+ figsize=figsize,
1136
+ colormap=colormap,
1137
+ alpha=alpha,
1138
+ save_path=save_path,
1139
+ show_query_point=show_query_point,
1140
+ overlay=overlay,
1141
+ )
1142
+
1143
+ # Add figure to results
1144
+ results["visualization"] = fig
1145
+
1146
+ return results