geoai-py 0.15.0__py2.py3-none-any.whl → 0.18.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,67 @@
1
+ """Structured output models for STAC catalog search results."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class STACCollectionInfo(BaseModel):
9
+ """Information about a STAC collection."""
10
+
11
+ id: str = Field(..., description="Collection identifier")
12
+ title: str = Field(..., description="Collection title")
13
+ description: Optional[str] = Field(None, description="Collection description")
14
+ license: Optional[str] = Field(None, description="License information")
15
+ temporal_extent: Optional[str] = Field(
16
+ None, description="Temporal extent (start/end dates)"
17
+ )
18
+ spatial_extent: Optional[str] = Field(None, description="Spatial bounding box")
19
+ providers: Optional[str] = Field(None, description="Data providers")
20
+ keywords: Optional[str] = Field(None, description="Keywords")
21
+
22
+
23
+ class STACAssetInfo(BaseModel):
24
+ """Information about a STAC item asset."""
25
+
26
+ key: str = Field(..., description="Asset key/identifier")
27
+ title: Optional[str] = Field(None, description="Asset title")
28
+
29
+
30
+ class STACItemInfo(BaseModel):
31
+ """Information about a STAC item."""
32
+
33
+ id: str = Field(..., description="Item identifier")
34
+ collection: str = Field(..., description="Collection ID")
35
+ datetime: Optional[str] = Field(None, description="Acquisition datetime")
36
+ bbox: Optional[List[float]] = Field(
37
+ None, description="Bounding box [west, south, east, north]"
38
+ )
39
+ assets: List[STACAssetInfo] = Field(
40
+ default_factory=list, description="Available assets"
41
+ )
42
+ # properties: Optional[Dict[str, Any]] = Field(
43
+ # None, description="Additional metadata properties"
44
+ # )
45
+
46
+
47
+ class STACSearchResult(BaseModel):
48
+ """Container for STAC search results."""
49
+
50
+ query: str = Field(..., description="Original search query")
51
+ collection: Optional[str] = Field(None, description="Collection searched")
52
+ item_count: int = Field(..., description="Number of items found")
53
+ items: List[STACItemInfo] = Field(
54
+ default_factory=list, description="List of STAC items"
55
+ )
56
+ bbox: Optional[List[float]] = Field(None, description="Search bounding box used")
57
+ time_range: Optional[str] = Field(None, description="Time range used for search")
58
+
59
+
60
+ class LocationInfo(BaseModel):
61
+ """Geographic location information."""
62
+
63
+ name: str = Field(..., description="Location name")
64
+ bbox: List[float] = Field(
65
+ ..., description="Bounding box [west, south, east, north]"
66
+ )
67
+ center: List[float] = Field(..., description="Center coordinates [lon, lat]")
@@ -0,0 +1,435 @@
1
+ """Tools for STAC catalog search and interaction."""
2
+
3
+ import ast
4
+ import json
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from strands import tool
8
+
9
+ from ..download import pc_collection_list, pc_stac_search
10
+ from .stac_models import (
11
+ LocationInfo,
12
+ STACAssetInfo,
13
+ STACCollectionInfo,
14
+ STACItemInfo,
15
+ STACSearchResult,
16
+ )
17
+
18
+
19
+ class STACTools:
20
+ """Collection of tools for searching and interacting with STAC catalogs."""
21
+
22
+ # Common location cache to avoid repeated geocoding
23
+ _LOCATION_CACHE = {
24
+ "san francisco": {
25
+ "name": "San Francisco",
26
+ "bbox": [-122.5155, 37.7034, -122.3549, 37.8324],
27
+ "center": [-122.4194, 37.7749],
28
+ },
29
+ "new york": {
30
+ "name": "New York",
31
+ "bbox": [-74.0479, 40.6829, -73.9067, 40.8820],
32
+ "center": [-73.9352, 40.7306],
33
+ },
34
+ "new york city": {
35
+ "name": "New York City",
36
+ "bbox": [-74.0479, 40.6829, -73.9067, 40.8820],
37
+ "center": [-73.9352, 40.7306],
38
+ },
39
+ "paris": {
40
+ "name": "Paris",
41
+ "bbox": [2.2241, 48.8156, 2.4698, 48.9022],
42
+ "center": [2.3522, 48.8566],
43
+ },
44
+ "london": {
45
+ "name": "London",
46
+ "bbox": [-0.5103, 51.2868, 0.3340, 51.6919],
47
+ "center": [-0.1276, 51.5074],
48
+ },
49
+ "tokyo": {
50
+ "name": "Tokyo",
51
+ "bbox": [139.5694, 35.5232, 139.9182, 35.8173],
52
+ "center": [139.6917, 35.6895],
53
+ },
54
+ "los angeles": {
55
+ "name": "Los Angeles",
56
+ "bbox": [-118.6682, 33.7037, -118.1553, 34.3373],
57
+ "center": [-118.2437, 34.0522],
58
+ },
59
+ "chicago": {
60
+ "name": "Chicago",
61
+ "bbox": [-87.9401, 41.6445, -87.5241, 42.0230],
62
+ "center": [-87.6298, 41.8781],
63
+ },
64
+ "seattle": {
65
+ "name": "Seattle",
66
+ "bbox": [-122.4595, 47.4810, -122.2244, 47.7341],
67
+ "center": [-122.3321, 47.6062],
68
+ },
69
+ "california": {
70
+ "name": "California",
71
+ "bbox": [-124.4820, 32.5288, -114.1315, 42.0095],
72
+ "center": [-119.4179, 36.7783],
73
+ },
74
+ "las vegas": {
75
+ "name": "Las Vegas",
76
+ "bbox": [-115.3711, 35.9630, -114.9372, 36.2610],
77
+ "center": [-115.1400, 36.1177],
78
+ },
79
+ }
80
+
81
+ def __init__(
82
+ self,
83
+ endpoint: str = "https://planetarycomputer.microsoft.com/api/stac/v1",
84
+ ) -> None:
85
+ """Initialize STAC tools.
86
+
87
+ Args:
88
+ endpoint: STAC API endpoint URL. Defaults to Microsoft Planetary Computer.
89
+ """
90
+ self.endpoint = endpoint
91
+ # Runtime cache for geocoding results
92
+ self._geocode_cache = {}
93
+
94
+ @tool(
95
+ description="List and search available STAC collections from Planetary Computer"
96
+ )
97
+ def list_collections(
98
+ self,
99
+ filter_keyword: Optional[str] = None,
100
+ detailed: bool = False,
101
+ ) -> str:
102
+ """List available STAC collections from Planetary Computer.
103
+
104
+ Args:
105
+ filter_keyword: Optional keyword to filter collections (searches in id, title, description).
106
+ detailed: If True, return detailed information including temporal extent, license, etc.
107
+
108
+ Returns:
109
+ JSON string containing list of collections with their metadata.
110
+ """
111
+ try:
112
+ # Get collections using existing function
113
+ df = pc_collection_list(
114
+ endpoint=self.endpoint,
115
+ detailed=detailed,
116
+ filter_by=None,
117
+ sort_by="id",
118
+ )
119
+
120
+ # Apply keyword filtering if specified
121
+ if filter_keyword:
122
+ mask = df["id"].str.contains(filter_keyword, case=False, na=False) | df[
123
+ "title"
124
+ ].str.contains(filter_keyword, case=False, na=False)
125
+ if "description" in df.columns:
126
+ mask |= df["description"].str.contains(
127
+ filter_keyword, case=False, na=False
128
+ )
129
+ df = df[mask]
130
+
131
+ # Convert to list of dictionaries
132
+ collections = df.to_dict("records")
133
+
134
+ # Convert to structured models
135
+ collection_models = []
136
+ for col in collections:
137
+ collection_models.append(
138
+ STACCollectionInfo(
139
+ id=col.get("id", ""),
140
+ title=col.get("title", ""),
141
+ description=col.get("description"),
142
+ license=col.get("license"),
143
+ temporal_extent=col.get("temporal_extent"),
144
+ spatial_extent=col.get("bbox"),
145
+ providers=col.get("providers"),
146
+ keywords=col.get("keywords"),
147
+ )
148
+ )
149
+
150
+ result = {
151
+ "count": len(collection_models),
152
+ "filter_keyword": filter_keyword,
153
+ "collections": [c.model_dump() for c in collection_models],
154
+ }
155
+
156
+ return json.dumps(result, indent=2)
157
+
158
+ except Exception as e:
159
+ return json.dumps({"error": str(e)})
160
+
161
+ @tool(
162
+ description="Search for STAC items in a specific collection with optional filters"
163
+ )
164
+ def search_items(
165
+ self,
166
+ collection: str,
167
+ bbox: Optional[Union[str, List[float]]] = None,
168
+ time_range: Optional[str] = None,
169
+ query: Optional[Union[str, Dict[str, Any]]] = None,
170
+ limit: Optional[Union[str, int]] = 10,
171
+ max_items: Optional[Union[str, int]] = 1,
172
+ ) -> str:
173
+ """Search for STAC items in the Planetary Computer catalog.
174
+
175
+ Args:
176
+ collection: Collection ID to search (e.g., "sentinel-2-l2a", "naip", "landsat-c2-l2").
177
+ bbox: Bounding box as [west, south, east, north] in WGS84 coordinates.
178
+ Example: [-122.5, 37.7, -122.3, 37.8] for San Francisco area.
179
+ time_range: Time range as "start/end" string in ISO format.
180
+ Example: "2024-09-01/2024-09-30" or "2024-09-01/2024-09-01" for single day.
181
+ query: Query parameters for filtering.
182
+ Example: {"eo:cloud_cover": {"lt": 10}} for cloud cover less than 10%.
183
+ limit: Number of items to return per page.
184
+ Example: 10 for 10 items per page.
185
+ max_items: Maximum number of items to return (default: 10).
186
+
187
+ Returns:
188
+ JSON string containing search results with item details including IDs, URLs, and metadata.
189
+ """
190
+ try:
191
+
192
+ if isinstance(bbox, str):
193
+ bbox = ast.literal_eval(bbox)
194
+ # Fix nested list issue: [[x,y,z,w]] -> [x,y,z,w]
195
+ if isinstance(bbox, list) and len(bbox) == 1 and isinstance(bbox[0], list):
196
+ bbox = bbox[0]
197
+
198
+ if isinstance(query, str):
199
+ # Try to fix common JSON formatting issues from LLM
200
+ query_str = query.strip()
201
+ # Fix missing closing braces
202
+ if query_str.count("{") > query_str.count("}"):
203
+ query_str = query_str + "}" * (
204
+ query_str.count("{") - query_str.count("}")
205
+ )
206
+ # Fix extra closing braces
207
+ elif query_str.count("}") > query_str.count("{"):
208
+ # Remove extra closing braces from the end
209
+ extra_braces = query_str.count("}") - query_str.count("{")
210
+ for _ in range(extra_braces):
211
+ query_str = query_str.rstrip("}")
212
+ query = ast.literal_eval(query_str)
213
+ if isinstance(limit, str):
214
+ limit = ast.literal_eval(limit)
215
+ if isinstance(max_items, str):
216
+ max_items = ast.literal_eval(max_items)
217
+
218
+ # Search using existing function
219
+ items = pc_stac_search(
220
+ collection=collection,
221
+ bbox=bbox,
222
+ time_range=time_range,
223
+ query=query,
224
+ limit=limit,
225
+ max_items=max_items,
226
+ quiet=True,
227
+ endpoint=self.endpoint,
228
+ )
229
+
230
+ # Convert to structured models
231
+ item_models = []
232
+ for item in items:
233
+ # Extract assets
234
+ assets = []
235
+ for key, asset in item.assets.items():
236
+ assets.append(
237
+ STACAssetInfo(
238
+ key=key,
239
+ title=asset.title,
240
+ )
241
+ )
242
+
243
+ item_models.append(
244
+ STACItemInfo(
245
+ id=item.id,
246
+ collection=item.collection_id,
247
+ datetime=str(item.datetime) if item.datetime else None,
248
+ bbox=list(item.bbox) if item.bbox else None,
249
+ assets=assets,
250
+ # properties=item.properties,
251
+ )
252
+ )
253
+
254
+ # Create search result
255
+ result = STACSearchResult(
256
+ query=f"Collection: {collection}",
257
+ collection=collection,
258
+ item_count=len(item_models),
259
+ items=item_models,
260
+ bbox=bbox,
261
+ time_range=time_range,
262
+ )
263
+
264
+ return json.dumps(result.model_dump(), indent=2)
265
+
266
+ except Exception as e:
267
+ return json.dumps({"error": str(e)})
268
+
269
+ @tool(description="Get detailed information about a specific STAC item")
270
+ def get_item_info(
271
+ self,
272
+ item_id: str,
273
+ collection: str,
274
+ ) -> str:
275
+ """Get detailed information about a specific STAC item.
276
+
277
+ Args:
278
+ item_id: The STAC item ID to retrieve.
279
+ collection: The collection ID containing the item.
280
+
281
+ Returns:
282
+ JSON string with detailed item information including all assets and metadata.
283
+ """
284
+ try:
285
+ # Search for the specific item
286
+ items = pc_stac_search(
287
+ collection=collection,
288
+ bbox=None,
289
+ time_range=None,
290
+ query={"id": {"eq": item_id}},
291
+ limit=1,
292
+ max_items=1,
293
+ quiet=True,
294
+ endpoint=self.endpoint,
295
+ )
296
+
297
+ if not items:
298
+ return json.dumps(
299
+ {"error": f"Item {item_id} not found in collection {collection}"}
300
+ )
301
+
302
+ item = items[0]
303
+
304
+ # Extract all assets with full details
305
+ assets = []
306
+ for key, asset in item.assets.items():
307
+ asset_info = {
308
+ "key": key,
309
+ "href": asset.href,
310
+ "type": asset.media_type,
311
+ "title": asset.title,
312
+ "description": getattr(asset, "description", None),
313
+ "roles": getattr(asset, "roles", None),
314
+ }
315
+ assets.append(asset_info)
316
+
317
+ result = {
318
+ "id": item.id,
319
+ "collection": item.collection_id,
320
+ "datetime": str(item.datetime) if item.datetime else None,
321
+ "bbox": list(item.bbox) if item.bbox else None,
322
+ # "properties": item.properties,
323
+ "assets": assets,
324
+ }
325
+
326
+ return json.dumps(result, indent=2)
327
+
328
+ except Exception as e:
329
+ return json.dumps({"error": str(e)})
330
+
331
+ @tool(description="Parse a location name and return its bounding box coordinates")
332
+ def geocode_location(self, location_name: str) -> str:
333
+ """Convert a location name to geographic coordinates and bounding box.
334
+
335
+ This tool uses a geocoding service to find the coordinates for a given location name.
336
+
337
+ Args:
338
+ location_name: Name of the location (e.g., "San Francisco", "New York", "Paris, France").
339
+
340
+ Returns:
341
+ JSON string with location info including bounding box and center coordinates.
342
+ """
343
+ try:
344
+ # Check static cache first (common locations)
345
+ location_key = location_name.lower().strip()
346
+ if location_key in self._LOCATION_CACHE:
347
+ cached = self._LOCATION_CACHE[location_key]
348
+ location_info = LocationInfo(
349
+ name=cached["name"],
350
+ bbox=cached["bbox"],
351
+ center=cached["center"],
352
+ )
353
+ return json.dumps(location_info.model_dump(), indent=2)
354
+
355
+ # Check runtime cache
356
+ if location_key in self._geocode_cache:
357
+ return self._geocode_cache[location_key]
358
+
359
+ # Geocode using Nominatim
360
+ import requests
361
+
362
+ url = "https://nominatim.openstreetmap.org/search"
363
+ params = {
364
+ "q": location_name,
365
+ "format": "json",
366
+ "limit": 1,
367
+ }
368
+ headers = {"User-Agent": "GeoAI-STAC-Agent/1.0"}
369
+
370
+ response = requests.get(url, params=params, headers=headers, timeout=10)
371
+ response.raise_for_status()
372
+
373
+ results = response.json()
374
+
375
+ if not results:
376
+ error_result = json.dumps(
377
+ {"error": f"Location '{location_name}' not found"}
378
+ )
379
+ self._geocode_cache[location_key] = error_result
380
+ return error_result
381
+
382
+ result = results[0]
383
+ bbox = [
384
+ float(result["boundingbox"][2]), # west
385
+ float(result["boundingbox"][0]), # south
386
+ float(result["boundingbox"][3]), # east
387
+ float(result["boundingbox"][1]), # north
388
+ ]
389
+ center = [float(result["lon"]), float(result["lat"])]
390
+
391
+ location_info = LocationInfo(
392
+ name=result.get("display_name", location_name),
393
+ bbox=bbox,
394
+ center=center,
395
+ )
396
+
397
+ result_json = json.dumps(location_info.model_dump(), indent=2)
398
+ # Cache the result
399
+ self._geocode_cache[location_key] = result_json
400
+
401
+ return result_json
402
+
403
+ except Exception as e:
404
+ return json.dumps({"error": f"Geocoding error: {str(e)}"})
405
+
406
+ @tool(
407
+ description="Get common STAC collection IDs for different satellite/aerial imagery types"
408
+ )
409
+ def get_common_collections(self) -> str:
410
+ """Get a list of commonly used STAC collections from Planetary Computer.
411
+
412
+ Returns:
413
+ JSON string with collection IDs and descriptions for popular datasets.
414
+ """
415
+ common_collections = {
416
+ "sentinel-2-l2a": "Sentinel-2 Level-2A - Multispectral imagery (10m-60m resolution, global coverage)",
417
+ "landsat-c2-l2": "Landsat Collection 2 Level-2 - Multispectral imagery (30m resolution, global coverage)",
418
+ "naip": "NAIP - National Agriculture Imagery Program (1m resolution, USA only)",
419
+ "sentinel-1-grd": "Sentinel-1 GRD - Synthetic Aperture Radar imagery (global coverage)",
420
+ "aster-l1t": "ASTER L1T - Multispectral and thermal imagery (15m-90m resolution)",
421
+ "cop-dem-glo-30": "Copernicus DEM - Global Digital Elevation Model (30m resolution)",
422
+ "hgb": "HGB - High Resolution Building Footprints",
423
+ "io-lulc": "Impact Observatory Land Use/Land Cover - Annual 10m resolution land cover",
424
+ "modis": "MODIS - Moderate Resolution Imaging Spectroradiometer (250m-1km resolution)",
425
+ "daymet-daily-hi": "Daymet - Daily surface weather data for Hawaii",
426
+ }
427
+
428
+ result = {
429
+ "count": len(common_collections),
430
+ "collections": [
431
+ {"id": k, "description": v} for k, v in common_collections.items()
432
+ ],
433
+ }
434
+
435
+ return json.dumps(result, indent=2)
geoai/change_detection.py CHANGED
@@ -13,7 +13,8 @@ from skimage.transform import resize
13
13
  try:
14
14
  from torchange.models.segment_any_change import AnyChange, show_change_masks
15
15
  except ImportError:
16
- print("torchange requires Python 3.11 or higher")
16
+ AnyChange = None
17
+ show_change_masks = None
17
18
 
18
19
  from .utils import download_file
19
20
 
@@ -36,6 +37,13 @@ class ChangeDetection:
36
37
 
37
38
  def _init_model(self):
38
39
  """Initialize the AnyChange model."""
40
+ if AnyChange is None:
41
+ raise ImportError(
42
+ "The 'torchange' package is required for change detection. "
43
+ "Please install it using: pip install torchange\n"
44
+ "Note: torchange requires Python 3.11 or higher."
45
+ )
46
+
39
47
  if self.sam_checkpoint is None:
40
48
  self.sam_checkpoint = download_checkpoint(self.sam_model_type)
41
49
 
@@ -551,6 +559,13 @@ class ChangeDetection:
551
559
  Returns:
552
560
  matplotlib.figure.Figure: The figure object
553
561
  """
