terrakio-core 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +7 -0
- terrakio_core/client.py +267 -84
- terrakio_core/dataset_management.py +1 -1
- {terrakio_core-0.2.6.dist-info → terrakio_core-0.2.7.dist-info}/METADATA +1 -1
- {terrakio_core-0.2.6.dist-info → terrakio_core-0.2.7.dist-info}/RECORD +7 -7
- {terrakio_core-0.2.6.dist-info → terrakio_core-0.2.7.dist-info}/WHEEL +0 -0
- {terrakio_core-0.2.6.dist-info → terrakio_core-0.2.7.dist-info}/top_level.txt +0 -0
terrakio_core/__init__.py
CHANGED
terrakio_core/client.py
CHANGED
|
@@ -11,14 +11,17 @@ import xarray as xr
|
|
|
11
11
|
import nest_asyncio
|
|
12
12
|
from shapely.geometry import shape, mapping
|
|
13
13
|
from shapely.geometry.base import BaseGeometry as ShapelyGeometry
|
|
14
|
-
|
|
14
|
+
from google.cloud import storage
|
|
15
15
|
from .exceptions import APIError, ConfigurationError
|
|
16
|
+
import logging
|
|
17
|
+
import textwrap
|
|
18
|
+
|
|
16
19
|
|
|
17
20
|
class BaseClient:
|
|
18
21
|
def __init__(self, url: Optional[str] = None, key: Optional[str] = None,
|
|
19
22
|
auth_url: Optional[str] = "https://dev-au.terrak.io",
|
|
20
23
|
quiet: bool = False, config_file: Optional[str] = None,
|
|
21
|
-
verify: bool = True, timeout: int =
|
|
24
|
+
verify: bool = True, timeout: int = 300):
|
|
22
25
|
nest_asyncio.apply()
|
|
23
26
|
self.quiet = quiet
|
|
24
27
|
self.verify = verify
|
|
@@ -86,9 +89,10 @@ class BaseClient:
|
|
|
86
89
|
)
|
|
87
90
|
return self._aiohttp_session
|
|
88
91
|
|
|
89
|
-
async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
|
|
90
|
-
|
|
91
|
-
|
|
92
|
+
async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
|
|
93
|
+
in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
|
|
94
|
+
output: str = "csv", resolution: int = -1, buffer: bool = True,
|
|
95
|
+
retry: int = 3, **kwargs):
|
|
92
96
|
"""
|
|
93
97
|
Asynchronous version of the wcs() method using aiohttp.
|
|
94
98
|
|
|
@@ -99,6 +103,8 @@ class BaseClient:
|
|
|
99
103
|
out_crs (str): Output coordinate reference system
|
|
100
104
|
output (str): Output format ('csv' or 'netcdf')
|
|
101
105
|
resolution (int): Resolution parameter
|
|
106
|
+
buffer (bool): Whether to buffer the request (default True)
|
|
107
|
+
retry (int): Number of retry attempts (default 3)
|
|
102
108
|
**kwargs: Additional parameters to pass to the WCS request
|
|
103
109
|
|
|
104
110
|
Returns:
|
|
@@ -111,8 +117,7 @@ class BaseClient:
|
|
|
111
117
|
"geometry": mapping(feature),
|
|
112
118
|
"properties": {}
|
|
113
119
|
}
|
|
114
|
-
|
|
115
|
-
|
|
120
|
+
|
|
116
121
|
payload = {
|
|
117
122
|
"feature": feature,
|
|
118
123
|
"in_crs": in_crs,
|
|
@@ -120,47 +125,68 @@ class BaseClient:
|
|
|
120
125
|
"output": output,
|
|
121
126
|
"resolution": resolution,
|
|
122
127
|
"expr": expr,
|
|
128
|
+
"buffer": buffer,
|
|
129
|
+
"resolution": resolution,
|
|
123
130
|
**kwargs
|
|
124
131
|
}
|
|
125
132
|
|
|
126
133
|
request_url = f"{self.url}/geoquery"
|
|
127
134
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
135
|
+
for attempt in range(retry + 1):
|
|
136
|
+
try:
|
|
137
|
+
session = await self.aiohttp_session
|
|
138
|
+
async with session.post(request_url, json=payload, ssl=self.verify) as response:
|
|
139
|
+
if not response.ok:
|
|
140
|
+
should_retry = False
|
|
141
|
+
|
|
142
|
+
if response.status in [408, 502, 503, 504]:
|
|
143
|
+
should_retry = True
|
|
144
|
+
elif response.status == 500:
|
|
145
|
+
try:
|
|
146
|
+
response_text = await response.text()
|
|
147
|
+
if "Internal server error" not in response_text:
|
|
148
|
+
should_retry = True
|
|
149
|
+
except:
|
|
150
|
+
should_retry = True
|
|
151
|
+
|
|
152
|
+
if should_retry and attempt < retry:
|
|
153
|
+
continue
|
|
154
|
+
else:
|
|
155
|
+
error_msg = f"API request failed: {response.status} {response.reason}"
|
|
156
|
+
try:
|
|
157
|
+
error_data = await response.json()
|
|
158
|
+
if "detail" in error_data:
|
|
159
|
+
error_msg += f" - {error_data['detail']}"
|
|
160
|
+
except:
|
|
161
|
+
pass
|
|
162
|
+
raise APIError(error_msg)
|
|
163
|
+
|
|
164
|
+
content = await response.read()
|
|
165
|
+
|
|
166
|
+
if output.lower() == "csv":
|
|
154
167
|
import pandas as pd
|
|
168
|
+
df = pd.read_csv(BytesIO(content))
|
|
169
|
+
return df
|
|
170
|
+
elif output.lower() == "netcdf":
|
|
171
|
+
return xr.open_dataset(BytesIO(content))
|
|
172
|
+
else:
|
|
155
173
|
try:
|
|
156
|
-
return
|
|
157
|
-
except:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
174
|
+
return xr.open_dataset(BytesIO(content))
|
|
175
|
+
except ValueError:
|
|
176
|
+
import pandas as pd
|
|
177
|
+
try:
|
|
178
|
+
return pd.read_csv(BytesIO(content))
|
|
179
|
+
except:
|
|
180
|
+
return content
|
|
181
|
+
|
|
182
|
+
except aiohttp.ClientError as e:
|
|
183
|
+
if attempt == retry:
|
|
184
|
+
raise APIError(f"Request failed: {str(e)}")
|
|
185
|
+
continue
|
|
186
|
+
except Exception as e:
|
|
187
|
+
if attempt == retry:
|
|
188
|
+
raise
|
|
189
|
+
continue
|
|
164
190
|
|
|
165
191
|
async def close_async(self):
|
|
166
192
|
"""Close the aiohttp session"""
|
|
@@ -174,41 +200,6 @@ class BaseClient:
|
|
|
174
200
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
175
201
|
await self.close_async()
|
|
176
202
|
|
|
177
|
-
def validate_feature(self, feature: Dict[str, Any]) -> None:
|
|
178
|
-
if hasattr(feature, 'is_valid'):
|
|
179
|
-
from shapely.geometry import mapping
|
|
180
|
-
feature = {
|
|
181
|
-
"type": "Feature",
|
|
182
|
-
"geometry": mapping(feature),
|
|
183
|
-
"properties": {}
|
|
184
|
-
}
|
|
185
|
-
if not isinstance(feature, dict):
|
|
186
|
-
raise ValueError("Feature must be a dictionary or a Shapely geometry")
|
|
187
|
-
if feature.get("type") != "Feature":
|
|
188
|
-
raise ValueError("GeoJSON object must be of type 'Feature'")
|
|
189
|
-
if "geometry" not in feature:
|
|
190
|
-
raise ValueError("Feature must contain a 'geometry' field")
|
|
191
|
-
if "properties" not in feature:
|
|
192
|
-
raise ValueError("Feature must contain a 'properties' field")
|
|
193
|
-
try:
|
|
194
|
-
geometry = shape(feature["geometry"])
|
|
195
|
-
except Exception as e:
|
|
196
|
-
raise ValueError(f"Invalid geometry format: {str(e)}")
|
|
197
|
-
if not geometry.is_valid:
|
|
198
|
-
raise ValueError(f"Invalid geometry: {geometry.is_valid_reason}")
|
|
199
|
-
geom_type = feature["geometry"]["type"]
|
|
200
|
-
if geom_type == "Point":
|
|
201
|
-
if len(feature["geometry"]["coordinates"]) != 2:
|
|
202
|
-
raise ValueError("Point must have exactly 2 coordinates")
|
|
203
|
-
elif geom_type == "Polygon":
|
|
204
|
-
if not geometry.is_simple:
|
|
205
|
-
raise ValueError("Polygon must be simple (not self-intersecting)")
|
|
206
|
-
if geometry.area == 0:
|
|
207
|
-
raise ValueError("Polygon must have non-zero area")
|
|
208
|
-
coords = feature["geometry"]["coordinates"][0]
|
|
209
|
-
if coords[0] != coords[-1]:
|
|
210
|
-
raise ValueError("Polygon must be closed (first and last points must match)")
|
|
211
|
-
|
|
212
203
|
def signup(self, email: str, password: str) -> Dict[str, Any]:
|
|
213
204
|
if not self.auth_client:
|
|
214
205
|
raise ConfigurationError("Authentication client not initialized. Please provide auth_url during client initialization.")
|
|
@@ -309,7 +300,6 @@ class BaseClient:
|
|
|
309
300
|
"geometry": mapping(feature),
|
|
310
301
|
"properties": {}
|
|
311
302
|
}
|
|
312
|
-
self.validate_feature(feature)
|
|
313
303
|
payload = {
|
|
314
304
|
"feature": feature,
|
|
315
305
|
"in_crs": in_crs,
|
|
@@ -321,7 +311,10 @@ class BaseClient:
|
|
|
321
311
|
}
|
|
322
312
|
request_url = f"{self.url}/geoquery"
|
|
323
313
|
try:
|
|
314
|
+
print("the request url is ", request_url)
|
|
315
|
+
print("the payload is ", payload)
|
|
324
316
|
response = self.session.post(request_url, json=payload, timeout=self.timeout, verify=self.verify)
|
|
317
|
+
print("the response is ", response.text)
|
|
325
318
|
if not response.ok:
|
|
326
319
|
error_msg = f"API request failed: {response.status_code} {response.reason}"
|
|
327
320
|
try:
|
|
@@ -690,10 +683,27 @@ class BaseClient:
|
|
|
690
683
|
)
|
|
691
684
|
return self.mass_stats.random_sample(name, **kwargs)
|
|
692
685
|
|
|
693
|
-
async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv"
|
|
686
|
+
async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv",
|
|
687
|
+
in_crs="epsg:4326", out_crs="epsg:4326", resolution=0.005, buffer=True):
|
|
694
688
|
"""
|
|
695
689
|
Compute zonal statistics for all geometries in a GeoDataFrame using asyncio for concurrency.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
|
|
693
|
+
expr (str): Terrakio expression to evaluate, can include spatial aggregations
|
|
694
|
+
conc (int): Number of concurrent requests to make
|
|
695
|
+
inplace (bool): Whether to modify the input GeoDataFrame in place
|
|
696
|
+
output (str): Output format (csv or netcdf)
|
|
697
|
+
in_crs (str): Input coordinate reference system
|
|
698
|
+
out_crs (str): Output coordinate reference system
|
|
699
|
+
resolution (int): Resolution parameter
|
|
700
|
+
buffer (bool): Whether to buffer the request (default True)
|
|
701
|
+
|
|
702
|
+
Returns:
|
|
703
|
+
geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
|
|
696
704
|
"""
|
|
705
|
+
if conc > 100:
|
|
706
|
+
raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
|
|
697
707
|
|
|
698
708
|
# Process geometries in batches
|
|
699
709
|
all_results = []
|
|
@@ -707,7 +717,8 @@ class BaseClient:
|
|
|
707
717
|
"geometry": mapping(geom),
|
|
708
718
|
"properties": {"index": index}
|
|
709
719
|
}
|
|
710
|
-
result = await self.wcs_async(expr=expr, feature=feature, output=output
|
|
720
|
+
result = await self.wcs_async(expr=expr, feature=feature, output=output,
|
|
721
|
+
in_crs=in_crs, out_crs=out_crs, resolution=resolution, buffer=buffer)
|
|
711
722
|
# Add original index to track which geometry this result belongs to
|
|
712
723
|
if isinstance(result, pd.DataFrame):
|
|
713
724
|
result['_geometry_index'] = index
|
|
@@ -845,7 +856,8 @@ class BaseClient:
|
|
|
845
856
|
else:
|
|
846
857
|
return result_gdf
|
|
847
858
|
|
|
848
|
-
def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv"
|
|
859
|
+
def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv",
|
|
860
|
+
in_crs="epsg:4326", out_crs="epsg:4326", resolution=0.005, buffer=True):
|
|
849
861
|
"""
|
|
850
862
|
Compute zonal statistics for all geometries in a GeoDataFrame.
|
|
851
863
|
|
|
@@ -855,10 +867,16 @@ class BaseClient:
|
|
|
855
867
|
conc (int): Number of concurrent requests to make
|
|
856
868
|
inplace (bool): Whether to modify the input GeoDataFrame in place
|
|
857
869
|
output (str): Output format (csv or netcdf)
|
|
870
|
+
in_crs (str): Input coordinate reference system
|
|
871
|
+
out_crs (str): Output coordinate reference system
|
|
872
|
+
resolution (int): Resolution parameter
|
|
873
|
+
buffer (bool): Whether to buffer the request (default True)
|
|
858
874
|
|
|
859
875
|
Returns:
|
|
860
876
|
geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
|
|
861
877
|
"""
|
|
878
|
+
if conc > 100:
|
|
879
|
+
raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
|
|
862
880
|
import asyncio
|
|
863
881
|
|
|
864
882
|
# Check if we're in a Jupyter environment or already have an event loop
|
|
@@ -866,21 +884,25 @@ class BaseClient:
|
|
|
866
884
|
loop = asyncio.get_running_loop()
|
|
867
885
|
# We're in an async context (like Jupyter), use create_task
|
|
868
886
|
nest_asyncio.apply()
|
|
869
|
-
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
887
|
+
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
888
|
+
in_crs, out_crs, resolution, buffer))
|
|
870
889
|
except RuntimeError:
|
|
871
890
|
# No running event loop, safe to use asyncio.run()
|
|
872
|
-
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
891
|
+
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
892
|
+
in_crs, out_crs, resolution, buffer))
|
|
873
893
|
except ImportError:
|
|
874
894
|
# nest_asyncio not available, try alternative approach
|
|
875
895
|
try:
|
|
876
896
|
loop = asyncio.get_running_loop()
|
|
877
897
|
# Create task in existing loop
|
|
878
|
-
task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
898
|
+
task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
899
|
+
in_crs, out_crs, resolution, buffer))
|
|
879
900
|
# This won't work directly - we need a different approach
|
|
880
901
|
raise RuntimeError("Cannot run async code in Jupyter without nest_asyncio. Please install: pip install nest-asyncio")
|
|
881
902
|
except RuntimeError:
|
|
882
903
|
# No event loop, use asyncio.run
|
|
883
|
-
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output
|
|
904
|
+
result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
|
|
905
|
+
in_crs, out_crs, resolution, buffer))
|
|
884
906
|
|
|
885
907
|
# Ensure aiohttp session is closed after running async code
|
|
886
908
|
try:
|
|
@@ -1011,6 +1033,80 @@ class BaseClient:
|
|
|
1011
1033
|
timeout=self.timeout
|
|
1012
1034
|
)
|
|
1013
1035
|
return self.space_management.delete_data_in_path(path, region)
|
|
1036
|
+
|
|
1037
|
+
def generate_ai_dataset(
|
|
1038
|
+
self,
|
|
1039
|
+
name: str,
|
|
1040
|
+
aoi_geojson: str,
|
|
1041
|
+
expression_x: str,
|
|
1042
|
+
expression_y: str,
|
|
1043
|
+
samples: int,
|
|
1044
|
+
tile_size: int,
|
|
1045
|
+
crs: str = "epsg:4326",
|
|
1046
|
+
res: float = 0.001,
|
|
1047
|
+
region: str = "aus",
|
|
1048
|
+
start_year: int = None,
|
|
1049
|
+
end_year: int = None,
|
|
1050
|
+
) -> dict:
|
|
1051
|
+
"""
|
|
1052
|
+
Generate an AI dataset using specified parameters.
|
|
1053
|
+
|
|
1054
|
+
Args:
|
|
1055
|
+
name (str): Name of the dataset to generate
|
|
1056
|
+
aoi_geojson (str): Path to GeoJSON file containing area of interest
|
|
1057
|
+
expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
1058
|
+
expression_y (str): Expression for Y variable with {year} placeholder
|
|
1059
|
+
samples (int): Number of samples to generate
|
|
1060
|
+
tile_size (int): Size of tiles in degrees
|
|
1061
|
+
crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
|
|
1062
|
+
res (float, optional): Resolution in degrees. Defaults to 0.001
|
|
1063
|
+
region (str, optional): Region code. Defaults to "aus"
|
|
1064
|
+
start_year (int, optional): Start year for data generation. Required if end_year provided
|
|
1065
|
+
end_year (int, optional): End year for data generation. Required if start_year provided
|
|
1066
|
+
overwrite (bool, optional): Whether to overwrite existing dataset. Defaults to False
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
dict: Response from the AI dataset generation API
|
|
1070
|
+
|
|
1071
|
+
Raises:
|
|
1072
|
+
ValidationError: If required parameters are missing or invalid
|
|
1073
|
+
APIError: If the API request fails
|
|
1074
|
+
"""
|
|
1075
|
+
|
|
1076
|
+
# we have the parameters, let pass the parameters to the random sample function
|
|
1077
|
+
# task_id = self.random_sample(name, aoi_geojson, expression_x, expression_y, samples, tile_size, crs, res, region, start_year, end_year, overwrite)
|
|
1078
|
+
config = {
|
|
1079
|
+
"expressions" : [{"expr": expression_x, "res": res, "prefix": "x"}],
|
|
1080
|
+
"filters" : []
|
|
1081
|
+
}
|
|
1082
|
+
config["expressions"].append({"expr": expression_y, "res" : res, "prefix": "y"})
|
|
1083
|
+
|
|
1084
|
+
expression_x = expression_x.replace("{year}", str(start_year))
|
|
1085
|
+
expression_y = expression_y.replace("{year}", str(start_year))
|
|
1086
|
+
print("the aoi geojson is ", aoi_geojson)
|
|
1087
|
+
with open(aoi_geojson, 'r') as f:
|
|
1088
|
+
aoi_data = json.load(f)
|
|
1089
|
+
print("the config is ", config)
|
|
1090
|
+
task_id = self.random_sample(
|
|
1091
|
+
name=name,
|
|
1092
|
+
config=config,
|
|
1093
|
+
aoi=aoi_data,
|
|
1094
|
+
samples=samples,
|
|
1095
|
+
year_range=[start_year, end_year],
|
|
1096
|
+
crs=crs,
|
|
1097
|
+
tile_size=tile_size,
|
|
1098
|
+
res=res,
|
|
1099
|
+
region=region,
|
|
1100
|
+
output="netcdf",
|
|
1101
|
+
server=self.url,
|
|
1102
|
+
bucket="terrakio-mass-requests",
|
|
1103
|
+
overwrite=True
|
|
1104
|
+
)["task_id"]
|
|
1105
|
+
print("the task id is ", task_id)
|
|
1106
|
+
task_id = self.start_mass_stats_job(task_id)
|
|
1107
|
+
print("the task id is ", task_id)
|
|
1108
|
+
return task_id
|
|
1109
|
+
|
|
1014
1110
|
|
|
1015
1111
|
def train_model(self, model_name: str, training_data: dict) -> dict:
|
|
1016
1112
|
"""
|
|
@@ -1044,3 +1140,90 @@ class BaseClient:
|
|
|
1044
1140
|
except requests.RequestException as e:
|
|
1045
1141
|
raise APIError(f"Model training request failed: {str(e)}")
|
|
1046
1142
|
|
|
1143
|
+
def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
|
|
1144
|
+
# we have the dataset and we have the product, and we have the model name, we need to create a new json file and add that to the dataset as our virtual dataset
|
|
1145
|
+
# upload the script to the bucket, the script should be able to download the model and do the inferencing
|
|
1146
|
+
# we need to upload the the json to the to the dataset as our virtual dataset
|
|
1147
|
+
# then we do nothing and wait for the user to make the request call to the explorer
|
|
1148
|
+
# we should have a uniform script for the random forest deployment
|
|
1149
|
+
# create a script for each model
|
|
1150
|
+
# upload script to google bucket,
|
|
1151
|
+
#
|
|
1152
|
+
|
|
1153
|
+
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
1154
|
+
# self.create_dataset(collection = "terrakio-datasets", input = input, )
|
|
1155
|
+
# we have the script, we need to upload it to the bucket
|
|
1156
|
+
script_name = f"{product}.py"
|
|
1157
|
+
print("the script content is ", script_content)
|
|
1158
|
+
print("the script name is ", script_name)
|
|
1159
|
+
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
1160
|
+
# after uploading the script, we need to create a new virtual dataset
|
|
1161
|
+
self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
|
|
1162
|
+
|
|
1163
|
+
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
1164
|
+
return textwrap.dedent(f'''
|
|
1165
|
+
import logging
|
|
1166
|
+
from io import BytesIO
|
|
1167
|
+
from google.cloud import storage
|
|
1168
|
+
from onnxruntime import InferenceSession
|
|
1169
|
+
import numpy as np
|
|
1170
|
+
import xarray as xr
|
|
1171
|
+
import datetime
|
|
1172
|
+
|
|
1173
|
+
logging.basicConfig(
|
|
1174
|
+
level=logging.INFO
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
def get_model():
|
|
1178
|
+
logging.info("Loading model for {model_name}...")
|
|
1179
|
+
|
|
1180
|
+
client = storage.Client()
|
|
1181
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
1182
|
+
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
1183
|
+
|
|
1184
|
+
model = BytesIO()
|
|
1185
|
+
blob.download_to_file(model)
|
|
1186
|
+
model.seek(0)
|
|
1187
|
+
|
|
1188
|
+
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
1189
|
+
return session
|
|
1190
|
+
|
|
1191
|
+
def {product}(*bands, model):
|
|
1192
|
+
logging.info("start preparing data")
|
|
1193
|
+
|
|
1194
|
+
original_shape = bands[0].shape
|
|
1195
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
1196
|
+
|
|
1197
|
+
transformed_bands = []
|
|
1198
|
+
for band in bands:
|
|
1199
|
+
transformed_band = band.values.reshape(-1,1)
|
|
1200
|
+
transformed_bands.append(transformed_band)
|
|
1201
|
+
|
|
1202
|
+
input_data = np.hstack(transformed_bands)
|
|
1203
|
+
|
|
1204
|
+
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
1205
|
+
|
|
1206
|
+
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
1207
|
+
|
|
1208
|
+
logging.info(f"Model output shape: {{output.shape}}")
|
|
1209
|
+
|
|
1210
|
+
output_reshaped = output.reshape(original_shape)
|
|
1211
|
+
result = xr.DataArray(
|
|
1212
|
+
data=output_reshaped,
|
|
1213
|
+
dims=bands[0].dims,
|
|
1214
|
+
coords=bands[0].coords
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
return result
|
|
1218
|
+
''').strip()
|
|
1219
|
+
|
|
1220
|
+
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
1221
|
+
"""Upload the generated script to Google Cloud Storage"""
|
|
1222
|
+
|
|
1223
|
+
client = storage.Client()
|
|
1224
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
1225
|
+
blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
|
|
1226
|
+
# the first layer is the uid, the second layer is the model training job name
|
|
1227
|
+
blob.upload_from_string(script_content, content_type='text/plain')
|
|
1228
|
+
logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
|
|
1229
|
+
|
|
@@ -117,7 +117,7 @@ class DatasetManagement:
|
|
|
117
117
|
|
|
118
118
|
# Add optional parameters if provided
|
|
119
119
|
for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
|
|
120
|
-
"no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform"]:
|
|
120
|
+
"no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
|
|
121
121
|
if param in kwargs:
|
|
122
122
|
payload[param] = kwargs[param]
|
|
123
123
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
terrakio_core/__init__.py,sha256=
|
|
1
|
+
terrakio_core/__init__.py,sha256=Un_V6wfdpzgstem38OlIcfKNKbTFR0e1ah50HHcZICU,88
|
|
2
2
|
terrakio_core/auth.py,sha256=Nuj0_X3Hiy17svYgGxrSAR-LXpTlP0J0dSrfMnkPUbI,7717
|
|
3
|
-
terrakio_core/client.py,sha256=
|
|
3
|
+
terrakio_core/client.py,sha256=Y1nPxq3UqyLf8akYkxhlC7aOyhJMKTs9eaeRx_IV2UI,56012
|
|
4
4
|
terrakio_core/config.py,sha256=AwJ1VgR5K7N32XCU5k7_Dp1nIv_FYt8MBonq9yKlGzA,2658
|
|
5
|
-
terrakio_core/dataset_management.py,sha256=
|
|
5
|
+
terrakio_core/dataset_management.py,sha256=LKUESSDPRu1JubQaQJWdPqHLGt-_Xv77Fpb4IM7vkzM,8751
|
|
6
6
|
terrakio_core/exceptions.py,sha256=9S-I20-QiDRj1qgjFyYUwYM7BLic_bxurcDOIm2Fu_0,410
|
|
7
7
|
terrakio_core/group_access_management.py,sha256=NJ7SX4keUzZAUENmJ5L6ynKf4eRlqtyir5uoKFyY17A,7315
|
|
8
8
|
terrakio_core/mass_stats.py,sha256=AqYJsd6nqo2BDh4vEPUDgsv4T0UR1_TPDoXa3WO3gTU,9284
|
|
9
9
|
terrakio_core/space_management.py,sha256=wlUUQrlj_4U_Lpjn9lbF5oj0Rv3NPvvnrd5mWej5kmA,4211
|
|
10
10
|
terrakio_core/user_management.py,sha256=MMNWkz0V_9X7ZYjjteuRU4H4W3F16iuQw1dpA2wVTGg,7400
|
|
11
|
-
terrakio_core-0.2.
|
|
12
|
-
terrakio_core-0.2.
|
|
13
|
-
terrakio_core-0.2.
|
|
14
|
-
terrakio_core-0.2.
|
|
11
|
+
terrakio_core-0.2.7.dist-info/METADATA,sha256=YEO4IiAHb9aUgWAoS8JjcBZY9wySuAQayaJaUwgd9CQ,1405
|
|
12
|
+
terrakio_core-0.2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
+
terrakio_core-0.2.7.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
|
|
14
|
+
terrakio_core-0.2.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|