terrakio-core 0.2.6__tar.gz → 0.2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/PKG-INFO +1 -1
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/pyproject.toml +1 -1
- terrakio_core-0.2.8/terrakio_core/__init__.py +7 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/client.py +284 -86
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/dataset_management.py +1 -1
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/PKG-INFO +1 -1
- terrakio_core-0.2.6/terrakio_core/__init__.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/README.md +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/setup.cfg +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/auth.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/config.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/exceptions.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/group_access_management.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/mass_stats.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/space_management.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/user_management.py +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/SOURCES.txt +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/dependency_links.txt +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/requires.txt +0 -0
- {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/top_level.txt +0 -0
|
@@ -11,14 +11,17 @@ import xarray as xr
|
|
|
11
11
|
import nest_asyncio
|
|
12
12
|
from shapely.geometry import shape, mapping
|
|
13
13
|
from shapely.geometry.base import BaseGeometry as ShapelyGeometry
|
|
14
|
-
|
|
14
|
+
from google.cloud import storage
|
|
15
15
|
from .exceptions import APIError, ConfigurationError
|
|
16
|
+
import logging
|
|
17
|
+
import textwrap
|
|
18
|
+
|
|
16
19
|
|
|
17
20
|
class BaseClient:
|
|
18
21
|
def __init__(self, url: Optional[str] = None, key: Optional[str] = None,
|
|
19
22
|
auth_url: Optional[str] = "https://dev-au.terrak.io",
|
|
20
23
|
quiet: bool = False, config_file: Optional[str] = None,
|
|
21
|
-
verify: bool = True, timeout: int =
|
|
24
|
+
verify: bool = True, timeout: int = 300):
|
|
22
25
|
nest_asyncio.apply()
|
|
23
26
|
self.quiet = quiet
|
|
24
27
|
self.verify = verify
|
|
@@ -86,9 +89,10 @@ class BaseClient:
|
|
|
86
89
|
)
|
|
87
90
|
return self._aiohttp_session
|
|
88
91
|
|
|
89
|
-
async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
|
|
90
|
-
|
|
91
|
-
|
|
92
|
+
async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
|
|
93
|
+
in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
|
|
94
|
+
output: str = "csv", resolution: int = -1, buffer: bool = False,
|
|
95
|
+
retry: int = 3, **kwargs):
|
|
92
96
|
"""
|
|
93
97
|
Asynchronous version of the wcs() method using aiohttp.
|
|
94
98
|
|
|
@@ -99,6 +103,8 @@ class BaseClient:
|
|
|
99
103
|
out_crs (str): Output coordinate reference system
|
|
100
104
|
output (str): Output format ('csv' or 'netcdf')
|
|
101
105
|
resolution (int): Resolution parameter
|
|
106
|
+
buffer (bool): Whether to buffer the request (default True)
|
|
107
|
+
retry (int): Number of retry attempts (default 3)
|
|
102
108
|
**kwargs: Additional parameters to pass to the WCS request
|
|
103
109
|
|
|
104
110
|
Returns:
|
|
@@ -111,8 +117,7 @@ class BaseClient:
|
|
|
111
117
|
"geometry": mapping(feature),
|
|
112
118
|
"properties": {}
|
|
113
119
|
}
|
|
114
|
-
|
|
115
|
-
|
|
120
|
+
|
|
116
121
|
payload = {
|
|
117
122
|
"feature": feature,
|
|
118
123
|
"in_crs": in_crs,
|
|
@@ -120,47 +125,65 @@ class BaseClient:
|
|
|
120
125
|
"output": output,
|
|
121
126
|
"resolution": resolution,
|
|
122
127
|
"expr": expr,
|
|
128
|
+
"buffer": buffer,
|
|
129
|
+
"resolution": resolution,
|
|
123
130
|
**kwargs
|
|
124
131
|
}
|
|
125
|
-
|
|
126
132
|
request_url = f"{self.url}/geoquery"
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
133
|
+
for attempt in range(retry + 1):
|
|
134
|
+
try:
|
|
135
|
+
session = await self.aiohttp_session
|
|
136
|
+
async with session.post(request_url, json=payload, ssl=self.verify) as response:
|
|
137
|
+
if not response.ok:
|
|
138
|
+
should_retry = False
|
|
139
|
+
if response.status in [408, 502, 503, 504]:
|
|
140
|
+
should_retry = True
|
|
141
|
+
elif response.status == 500:
|
|
142
|
+
try:
|
|
143
|
+
response_text = await response.text()
|
|
144
|
+
if "Internal server error" not in response_text:
|
|
145
|
+
should_retry = True
|
|
146
|
+
except:
|
|
147
|
+
should_retry = True
|
|
148
|
+
|
|
149
|
+
if should_retry and attempt < retry:
|
|
150
|
+
continue
|
|
151
|
+
else:
|
|
152
|
+
error_msg = f"API request failed: {response.status} {response.reason}"
|
|
153
|
+
try:
|
|
154
|
+
error_data = await response.json()
|
|
155
|
+
if "detail" in error_data:
|
|
156
|
+
error_msg += f" - {error_data['detail']}"
|
|
157
|
+
except:
|
|
158
|
+
pass
|
|
159
|
+
raise APIError(error_msg)
|
|
160
|
+
|
|
161
|
+
content = await response.read()
|
|
162
|
+
|
|
163
|
+
if output.lower() == "csv":
|
|
154
164
|
import pandas as pd
|
|
165
|
+
df = pd.read_csv(BytesIO(content))
|
|
166
|
+
return df
|
|
167
|
+
elif output.lower() == "netcdf":
|
|
168
|
+
return xr.open_dataset(BytesIO(content))
|
|
169
|
+
else:
|
|
155
170
|
try:
|
|
156
|
-
return
|
|
157
|
-
except:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
171
|
+
return xr.open_dataset(BytesIO(content))
|
|
172
|
+
except ValueError:
|
|
173
|
+
import pandas as pd
|
|
174
|
+
try:
|
|
175
|
+
return pd.read_csv(BytesIO(content))
|
|
176
|
+
except:
|
|
177
|
+
return content
|
|
178
|
+
|
|
179
|
+
except aiohttp.ClientError as e:
|
|
180
|
+
if attempt == retry:
|
|
181
|
+
raise APIError(f"Request failed: {str(e)}")
|
|
182
|
+
continue
|
|
183
|
+
except Exception as e:
|
|
184
|
+
if attempt == retry:
|
|
185
|
+
raise
|
|
186
|
+
continue
|
|
164
187
|
|
|
165
188
|
async def close_async(self):
|
|
166
189
|
"""Close the aiohttp session"""
|
|
@@ -174,41 +197,6 @@ class BaseClient:
|
|
|
174
197
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
175
198
|
await self.close_async()
|
|
176
199
|
|
|
177
|
-
def validate_feature(self, feature: Dict[str, Any]) -> None:
|
|
178
|
-
if hasattr(feature, 'is_valid'):
|
|
179
|
-
from shapely.geometry import mapping
|
|
180
|
-
feature = {
|
|
181
|
-
"type": "Feature",
|
|
182
|
-
"geometry": mapping(feature),
|
|
183
|
-
"properties": {}
|
|
184
|
-
}
|
|
185
|
-
if not isinstance(feature, dict):
|
|
186
|
-
raise ValueError("Feature must be a dictionary or a Shapely geometry")
|
|
187
|
-
if feature.get("type") != "Feature":
|
|
188
|
-
raise ValueError("GeoJSON object must be of type 'Feature'")
|
|
189
|
-
if "geometry" not in feature:
|
|
190
|
-
raise ValueError("Feature must contain a 'geometry' field")
|
|
191
|
-
if "properties" not in feature:
|
|
192
|
-
raise ValueError("Feature must contain a 'properties' field")
|
|
193
|
-
try:
|
|
194
|
-
geometry = shape(feature["geometry"])
|
|
195
|
-
except Exception as e:
|
|
196
|
-
raise ValueError(f"Invalid geometry format: {str(e)}")
|
|
197
|
-
if not geometry.is_valid:
|
|
198
|
-
raise ValueError(f"Invalid geometry: {geometry.is_valid_reason}")
|
|
199
|
-
geom_type = feature["geometry"]["type"]
|
|
200
|
-
if geom_type == "Point":
|
|
201
|
-
if len(feature["geometry"]["coordinates"]) != 2:
|
|
202
|
-
raise ValueError("Point must have exactly 2 coordinates")
|
|
203
|
-
elif geom_type == "Polygon":
|
|
204
|
-
if not geometry.is_simple:
|
|
205
|
-
raise ValueError("Polygon must be simple (not self-intersecting)")
|
|
206
|
-
if geometry.area == 0:
|
|
207
|
-
raise ValueError("Polygon must have non-zero area")
|
|
208
|
-
coords = feature["geometry"]["coordinates"][0]
|
|
209
|
-
if coords[0] != coords[-1]:
|
|
210
|
-
raise ValueError("Polygon must be closed (first and last points must match)")
|
|
211
|
-
|
|
212
200
|
def signup(self, email: str, password: str) -> Dict[str, Any]:
|
|
213
201
|
if not self.auth_client:
|
|
214
202
|
raise ConfigurationError("Authentication client not initialized. Please provide auth_url during client initialization.")
|
|
@@ -309,7 +297,6 @@ class BaseClient:
|
|
|
309
297
|
"geometry": mapping(feature),
|
|
310
298
|
"properties": {}
|
|
311
299
|
}
|
|
312
|
-
self.validate_feature(feature)
|
|
313
300
|
payload = {
|
|
314
301
|
"feature": feature,
|
|
315
302
|
"in_crs": in_crs,
|
|
@@ -321,7 +308,10 @@ class BaseClient:
|
|
|
321
308
|
}
|
|
322
309
|
request_url = f"{self.url}/geoquery"
|
|
323
310
|
try:
|
|
311
|
+
print("the request url is ", request_url)
|
|
312
|
+
print("the payload is ", payload)
|
|
324
313
|
response = self.session.post(request_url, json=payload, timeout=self.timeout, verify=self.verify)
|
|
314
|
+
print("the response is ", response.text)
|
|
325
315
|
if not response.ok:
|
|
326
316
|
error_msg = f"API request failed: {response.status_code} {response.reason}"
|
|
327
317
|
try:
|
|
@@ -690,15 +680,39 @@ class BaseClient:
|
|
|
690
680
|
)
|
|
691
681
|
return self.mass_stats.random_sample(name, **kwargs)
|
|
692
682
|
|
|
693
|
-
async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv"
|
|
683
|
+
async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv",
|
|
684
|
+
in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
|
|
694
685
|
"""
|
|
695
686
|
Compute zonal statistics for all geometries in a GeoDataFrame using asyncio for concurrency.
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
|
|
690
|
+
expr (str): Terrakio expression to evaluate, can include spatial aggregations
|
|
691
|
+
conc (int): Number of concurrent requests to make
|
|
692
|
+
inplace (bool): Whether to modify the input GeoDataFrame in place
|
|
693
|
+
output (str): Output format (csv or netcdf)
|
|
694
|
+
in_crs (str): Input coordinate reference system
|
|
695
|
+
out_crs (str): Output coordinate reference system
|
|
696
|
+
resolution (int): Resolution parameter
|
|
697
|
+
buffer (bool): Whether to buffer the request (default True)
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
|
|
696
701
|
"""
|
|
702
|
+
if conc > 100:
|
|
703
|
+
raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
|
|
697
704
|
|
|
698
705
|
# Process geometries in batches
|
|
699
706
|
all_results = []
|
|
700
707
|
row_indices = []
|
|
701
708
|
|
|
709
|
+
# Calculate total batches for progress reporting
|
|
710
|
+
total_geometries = len(gdb)
|
|
711
|
+
total_batches = (total_geometries + conc - 1) // conc # Ceiling division
|
|
712
|
+
completed_batches = 0
|
|
713
|
+
|
|
714
|
+
print(f"Processing {total_geometries} geometries with concurrency {conc}")
|
|
715
|
+
|
|
702
716
|
async def process_geometry(geom, index):
|
|
703
717
|
"""Process a single geometry"""
|
|
704
718
|
try:
|
|
@@ -707,7 +721,8 @@ class BaseClient:
|
|
|
707
721
|
"geometry": mapping(geom),
|
|
708
722
|
"properties": {"index": index}
|
|
709
723
|
}
|
|
710
|
-
result = await self.wcs_async(expr=expr, feature=feature, output=output
|
|
724
|
+
result = await self.wcs_async(expr=expr, feature=feature, output=output,
|
|
725
|
+
in_crs=in_crs, out_crs=out_crs, resolution=resolution, buffer=buffer)
|
|
711
726
|
# Add original index to track which geometry this result belongs to
|
|
712
727
|
if isinstance(result, pd.DataFrame):
|
|
713
728
|
result['_geometry_index'] = index
|
|
@@ -749,11 +764,19 @@ class BaseClient:
|
|
|
749
764
|
batch_results = await process_batch(batch_indices)
|
|
750
765
|
all_results.extend(batch_results)
|
|
751
766
|
row_indices.extend(batch_indices)
|
|
767
|
+
|
|
768
|
+
# Update progress
|
|
769
|
+
completed_batches += 1
|
|
770
|
+
processed_geometries = min(i + conc, total_geometries)
|
|
771
|
+
print(f"Progress: {completed_batches}/{total_batches} completed ({processed_geometries}/{total_geometries} geometries processed)")
|
|
772
|
+
|
|
752
773
|
except Exception as e:
|
|
753
774
|
if hasattr(e, 'response'):
|
|
754
775
|
raise APIError(f"API request failed: {e.response.text}")
|
|
755
776
|
raise
|
|
756
777
|
|
|
778
|
+
print("All batches completed! Processing results...")
|
|
779
|
+
|
|
757
780
|
if not all_results:
|
|
758
781
|
raise ValueError("No valid results were returned for any geometry")
|
|
759
782
|
|
|
@@ -845,7 +868,8 @@ class BaseClient:
|
|
|
845
868
|
else:
|
|
846
869
|
return result_gdf
|
|
847
870
|
|
|
848
|
-
def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv"
|
|
871
|
+
def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv",
|
|
872
|
+
in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
|
|
849
873
|
"""
|
|
850
874
|
Compute zonal statistics for all geometries in a GeoDataFrame.
|
|
851
875
|
|
|
@@ -855,32 +879,44 @@ class BaseClient:
|
|
|
855
879
|
conc (int): Number of concurrent requests to make
|
|
856
880
|
inplace (bool): Whether to modify the input GeoDataFrame in place
|
|
857
881
|
output (str): Output format (csv or netcdf)
|
|
882
|
+
in_crs (str): Input coordinate reference system
|
|
883
|
+
out_crs (str): Output coordinate reference system
|
|
884
|
+
resolution (int): Resolution parameter
|
|
885
|
+
buffer (bool): Whether to buffer the request (default True)
|
|
858
886
|
|
|
859
887
|
Returns:
|
|
860
888
|
geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
|
|
861
889
|
"""
|
|
890
|
+
if conc > 100:
|
|
891
|
+
raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
|
|
862
892
|
import asyncio
|
|
863
893
|
|
|
894
|
+
print(f"Starting zonal statistics computation for expression: {expr}")
|
|
895
|
+
|
|
864
896
|
# Check if we're in a Jupyter environment or already have an event loop
|
|
865
897
|
try:
|
|
866
898
|
loop = asyncio.get_running_loop()
|
|
867
899
|
# We're in an async context (like Jupyter), use create_task
|
|
868
900
|
nest_asyncio.apply()
|
|
869
|
-
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
901
|
+
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
902
|
+
in_crs, out_crs, resolution, buffer))
|
|
870
903
|
except RuntimeError:
|
|
871
904
|
# No running event loop, safe to use asyncio.run()
|
|
872
|
-
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
905
|
+
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
906
|
+
in_crs, out_crs, resolution, buffer))
|
|
873
907
|
except ImportError:
|
|
874
908
|
# nest_asyncio not available, try alternative approach
|
|
875
909
|
try:
|
|
876
910
|
loop = asyncio.get_running_loop()
|
|
877
911
|
# Create task in existing loop
|
|
878
|
-
task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
912
|
+
task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
913
|
+
in_crs, out_crs, resolution, buffer))
|
|
879
914
|
# This won't work directly - we need a different approach
|
|
880
915
|
raise RuntimeError("Cannot run async code in Jupyter without nest_asyncio. Please install: pip install nest-asyncio")
|
|
881
916
|
except RuntimeError:
|
|
882
917
|
# No event loop, use asyncio.run
|
|
883
|
-
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
918
|
+
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
919
|
+
in_crs, out_crs, resolution, buffer))
|
|
884
920
|
|
|
885
921
|
# Ensure aiohttp session is closed after running async code
|
|
886
922
|
try:
|
|
@@ -890,6 +926,7 @@ class BaseClient:
|
|
|
890
926
|
# Event loop may already be closed, ignore
|
|
891
927
|
pass
|
|
892
928
|
|
|
929
|
+
print("Zonal statistics computation completed!")
|
|
893
930
|
return result
|
|
894
931
|
|
|
895
932
|
# Group access management protected methods
|
|
@@ -1011,6 +1048,80 @@ class BaseClient:
|
|
|
1011
1048
|
timeout=self.timeout
|
|
1012
1049
|
)
|
|
1013
1050
|
return self.space_management.delete_data_in_path(path, region)
|
|
1051
|
+
|
|
1052
|
+
def generate_ai_dataset(
|
|
1053
|
+
self,
|
|
1054
|
+
name: str,
|
|
1055
|
+
aoi_geojson: str,
|
|
1056
|
+
expression_x: str,
|
|
1057
|
+
expression_y: str,
|
|
1058
|
+
samples: int,
|
|
1059
|
+
tile_size: int,
|
|
1060
|
+
crs: str = "epsg:4326",
|
|
1061
|
+
res: float = 0.001,
|
|
1062
|
+
region: str = "aus",
|
|
1063
|
+
start_year: int = None,
|
|
1064
|
+
end_year: int = None,
|
|
1065
|
+
) -> dict:
|
|
1066
|
+
"""
|
|
1067
|
+
Generate an AI dataset using specified parameters.
|
|
1068
|
+
|
|
1069
|
+
Args:
|
|
1070
|
+
name (str): Name of the dataset to generate
|
|
1071
|
+
aoi_geojson (str): Path to GeoJSON file containing area of interest
|
|
1072
|
+
expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
1073
|
+
expression_y (str): Expression for Y variable with {year} placeholder
|
|
1074
|
+
samples (int): Number of samples to generate
|
|
1075
|
+
tile_size (int): Size of tiles in degrees
|
|
1076
|
+
crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
|
|
1077
|
+
res (float, optional): Resolution in degrees. Defaults to 0.001
|
|
1078
|
+
region (str, optional): Region code. Defaults to "aus"
|
|
1079
|
+
start_year (int, optional): Start year for data generation. Required if end_year provided
|
|
1080
|
+
end_year (int, optional): End year for data generation. Required if start_year provided
|
|
1081
|
+
overwrite (bool, optional): Whether to overwrite existing dataset. Defaults to False
|
|
1082
|
+
|
|
1083
|
+
Returns:
|
|
1084
|
+
dict: Response from the AI dataset generation API
|
|
1085
|
+
|
|
1086
|
+
Raises:
|
|
1087
|
+
ValidationError: If required parameters are missing or invalid
|
|
1088
|
+
APIError: If the API request fails
|
|
1089
|
+
"""
|
|
1090
|
+
|
|
1091
|
+
# we have the parameters, let pass the parameters to the random sample function
|
|
1092
|
+
# task_id = self.random_sample(name, aoi_geojson, expression_x, expression_y, samples, tile_size, crs, res, region, start_year, end_year, overwrite)
|
|
1093
|
+
config = {
|
|
1094
|
+
"expressions" : [{"expr": expression_x, "res": res, "prefix": "x"}],
|
|
1095
|
+
"filters" : []
|
|
1096
|
+
}
|
|
1097
|
+
config["expressions"].append({"expr": expression_y, "res" : res, "prefix": "y"})
|
|
1098
|
+
|
|
1099
|
+
expression_x = expression_x.replace("{year}", str(start_year))
|
|
1100
|
+
expression_y = expression_y.replace("{year}", str(start_year))
|
|
1101
|
+
print("the aoi geojson is ", aoi_geojson)
|
|
1102
|
+
with open(aoi_geojson, 'r') as f:
|
|
1103
|
+
aoi_data = json.load(f)
|
|
1104
|
+
print("the config is ", config)
|
|
1105
|
+
task_id = self.random_sample(
|
|
1106
|
+
name=name,
|
|
1107
|
+
config=config,
|
|
1108
|
+
aoi=aoi_data,
|
|
1109
|
+
samples=samples,
|
|
1110
|
+
year_range=[start_year, end_year],
|
|
1111
|
+
crs=crs,
|
|
1112
|
+
tile_size=tile_size,
|
|
1113
|
+
res=res,
|
|
1114
|
+
region=region,
|
|
1115
|
+
output="netcdf",
|
|
1116
|
+
server=self.url,
|
|
1117
|
+
bucket="terrakio-mass-requests",
|
|
1118
|
+
overwrite=True
|
|
1119
|
+
)["task_id"]
|
|
1120
|
+
print("the task id is ", task_id)
|
|
1121
|
+
task_id = self.start_mass_stats_job(task_id)
|
|
1122
|
+
print("the task id is ", task_id)
|
|
1123
|
+
return task_id
|
|
1124
|
+
|
|
1014
1125
|
|
|
1015
1126
|
def train_model(self, model_name: str, training_data: dict) -> dict:
|
|
1016
1127
|
"""
|
|
@@ -1044,3 +1155,90 @@ class BaseClient:
|
|
|
1044
1155
|
except requests.RequestException as e:
|
|
1045
1156
|
raise APIError(f"Model training request failed: {str(e)}")
|
|
1046
1157
|
|
|
1158
|
+
def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
|
|
1159
|
+
# we have the dataset and we have the product, and we have the model name, we need to create a new json file and add that to the dataset as our virtual dataset
|
|
1160
|
+
# upload the script to the bucket, the script should be able to download the model and do the inferencing
|
|
1161
|
+
# we need to upload the the json to the to the dataset as our virtual dataset
|
|
1162
|
+
# then we do nothing and wait for the user to make the request call to the explorer
|
|
1163
|
+
# we should have a uniform script for the random forest deployment
|
|
1164
|
+
# create a script for each model
|
|
1165
|
+
# upload script to google bucket,
|
|
1166
|
+
#
|
|
1167
|
+
|
|
1168
|
+
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
1169
|
+
# self.create_dataset(collection = "terrakio-datasets", input = input, )
|
|
1170
|
+
# we have the script, we need to upload it to the bucket
|
|
1171
|
+
script_name = f"{product}.py"
|
|
1172
|
+
print("the script content is ", script_content)
|
|
1173
|
+
print("the script name is ", script_name)
|
|
1174
|
+
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
1175
|
+
# after uploading the script, we need to create a new virtual dataset
|
|
1176
|
+
self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
|
|
1177
|
+
|
|
1178
|
+
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
1179
|
+
return textwrap.dedent(f'''
|
|
1180
|
+
import logging
|
|
1181
|
+
from io import BytesIO
|
|
1182
|
+
from google.cloud import storage
|
|
1183
|
+
from onnxruntime import InferenceSession
|
|
1184
|
+
import numpy as np
|
|
1185
|
+
import xarray as xr
|
|
1186
|
+
import datetime
|
|
1187
|
+
|
|
1188
|
+
logging.basicConfig(
|
|
1189
|
+
level=logging.INFO
|
|
1190
|
+
)
|
|
1191
|
+
|
|
1192
|
+
def get_model():
|
|
1193
|
+
logging.info("Loading model for {model_name}...")
|
|
1194
|
+
|
|
1195
|
+
client = storage.Client()
|
|
1196
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
1197
|
+
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
1198
|
+
|
|
1199
|
+
model = BytesIO()
|
|
1200
|
+
blob.download_to_file(model)
|
|
1201
|
+
model.seek(0)
|
|
1202
|
+
|
|
1203
|
+
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
1204
|
+
return session
|
|
1205
|
+
|
|
1206
|
+
def {product}(*bands, model):
|
|
1207
|
+
logging.info("start preparing data")
|
|
1208
|
+
|
|
1209
|
+
original_shape = bands[0].shape
|
|
1210
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
1211
|
+
|
|
1212
|
+
transformed_bands = []
|
|
1213
|
+
for band in bands:
|
|
1214
|
+
transformed_band = band.values.reshape(-1,1)
|
|
1215
|
+
transformed_bands.append(transformed_band)
|
|
1216
|
+
|
|
1217
|
+
input_data = np.hstack(transformed_bands)
|
|
1218
|
+
|
|
1219
|
+
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
1220
|
+
|
|
1221
|
+
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
1222
|
+
|
|
1223
|
+
logging.info(f"Model output shape: {{output.shape}}")
|
|
1224
|
+
|
|
1225
|
+
output_reshaped = output.reshape(original_shape)
|
|
1226
|
+
result = xr.DataArray(
|
|
1227
|
+
data=output_reshaped,
|
|
1228
|
+
dims=bands[0].dims,
|
|
1229
|
+
coords=bands[0].coords
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
return result
|
|
1233
|
+
''').strip()
|
|
1234
|
+
|
|
1235
|
+
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
1236
|
+
"""Upload the generated script to Google Cloud Storage"""
|
|
1237
|
+
|
|
1238
|
+
client = storage.Client()
|
|
1239
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
1240
|
+
blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
|
|
1241
|
+
# the first layer is the uid, the second layer is the model training job name
|
|
1242
|
+
blob.upload_from_string(script_content, content_type='text/plain')
|
|
1243
|
+
logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
|
|
1244
|
+
|
|
@@ -117,7 +117,7 @@ class DatasetManagement:
|
|
|
117
117
|
|
|
118
118
|
# Add optional parameters if provided
|
|
119
119
|
for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
|
|
120
|
-
"no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform"]:
|
|
120
|
+
"no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
|
|
121
121
|
if param in kwargs:
|
|
122
122
|
payload[param] = kwargs[param]
|
|
123
123
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|