geoai-py 0.18.2__py2.py3-none-any.whl → 0.20.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/map_widgets.py CHANGED
@@ -1,10 +1,29 @@
1
1
  """Interactive widget for GeoAI."""
2
2
 
3
+ import os
4
+ import string
5
+ import random
6
+ import tempfile
7
+ from typing import Any, Optional
8
+
3
9
  import ipywidgets as widgets
4
10
 
5
11
  from .utils import dict_to_image, dict_to_rioxarray
6
12
 
7
13
 
14
+ def random_string(string_length: int = 6) -> str:
15
+ """Generate a random string of fixed length.
16
+
17
+ Args:
18
+ string_length: The length of the random string. Defaults to 6.
19
+
20
+ Returns:
21
+ A random string of the specified length.
22
+ """
23
+ letters = string.ascii_lowercase
24
+ return "".join(random.choice(letters) for _ in range(string_length))
25
+
26
+
8
27
  class DINOv3GUI(widgets.VBox):
9
28
  """Interactive widget for DINOv3."""
10
29
 
@@ -172,3 +191,540 @@ class DINOv3GUI(widgets.VBox):
172
191
 
173
192
  host_map.on_interaction(handle_map_interaction)
174
193
  host_map.default_style = {"cursor": "crosshair"}
194
+
195
+
196
+ def moondream_gui(
197
+ moondream,
198
+ basemap: str = "SATELLITE",
199
+ out_dir: Optional[str] = None,
200
+ opacity: float = 0.5,
201
+ **kwargs: Any,
202
+ ):
203
+ """Display an interactive GUI for using Moondream with leafmap.
204
+
205
+ This function creates an interactive map interface for using Moondream
206
+ vision language model capabilities including:
207
+ - Image captioning (short, normal, long) with streaming output
208
+ - Visual question answering (query) with streaming output
209
+ - Object detection with bounding boxes displayed on map
210
+ - Point detection for locating objects with markers on map
211
+
212
+ Args:
213
+ moondream (MoondreamGeo): The MoondreamGeo object with a loaded image.
214
+ Must have called load_image() or load_geotiff() first.
215
+ basemap (str, optional): The basemap to use. Defaults to "SATELLITE".
216
+ out_dir (str, optional): The output directory for saving results.
217
+ Defaults to None (uses temp directory).
218
+ opacity (float, optional): The opacity of overlay layers. Defaults to 0.5.
219
+ **kwargs: Additional keyword arguments passed to leafmap.Map().
220
+
221
+ Returns:
222
+ leafmap.Map: The interactive map with the Moondream GUI.
223
+
224
+ Example:
225
+ >>> from geoai import MoondreamGeo, moondream_gui
226
+ >>> moondream = MoondreamGeo()
227
+ >>> moondream.load_image("image.tif")
228
+ >>> m = moondream_gui(moondream)
229
+ >>> m
230
+ """
231
+ try:
232
+ import ipyevents
233
+ import ipyleaflet
234
+ import leafmap
235
+ from ipyfilechooser import FileChooser
236
+ except ImportError:
237
+ raise ImportError(
238
+ "The moondream_gui function requires additional packages. "
239
+ "Please install them with: pip install leafmap ipyevents ipyfilechooser"
240
+ )
241
+
242
+ if out_dir is None:
243
+ out_dir = tempfile.gettempdir()
244
+
245
+ # Create the map
246
+ m = leafmap.Map(**kwargs)
247
+ m.default_style = {"cursor": "crosshair"}
248
+ if basemap is not None:
249
+ m.add_basemap(basemap, show=False)
250
+
251
+ # Try to add the image layer if source is available
252
+ if moondream._source_path is not None:
253
+ try:
254
+ m.add_raster(moondream._source_path, layer_name="Image")
255
+ except Exception:
256
+ pass
257
+
258
+ # Initialize marker storage for detection results
259
+ m.detection_markers = []
260
+ m.point_markers = []
261
+
262
+ # Removed unused LayerGroups for detections and points.
263
+ m.last_result_as_gdf = None
264
+
265
+ # Widget styling
266
+ widget_width = "300px"
267
+ button_width = "90px"
268
+ padding = "0px 4px 0px 4px"
269
+ style = {"description_width": "initial"}
270
+
271
+ # Create toolbar buttons
272
+ toolbar_button = widgets.ToggleButton(
273
+ value=True,
274
+ tooltip="Toolbar",
275
+ icon="gear",
276
+ layout=widgets.Layout(width="28px", height="28px", padding="0px 0px 0px 4px"),
277
+ )
278
+
279
+ close_button = widgets.ToggleButton(
280
+ value=False,
281
+ tooltip="Close the tool",
282
+ icon="times",
283
+ button_style="primary",
284
+ layout=widgets.Layout(height="28px", width="28px", padding="0px 0px 0px 4px"),
285
+ )
286
+
287
+ # Mode selection
288
+ mode_dropdown = widgets.Dropdown(
289
+ options=["Caption", "Query", "Detect", "Point"],
290
+ value="Caption",
291
+ description="Mode:",
292
+ style=style,
293
+ layout=widgets.Layout(width=widget_width, padding=padding),
294
+ )
295
+
296
+ # Text prompt input
297
+ text_prompt = widgets.Text(
298
+ description="Prompt:",
299
+ placeholder="Enter text prompt...",
300
+ style=style,
301
+ layout=widgets.Layout(width=widget_width, padding=padding),
302
+ )
303
+
304
+ # Caption length selector (only visible in Caption mode)
305
+ caption_length = widgets.Dropdown(
306
+ options=["short", "normal", "long"],
307
+ value="normal",
308
+ description="Length:",
309
+ style=style,
310
+ layout=widgets.Layout(width=widget_width, padding=padding),
311
+ )
312
+
313
+ # Opacity slider for overlays
314
+ opacity_slider = widgets.FloatSlider(
315
+ description="Opacity:",
316
+ min=0,
317
+ max=1,
318
+ value=opacity,
319
+ readout=True,
320
+ continuous_update=True,
321
+ layout=widgets.Layout(width=widget_width, padding=padding),
322
+ style=style,
323
+ )
324
+
325
+ # Color picker for detection/point markers
326
+ colorpicker = widgets.ColorPicker(
327
+ concise=False,
328
+ description="Color:",
329
+ value="#ff0000",
330
+ layout=widgets.Layout(width="150px", padding=padding),
331
+ style=style,
332
+ )
333
+
334
+ # Action buttons
335
+ run_button = widgets.ToggleButton(
336
+ description="Run",
337
+ value=False,
338
+ button_style="primary",
339
+ layout=widgets.Layout(padding=padding, width=button_width),
340
+ )
341
+
342
+ save_button = widgets.ToggleButton(
343
+ description="Save",
344
+ value=False,
345
+ button_style="primary",
346
+ layout=widgets.Layout(width=button_width),
347
+ )
348
+
349
+ reset_button = widgets.ToggleButton(
350
+ description="Reset",
351
+ value=False,
352
+ button_style="primary",
353
+ layout=widgets.Layout(width=button_width),
354
+ )
355
+
356
+ # Output area for displaying results - using HTML for better text display
357
+ output_html = widgets.HTML(
358
+ value="",
359
+ layout=widgets.Layout(
360
+ width=widget_width,
361
+ padding=padding,
362
+ max_width=widget_width,
363
+ min_height="0px",
364
+ max_height="300px",
365
+ overflow="auto",
366
+ ),
367
+ )
368
+
369
+ # Build the toolbar layout
370
+ toolbar_header = widgets.HBox()
371
+ toolbar_header.children = [close_button, toolbar_button]
372
+
373
+ toolbar_footer = widgets.VBox()
374
+ toolbar_footer.children = [
375
+ mode_dropdown,
376
+ text_prompt,
377
+ caption_length,
378
+ opacity_slider,
379
+ colorpicker,
380
+ widgets.HBox(
381
+ [run_button, save_button, reset_button],
382
+ layout=widgets.Layout(padding="0px 4px 0px 4px"),
383
+ ),
384
+ output_html,
385
+ ]
386
+
387
+ toolbar_widget = widgets.VBox()
388
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
389
+
390
+ # Event handling for toolbar collapse/expand
391
+ toolbar_event = ipyevents.Event(
392
+ source=toolbar_widget, watched_events=["mouseenter", "mouseleave"]
393
+ )
394
+
395
+ def update_ui_visibility(change=None):
396
+ """Update UI element visibility based on selected mode."""
397
+ mode = mode_dropdown.value
398
+
399
+ # Clear prompt and output when mode changes
400
+ text_prompt.value = ""
401
+ output_html.value = ""
402
+
403
+ if mode == "Caption":
404
+ text_prompt.layout.display = "none"
405
+ caption_length.layout.display = "flex"
406
+ elif mode == "Query":
407
+ text_prompt.layout.display = "flex"
408
+ text_prompt.placeholder = "Ask a question about the image..."
409
+ caption_length.layout.display = "none"
410
+ elif mode == "Detect":
411
+ text_prompt.layout.display = "flex"
412
+ text_prompt.placeholder = "Object type to detect (e.g., building, trees)..."
413
+ caption_length.layout.display = "none"
414
+ elif mode == "Point":
415
+ text_prompt.layout.display = "flex"
416
+ text_prompt.placeholder = "Object description to locate..."
417
+ caption_length.layout.display = "none"
418
+
419
+ mode_dropdown.observe(update_ui_visibility, "value")
420
+ update_ui_visibility() # Initial update
421
+
422
+ def handle_toolbar_event(event):
423
+ if event["type"] == "mouseenter":
424
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
425
+ elif event["type"] == "mouseleave":
426
+ if not toolbar_button.value:
427
+ toolbar_widget.children = [toolbar_button]
428
+ toolbar_button.value = False
429
+ close_button.value = False
430
+
431
+ toolbar_event.on_dom_event(handle_toolbar_event)
432
+
433
+ def toolbar_btn_click(change):
434
+ if change["new"]:
435
+ close_button.value = False
436
+ toolbar_widget.children = [toolbar_header, toolbar_footer]
437
+ else:
438
+ if not close_button.value:
439
+ toolbar_widget.children = [toolbar_button]
440
+
441
+ toolbar_button.observe(toolbar_btn_click, "value")
442
+
443
+ def close_btn_click(change):
444
+ if change["new"]:
445
+ toolbar_button.value = False
446
+ if m.toolbar_control in m.controls:
447
+ m.remove_control(m.toolbar_control)
448
+ toolbar_widget.close()
449
+
450
+ close_button.observe(close_btn_click, "value")
451
+
452
+ def clear_detections():
453
+ """Clear all detection markers and layers."""
454
+ if "Detections" in m.get_layer_names():
455
+ m.remove_layer(m.find_layer("Detections"))
456
+
457
+ def clear_points():
458
+ """Clear all point markers."""
459
+ if "Points" in m.get_layer_names():
460
+ m.remove_layer(m.find_layer("Points"))
461
+
462
+ def add_detection_boxes(result, color="#ff0000"):
463
+ """Add bounding boxes from detection result to the map."""
464
+ clear_detections()
465
+
466
+ if "gdf" in result and len(result["gdf"]) > 0:
467
+ gdf = result["gdf"].copy()
468
+ m.add_gdf(
469
+ gdf,
470
+ layer_name="Detections",
471
+ style={
472
+ "color": color,
473
+ "fillColor": color,
474
+ "fillOpacity": opacity_slider.value,
475
+ "weight": 2,
476
+ },
477
+ info_mode=None,
478
+ )
479
+
480
+ def add_point_markers(result, color="#ff0000", opacity=0.5):
481
+ """Add point markers from point detection result to the map."""
482
+ clear_points()
483
+
484
+ if "gdf" in result and len(result["gdf"]) > 0:
485
+ gdf = result["gdf"].copy().to_crs("EPSG:4326")
486
+ gdf["x"] = gdf.geometry.centroid.x
487
+ gdf["y"] = gdf.geometry.centroid.y
488
+
489
+ m.add_circle_markers_from_xy(
490
+ gdf,
491
+ "x",
492
+ "y",
493
+ radius=6,
494
+ color=color,
495
+ fill_color=color,
496
+ fill_opacity=opacity,
497
+ layer_name="Points",
498
+ )
499
+
500
+ def update_output(text, append=False):
501
+ """Update the output HTML widget."""
502
+ # Escape HTML and convert newlines
503
+ import html
504
+
505
+ escaped = html.escape(text)
506
+ formatted = escaped.replace("\n", "<br>")
507
+ style = "font-family: monospace; font-size: 12px; word-wrap: break-word;"
508
+
509
+ if append and output_html.value:
510
+ # Extract existing content and append
511
+ current = output_html.value
512
+ if "<div" in current:
513
+ # Find the content between div tags
514
+ start = current.find(">") + 1
515
+ end = current.rfind("</div>")
516
+ existing = current[start:end]
517
+ output_html.value = f'<div style="{style}">{existing}{formatted}</div>'
518
+ else:
519
+ output_html.value = f'<div style="{style}">{formatted}</div>'
520
+ else:
521
+ output_html.value = f'<div style="{style}">{formatted}</div>'
522
+
523
+ def run_button_click(change):
524
+ if change["new"]:
525
+ run_button.value = False
526
+ mode = mode_dropdown.value
527
+
528
+ if moondream._source_path is None and moondream._metadata is None:
529
+ update_output(
530
+ "Please load an image first using load_image() or load_geotiff()."
531
+ )
532
+ return
533
+
534
+ try:
535
+
536
+ if mode == "Caption":
537
+ update_output(f"Generating caption ({caption_length.value})...")
538
+
539
+ result = moondream.caption(
540
+ moondream._source_path,
541
+ length=caption_length.value,
542
+ stream=False,
543
+ )
544
+ caption_text = result.get("caption", str(result))
545
+ update_output(f"Caption ({caption_length.value}):\n{caption_text}")
546
+ m.last_result = result
547
+ m.last_result_as_gdf = None
548
+
549
+ elif mode == "Query":
550
+ if len(text_prompt.value) == 0:
551
+ update_output("Please enter a question in the prompt field.")
552
+ return
553
+
554
+ update_output(f"Q: {text_prompt.value}\nGenerating answer...")
555
+
556
+ result = moondream.query(
557
+ text_prompt.value,
558
+ source=moondream._source_path,
559
+ stream=False,
560
+ )
561
+ answer_text = result.get("answer", str(result))
562
+ update_output(f"Q: {text_prompt.value}\nA: {answer_text}")
563
+ m.last_result = result
564
+
565
+ elif mode == "Detect":
566
+ if len(text_prompt.value) == 0:
567
+ update_output("Please enter an object type to detect.")
568
+ return
569
+
570
+ update_output(f"Detecting: {text_prompt.value}...")
571
+
572
+ result = moondream.detect(
573
+ moondream._source_path,
574
+ text_prompt.value,
575
+ )
576
+ num_objects = len(result.get("objects", []))
577
+
578
+ # Show detection info
579
+ info_text = f"Detecting: {text_prompt.value}\nFound {num_objects} object(s)."
580
+ if "gdf" in result and len(result["gdf"]) > 0:
581
+ info_text += (
582
+ f"\nAdded {len(result['gdf'])} bounding box(es) to map."
583
+ )
584
+ update_output(info_text)
585
+
586
+ if num_objects > 0:
587
+ add_detection_boxes(result, colorpicker.value)
588
+ m.last_result = result
589
+ if "gdf" in result and len(result["gdf"]) > 0:
590
+ m.last_result_as_gdf = result["gdf"].to_crs("EPSG:4326")
591
+
592
+ elif mode == "Point":
593
+ if len(text_prompt.value) == 0:
594
+ update_output("Please enter an object description to locate.")
595
+ return
596
+
597
+ update_output(f"Locating: {text_prompt.value}...")
598
+
599
+ result = moondream.point(
600
+ moondream._source_path,
601
+ text_prompt.value,
602
+ )
603
+ num_points = len(result.get("points", []))
604
+ update_output(
605
+ f"Locating: {text_prompt.value}\nFound {num_points} point(s)."
606
+ )
607
+
608
+ if num_points > 0:
609
+ add_point_markers(
610
+ result, colorpicker.value, opacity_slider.value
611
+ )
612
+ m.last_result = result
613
+ if "gdf" in result and len(result["gdf"]) > 0:
614
+ m.last_result_as_gdf = result["gdf"].to_crs("EPSG:4326")
615
+ except Exception as e:
616
+ import traceback
617
+
618
+ update_output(f"Error: {e}\n\n{traceback.format_exc()}")
619
+
620
+ run_button.observe(run_button_click, "value")
621
+
622
+ def filechooser_callback(chooser):
623
+ if chooser.selected is not None:
624
+ try:
625
+ filename = chooser.selected
626
+ if hasattr(m, "last_result") and m.last_result:
627
+ result = m.last_result
628
+
629
+ # Save based on result type
630
+ if "gdf" in result and len(result["gdf"]) > 0:
631
+ gdf = result["gdf"]
632
+ ext = os.path.splitext(filename)[1].lower()
633
+ if ext == ".geojson":
634
+ gdf.to_file(filename, driver="GeoJSON")
635
+ elif ext == ".shp":
636
+ gdf.to_file(filename, driver="ESRI Shapefile")
637
+ elif ext == ".gpkg":
638
+ gdf.to_file(filename, driver="GPKG")
639
+ else:
640
+ gdf.to_file(filename)
641
+ update_output(f"Saved {len(gdf)} features to {filename}")
642
+
643
+ elif "caption" in result:
644
+ with open(filename, "w") as f:
645
+ f.write(result["caption"])
646
+ update_output(f"Saved caption to {filename}")
647
+
648
+ elif "answer" in result:
649
+ with open(filename, "w") as f:
650
+ f.write(f"Q: {text_prompt.value}\n")
651
+ f.write(f"A: {result['answer']}")
652
+ update_output(f"Saved Q&A to {filename}")
653
+
654
+ except Exception as e:
655
+ update_output(f"Error saving: {e}")
656
+
657
+ if hasattr(m, "save_control") and m.save_control in m.controls:
658
+ m.remove_control(m.save_control)
659
+ delattr(m, "save_control")
660
+ save_button.value = False
661
+
662
+ def save_button_click(change):
663
+ if change["new"]:
664
+ if not hasattr(m, "last_result") or m.last_result is None:
665
+ update_output("Please run an operation first.")
666
+ save_button.value = False
667
+ return
668
+
669
+ result = m.last_result
670
+ mode = mode_dropdown.value
671
+
672
+ # Determine default filename and filter
673
+ if mode in ["Detect", "Point"] and "gdf" in result:
674
+ default_filename = f"{mode.lower()}_{random_string()}.geojson"
675
+ filter_pattern = ["*.geojson", "*.gpkg", "*.shp"]
676
+ else:
677
+ default_filename = f"{mode.lower()}_{random_string()}.txt"
678
+ filter_pattern = ["*.txt"]
679
+
680
+ sandbox_path = os.environ.get("SANDBOX_PATH")
681
+ filechooser = FileChooser(
682
+ path=os.getcwd(),
683
+ filename=default_filename,
684
+ sandbox_path=sandbox_path,
685
+ layout=widgets.Layout(width="454px"),
686
+ )
687
+ filechooser.use_dir_icons = True
688
+ filechooser.filter_pattern = filter_pattern
689
+ filechooser.register_callback(filechooser_callback)
690
+ save_control = ipyleaflet.WidgetControl(
691
+ widget=filechooser, position="topright"
692
+ )
693
+ m.add_control(save_control)
694
+ m.save_control = save_control
695
+ else:
696
+ if hasattr(m, "save_control") and m.save_control in m.controls:
697
+ m.remove_control(m.save_control)
698
+ delattr(m, "save_control")
699
+
700
+ save_button.observe(save_button_click, "value")
701
+
702
+ def reset_button_click(change):
703
+ if change["new"]:
704
+ run_button.value = False
705
+ save_button.value = False
706
+ reset_button.value = False
707
+ text_prompt.value = ""
708
+ caption_length.value = "normal"
709
+ opacity_slider.value = 0.5
710
+ colorpicker.value = "#ff0000"
711
+ output_html.value = ""
712
+
713
+ # Clear all markers and detection boxes
714
+ clear_detections()
715
+ clear_points()
716
+
717
+ # Clear last result
718
+ if hasattr(m, "last_result"):
719
+ m.last_result = None
720
+
721
+ reset_button.observe(reset_button_click, "value")
722
+
723
+ # Add the toolbar control to the map
724
+ toolbar_control = ipyleaflet.WidgetControl(
725
+ widget=toolbar_widget, position="topright"
726
+ )
727
+ m.add_control(toolbar_control)
728
+ m.toolbar_control = toolbar_control
729
+
730
+ return m