geoai-py 0.18.2__py2.py3-none-any.whl → 0.20.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/moondream.py ADDED
@@ -0,0 +1,990 @@
1
+ """Moondream Vision Language Model module for GeoAI.
2
+
3
+ This module provides an interface for using Moondream vision language models
4
+ (moondream2 and moondream3-preview) with geospatial imagery, supporting
5
+ GeoTIFF input and georeferenced output.
6
+
7
+ Supported models:
8
+ - moondream2: https://huggingface.co/vikhyatk/moondream2
9
+ - moondream3-preview: https://huggingface.co/moondream/moondream3-preview
10
+ """
11
+
12
+ import os
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+
15
+ import geopandas as gpd
16
+ import numpy as np
17
+ import rasterio
18
+ import torch
19
+ from PIL import Image
20
+ from shapely.geometry import Point, box
21
+ from transformers.utils import logging as hf_logging
22
+
23
+ from .utils import get_device
24
+
25
+
26
+ hf_logging.set_verbosity_error() # silence HF load reports
27
+
28
+
29
+ class MoondreamGeo:
30
+ """Moondream Vision Language Model processor with GeoTIFF support.
31
+
32
+ This class provides an interface for using Moondream models for
33
+ geospatial image analysis, including captioning, visual querying,
34
+ object detection, and pointing.
35
+
36
+ Attributes:
37
+ model: The loaded Moondream model.
38
+ model_name: Name of the model being used.
39
+ device: Torch device for inference.
40
+ model_version: Either "moondream2" or "moondream3".
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ model_name: str = "vikhyatk/moondream2",
46
+ revision: Optional[str] = None,
47
+ device: Optional[str] = None,
48
+ compile_model: bool = False,
49
+ **kwargs: Any,
50
+ ) -> None:
51
+ """Initialize the Moondream processor.
52
+
53
+ Args:
54
+ model_name: HuggingFace model name. Options:
55
+ - "vikhyatk/moondream2" (default)
56
+ - "moondream/moondream3-preview"
57
+ revision: Model revision/checkpoint date. For moondream2, recommended
58
+ to use a specific date like "2025-06-21" for reproducibility.
59
+ device: Device for inference ("cuda", "mps", or "cpu").
60
+ If None, automatically selects the best available device.
61
+ compile_model: Whether to compile the model (recommended for
62
+ moondream3-preview for faster inference).
63
+ **kwargs: Additional arguments passed to from_pretrained.
64
+
65
+ Raises:
66
+ ImportError: If transformers is not installed.
67
+ RuntimeError: If model loading fails.
68
+ """
69
+ self.model_name = model_name
70
+ self.device = device or get_device()
71
+ self._source_path: Optional[str] = None
72
+ self._metadata: Optional[Dict] = None
73
+
74
+ # Determine model version
75
+ if "moondream3" in model_name.lower():
76
+ self.model_version = "moondream3"
77
+ else:
78
+ self.model_version = "moondream2"
79
+
80
+ # Load the model
81
+ self.model = self._load_model(revision, compile_model, **kwargs)
82
+
83
+ def _load_model(
84
+ self,
85
+ revision: Optional[str],
86
+ compile_model: bool,
87
+ **kwargs: Any,
88
+ ) -> Any:
89
+ """Load the Moondream model.
90
+
91
+ Args:
92
+ revision: Model revision/checkpoint.
93
+ compile_model: Whether to compile the model.
94
+ **kwargs: Additional arguments for from_pretrained.
95
+
96
+ Returns:
97
+ Loaded model instance.
98
+
99
+ Raises:
100
+ RuntimeError: If model loading fails.
101
+ """
102
+ try:
103
+ from transformers import AutoModelForCausalLM
104
+
105
+ # Default kwargs
106
+ load_kwargs = {
107
+ "trust_remote_code": True,
108
+ }
109
+
110
+ # Try to use device_map with accelerate, fall back to manual device placement
111
+ use_device_map = False
112
+ try:
113
+ import accelerate # noqa: F401
114
+
115
+ # Build device map
116
+ if isinstance(self.device, str):
117
+ device_map = {"": self.device}
118
+ else:
119
+ device_map = {"": str(self.device)}
120
+ load_kwargs["device_map"] = device_map
121
+ use_device_map = True
122
+ except ImportError:
123
+ # accelerate not available, will move model to device manually
124
+ pass
125
+
126
+ # Add revision if specified
127
+ if revision:
128
+ load_kwargs["revision"] = revision
129
+
130
+ # For moondream3, use bfloat16
131
+ if self.model_version == "moondream3":
132
+ load_kwargs["torch_dtype"] = torch.bfloat16
133
+ # Note: moondream3 uses dtype instead of torch_dtype in some versions
134
+ load_kwargs["dtype"] = torch.bfloat16
135
+
136
+ load_kwargs.update(kwargs)
137
+
138
+ print(f"Loading {self.model_name}...")
139
+
140
+ # Try to load with potential transformers 5.0 compatibility fix
141
+ try:
142
+ model = AutoModelForCausalLM.from_pretrained(
143
+ self.model_name,
144
+ **load_kwargs,
145
+ )
146
+ except AttributeError as attr_err:
147
+ # Handle transformers 5.0 compatibility issue with custom models
148
+ if "all_tied_weights_keys" in str(attr_err):
149
+ # print(
150
+ # "Note: Detected transformers 5.0+ compatibility issue. "
151
+ # "Attempting workaround..."
152
+ # )
153
+ # Try patching the model class
154
+ model = self._load_with_patch(load_kwargs)
155
+ else:
156
+ raise
157
+
158
+ # Move model to device if accelerate wasn't used
159
+ if not use_device_map:
160
+ device = (
161
+ self.device
162
+ if isinstance(self.device, torch.device)
163
+ else torch.device(self.device)
164
+ )
165
+ model = model.to(device)
166
+ # print(f"Model moved to {device}")
167
+
168
+ # Set model to evaluation mode
169
+ model.eval()
170
+
171
+ # Compile model if requested (recommended for moondream3)
172
+ if compile_model and hasattr(model, "compile"):
173
+ print("Compiling model for faster inference...")
174
+ model.compile()
175
+
176
+ print(f"Using device: {self.device}")
177
+ return model
178
+
179
+ except Exception as e:
180
+ # Provide helpful error message
181
+ error_msg = str(e)
182
+ if "all_tied_weights_keys" in error_msg:
183
+ error_msg = (
184
+ f"Failed to load Moondream model due to transformers version "
185
+ f"incompatibility. The model's custom code may not be compatible "
186
+ f"with your current transformers version. Try: "
187
+ f"1) Wait for an updated model revision, or "
188
+ f"2) Use a compatible transformers version. "
189
+ f"Original error: {e}"
190
+ )
191
+ raise RuntimeError(f"Failed to load Moondream model: {error_msg}") from e
192
+
193
+ def _load_with_patch(self, load_kwargs: Dict) -> Any:
194
+ """Load model with compatibility patch for transformers 5.0+.
195
+
196
+ Args:
197
+ load_kwargs: Keyword arguments for from_pretrained.
198
+
199
+ Returns:
200
+ Loaded model instance.
201
+ """
202
+ from transformers import AutoModelForCausalLM, PreTrainedModel
203
+
204
+ # Patch the PreTrainedModel class to add missing attribute
205
+ original_getattr = PreTrainedModel.__getattr__
206
+
207
+ def patched_getattr(self, name):
208
+ if name == "all_tied_weights_keys":
209
+ # Return empty dict to satisfy the check
210
+ if not hasattr(self, "_all_tied_weights_keys"):
211
+ self._all_tied_weights_keys = {}
212
+ return self._all_tied_weights_keys
213
+ return original_getattr(self, name)
214
+
215
+ # Apply patch temporarily
216
+ PreTrainedModel.__getattr__ = patched_getattr
217
+
218
+ try:
219
+ model = AutoModelForCausalLM.from_pretrained(
220
+ self.model_name,
221
+ **load_kwargs,
222
+ )
223
+ return model
224
+ finally:
225
+ # Restore original
226
+ PreTrainedModel.__getattr__ = original_getattr
227
+
228
+ def load_geotiff(
229
+ self,
230
+ source: str,
231
+ bands: Optional[List[int]] = None,
232
+ ) -> Tuple[Image.Image, Dict]:
233
+ """Load a GeoTIFF file and return a PIL Image with metadata.
234
+
235
+ Args:
236
+ source: Path to GeoTIFF file.
237
+ bands: List of band indices to read (1-indexed). If None, reads
238
+ first 3 bands for RGB or first band for grayscale.
239
+
240
+ Returns:
241
+ Tuple of (PIL Image, metadata dict).
242
+
243
+ Raises:
244
+ FileNotFoundError: If source file doesn't exist.
245
+ RuntimeError: If loading fails.
246
+ """
247
+ if not os.path.exists(source):
248
+ raise FileNotFoundError(f"File not found: {source}")
249
+
250
+ try:
251
+ with rasterio.open(source) as src:
252
+ # Store metadata
253
+ metadata = {
254
+ "profile": src.profile.copy(),
255
+ "crs": src.crs,
256
+ "transform": src.transform,
257
+ "bounds": src.bounds,
258
+ "width": src.width,
259
+ "height": src.height,
260
+ }
261
+
262
+ # Read bands
263
+ if bands is None:
264
+ if src.count >= 3:
265
+ bands = [1, 2, 3] # RGB
266
+ else:
267
+ bands = [1] # Grayscale
268
+
269
+ data = src.read(bands)
270
+
271
+ # Convert to RGB image
272
+ if len(bands) == 1:
273
+ # Grayscale to RGB
274
+ img_array = np.repeat(data[0:1], 3, axis=0)
275
+ elif len(bands) >= 3:
276
+ img_array = data[:3]
277
+ else:
278
+ # Pad to 3 channels
279
+ img_array = np.zeros((3, data.shape[1], data.shape[2]))
280
+ img_array[: data.shape[0]] = data
281
+
282
+ # Normalize to 0-255 range
283
+ img_array = self._normalize_image(img_array)
284
+
285
+ # Convert to PIL Image (HWC format)
286
+ img_array = np.transpose(img_array, (1, 2, 0))
287
+ image = Image.fromarray(img_array.astype(np.uint8))
288
+
289
+ self._source_path = source
290
+ self._metadata = metadata
291
+
292
+ return image, metadata
293
+
294
+ except Exception as e:
295
+ raise RuntimeError(f"Failed to load GeoTIFF: {e}") from e
296
+
297
+ def load_image(
298
+ self,
299
+ source: Union[str, Image.Image, np.ndarray],
300
+ bands: Optional[List[int]] = None,
301
+ ) -> Tuple[Image.Image, Optional[Dict]]:
302
+ """Load an image from various sources.
303
+
304
+ Args:
305
+ source: Image source - can be a file path (GeoTIFF, PNG, JPG),
306
+ PIL Image, or numpy array.
307
+ bands: Band indices for GeoTIFF (1-indexed).
308
+
309
+ Returns:
310
+ Tuple of (PIL Image, metadata dict or None).
311
+ """
312
+ if isinstance(source, Image.Image):
313
+ self._source_path = None
314
+ self._metadata = None
315
+ return source, None
316
+
317
+ if isinstance(source, np.ndarray):
318
+ if source.ndim == 2:
319
+ # Grayscale
320
+ source = np.stack([source] * 3, axis=-1)
321
+ elif source.ndim == 3 and source.shape[0] <= 4:
322
+ # CHW format
323
+ source = np.transpose(source[:3], (1, 2, 0))
324
+
325
+ source = self._normalize_image(source)
326
+ image = Image.fromarray(source.astype(np.uint8))
327
+ self._source_path = None
328
+ self._metadata = None
329
+ return image, None
330
+
331
+ if isinstance(source, str):
332
+ if source.startswith(("http://", "https://")):
333
+ # URL - download and load
334
+ from .utils import download_file
335
+
336
+ source = download_file(source)
337
+
338
+ # Check if it's a GeoTIFF
339
+ try:
340
+ with rasterio.open(source) as src:
341
+ if src.crs is not None or source.lower().endswith(
342
+ (".tif", ".tiff")
343
+ ):
344
+ return self.load_geotiff(source, bands)
345
+ except rasterio.RasterioIOError:
346
+ pass
347
+
348
+ # Regular image
349
+ image = Image.open(source).convert("RGB")
350
+ self._source_path = source
351
+ self._metadata = {
352
+ "width": image.width,
353
+ "height": image.height,
354
+ "crs": None,
355
+ "transform": None,
356
+ "bounds": None,
357
+ }
358
+ return image, self._metadata
359
+
360
+ def _normalize_image(self, data: np.ndarray) -> np.ndarray:
361
+ """Normalize image data to 0-255 range using percentile stretching.
362
+
363
+ Args:
364
+ data: Input array (can be CHW or HWC format).
365
+
366
+ Returns:
367
+ Normalized array in uint8 range.
368
+ """
369
+ if data.dtype == np.uint8:
370
+ return data
371
+
372
+ # Determine if CHW or HWC
373
+ if data.ndim == 3 and data.shape[0] <= 4:
374
+ # CHW format - normalize each channel
375
+ normalized = np.zeros_like(data, dtype=np.float32)
376
+ for i in range(data.shape[0]):
377
+ band = data[i].astype(np.float32)
378
+ p2, p98 = np.percentile(band, [2, 98])
379
+ if p98 > p2:
380
+ normalized[i] = np.clip((band - p2) / (p98 - p2) * 255, 0, 255)
381
+ else:
382
+ normalized[i] = np.clip(band, 0, 255)
383
+ else:
384
+ # HWC format or 2D
385
+ data = data.astype(np.float32)
386
+ p2, p98 = np.percentile(data, [2, 98])
387
+ if p98 > p2:
388
+ normalized = np.clip((data - p2) / (p98 - p2) * 255, 0, 255)
389
+ else:
390
+ normalized = np.clip(data, 0, 255)
391
+
392
+ return normalized.astype(np.uint8)
393
+
394
+ def encode_image(
395
+ self,
396
+ source: Union[str, Image.Image, np.ndarray],
397
+ bands: Optional[List[int]] = None,
398
+ ) -> Any:
399
+ """Pre-encode an image for efficient multiple inferences.
400
+
401
+ Use this when you plan to run multiple queries on the same image.
402
+
403
+ Args:
404
+ source: Image source.
405
+ bands: Band indices for GeoTIFF.
406
+
407
+ Returns:
408
+ Encoded image that can be passed to query, caption, etc.
409
+ """
410
+ image, _ = self.load_image(source, bands)
411
+
412
+ if hasattr(self.model, "encode_image"):
413
+ return self.model.encode_image(image)
414
+ return image
415
+
416
+ def caption(
417
+ self,
418
+ source: Union[str, Image.Image, np.ndarray, Any],
419
+ length: str = "normal",
420
+ stream: bool = False,
421
+ bands: Optional[List[int]] = None,
422
+ settings: Optional[Dict] = None,
423
+ **kwargs: Any,
424
+ ) -> Dict[str, Any]:
425
+ """Generate a caption for an image.
426
+
427
+ Args:
428
+ source: Image source or pre-encoded image.
429
+ length: Caption length - "short", "normal", or "long".
430
+ stream: Whether to stream the output.
431
+ bands: Band indices for GeoTIFF.
432
+ settings: Additional settings (temperature, top_p, max_tokens).
433
+ **kwargs: Additional arguments for the model.
434
+
435
+ Returns:
436
+ Dictionary with "caption" key containing the generated caption.
437
+ """
438
+ # Load image if not pre-encoded
439
+ if isinstance(source, (str, Image.Image, np.ndarray)):
440
+ image, _ = self.load_image(source, bands)
441
+ else:
442
+ image = source # Pre-encoded
443
+
444
+ call_kwargs = {"length": length, "stream": stream}
445
+ if settings:
446
+ call_kwargs["settings"] = settings
447
+ call_kwargs.update(kwargs)
448
+
449
+ return self.model.caption(image, **call_kwargs)
450
+
451
+ def query(
452
+ self,
453
+ question: str,
454
+ source: Optional[Union[str, Image.Image, np.ndarray, Any]] = None,
455
+ reasoning: Optional[bool] = None,
456
+ stream: bool = False,
457
+ bands: Optional[List[int]] = None,
458
+ settings: Optional[Dict] = None,
459
+ **kwargs: Any,
460
+ ) -> Dict[str, Any]:
461
+ """Ask a question about an image or text-only query.
462
+
463
+ Args:
464
+ question: The question to ask.
465
+ source: Image source or pre-encoded image. If None, performs
466
+ text-only query (moondream3 only).
467
+ reasoning: Enable reasoning mode for more complex tasks
468
+ (moondream3 only, default True).
469
+ stream: Whether to stream the output.
470
+ bands: Band indices for GeoTIFF.
471
+ settings: Additional settings (temperature, top_p, max_tokens).
472
+ **kwargs: Additional arguments for the model.
473
+
474
+ Returns:
475
+ Dictionary with "answer" key containing the response.
476
+ """
477
+ call_kwargs = {"question": question, "stream": stream}
478
+
479
+ if source is not None:
480
+ if isinstance(source, (str, Image.Image, np.ndarray)):
481
+ image, _ = self.load_image(source, bands)
482
+ else:
483
+ image = source
484
+ call_kwargs["image"] = image
485
+
486
+ if reasoning is not None and self.model_version == "moondream3":
487
+ call_kwargs["reasoning"] = reasoning
488
+
489
+ if settings:
490
+ call_kwargs["settings"] = settings
491
+ call_kwargs.update(kwargs)
492
+
493
+ return self.model.query(**call_kwargs)
494
+
495
+ def detect(
496
+ self,
497
+ source: Union[str, Image.Image, np.ndarray, Any],
498
+ object_type: str,
499
+ bands: Optional[List[int]] = None,
500
+ output_path: Optional[str] = None,
501
+ settings: Optional[Dict] = None,
502
+ stream: bool = False,
503
+ **kwargs: Any,
504
+ ) -> Dict[str, Any]:
505
+ """Detect objects of a specific type in an image.
506
+
507
+ Args:
508
+ source: Image source or pre-encoded image.
509
+ object_type: Type of object to detect (e.g., "car", "building").
510
+ bands: Band indices for GeoTIFF.
511
+ output_path: Path to save results as GeoJSON/Shapefile/GeoPackage.
512
+ settings: Additional settings (max_objects, etc.).
513
+ stream: Whether to stream the output (moondream2 only).
514
+ **kwargs: Additional arguments for the model.
515
+
516
+ Returns:
517
+ Dictionary with "objects" key containing list of bounding boxes
518
+ with normalized coordinates (x_min, y_min, x_max, y_max).
519
+ If georeferenced, also includes "gdf" (GeoDataFrame) and
520
+ "crs", "bounds" keys.
521
+ """
522
+ # Load image
523
+ if isinstance(source, (str, Image.Image, np.ndarray)):
524
+ image, metadata = self.load_image(source, bands)
525
+ else:
526
+ image = source
527
+ metadata = self._metadata
528
+
529
+ call_kwargs = {}
530
+ if settings:
531
+ call_kwargs["settings"] = settings
532
+ if self.model_version == "moondream2" and stream:
533
+ call_kwargs["stream"] = stream
534
+ call_kwargs.update(kwargs)
535
+
536
+ result = self.model.detect(image, object_type, **call_kwargs)
537
+
538
+ # Convert to georeferenced if possible
539
+ if metadata and metadata.get("crs") and metadata.get("transform"):
540
+ result = self._georef_detections(result, metadata)
541
+
542
+ if output_path:
543
+ self._save_vector(result["gdf"], output_path)
544
+
545
+ return result
546
+
547
+ def point(
548
+ self,
549
+ source: Union[str, Image.Image, np.ndarray, Any],
550
+ object_description: str,
551
+ bands: Optional[List[int]] = None,
552
+ output_path: Optional[str] = None,
553
+ **kwargs: Any,
554
+ ) -> Dict[str, Any]:
555
+ """Find points (x, y coordinates) for objects in an image.
556
+
557
+ Args:
558
+ source: Image source or pre-encoded image.
559
+ object_description: Description of objects to find.
560
+ bands: Band indices for GeoTIFF.
561
+ output_path: Path to save results as GeoJSON/Shapefile/GeoPackage.
562
+ **kwargs: Additional arguments for the model.
563
+
564
+ Returns:
565
+ Dictionary with "points" key containing list of points
566
+ with normalized coordinates (x, y in 0-1 range).
567
+ If georeferenced, also includes "gdf" (GeoDataFrame).
568
+ """
569
+ # Load image
570
+ if isinstance(source, (str, Image.Image, np.ndarray)):
571
+ image, metadata = self.load_image(source, bands)
572
+ else:
573
+ image = source
574
+ metadata = self._metadata
575
+
576
+ result = self.model.point(image, object_description, **kwargs)
577
+
578
+ # Convert to georeferenced if possible
579
+ if metadata and metadata.get("crs") and metadata.get("transform"):
580
+ result = self._georef_points(result, metadata)
581
+
582
+ if output_path:
583
+ self._save_vector(result["gdf"], output_path)
584
+
585
+ return result
586
+
587
+ def _georef_detections(
588
+ self,
589
+ result: Dict[str, Any],
590
+ metadata: Dict,
591
+ ) -> Dict[str, Any]:
592
+ """Convert detection results to georeferenced format.
593
+
594
+ Args:
595
+ result: Detection result from model.
596
+ metadata: Image metadata with CRS and transform.
597
+
598
+ Returns:
599
+ Updated result dictionary with GeoDataFrame.
600
+ """
601
+ objects = result.get("objects", [])
602
+ if not objects:
603
+ result["gdf"] = gpd.GeoDataFrame(
604
+ columns=["geometry", "x_min", "y_min", "x_max", "y_max"],
605
+ crs=metadata["crs"],
606
+ )
607
+ result["crs"] = metadata["crs"]
608
+ result["bounds"] = metadata["bounds"]
609
+ return result
610
+
611
+ width = metadata["width"]
612
+ height = metadata["height"]
613
+ transform = metadata["transform"]
614
+
615
+ geometries = []
616
+ records = []
617
+
618
+ for obj in objects:
619
+ # Convert normalized coords to pixel coords
620
+ px_x_min = obj["x_min"] * width
621
+ px_y_min = obj["y_min"] * height
622
+ px_x_max = obj["x_max"] * width
623
+ px_y_max = obj["y_max"] * height
624
+
625
+ # Convert pixel coords to geographic coords
626
+ geo_x_min, geo_y_min = transform * (px_x_min, px_y_max)
627
+ geo_x_max, geo_y_max = transform * (px_x_max, px_y_min)
628
+
629
+ # Create polygon
630
+ geom = box(geo_x_min, geo_y_min, geo_x_max, geo_y_max)
631
+ geometries.append(geom)
632
+
633
+ records.append(
634
+ {
635
+ "x_min": obj["x_min"],
636
+ "y_min": obj["y_min"],
637
+ "x_max": obj["x_max"],
638
+ "y_max": obj["y_max"],
639
+ "px_x_min": int(px_x_min),
640
+ "px_y_min": int(px_y_min),
641
+ "px_x_max": int(px_x_max),
642
+ "px_y_max": int(px_y_max),
643
+ }
644
+ )
645
+
646
+ gdf = gpd.GeoDataFrame(records, geometry=geometries, crs=metadata["crs"])
647
+
648
+ result["gdf"] = gdf
649
+ result["crs"] = metadata["crs"]
650
+ result["bounds"] = metadata["bounds"]
651
+
652
+ return result
653
+
654
+ def _georef_points(
655
+ self,
656
+ result: Dict[str, Any],
657
+ metadata: Dict,
658
+ ) -> Dict[str, Any]:
659
+ """Convert point results to georeferenced format.
660
+
661
+ Args:
662
+ result: Point result from model.
663
+ metadata: Image metadata with CRS and transform.
664
+
665
+ Returns:
666
+ Updated result dictionary with GeoDataFrame.
667
+ """
668
+ points = result.get("points", [])
669
+ if not points:
670
+ result["gdf"] = gpd.GeoDataFrame(
671
+ columns=["geometry", "x", "y"],
672
+ crs=metadata["crs"],
673
+ )
674
+ result["crs"] = metadata["crs"]
675
+ result["bounds"] = metadata["bounds"]
676
+ return result
677
+
678
+ width = metadata["width"]
679
+ height = metadata["height"]
680
+ transform = metadata["transform"]
681
+
682
+ geometries = []
683
+ records = []
684
+
685
+ for pt in points:
686
+ # Convert normalized coords to pixel coords
687
+ px_x = pt["x"] * width
688
+ px_y = pt["y"] * height
689
+
690
+ # Convert pixel coords to geographic coords
691
+ geo_x, geo_y = transform * (px_x, px_y)
692
+
693
+ # Create point
694
+ geom = Point(geo_x, geo_y)
695
+ geometries.append(geom)
696
+
697
+ records.append(
698
+ {
699
+ "x": pt["x"],
700
+ "y": pt["y"],
701
+ "px_x": int(px_x),
702
+ "px_y": int(px_y),
703
+ }
704
+ )
705
+
706
+ gdf = gpd.GeoDataFrame(records, geometry=geometries, crs=metadata["crs"])
707
+
708
+ result["gdf"] = gdf
709
+ result["crs"] = metadata["crs"]
710
+ result["bounds"] = metadata["bounds"]
711
+
712
+ return result
713
+
714
+ def _save_vector(
715
+ self,
716
+ gdf: gpd.GeoDataFrame,
717
+ output_path: str,
718
+ ) -> None:
719
+ """Save GeoDataFrame to vector file.
720
+
721
+ Args:
722
+ gdf: GeoDataFrame to save.
723
+ output_path: Output file path. Extension determines format:
724
+ .geojson -> GeoJSON
725
+ .shp -> Shapefile
726
+ .gpkg -> GeoPackage
727
+ """
728
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
729
+
730
+ ext = os.path.splitext(output_path)[1].lower()
731
+ if ext == ".geojson":
732
+ gdf.to_file(output_path, driver="GeoJSON")
733
+ elif ext == ".shp":
734
+ gdf.to_file(output_path, driver="ESRI Shapefile")
735
+ elif ext == ".gpkg":
736
+ gdf.to_file(output_path, driver="GPKG")
737
+ else:
738
+ gdf.to_file(output_path)
739
+
740
+ print(f"Saved {len(gdf)} features to {output_path}")
741
+
742
+ def create_detection_mask(
743
+ self,
744
+ source: Union[str, Image.Image, np.ndarray],
745
+ object_type: str,
746
+ output_path: Optional[str] = None,
747
+ bands: Optional[List[int]] = None,
748
+ **kwargs: Any,
749
+ ) -> Tuple[np.ndarray, Optional[Dict]]:
750
+ """Create a binary mask from object detections.
751
+
752
+ Args:
753
+ source: Image source.
754
+ object_type: Type of object to detect.
755
+ output_path: Path to save mask as GeoTIFF.
756
+ bands: Band indices for GeoTIFF.
757
+ **kwargs: Additional arguments for detect().
758
+
759
+ Returns:
760
+ Tuple of (mask array, metadata dict).
761
+ """
762
+ # Load image to get dimensions
763
+ image, metadata = self.load_image(source, bands)
764
+ width, height = image.size
765
+
766
+ # Detect objects
767
+ result = self.detect(source, object_type, bands=bands, **kwargs)
768
+ objects = result.get("objects", [])
769
+
770
+ # Create mask
771
+ mask = np.zeros((height, width), dtype=np.uint8)
772
+
773
+ for obj in objects:
774
+ x_min = int(obj["x_min"] * width)
775
+ y_min = int(obj["y_min"] * height)
776
+ x_max = int(obj["x_max"] * width)
777
+ y_max = int(obj["y_max"] * height)
778
+
779
+ mask[y_min:y_max, x_min:x_max] = 255
780
+
781
+ # Save as GeoTIFF if requested
782
+ if output_path and metadata and metadata.get("crs"):
783
+ self._save_mask_geotiff(mask, output_path, metadata)
784
+ elif output_path:
785
+ # Save as regular image
786
+ Image.fromarray(mask).save(output_path)
787
+
788
+ return mask, metadata
789
+
790
+ def _save_mask_geotiff(
791
+ self,
792
+ mask: np.ndarray,
793
+ output_path: str,
794
+ metadata: Dict,
795
+ ) -> None:
796
+ """Save mask as GeoTIFF with georeferencing.
797
+
798
+ Args:
799
+ mask: 2D mask array.
800
+ output_path: Output file path.
801
+ metadata: Image metadata with profile.
802
+ """
803
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
804
+
805
+ profile = metadata["profile"].copy()
806
+ profile.update(
807
+ {
808
+ "dtype": "uint8",
809
+ "count": 1,
810
+ "height": mask.shape[0],
811
+ "width": mask.shape[1],
812
+ "compress": "lzw",
813
+ }
814
+ )
815
+
816
+ with rasterio.open(output_path, "w", **profile) as dst:
817
+ dst.write(mask, 1)
818
+
819
+ print(f"Saved mask to {output_path}")
820
+
821
+ def show_gui(
822
+ self,
823
+ basemap: str = "SATELLITE",
824
+ out_dir: Optional[str] = None,
825
+ opacity: float = 0.5,
826
+ **kwargs: Any,
827
+ ) -> Any:
828
+ """Display an interactive GUI for using Moondream with leafmap.
829
+
830
+ This method creates an interactive map interface for using Moondream
831
+ vision language model capabilities including:
832
+ - Image captioning (short, normal, long)
833
+ - Visual question answering (query)
834
+ - Object detection with bounding boxes
835
+ - Point detection for locating objects
836
+
837
+ Args:
838
+ basemap: The basemap to use. Defaults to "SATELLITE".
839
+ out_dir: The output directory for saving results.
840
+ Defaults to None (uses temp directory).
841
+ opacity: The opacity of overlay layers. Defaults to 0.5.
842
+ **kwargs: Additional keyword arguments passed to leafmap.Map().
843
+
844
+ Returns:
845
+ leafmap.Map: The interactive map with the Moondream GUI.
846
+
847
+ Example:
848
+ >>> moondream = MoondreamGeo()
849
+ >>> moondream.load_image("image.tif")
850
+ >>> m = moondream.show_gui()
851
+ >>> m
852
+ """
853
+ from .map_widgets import moondream_gui
854
+
855
+ return moondream_gui(
856
+ self,
857
+ basemap=basemap,
858
+ out_dir=out_dir,
859
+ opacity=opacity,
860
+ **kwargs,
861
+ )
862
+
863
+ def get_last_result(self) -> Dict[str, Any]:
864
+
865
+ if hasattr(self, "last_result") and "gdf" in self.last_result:
866
+ return self.last_result["gdf"]
867
+ else:
868
+ return None
869
+
870
+
871
+ def moondream_caption(
872
+ source: Union[str, Image.Image, np.ndarray],
873
+ model_name: str = "vikhyatk/moondream2",
874
+ revision: Optional[str] = None,
875
+ length: str = "normal",
876
+ bands: Optional[List[int]] = None,
877
+ device: Optional[str] = None,
878
+ **kwargs: Any,
879
+ ) -> str:
880
+ """Convenience function to generate a caption for an image.
881
+
882
+ Args:
883
+ source: Image source (file path, PIL Image, or numpy array).
884
+ model_name: Moondream model name.
885
+ revision: Model revision.
886
+ length: Caption length ("short", "normal", "long").
887
+ bands: Band indices for GeoTIFF.
888
+ device: Device for inference.
889
+ **kwargs: Additional arguments.
890
+
891
+ Returns:
892
+ Generated caption string.
893
+ """
894
+ processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
895
+ result = processor.caption(source, length=length, bands=bands, **kwargs)
896
+ return result["caption"]
897
+
898
+
899
+ def moondream_query(
900
+ question: str,
901
+ source: Optional[Union[str, Image.Image, np.ndarray]] = None,
902
+ model_name: str = "vikhyatk/moondream2",
903
+ revision: Optional[str] = None,
904
+ reasoning: Optional[bool] = None,
905
+ bands: Optional[List[int]] = None,
906
+ device: Optional[str] = None,
907
+ **kwargs: Any,
908
+ ) -> str:
909
+ """Convenience function to ask a question about an image.
910
+
911
+ Args:
912
+ question: Question to ask.
913
+ source: Image source (optional for text-only queries).
914
+ model_name: Moondream model name.
915
+ revision: Model revision.
916
+ reasoning: Enable reasoning mode (moondream3 only).
917
+ bands: Band indices for GeoTIFF.
918
+ device: Device for inference.
919
+ **kwargs: Additional arguments.
920
+
921
+ Returns:
922
+ Answer string.
923
+ """
924
+ processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
925
+ result = processor.query(
926
+ question, source=source, reasoning=reasoning, bands=bands, **kwargs
927
+ )
928
+ return result["answer"]
929
+
930
+
931
+ def moondream_detect(
932
+ source: Union[str, Image.Image, np.ndarray],
933
+ object_type: str,
934
+ model_name: str = "vikhyatk/moondream2",
935
+ revision: Optional[str] = None,
936
+ output_path: Optional[str] = None,
937
+ bands: Optional[List[int]] = None,
938
+ device: Optional[str] = None,
939
+ **kwargs: Any,
940
+ ) -> Dict[str, Any]:
941
+ """Convenience function to detect objects in an image.
942
+
943
+ Args:
944
+ source: Image source.
945
+ object_type: Type of object to detect.
946
+ model_name: Moondream model name.
947
+ revision: Model revision.
948
+ output_path: Path to save results as vector file.
949
+ bands: Band indices for GeoTIFF.
950
+ device: Device for inference.
951
+ **kwargs: Additional arguments.
952
+
953
+ Returns:
954
+ Detection results dictionary with "objects" and optionally "gdf".
955
+ """
956
+ processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
957
+ return processor.detect(
958
+ source, object_type, output_path=output_path, bands=bands, **kwargs
959
+ )
960
+
961
+
962
+ def moondream_point(
963
+ source: Union[str, Image.Image, np.ndarray],
964
+ object_description: str,
965
+ model_name: str = "vikhyatk/moondream2",
966
+ revision: Optional[str] = None,
967
+ output_path: Optional[str] = None,
968
+ bands: Optional[List[int]] = None,
969
+ device: Optional[str] = None,
970
+ **kwargs: Any,
971
+ ) -> Dict[str, Any]:
972
+ """Convenience function to find points for objects in an image.
973
+
974
+ Args:
975
+ source: Image source.
976
+ object_description: Description of objects to find.
977
+ model_name: Moondream model name.
978
+ revision: Model revision.
979
+ output_path: Path to save results as vector file.
980
+ bands: Band indices for GeoTIFF.
981
+ device: Device for inference.
982
+ **kwargs: Additional arguments.
983
+
984
+ Returns:
985
+ Point results dictionary with "points" and optionally "gdf".
986
+ """
987
+ processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
988
+ return processor.point(
989
+ source, object_description, output_path=output_path, bands=bands, **kwargs
990
+ )