geoai-py 0.18.2__py2.py3-none-any.whl → 0.19.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +23 -1
- geoai/agents/__init__.py +1 -0
- geoai/agents/geo_agents.py +74 -29
- geoai/geoai.py +2 -0
- geoai/landcover_train.py +685 -0
- geoai/landcover_utils.py +383 -0
- geoai/map_widgets.py +556 -0
- geoai/moondream.py +990 -0
- geoai/tools/__init__.py +11 -0
- geoai/tools/sr.py +194 -0
- geoai/utils.py +329 -1881
- {geoai_py-0.18.2.dist-info → geoai_py-0.19.0.dist-info}/METADATA +3 -1
- {geoai_py-0.18.2.dist-info → geoai_py-0.19.0.dist-info}/RECORD +17 -13
- {geoai_py-0.18.2.dist-info → geoai_py-0.19.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.18.2.dist-info → geoai_py-0.19.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.18.2.dist-info → geoai_py-0.19.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.18.2.dist-info → geoai_py-0.19.0.dist-info}/top_level.txt +0 -0
geoai/moondream.py
ADDED
|
@@ -0,0 +1,990 @@
|
|
|
1
|
+
"""Moondream Vision Language Model module for GeoAI.
|
|
2
|
+
|
|
3
|
+
This module provides an interface for using Moondream vision language models
|
|
4
|
+
(moondream2 and moondream3-preview) with geospatial imagery, supporting
|
|
5
|
+
GeoTIFF input and georeferenced output.
|
|
6
|
+
|
|
7
|
+
Supported models:
|
|
8
|
+
- moondream2: https://huggingface.co/vikhyatk/moondream2
|
|
9
|
+
- moondream3-preview: https://huggingface.co/moondream/moondream3-preview
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
14
|
+
|
|
15
|
+
import geopandas as gpd
|
|
16
|
+
import numpy as np
|
|
17
|
+
import rasterio
|
|
18
|
+
import torch
|
|
19
|
+
from PIL import Image
|
|
20
|
+
from shapely.geometry import Point, box
|
|
21
|
+
from transformers.utils import logging as hf_logging
|
|
22
|
+
|
|
23
|
+
from .utils import get_device
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
hf_logging.set_verbosity_error() # silence HF load reports
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MoondreamGeo:
|
|
30
|
+
"""Moondream Vision Language Model processor with GeoTIFF support.
|
|
31
|
+
|
|
32
|
+
This class provides an interface for using Moondream models for
|
|
33
|
+
geospatial image analysis, including captioning, visual querying,
|
|
34
|
+
object detection, and pointing.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
model: The loaded Moondream model.
|
|
38
|
+
model_name: Name of the model being used.
|
|
39
|
+
device: Torch device for inference.
|
|
40
|
+
model_version: Either "moondream2" or "moondream3".
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
model_name: str = "vikhyatk/moondream2",
|
|
46
|
+
revision: Optional[str] = None,
|
|
47
|
+
device: Optional[str] = None,
|
|
48
|
+
compile_model: bool = False,
|
|
49
|
+
**kwargs: Any,
|
|
50
|
+
) -> None:
|
|
51
|
+
"""Initialize the Moondream processor.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_name: HuggingFace model name. Options:
|
|
55
|
+
- "vikhyatk/moondream2" (default)
|
|
56
|
+
- "moondream/moondream3-preview"
|
|
57
|
+
revision: Model revision/checkpoint date. For moondream2, recommended
|
|
58
|
+
to use a specific date like "2025-06-21" for reproducibility.
|
|
59
|
+
device: Device for inference ("cuda", "mps", or "cpu").
|
|
60
|
+
If None, automatically selects the best available device.
|
|
61
|
+
compile_model: Whether to compile the model (recommended for
|
|
62
|
+
moondream3-preview for faster inference).
|
|
63
|
+
**kwargs: Additional arguments passed to from_pretrained.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ImportError: If transformers is not installed.
|
|
67
|
+
RuntimeError: If model loading fails.
|
|
68
|
+
"""
|
|
69
|
+
self.model_name = model_name
|
|
70
|
+
self.device = device or get_device()
|
|
71
|
+
self._source_path: Optional[str] = None
|
|
72
|
+
self._metadata: Optional[Dict] = None
|
|
73
|
+
|
|
74
|
+
# Determine model version
|
|
75
|
+
if "moondream3" in model_name.lower():
|
|
76
|
+
self.model_version = "moondream3"
|
|
77
|
+
else:
|
|
78
|
+
self.model_version = "moondream2"
|
|
79
|
+
|
|
80
|
+
# Load the model
|
|
81
|
+
self.model = self._load_model(revision, compile_model, **kwargs)
|
|
82
|
+
|
|
83
|
+
def _load_model(
|
|
84
|
+
self,
|
|
85
|
+
revision: Optional[str],
|
|
86
|
+
compile_model: bool,
|
|
87
|
+
**kwargs: Any,
|
|
88
|
+
) -> Any:
|
|
89
|
+
"""Load the Moondream model.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
revision: Model revision/checkpoint.
|
|
93
|
+
compile_model: Whether to compile the model.
|
|
94
|
+
**kwargs: Additional arguments for from_pretrained.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Loaded model instance.
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
RuntimeError: If model loading fails.
|
|
101
|
+
"""
|
|
102
|
+
try:
|
|
103
|
+
from transformers import AutoModelForCausalLM
|
|
104
|
+
|
|
105
|
+
# Default kwargs
|
|
106
|
+
load_kwargs = {
|
|
107
|
+
"trust_remote_code": True,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# Try to use device_map with accelerate, fall back to manual device placement
|
|
111
|
+
use_device_map = False
|
|
112
|
+
try:
|
|
113
|
+
import accelerate # noqa: F401
|
|
114
|
+
|
|
115
|
+
# Build device map
|
|
116
|
+
if isinstance(self.device, str):
|
|
117
|
+
device_map = {"": self.device}
|
|
118
|
+
else:
|
|
119
|
+
device_map = {"": str(self.device)}
|
|
120
|
+
load_kwargs["device_map"] = device_map
|
|
121
|
+
use_device_map = True
|
|
122
|
+
except ImportError:
|
|
123
|
+
# accelerate not available, will move model to device manually
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
# Add revision if specified
|
|
127
|
+
if revision:
|
|
128
|
+
load_kwargs["revision"] = revision
|
|
129
|
+
|
|
130
|
+
# For moondream3, use bfloat16
|
|
131
|
+
if self.model_version == "moondream3":
|
|
132
|
+
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
133
|
+
# Note: moondream3 uses dtype instead of torch_dtype in some versions
|
|
134
|
+
load_kwargs["dtype"] = torch.bfloat16
|
|
135
|
+
|
|
136
|
+
load_kwargs.update(kwargs)
|
|
137
|
+
|
|
138
|
+
print(f"Loading {self.model_name}...")
|
|
139
|
+
|
|
140
|
+
# Try to load with potential transformers 5.0 compatibility fix
|
|
141
|
+
try:
|
|
142
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
143
|
+
self.model_name,
|
|
144
|
+
**load_kwargs,
|
|
145
|
+
)
|
|
146
|
+
except AttributeError as attr_err:
|
|
147
|
+
# Handle transformers 5.0 compatibility issue with custom models
|
|
148
|
+
if "all_tied_weights_keys" in str(attr_err):
|
|
149
|
+
# print(
|
|
150
|
+
# "Note: Detected transformers 5.0+ compatibility issue. "
|
|
151
|
+
# "Attempting workaround..."
|
|
152
|
+
# )
|
|
153
|
+
# Try patching the model class
|
|
154
|
+
model = self._load_with_patch(load_kwargs)
|
|
155
|
+
else:
|
|
156
|
+
raise
|
|
157
|
+
|
|
158
|
+
# Move model to device if accelerate wasn't used
|
|
159
|
+
if not use_device_map:
|
|
160
|
+
device = (
|
|
161
|
+
self.device
|
|
162
|
+
if isinstance(self.device, torch.device)
|
|
163
|
+
else torch.device(self.device)
|
|
164
|
+
)
|
|
165
|
+
model = model.to(device)
|
|
166
|
+
# print(f"Model moved to {device}")
|
|
167
|
+
|
|
168
|
+
# Set model to evaluation mode
|
|
169
|
+
model.eval()
|
|
170
|
+
|
|
171
|
+
# Compile model if requested (recommended for moondream3)
|
|
172
|
+
if compile_model and hasattr(model, "compile"):
|
|
173
|
+
print("Compiling model for faster inference...")
|
|
174
|
+
model.compile()
|
|
175
|
+
|
|
176
|
+
print(f"Using device: {self.device}")
|
|
177
|
+
return model
|
|
178
|
+
|
|
179
|
+
except Exception as e:
|
|
180
|
+
# Provide helpful error message
|
|
181
|
+
error_msg = str(e)
|
|
182
|
+
if "all_tied_weights_keys" in error_msg:
|
|
183
|
+
error_msg = (
|
|
184
|
+
f"Failed to load Moondream model due to transformers version "
|
|
185
|
+
f"incompatibility. The model's custom code may not be compatible "
|
|
186
|
+
f"with your current transformers version. Try: "
|
|
187
|
+
f"1) Wait for an updated model revision, or "
|
|
188
|
+
f"2) Use a compatible transformers version. "
|
|
189
|
+
f"Original error: {e}"
|
|
190
|
+
)
|
|
191
|
+
raise RuntimeError(f"Failed to load Moondream model: {error_msg}") from e
|
|
192
|
+
|
|
193
|
+
def _load_with_patch(self, load_kwargs: Dict) -> Any:
|
|
194
|
+
"""Load model with compatibility patch for transformers 5.0+.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
load_kwargs: Keyword arguments for from_pretrained.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Loaded model instance.
|
|
201
|
+
"""
|
|
202
|
+
from transformers import AutoModelForCausalLM, PreTrainedModel
|
|
203
|
+
|
|
204
|
+
# Patch the PreTrainedModel class to add missing attribute
|
|
205
|
+
original_getattr = PreTrainedModel.__getattr__
|
|
206
|
+
|
|
207
|
+
def patched_getattr(self, name):
|
|
208
|
+
if name == "all_tied_weights_keys":
|
|
209
|
+
# Return empty dict to satisfy the check
|
|
210
|
+
if not hasattr(self, "_all_tied_weights_keys"):
|
|
211
|
+
self._all_tied_weights_keys = {}
|
|
212
|
+
return self._all_tied_weights_keys
|
|
213
|
+
return original_getattr(self, name)
|
|
214
|
+
|
|
215
|
+
# Apply patch temporarily
|
|
216
|
+
PreTrainedModel.__getattr__ = patched_getattr
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
220
|
+
self.model_name,
|
|
221
|
+
**load_kwargs,
|
|
222
|
+
)
|
|
223
|
+
return model
|
|
224
|
+
finally:
|
|
225
|
+
# Restore original
|
|
226
|
+
PreTrainedModel.__getattr__ = original_getattr
|
|
227
|
+
|
|
228
|
+
def load_geotiff(
|
|
229
|
+
self,
|
|
230
|
+
source: str,
|
|
231
|
+
bands: Optional[List[int]] = None,
|
|
232
|
+
) -> Tuple[Image.Image, Dict]:
|
|
233
|
+
"""Load a GeoTIFF file and return a PIL Image with metadata.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
source: Path to GeoTIFF file.
|
|
237
|
+
bands: List of band indices to read (1-indexed). If None, reads
|
|
238
|
+
first 3 bands for RGB or first band for grayscale.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Tuple of (PIL Image, metadata dict).
|
|
242
|
+
|
|
243
|
+
Raises:
|
|
244
|
+
FileNotFoundError: If source file doesn't exist.
|
|
245
|
+
RuntimeError: If loading fails.
|
|
246
|
+
"""
|
|
247
|
+
if not os.path.exists(source):
|
|
248
|
+
raise FileNotFoundError(f"File not found: {source}")
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
with rasterio.open(source) as src:
|
|
252
|
+
# Store metadata
|
|
253
|
+
metadata = {
|
|
254
|
+
"profile": src.profile.copy(),
|
|
255
|
+
"crs": src.crs,
|
|
256
|
+
"transform": src.transform,
|
|
257
|
+
"bounds": src.bounds,
|
|
258
|
+
"width": src.width,
|
|
259
|
+
"height": src.height,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
# Read bands
|
|
263
|
+
if bands is None:
|
|
264
|
+
if src.count >= 3:
|
|
265
|
+
bands = [1, 2, 3] # RGB
|
|
266
|
+
else:
|
|
267
|
+
bands = [1] # Grayscale
|
|
268
|
+
|
|
269
|
+
data = src.read(bands)
|
|
270
|
+
|
|
271
|
+
# Convert to RGB image
|
|
272
|
+
if len(bands) == 1:
|
|
273
|
+
# Grayscale to RGB
|
|
274
|
+
img_array = np.repeat(data[0:1], 3, axis=0)
|
|
275
|
+
elif len(bands) >= 3:
|
|
276
|
+
img_array = data[:3]
|
|
277
|
+
else:
|
|
278
|
+
# Pad to 3 channels
|
|
279
|
+
img_array = np.zeros((3, data.shape[1], data.shape[2]))
|
|
280
|
+
img_array[: data.shape[0]] = data
|
|
281
|
+
|
|
282
|
+
# Normalize to 0-255 range
|
|
283
|
+
img_array = self._normalize_image(img_array)
|
|
284
|
+
|
|
285
|
+
# Convert to PIL Image (HWC format)
|
|
286
|
+
img_array = np.transpose(img_array, (1, 2, 0))
|
|
287
|
+
image = Image.fromarray(img_array.astype(np.uint8))
|
|
288
|
+
|
|
289
|
+
self._source_path = source
|
|
290
|
+
self._metadata = metadata
|
|
291
|
+
|
|
292
|
+
return image, metadata
|
|
293
|
+
|
|
294
|
+
except Exception as e:
|
|
295
|
+
raise RuntimeError(f"Failed to load GeoTIFF: {e}") from e
|
|
296
|
+
|
|
297
|
+
def load_image(
|
|
298
|
+
self,
|
|
299
|
+
source: Union[str, Image.Image, np.ndarray],
|
|
300
|
+
bands: Optional[List[int]] = None,
|
|
301
|
+
) -> Tuple[Image.Image, Optional[Dict]]:
|
|
302
|
+
"""Load an image from various sources.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
source: Image source - can be a file path (GeoTIFF, PNG, JPG),
|
|
306
|
+
PIL Image, or numpy array.
|
|
307
|
+
bands: Band indices for GeoTIFF (1-indexed).
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Tuple of (PIL Image, metadata dict or None).
|
|
311
|
+
"""
|
|
312
|
+
if isinstance(source, Image.Image):
|
|
313
|
+
self._source_path = None
|
|
314
|
+
self._metadata = None
|
|
315
|
+
return source, None
|
|
316
|
+
|
|
317
|
+
if isinstance(source, np.ndarray):
|
|
318
|
+
if source.ndim == 2:
|
|
319
|
+
# Grayscale
|
|
320
|
+
source = np.stack([source] * 3, axis=-1)
|
|
321
|
+
elif source.ndim == 3 and source.shape[0] <= 4:
|
|
322
|
+
# CHW format
|
|
323
|
+
source = np.transpose(source[:3], (1, 2, 0))
|
|
324
|
+
|
|
325
|
+
source = self._normalize_image(source)
|
|
326
|
+
image = Image.fromarray(source.astype(np.uint8))
|
|
327
|
+
self._source_path = None
|
|
328
|
+
self._metadata = None
|
|
329
|
+
return image, None
|
|
330
|
+
|
|
331
|
+
if isinstance(source, str):
|
|
332
|
+
if source.startswith(("http://", "https://")):
|
|
333
|
+
# URL - download and load
|
|
334
|
+
from .utils import download_file
|
|
335
|
+
|
|
336
|
+
source = download_file(source)
|
|
337
|
+
|
|
338
|
+
# Check if it's a GeoTIFF
|
|
339
|
+
try:
|
|
340
|
+
with rasterio.open(source) as src:
|
|
341
|
+
if src.crs is not None or source.lower().endswith(
|
|
342
|
+
(".tif", ".tiff")
|
|
343
|
+
):
|
|
344
|
+
return self.load_geotiff(source, bands)
|
|
345
|
+
except rasterio.RasterioIOError:
|
|
346
|
+
pass
|
|
347
|
+
|
|
348
|
+
# Regular image
|
|
349
|
+
image = Image.open(source).convert("RGB")
|
|
350
|
+
self._source_path = source
|
|
351
|
+
self._metadata = {
|
|
352
|
+
"width": image.width,
|
|
353
|
+
"height": image.height,
|
|
354
|
+
"crs": None,
|
|
355
|
+
"transform": None,
|
|
356
|
+
"bounds": None,
|
|
357
|
+
}
|
|
358
|
+
return image, self._metadata
|
|
359
|
+
|
|
360
|
+
def _normalize_image(self, data: np.ndarray) -> np.ndarray:
|
|
361
|
+
"""Normalize image data to 0-255 range using percentile stretching.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
data: Input array (can be CHW or HWC format).
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Normalized array in uint8 range.
|
|
368
|
+
"""
|
|
369
|
+
if data.dtype == np.uint8:
|
|
370
|
+
return data
|
|
371
|
+
|
|
372
|
+
# Determine if CHW or HWC
|
|
373
|
+
if data.ndim == 3 and data.shape[0] <= 4:
|
|
374
|
+
# CHW format - normalize each channel
|
|
375
|
+
normalized = np.zeros_like(data, dtype=np.float32)
|
|
376
|
+
for i in range(data.shape[0]):
|
|
377
|
+
band = data[i].astype(np.float32)
|
|
378
|
+
p2, p98 = np.percentile(band, [2, 98])
|
|
379
|
+
if p98 > p2:
|
|
380
|
+
normalized[i] = np.clip((band - p2) / (p98 - p2) * 255, 0, 255)
|
|
381
|
+
else:
|
|
382
|
+
normalized[i] = np.clip(band, 0, 255)
|
|
383
|
+
else:
|
|
384
|
+
# HWC format or 2D
|
|
385
|
+
data = data.astype(np.float32)
|
|
386
|
+
p2, p98 = np.percentile(data, [2, 98])
|
|
387
|
+
if p98 > p2:
|
|
388
|
+
normalized = np.clip((data - p2) / (p98 - p2) * 255, 0, 255)
|
|
389
|
+
else:
|
|
390
|
+
normalized = np.clip(data, 0, 255)
|
|
391
|
+
|
|
392
|
+
return normalized.astype(np.uint8)
|
|
393
|
+
|
|
394
|
+
def encode_image(
|
|
395
|
+
self,
|
|
396
|
+
source: Union[str, Image.Image, np.ndarray],
|
|
397
|
+
bands: Optional[List[int]] = None,
|
|
398
|
+
) -> Any:
|
|
399
|
+
"""Pre-encode an image for efficient multiple inferences.
|
|
400
|
+
|
|
401
|
+
Use this when you plan to run multiple queries on the same image.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
source: Image source.
|
|
405
|
+
bands: Band indices for GeoTIFF.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Encoded image that can be passed to query, caption, etc.
|
|
409
|
+
"""
|
|
410
|
+
image, _ = self.load_image(source, bands)
|
|
411
|
+
|
|
412
|
+
if hasattr(self.model, "encode_image"):
|
|
413
|
+
return self.model.encode_image(image)
|
|
414
|
+
return image
|
|
415
|
+
|
|
416
|
+
def caption(
|
|
417
|
+
self,
|
|
418
|
+
source: Union[str, Image.Image, np.ndarray, Any],
|
|
419
|
+
length: str = "normal",
|
|
420
|
+
stream: bool = False,
|
|
421
|
+
bands: Optional[List[int]] = None,
|
|
422
|
+
settings: Optional[Dict] = None,
|
|
423
|
+
**kwargs: Any,
|
|
424
|
+
) -> Dict[str, Any]:
|
|
425
|
+
"""Generate a caption for an image.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
source: Image source or pre-encoded image.
|
|
429
|
+
length: Caption length - "short", "normal", or "long".
|
|
430
|
+
stream: Whether to stream the output.
|
|
431
|
+
bands: Band indices for GeoTIFF.
|
|
432
|
+
settings: Additional settings (temperature, top_p, max_tokens).
|
|
433
|
+
**kwargs: Additional arguments for the model.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
Dictionary with "caption" key containing the generated caption.
|
|
437
|
+
"""
|
|
438
|
+
# Load image if not pre-encoded
|
|
439
|
+
if isinstance(source, (str, Image.Image, np.ndarray)):
|
|
440
|
+
image, _ = self.load_image(source, bands)
|
|
441
|
+
else:
|
|
442
|
+
image = source # Pre-encoded
|
|
443
|
+
|
|
444
|
+
call_kwargs = {"length": length, "stream": stream}
|
|
445
|
+
if settings:
|
|
446
|
+
call_kwargs["settings"] = settings
|
|
447
|
+
call_kwargs.update(kwargs)
|
|
448
|
+
|
|
449
|
+
return self.model.caption(image, **call_kwargs)
|
|
450
|
+
|
|
451
|
+
def query(
|
|
452
|
+
self,
|
|
453
|
+
question: str,
|
|
454
|
+
source: Optional[Union[str, Image.Image, np.ndarray, Any]] = None,
|
|
455
|
+
reasoning: Optional[bool] = None,
|
|
456
|
+
stream: bool = False,
|
|
457
|
+
bands: Optional[List[int]] = None,
|
|
458
|
+
settings: Optional[Dict] = None,
|
|
459
|
+
**kwargs: Any,
|
|
460
|
+
) -> Dict[str, Any]:
|
|
461
|
+
"""Ask a question about an image or text-only query.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
question: The question to ask.
|
|
465
|
+
source: Image source or pre-encoded image. If None, performs
|
|
466
|
+
text-only query (moondream3 only).
|
|
467
|
+
reasoning: Enable reasoning mode for more complex tasks
|
|
468
|
+
(moondream3 only, default True).
|
|
469
|
+
stream: Whether to stream the output.
|
|
470
|
+
bands: Band indices for GeoTIFF.
|
|
471
|
+
settings: Additional settings (temperature, top_p, max_tokens).
|
|
472
|
+
**kwargs: Additional arguments for the model.
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
Dictionary with "answer" key containing the response.
|
|
476
|
+
"""
|
|
477
|
+
call_kwargs = {"question": question, "stream": stream}
|
|
478
|
+
|
|
479
|
+
if source is not None:
|
|
480
|
+
if isinstance(source, (str, Image.Image, np.ndarray)):
|
|
481
|
+
image, _ = self.load_image(source, bands)
|
|
482
|
+
else:
|
|
483
|
+
image = source
|
|
484
|
+
call_kwargs["image"] = image
|
|
485
|
+
|
|
486
|
+
if reasoning is not None and self.model_version == "moondream3":
|
|
487
|
+
call_kwargs["reasoning"] = reasoning
|
|
488
|
+
|
|
489
|
+
if settings:
|
|
490
|
+
call_kwargs["settings"] = settings
|
|
491
|
+
call_kwargs.update(kwargs)
|
|
492
|
+
|
|
493
|
+
return self.model.query(**call_kwargs)
|
|
494
|
+
|
|
495
|
+
def detect(
|
|
496
|
+
self,
|
|
497
|
+
source: Union[str, Image.Image, np.ndarray, Any],
|
|
498
|
+
object_type: str,
|
|
499
|
+
bands: Optional[List[int]] = None,
|
|
500
|
+
output_path: Optional[str] = None,
|
|
501
|
+
settings: Optional[Dict] = None,
|
|
502
|
+
stream: bool = False,
|
|
503
|
+
**kwargs: Any,
|
|
504
|
+
) -> Dict[str, Any]:
|
|
505
|
+
"""Detect objects of a specific type in an image.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
source: Image source or pre-encoded image.
|
|
509
|
+
object_type: Type of object to detect (e.g., "car", "building").
|
|
510
|
+
bands: Band indices for GeoTIFF.
|
|
511
|
+
output_path: Path to save results as GeoJSON/Shapefile/GeoPackage.
|
|
512
|
+
settings: Additional settings (max_objects, etc.).
|
|
513
|
+
stream: Whether to stream the output (moondream2 only).
|
|
514
|
+
**kwargs: Additional arguments for the model.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
Dictionary with "objects" key containing list of bounding boxes
|
|
518
|
+
with normalized coordinates (x_min, y_min, x_max, y_max).
|
|
519
|
+
If georeferenced, also includes "gdf" (GeoDataFrame) and
|
|
520
|
+
"crs", "bounds" keys.
|
|
521
|
+
"""
|
|
522
|
+
# Load image
|
|
523
|
+
if isinstance(source, (str, Image.Image, np.ndarray)):
|
|
524
|
+
image, metadata = self.load_image(source, bands)
|
|
525
|
+
else:
|
|
526
|
+
image = source
|
|
527
|
+
metadata = self._metadata
|
|
528
|
+
|
|
529
|
+
call_kwargs = {}
|
|
530
|
+
if settings:
|
|
531
|
+
call_kwargs["settings"] = settings
|
|
532
|
+
if self.model_version == "moondream2" and stream:
|
|
533
|
+
call_kwargs["stream"] = stream
|
|
534
|
+
call_kwargs.update(kwargs)
|
|
535
|
+
|
|
536
|
+
result = self.model.detect(image, object_type, **call_kwargs)
|
|
537
|
+
|
|
538
|
+
# Convert to georeferenced if possible
|
|
539
|
+
if metadata and metadata.get("crs") and metadata.get("transform"):
|
|
540
|
+
result = self._georef_detections(result, metadata)
|
|
541
|
+
|
|
542
|
+
if output_path:
|
|
543
|
+
self._save_vector(result["gdf"], output_path)
|
|
544
|
+
|
|
545
|
+
return result
|
|
546
|
+
|
|
547
|
+
def point(
|
|
548
|
+
self,
|
|
549
|
+
source: Union[str, Image.Image, np.ndarray, Any],
|
|
550
|
+
object_description: str,
|
|
551
|
+
bands: Optional[List[int]] = None,
|
|
552
|
+
output_path: Optional[str] = None,
|
|
553
|
+
**kwargs: Any,
|
|
554
|
+
) -> Dict[str, Any]:
|
|
555
|
+
"""Find points (x, y coordinates) for objects in an image.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
source: Image source or pre-encoded image.
|
|
559
|
+
object_description: Description of objects to find.
|
|
560
|
+
bands: Band indices for GeoTIFF.
|
|
561
|
+
output_path: Path to save results as GeoJSON/Shapefile/GeoPackage.
|
|
562
|
+
**kwargs: Additional arguments for the model.
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
Dictionary with "points" key containing list of points
|
|
566
|
+
with normalized coordinates (x, y in 0-1 range).
|
|
567
|
+
If georeferenced, also includes "gdf" (GeoDataFrame).
|
|
568
|
+
"""
|
|
569
|
+
# Load image
|
|
570
|
+
if isinstance(source, (str, Image.Image, np.ndarray)):
|
|
571
|
+
image, metadata = self.load_image(source, bands)
|
|
572
|
+
else:
|
|
573
|
+
image = source
|
|
574
|
+
metadata = self._metadata
|
|
575
|
+
|
|
576
|
+
result = self.model.point(image, object_description, **kwargs)
|
|
577
|
+
|
|
578
|
+
# Convert to georeferenced if possible
|
|
579
|
+
if metadata and metadata.get("crs") and metadata.get("transform"):
|
|
580
|
+
result = self._georef_points(result, metadata)
|
|
581
|
+
|
|
582
|
+
if output_path:
|
|
583
|
+
self._save_vector(result["gdf"], output_path)
|
|
584
|
+
|
|
585
|
+
return result
|
|
586
|
+
|
|
587
|
+
def _georef_detections(
|
|
588
|
+
self,
|
|
589
|
+
result: Dict[str, Any],
|
|
590
|
+
metadata: Dict,
|
|
591
|
+
) -> Dict[str, Any]:
|
|
592
|
+
"""Convert detection results to georeferenced format.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
result: Detection result from model.
|
|
596
|
+
metadata: Image metadata with CRS and transform.
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
Updated result dictionary with GeoDataFrame.
|
|
600
|
+
"""
|
|
601
|
+
objects = result.get("objects", [])
|
|
602
|
+
if not objects:
|
|
603
|
+
result["gdf"] = gpd.GeoDataFrame(
|
|
604
|
+
columns=["geometry", "x_min", "y_min", "x_max", "y_max"],
|
|
605
|
+
crs=metadata["crs"],
|
|
606
|
+
)
|
|
607
|
+
result["crs"] = metadata["crs"]
|
|
608
|
+
result["bounds"] = metadata["bounds"]
|
|
609
|
+
return result
|
|
610
|
+
|
|
611
|
+
width = metadata["width"]
|
|
612
|
+
height = metadata["height"]
|
|
613
|
+
transform = metadata["transform"]
|
|
614
|
+
|
|
615
|
+
geometries = []
|
|
616
|
+
records = []
|
|
617
|
+
|
|
618
|
+
for obj in objects:
|
|
619
|
+
# Convert normalized coords to pixel coords
|
|
620
|
+
px_x_min = obj["x_min"] * width
|
|
621
|
+
px_y_min = obj["y_min"] * height
|
|
622
|
+
px_x_max = obj["x_max"] * width
|
|
623
|
+
px_y_max = obj["y_max"] * height
|
|
624
|
+
|
|
625
|
+
# Convert pixel coords to geographic coords
|
|
626
|
+
geo_x_min, geo_y_min = transform * (px_x_min, px_y_max)
|
|
627
|
+
geo_x_max, geo_y_max = transform * (px_x_max, px_y_min)
|
|
628
|
+
|
|
629
|
+
# Create polygon
|
|
630
|
+
geom = box(geo_x_min, geo_y_min, geo_x_max, geo_y_max)
|
|
631
|
+
geometries.append(geom)
|
|
632
|
+
|
|
633
|
+
records.append(
|
|
634
|
+
{
|
|
635
|
+
"x_min": obj["x_min"],
|
|
636
|
+
"y_min": obj["y_min"],
|
|
637
|
+
"x_max": obj["x_max"],
|
|
638
|
+
"y_max": obj["y_max"],
|
|
639
|
+
"px_x_min": int(px_x_min),
|
|
640
|
+
"px_y_min": int(px_y_min),
|
|
641
|
+
"px_x_max": int(px_x_max),
|
|
642
|
+
"px_y_max": int(px_y_max),
|
|
643
|
+
}
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
gdf = gpd.GeoDataFrame(records, geometry=geometries, crs=metadata["crs"])
|
|
647
|
+
|
|
648
|
+
result["gdf"] = gdf
|
|
649
|
+
result["crs"] = metadata["crs"]
|
|
650
|
+
result["bounds"] = metadata["bounds"]
|
|
651
|
+
|
|
652
|
+
return result
|
|
653
|
+
|
|
654
|
+
def _georef_points(
|
|
655
|
+
self,
|
|
656
|
+
result: Dict[str, Any],
|
|
657
|
+
metadata: Dict,
|
|
658
|
+
) -> Dict[str, Any]:
|
|
659
|
+
"""Convert point results to georeferenced format.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
result: Point result from model.
|
|
663
|
+
metadata: Image metadata with CRS and transform.
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
Updated result dictionary with GeoDataFrame.
|
|
667
|
+
"""
|
|
668
|
+
points = result.get("points", [])
|
|
669
|
+
if not points:
|
|
670
|
+
result["gdf"] = gpd.GeoDataFrame(
|
|
671
|
+
columns=["geometry", "x", "y"],
|
|
672
|
+
crs=metadata["crs"],
|
|
673
|
+
)
|
|
674
|
+
result["crs"] = metadata["crs"]
|
|
675
|
+
result["bounds"] = metadata["bounds"]
|
|
676
|
+
return result
|
|
677
|
+
|
|
678
|
+
width = metadata["width"]
|
|
679
|
+
height = metadata["height"]
|
|
680
|
+
transform = metadata["transform"]
|
|
681
|
+
|
|
682
|
+
geometries = []
|
|
683
|
+
records = []
|
|
684
|
+
|
|
685
|
+
for pt in points:
|
|
686
|
+
# Convert normalized coords to pixel coords
|
|
687
|
+
px_x = pt["x"] * width
|
|
688
|
+
px_y = pt["y"] * height
|
|
689
|
+
|
|
690
|
+
# Convert pixel coords to geographic coords
|
|
691
|
+
geo_x, geo_y = transform * (px_x, px_y)
|
|
692
|
+
|
|
693
|
+
# Create point
|
|
694
|
+
geom = Point(geo_x, geo_y)
|
|
695
|
+
geometries.append(geom)
|
|
696
|
+
|
|
697
|
+
records.append(
|
|
698
|
+
{
|
|
699
|
+
"x": pt["x"],
|
|
700
|
+
"y": pt["y"],
|
|
701
|
+
"px_x": int(px_x),
|
|
702
|
+
"px_y": int(px_y),
|
|
703
|
+
}
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
gdf = gpd.GeoDataFrame(records, geometry=geometries, crs=metadata["crs"])
|
|
707
|
+
|
|
708
|
+
result["gdf"] = gdf
|
|
709
|
+
result["crs"] = metadata["crs"]
|
|
710
|
+
result["bounds"] = metadata["bounds"]
|
|
711
|
+
|
|
712
|
+
return result
|
|
713
|
+
|
|
714
|
+
def _save_vector(
|
|
715
|
+
self,
|
|
716
|
+
gdf: gpd.GeoDataFrame,
|
|
717
|
+
output_path: str,
|
|
718
|
+
) -> None:
|
|
719
|
+
"""Save GeoDataFrame to vector file.
|
|
720
|
+
|
|
721
|
+
Args:
|
|
722
|
+
gdf: GeoDataFrame to save.
|
|
723
|
+
output_path: Output file path. Extension determines format:
|
|
724
|
+
.geojson -> GeoJSON
|
|
725
|
+
.shp -> Shapefile
|
|
726
|
+
.gpkg -> GeoPackage
|
|
727
|
+
"""
|
|
728
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
729
|
+
|
|
730
|
+
ext = os.path.splitext(output_path)[1].lower()
|
|
731
|
+
if ext == ".geojson":
|
|
732
|
+
gdf.to_file(output_path, driver="GeoJSON")
|
|
733
|
+
elif ext == ".shp":
|
|
734
|
+
gdf.to_file(output_path, driver="ESRI Shapefile")
|
|
735
|
+
elif ext == ".gpkg":
|
|
736
|
+
gdf.to_file(output_path, driver="GPKG")
|
|
737
|
+
else:
|
|
738
|
+
gdf.to_file(output_path)
|
|
739
|
+
|
|
740
|
+
print(f"Saved {len(gdf)} features to {output_path}")
|
|
741
|
+
|
|
742
|
+
def create_detection_mask(
|
|
743
|
+
self,
|
|
744
|
+
source: Union[str, Image.Image, np.ndarray],
|
|
745
|
+
object_type: str,
|
|
746
|
+
output_path: Optional[str] = None,
|
|
747
|
+
bands: Optional[List[int]] = None,
|
|
748
|
+
**kwargs: Any,
|
|
749
|
+
) -> Tuple[np.ndarray, Optional[Dict]]:
|
|
750
|
+
"""Create a binary mask from object detections.
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
source: Image source.
|
|
754
|
+
object_type: Type of object to detect.
|
|
755
|
+
output_path: Path to save mask as GeoTIFF.
|
|
756
|
+
bands: Band indices for GeoTIFF.
|
|
757
|
+
**kwargs: Additional arguments for detect().
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
Tuple of (mask array, metadata dict).
|
|
761
|
+
"""
|
|
762
|
+
# Load image to get dimensions
|
|
763
|
+
image, metadata = self.load_image(source, bands)
|
|
764
|
+
width, height = image.size
|
|
765
|
+
|
|
766
|
+
# Detect objects
|
|
767
|
+
result = self.detect(source, object_type, bands=bands, **kwargs)
|
|
768
|
+
objects = result.get("objects", [])
|
|
769
|
+
|
|
770
|
+
# Create mask
|
|
771
|
+
mask = np.zeros((height, width), dtype=np.uint8)
|
|
772
|
+
|
|
773
|
+
for obj in objects:
|
|
774
|
+
x_min = int(obj["x_min"] * width)
|
|
775
|
+
y_min = int(obj["y_min"] * height)
|
|
776
|
+
x_max = int(obj["x_max"] * width)
|
|
777
|
+
y_max = int(obj["y_max"] * height)
|
|
778
|
+
|
|
779
|
+
mask[y_min:y_max, x_min:x_max] = 255
|
|
780
|
+
|
|
781
|
+
# Save as GeoTIFF if requested
|
|
782
|
+
if output_path and metadata and metadata.get("crs"):
|
|
783
|
+
self._save_mask_geotiff(mask, output_path, metadata)
|
|
784
|
+
elif output_path:
|
|
785
|
+
# Save as regular image
|
|
786
|
+
Image.fromarray(mask).save(output_path)
|
|
787
|
+
|
|
788
|
+
return mask, metadata
|
|
789
|
+
|
|
790
|
+
def _save_mask_geotiff(
|
|
791
|
+
self,
|
|
792
|
+
mask: np.ndarray,
|
|
793
|
+
output_path: str,
|
|
794
|
+
metadata: Dict,
|
|
795
|
+
) -> None:
|
|
796
|
+
"""Save mask as GeoTIFF with georeferencing.
|
|
797
|
+
|
|
798
|
+
Args:
|
|
799
|
+
mask: 2D mask array.
|
|
800
|
+
output_path: Output file path.
|
|
801
|
+
metadata: Image metadata with profile.
|
|
802
|
+
"""
|
|
803
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
804
|
+
|
|
805
|
+
profile = metadata["profile"].copy()
|
|
806
|
+
profile.update(
|
|
807
|
+
{
|
|
808
|
+
"dtype": "uint8",
|
|
809
|
+
"count": 1,
|
|
810
|
+
"height": mask.shape[0],
|
|
811
|
+
"width": mask.shape[1],
|
|
812
|
+
"compress": "lzw",
|
|
813
|
+
}
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
|
817
|
+
dst.write(mask, 1)
|
|
818
|
+
|
|
819
|
+
print(f"Saved mask to {output_path}")
|
|
820
|
+
|
|
821
|
+
def show_gui(
|
|
822
|
+
self,
|
|
823
|
+
basemap: str = "SATELLITE",
|
|
824
|
+
out_dir: Optional[str] = None,
|
|
825
|
+
opacity: float = 0.5,
|
|
826
|
+
**kwargs: Any,
|
|
827
|
+
) -> Any:
|
|
828
|
+
"""Display an interactive GUI for using Moondream with leafmap.
|
|
829
|
+
|
|
830
|
+
This method creates an interactive map interface for using Moondream
|
|
831
|
+
vision language model capabilities including:
|
|
832
|
+
- Image captioning (short, normal, long)
|
|
833
|
+
- Visual question answering (query)
|
|
834
|
+
- Object detection with bounding boxes
|
|
835
|
+
- Point detection for locating objects
|
|
836
|
+
|
|
837
|
+
Args:
|
|
838
|
+
basemap: The basemap to use. Defaults to "SATELLITE".
|
|
839
|
+
out_dir: The output directory for saving results.
|
|
840
|
+
Defaults to None (uses temp directory).
|
|
841
|
+
opacity: The opacity of overlay layers. Defaults to 0.5.
|
|
842
|
+
**kwargs: Additional keyword arguments passed to leafmap.Map().
|
|
843
|
+
|
|
844
|
+
Returns:
|
|
845
|
+
leafmap.Map: The interactive map with the Moondream GUI.
|
|
846
|
+
|
|
847
|
+
Example:
|
|
848
|
+
>>> moondream = MoondreamGeo()
|
|
849
|
+
>>> moondream.load_image("image.tif")
|
|
850
|
+
>>> m = moondream.show_gui()
|
|
851
|
+
>>> m
|
|
852
|
+
"""
|
|
853
|
+
from .map_widgets import moondream_gui
|
|
854
|
+
|
|
855
|
+
return moondream_gui(
|
|
856
|
+
self,
|
|
857
|
+
basemap=basemap,
|
|
858
|
+
out_dir=out_dir,
|
|
859
|
+
opacity=opacity,
|
|
860
|
+
**kwargs,
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
def get_last_result(self) -> Dict[str, Any]:
|
|
864
|
+
|
|
865
|
+
if hasattr(self, "last_result") and "gdf" in self.last_result:
|
|
866
|
+
return self.last_result["gdf"]
|
|
867
|
+
else:
|
|
868
|
+
return None
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
def moondream_caption(
|
|
872
|
+
source: Union[str, Image.Image, np.ndarray],
|
|
873
|
+
model_name: str = "vikhyatk/moondream2",
|
|
874
|
+
revision: Optional[str] = None,
|
|
875
|
+
length: str = "normal",
|
|
876
|
+
bands: Optional[List[int]] = None,
|
|
877
|
+
device: Optional[str] = None,
|
|
878
|
+
**kwargs: Any,
|
|
879
|
+
) -> str:
|
|
880
|
+
"""Convenience function to generate a caption for an image.
|
|
881
|
+
|
|
882
|
+
Args:
|
|
883
|
+
source: Image source (file path, PIL Image, or numpy array).
|
|
884
|
+
model_name: Moondream model name.
|
|
885
|
+
revision: Model revision.
|
|
886
|
+
length: Caption length ("short", "normal", "long").
|
|
887
|
+
bands: Band indices for GeoTIFF.
|
|
888
|
+
device: Device for inference.
|
|
889
|
+
**kwargs: Additional arguments.
|
|
890
|
+
|
|
891
|
+
Returns:
|
|
892
|
+
Generated caption string.
|
|
893
|
+
"""
|
|
894
|
+
processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
|
|
895
|
+
result = processor.caption(source, length=length, bands=bands, **kwargs)
|
|
896
|
+
return result["caption"]
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def moondream_query(
|
|
900
|
+
question: str,
|
|
901
|
+
source: Optional[Union[str, Image.Image, np.ndarray]] = None,
|
|
902
|
+
model_name: str = "vikhyatk/moondream2",
|
|
903
|
+
revision: Optional[str] = None,
|
|
904
|
+
reasoning: Optional[bool] = None,
|
|
905
|
+
bands: Optional[List[int]] = None,
|
|
906
|
+
device: Optional[str] = None,
|
|
907
|
+
**kwargs: Any,
|
|
908
|
+
) -> str:
|
|
909
|
+
"""Convenience function to ask a question about an image.
|
|
910
|
+
|
|
911
|
+
Args:
|
|
912
|
+
question: Question to ask.
|
|
913
|
+
source: Image source (optional for text-only queries).
|
|
914
|
+
model_name: Moondream model name.
|
|
915
|
+
revision: Model revision.
|
|
916
|
+
reasoning: Enable reasoning mode (moondream3 only).
|
|
917
|
+
bands: Band indices for GeoTIFF.
|
|
918
|
+
device: Device for inference.
|
|
919
|
+
**kwargs: Additional arguments.
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
Answer string.
|
|
923
|
+
"""
|
|
924
|
+
processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
|
|
925
|
+
result = processor.query(
|
|
926
|
+
question, source=source, reasoning=reasoning, bands=bands, **kwargs
|
|
927
|
+
)
|
|
928
|
+
return result["answer"]
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def moondream_detect(
|
|
932
|
+
source: Union[str, Image.Image, np.ndarray],
|
|
933
|
+
object_type: str,
|
|
934
|
+
model_name: str = "vikhyatk/moondream2",
|
|
935
|
+
revision: Optional[str] = None,
|
|
936
|
+
output_path: Optional[str] = None,
|
|
937
|
+
bands: Optional[List[int]] = None,
|
|
938
|
+
device: Optional[str] = None,
|
|
939
|
+
**kwargs: Any,
|
|
940
|
+
) -> Dict[str, Any]:
|
|
941
|
+
"""Convenience function to detect objects in an image.
|
|
942
|
+
|
|
943
|
+
Args:
|
|
944
|
+
source: Image source.
|
|
945
|
+
object_type: Type of object to detect.
|
|
946
|
+
model_name: Moondream model name.
|
|
947
|
+
revision: Model revision.
|
|
948
|
+
output_path: Path to save results as vector file.
|
|
949
|
+
bands: Band indices for GeoTIFF.
|
|
950
|
+
device: Device for inference.
|
|
951
|
+
**kwargs: Additional arguments.
|
|
952
|
+
|
|
953
|
+
Returns:
|
|
954
|
+
Detection results dictionary with "objects" and optionally "gdf".
|
|
955
|
+
"""
|
|
956
|
+
processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
|
|
957
|
+
return processor.detect(
|
|
958
|
+
source, object_type, output_path=output_path, bands=bands, **kwargs
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
def moondream_point(
|
|
963
|
+
source: Union[str, Image.Image, np.ndarray],
|
|
964
|
+
object_description: str,
|
|
965
|
+
model_name: str = "vikhyatk/moondream2",
|
|
966
|
+
revision: Optional[str] = None,
|
|
967
|
+
output_path: Optional[str] = None,
|
|
968
|
+
bands: Optional[List[int]] = None,
|
|
969
|
+
device: Optional[str] = None,
|
|
970
|
+
**kwargs: Any,
|
|
971
|
+
) -> Dict[str, Any]:
|
|
972
|
+
"""Convenience function to find points for objects in an image.
|
|
973
|
+
|
|
974
|
+
Args:
|
|
975
|
+
source: Image source.
|
|
976
|
+
object_description: Description of objects to find.
|
|
977
|
+
model_name: Moondream model name.
|
|
978
|
+
revision: Model revision.
|
|
979
|
+
output_path: Path to save results as vector file.
|
|
980
|
+
bands: Band indices for GeoTIFF.
|
|
981
|
+
device: Device for inference.
|
|
982
|
+
**kwargs: Additional arguments.
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
Point results dictionary with "points" and optionally "gdf".
|
|
986
|
+
"""
|
|
987
|
+
processor = MoondreamGeo(model_name=model_name, revision=revision, device=device)
|
|
988
|
+
return processor.point(
|
|
989
|
+
source, object_description, output_path=output_path, bands=bands, **kwargs
|
|
990
|
+
)
|