562
+ if show_change_masks is None:
563
+ raise ImportError(
564
+ "The 'torchange' package is required for change detection visualization. "
565
+ "Please install it using: pip install torchange\n"
566
+ "Note: torchange requires Python 3.11 or higher."
567
+ )
568
+
554
569
  change_masks, img1, img2 = self.detect_changes(
555
570
  image1_path, image2_path, return_results=True
556
571
  )
@@ -561,7 +576,15 @@ class ChangeDetection:
561
576
 
562
577
  return fig
563
578
 
564
- def visualize_results(self, image1_path, image2_path, binary_path, prob_path):
579
+ def visualize_results(
580
+ self,
581
+ image1_path,
582
+ image2_path,
583
+ binary_path,
584
+ prob_path,
585
+ title1="Earlier Image",
586
+ title2="Later Image",
587
+ ):
565
588
  """Create enhanced visualization with probability analysis."""
566
589
 
567
590
  # Load data
@@ -594,11 +617,11 @@ class ChangeDetection:
594
617
 
595
618
  # Row 1: Original and overlays
596
619
  axes[0, 0].imshow(img1_crop)
597
- axes[0, 0].set_title("2019 Image", fontweight="bold")
620
+ axes[0, 0].set_title(title1, fontweight="bold")
598
621
  axes[0, 0].axis("off")
