terrakio-core 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +3 -1
- terrakio_core/accessors.py +477 -0
- terrakio_core/async_client.py +24 -39
- terrakio_core/client.py +83 -84
- terrakio_core/convenience_functions/convenience_functions.py +316 -324
- terrakio_core/endpoints/auth.py +8 -1
- terrakio_core/endpoints/mass_stats.py +19 -15
- terrakio_core/endpoints/model_management.py +728 -597
- terrakio_core/sync_client.py +341 -33
- {terrakio_core-0.4.2.dist-info → terrakio_core-0.4.4.dist-info}/METADATA +2 -1
- terrakio_core-0.4.4.dist-info/RECORD +22 -0
- terrakio_core-0.4.2.dist-info/RECORD +0 -21
- {terrakio_core-0.4.2.dist-info → terrakio_core-0.4.4.dist-info}/WHEEL +0 -0
- {terrakio_core-0.4.2.dist-info → terrakio_core-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -1,43 +1,43 @@
|
|
|
1
|
-
|
|
1
|
+
# Standard library imports
|
|
2
|
+
import ast
|
|
2
3
|
import json
|
|
3
|
-
import time
|
|
4
4
|
import textwrap
|
|
5
|
-
import
|
|
6
|
-
from typing import Dict, Any, Union, Tuple, Optional
|
|
5
|
+
import time
|
|
7
6
|
from io import BytesIO
|
|
8
|
-
import
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
import onnxruntime as ort
|
|
9
|
+
|
|
10
|
+
# Internal imports
|
|
11
|
+
from ..helper.decorators import require_api_key
|
|
12
|
+
|
|
13
|
+
# Optional dependency flags
|
|
12
14
|
TORCH_AVAILABLE = False
|
|
13
15
|
SKL2ONNX_AVAILABLE = False
|
|
14
16
|
|
|
17
|
+
# PyTorch imports
|
|
15
18
|
try:
|
|
16
19
|
import torch
|
|
17
20
|
TORCH_AVAILABLE = True
|
|
18
21
|
except ImportError:
|
|
19
22
|
torch = None
|
|
20
23
|
|
|
24
|
+
# Scikit-learn and ONNX conversion imports
|
|
21
25
|
try:
|
|
26
|
+
from sklearn.base import BaseEstimator
|
|
22
27
|
from skl2onnx import convert_sklearn
|
|
23
28
|
from skl2onnx.common.data_types import FloatTensorType
|
|
24
|
-
from sklearn.base import BaseEstimator
|
|
25
29
|
SKL2ONNX_AVAILABLE = True
|
|
26
30
|
except ImportError:
|
|
31
|
+
BaseEstimator = None
|
|
27
32
|
convert_sklearn = None
|
|
28
33
|
FloatTensorType = None
|
|
29
|
-
BaseEstimator = None
|
|
30
|
-
|
|
31
|
-
from io import BytesIO
|
|
32
|
-
from typing import Tuple
|
|
33
|
-
|
|
34
34
|
|
|
35
35
|
class ModelManagement:
|
|
36
36
|
def __init__(self, client):
|
|
37
37
|
self._client = client
|
|
38
38
|
|
|
39
39
|
@require_api_key
|
|
40
|
-
def generate_ai_dataset(
|
|
40
|
+
async def generate_ai_dataset(
|
|
41
41
|
self,
|
|
42
42
|
name: str,
|
|
43
43
|
aoi_geojson: str,
|
|
@@ -51,7 +51,8 @@ class ModelManagement:
|
|
|
51
51
|
filter_y: str = "skip",
|
|
52
52
|
crs: str = "epsg:4326",
|
|
53
53
|
res: float = 0.001,
|
|
54
|
-
region: str =
|
|
54
|
+
region: str = None,
|
|
55
|
+
bucket: str = None,
|
|
55
56
|
start_year: int = None,
|
|
56
57
|
end_year: int = None,
|
|
57
58
|
) -> dict:
|
|
@@ -71,7 +72,8 @@ class ModelManagement:
|
|
|
71
72
|
tile_size (int): Size of tiles in degrees
|
|
72
73
|
crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
|
|
73
74
|
res (float, optional): Resolution in degrees. Defaults to 0.001
|
|
74
|
-
region (str, optional): Region code. Defaults to
|
|
75
|
+
region (str, optional): Region code. Defaults to None
|
|
76
|
+
bucket (str, optional): Bucket name. Defaults to None
|
|
75
77
|
start_year (int, optional): Start year for data generation. Required if end_year provided
|
|
76
78
|
end_year (int, optional): End year for data generation. Required if start_year provided
|
|
77
79
|
|
|
@@ -109,7 +111,7 @@ class ModelManagement:
|
|
|
109
111
|
with open(aoi_geojson, 'r') as f:
|
|
110
112
|
aoi_data = json.load(f)
|
|
111
113
|
|
|
112
|
-
task_response = self._client.mass_stats.random_sample(
|
|
114
|
+
task_response = await self._client.mass_stats.random_sample(
|
|
113
115
|
name=name,
|
|
114
116
|
config=config,
|
|
115
117
|
aoi=aoi_data,
|
|
@@ -121,19 +123,17 @@ class ModelManagement:
|
|
|
121
123
|
region=region,
|
|
122
124
|
output="netcdf",
|
|
123
125
|
server=self._client.url,
|
|
124
|
-
bucket=
|
|
126
|
+
bucket=bucket,
|
|
125
127
|
overwrite=True
|
|
126
128
|
)
|
|
127
129
|
task_id = task_response["task_id"]
|
|
128
130
|
|
|
129
|
-
# Wait for job completion with progress bar
|
|
130
131
|
while True:
|
|
131
|
-
result = self._client.
|
|
132
|
+
result = await self._client.mass_stats.track_job(ids=[task_id])
|
|
132
133
|
status = result[task_id]['status']
|
|
133
134
|
completed = result[task_id].get('completed', 0)
|
|
134
135
|
total = result[task_id].get('total', 1)
|
|
135
136
|
|
|
136
|
-
# Create progress bar
|
|
137
137
|
progress = completed / total if total > 0 else 0
|
|
138
138
|
bar_length = 50
|
|
139
139
|
filled_length = int(bar_length * progress)
|
|
@@ -149,454 +149,456 @@ class ModelManagement:
|
|
|
149
149
|
self._client.logger.info("Job encountered an error")
|
|
150
150
|
raise Exception(f"Job {task_id} encountered an error")
|
|
151
151
|
|
|
152
|
-
# Wait 5 seconds before checking again
|
|
153
152
|
time.sleep(5)
|
|
154
153
|
|
|
155
|
-
|
|
156
|
-
task_id = self._client.mass_stats.start_mass_stats_job(task_id)
|
|
154
|
+
task_id = await self._client.mass_stats.start_job(task_id)
|
|
157
155
|
return task_id
|
|
158
156
|
|
|
159
157
|
@require_api_key
|
|
160
|
-
async def
|
|
158
|
+
async def _get_url_for_upload_model_and_script(self, expression: str, model_name: str, script_name: str) -> str:
|
|
161
159
|
"""
|
|
162
|
-
|
|
163
|
-
Converts PyTorch and scikit-learn models to ONNX format before uploading.
|
|
164
|
-
|
|
160
|
+
Get the url for the upload of the model
|
|
165
161
|
Args:
|
|
166
|
-
|
|
167
|
-
model_name:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
Raises:
|
|
172
|
-
APIError: If the API request fails
|
|
173
|
-
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
174
|
-
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
162
|
+
expression: The expression to use for the upload(for deciding which bucket to upload to)
|
|
163
|
+
model_name: The name of the model to upload
|
|
164
|
+
script_name: The name of the script to upload
|
|
165
|
+
Returns:
|
|
166
|
+
The url for the upload of the model
|
|
175
167
|
"""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
|
|
183
|
-
|
|
184
|
-
# Upload ONNX model to bucket
|
|
185
|
-
blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
|
|
186
|
-
|
|
187
|
-
blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
|
|
188
|
-
self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
|
|
168
|
+
payload = {
|
|
169
|
+
"model_name": model_name,
|
|
170
|
+
"expression": expression,
|
|
171
|
+
"script_name": script_name
|
|
172
|
+
}
|
|
173
|
+
return await self._client._terrakio_request("POST", "models/upload", json=payload)
|
|
189
174
|
|
|
190
|
-
def
|
|
175
|
+
async def _upload_model_to_url(self, upload_model_url: str, model: bytes):
|
|
191
176
|
"""
|
|
192
|
-
|
|
193
|
-
|
|
177
|
+
Upload a model to a given URL.
|
|
194
178
|
Args:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
179
|
+
model_url: The url to upload the model to
|
|
180
|
+
model: The model to upload
|
|
181
|
+
|
|
199
182
|
Returns:
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
Raises:
|
|
203
|
-
ValueError: If model type is not supported
|
|
204
|
-
ImportError: If required libraries are not installed
|
|
183
|
+
The response from the server
|
|
205
184
|
"""
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
return self._convert_sklearn_to_onnx(model, model_name, input_shape)
|
|
231
|
-
else:
|
|
232
|
-
# Provide helpful error message
|
|
233
|
-
model_type = type(model).__name__
|
|
234
|
-
model_module = type(model).__module__
|
|
235
|
-
available_types = []
|
|
236
|
-
missing_deps = []
|
|
237
|
-
|
|
238
|
-
if TORCH_AVAILABLE:
|
|
239
|
-
available_types.append("PyTorch (torch.nn.Module)")
|
|
240
|
-
else:
|
|
241
|
-
missing_deps.append("torch")
|
|
242
|
-
|
|
243
|
-
if SKL2ONNX_AVAILABLE:
|
|
244
|
-
available_types.append("scikit-learn (BaseEstimator)")
|
|
245
|
-
else:
|
|
246
|
-
missing_deps.append("skl2onnx")
|
|
247
|
-
|
|
248
|
-
if missing_deps:
|
|
249
|
-
raise ImportError(
|
|
250
|
-
f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
|
|
251
|
-
f"Install with: pip install {' '.join(missing_deps)}"
|
|
252
|
-
)
|
|
253
|
-
else:
|
|
254
|
-
raise ValueError(
|
|
255
|
-
f"Unsupported model type: {model_type} from {model_module}. "
|
|
256
|
-
f"Supported types: {', '.join(available_types)}"
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
|
|
260
|
-
"""Convert PyTorch model to ONNX format with dynamic input dimensions."""
|
|
261
|
-
if input_shape is None:
|
|
262
|
-
raise ValueError("input_shape is required for PyTorch models")
|
|
263
|
-
|
|
264
|
-
self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
|
|
265
|
-
|
|
266
|
-
try:
|
|
267
|
-
# Set model to evaluation mode
|
|
268
|
-
model.eval()
|
|
269
|
-
|
|
270
|
-
# Create dummy input
|
|
271
|
-
dummy_input = torch.randn(input_shape)
|
|
272
|
-
|
|
273
|
-
# Use BytesIO to avoid creating temporary files
|
|
274
|
-
onnx_buffer = BytesIO()
|
|
275
|
-
|
|
276
|
-
# Determine dynamic axes based on input shape
|
|
277
|
-
# Common patterns for different input types:
|
|
278
|
-
if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
|
|
279
|
-
dynamic_axes = {
|
|
280
|
-
'float_input': {
|
|
281
|
-
0: 'batch_size',
|
|
282
|
-
2: 'height', # Make height dynamic for variable input sizes
|
|
283
|
-
3: 'width' # Make width dynamic for variable input sizes
|
|
284
|
-
},
|
|
285
|
-
'output': {0: 'batch_size'}
|
|
286
|
-
}
|
|
287
|
-
elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
|
|
288
|
-
dynamic_axes = {
|
|
289
|
-
'float_input': {
|
|
290
|
-
0: 'batch_size',
|
|
291
|
-
1: 'dim1', # Generic dynamic dimension
|
|
292
|
-
2: 'dim2' # Generic dynamic dimension
|
|
293
|
-
},
|
|
294
|
-
'output': {0: 'batch_size'}
|
|
295
|
-
}
|
|
296
|
-
elif len(input_shape) == 2: # Likely (batch, features)
|
|
297
|
-
dynamic_axes = {
|
|
298
|
-
'float_input': {
|
|
299
|
-
0: 'batch_size'
|
|
300
|
-
# Don't make features dynamic as it usually affects model architecture
|
|
301
|
-
},
|
|
302
|
-
'output': {0: 'batch_size'}
|
|
303
|
-
}
|
|
304
|
-
else:
|
|
305
|
-
# For other shapes, just make batch size dynamic
|
|
306
|
-
dynamic_axes = {
|
|
307
|
-
'float_input': {0: 'batch_size'},
|
|
308
|
-
'output': {0: 'batch_size'}
|
|
309
|
-
}
|
|
310
|
-
|
|
311
|
-
torch.onnx.export(
|
|
312
|
-
model,
|
|
313
|
-
dummy_input,
|
|
314
|
-
onnx_buffer,
|
|
315
|
-
export_params=True,
|
|
316
|
-
opset_version=11,
|
|
317
|
-
do_constant_folding=True,
|
|
318
|
-
input_names=['float_input'],
|
|
319
|
-
output_names=['output'],
|
|
320
|
-
dynamic_axes=dynamic_axes
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
|
|
324
|
-
return onnx_buffer.getvalue()
|
|
325
|
-
|
|
326
|
-
except Exception as e:
|
|
327
|
-
raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
|
|
331
|
-
"""Convert scikit-learn model to ONNX format."""
|
|
332
|
-
self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
|
|
333
|
-
|
|
334
|
-
# Try to infer input shape if not provided
|
|
335
|
-
if input_shape is None:
|
|
336
|
-
if hasattr(model, 'n_features_in_'):
|
|
337
|
-
input_shape = (1, model.n_features_in_)
|
|
338
|
-
else:
|
|
339
|
-
raise ValueError(
|
|
340
|
-
"input_shape is required for scikit-learn models when n_features_in_ is not available. "
|
|
341
|
-
"This usually happens with older sklearn versions or models not fitted yet."
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
try:
|
|
345
|
-
# Convert scikit-learn model to ONNX
|
|
346
|
-
initial_type = [('float_input', FloatTensorType(input_shape))]
|
|
347
|
-
onnx_model = convert_sklearn(model, initial_types=initial_type)
|
|
348
|
-
return onnx_model.SerializeToString()
|
|
349
|
-
|
|
350
|
-
except Exception as e:
|
|
351
|
-
raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
|
|
352
|
-
|
|
185
|
+
headers = {
|
|
186
|
+
"Content-Type": "application/octet-stream",
|
|
187
|
+
"Content-Length": str(len(model))
|
|
188
|
+
}
|
|
189
|
+
response = await self._client._regular_request("PUT", endpoint = upload_model_url, data=model, headers=headers)
|
|
190
|
+
return response
|
|
191
|
+
|
|
192
|
+
@require_api_key
|
|
193
|
+
async def _upload_script_to_url(self, upload_script_url: str, script_content: str):
|
|
194
|
+
"""
|
|
195
|
+
Upload the generated script to the url
|
|
196
|
+
Args:
|
|
197
|
+
url: Url for the upload of the script
|
|
198
|
+
script_content: Content of the script
|
|
199
|
+
returns:
|
|
200
|
+
None
|
|
201
|
+
"""
|
|
202
|
+
script_bytes = script_content.encode('utf-8')
|
|
203
|
+
headers = {
|
|
204
|
+
"Content-Type": "text/x-python",
|
|
205
|
+
"Content-Length": str(len(script_bytes))
|
|
206
|
+
}
|
|
207
|
+
response = await self._client._regular_request("PUT", endpoint=upload_script_url, data=script_bytes, headers=headers)
|
|
208
|
+
return response
|
|
353
209
|
|
|
354
210
|
@require_api_key
|
|
355
|
-
async def
|
|
211
|
+
async def _upload_model_and_script(self, model, model_name: str, script_name: str, input_expression: str, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
|
|
356
212
|
"""
|
|
357
|
-
Upload a
|
|
358
|
-
|
|
213
|
+
Upload a model and script to the bucket
|
|
359
214
|
Args:
|
|
360
215
|
model: The model object (PyTorch model or scikit-learn model)
|
|
361
216
|
model_name: Name for the model (without extension)
|
|
362
|
-
|
|
363
|
-
product: Product name for the inference
|
|
217
|
+
script_name: Name for the script (without extension)
|
|
364
218
|
input_expression: Input expression for the dataset
|
|
365
|
-
dates_iso8601: List of dates in ISO8601 format
|
|
366
219
|
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
367
220
|
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
368
|
-
|
|
221
|
+
model_type: The type of the model we want to upload
|
|
369
222
|
Raises:
|
|
370
223
|
APIError: If the API request fails
|
|
371
224
|
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
372
|
-
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
bucket_name: Name of the bucket where the model is stored
|
|
373
228
|
"""
|
|
374
|
-
await self.
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
229
|
+
response = await self._get_url_for_upload_model_and_script(expression = input_expression, model_name = model_name, script_name = script_name)
|
|
230
|
+
model_url, script_url, bucket_name = response.get("model_upload_url"), response.get("script_upload_url"), response.get("bucket_name")
|
|
231
|
+
if not model_url or not script_url:
|
|
232
|
+
raise ValueError("No url returned from the server for the upload process")
|
|
233
|
+
try:
|
|
234
|
+
model_in_onnx_bytes, model_type = self._convert_model_to_onnx(model = model, input_shape = input_shape, model_type = model_type)
|
|
235
|
+
if model_type == "neural_network":
|
|
236
|
+
script_content = await self._generate_cnn_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
|
|
237
|
+
elif model_type == "random_forest":
|
|
238
|
+
script_content = await self._generate_random_forest_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"Unsupported model type: {model_type}. Supported types: neural_network, random_forest")
|
|
241
|
+
script_upload_response = await self._upload_script_to_url( upload_script_url = script_url, script_content = script_content)
|
|
242
|
+
if script_upload_response.status not in [200, 201, 204]:
|
|
243
|
+
self._client.logger.error(f"Script upload error: {script_upload_response.text()}")
|
|
244
|
+
raise Exception(f"Failed to upload script: {script_upload_response.text()}")
|
|
245
|
+
model_upload_response = await self._upload_model_to_url(upload_model_url = model_url, model = model_in_onnx_bytes)
|
|
246
|
+
if model_upload_response.status not in [200, 201, 204]:
|
|
247
|
+
self._client.logger.error(f"Model upload error: {model_upload_response.text()}")
|
|
248
|
+
raise Exception(f"Failed to upload model: {model_upload_response.text()}")
|
|
249
|
+
except Exception as e:
|
|
250
|
+
raise Exception(f"Error uploading model: {e}")
|
|
251
|
+
self._client.logger.info(f"Model and Script uploaded successfully to {model_url}")
|
|
252
|
+
return bucket_name
|
|
378
253
|
|
|
379
254
|
@require_api_key
|
|
380
|
-
async def upload_and_deploy_model(self, model,
|
|
255
|
+
async def upload_and_deploy_model(self, model, virtual_dataset_name: str, virtual_product_name: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
|
|
381
256
|
"""
|
|
382
257
|
Upload a model to the bucket and deploy it.
|
|
383
|
-
|
|
384
258
|
Args:
|
|
385
259
|
model: The model object (PyTorch model or scikit-learn model)
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
product: Product name for the inference
|
|
260
|
+
virtual_dataset_name: Name for the virtual dataset (without extension)
|
|
261
|
+
virtual_product_name: Product name for the inference
|
|
389
262
|
input_expression: Input expression for the dataset
|
|
390
263
|
dates_iso8601: List of dates in ISO8601 format
|
|
391
264
|
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
|
|
265
|
+
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
266
|
+
model_type: The type of the model we want to upload
|
|
395
267
|
|
|
396
|
-
@require_api_key
|
|
397
|
-
def train_model(
|
|
398
|
-
self,
|
|
399
|
-
model_name: str,
|
|
400
|
-
training_dataset: str,
|
|
401
|
-
task_type: str,
|
|
402
|
-
model_category: str,
|
|
403
|
-
architecture: str,
|
|
404
|
-
region: str,
|
|
405
|
-
hyperparameters: dict = None
|
|
406
|
-
) -> dict:
|
|
407
|
-
"""
|
|
408
|
-
Train a model using the external model training API.
|
|
409
|
-
|
|
410
|
-
Args:
|
|
411
|
-
model_name (str): The name of the model to train.
|
|
412
|
-
training_dataset (str): The training dataset identifier.
|
|
413
|
-
task_type (str): The type of ML task (e.g., regression, classification).
|
|
414
|
-
model_category (str): The category of model (e.g., random_forest).
|
|
415
|
-
architecture (str): The model architecture.
|
|
416
|
-
region (str): The region identifier.
|
|
417
|
-
hyperparameters (dict, optional): Additional hyperparameters for training.
|
|
418
|
-
|
|
419
|
-
Returns:
|
|
420
|
-
dict: The response from the model training API.
|
|
421
|
-
|
|
422
268
|
Raises:
|
|
423
269
|
APIError: If the API request fails
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
"model_name": model_name,
|
|
427
|
-
"training_dataset": training_dataset,
|
|
428
|
-
"task_type": task_type,
|
|
429
|
-
"model_category": model_category,
|
|
430
|
-
"architecture": architecture,
|
|
431
|
-
"region": region,
|
|
432
|
-
"hyperparameters": hyperparameters
|
|
433
|
-
}
|
|
434
|
-
return self._client._terrakio_request("POST", "/train_model", json=payload)
|
|
270
|
+
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
271
|
+
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
435
272
|
|
|
436
|
-
@require_api_key
|
|
437
|
-
async def deploy_model(
|
|
438
|
-
self,
|
|
439
|
-
dataset: str,
|
|
440
|
-
product: str,
|
|
441
|
-
model_name: str,
|
|
442
|
-
input_expression: str,
|
|
443
|
-
model_training_job_name: str,
|
|
444
|
-
dates_iso8601: list
|
|
445
|
-
) -> Dict[str, Any]:
|
|
446
|
-
"""
|
|
447
|
-
Deploy a model by generating inference script and creating dataset.
|
|
448
|
-
|
|
449
|
-
Args:
|
|
450
|
-
dataset: Name of the dataset to create
|
|
451
|
-
product: Product name for the inference
|
|
452
|
-
model_name: Name of the trained model
|
|
453
|
-
input_expression: Input expression for the dataset
|
|
454
|
-
model_training_job_name: Name of the training job
|
|
455
|
-
dates_iso8601: List of dates in ISO8601 format
|
|
456
|
-
|
|
457
273
|
Returns:
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
Raises:
|
|
461
|
-
APIError: If the API request fails
|
|
274
|
+
None
|
|
462
275
|
"""
|
|
463
|
-
|
|
276
|
+
bucket_name = await self._upload_model_and_script(model=model, model_name=virtual_dataset_name, script_name= virtual_product_name, input_shape=input_shape, input_expression=input_expression, processing_script_path=processing_script_path, model_type= model_type)
|
|
464
277
|
user_info = await self._client.auth.get_user_info()
|
|
465
278
|
uid = user_info["uid"]
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
469
|
-
script_name = f"{product}.py"
|
|
470
|
-
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
471
|
-
|
|
472
|
-
# Create dataset
|
|
473
|
-
return await self._client.datasets.create_dataset(
|
|
474
|
-
name=dataset,
|
|
279
|
+
await self._client.datasets.create_dataset(
|
|
280
|
+
name=virtual_dataset_name,
|
|
475
281
|
collection="terrakio-datasets",
|
|
476
|
-
products=[
|
|
477
|
-
path=f"gs://
|
|
282
|
+
products=[virtual_product_name],
|
|
283
|
+
path=f"gs://{bucket_name}/{uid}/virtual_datasets/{virtual_dataset_name}/inference_scripts",
|
|
478
284
|
input=input_expression,
|
|
479
285
|
dates_iso8601=dates_iso8601,
|
|
480
286
|
padding=0
|
|
481
287
|
)
|
|
482
288
|
|
|
483
|
-
|
|
289
|
+
@require_api_key
|
|
290
|
+
async def _generate_random_forest_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
|
|
484
291
|
"""
|
|
485
|
-
|
|
292
|
+
Generate Python inference script for the Random Forest model.
|
|
486
293
|
|
|
487
294
|
Args:
|
|
488
|
-
|
|
295
|
+
bucket_name: Name of the bucket where the model is stored
|
|
296
|
+
virtual_dataset_name: Name of the virtual dataset and the model
|
|
297
|
+
virtual_product_name: Name of the virtual product
|
|
298
|
+
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
489
299
|
|
|
490
300
|
Returns:
|
|
491
|
-
|
|
301
|
+
str: Generated Python script content
|
|
492
302
|
"""
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
303
|
+
user_info = await self._client.auth.get_user_info()
|
|
304
|
+
uid = user_info["uid"]
|
|
305
|
+
preprocessing_code, postprocessing_code = None, None
|
|
306
|
+
|
|
307
|
+
if processing_script_path:
|
|
308
|
+
try:
|
|
309
|
+
preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
|
|
310
|
+
if preprocessing_code:
|
|
311
|
+
self._client.logger.info(f"Using custom preprocessing from: {processing_script_path}")
|
|
312
|
+
if postprocessing_code:
|
|
313
|
+
self._client.logger.info(f"Using custom postprocessing from: {processing_script_path}")
|
|
314
|
+
if not preprocessing_code and not postprocessing_code:
|
|
315
|
+
self._client.logger.warning(f"No preprocessing or postprocessing functions found in {processing_script_path}")
|
|
316
|
+
self._client.logger.info("Deployment will continue without custom processing")
|
|
317
|
+
except Exception as e:
|
|
318
|
+
raise ValueError(f"Failed to load processing script: {str(e)}")
|
|
319
|
+
|
|
320
|
+
preprocessing_section = ""
|
|
321
|
+
if preprocessing_code and preprocessing_code.strip():
|
|
322
|
+
clean_preprocessing = textwrap.dedent(preprocessing_code)
|
|
323
|
+
preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
|
|
511
324
|
|
|
512
|
-
|
|
513
|
-
postprocessing_code
|
|
325
|
+
postprocessing_section = ""
|
|
326
|
+
if postprocessing_code and postprocessing_code.strip():
|
|
327
|
+
clean_postprocessing = textwrap.dedent(postprocessing_code)
|
|
328
|
+
postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
|
|
329
|
+
|
|
330
|
+
script_lines = [
|
|
331
|
+
"import logging",
|
|
332
|
+
"from io import BytesIO",
|
|
333
|
+
"import numpy as np",
|
|
334
|
+
"import pandas as pd",
|
|
335
|
+
"import xarray as xr",
|
|
336
|
+
"from google.cloud import storage",
|
|
337
|
+
"from onnxruntime import InferenceSession",
|
|
338
|
+
"from typing import Tuple",
|
|
339
|
+
"",
|
|
340
|
+
"logging.basicConfig(",
|
|
341
|
+
" level=logging.INFO",
|
|
342
|
+
")",
|
|
343
|
+
"",
|
|
344
|
+
]
|
|
514
345
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
346
|
+
if preprocessing_section:
|
|
347
|
+
script_lines.extend([
|
|
348
|
+
"def validate_preprocessing_output(data_arrays):",
|
|
349
|
+
" \"\"\"",
|
|
350
|
+
" Validate preprocessing output coordinates and data type.",
|
|
351
|
+
" ",
|
|
352
|
+
" Args:",
|
|
353
|
+
" data_arrays: List of xarray DataArrays from preprocessing",
|
|
354
|
+
" ",
|
|
355
|
+
" Returns:",
|
|
356
|
+
" str: Validation signature symbol",
|
|
357
|
+
" ",
|
|
358
|
+
" Raises:",
|
|
359
|
+
" ValueError: If validation fails",
|
|
360
|
+
" \"\"\"",
|
|
361
|
+
" import numpy as np",
|
|
362
|
+
" ",
|
|
363
|
+
" if not data_arrays:",
|
|
364
|
+
" raise ValueError(\"No data arrays provided from preprocessing\")",
|
|
365
|
+
" ",
|
|
366
|
+
" reference_shape = None",
|
|
367
|
+
" ",
|
|
368
|
+
" for i, data_array in enumerate(data_arrays):",
|
|
369
|
+
" # Check if it's an xarray DataArray",
|
|
370
|
+
" if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
|
|
371
|
+
" raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
|
|
372
|
+
" ",
|
|
373
|
+
" # Check coordinates",
|
|
374
|
+
" if 'time' not in data_array.coords:",
|
|
375
|
+
" raise ValueError(f\"Channel {i+1} missing time coordinate\")",
|
|
376
|
+
" ",
|
|
377
|
+
" spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
|
|
378
|
+
" if len(spatial_dims) != 2:",
|
|
379
|
+
" raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
|
|
380
|
+
" ",
|
|
381
|
+
" for dim in spatial_dims:",
|
|
382
|
+
" if dim not in data_array.coords:",
|
|
383
|
+
" raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
|
|
384
|
+
" ",
|
|
385
|
+
" # Check shape consistency",
|
|
386
|
+
" shape = data_array.shape",
|
|
387
|
+
" if reference_shape is None:",
|
|
388
|
+
" reference_shape = shape",
|
|
389
|
+
" else:",
|
|
390
|
+
" if shape != reference_shape:",
|
|
391
|
+
" raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
|
|
392
|
+
" ",
|
|
393
|
+
" # Generate validation signature",
|
|
394
|
+
" signature_components = [",
|
|
395
|
+
" f\"CH{len(data_arrays)}\", # Channel count",
|
|
396
|
+
" f\"T{reference_shape[0]}\", # Time dimension",
|
|
397
|
+
" f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
|
|
398
|
+
" f\"DT{data_arrays[0].values.dtype}\", # Data type",
|
|
399
|
+
" ]",
|
|
400
|
+
" ",
|
|
401
|
+
" signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
|
|
402
|
+
" ",
|
|
403
|
+
" return signature",
|
|
404
|
+
"",
|
|
405
|
+
])
|
|
524
406
|
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
407
|
+
if postprocessing_section:
|
|
408
|
+
script_lines.extend([
|
|
409
|
+
"def validate_postprocessing_output(result_array):",
|
|
410
|
+
" \"\"\"",
|
|
411
|
+
" Validate postprocessing output coordinates and data type.",
|
|
412
|
+
" ",
|
|
413
|
+
" Args:",
|
|
414
|
+
" result_array: xarray DataArray from postprocessing",
|
|
415
|
+
" ",
|
|
416
|
+
" Returns:",
|
|
417
|
+
" str: Validation signature symbol",
|
|
418
|
+
" ",
|
|
419
|
+
" Raises:",
|
|
420
|
+
" ValueError: If validation fails",
|
|
421
|
+
" \"\"\"",
|
|
422
|
+
" import numpy as np",
|
|
423
|
+
" ",
|
|
424
|
+
" # Check if it's an xarray DataArray",
|
|
425
|
+
" if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
|
|
426
|
+
" raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
|
|
427
|
+
" ",
|
|
428
|
+
" # Check required coordinates",
|
|
429
|
+
" if 'time' not in result_array.coords:",
|
|
430
|
+
" raise ValueError(\"Missing time coordinate\")",
|
|
431
|
+
" ",
|
|
432
|
+
" spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
|
|
433
|
+
" if len(spatial_dims) != 2:",
|
|
434
|
+
" raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
|
|
435
|
+
" ",
|
|
436
|
+
" for dim in spatial_dims:",
|
|
437
|
+
" if dim not in result_array.coords:",
|
|
438
|
+
" raise ValueError(f\"Missing spatial coordinate: {dim}\")",
|
|
439
|
+
" ",
|
|
440
|
+
" # Check shape",
|
|
441
|
+
" shape = result_array.shape",
|
|
442
|
+
" ",
|
|
443
|
+
" # Generate validation signature",
|
|
444
|
+
" signature_components = [",
|
|
445
|
+
" f\"T{shape[0]}\", # Time dimension",
|
|
446
|
+
" f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
|
|
447
|
+
" f\"DT{result_array.values.dtype}\", # Data type",
|
|
448
|
+
" ]",
|
|
449
|
+
" ",
|
|
450
|
+
" signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
|
|
451
|
+
" ",
|
|
452
|
+
" return signature",
|
|
453
|
+
"",
|
|
454
|
+
])
|
|
535
455
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
456
|
+
if preprocessing_section:
|
|
457
|
+
script_lines.extend([
|
|
458
|
+
"def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
|
|
459
|
+
preprocessing_section,
|
|
460
|
+
"",
|
|
461
|
+
])
|
|
541
462
|
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
463
|
+
if postprocessing_section:
|
|
464
|
+
script_lines.extend([
|
|
465
|
+
"def postprocessing(array: xr.DataArray) -> xr.DataArray:",
|
|
466
|
+
postprocessing_section,
|
|
467
|
+
"",
|
|
468
|
+
])
|
|
545
469
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
470
|
+
script_lines.extend([
|
|
471
|
+
"def get_model():",
|
|
472
|
+
f" logging.info(\"Loading Random Forest model for {virtual_dataset_name}...\")",
|
|
473
|
+
"",
|
|
474
|
+
" client = storage.Client()",
|
|
475
|
+
f" bucket = client.get_bucket('{bucket_name}')",
|
|
476
|
+
f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
|
|
477
|
+
"",
|
|
478
|
+
" model = BytesIO()",
|
|
479
|
+
" blob.download_to_file(model)",
|
|
480
|
+
" model.seek(0)",
|
|
481
|
+
"",
|
|
482
|
+
" session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
|
|
483
|
+
" return session",
|
|
484
|
+
"",
|
|
485
|
+
f"def {virtual_product_name}(*bands, model):",
|
|
486
|
+
" logging.info(\"Start preparing Random Forest data\")",
|
|
487
|
+
" data_arrays = list(bands)",
|
|
488
|
+
" ",
|
|
489
|
+
" if not data_arrays:",
|
|
490
|
+
" raise ValueError(\"No bands provided\")",
|
|
491
|
+
" ",
|
|
492
|
+
])
|
|
551
493
|
|
|
552
|
-
if
|
|
553
|
-
|
|
494
|
+
if preprocessing_section:
|
|
495
|
+
script_lines.extend([
|
|
496
|
+
" # Apply preprocessing",
|
|
497
|
+
" data_arrays = preprocessing(tuple(data_arrays))",
|
|
498
|
+
" data_arrays = list(data_arrays) # Convert back to list for processing",
|
|
499
|
+
" ",
|
|
500
|
+
" # Validate preprocessing output",
|
|
501
|
+
" preprocessing_signature = validate_preprocessing_output(data_arrays)",
|
|
502
|
+
" ",
|
|
503
|
+
])
|
|
554
504
|
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
505
|
+
script_lines.extend([
|
|
506
|
+
" reference_array = data_arrays[0]",
|
|
507
|
+
" original_shape = reference_array.shape",
|
|
508
|
+
" ",
|
|
509
|
+
" if 'time' in reference_array.dims:",
|
|
510
|
+
" time_coords = reference_array.coords['time']",
|
|
511
|
+
" if len(time_coords) == 1:",
|
|
512
|
+
" output_timestamp = time_coords[0]",
|
|
513
|
+
" else:",
|
|
514
|
+
" years = [pd.to_datetime(t).year for t in time_coords.values]",
|
|
515
|
+
" unique_years = set(years)",
|
|
516
|
+
" ",
|
|
517
|
+
" if len(unique_years) == 1:",
|
|
518
|
+
" year = list(unique_years)[0]",
|
|
519
|
+
" output_timestamp = pd.Timestamp(f\"{year}-01-01\")",
|
|
520
|
+
" else:",
|
|
521
|
+
" latest_year = max(unique_years)",
|
|
522
|
+
" output_timestamp = pd.Timestamp(f\"{latest_year}-01-01\")",
|
|
523
|
+
" else:",
|
|
524
|
+
" output_timestamp = pd.Timestamp(\"1970-01-01\")",
|
|
525
|
+
"",
|
|
526
|
+
" averaged_bands = []",
|
|
527
|
+
" for data_array in data_arrays:",
|
|
528
|
+
" if 'time' in data_array.dims:",
|
|
529
|
+
" averaged_band = np.mean(data_array.values, axis=0)",
|
|
530
|
+
" else:",
|
|
531
|
+
" averaged_band = data_array.values",
|
|
532
|
+
"",
|
|
533
|
+
" flattened_band = averaged_band.reshape(-1, 1)",
|
|
534
|
+
" averaged_bands.append(flattened_band)",
|
|
535
|
+
"",
|
|
536
|
+
" input_data = np.hstack(averaged_bands)",
|
|
537
|
+
"",
|
|
538
|
+
" output = model.run(None, {\"float_input\": input_data.astype(np.float32)})[0]",
|
|
539
|
+
"",
|
|
540
|
+
" if len(original_shape) >= 3:",
|
|
541
|
+
" spatial_shape = original_shape[1:]",
|
|
542
|
+
" else:",
|
|
543
|
+
" spatial_shape = original_shape",
|
|
544
|
+
"",
|
|
545
|
+
" output_reshaped = output.reshape(spatial_shape)",
|
|
546
|
+
"",
|
|
547
|
+
" output_with_time = np.expand_dims(output_reshaped, axis=0)",
|
|
548
|
+
"",
|
|
549
|
+
" if 'time' in reference_array.dims:",
|
|
550
|
+
" spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
|
|
551
|
+
" spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}",
|
|
552
|
+
" else:",
|
|
553
|
+
" spatial_dims = list(reference_array.dims)",
|
|
554
|
+
" spatial_coords = dict(reference_array.coords)",
|
|
555
|
+
"",
|
|
556
|
+
" result = xr.DataArray(",
|
|
557
|
+
" data=output_with_time.astype(np.float32),",
|
|
558
|
+
" dims=['time'] + list(spatial_dims),",
|
|
559
|
+
" coords={",
|
|
560
|
+
" 'time': [output_timestamp.values],",
|
|
561
|
+
" 'y': spatial_coords['y'].values,",
|
|
562
|
+
" 'x': spatial_coords['x'].values",
|
|
563
|
+
" },",
|
|
564
|
+
" attrs={",
|
|
565
|
+
" 'description': 'Random Forest model prediction',",
|
|
566
|
+
" }",
|
|
567
|
+
" )",
|
|
568
|
+
])
|
|
558
569
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
570
|
+
if postprocessing_section:
|
|
571
|
+
script_lines.extend([
|
|
572
|
+
" # Apply postprocessing",
|
|
573
|
+
" result = postprocessing(result)",
|
|
574
|
+
" ",
|
|
575
|
+
" # Validate postprocessing output",
|
|
576
|
+
" postprocessing_signature = validate_postprocessing_output(result)",
|
|
577
|
+
" ",
|
|
578
|
+
])
|
|
562
579
|
|
|
563
|
-
return
|
|
580
|
+
script_lines.append(" return result")
|
|
564
581
|
|
|
582
|
+
return "\n".join(script_lines)
|
|
583
|
+
|
|
565
584
|
@require_api_key
|
|
566
|
-
async def
|
|
567
|
-
self,
|
|
568
|
-
dataset: str,
|
|
569
|
-
product: str,
|
|
570
|
-
model_name: str,
|
|
571
|
-
input_expression: str,
|
|
572
|
-
model_training_job_name: str,
|
|
573
|
-
dates_iso8601: list,
|
|
574
|
-
processing_script_path: Optional[str] = None
|
|
575
|
-
) -> Dict[str, Any]:
|
|
585
|
+
async def _generate_cnn_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
|
|
576
586
|
"""
|
|
577
|
-
|
|
587
|
+
Generate Python inference script for CNN model with time-stacked bands.
|
|
578
588
|
|
|
579
589
|
Args:
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
input_expression: Input expression for the dataset
|
|
584
|
-
model_training_job_name: Name of the training job
|
|
585
|
-
dates_iso8601: List of dates in ISO8601 format
|
|
590
|
+
bucket_name: Name of the bucket where the model is stored
|
|
591
|
+
virtual_dataset_name: Name of the virtual dataset and the model
|
|
592
|
+
virtual_product_name: Name of the virtual product
|
|
586
593
|
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
587
594
|
Returns:
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
Raises:
|
|
591
|
-
APIError: If the API request fails
|
|
595
|
+
str: Generated Python script content
|
|
592
596
|
"""
|
|
593
|
-
# Get user info to get UID
|
|
594
597
|
user_info = await self._client.auth.get_user_info()
|
|
595
598
|
uid = user_info["uid"]
|
|
596
|
-
|
|
597
599
|
preprocessing_code, postprocessing_code = None, None
|
|
600
|
+
|
|
598
601
|
if processing_script_path:
|
|
599
|
-
# if there is a function that is being passed in
|
|
600
602
|
try:
|
|
601
603
|
preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
|
|
602
604
|
if preprocessing_code:
|
|
@@ -608,176 +610,17 @@ class ModelManagement:
|
|
|
608
610
|
self._client.logger.info("Deployment will continue without custom processing")
|
|
609
611
|
except Exception as e:
|
|
610
612
|
raise ValueError(f"Failed to load processing script: {str(e)}")
|
|
611
|
-
# so we already have the preprocessing code and the post processing code, I need to pass them to the generate cnn script function
|
|
612
|
-
# Generate and upload script
|
|
613
|
-
# Build preprocessing section with CONSISTENT 8-space indentation
|
|
614
|
-
preprocessing_section = ""
|
|
615
|
-
if preprocessing_code and preprocessing_code.strip():
|
|
616
|
-
# First dedent the preprocessing code to remove any existing indentation
|
|
617
|
-
clean_preprocessing = preprocessing_code
|
|
618
|
-
# Then add consistent 8-space indentation to match the template
|
|
619
|
-
preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
|
|
620
|
-
print(preprocessing_section)
|
|
621
|
-
script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
|
|
622
|
-
script_name = f"{product}.py"
|
|
623
|
-
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
624
|
-
# Create dataset
|
|
625
|
-
return await self._client.datasets.create_dataset(
|
|
626
|
-
name=dataset,
|
|
627
|
-
collection="terrakio-datasets",
|
|
628
|
-
products=[product],
|
|
629
|
-
path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
|
|
630
|
-
input=input_expression,
|
|
631
|
-
dates_iso8601=dates_iso8601,
|
|
632
|
-
padding=0
|
|
633
|
-
)
|
|
634
|
-
|
|
635
|
-
@require_api_key
|
|
636
|
-
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
637
|
-
"""
|
|
638
|
-
Generate Python inference script for the model.
|
|
639
|
-
|
|
640
|
-
Args:
|
|
641
|
-
model_name: Name of the model
|
|
642
|
-
product: Product name
|
|
643
|
-
model_training_job_name: Training job name
|
|
644
|
-
uid: User ID
|
|
645
|
-
|
|
646
|
-
Returns:
|
|
647
|
-
str: Generated Python script content
|
|
648
|
-
"""
|
|
649
|
-
return textwrap.dedent(f'''
|
|
650
|
-
import logging
|
|
651
|
-
from io import BytesIO
|
|
652
|
-
|
|
653
|
-
import numpy as np
|
|
654
|
-
import pandas as pd
|
|
655
|
-
import xarray as xr
|
|
656
|
-
from google.cloud import storage
|
|
657
|
-
from onnxruntime import InferenceSession
|
|
658
|
-
|
|
659
|
-
logging.basicConfig(
|
|
660
|
-
level=logging.INFO
|
|
661
|
-
)
|
|
662
|
-
|
|
663
|
-
def get_model():
|
|
664
|
-
logging.info("Loading model for {model_name}...")
|
|
665
|
-
|
|
666
|
-
client = storage.Client()
|
|
667
|
-
bucket = client.get_bucket('terrakio-mass-requests')
|
|
668
|
-
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
669
|
-
|
|
670
|
-
model = BytesIO()
|
|
671
|
-
blob.download_to_file(model)
|
|
672
|
-
model.seek(0)
|
|
673
|
-
|
|
674
|
-
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
675
|
-
return session
|
|
676
|
-
|
|
677
|
-
def {product}(*bands, model):
|
|
678
|
-
logging.info("start preparing data")
|
|
679
|
-
|
|
680
|
-
data_arrays = list(bands)
|
|
681
|
-
|
|
682
|
-
reference_array = data_arrays[0]
|
|
683
|
-
original_shape = reference_array.shape
|
|
684
|
-
logging.info(f"Original shape: {{original_shape}}")
|
|
685
613
|
|
|
686
|
-
if 'time' in reference_array.dims:
|
|
687
|
-
time_coords = reference_array.coords['time']
|
|
688
|
-
if len(time_coords) == 1:
|
|
689
|
-
output_timestamp = time_coords[0]
|
|
690
|
-
else:
|
|
691
|
-
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
692
|
-
unique_years = set(years)
|
|
693
|
-
|
|
694
|
-
if len(unique_years) == 1:
|
|
695
|
-
year = list(unique_years)[0]
|
|
696
|
-
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
697
|
-
else:
|
|
698
|
-
latest_year = max(unique_years)
|
|
699
|
-
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
700
|
-
else:
|
|
701
|
-
output_timestamp = pd.Timestamp("1970-01-01")
|
|
702
|
-
|
|
703
|
-
averaged_bands = []
|
|
704
|
-
for data_array in data_arrays:
|
|
705
|
-
if 'time' in data_array.dims:
|
|
706
|
-
averaged_band = np.mean(data_array.values, axis=0)
|
|
707
|
-
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
708
|
-
else:
|
|
709
|
-
averaged_band = data_array.values
|
|
710
|
-
logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
|
|
711
|
-
|
|
712
|
-
flattened_band = averaged_band.reshape(-1, 1)
|
|
713
|
-
averaged_bands.append(flattened_band)
|
|
714
|
-
|
|
715
|
-
input_data = np.hstack(averaged_bands)
|
|
716
|
-
|
|
717
|
-
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
718
|
-
|
|
719
|
-
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
720
|
-
|
|
721
|
-
logging.info(f"Model output shape: {{output.shape}}")
|
|
722
|
-
|
|
723
|
-
if len(original_shape) >= 3:
|
|
724
|
-
spatial_shape = original_shape[1:]
|
|
725
|
-
else:
|
|
726
|
-
spatial_shape = original_shape
|
|
727
|
-
|
|
728
|
-
output_reshaped = output.reshape(spatial_shape)
|
|
729
|
-
|
|
730
|
-
output_with_time = np.expand_dims(output_reshaped, axis=0)
|
|
731
|
-
|
|
732
|
-
if 'time' in reference_array.dims:
|
|
733
|
-
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
734
|
-
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
|
|
735
|
-
else:
|
|
736
|
-
spatial_dims = list(reference_array.dims)
|
|
737
|
-
spatial_coords = dict(reference_array.coords)
|
|
738
|
-
|
|
739
|
-
result = xr.DataArray(
|
|
740
|
-
data=output_with_time.astype(np.float32),
|
|
741
|
-
dims=['time'] + list(spatial_dims),
|
|
742
|
-
coords={{
|
|
743
|
-
'time': [output_timestamp.values],
|
|
744
|
-
'y': spatial_coords['y'].values,
|
|
745
|
-
'x': spatial_coords['x'].values
|
|
746
|
-
}}
|
|
747
|
-
)
|
|
748
|
-
return result
|
|
749
|
-
''').strip()
|
|
750
|
-
|
|
751
|
-
@require_api_key
|
|
752
|
-
def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
|
|
753
|
-
"""
|
|
754
|
-
Generate Python inference script for CNN model with time-stacked bands.
|
|
755
|
-
|
|
756
|
-
Args:
|
|
757
|
-
model_name: Name of the model
|
|
758
|
-
product: Product name
|
|
759
|
-
model_training_job_name: Training job name
|
|
760
|
-
uid: User ID
|
|
761
|
-
preprocessing_code: Preprocessing code
|
|
762
|
-
postprocessing_code: Postprocessing code
|
|
763
|
-
Returns:
|
|
764
|
-
str: Generated Python script content
|
|
765
|
-
"""
|
|
766
|
-
import textwrap
|
|
767
|
-
|
|
768
|
-
# Build preprocessing section with CONSISTENT 4-space indentation
|
|
769
614
|
preprocessing_section = ""
|
|
770
615
|
if preprocessing_code and preprocessing_code.strip():
|
|
771
616
|
clean_preprocessing = textwrap.dedent(preprocessing_code)
|
|
772
617
|
preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
|
|
773
618
|
|
|
774
|
-
# Build postprocessing section with CONSISTENT 4-space indentation
|
|
775
619
|
postprocessing_section = ""
|
|
776
620
|
if postprocessing_code and postprocessing_code.strip():
|
|
777
621
|
clean_postprocessing = textwrap.dedent(postprocessing_code)
|
|
778
622
|
postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
|
|
779
623
|
|
|
780
|
-
# Build the template WITHOUT dedenting the whole thing, so indentation is preserved
|
|
781
624
|
script_lines = [
|
|
782
625
|
"import logging",
|
|
783
626
|
"from io import BytesIO",
|
|
@@ -794,7 +637,116 @@ class ModelManagement:
|
|
|
794
637
|
"",
|
|
795
638
|
]
|
|
796
639
|
|
|
797
|
-
|
|
640
|
+
if preprocessing_section:
|
|
641
|
+
script_lines.extend([
|
|
642
|
+
"def validate_preprocessing_output(data_arrays):",
|
|
643
|
+
" \"\"\"",
|
|
644
|
+
" Validate preprocessing output coordinates and data type.",
|
|
645
|
+
" ",
|
|
646
|
+
" Args:",
|
|
647
|
+
" data_arrays: List of xarray DataArrays from preprocessing",
|
|
648
|
+
" ",
|
|
649
|
+
" Returns:",
|
|
650
|
+
" str: Validation signature symbol",
|
|
651
|
+
" ",
|
|
652
|
+
" Raises:",
|
|
653
|
+
" ValueError: If validation fails",
|
|
654
|
+
" \"\"\"",
|
|
655
|
+
" import numpy as np",
|
|
656
|
+
" ",
|
|
657
|
+
" if not data_arrays:",
|
|
658
|
+
" raise ValueError(\"No data arrays provided from preprocessing\")",
|
|
659
|
+
" ",
|
|
660
|
+
" reference_shape = None",
|
|
661
|
+
" ",
|
|
662
|
+
" for i, data_array in enumerate(data_arrays):",
|
|
663
|
+
" # Check if it's an xarray DataArray",
|
|
664
|
+
" if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
|
|
665
|
+
" raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
|
|
666
|
+
" ",
|
|
667
|
+
" # Check coordinates",
|
|
668
|
+
" if 'time' not in data_array.coords:",
|
|
669
|
+
" raise ValueError(f\"Channel {i+1} missing time coordinate\")",
|
|
670
|
+
" ",
|
|
671
|
+
" spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
|
|
672
|
+
" if len(spatial_dims) != 2:",
|
|
673
|
+
" raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
|
|
674
|
+
" ",
|
|
675
|
+
" for dim in spatial_dims:",
|
|
676
|
+
" if dim not in data_array.coords:",
|
|
677
|
+
" raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
|
|
678
|
+
" ",
|
|
679
|
+
" # Check shape consistency",
|
|
680
|
+
" shape = data_array.shape",
|
|
681
|
+
" if reference_shape is None:",
|
|
682
|
+
" reference_shape = shape",
|
|
683
|
+
" else:",
|
|
684
|
+
" if shape != reference_shape:",
|
|
685
|
+
" raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
|
|
686
|
+
" ",
|
|
687
|
+
" # Generate validation signature",
|
|
688
|
+
" signature_components = [",
|
|
689
|
+
" f\"CH{len(data_arrays)}\", # Channel count",
|
|
690
|
+
" f\"T{reference_shape[0]}\", # Time dimension",
|
|
691
|
+
" f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
|
|
692
|
+
" f\"DT{data_arrays[0].values.dtype}\", # Data type",
|
|
693
|
+
" ]",
|
|
694
|
+
" ",
|
|
695
|
+
" signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
|
|
696
|
+
" ",
|
|
697
|
+
" return signature",
|
|
698
|
+
"",
|
|
699
|
+
])
|
|
700
|
+
|
|
701
|
+
if postprocessing_section:
|
|
702
|
+
script_lines.extend([
|
|
703
|
+
"def validate_postprocessing_output(result_array):",
|
|
704
|
+
" \"\"\"",
|
|
705
|
+
" Validate postprocessing output coordinates and data type.",
|
|
706
|
+
" ",
|
|
707
|
+
" Args:",
|
|
708
|
+
" result_array: xarray DataArray from postprocessing",
|
|
709
|
+
" ",
|
|
710
|
+
" Returns:",
|
|
711
|
+
" str: Validation signature symbol",
|
|
712
|
+
" ",
|
|
713
|
+
" Raises:",
|
|
714
|
+
" ValueError: If validation fails",
|
|
715
|
+
" \"\"\"",
|
|
716
|
+
" import numpy as np",
|
|
717
|
+
" ",
|
|
718
|
+
" # Check if it's an xarray DataArray",
|
|
719
|
+
" if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
|
|
720
|
+
" raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
|
|
721
|
+
" ",
|
|
722
|
+
" # Check required coordinates",
|
|
723
|
+
" if 'time' not in result_array.coords:",
|
|
724
|
+
" raise ValueError(\"Missing time coordinate\")",
|
|
725
|
+
" ",
|
|
726
|
+
" spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
|
|
727
|
+
" if len(spatial_dims) != 2:",
|
|
728
|
+
" raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
|
|
729
|
+
" ",
|
|
730
|
+
" for dim in spatial_dims:",
|
|
731
|
+
" if dim not in result_array.coords:",
|
|
732
|
+
" raise ValueError(f\"Missing spatial coordinate: {dim}\")",
|
|
733
|
+
" ",
|
|
734
|
+
" # Check shape",
|
|
735
|
+
" shape = result_array.shape",
|
|
736
|
+
" ",
|
|
737
|
+
" # Generate validation signature",
|
|
738
|
+
" signature_components = [",
|
|
739
|
+
" f\"T{shape[0]}\", # Time dimension",
|
|
740
|
+
" f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
|
|
741
|
+
" f\"DT{result_array.values.dtype}\", # Data type",
|
|
742
|
+
" ]",
|
|
743
|
+
" ",
|
|
744
|
+
" signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
|
|
745
|
+
" ",
|
|
746
|
+
" return signature",
|
|
747
|
+
"",
|
|
748
|
+
])
|
|
749
|
+
|
|
798
750
|
if preprocessing_section:
|
|
799
751
|
script_lines.extend([
|
|
800
752
|
"def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
|
|
@@ -802,7 +754,6 @@ class ModelManagement:
|
|
|
802
754
|
"",
|
|
803
755
|
])
|
|
804
756
|
|
|
805
|
-
# Add postprocessing function definition BEFORE the main function
|
|
806
757
|
if postprocessing_section:
|
|
807
758
|
script_lines.extend([
|
|
808
759
|
"def postprocessing(array: xr.DataArray) -> xr.DataArray:",
|
|
@@ -810,14 +761,13 @@ class ModelManagement:
|
|
|
810
761
|
"",
|
|
811
762
|
])
|
|
812
763
|
|
|
813
|
-
# Add the get_model function
|
|
814
764
|
script_lines.extend([
|
|
815
765
|
"def get_model():",
|
|
816
|
-
f" logging.info(\"Loading CNN model for {
|
|
766
|
+
f" logging.info(\"Loading CNN model for {virtual_dataset_name}...\")",
|
|
817
767
|
"",
|
|
818
768
|
" client = storage.Client()",
|
|
819
|
-
" bucket = client.get_bucket('
|
|
820
|
-
f" blob = bucket.blob('{uid}/{
|
|
769
|
+
f" bucket = client.get_bucket('{bucket_name}')",
|
|
770
|
+
f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
|
|
821
771
|
"",
|
|
822
772
|
" model = BytesIO()",
|
|
823
773
|
" blob.download_to_file(model)",
|
|
@@ -826,7 +776,7 @@ class ModelManagement:
|
|
|
826
776
|
" session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
|
|
827
777
|
" return session",
|
|
828
778
|
"",
|
|
829
|
-
f"def {
|
|
779
|
+
f"def {virtual_product_name}(*bands, model):",
|
|
830
780
|
" logging.info(\"Start preparing CNN data with time-stacked bands\")",
|
|
831
781
|
" data_arrays = list(bands)",
|
|
832
782
|
" ",
|
|
@@ -835,20 +785,20 @@ class ModelManagement:
|
|
|
835
785
|
" ",
|
|
836
786
|
])
|
|
837
787
|
|
|
838
|
-
# Add preprocessing call if preprocessing exists
|
|
839
788
|
if preprocessing_section:
|
|
840
789
|
script_lines.extend([
|
|
841
790
|
" # Apply preprocessing",
|
|
842
791
|
" data_arrays = preprocessing(tuple(data_arrays))",
|
|
843
792
|
" data_arrays = list(data_arrays) # Convert back to list for processing",
|
|
844
793
|
" ",
|
|
794
|
+
" # Validate preprocessing output",
|
|
795
|
+
" preprocessing_signature = validate_preprocessing_output(data_arrays)",
|
|
796
|
+
" ",
|
|
845
797
|
])
|
|
846
798
|
|
|
847
|
-
# Continue with the rest of the processing logic
|
|
848
799
|
script_lines.extend([
|
|
849
800
|
" reference_array = data_arrays[0]",
|
|
850
801
|
" original_shape = reference_array.shape",
|
|
851
|
-
" logging.info(f\"Original shape: {original_shape}\")",
|
|
852
802
|
" ",
|
|
853
803
|
" # Get time coordinates - all bands should have the same time dimension",
|
|
854
804
|
" if 'time' not in reference_array.dims:",
|
|
@@ -856,24 +806,19 @@ class ModelManagement:
|
|
|
856
806
|
" ",
|
|
857
807
|
" time_coords = reference_array.coords['time']",
|
|
858
808
|
" num_timestamps = len(time_coords)",
|
|
859
|
-
" logging.info(f\"Number of timestamps: {num_timestamps}\")",
|
|
860
809
|
" ",
|
|
861
810
|
" # Get spatial dimensions",
|
|
862
811
|
" spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
|
|
863
812
|
" height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
|
|
864
813
|
" width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
|
|
865
|
-
" logging.info(f\"Spatial dimensions: {height} x {width}\")",
|
|
866
814
|
" ",
|
|
867
815
|
" # Stack bands across time dimension",
|
|
868
816
|
" # Result will be: (num_bands * num_timestamps, height, width)",
|
|
869
817
|
" stacked_channels = []",
|
|
870
818
|
" ",
|
|
871
819
|
" for band_idx, data_array in enumerate(data_arrays):",
|
|
872
|
-
" logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
|
|
873
|
-
" ",
|
|
874
820
|
" # Ensure consistent time coordinates across bands",
|
|
875
821
|
" if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
|
|
876
|
-
" logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
|
|
877
822
|
" data_array = data_array.sel(time=time_coords, method='nearest')",
|
|
878
823
|
" ",
|
|
879
824
|
" # Extract values and ensure proper ordering (time, height, width)",
|
|
@@ -892,23 +837,18 @@ class ModelManagement:
|
|
|
892
837
|
" # Stack all channels: (num_bands * num_timestamps, height, width)",
|
|
893
838
|
" input_channels = np.stack(stacked_channels, axis=0)",
|
|
894
839
|
" total_channels = len(data_arrays) * num_timestamps",
|
|
895
|
-
" logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
|
|
896
|
-
" logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
|
|
897
840
|
" ",
|
|
898
841
|
" # Add batch dimension: (1, num_channels, height, width)",
|
|
899
842
|
" input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
|
|
900
|
-
" logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
|
|
901
843
|
" ",
|
|
902
844
|
" # Run inference",
|
|
903
845
|
" output = model.run(None, {\"float_input\": input_data})[0]",
|
|
904
|
-
" logging.info(f\"Model output shape: {output.shape}\")",
|
|
905
846
|
" ",
|
|
906
|
-
" #
|
|
847
|
+
" # Handle multi-class CNN output properly",
|
|
907
848
|
" if output.ndim == 4:",
|
|
908
849
|
" if output.shape[1] == 1:",
|
|
909
850
|
" # Single class output (regression or binary classification)",
|
|
910
851
|
" output_2d = output[0, 0]",
|
|
911
|
-
" logging.info(\"Single channel output detected\")",
|
|
912
852
|
" else:",
|
|
913
853
|
" # Multi-class output - convert logits/probabilities to class predictions",
|
|
914
854
|
" output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
|
|
@@ -916,22 +856,14 @@ class ModelManagement:
|
|
|
916
856
|
" ",
|
|
917
857
|
" # Apply class merging: merge class 6 into class 3",
|
|
918
858
|
" output_2d = np.where(output_2d == 6, 3, output_2d)",
|
|
919
|
-
" ",
|
|
920
|
-
" logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
|
|
921
|
-
" logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
|
|
922
|
-
" logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
|
|
923
859
|
" elif output.ndim == 3:",
|
|
924
860
|
" # Remove batch dimension",
|
|
925
861
|
" output_2d = output[0]",
|
|
926
|
-
" logging.info(\"3D output detected, removed batch dimension\")",
|
|
927
862
|
" else:",
|
|
928
863
|
" # Handle other cases",
|
|
929
864
|
" output_2d = np.squeeze(output)",
|
|
930
865
|
" if output_2d.ndim != 2:",
|
|
931
|
-
" logging.error(f\"Cannot process output shape: {output.shape}\")",
|
|
932
|
-
" logging.error(f\"After squeeze: {output_2d.shape}\")",
|
|
933
866
|
" raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
|
|
934
|
-
" logging.info(\"Applied squeeze to output\")",
|
|
935
867
|
" ",
|
|
936
868
|
" # Ensure output is 2D",
|
|
937
869
|
" if output_2d.ndim != 2:",
|
|
@@ -949,11 +881,9 @@ class ModelManagement:
|
|
|
949
881
|
" if is_multiclass:",
|
|
950
882
|
" # Multi-class classification - use integer type",
|
|
951
883
|
" output_dtype = np.int32",
|
|
952
|
-
" output_type = 'classification'",
|
|
953
884
|
" else:",
|
|
954
885
|
" # Single output - use float type",
|
|
955
886
|
" output_dtype = np.float32",
|
|
956
|
-
" output_type = 'regression'",
|
|
957
887
|
" ",
|
|
958
888
|
" result = xr.DataArray(",
|
|
959
889
|
" data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
|
|
@@ -967,31 +897,232 @@ class ModelManagement:
|
|
|
967
897
|
" 'description': 'CNN model prediction',",
|
|
968
898
|
" }",
|
|
969
899
|
" )",
|
|
970
|
-
" ",
|
|
971
|
-
" logging.info(f\"Final result shape: {result.shape}\")",
|
|
972
|
-
" logging.info(f\"Final result data type: {result.dtype}\")",
|
|
973
|
-
" logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
|
|
974
900
|
])
|
|
975
901
|
|
|
976
|
-
# Add postprocessing call if postprocessing exists
|
|
977
902
|
if postprocessing_section:
|
|
978
903
|
script_lines.extend([
|
|
979
904
|
" # Apply postprocessing",
|
|
980
905
|
" result = postprocessing(result)",
|
|
981
906
|
" ",
|
|
907
|
+
" # Validate postprocessing output",
|
|
908
|
+
" postprocessing_signature = validate_postprocessing_output(result)",
|
|
909
|
+
" ",
|
|
982
910
|
])
|
|
983
911
|
|
|
984
|
-
# Single return statement at the end
|
|
985
912
|
script_lines.append(" return result")
|
|
986
913
|
|
|
987
914
|
return "\n".join(script_lines)
|
|
915
|
+
|
|
916
|
+
def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
|
|
917
|
+
"""
|
|
918
|
+
Parse a Python file and extract preprocessing and postprocessing function bodies.
|
|
919
|
+
|
|
920
|
+
Args:
|
|
921
|
+
script_path: Path to the Python file containing processing functions
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
Tuple of (preprocessing_code, postprocessing_code) where each can be None
|
|
925
|
+
"""
|
|
926
|
+
try:
|
|
927
|
+
with open(script_path, 'r', encoding='utf-8') as f:
|
|
928
|
+
script_content = f.read()
|
|
929
|
+
except FileNotFoundError:
|
|
930
|
+
raise FileNotFoundError(f"Processing script not found: {script_path}")
|
|
931
|
+
except Exception as e:
|
|
932
|
+
raise ValueError(f"Error reading processing script: {e}")
|
|
933
|
+
|
|
934
|
+
if not script_content.strip():
|
|
935
|
+
self._client.logger.info(f"Processing script {script_path} is empty")
|
|
936
|
+
return None, None
|
|
937
|
+
|
|
938
|
+
try:
|
|
939
|
+
tree = ast.parse(script_content)
|
|
940
|
+
except SyntaxError as e:
|
|
941
|
+
raise ValueError(f"Syntax error in processing script: {e}")
|
|
942
|
+
|
|
943
|
+
preprocessing_code = None
|
|
944
|
+
postprocessing_code = None
|
|
945
|
+
|
|
946
|
+
function_names = []
|
|
947
|
+
for node in ast.walk(tree):
|
|
948
|
+
if isinstance(node, ast.FunctionDef):
|
|
949
|
+
function_names.append(node.name)
|
|
950
|
+
if node.name == 'preprocessing':
|
|
951
|
+
preprocessing_code = self._extract_function_body(script_content, node)
|
|
952
|
+
elif node.name == 'postprocessing':
|
|
953
|
+
postprocessing_code = self._extract_function_body(script_content, node)
|
|
954
|
+
|
|
955
|
+
if not function_names:
|
|
956
|
+
self._client.logger.warning(f"No functions found in processing script: {script_path}")
|
|
957
|
+
else:
|
|
958
|
+
found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
|
|
959
|
+
if found_functions:
|
|
960
|
+
self._client.logger.info(f"Found processing functions: {found_functions}")
|
|
961
|
+
else:
|
|
962
|
+
self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
|
|
963
|
+
f"Available functions: {function_names}")
|
|
964
|
+
|
|
965
|
+
return preprocessing_code, postprocessing_code
|
|
966
|
+
|
|
967
|
+
def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
|
|
968
|
+
"""Extract the body of a function from the script content."""
|
|
969
|
+
lines = script_content.split('\n')
|
|
970
|
+
|
|
971
|
+
start_line = func_node.lineno - 1
|
|
972
|
+
end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
|
|
973
|
+
|
|
974
|
+
body_lines = []
|
|
975
|
+
for i in range(start_line + 1, end_line + 1):
|
|
976
|
+
if i < len(lines):
|
|
977
|
+
body_lines.append(lines[i])
|
|
978
|
+
|
|
979
|
+
if not body_lines:
|
|
980
|
+
return ""
|
|
981
|
+
|
|
982
|
+
body_text = '\n'.join(body_lines)
|
|
983
|
+
cleaned_body = textwrap.dedent(body_text).strip()
|
|
984
|
+
|
|
985
|
+
if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
|
|
986
|
+
return ""
|
|
987
|
+
|
|
988
|
+
return cleaned_body
|
|
988
989
|
|
|
990
|
+
def _convert_model_to_onnx(self, model, input_shape: Tuple[int, ...] = None, model_type: Optional[str] = None) -> bytes:
|
|
991
|
+
"""
|
|
992
|
+
Convert a model to ONNX format and return as bytes.
|
|
993
|
+
|
|
994
|
+
Args:
|
|
995
|
+
model: The model object (PyTorch or scikit-learn)
|
|
996
|
+
input_shape: Shape of input data
|
|
997
|
+
model_type: Type of model (neural_network, random_forest), only used for onnx model generation
|
|
998
|
+
Returns:
|
|
999
|
+
bytes: ONNX model as bytes
|
|
1000
|
+
|
|
1001
|
+
Raises:
|
|
1002
|
+
ValueError: If model type is not supported
|
|
1003
|
+
ImportError: If required libraries are not installed
|
|
1004
|
+
"""
|
|
1005
|
+
if isinstance(model, torch.nn.Module):
|
|
1006
|
+
if not TORCH_AVAILABLE:
|
|
1007
|
+
raise ImportError("PyTorch is not installed. Please install it with: pip install torch")
|
|
1008
|
+
return self._convert_pytorch_to_onnx(model, input_shape), "neural_network"
|
|
1009
|
+
elif isinstance(model, BaseEstimator):
|
|
1010
|
+
if not SKL2ONNX_AVAILABLE:
|
|
1011
|
+
raise ImportError("skl2onnx is not installed. Please install it with: pip install skl2onnx")
|
|
1012
|
+
return self._convert_sklearn_to_onnx(model, input_shape), "random_forest"
|
|
1013
|
+
elif isinstance(model, ort.InferenceSession):
|
|
1014
|
+
if model_type is None:
|
|
1015
|
+
raise ValueError(
|
|
1016
|
+
"For ONNX InferenceSession models, you must specify the 'model_type' parameter. Currently 'nerual network' and 'random forest' are supported."
|
|
1017
|
+
"Example: model_type='random forest' or model_type='neural network'"
|
|
1018
|
+
)
|
|
1019
|
+
return model.SerializeToString(), model_type
|
|
1020
|
+
else:
|
|
1021
|
+
model_type = type(model).__name__
|
|
1022
|
+
raise ValueError(f"Unsupported model type: {model_type}. Supported types: PyTorch nn.Module, sklearn BaseEstimator")
|
|
1023
|
+
|
|
1024
|
+
def _convert_pytorch_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
|
|
1025
|
+
try:
|
|
1026
|
+
model.eval()
|
|
1027
|
+
dummy_input = torch.randn(input_shape)
|
|
1028
|
+
onnx_buffer = BytesIO()
|
|
1029
|
+
|
|
1030
|
+
if len(input_shape) == 4:
|
|
1031
|
+
dynamic_axes = {
|
|
1032
|
+
'float_input': {
|
|
1033
|
+
0: 'batch_size',
|
|
1034
|
+
2: 'height',
|
|
1035
|
+
3: 'width'
|
|
1036
|
+
}
|
|
1037
|
+
}
|
|
1038
|
+
|
|
1039
|
+
elif len(input_shape) == 5:
|
|
1040
|
+
dynamic_axes = {
|
|
1041
|
+
'float_input': {
|
|
1042
|
+
0: 'batch_size',
|
|
1043
|
+
3: 'height',
|
|
1044
|
+
4: 'width'
|
|
1045
|
+
}
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
else:
|
|
1049
|
+
dynamic_axes = {
|
|
1050
|
+
'float_input': {
|
|
1051
|
+
0: 'batch_size'
|
|
1052
|
+
}
|
|
1053
|
+
}
|
|
1054
|
+
|
|
1055
|
+
torch.onnx.export(
|
|
1056
|
+
model,
|
|
1057
|
+
dummy_input,
|
|
1058
|
+
onnx_buffer,
|
|
1059
|
+
input_names=['float_input'],
|
|
1060
|
+
dynamic_axes=dynamic_axes
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
return onnx_buffer.getvalue()
|
|
1064
|
+
except Exception as e:
|
|
1065
|
+
raise ValueError(f"Failed to convert PyTorch model to ONNX: {str(e)}")
|
|
1066
|
+
|
|
1067
|
+
def _convert_sklearn_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
|
|
1068
|
+
"""
|
|
1069
|
+
Convert scikit-learn model(assume it is a random forest model) to ONNX format.
|
|
1070
|
+
|
|
1071
|
+
Args:
|
|
1072
|
+
model: The scikit-learn model object
|
|
1073
|
+
input_shape: Shape of input data (required)
|
|
1074
|
+
|
|
1075
|
+
Returns:
|
|
1076
|
+
bytes: ONNX model as bytes
|
|
1077
|
+
|
|
1078
|
+
Raises:
|
|
1079
|
+
ValueError: If conversion fails
|
|
1080
|
+
"""
|
|
1081
|
+
self._client.logger.info(f"Converting random forest model to ONNX...")
|
|
1082
|
+
|
|
1083
|
+
try:
|
|
1084
|
+
initial_type = [('float_input', FloatTensorType(input_shape))]
|
|
1085
|
+
onnx_model = convert_sklearn(model, initial_types=initial_type)
|
|
1086
|
+
return onnx_model.SerializeToString()
|
|
1087
|
+
except Exception as e:
|
|
1088
|
+
raise ValueError(f"Failed to convert scikit-learn model to ONNX: {str(e)}")
|
|
1089
|
+
|
|
989
1090
|
@require_api_key
|
|
990
|
-
def
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
1091
|
+
def train_model(
|
|
1092
|
+
self,
|
|
1093
|
+
model_name: str,
|
|
1094
|
+
training_dataset: str,
|
|
1095
|
+
task_type: str,
|
|
1096
|
+
model_category: str,
|
|
1097
|
+
architecture: str,
|
|
1098
|
+
region: str,
|
|
1099
|
+
hyperparameters: dict = None
|
|
1100
|
+
) -> dict:
|
|
1101
|
+
"""
|
|
1102
|
+
Train a model using the external model training API.
|
|
1103
|
+
|
|
1104
|
+
Args:
|
|
1105
|
+
model_name (str): The name of the model to train.
|
|
1106
|
+
training_dataset (str): The training dataset identifier.
|
|
1107
|
+
task_type (str): The type of ML task (e.g., regression, classification).
|
|
1108
|
+
model_category (str): The category of model (e.g., random_forest).
|
|
1109
|
+
architecture (str): The model architecture.
|
|
1110
|
+
region (str): The region identifier.
|
|
1111
|
+
hyperparameters (dict, optional): Additional hyperparameters for training.
|
|
1112
|
+
|
|
1113
|
+
Returns:
|
|
1114
|
+
dict: The response from the model training API.
|
|
1115
|
+
|
|
1116
|
+
Raises:
|
|
1117
|
+
APIError: If the API request fails
|
|
1118
|
+
"""
|
|
1119
|
+
payload = {
|
|
1120
|
+
"model_name": model_name,
|
|
1121
|
+
"training_dataset": training_dataset,
|
|
1122
|
+
"task_type": task_type,
|
|
1123
|
+
"model_category": model_category,
|
|
1124
|
+
"architecture": architecture,
|
|
1125
|
+
"region": region,
|
|
1126
|
+
"hyperparameters": hyperparameters
|
|
1127
|
+
}
|
|
1128
|
+
return self._client._terrakio_request("POST", "/train_model", json=payload)
|