terrakio-core 0.3.6__py3-none-any.whl → 0.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +1 -1
- terrakio_core/async_client.py +50 -47
- terrakio_core/endpoints/dataset_management.py +6 -4
- terrakio_core/endpoints/mass_stats.py +5 -5
- terrakio_core/endpoints/model_management.py +426 -21
- terrakio_core/exceptions.py +4 -2
- {terrakio_core-0.3.6.dist-info → terrakio_core-0.3.8.dist-info}/METADATA +7 -1
- {terrakio_core-0.3.6.dist-info → terrakio_core-0.3.8.dist-info}/RECORD +10 -10
- {terrakio_core-0.3.6.dist-info → terrakio_core-0.3.8.dist-info}/WHEEL +0 -0
- {terrakio_core-0.3.6.dist-info → terrakio_core-0.3.8.dist-info}/top_level.txt +0 -0
terrakio_core/__init__.py
CHANGED
terrakio_core/async_client.py
CHANGED
|
@@ -46,65 +46,65 @@ class AsyncClient(BaseClient):
|
|
|
46
46
|
else:
|
|
47
47
|
return await self._make_request_with_retry(self._session, method, endpoint, **kwargs)
|
|
48
48
|
|
|
49
|
-
|
|
50
49
|
async def _make_request_with_retry(self, session: aiohttp.ClientSession, method: str, endpoint: str, **kwargs) -> Dict[Any, Any]:
|
|
51
50
|
url = f"{self.url}/{endpoint.lstrip('/')}"
|
|
51
|
+
last_exception = None
|
|
52
|
+
|
|
52
53
|
for attempt in range(self.retry + 1):
|
|
53
|
-
try:
|
|
54
|
+
try:
|
|
54
55
|
async with session.request(method, url, **kwargs) as response:
|
|
55
|
-
|
|
56
|
+
if not response.ok and self._should_retry(response.status, attempt):
|
|
57
|
+
self.logger.info(f"Request failed (attempt {attempt+1}/{self.retry+1}): {response.status}. Retrying...")
|
|
58
|
+
continue
|
|
56
59
|
if not response.ok:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if response.status == 400:
|
|
60
|
-
should_retry = False
|
|
61
|
-
else:
|
|
62
|
-
if response.status in [408, 502, 503, 504]:
|
|
63
|
-
should_retry = True
|
|
64
|
-
elif response.status == 500:
|
|
65
|
-
try:
|
|
66
|
-
response_text = await response.text()
|
|
67
|
-
if "Internal server error" not in response_text:
|
|
68
|
-
should_retry = True
|
|
69
|
-
except:
|
|
70
|
-
should_retry = True
|
|
71
|
-
|
|
72
|
-
if should_retry and attempt < self.retry:
|
|
73
|
-
self.logger.info(f"Request failed (attempt {attempt+1}/{self.retry+1}): {response.status} {response.reason}. Retrying...")
|
|
74
|
-
continue
|
|
75
|
-
else:
|
|
76
|
-
error_msg = f"API request failed: {response.status} {response.reason}"
|
|
77
|
-
try:
|
|
78
|
-
error_data = await response.json()
|
|
79
|
-
if "detail" in error_data:
|
|
80
|
-
error_msg += f" - {error_data['detail']}"
|
|
81
|
-
except:
|
|
82
|
-
pass
|
|
83
|
-
raise APIError(error_msg, status_code=response.status)
|
|
84
|
-
|
|
85
|
-
content_type = response.headers.get('content-type', '').lower()
|
|
86
|
-
content = await response.read()
|
|
87
|
-
if 'json' in content_type:
|
|
88
|
-
return json.loads(content.decode('utf-8'))
|
|
89
|
-
elif 'csv' in content_type:
|
|
90
|
-
return pd.read_csv(BytesIO(content))
|
|
91
|
-
elif 'image/' in content_type:
|
|
92
|
-
return content
|
|
93
|
-
elif 'text' in content_type:
|
|
94
|
-
return content.decode('utf-8')
|
|
95
|
-
else:
|
|
60
|
+
error_msg = f"API request failed: {response.status} {response.reason}"
|
|
96
61
|
try:
|
|
97
|
-
|
|
62
|
+
error_data = await response.json()
|
|
63
|
+
if "detail" in error_data:
|
|
64
|
+
error_msg += f" - {error_data['detail']}"
|
|
98
65
|
except:
|
|
99
|
-
|
|
66
|
+
pass
|
|
67
|
+
raise APIError(error_msg, status_code=response.status)
|
|
68
|
+
return await self._parse_response(response)
|
|
69
|
+
|
|
100
70
|
except aiohttp.ClientError as e:
|
|
71
|
+
last_exception = e
|
|
101
72
|
if attempt < self.retry:
|
|
102
|
-
self.logger.info(f"
|
|
73
|
+
self.logger.info(f"Networking error (attempt {attempt+1}/{self.retry+1}): {e}. Retrying...")
|
|
103
74
|
continue
|
|
104
75
|
else:
|
|
105
|
-
|
|
106
|
-
|
|
76
|
+
break
|
|
77
|
+
|
|
78
|
+
raise APIError(f"Networking error, request failed after {self.retry+1} attempts: {last_exception}", status_code=None)
|
|
79
|
+
|
|
80
|
+
def _should_retry(self, status_code: int, attempt: int) -> bool:
|
|
81
|
+
"""Determine if the request should be retried based on status code."""
|
|
82
|
+
if attempt >= self.retry:
|
|
83
|
+
return False
|
|
84
|
+
elif status_code in [408, 502, 503, 504]:
|
|
85
|
+
return True
|
|
86
|
+
else:
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
async def _parse_response(self, response) -> Any:
|
|
90
|
+
"""Parse response based on content type."""
|
|
91
|
+
content_type = response.headers.get('content-type', '').lower()
|
|
92
|
+
content = await response.read()
|
|
107
93
|
|
|
94
|
+
if 'json' in content_type:
|
|
95
|
+
return json.loads(content.decode('utf-8'))
|
|
96
|
+
elif 'csv' in content_type:
|
|
97
|
+
return pd.read_csv(BytesIO(content))
|
|
98
|
+
elif 'image/' in content_type:
|
|
99
|
+
return content
|
|
100
|
+
elif 'text' in content_type:
|
|
101
|
+
return content.decode('utf-8')
|
|
102
|
+
else:
|
|
103
|
+
try:
|
|
104
|
+
return xr.open_dataset(BytesIO(content))
|
|
105
|
+
except:
|
|
106
|
+
raise APIError(f"Unknown response format: {content_type}", status_code=response.status)
|
|
107
|
+
|
|
108
108
|
async def _regular_request(self, method: str, endpoint: str, **kwargs):
|
|
109
109
|
url = endpoint.lstrip('/')
|
|
110
110
|
if self._session is None:
|
|
@@ -134,6 +134,7 @@ class AsyncClient(BaseClient):
|
|
|
134
134
|
output: str = "csv",
|
|
135
135
|
resolution: int = -1,
|
|
136
136
|
geom_fix: bool = False,
|
|
137
|
+
validated: bool = True,
|
|
137
138
|
**kwargs
|
|
138
139
|
):
|
|
139
140
|
"""
|
|
@@ -147,6 +148,7 @@ class AsyncClient(BaseClient):
|
|
|
147
148
|
output (str): Output format ('csv' or 'netcdf')
|
|
148
149
|
resolution (int): Resolution parameter
|
|
149
150
|
geom_fix (bool): Whether to fix the geometry (default False)
|
|
151
|
+
validated (bool): Whether to use validated data (default True)
|
|
150
152
|
**kwargs: Additional parameters to pass to the WCS request
|
|
151
153
|
|
|
152
154
|
Returns:
|
|
@@ -169,6 +171,7 @@ class AsyncClient(BaseClient):
|
|
|
169
171
|
"resolution": resolution,
|
|
170
172
|
"expr": expr,
|
|
171
173
|
"buffer": geom_fix,
|
|
174
|
+
"validated": validated,
|
|
172
175
|
**kwargs
|
|
173
176
|
}
|
|
174
177
|
return await self._terrakio_request("POST", "geoquery", json=payload)
|
|
@@ -42,7 +42,7 @@ class DatasetManagement:
|
|
|
42
42
|
return self._client._terrakio_request("GET", f"/datasets/{name}", params = params)
|
|
43
43
|
|
|
44
44
|
@require_api_key
|
|
45
|
-
def create_dataset(
|
|
45
|
+
async def create_dataset(
|
|
46
46
|
self,
|
|
47
47
|
name: str,
|
|
48
48
|
collection: str = "terrakio-datasets",
|
|
@@ -59,7 +59,8 @@ class DatasetManagement:
|
|
|
59
59
|
proj4: Optional[str] = None,
|
|
60
60
|
abstract: Optional[str] = None,
|
|
61
61
|
geotransform: Optional[List[float]] = None,
|
|
62
|
-
padding: Optional[Any] = None
|
|
62
|
+
padding: Optional[Any] = None,
|
|
63
|
+
input: Optional[str] = None
|
|
63
64
|
) -> Dict[str, Any]:
|
|
64
65
|
"""
|
|
65
66
|
Create a new dataset.
|
|
@@ -104,12 +105,13 @@ class DatasetManagement:
|
|
|
104
105
|
"proj4": proj4,
|
|
105
106
|
"abstract": abstract,
|
|
106
107
|
"geotransform": geotransform,
|
|
107
|
-
"padding": padding
|
|
108
|
+
"padding": padding,
|
|
109
|
+
"input": input
|
|
108
110
|
}
|
|
109
111
|
for param, value in param_mapping.items():
|
|
110
112
|
if value is not None:
|
|
111
113
|
payload[param] = value
|
|
112
|
-
return self._client._terrakio_request("POST", "/datasets", params = params, json = payload)
|
|
114
|
+
return await self._client._terrakio_request("POST", "/datasets", params = params, json = payload)
|
|
113
115
|
|
|
114
116
|
@require_api_key
|
|
115
117
|
def update_dataset(
|
|
@@ -12,7 +12,7 @@ class MassStats:
|
|
|
12
12
|
self._client = client
|
|
13
13
|
|
|
14
14
|
@require_api_key
|
|
15
|
-
async def
|
|
15
|
+
async def _upload_request(
|
|
16
16
|
self,
|
|
17
17
|
name: str,
|
|
18
18
|
size: int,
|
|
@@ -220,7 +220,7 @@ class MassStats:
|
|
|
220
220
|
return self._client._terrakio_request("GET", "mass_stats/download", params=params)
|
|
221
221
|
|
|
222
222
|
@require_api_key
|
|
223
|
-
async def
|
|
223
|
+
async def _upload_file(self, file_path: str, url: str, use_gzip: bool = False):
|
|
224
224
|
"""
|
|
225
225
|
Helper method to upload a JSON file to a signed URL.
|
|
226
226
|
|
|
@@ -427,7 +427,7 @@ class MassStats:
|
|
|
427
427
|
return e
|
|
428
428
|
except json.JSONDecodeError as e:
|
|
429
429
|
return e
|
|
430
|
-
upload_result = await self.
|
|
430
|
+
upload_result = await self._upload_request(name = name, size = size, region = region, output = output, config = config, location = location, force_loc = force_loc, overwrite = overwrite, server = server, skip_existing = skip_existing)
|
|
431
431
|
requests_url = upload_result.get('requests_url')
|
|
432
432
|
manifest_url = upload_result.get('manifest_url')
|
|
433
433
|
if not requests_url:
|
|
@@ -436,7 +436,7 @@ class MassStats:
|
|
|
436
436
|
try:
|
|
437
437
|
# in this place we are uploading the request json file, we need to check whether the json is in the correct format or not
|
|
438
438
|
self.validate_request(request_json)
|
|
439
|
-
requests_response = await self.
|
|
439
|
+
requests_response = await self._upload_file(request_json, requests_url, use_gzip=True)
|
|
440
440
|
if requests_response.status not in [200, 201, 204]:
|
|
441
441
|
self._client.logger.error(f"Requests upload error: {requests_response.text()}")
|
|
442
442
|
raise Exception(f"Failed to upload request JSON: {requests_response.text()}")
|
|
@@ -447,7 +447,7 @@ class MassStats:
|
|
|
447
447
|
raise ValueError("No manifest_url returned from server for manifest JSON upload")
|
|
448
448
|
|
|
449
449
|
try:
|
|
450
|
-
manifest_response = await self.
|
|
450
|
+
manifest_response = await self._upload_file(manifest_json, manifest_url, use_gzip=False)
|
|
451
451
|
if manifest_response.status not in [200, 201, 204]:
|
|
452
452
|
self._client.logger.error(f"Manifest upload error: {manifest_response.text()}")
|
|
453
453
|
raise Exception(f"Failed to upload manifest JSON: {manifest_response.text()}")
|
|
@@ -3,9 +3,32 @@ import json
|
|
|
3
3
|
import time
|
|
4
4
|
import textwrap
|
|
5
5
|
import logging
|
|
6
|
-
from typing import Dict, Any
|
|
6
|
+
from typing import Dict, Any, Union, Tuple
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
import numpy as np
|
|
7
9
|
from google.cloud import storage
|
|
8
10
|
from ..helper.decorators import require_token, require_api_key, require_auth
|
|
11
|
+
TORCH_AVAILABLE = False
|
|
12
|
+
SKL2ONNX_AVAILABLE = False
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import torch
|
|
16
|
+
TORCH_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
torch = None
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from skl2onnx import convert_sklearn
|
|
22
|
+
from skl2onnx.common.data_types import FloatTensorType
|
|
23
|
+
from sklearn.base import BaseEstimator
|
|
24
|
+
SKL2ONNX_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
convert_sklearn = None
|
|
27
|
+
FloatTensorType = None
|
|
28
|
+
BaseEstimator = None
|
|
29
|
+
|
|
30
|
+
from io import BytesIO
|
|
31
|
+
from typing import Tuple
|
|
9
32
|
|
|
10
33
|
class ModelManagement:
|
|
11
34
|
def __init__(self, client):
|
|
@@ -115,14 +138,13 @@ class ModelManagement:
|
|
|
115
138
|
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
|
116
139
|
percentage = progress * 100
|
|
117
140
|
|
|
118
|
-
|
|
119
|
-
print(f"\rJob status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})", end='')
|
|
141
|
+
self._client.logger.info(f"Job status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})")
|
|
120
142
|
|
|
121
143
|
if status == "Completed":
|
|
122
|
-
|
|
144
|
+
self._client.logger.info("Job completed successfully!")
|
|
123
145
|
break
|
|
124
146
|
elif status == "Error":
|
|
125
|
-
|
|
147
|
+
self._client.logger.info("Job encountered an error")
|
|
126
148
|
raise Exception(f"Job {task_id} encountered an error")
|
|
127
149
|
|
|
128
150
|
# Wait 5 seconds before checking again
|
|
@@ -133,43 +155,238 @@ class ModelManagement:
|
|
|
133
155
|
return task_id
|
|
134
156
|
|
|
135
157
|
@require_api_key
|
|
136
|
-
async def upload_model(self,
|
|
158
|
+
async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
|
|
137
159
|
"""
|
|
138
160
|
Upload a model to the bucket so that it can be used for inference.
|
|
161
|
+
Converts PyTorch and scikit-learn models to ONNX format before uploading.
|
|
139
162
|
|
|
140
163
|
Args:
|
|
141
|
-
|
|
164
|
+
model: The model object (PyTorch model or scikit-learn model)
|
|
165
|
+
model_name: Name for the model (without extension)
|
|
166
|
+
input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
|
|
167
|
+
Required for PyTorch models, optional for scikit-learn models
|
|
142
168
|
|
|
143
169
|
Raises:
|
|
144
170
|
APIError: If the API request fails
|
|
171
|
+
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
172
|
+
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
145
173
|
"""
|
|
146
174
|
uid = (await self._client.auth.get_user_info())["uid"]
|
|
147
|
-
model_name = os.path.basename(model_path)
|
|
148
175
|
|
|
149
176
|
client = storage.Client()
|
|
150
177
|
bucket = client.get_bucket('terrakio-mass-requests')
|
|
151
|
-
model_file_name = os.path.splitext(model_name)[0]
|
|
152
|
-
blob = bucket.blob(f'{uid}/{model_file_name}/models/{model_name}')
|
|
153
178
|
|
|
154
|
-
|
|
155
|
-
self.
|
|
179
|
+
# Convert model to ONNX format
|
|
180
|
+
onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
|
|
181
|
+
|
|
182
|
+
# Upload ONNX model to bucket
|
|
183
|
+
blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
|
|
184
|
+
|
|
185
|
+
blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
|
|
186
|
+
self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
|
|
187
|
+
|
|
188
|
+
def _convert_model_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
|
|
189
|
+
"""
|
|
190
|
+
Convert a model to ONNX format and return as bytes.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
model: The model object (PyTorch or scikit-learn)
|
|
194
|
+
model_name: Name of the model for logging
|
|
195
|
+
input_shape: Shape of input data
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
bytes: ONNX model as bytes
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ValueError: If model type is not supported
|
|
202
|
+
ImportError: If required libraries are not installed
|
|
203
|
+
"""
|
|
204
|
+
# Early check for any conversion capability
|
|
205
|
+
if not (TORCH_AVAILABLE or SKL2ONNX_AVAILABLE):
|
|
206
|
+
raise ImportError(
|
|
207
|
+
"ONNX conversion requires additional dependencies. Install with:\n"
|
|
208
|
+
" pip install torch # For PyTorch models\n"
|
|
209
|
+
" pip install skl2onnx # For scikit-learn models\n"
|
|
210
|
+
" pip install torch skl2onnx # For both"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Check if it's a PyTorch model using isinstance (preferred) with fallback
|
|
214
|
+
is_pytorch = False
|
|
215
|
+
if TORCH_AVAILABLE:
|
|
216
|
+
is_pytorch = (isinstance(model, torch.nn.Module) or
|
|
217
|
+
hasattr(model, 'state_dict'))
|
|
218
|
+
|
|
219
|
+
# Check if it's a scikit-learn model
|
|
220
|
+
is_sklearn = False
|
|
221
|
+
if SKL2ONNX_AVAILABLE:
|
|
222
|
+
is_sklearn = (isinstance(model, BaseEstimator) or
|
|
223
|
+
(hasattr(model, 'fit') and hasattr(model, 'predict')))
|
|
224
|
+
|
|
225
|
+
if is_pytorch and TORCH_AVAILABLE:
|
|
226
|
+
return self._convert_pytorch_to_onnx(model, model_name, input_shape)
|
|
227
|
+
elif is_sklearn and SKL2ONNX_AVAILABLE:
|
|
228
|
+
return self._convert_sklearn_to_onnx(model, model_name, input_shape)
|
|
229
|
+
else:
|
|
230
|
+
# Provide helpful error message
|
|
231
|
+
model_type = type(model).__name__
|
|
232
|
+
model_module = type(model).__module__
|
|
233
|
+
available_types = []
|
|
234
|
+
missing_deps = []
|
|
235
|
+
|
|
236
|
+
if TORCH_AVAILABLE:
|
|
237
|
+
available_types.append("PyTorch (torch.nn.Module)")
|
|
238
|
+
else:
|
|
239
|
+
missing_deps.append("torch")
|
|
240
|
+
|
|
241
|
+
if SKL2ONNX_AVAILABLE:
|
|
242
|
+
available_types.append("scikit-learn (BaseEstimator)")
|
|
243
|
+
else:
|
|
244
|
+
missing_deps.append("skl2onnx")
|
|
245
|
+
|
|
246
|
+
if missing_deps:
|
|
247
|
+
raise ImportError(
|
|
248
|
+
f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
|
|
249
|
+
f"Install with: pip install {' '.join(missing_deps)}"
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"Unsupported model type: {model_type} from {model_module}. "
|
|
254
|
+
f"Supported types: {', '.join(available_types)}"
|
|
255
|
+
)
|
|
156
256
|
|
|
257
|
+
def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
|
|
258
|
+
"""Convert PyTorch model to ONNX format with dynamic input dimensions."""
|
|
259
|
+
if input_shape is None:
|
|
260
|
+
raise ValueError("input_shape is required for PyTorch models")
|
|
261
|
+
|
|
262
|
+
self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
# Set model to evaluation mode
|
|
266
|
+
model.eval()
|
|
267
|
+
|
|
268
|
+
# Create dummy input
|
|
269
|
+
dummy_input = torch.randn(input_shape)
|
|
270
|
+
|
|
271
|
+
# Use BytesIO to avoid creating temporary files
|
|
272
|
+
onnx_buffer = BytesIO()
|
|
273
|
+
|
|
274
|
+
# Determine dynamic axes based on input shape
|
|
275
|
+
# Common patterns for different input types:
|
|
276
|
+
if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
|
|
277
|
+
dynamic_axes = {
|
|
278
|
+
'float_input': {
|
|
279
|
+
0: 'batch_size',
|
|
280
|
+
2: 'height', # Make height dynamic for variable input sizes
|
|
281
|
+
3: 'width' # Make width dynamic for variable input sizes
|
|
282
|
+
},
|
|
283
|
+
'output': {0: 'batch_size'}
|
|
284
|
+
}
|
|
285
|
+
elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
|
|
286
|
+
dynamic_axes = {
|
|
287
|
+
'float_input': {
|
|
288
|
+
0: 'batch_size',
|
|
289
|
+
1: 'dim1', # Generic dynamic dimension
|
|
290
|
+
2: 'dim2' # Generic dynamic dimension
|
|
291
|
+
},
|
|
292
|
+
'output': {0: 'batch_size'}
|
|
293
|
+
}
|
|
294
|
+
elif len(input_shape) == 2: # Likely (batch, features)
|
|
295
|
+
dynamic_axes = {
|
|
296
|
+
'float_input': {
|
|
297
|
+
0: 'batch_size'
|
|
298
|
+
# Don't make features dynamic as it usually affects model architecture
|
|
299
|
+
},
|
|
300
|
+
'output': {0: 'batch_size'}
|
|
301
|
+
}
|
|
302
|
+
else:
|
|
303
|
+
# For other shapes, just make batch size dynamic
|
|
304
|
+
dynamic_axes = {
|
|
305
|
+
'float_input': {0: 'batch_size'},
|
|
306
|
+
'output': {0: 'batch_size'}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
torch.onnx.export(
|
|
310
|
+
model,
|
|
311
|
+
dummy_input,
|
|
312
|
+
onnx_buffer,
|
|
313
|
+
export_params=True,
|
|
314
|
+
opset_version=11,
|
|
315
|
+
do_constant_folding=True,
|
|
316
|
+
input_names=['float_input'],
|
|
317
|
+
output_names=['output'],
|
|
318
|
+
dynamic_axes=dynamic_axes
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
|
|
322
|
+
return onnx_buffer.getvalue()
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
|
|
329
|
+
"""Convert scikit-learn model to ONNX format."""
|
|
330
|
+
self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
|
|
331
|
+
|
|
332
|
+
# Try to infer input shape if not provided
|
|
333
|
+
if input_shape is None:
|
|
334
|
+
if hasattr(model, 'n_features_in_'):
|
|
335
|
+
input_shape = (1, model.n_features_in_)
|
|
336
|
+
else:
|
|
337
|
+
raise ValueError(
|
|
338
|
+
"input_shape is required for scikit-learn models when n_features_in_ is not available. "
|
|
339
|
+
"This usually happens with older sklearn versions or models not fitted yet."
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
# Convert scikit-learn model to ONNX
|
|
344
|
+
initial_type = [('float_input', FloatTensorType(input_shape))]
|
|
345
|
+
onnx_model = convert_sklearn(model, initial_types=initial_type)
|
|
346
|
+
return onnx_model.SerializeToString()
|
|
347
|
+
|
|
348
|
+
except Exception as e:
|
|
349
|
+
raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
|
|
350
|
+
|
|
157
351
|
@require_api_key
|
|
158
|
-
def
|
|
352
|
+
async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
|
|
353
|
+
"""
|
|
354
|
+
Upload a CNN model to the bucket and deploy it.
|
|
159
355
|
|
|
356
|
+
Args:
|
|
357
|
+
model: The model object (PyTorch model or scikit-learn model)
|
|
358
|
+
model_name: Name for the model (without extension)
|
|
359
|
+
dataset: Name of the dataset to create
|
|
360
|
+
product: Product name for the inference
|
|
361
|
+
input_expression: Input expression for the dataset
|
|
362
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
363
|
+
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
APIError: If the API request fails
|
|
367
|
+
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
368
|
+
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
369
|
+
"""
|
|
370
|
+
await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
|
|
371
|
+
# so the uploading process is kinda similar, but the deployment step is kinda different
|
|
372
|
+
await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
|
|
373
|
+
|
|
374
|
+
@require_api_key
|
|
375
|
+
async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
|
|
160
376
|
"""
|
|
161
377
|
Upload a model to the bucket and deploy it.
|
|
162
378
|
|
|
163
379
|
Args:
|
|
164
|
-
|
|
165
|
-
|
|
380
|
+
model: The model object (PyTorch model or scikit-learn model)
|
|
381
|
+
model_name: Name for the model (without extension)
|
|
382
|
+
dataset: Name of the dataset to create
|
|
166
383
|
product: Product name for the inference
|
|
167
384
|
input_expression: Input expression for the dataset
|
|
168
385
|
dates_iso8601: List of dates in ISO8601 format
|
|
386
|
+
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
169
387
|
"""
|
|
170
|
-
self.upload_model(
|
|
171
|
-
model_name =
|
|
172
|
-
self.deploy_model(dataset = dataset, product = product, model_name = model_name, input_expression = input_expression, model_training_job_name = model_name, dates_iso8601 = dates_iso8601)
|
|
388
|
+
await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
|
|
389
|
+
await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
|
|
173
390
|
|
|
174
391
|
@require_api_key
|
|
175
392
|
def train_model(
|
|
@@ -212,7 +429,7 @@ class ModelManagement:
|
|
|
212
429
|
return self._client._terrakio_request("POST", "/train_model", json=payload)
|
|
213
430
|
|
|
214
431
|
@require_api_key
|
|
215
|
-
def deploy_model(
|
|
432
|
+
async def deploy_model(
|
|
216
433
|
self,
|
|
217
434
|
dataset: str,
|
|
218
435
|
product: str,
|
|
@@ -239,7 +456,7 @@ class ModelManagement:
|
|
|
239
456
|
APIError: If the API request fails
|
|
240
457
|
"""
|
|
241
458
|
# Get user info to get UID
|
|
242
|
-
user_info = self._client.get_user_info()
|
|
459
|
+
user_info = await self._client.auth.get_user_info()
|
|
243
460
|
uid = user_info["uid"]
|
|
244
461
|
|
|
245
462
|
# Generate and upload script
|
|
@@ -248,7 +465,53 @@ class ModelManagement:
|
|
|
248
465
|
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
249
466
|
|
|
250
467
|
# Create dataset
|
|
251
|
-
return self._client.datasets.create_dataset(
|
|
468
|
+
return await self._client.datasets.create_dataset(
|
|
469
|
+
name=dataset,
|
|
470
|
+
collection="terrakio-datasets",
|
|
471
|
+
products=[product],
|
|
472
|
+
path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
|
|
473
|
+
input=input_expression,
|
|
474
|
+
dates_iso8601=dates_iso8601,
|
|
475
|
+
padding=0
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
@require_api_key
|
|
479
|
+
async def deploy_cnn_model(
|
|
480
|
+
self,
|
|
481
|
+
dataset: str,
|
|
482
|
+
product: str,
|
|
483
|
+
model_name: str,
|
|
484
|
+
input_expression: str,
|
|
485
|
+
model_training_job_name: str,
|
|
486
|
+
dates_iso8601: list
|
|
487
|
+
) -> Dict[str, Any]:
|
|
488
|
+
"""
|
|
489
|
+
Deploy a CNN model by generating inference script and creating dataset.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
dataset: Name of the dataset to create
|
|
493
|
+
product: Product name for the inference
|
|
494
|
+
model_name: Name of the trained model
|
|
495
|
+
input_expression: Input expression for the dataset
|
|
496
|
+
model_training_job_name: Name of the training job
|
|
497
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
498
|
+
|
|
499
|
+
Returns:
|
|
500
|
+
dict: Response from the deployment process
|
|
501
|
+
|
|
502
|
+
Raises:
|
|
503
|
+
APIError: If the API request fails
|
|
504
|
+
"""
|
|
505
|
+
# Get user info to get UID
|
|
506
|
+
user_info = await self._client.auth.get_user_info()
|
|
507
|
+
uid = user_info["uid"]
|
|
508
|
+
|
|
509
|
+
# Generate and upload script
|
|
510
|
+
script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid)
|
|
511
|
+
script_name = f"{product}.py"
|
|
512
|
+
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
513
|
+
# Create dataset
|
|
514
|
+
return await self._client.datasets.create_dataset(
|
|
252
515
|
name=dataset,
|
|
253
516
|
collection="terrakio-datasets",
|
|
254
517
|
products=[product],
|
|
@@ -374,6 +637,148 @@ class ModelManagement:
|
|
|
374
637
|
return result
|
|
375
638
|
''').strip()
|
|
376
639
|
|
|
640
|
+
@require_api_key
|
|
641
|
+
def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
642
|
+
"""
|
|
643
|
+
Generate Python inference script for CNN model with time-stacked bands.
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
model_name: Name of the model
|
|
647
|
+
product: Product name
|
|
648
|
+
model_training_job_name: Training job name
|
|
649
|
+
uid: User ID
|
|
650
|
+
|
|
651
|
+
Returns:
|
|
652
|
+
str: Generated Python script content
|
|
653
|
+
"""
|
|
654
|
+
return textwrap.dedent(f'''
|
|
655
|
+
import logging
|
|
656
|
+
from io import BytesIO
|
|
657
|
+
|
|
658
|
+
import numpy as np
|
|
659
|
+
import pandas as pd
|
|
660
|
+
import xarray as xr
|
|
661
|
+
from google.cloud import storage
|
|
662
|
+
from onnxruntime import InferenceSession
|
|
663
|
+
|
|
664
|
+
logging.basicConfig(
|
|
665
|
+
level=logging.INFO
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
def get_model():
|
|
669
|
+
logging.info("Loading CNN model for {model_name}...")
|
|
670
|
+
|
|
671
|
+
client = storage.Client()
|
|
672
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
673
|
+
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
674
|
+
|
|
675
|
+
model = BytesIO()
|
|
676
|
+
blob.download_to_file(model)
|
|
677
|
+
model.seek(0)
|
|
678
|
+
|
|
679
|
+
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
680
|
+
return session
|
|
681
|
+
|
|
682
|
+
def {product}(*bands, model):
|
|
683
|
+
logging.info("Start preparing CNN data with time-stacked bands")
|
|
684
|
+
|
|
685
|
+
data_arrays = list(bands)
|
|
686
|
+
|
|
687
|
+
if not data_arrays:
|
|
688
|
+
raise ValueError("No bands provided")
|
|
689
|
+
|
|
690
|
+
reference_array = data_arrays[0]
|
|
691
|
+
original_shape = reference_array.shape
|
|
692
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
693
|
+
|
|
694
|
+
# Get time coordinates - all bands should have the same time dimension
|
|
695
|
+
if 'time' not in reference_array.dims:
|
|
696
|
+
raise ValueError("Time dimension is required for CNN processing")
|
|
697
|
+
|
|
698
|
+
time_coords = reference_array.coords['time']
|
|
699
|
+
num_timestamps = len(time_coords)
|
|
700
|
+
logging.info(f"Number of timestamps: {{num_timestamps}}")
|
|
701
|
+
|
|
702
|
+
# Get spatial dimensions
|
|
703
|
+
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
704
|
+
height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height
|
|
705
|
+
width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width
|
|
706
|
+
logging.info(f"Spatial dimensions: {{height}} x {{width}}")
|
|
707
|
+
|
|
708
|
+
# Stack bands across time dimension
|
|
709
|
+
# Result will be: (num_bands * num_timestamps, height, width)
|
|
710
|
+
stacked_channels = []
|
|
711
|
+
|
|
712
|
+
for band_idx, data_array in enumerate(data_arrays):
|
|
713
|
+
logging.info(f"Processing band {{band_idx + 1}}/{{len(data_arrays)}}")
|
|
714
|
+
|
|
715
|
+
# Ensure consistent time coordinates across bands
|
|
716
|
+
if not np.array_equal(data_array.coords['time'].values, time_coords.values):
|
|
717
|
+
logging.warning(f"Band {{band_idx}} has different time coordinates, aligning...")
|
|
718
|
+
data_array = data_array.sel(time=time_coords, method='nearest')
|
|
719
|
+
|
|
720
|
+
# Extract values and ensure proper ordering (time, height, width)
|
|
721
|
+
band_values = data_array.values
|
|
722
|
+
if band_values.ndim == 3:
|
|
723
|
+
# Reorder dimensions if needed to ensure (time, height, width)
|
|
724
|
+
time_dim_idx = data_array.dims.index('time')
|
|
725
|
+
if time_dim_idx != 0:
|
|
726
|
+
axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]
|
|
727
|
+
band_values = np.transpose(band_values, axes_order)
|
|
728
|
+
|
|
729
|
+
# Add each timestamp of this band to the channel stack
|
|
730
|
+
for t in range(num_timestamps):
|
|
731
|
+
stacked_channels.append(band_values[t])
|
|
732
|
+
|
|
733
|
+
# Stack all channels: (num_bands * num_timestamps, height, width)
|
|
734
|
+
input_channels = np.stack(stacked_channels, axis=0)
|
|
735
|
+
total_channels = len(data_arrays) * num_timestamps
|
|
736
|
+
logging.info(f"Stacked channels shape: {{input_channels.shape}}")
|
|
737
|
+
logging.info(f"Total channels: {{total_channels}} ({{len(data_arrays)}} bands × {{num_timestamps}} timestamps)")
|
|
738
|
+
|
|
739
|
+
# Add batch dimension: (1, num_channels, height, width)
|
|
740
|
+
input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)
|
|
741
|
+
logging.info(f"Final input shape for CNN: {{input_data.shape}}")
|
|
742
|
+
|
|
743
|
+
# Run inference
|
|
744
|
+
output = model.run(None, {{"float_input": input_data}})[0]
|
|
745
|
+
logging.info(f"Model output shape: {{output.shape}}")
|
|
746
|
+
|
|
747
|
+
# Process output back to xarray format
|
|
748
|
+
# Assuming output is (1, height, width) or (1, 1, height, width)
|
|
749
|
+
if output.ndim == 4 and output.shape[1] == 1:
|
|
750
|
+
# Remove channel dimension if it's 1
|
|
751
|
+
output_2d = output[0, 0]
|
|
752
|
+
elif output.ndim == 3:
|
|
753
|
+
# Remove batch dimension
|
|
754
|
+
output_2d = output[0]
|
|
755
|
+
else:
|
|
756
|
+
# Handle other cases
|
|
757
|
+
output_2d = np.squeeze(output)
|
|
758
|
+
if output_2d.ndim != 2:
|
|
759
|
+
raise ValueError(f"Unexpected output shape after processing: {{output_2d.shape}}")
|
|
760
|
+
|
|
761
|
+
# Determine output timestamp (use the latest timestamp)
|
|
762
|
+
output_timestamp = time_coords[-1]
|
|
763
|
+
|
|
764
|
+
# Get spatial coordinates from reference array
|
|
765
|
+
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims}}
|
|
766
|
+
|
|
767
|
+
# Create output DataArray
|
|
768
|
+
result = xr.DataArray(
|
|
769
|
+
data=np.expand_dims(output_2d.astype(np.float32), axis=0),
|
|
770
|
+
dims=['time'] + spatial_dims,
|
|
771
|
+
coords={{
|
|
772
|
+
'time': [output_timestamp.values],
|
|
773
|
+
spatial_dims[0]: spatial_coords[spatial_dims[0]].values,
|
|
774
|
+
spatial_dims[1]: spatial_coords[spatial_dims[1]].values
|
|
775
|
+
}}
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
logging.info(f"Final result shape: {{result.shape}}")
|
|
779
|
+
return result
|
|
780
|
+
''').strip()
|
|
781
|
+
|
|
377
782
|
@require_api_key
|
|
378
783
|
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
379
784
|
"""Upload the generated script to Google Cloud Storage"""
|
terrakio_core/exceptions.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
class APIError(Exception):
|
|
2
2
|
"""Exception raised for errors in the API responses."""
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
|
|
4
|
+
def __init__(self, message, status_code=None):
|
|
5
|
+
super().__init__(message)
|
|
6
|
+
self.status_code = status_code
|
|
5
7
|
|
|
6
8
|
class ConfigurationError(Exception):
|
|
7
9
|
"""Exception raised for errors in the configuration."""
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: terrakio-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.8
|
|
4
4
|
Summary: Core components for Terrakio API clients
|
|
5
5
|
Author-email: Yupeng Chao <yupeng@haizea.com.au>
|
|
6
6
|
Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
|
|
@@ -23,6 +23,12 @@ Requires-Dist: shapely>=2.0.0
|
|
|
23
23
|
Requires-Dist: geopandas>=0.13.0
|
|
24
24
|
Requires-Dist: google-cloud-storage>=2.0.0
|
|
25
25
|
Requires-Dist: nest_asyncio
|
|
26
|
+
Provides-Extra: ml
|
|
27
|
+
Requires-Dist: torch>=2.7.1; extra == "ml"
|
|
28
|
+
Requires-Dist: scikit-learn>=1.7.0; extra == "ml"
|
|
29
|
+
Requires-Dist: skl2onnx>=1.19.1; extra == "ml"
|
|
30
|
+
Requires-Dist: onnx>=1.18.0; extra == "ml"
|
|
31
|
+
Requires-Dist: onnxruntime>=1.10.0; extra == "ml"
|
|
26
32
|
|
|
27
33
|
# Terrakio Core
|
|
28
34
|
|
|
@@ -1,21 +1,21 @@
|
|
|
1
|
-
terrakio_core/__init__.py,sha256=
|
|
2
|
-
terrakio_core/async_client.py,sha256=
|
|
1
|
+
terrakio_core/__init__.py,sha256=nUk_Q29ij_1R32AjR8Ygwy0_ID4-zdVkQexvBz7reM4,242
|
|
2
|
+
terrakio_core/async_client.py,sha256=wyuIJGMAzfFwfM5BYgCTMuY0YAKNYJqRNB3IF5NpVow,12730
|
|
3
3
|
terrakio_core/client.py,sha256=h8GW88g6RlGwNFW6MW48c_3BnaeT9nSd19LI1jCn1GU,1008
|
|
4
4
|
terrakio_core/config.py,sha256=r8NARVYOca4AuM88VP_j-8wQxOk1s7VcRdyEdseBlLE,4193
|
|
5
|
-
terrakio_core/exceptions.py,sha256=
|
|
5
|
+
terrakio_core/exceptions.py,sha256=4qnpOM1gOxsNIXDXY4qwY1d3I4Myhp7HBh7b2D0SVrU,529
|
|
6
6
|
terrakio_core/sync_client.py,sha256=v1mcBtUaKWACqZgw8dTTVPMxUfKfiY0kjtBKzDwtGTU,13634
|
|
7
7
|
terrakio_core/convenience_functions/convenience_functions.py,sha256=U7bLGwfBF-FUYc0nv49pAViPsBQ6LgPlV6c6b-zeKo8,10616
|
|
8
8
|
terrakio_core/endpoints/auth.py,sha256=e_hdNE6JOGhRVlQMFdEoOmoMHp5EzK6CclOEnc_AmZw,5863
|
|
9
|
-
terrakio_core/endpoints/dataset_management.py,sha256=
|
|
9
|
+
terrakio_core/endpoints/dataset_management.py,sha256=BUm8IIlW_Q45vDiQp16CiJGeSLheI8uWRVRQtMdhaNk,13161
|
|
10
10
|
terrakio_core/endpoints/group_management.py,sha256=VFl3jakjQa9OPi351D3DZvLU9M7fHdfjCzGhmyJsx3U,6309
|
|
11
|
-
terrakio_core/endpoints/mass_stats.py,sha256=
|
|
12
|
-
terrakio_core/endpoints/model_management.py,sha256=
|
|
11
|
+
terrakio_core/endpoints/mass_stats.py,sha256=y1w3QLkDD0sKP1tBcFDqgLYLNxX94I-LYbNotaKhLYM,21356
|
|
12
|
+
terrakio_core/endpoints/model_management.py,sha256=Q2bqsVfBILu-hZVw1tr5WjOR68qoYF6m326YJXgAOeo,33886
|
|
13
13
|
terrakio_core/endpoints/space_management.py,sha256=YWb55nkJnFJGlALJ520DvurxDqVqwYtsvqQPWzxzhDs,2266
|
|
14
14
|
terrakio_core/endpoints/user_management.py,sha256=x0JW6VET7eokngmkhZPukegxoJNR1X09BVehJt2nIdI,3781
|
|
15
15
|
terrakio_core/helper/bounded_taskgroup.py,sha256=wiTH10jhKZgrsgrFUNG6gig8bFkUEPHkGRT2XY7Rgmo,677
|
|
16
16
|
terrakio_core/helper/decorators.py,sha256=L6om7wmWNgCei3Wy5U0aZ-70OzsCwclkjIf7SfQuhCg,2289
|
|
17
17
|
terrakio_core/helper/tiles.py,sha256=xNtp3oDD912PN_FQV5fb6uQYhwfHANuXyIcxoVCCfZU,2632
|
|
18
|
-
terrakio_core-0.3.
|
|
19
|
-
terrakio_core-0.3.
|
|
20
|
-
terrakio_core-0.3.
|
|
21
|
-
terrakio_core-0.3.
|
|
18
|
+
terrakio_core-0.3.8.dist-info/METADATA,sha256=oZlZhEda5qq8myogGbxvlV0ZJExcEf5kMaYWEVES0BE,1728
|
|
19
|
+
terrakio_core-0.3.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
terrakio_core-0.3.8.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
|
|
21
|
+
terrakio_core-0.3.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|