599
622
 
600
623
  axes[0, 1].imshow(img2_crop)
601
- axes[0, 1].set_title("2022 Image", fontweight="bold")
624
+ axes[0, 1].set_title(title2, fontweight="bold")
602
625
  axes[0, 1].axis("off")
603
626
 
604
627
  # Binary overlay
@@ -708,6 +731,8 @@ class ChangeDetection:
708
731
  image2_path,
709
732
  binary_path,
710
733
  prob_path,
734
+ title1="Earlier Image",
735
+ title2="Later Image",
711
736
  output_path="split_comparison.png",
712
737
  ):
713
738
  """Create a split comparison visualization showing before/after with change overlay."""
@@ -742,7 +767,7 @@ class ChangeDetection:
742
767
  # Create split comparison
743
768
  fig, ax = plt.subplots(1, 1, figsize=(15, 10))
744
769
 
745
- # Create combined image - left half is 2019, right half is 2022
770
+ # Create combined image - left half is earlier, right half is later
746
771
  combined_img = np.zeros_like(img1)
747
772
  combined_img[:, : w // 2] = img1[:, : w // 2]
748
773
  combined_img[:, w // 2 :] = img2[:, w // 2 :]
@@ -763,7 +788,7 @@ class ChangeDetection:
763
788
  ax.text(
764
789
  w // 4,
765
790
  50,
766
- "2019",
791
+ title1,
767
792
  fontsize=20,
768
793
  color="white",
769
794
  ha="center",
@@ -772,7 +797,7 @@ class ChangeDetection:
772
797
  ax.text(
773
798
  3 * w // 4,
774
799
  50,
775
- "2022",
800
+ title2,
776
801
  fontsize=20,
777
802
  color="white",
778
803
  ha="center",
geoai/download.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """This module provides functions to download data, including NAIP imagery and building data from Overture Maps."""
2
2
 
3
+ import datetime
3
4
  import logging
4
5
  import os
5
6
  import subprocess
@@ -819,6 +820,7 @@ def pc_stac_search(
819
820
  query: Optional[Dict[str, Any]] = None,
820
821
  limit: int = 10,
821
822
  max_items: Optional[int] = None,
823
+ quiet: bool = False,
822
824
  endpoint: str = "https://planetarycomputer.microsoft.com/api/stac/v1",
823
825
  ) -> List["pystac.Item"]:
824
826
  """
@@ -839,6 +841,7 @@ def pc_stac_search(
839
841
  limit (int, optional): Number of items to return per page. Defaults to 10.
840
842
  max_items (int, optional): Maximum total number of items to return.
841
843
  Defaults to None (returns all matching items).
844
+ quiet: bool, optional): Whether to suppress print statements. Defaults to False.
842
845
  endpoint (str, optional): STAC API endpoint URL.
843
846
  Defaults to "https://planetarycomputer.microsoft.com/api/stac/v1".
844
847
 
@@ -896,7 +899,8 @@ def pc_stac_search(
896
899
  except Exception as e:
897
900
  raise Exception(f"Error retrieving search results: {str(e)}")
898
901
 
899
- print(f"Found {len(items)} items matching search criteria")
902
+ if not quiet:
903
+ print(f"Found {len(items)} items matching search criteria")
900
904
 
901
905
  return items
902
906
 
geoai/geoai.py CHANGED
@@ -32,6 +32,9 @@ from .train import (
32
32
  instance_segmentation,
33
33
  instance_segmentation_batch,
34
34
  instance_segmentation_inference_on_geotiff,
35
+ lightly_embed_images,
36
+ lightly_train_model,
37
+ load_lightly_pretrained_model,
35
38
  object_detection,
36
39
  object_detection_batch,
37
40
  semantic_segmentation,
geoai/timm_segment.py CHANGED
@@ -241,7 +241,10 @@ class TimmSegmentationModel(pl.LightningModule):
241
241
  )
242
242
 
243
243
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
244
- optimizer, mode="min", factor=0.5, patience=5, verbose=True
244
+ optimizer,
245
+ mode="min",
246
+ factor=0.5,
247
+ patience=5,
245
248
  )
246
249
 
247
250
  return {