terrakio-core 0.3.4__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +10 -1
- terrakio_core/async_client.py +304 -0
- terrakio_core/client.py +22 -1713
- terrakio_core/config.py +8 -15
- terrakio_core/convenience_functions/convenience_functions.py +296 -0
- terrakio_core/endpoints/auth.py +180 -0
- terrakio_core/endpoints/dataset_management.py +371 -0
- terrakio_core/endpoints/group_management.py +228 -0
- terrakio_core/endpoints/mass_stats.py +594 -0
- terrakio_core/endpoints/model_management.py +790 -0
- terrakio_core/endpoints/space_management.py +72 -0
- terrakio_core/endpoints/user_management.py +131 -0
- terrakio_core/exceptions.py +4 -2
- terrakio_core/helper/bounded_taskgroup.py +20 -0
- terrakio_core/helper/decorators.py +58 -0
- terrakio_core/{generation → helper}/tiles.py +1 -12
- terrakio_core/sync_client.py +370 -0
- {terrakio_core-0.3.4.dist-info → terrakio_core-0.3.7.dist-info}/METADATA +7 -1
- terrakio_core-0.3.7.dist-info/RECORD +21 -0
- terrakio_core/auth.py +0 -223
- terrakio_core/dataset_management.py +0 -287
- terrakio_core/decorators.py +0 -18
- terrakio_core/group_access_management.py +0 -232
- terrakio_core/mass_stats.py +0 -504
- terrakio_core/space_management.py +0 -101
- terrakio_core/user_management.py +0 -227
- terrakio_core-0.3.4.dist-info/RECORD +0 -16
- {terrakio_core-0.3.4.dist-info → terrakio_core-0.3.7.dist-info}/WHEEL +0 -0
- {terrakio_core-0.3.4.dist-info → terrakio_core-0.3.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,790 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import textwrap
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Dict, Any, Union, Tuple
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
import numpy as np
|
|
9
|
+
from google.cloud import storage
|
|
10
|
+
from ..helper.decorators import require_token, require_api_key, require_auth
|
|
11
|
+
TORCH_AVAILABLE = False
|
|
12
|
+
SKL2ONNX_AVAILABLE = False
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import torch
|
|
16
|
+
TORCH_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
torch = None
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from skl2onnx import convert_sklearn
|
|
22
|
+
from skl2onnx.common.data_types import FloatTensorType
|
|
23
|
+
from sklearn.base import BaseEstimator
|
|
24
|
+
SKL2ONNX_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
convert_sklearn = None
|
|
27
|
+
FloatTensorType = None
|
|
28
|
+
BaseEstimator = None
|
|
29
|
+
|
|
30
|
+
from io import BytesIO
|
|
31
|
+
from typing import Tuple
|
|
32
|
+
|
|
33
|
+
class ModelManagement:
|
|
34
|
+
def __init__(self, client):
|
|
35
|
+
self._client = client
|
|
36
|
+
|
|
37
|
+
@require_api_key
|
|
38
|
+
def generate_ai_dataset(
|
|
39
|
+
self,
|
|
40
|
+
name: str,
|
|
41
|
+
aoi_geojson: str,
|
|
42
|
+
expression_x: str,
|
|
43
|
+
filter_x_rate: float,
|
|
44
|
+
filter_y_rate: float,
|
|
45
|
+
samples: int,
|
|
46
|
+
tile_size: int,
|
|
47
|
+
expression_y: str = "skip",
|
|
48
|
+
filter_x: str = "skip",
|
|
49
|
+
filter_y: str = "skip",
|
|
50
|
+
crs: str = "epsg:4326",
|
|
51
|
+
res: float = 0.001,
|
|
52
|
+
region: str = "aus",
|
|
53
|
+
start_year: int = None,
|
|
54
|
+
end_year: int = None,
|
|
55
|
+
) -> dict:
|
|
56
|
+
"""
|
|
57
|
+
Generate an AI dataset using specified parameters.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
name (str): Name of the dataset to generate
|
|
61
|
+
aoi_geojson (str): Path to GeoJSON file containing area of interest
|
|
62
|
+
expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
63
|
+
filter_x (str): Filter for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
64
|
+
filter_x_rate (float): Filter rate for X variable (e.g. 0.5)
|
|
65
|
+
expression_y (str): Expression for Y variable with {year} placeholder
|
|
66
|
+
filter_y (str): Filter for Y variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
67
|
+
filter_y_rate (float): Filter rate for Y variable (e.g. 0.5)
|
|
68
|
+
samples (int): Number of samples to generate
|
|
69
|
+
tile_size (int): Size of tiles in degrees
|
|
70
|
+
crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
|
|
71
|
+
res (float, optional): Resolution in degrees. Defaults to 0.001
|
|
72
|
+
region (str, optional): Region code. Defaults to "aus"
|
|
73
|
+
start_year (int, optional): Start year for data generation. Required if end_year provided
|
|
74
|
+
end_year (int, optional): End year for data generation. Required if start_year provided
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
dict: Response from the AI dataset generation API
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
APIError: If the API request fails
|
|
81
|
+
"""
|
|
82
|
+
# Build config for expressions and filters
|
|
83
|
+
config = {
|
|
84
|
+
"expressions": [{"expr": expression_x, "res": res, "prefix": "x"}],
|
|
85
|
+
"filters": []
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
if expression_y != "skip":
|
|
89
|
+
config["expressions"].append({"expr": expression_y, "res": res, "prefix": "y"})
|
|
90
|
+
|
|
91
|
+
if filter_x != "skip":
|
|
92
|
+
config["filters"].append({"expr": filter_x, "res": res, "rate": filter_x_rate})
|
|
93
|
+
if filter_y != "skip":
|
|
94
|
+
config["filters"].append({"expr": filter_y, "res": res, "rate": filter_y_rate})
|
|
95
|
+
|
|
96
|
+
# Replace year placeholders if start_year is provided
|
|
97
|
+
if start_year is not None:
|
|
98
|
+
expression_x = expression_x.replace("{year}", str(start_year))
|
|
99
|
+
if expression_y != "skip":
|
|
100
|
+
expression_y = expression_y.replace("{year}", str(start_year))
|
|
101
|
+
if filter_x != "skip":
|
|
102
|
+
filter_x = filter_x.replace("{year}", str(start_year))
|
|
103
|
+
if filter_y != "skip":
|
|
104
|
+
filter_y = filter_y.replace("{year}", str(start_year))
|
|
105
|
+
|
|
106
|
+
# Load AOI GeoJSON
|
|
107
|
+
with open(aoi_geojson, 'r') as f:
|
|
108
|
+
aoi_data = json.load(f)
|
|
109
|
+
|
|
110
|
+
task_response = self._client.mass_stats.random_sample(
|
|
111
|
+
name=name,
|
|
112
|
+
config=config,
|
|
113
|
+
aoi=aoi_data,
|
|
114
|
+
samples=samples,
|
|
115
|
+
year_range=[start_year, end_year],
|
|
116
|
+
crs=crs,
|
|
117
|
+
tile_size=tile_size,
|
|
118
|
+
res=res,
|
|
119
|
+
region=region,
|
|
120
|
+
output="netcdf",
|
|
121
|
+
server=self._client.url,
|
|
122
|
+
bucket="terrakio-mass-requests",
|
|
123
|
+
overwrite=True
|
|
124
|
+
)
|
|
125
|
+
task_id = task_response["task_id"]
|
|
126
|
+
|
|
127
|
+
# Wait for job completion with progress bar
|
|
128
|
+
while True:
|
|
129
|
+
result = self._client.track_mass_stats_job(ids=[task_id])
|
|
130
|
+
status = result[task_id]['status']
|
|
131
|
+
completed = result[task_id].get('completed', 0)
|
|
132
|
+
total = result[task_id].get('total', 1)
|
|
133
|
+
|
|
134
|
+
# Create progress bar
|
|
135
|
+
progress = completed / total if total > 0 else 0
|
|
136
|
+
bar_length = 50
|
|
137
|
+
filled_length = int(bar_length * progress)
|
|
138
|
+
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
|
139
|
+
percentage = progress * 100
|
|
140
|
+
|
|
141
|
+
self._client.logger.info(f"Job status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})")
|
|
142
|
+
|
|
143
|
+
if status == "Completed":
|
|
144
|
+
self._client.logger.info("Job completed successfully!")
|
|
145
|
+
break
|
|
146
|
+
elif status == "Error":
|
|
147
|
+
self._client.logger.info("Job encountered an error")
|
|
148
|
+
raise Exception(f"Job {task_id} encountered an error")
|
|
149
|
+
|
|
150
|
+
# Wait 5 seconds before checking again
|
|
151
|
+
time.sleep(5)
|
|
152
|
+
|
|
153
|
+
# after all the random sample jobs are done, we then start the mass stats job
|
|
154
|
+
task_id = self._client.mass_stats.start_mass_stats_job(task_id)
|
|
155
|
+
return task_id
|
|
156
|
+
|
|
157
|
+
@require_api_key
|
|
158
|
+
async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
|
|
159
|
+
"""
|
|
160
|
+
Upload a model to the bucket so that it can be used for inference.
|
|
161
|
+
Converts PyTorch and scikit-learn models to ONNX format before uploading.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
model: The model object (PyTorch model or scikit-learn model)
|
|
165
|
+
model_name: Name for the model (without extension)
|
|
166
|
+
input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
|
|
167
|
+
Required for PyTorch models, optional for scikit-learn models
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
APIError: If the API request fails
|
|
171
|
+
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
172
|
+
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
173
|
+
"""
|
|
174
|
+
uid = (await self._client.auth.get_user_info())["uid"]
|
|
175
|
+
|
|
176
|
+
client = storage.Client()
|
|
177
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
178
|
+
|
|
179
|
+
# Convert model to ONNX format
|
|
180
|
+
onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
|
|
181
|
+
|
|
182
|
+
# Upload ONNX model to bucket
|
|
183
|
+
blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
|
|
184
|
+
|
|
185
|
+
blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
|
|
186
|
+
self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
|
|
187
|
+
|
|
188
|
+
def _convert_model_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
|
|
189
|
+
"""
|
|
190
|
+
Convert a model to ONNX format and return as bytes.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
model: The model object (PyTorch or scikit-learn)
|
|
194
|
+
model_name: Name of the model for logging
|
|
195
|
+
input_shape: Shape of input data
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
bytes: ONNX model as bytes
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ValueError: If model type is not supported
|
|
202
|
+
ImportError: If required libraries are not installed
|
|
203
|
+
"""
|
|
204
|
+
# Early check for any conversion capability
|
|
205
|
+
if not (TORCH_AVAILABLE or SKL2ONNX_AVAILABLE):
|
|
206
|
+
raise ImportError(
|
|
207
|
+
"ONNX conversion requires additional dependencies. Install with:\n"
|
|
208
|
+
" pip install torch # For PyTorch models\n"
|
|
209
|
+
" pip install skl2onnx # For scikit-learn models\n"
|
|
210
|
+
" pip install torch skl2onnx # For both"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Check if it's a PyTorch model using isinstance (preferred) with fallback
|
|
214
|
+
is_pytorch = False
|
|
215
|
+
if TORCH_AVAILABLE:
|
|
216
|
+
is_pytorch = (isinstance(model, torch.nn.Module) or
|
|
217
|
+
hasattr(model, 'state_dict'))
|
|
218
|
+
|
|
219
|
+
# Check if it's a scikit-learn model
|
|
220
|
+
is_sklearn = False
|
|
221
|
+
if SKL2ONNX_AVAILABLE:
|
|
222
|
+
is_sklearn = (isinstance(model, BaseEstimator) or
|
|
223
|
+
(hasattr(model, 'fit') and hasattr(model, 'predict')))
|
|
224
|
+
|
|
225
|
+
if is_pytorch and TORCH_AVAILABLE:
|
|
226
|
+
return self._convert_pytorch_to_onnx(model, model_name, input_shape)
|
|
227
|
+
elif is_sklearn and SKL2ONNX_AVAILABLE:
|
|
228
|
+
return self._convert_sklearn_to_onnx(model, model_name, input_shape)
|
|
229
|
+
else:
|
|
230
|
+
# Provide helpful error message
|
|
231
|
+
model_type = type(model).__name__
|
|
232
|
+
model_module = type(model).__module__
|
|
233
|
+
available_types = []
|
|
234
|
+
missing_deps = []
|
|
235
|
+
|
|
236
|
+
if TORCH_AVAILABLE:
|
|
237
|
+
available_types.append("PyTorch (torch.nn.Module)")
|
|
238
|
+
else:
|
|
239
|
+
missing_deps.append("torch")
|
|
240
|
+
|
|
241
|
+
if SKL2ONNX_AVAILABLE:
|
|
242
|
+
available_types.append("scikit-learn (BaseEstimator)")
|
|
243
|
+
else:
|
|
244
|
+
missing_deps.append("skl2onnx")
|
|
245
|
+
|
|
246
|
+
if missing_deps:
|
|
247
|
+
raise ImportError(
|
|
248
|
+
f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
|
|
249
|
+
f"Install with: pip install {' '.join(missing_deps)}"
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"Unsupported model type: {model_type} from {model_module}. "
|
|
254
|
+
f"Supported types: {', '.join(available_types)}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
|
|
258
|
+
"""Convert PyTorch model to ONNX format with dynamic input dimensions."""
|
|
259
|
+
if input_shape is None:
|
|
260
|
+
raise ValueError("input_shape is required for PyTorch models")
|
|
261
|
+
|
|
262
|
+
self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
# Set model to evaluation mode
|
|
266
|
+
model.eval()
|
|
267
|
+
|
|
268
|
+
# Create dummy input
|
|
269
|
+
dummy_input = torch.randn(input_shape)
|
|
270
|
+
|
|
271
|
+
# Use BytesIO to avoid creating temporary files
|
|
272
|
+
onnx_buffer = BytesIO()
|
|
273
|
+
|
|
274
|
+
# Determine dynamic axes based on input shape
|
|
275
|
+
# Common patterns for different input types:
|
|
276
|
+
if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
|
|
277
|
+
dynamic_axes = {
|
|
278
|
+
'float_input': {
|
|
279
|
+
0: 'batch_size',
|
|
280
|
+
2: 'height', # Make height dynamic for variable input sizes
|
|
281
|
+
3: 'width' # Make width dynamic for variable input sizes
|
|
282
|
+
},
|
|
283
|
+
'output': {0: 'batch_size'}
|
|
284
|
+
}
|
|
285
|
+
elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
|
|
286
|
+
dynamic_axes = {
|
|
287
|
+
'float_input': {
|
|
288
|
+
0: 'batch_size',
|
|
289
|
+
1: 'dim1', # Generic dynamic dimension
|
|
290
|
+
2: 'dim2' # Generic dynamic dimension
|
|
291
|
+
},
|
|
292
|
+
'output': {0: 'batch_size'}
|
|
293
|
+
}
|
|
294
|
+
elif len(input_shape) == 2: # Likely (batch, features)
|
|
295
|
+
dynamic_axes = {
|
|
296
|
+
'float_input': {
|
|
297
|
+
0: 'batch_size'
|
|
298
|
+
# Don't make features dynamic as it usually affects model architecture
|
|
299
|
+
},
|
|
300
|
+
'output': {0: 'batch_size'}
|
|
301
|
+
}
|
|
302
|
+
else:
|
|
303
|
+
# For other shapes, just make batch size dynamic
|
|
304
|
+
dynamic_axes = {
|
|
305
|
+
'float_input': {0: 'batch_size'},
|
|
306
|
+
'output': {0: 'batch_size'}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
torch.onnx.export(
|
|
310
|
+
model,
|
|
311
|
+
dummy_input,
|
|
312
|
+
onnx_buffer,
|
|
313
|
+
export_params=True,
|
|
314
|
+
opset_version=11,
|
|
315
|
+
do_constant_folding=True,
|
|
316
|
+
input_names=['float_input'],
|
|
317
|
+
output_names=['output'],
|
|
318
|
+
dynamic_axes=dynamic_axes
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
|
|
322
|
+
return onnx_buffer.getvalue()
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
|
|
329
|
+
"""Convert scikit-learn model to ONNX format."""
|
|
330
|
+
self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
|
|
331
|
+
|
|
332
|
+
# Try to infer input shape if not provided
|
|
333
|
+
if input_shape is None:
|
|
334
|
+
if hasattr(model, 'n_features_in_'):
|
|
335
|
+
input_shape = (1, model.n_features_in_)
|
|
336
|
+
else:
|
|
337
|
+
raise ValueError(
|
|
338
|
+
"input_shape is required for scikit-learn models when n_features_in_ is not available. "
|
|
339
|
+
"This usually happens with older sklearn versions or models not fitted yet."
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
# Convert scikit-learn model to ONNX
|
|
344
|
+
initial_type = [('float_input', FloatTensorType(input_shape))]
|
|
345
|
+
onnx_model = convert_sklearn(model, initial_types=initial_type)
|
|
346
|
+
return onnx_model.SerializeToString()
|
|
347
|
+
|
|
348
|
+
except Exception as e:
|
|
349
|
+
raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
|
|
350
|
+
|
|
351
|
+
@require_api_key
|
|
352
|
+
async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
|
|
353
|
+
"""
|
|
354
|
+
Upload a CNN model to the bucket and deploy it.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
model: The model object (PyTorch model or scikit-learn model)
|
|
358
|
+
model_name: Name for the model (without extension)
|
|
359
|
+
dataset: Name of the dataset to create
|
|
360
|
+
product: Product name for the inference
|
|
361
|
+
input_expression: Input expression for the dataset
|
|
362
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
363
|
+
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
APIError: If the API request fails
|
|
367
|
+
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
368
|
+
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
369
|
+
"""
|
|
370
|
+
await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
|
|
371
|
+
# so the uploading process is kinda similar, but the deployment step is kinda different
|
|
372
|
+
await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
|
|
373
|
+
|
|
374
|
+
@require_api_key
|
|
375
|
+
async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
|
|
376
|
+
"""
|
|
377
|
+
Upload a model to the bucket and deploy it.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
model: The model object (PyTorch model or scikit-learn model)
|
|
381
|
+
model_name: Name for the model (without extension)
|
|
382
|
+
dataset: Name of the dataset to create
|
|
383
|
+
product: Product name for the inference
|
|
384
|
+
input_expression: Input expression for the dataset
|
|
385
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
386
|
+
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
387
|
+
"""
|
|
388
|
+
await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
|
|
389
|
+
await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
|
|
390
|
+
|
|
391
|
+
@require_api_key
|
|
392
|
+
def train_model(
|
|
393
|
+
self,
|
|
394
|
+
model_name: str,
|
|
395
|
+
training_dataset: str,
|
|
396
|
+
task_type: str,
|
|
397
|
+
model_category: str,
|
|
398
|
+
architecture: str,
|
|
399
|
+
region: str,
|
|
400
|
+
hyperparameters: dict = None
|
|
401
|
+
) -> dict:
|
|
402
|
+
"""
|
|
403
|
+
Train a model using the external model training API.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
model_name (str): The name of the model to train.
|
|
407
|
+
training_dataset (str): The training dataset identifier.
|
|
408
|
+
task_type (str): The type of ML task (e.g., regression, classification).
|
|
409
|
+
model_category (str): The category of model (e.g., random_forest).
|
|
410
|
+
architecture (str): The model architecture.
|
|
411
|
+
region (str): The region identifier.
|
|
412
|
+
hyperparameters (dict, optional): Additional hyperparameters for training.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
dict: The response from the model training API.
|
|
416
|
+
|
|
417
|
+
Raises:
|
|
418
|
+
APIError: If the API request fails
|
|
419
|
+
"""
|
|
420
|
+
payload = {
|
|
421
|
+
"model_name": model_name,
|
|
422
|
+
"training_dataset": training_dataset,
|
|
423
|
+
"task_type": task_type,
|
|
424
|
+
"model_category": model_category,
|
|
425
|
+
"architecture": architecture,
|
|
426
|
+
"region": region,
|
|
427
|
+
"hyperparameters": hyperparameters
|
|
428
|
+
}
|
|
429
|
+
return self._client._terrakio_request("POST", "/train_model", json=payload)
|
|
430
|
+
|
|
431
|
+
@require_api_key
|
|
432
|
+
async def deploy_model(
|
|
433
|
+
self,
|
|
434
|
+
dataset: str,
|
|
435
|
+
product: str,
|
|
436
|
+
model_name: str,
|
|
437
|
+
input_expression: str,
|
|
438
|
+
model_training_job_name: str,
|
|
439
|
+
dates_iso8601: list
|
|
440
|
+
) -> Dict[str, Any]:
|
|
441
|
+
"""
|
|
442
|
+
Deploy a model by generating inference script and creating dataset.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
dataset: Name of the dataset to create
|
|
446
|
+
product: Product name for the inference
|
|
447
|
+
model_name: Name of the trained model
|
|
448
|
+
input_expression: Input expression for the dataset
|
|
449
|
+
model_training_job_name: Name of the training job
|
|
450
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
dict: Response from the deployment process
|
|
454
|
+
|
|
455
|
+
Raises:
|
|
456
|
+
APIError: If the API request fails
|
|
457
|
+
"""
|
|
458
|
+
# Get user info to get UID
|
|
459
|
+
user_info = await self._client.auth.get_user_info()
|
|
460
|
+
uid = user_info["uid"]
|
|
461
|
+
|
|
462
|
+
# Generate and upload script
|
|
463
|
+
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
464
|
+
script_name = f"{product}.py"
|
|
465
|
+
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
466
|
+
|
|
467
|
+
# Create dataset
|
|
468
|
+
return await self._client.datasets.create_dataset(
|
|
469
|
+
name=dataset,
|
|
470
|
+
collection="terrakio-datasets",
|
|
471
|
+
products=[product],
|
|
472
|
+
path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
|
|
473
|
+
input=input_expression,
|
|
474
|
+
dates_iso8601=dates_iso8601,
|
|
475
|
+
padding=0
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
@require_api_key
|
|
479
|
+
async def deploy_cnn_model(
|
|
480
|
+
self,
|
|
481
|
+
dataset: str,
|
|
482
|
+
product: str,
|
|
483
|
+
model_name: str,
|
|
484
|
+
input_expression: str,
|
|
485
|
+
model_training_job_name: str,
|
|
486
|
+
dates_iso8601: list
|
|
487
|
+
) -> Dict[str, Any]:
|
|
488
|
+
"""
|
|
489
|
+
Deploy a CNN model by generating inference script and creating dataset.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
dataset: Name of the dataset to create
|
|
493
|
+
product: Product name for the inference
|
|
494
|
+
model_name: Name of the trained model
|
|
495
|
+
input_expression: Input expression for the dataset
|
|
496
|
+
model_training_job_name: Name of the training job
|
|
497
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
498
|
+
|
|
499
|
+
Returns:
|
|
500
|
+
dict: Response from the deployment process
|
|
501
|
+
|
|
502
|
+
Raises:
|
|
503
|
+
APIError: If the API request fails
|
|
504
|
+
"""
|
|
505
|
+
# Get user info to get UID
|
|
506
|
+
user_info = await self._client.auth.get_user_info()
|
|
507
|
+
uid = user_info["uid"]
|
|
508
|
+
|
|
509
|
+
# Generate and upload script
|
|
510
|
+
script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid)
|
|
511
|
+
script_name = f"{product}.py"
|
|
512
|
+
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
513
|
+
# Create dataset
|
|
514
|
+
return await self._client.datasets.create_dataset(
|
|
515
|
+
name=dataset,
|
|
516
|
+
collection="terrakio-datasets",
|
|
517
|
+
products=[product],
|
|
518
|
+
path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
|
|
519
|
+
input=input_expression,
|
|
520
|
+
dates_iso8601=dates_iso8601,
|
|
521
|
+
padding=0
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
@require_api_key
|
|
525
|
+
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
526
|
+
"""
|
|
527
|
+
Generate Python inference script for the model.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
model_name: Name of the model
|
|
531
|
+
product: Product name
|
|
532
|
+
model_training_job_name: Training job name
|
|
533
|
+
uid: User ID
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
str: Generated Python script content
|
|
537
|
+
"""
|
|
538
|
+
return textwrap.dedent(f'''
|
|
539
|
+
import logging
|
|
540
|
+
from io import BytesIO
|
|
541
|
+
|
|
542
|
+
import numpy as np
|
|
543
|
+
import pandas as pd
|
|
544
|
+
import xarray as xr
|
|
545
|
+
from google.cloud import storage
|
|
546
|
+
from onnxruntime import InferenceSession
|
|
547
|
+
|
|
548
|
+
logging.basicConfig(
|
|
549
|
+
level=logging.INFO
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
def get_model():
|
|
553
|
+
logging.info("Loading model for {model_name}...")
|
|
554
|
+
|
|
555
|
+
client = storage.Client()
|
|
556
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
557
|
+
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
558
|
+
|
|
559
|
+
model = BytesIO()
|
|
560
|
+
blob.download_to_file(model)
|
|
561
|
+
model.seek(0)
|
|
562
|
+
|
|
563
|
+
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
564
|
+
return session
|
|
565
|
+
|
|
566
|
+
def {product}(*bands, model):
|
|
567
|
+
logging.info("start preparing data")
|
|
568
|
+
|
|
569
|
+
data_arrays = list(bands)
|
|
570
|
+
|
|
571
|
+
reference_array = data_arrays[0]
|
|
572
|
+
original_shape = reference_array.shape
|
|
573
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
574
|
+
|
|
575
|
+
if 'time' in reference_array.dims:
|
|
576
|
+
time_coords = reference_array.coords['time']
|
|
577
|
+
if len(time_coords) == 1:
|
|
578
|
+
output_timestamp = time_coords[0]
|
|
579
|
+
else:
|
|
580
|
+
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
581
|
+
unique_years = set(years)
|
|
582
|
+
|
|
583
|
+
if len(unique_years) == 1:
|
|
584
|
+
year = list(unique_years)[0]
|
|
585
|
+
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
586
|
+
else:
|
|
587
|
+
latest_year = max(unique_years)
|
|
588
|
+
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
589
|
+
else:
|
|
590
|
+
output_timestamp = pd.Timestamp("1970-01-01")
|
|
591
|
+
|
|
592
|
+
averaged_bands = []
|
|
593
|
+
for data_array in data_arrays:
|
|
594
|
+
if 'time' in data_array.dims:
|
|
595
|
+
averaged_band = np.mean(data_array.values, axis=0)
|
|
596
|
+
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
597
|
+
else:
|
|
598
|
+
averaged_band = data_array.values
|
|
599
|
+
logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
|
|
600
|
+
|
|
601
|
+
flattened_band = averaged_band.reshape(-1, 1)
|
|
602
|
+
averaged_bands.append(flattened_band)
|
|
603
|
+
|
|
604
|
+
input_data = np.hstack(averaged_bands)
|
|
605
|
+
|
|
606
|
+
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
607
|
+
|
|
608
|
+
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
609
|
+
|
|
610
|
+
logging.info(f"Model output shape: {{output.shape}}")
|
|
611
|
+
|
|
612
|
+
if len(original_shape) >= 3:
|
|
613
|
+
spatial_shape = original_shape[1:]
|
|
614
|
+
else:
|
|
615
|
+
spatial_shape = original_shape
|
|
616
|
+
|
|
617
|
+
output_reshaped = output.reshape(spatial_shape)
|
|
618
|
+
|
|
619
|
+
output_with_time = np.expand_dims(output_reshaped, axis=0)
|
|
620
|
+
|
|
621
|
+
if 'time' in reference_array.dims:
|
|
622
|
+
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
623
|
+
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
|
|
624
|
+
else:
|
|
625
|
+
spatial_dims = list(reference_array.dims)
|
|
626
|
+
spatial_coords = dict(reference_array.coords)
|
|
627
|
+
|
|
628
|
+
result = xr.DataArray(
|
|
629
|
+
data=output_with_time.astype(np.float32),
|
|
630
|
+
dims=['time'] + list(spatial_dims),
|
|
631
|
+
coords={{
|
|
632
|
+
'time': [output_timestamp.values],
|
|
633
|
+
'y': spatial_coords['y'].values,
|
|
634
|
+
'x': spatial_coords['x'].values
|
|
635
|
+
}}
|
|
636
|
+
)
|
|
637
|
+
return result
|
|
638
|
+
''').strip()
|
|
639
|
+
|
|
640
|
+
@require_api_key
|
|
641
|
+
def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
642
|
+
"""
|
|
643
|
+
Generate Python inference script for CNN model with time-stacked bands.
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
model_name: Name of the model
|
|
647
|
+
product: Product name
|
|
648
|
+
model_training_job_name: Training job name
|
|
649
|
+
uid: User ID
|
|
650
|
+
|
|
651
|
+
Returns:
|
|
652
|
+
str: Generated Python script content
|
|
653
|
+
"""
|
|
654
|
+
return textwrap.dedent(f'''
|
|
655
|
+
import logging
|
|
656
|
+
from io import BytesIO
|
|
657
|
+
|
|
658
|
+
import numpy as np
|
|
659
|
+
import pandas as pd
|
|
660
|
+
import xarray as xr
|
|
661
|
+
from google.cloud import storage
|
|
662
|
+
from onnxruntime import InferenceSession
|
|
663
|
+
|
|
664
|
+
logging.basicConfig(
|
|
665
|
+
level=logging.INFO
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
def get_model():
|
|
669
|
+
logging.info("Loading CNN model for {model_name}...")
|
|
670
|
+
|
|
671
|
+
client = storage.Client()
|
|
672
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
673
|
+
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
674
|
+
|
|
675
|
+
model = BytesIO()
|
|
676
|
+
blob.download_to_file(model)
|
|
677
|
+
model.seek(0)
|
|
678
|
+
|
|
679
|
+
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
680
|
+
return session
|
|
681
|
+
|
|
682
|
+
def {product}(*bands, model):
|
|
683
|
+
logging.info("Start preparing CNN data with time-stacked bands")
|
|
684
|
+
|
|
685
|
+
data_arrays = list(bands)
|
|
686
|
+
|
|
687
|
+
if not data_arrays:
|
|
688
|
+
raise ValueError("No bands provided")
|
|
689
|
+
|
|
690
|
+
reference_array = data_arrays[0]
|
|
691
|
+
original_shape = reference_array.shape
|
|
692
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
693
|
+
|
|
694
|
+
# Get time coordinates - all bands should have the same time dimension
|
|
695
|
+
if 'time' not in reference_array.dims:
|
|
696
|
+
raise ValueError("Time dimension is required for CNN processing")
|
|
697
|
+
|
|
698
|
+
time_coords = reference_array.coords['time']
|
|
699
|
+
num_timestamps = len(time_coords)
|
|
700
|
+
logging.info(f"Number of timestamps: {{num_timestamps}}")
|
|
701
|
+
|
|
702
|
+
# Get spatial dimensions
|
|
703
|
+
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
704
|
+
height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height
|
|
705
|
+
width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width
|
|
706
|
+
logging.info(f"Spatial dimensions: {{height}} x {{width}}")
|
|
707
|
+
|
|
708
|
+
# Stack bands across time dimension
|
|
709
|
+
# Result will be: (num_bands * num_timestamps, height, width)
|
|
710
|
+
stacked_channels = []
|
|
711
|
+
|
|
712
|
+
for band_idx, data_array in enumerate(data_arrays):
|
|
713
|
+
logging.info(f"Processing band {{band_idx + 1}}/{{len(data_arrays)}}")
|
|
714
|
+
|
|
715
|
+
# Ensure consistent time coordinates across bands
|
|
716
|
+
if not np.array_equal(data_array.coords['time'].values, time_coords.values):
|
|
717
|
+
logging.warning(f"Band {{band_idx}} has different time coordinates, aligning...")
|
|
718
|
+
data_array = data_array.sel(time=time_coords, method='nearest')
|
|
719
|
+
|
|
720
|
+
# Extract values and ensure proper ordering (time, height, width)
|
|
721
|
+
band_values = data_array.values
|
|
722
|
+
if band_values.ndim == 3:
|
|
723
|
+
# Reorder dimensions if needed to ensure (time, height, width)
|
|
724
|
+
time_dim_idx = data_array.dims.index('time')
|
|
725
|
+
if time_dim_idx != 0:
|
|
726
|
+
axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]
|
|
727
|
+
band_values = np.transpose(band_values, axes_order)
|
|
728
|
+
|
|
729
|
+
# Add each timestamp of this band to the channel stack
|
|
730
|
+
for t in range(num_timestamps):
|
|
731
|
+
stacked_channels.append(band_values[t])
|
|
732
|
+
|
|
733
|
+
# Stack all channels: (num_bands * num_timestamps, height, width)
|
|
734
|
+
input_channels = np.stack(stacked_channels, axis=0)
|
|
735
|
+
total_channels = len(data_arrays) * num_timestamps
|
|
736
|
+
logging.info(f"Stacked channels shape: {{input_channels.shape}}")
|
|
737
|
+
logging.info(f"Total channels: {{total_channels}} ({{len(data_arrays)}} bands × {{num_timestamps}} timestamps)")
|
|
738
|
+
|
|
739
|
+
# Add batch dimension: (1, num_channels, height, width)
|
|
740
|
+
input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)
|
|
741
|
+
logging.info(f"Final input shape for CNN: {{input_data.shape}}")
|
|
742
|
+
|
|
743
|
+
# Run inference
|
|
744
|
+
output = model.run(None, {{"float_input": input_data}})[0]
|
|
745
|
+
logging.info(f"Model output shape: {{output.shape}}")
|
|
746
|
+
|
|
747
|
+
# Process output back to xarray format
|
|
748
|
+
# Assuming output is (1, height, width) or (1, 1, height, width)
|
|
749
|
+
if output.ndim == 4 and output.shape[1] == 1:
|
|
750
|
+
# Remove channel dimension if it's 1
|
|
751
|
+
output_2d = output[0, 0]
|
|
752
|
+
elif output.ndim == 3:
|
|
753
|
+
# Remove batch dimension
|
|
754
|
+
output_2d = output[0]
|
|
755
|
+
else:
|
|
756
|
+
# Handle other cases
|
|
757
|
+
output_2d = np.squeeze(output)
|
|
758
|
+
if output_2d.ndim != 2:
|
|
759
|
+
raise ValueError(f"Unexpected output shape after processing: {{output_2d.shape}}")
|
|
760
|
+
|
|
761
|
+
# Determine output timestamp (use the latest timestamp)
|
|
762
|
+
output_timestamp = time_coords[-1]
|
|
763
|
+
|
|
764
|
+
# Get spatial coordinates from reference array
|
|
765
|
+
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims}}
|
|
766
|
+
|
|
767
|
+
# Create output DataArray
|
|
768
|
+
result = xr.DataArray(
|
|
769
|
+
data=np.expand_dims(output_2d.astype(np.float32), axis=0),
|
|
770
|
+
dims=['time'] + spatial_dims,
|
|
771
|
+
coords={{
|
|
772
|
+
'time': [output_timestamp.values],
|
|
773
|
+
spatial_dims[0]: spatial_coords[spatial_dims[0]].values,
|
|
774
|
+
spatial_dims[1]: spatial_coords[spatial_dims[1]].values
|
|
775
|
+
}}
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
logging.info(f"Final result shape: {{result.shape}}")
|
|
779
|
+
return result
|
|
780
|
+
''').strip()
|
|
781
|
+
|
|
782
|
+
@require_api_key
|
|
783
|
+
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
784
|
+
"""Upload the generated script to Google Cloud Storage"""
|
|
785
|
+
|
|
786
|
+
client = storage.Client()
|
|
787
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
788
|
+
blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
|
|
789
|
+
blob.upload_from_string(script_content, content_type='text/plain')
|
|
790
|
+
logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
|