terrakio-core 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +3 -1
- terrakio_core/accessors.py +477 -0
- terrakio_core/async_client.py +23 -38
- terrakio_core/client.py +83 -84
- terrakio_core/convenience_functions/convenience_functions.py +316 -324
- terrakio_core/endpoints/auth.py +8 -1
- terrakio_core/endpoints/mass_stats.py +13 -9
- terrakio_core/endpoints/model_management.py +604 -948
- terrakio_core/sync_client.py +341 -33
- {terrakio_core-0.4.3.dist-info → terrakio_core-0.4.4.dist-info}/METADATA +2 -1
- terrakio_core-0.4.4.dist-info/RECORD +22 -0
- terrakio_core-0.4.3.dist-info/RECORD +0 -21
- {terrakio_core-0.4.3.dist-info → terrakio_core-0.4.4.dist-info}/WHEEL +0 -0
- {terrakio_core-0.4.3.dist-info → terrakio_core-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -1,36 +1,36 @@
|
|
|
1
|
-
|
|
1
|
+
# Standard library imports
|
|
2
|
+
import ast
|
|
2
3
|
import json
|
|
3
|
-
import time
|
|
4
4
|
import textwrap
|
|
5
|
-
import
|
|
6
|
-
from typing import Dict, Any, Union, Tuple, Optional
|
|
5
|
+
import time
|
|
7
6
|
from io import BytesIO
|
|
8
|
-
import
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
import onnxruntime as ort
|
|
9
|
+
|
|
10
|
+
# Internal imports
|
|
11
|
+
from ..helper.decorators import require_api_key
|
|
12
|
+
|
|
13
|
+
# Optional dependency flags
|
|
12
14
|
TORCH_AVAILABLE = False
|
|
13
15
|
SKL2ONNX_AVAILABLE = False
|
|
14
16
|
|
|
17
|
+
# PyTorch imports
|
|
15
18
|
try:
|
|
16
19
|
import torch
|
|
17
20
|
TORCH_AVAILABLE = True
|
|
18
21
|
except ImportError:
|
|
19
22
|
torch = None
|
|
20
23
|
|
|
24
|
+
# Scikit-learn and ONNX conversion imports
|
|
21
25
|
try:
|
|
26
|
+
from sklearn.base import BaseEstimator
|
|
22
27
|
from skl2onnx import convert_sklearn
|
|
23
28
|
from skl2onnx.common.data_types import FloatTensorType
|
|
24
|
-
from sklearn.base import BaseEstimator
|
|
25
29
|
SKL2ONNX_AVAILABLE = True
|
|
26
30
|
except ImportError:
|
|
31
|
+
BaseEstimator = None
|
|
27
32
|
convert_sklearn = None
|
|
28
33
|
FloatTensorType = None
|
|
29
|
-
BaseEstimator = None
|
|
30
|
-
|
|
31
|
-
from io import BytesIO
|
|
32
|
-
from typing import Tuple
|
|
33
|
-
|
|
34
34
|
|
|
35
35
|
class ModelManagement:
|
|
36
36
|
def __init__(self, client):
|
|
@@ -128,14 +128,12 @@ class ModelManagement:
|
|
|
128
128
|
)
|
|
129
129
|
task_id = task_response["task_id"]
|
|
130
130
|
|
|
131
|
-
# Wait for job completion with progress bar
|
|
132
131
|
while True:
|
|
133
132
|
result = await self._client.mass_stats.track_job(ids=[task_id])
|
|
134
133
|
status = result[task_id]['status']
|
|
135
134
|
completed = result[task_id].get('completed', 0)
|
|
136
135
|
total = result[task_id].get('total', 1)
|
|
137
136
|
|
|
138
|
-
# Create progress bar
|
|
139
137
|
progress = completed / total if total > 0 else 0
|
|
140
138
|
bar_length = 50
|
|
141
139
|
filled_length = int(bar_length * progress)
|
|
@@ -151,526 +149,456 @@ class ModelManagement:
|
|
|
151
149
|
self._client.logger.info("Job encountered an error")
|
|
152
150
|
raise Exception(f"Job {task_id} encountered an error")
|
|
153
151
|
|
|
154
|
-
# Wait 5 seconds before checking again
|
|
155
152
|
time.sleep(5)
|
|
156
153
|
|
|
157
|
-
# after all the random sample jobs are done, we then start the mass stats job
|
|
158
|
-
# task_id = self._client.mass_stats.start_mass_stats_job(task_id)
|
|
159
154
|
task_id = await self._client.mass_stats.start_job(task_id)
|
|
160
155
|
return task_id
|
|
161
|
-
# the folder that is being created is not under the jobs folder, its directly under the UID folder
|
|
162
156
|
|
|
163
|
-
# @require_api_key
|
|
164
|
-
# async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
|
|
165
|
-
# """
|
|
166
|
-
# Upload a model to the bucket so that it can be used for inference.
|
|
167
|
-
# Converts PyTorch and scikit-learn models to ONNX format before uploading.
|
|
168
|
-
|
|
169
|
-
# Args:
|
|
170
|
-
# model: The model object (PyTorch model or scikit-learn model)
|
|
171
|
-
# model_name: Name for the model (without extension)
|
|
172
|
-
# input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
|
|
173
|
-
# Required for PyTorch models, optional for scikit-learn models
|
|
174
|
-
|
|
175
|
-
# Raises:
|
|
176
|
-
# APIError: If the API request fails
|
|
177
|
-
# ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
178
|
-
# ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
179
|
-
# """
|
|
180
|
-
# uid = (await self._client.auth.get_user_info())["uid"]
|
|
181
|
-
# # above line is getting the uid,
|
|
182
|
-
|
|
183
|
-
# client = storage.Client()
|
|
184
|
-
# bucket = client.get_bucket('terrakio-mass-requests')
|
|
185
|
-
|
|
186
|
-
# # Convert model to ONNX format
|
|
187
|
-
# onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
|
|
188
|
-
|
|
189
|
-
# # Upload ONNX model to bucket
|
|
190
|
-
# # blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
|
|
191
|
-
# # we don't need to upload the model to the bucket
|
|
192
|
-
# # so the stuff is stored under the virtual datasets folder
|
|
193
|
-
# # the model name and the virtual dataset name should be the same
|
|
194
|
-
# virtual_dataset_name = model_name
|
|
195
|
-
# blob = bucket.blob(f'{uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx')
|
|
196
|
-
# # wer are uploading the model to the virtual dataset folder
|
|
197
|
-
|
|
198
|
-
# blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
|
|
199
|
-
|
|
200
|
-
# self._client.logger.info(f"Model uploaded successfully to {uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx")
|
|
201
|
-
|
|
202
|
-
# this is the upload model function, I think we need to upload to the user, under the virutal_datasets folder, and create the virtual dataset
|
|
203
|
-
|
|
204
|
-
|
|
205
157
|
@require_api_key
|
|
206
|
-
async def
|
|
158
|
+
async def _get_url_for_upload_model_and_script(self, expression: str, model_name: str, script_name: str) -> str:
|
|
207
159
|
"""
|
|
208
|
-
|
|
209
|
-
Converts PyTorch and scikit-learn models to ONNX format before uploading.
|
|
210
|
-
|
|
160
|
+
Get the url for the upload of the model
|
|
211
161
|
Args:
|
|
212
|
-
|
|
213
|
-
model_name:
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
Raises:
|
|
218
|
-
APIError: If the API request fails
|
|
219
|
-
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
220
|
-
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
162
|
+
expression: The expression to use for the upload(for deciding which bucket to upload to)
|
|
163
|
+
model_name: The name of the model to upload
|
|
164
|
+
script_name: The name of the script to upload
|
|
165
|
+
Returns:
|
|
166
|
+
The url for the upload of the model
|
|
221
167
|
"""
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
|
|
229
|
-
|
|
230
|
-
# Upload ONNX model to bucket
|
|
231
|
-
blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
|
|
232
|
-
|
|
233
|
-
blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
|
|
234
|
-
self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
|
|
168
|
+
payload = {
|
|
169
|
+
"model_name": model_name,
|
|
170
|
+
"expression": expression,
|
|
171
|
+
"script_name": script_name
|
|
172
|
+
}
|
|
173
|
+
return await self._client._terrakio_request("POST", "models/upload", json=payload)
|
|
235
174
|
|
|
236
|
-
def
|
|
175
|
+
async def _upload_model_to_url(self, upload_model_url: str, model: bytes):
|
|
237
176
|
"""
|
|
238
|
-
|
|
239
|
-
|
|
177
|
+
Upload a model to a given URL.
|
|
240
178
|
Args:
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
179
|
+
model_url: The url to upload the model to
|
|
180
|
+
model: The model to upload
|
|
181
|
+
|
|
245
182
|
Returns:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
Raises:
|
|
249
|
-
ValueError: If model type is not supported
|
|
250
|
-
ImportError: If required libraries are not installed
|
|
183
|
+
The response from the server
|
|
251
184
|
"""
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
return self._convert_sklearn_to_onnx(model, model_name, input_shape)
|
|
277
|
-
else:
|
|
278
|
-
# Provide helpful error message
|
|
279
|
-
model_type = type(model).__name__
|
|
280
|
-
model_module = type(model).__module__
|
|
281
|
-
available_types = []
|
|
282
|
-
missing_deps = []
|
|
283
|
-
|
|
284
|
-
if TORCH_AVAILABLE:
|
|
285
|
-
available_types.append("PyTorch (torch.nn.Module)")
|
|
286
|
-
else:
|
|
287
|
-
missing_deps.append("torch")
|
|
288
|
-
|
|
289
|
-
if SKL2ONNX_AVAILABLE:
|
|
290
|
-
available_types.append("scikit-learn (BaseEstimator)")
|
|
291
|
-
else:
|
|
292
|
-
missing_deps.append("skl2onnx")
|
|
293
|
-
|
|
294
|
-
if missing_deps:
|
|
295
|
-
raise ImportError(
|
|
296
|
-
f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
|
|
297
|
-
f"Install with: pip install {' '.join(missing_deps)}"
|
|
298
|
-
)
|
|
299
|
-
else:
|
|
300
|
-
raise ValueError(
|
|
301
|
-
f"Unsupported model type: {model_type} from {model_module}. "
|
|
302
|
-
f"Supported types: {', '.join(available_types)}"
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
|
|
306
|
-
"""Convert PyTorch model to ONNX format with dynamic input dimensions."""
|
|
307
|
-
if input_shape is None:
|
|
308
|
-
raise ValueError("input_shape is required for PyTorch models")
|
|
309
|
-
|
|
310
|
-
self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
|
|
311
|
-
|
|
312
|
-
try:
|
|
313
|
-
# Set model to evaluation mode
|
|
314
|
-
model.eval()
|
|
315
|
-
|
|
316
|
-
# Create dummy input
|
|
317
|
-
dummy_input = torch.randn(input_shape)
|
|
318
|
-
|
|
319
|
-
# Use BytesIO to avoid creating temporary files
|
|
320
|
-
onnx_buffer = BytesIO()
|
|
321
|
-
|
|
322
|
-
# Determine dynamic axes based on input shape
|
|
323
|
-
# Common patterns for different input types:
|
|
324
|
-
if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
|
|
325
|
-
dynamic_axes = {
|
|
326
|
-
'float_input': {
|
|
327
|
-
0: 'batch_size',
|
|
328
|
-
2: 'height', # Make height dynamic for variable input sizes
|
|
329
|
-
3: 'width' # Make width dynamic for variable input sizes
|
|
330
|
-
},
|
|
331
|
-
'output': {0: 'batch_size'}
|
|
332
|
-
}
|
|
333
|
-
elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
|
|
334
|
-
dynamic_axes = {
|
|
335
|
-
'float_input': {
|
|
336
|
-
0: 'batch_size',
|
|
337
|
-
1: 'dim1', # Generic dynamic dimension
|
|
338
|
-
2: 'dim2' # Generic dynamic dimension
|
|
339
|
-
},
|
|
340
|
-
'output': {0: 'batch_size'}
|
|
341
|
-
}
|
|
342
|
-
elif len(input_shape) == 2: # Likely (batch, features)
|
|
343
|
-
dynamic_axes = {
|
|
344
|
-
'float_input': {
|
|
345
|
-
0: 'batch_size'
|
|
346
|
-
# Don't make features dynamic as it usually affects model architecture
|
|
347
|
-
},
|
|
348
|
-
'output': {0: 'batch_size'}
|
|
349
|
-
}
|
|
350
|
-
else:
|
|
351
|
-
# For other shapes, just make batch size dynamic
|
|
352
|
-
dynamic_axes = {
|
|
353
|
-
'float_input': {0: 'batch_size'},
|
|
354
|
-
'output': {0: 'batch_size'}
|
|
355
|
-
}
|
|
356
|
-
|
|
357
|
-
torch.onnx.export(
|
|
358
|
-
model,
|
|
359
|
-
dummy_input,
|
|
360
|
-
onnx_buffer,
|
|
361
|
-
export_params=True,
|
|
362
|
-
opset_version=11,
|
|
363
|
-
do_constant_folding=True,
|
|
364
|
-
input_names=['float_input'],
|
|
365
|
-
output_names=['output'],
|
|
366
|
-
dynamic_axes=dynamic_axes
|
|
367
|
-
)
|
|
368
|
-
|
|
369
|
-
self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
|
|
370
|
-
return onnx_buffer.getvalue()
|
|
371
|
-
|
|
372
|
-
except Exception as e:
|
|
373
|
-
raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
|
|
377
|
-
"""Convert scikit-learn model to ONNX format."""
|
|
378
|
-
self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
|
|
379
|
-
|
|
380
|
-
# Try to infer input shape if not provided
|
|
381
|
-
if input_shape is None:
|
|
382
|
-
if hasattr(model, 'n_features_in_'):
|
|
383
|
-
input_shape = (1, model.n_features_in_)
|
|
384
|
-
else:
|
|
385
|
-
raise ValueError(
|
|
386
|
-
"input_shape is required for scikit-learn models when n_features_in_ is not available. "
|
|
387
|
-
"This usually happens with older sklearn versions or models not fitted yet."
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
try:
|
|
391
|
-
# Convert scikit-learn model to ONNX
|
|
392
|
-
initial_type = [('float_input', FloatTensorType(input_shape))]
|
|
393
|
-
onnx_model = convert_sklearn(model, initial_types=initial_type)
|
|
394
|
-
return onnx_model.SerializeToString()
|
|
395
|
-
|
|
396
|
-
except Exception as e:
|
|
397
|
-
raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
# we do not need to pass in both the model name and the dataset name, since the model name should the same as the virtual dataset name
|
|
401
|
-
# but we are gonna have multiple products for the same virtual dataset
|
|
402
|
-
# @require_api_key
|
|
403
|
-
# async def upload_and_deploy_cnn_model(self, model, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
|
|
404
|
-
# """
|
|
405
|
-
# Upload a CNN model to the bucket and deploy it.
|
|
406
|
-
|
|
407
|
-
# Args:
|
|
408
|
-
# model: The model object (PyTorch model or scikit-learn model)
|
|
409
|
-
# model_name: Name for the model (without extension)
|
|
410
|
-
# dataset: Name of the dataset to create
|
|
411
|
-
# product: Product name for the inference
|
|
412
|
-
# input_expression: Input expression for the dataset
|
|
413
|
-
# dates_iso8601: List of dates in ISO8601 format
|
|
414
|
-
# input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
415
|
-
# processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
416
|
-
|
|
417
|
-
# Raises:
|
|
418
|
-
# APIError: If the API request fails
|
|
419
|
-
# ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
420
|
-
# ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
421
|
-
# """
|
|
422
|
-
# await self.upload_model(model=model, model_name=dataset, input_shape=input_shape)
|
|
423
|
-
# # so the uploading process is kinda similar, but the deployment step is kinda different
|
|
424
|
-
# # we should pass the processing script path to the deploy cnn model function
|
|
425
|
-
# await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
|
|
185
|
+
headers = {
|
|
186
|
+
"Content-Type": "application/octet-stream",
|
|
187
|
+
"Content-Length": str(len(model))
|
|
188
|
+
}
|
|
189
|
+
response = await self._client._regular_request("PUT", endpoint = upload_model_url, data=model, headers=headers)
|
|
190
|
+
return response
|
|
191
|
+
|
|
192
|
+
@require_api_key
|
|
193
|
+
async def _upload_script_to_url(self, upload_script_url: str, script_content: str):
|
|
194
|
+
"""
|
|
195
|
+
Upload the generated script to the url
|
|
196
|
+
Args:
|
|
197
|
+
url: Url for the upload of the script
|
|
198
|
+
script_content: Content of the script
|
|
199
|
+
returns:
|
|
200
|
+
None
|
|
201
|
+
"""
|
|
202
|
+
script_bytes = script_content.encode('utf-8')
|
|
203
|
+
headers = {
|
|
204
|
+
"Content-Type": "text/x-python",
|
|
205
|
+
"Content-Length": str(len(script_bytes))
|
|
206
|
+
}
|
|
207
|
+
response = await self._client._regular_request("PUT", endpoint=upload_script_url, data=script_bytes, headers=headers)
|
|
208
|
+
return response
|
|
426
209
|
|
|
427
210
|
@require_api_key
|
|
428
|
-
async def
|
|
211
|
+
async def _upload_model_and_script(self, model, model_name: str, script_name: str, input_expression: str, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
|
|
429
212
|
"""
|
|
430
|
-
Upload a
|
|
431
|
-
|
|
213
|
+
Upload a model and script to the bucket
|
|
432
214
|
Args:
|
|
433
215
|
model: The model object (PyTorch model or scikit-learn model)
|
|
434
216
|
model_name: Name for the model (without extension)
|
|
435
|
-
|
|
436
|
-
product: Product name for the inference
|
|
217
|
+
script_name: Name for the script (without extension)
|
|
437
218
|
input_expression: Input expression for the dataset
|
|
438
|
-
dates_iso8601: List of dates in ISO8601 format
|
|
439
219
|
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
440
220
|
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
441
|
-
|
|
221
|
+
model_type: The type of the model we want to upload
|
|
442
222
|
Raises:
|
|
443
223
|
APIError: If the API request fails
|
|
444
224
|
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
445
|
-
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
446
|
-
"""
|
|
447
|
-
await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
|
|
448
|
-
# so the uploading process is kinda similar, but the deployment step is kinda different
|
|
449
|
-
# we should pass the processing script path to the deploy cnn model function
|
|
450
|
-
await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
|
|
451
225
|
|
|
226
|
+
Returns:
|
|
227
|
+
bucket_name: Name of the bucket where the model is stored
|
|
228
|
+
"""
|
|
229
|
+
response = await self._get_url_for_upload_model_and_script(expression = input_expression, model_name = model_name, script_name = script_name)
|
|
230
|
+
model_url, script_url, bucket_name = response.get("model_upload_url"), response.get("script_upload_url"), response.get("bucket_name")
|
|
231
|
+
if not model_url or not script_url:
|
|
232
|
+
raise ValueError("No url returned from the server for the upload process")
|
|
233
|
+
try:
|
|
234
|
+
model_in_onnx_bytes, model_type = self._convert_model_to_onnx(model = model, input_shape = input_shape, model_type = model_type)
|
|
235
|
+
if model_type == "neural_network":
|
|
236
|
+
script_content = await self._generate_cnn_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
|
|
237
|
+
elif model_type == "random_forest":
|
|
238
|
+
script_content = await self._generate_random_forest_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"Unsupported model type: {model_type}. Supported types: neural_network, random_forest")
|
|
241
|
+
script_upload_response = await self._upload_script_to_url( upload_script_url = script_url, script_content = script_content)
|
|
242
|
+
if script_upload_response.status not in [200, 201, 204]:
|
|
243
|
+
self._client.logger.error(f"Script upload error: {script_upload_response.text()}")
|
|
244
|
+
raise Exception(f"Failed to upload script: {script_upload_response.text()}")
|
|
245
|
+
model_upload_response = await self._upload_model_to_url(upload_model_url = model_url, model = model_in_onnx_bytes)
|
|
246
|
+
if model_upload_response.status not in [200, 201, 204]:
|
|
247
|
+
self._client.logger.error(f"Model upload error: {model_upload_response.text()}")
|
|
248
|
+
raise Exception(f"Failed to upload model: {model_upload_response.text()}")
|
|
249
|
+
except Exception as e:
|
|
250
|
+
raise Exception(f"Error uploading model: {e}")
|
|
251
|
+
self._client.logger.info(f"Model and Script uploaded successfully to {model_url}")
|
|
252
|
+
return bucket_name
|
|
452
253
|
|
|
453
254
|
@require_api_key
|
|
454
|
-
async def upload_and_deploy_model(self, model,
|
|
255
|
+
async def upload_and_deploy_model(self, model, virtual_dataset_name: str, virtual_product_name: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
|
|
455
256
|
"""
|
|
456
257
|
Upload a model to the bucket and deploy it.
|
|
457
|
-
|
|
458
258
|
Args:
|
|
459
259
|
model: The model object (PyTorch model or scikit-learn model)
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
product: Product name for the inference
|
|
260
|
+
virtual_dataset_name: Name for the virtual dataset (without extension)
|
|
261
|
+
virtual_product_name: Product name for the inference
|
|
463
262
|
input_expression: Input expression for the dataset
|
|
464
263
|
dates_iso8601: List of dates in ISO8601 format
|
|
465
264
|
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
|
|
265
|
+
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
266
|
+
model_type: The type of the model we want to upload
|
|
469
267
|
|
|
470
|
-
@require_api_key
|
|
471
|
-
def train_model(
|
|
472
|
-
self,
|
|
473
|
-
model_name: str,
|
|
474
|
-
training_dataset: str,
|
|
475
|
-
task_type: str,
|
|
476
|
-
model_category: str,
|
|
477
|
-
architecture: str,
|
|
478
|
-
region: str,
|
|
479
|
-
hyperparameters: dict = None
|
|
480
|
-
) -> dict:
|
|
481
|
-
"""
|
|
482
|
-
Train a model using the external model training API.
|
|
483
|
-
|
|
484
|
-
Args:
|
|
485
|
-
model_name (str): The name of the model to train.
|
|
486
|
-
training_dataset (str): The training dataset identifier.
|
|
487
|
-
task_type (str): The type of ML task (e.g., regression, classification).
|
|
488
|
-
model_category (str): The category of model (e.g., random_forest).
|
|
489
|
-
architecture (str): The model architecture.
|
|
490
|
-
region (str): The region identifier.
|
|
491
|
-
hyperparameters (dict, optional): Additional hyperparameters for training.
|
|
492
|
-
|
|
493
|
-
Returns:
|
|
494
|
-
dict: The response from the model training API.
|
|
495
|
-
|
|
496
268
|
Raises:
|
|
497
269
|
APIError: If the API request fails
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
"model_name": model_name,
|
|
501
|
-
"training_dataset": training_dataset,
|
|
502
|
-
"task_type": task_type,
|
|
503
|
-
"model_category": model_category,
|
|
504
|
-
"architecture": architecture,
|
|
505
|
-
"region": region,
|
|
506
|
-
"hyperparameters": hyperparameters
|
|
507
|
-
}
|
|
508
|
-
return self._client._terrakio_request("POST", "/train_model", json=payload)
|
|
270
|
+
ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
271
|
+
ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
509
272
|
|
|
510
|
-
@require_api_key
|
|
511
|
-
async def deploy_model(
|
|
512
|
-
self,
|
|
513
|
-
dataset: str,
|
|
514
|
-
product: str,
|
|
515
|
-
model_name: str,
|
|
516
|
-
input_expression: str,
|
|
517
|
-
model_training_job_name: str,
|
|
518
|
-
dates_iso8601: list
|
|
519
|
-
) -> Dict[str, Any]:
|
|
520
|
-
"""
|
|
521
|
-
Deploy a model by generating inference script and creating dataset.
|
|
522
|
-
|
|
523
|
-
Args:
|
|
524
|
-
dataset: Name of the dataset to create
|
|
525
|
-
product: Product name for the inference
|
|
526
|
-
model_name: Name of the trained model
|
|
527
|
-
input_expression: Input expression for the dataset
|
|
528
|
-
model_training_job_name: Name of the training job
|
|
529
|
-
dates_iso8601: List of dates in ISO8601 format
|
|
530
|
-
|
|
531
273
|
Returns:
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
Raises:
|
|
535
|
-
APIError: If the API request fails
|
|
274
|
+
None
|
|
536
275
|
"""
|
|
537
|
-
|
|
276
|
+
bucket_name = await self._upload_model_and_script(model=model, model_name=virtual_dataset_name, script_name= virtual_product_name, input_shape=input_shape, input_expression=input_expression, processing_script_path=processing_script_path, model_type= model_type)
|
|
538
277
|
user_info = await self._client.auth.get_user_info()
|
|
539
278
|
uid = user_info["uid"]
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
543
|
-
script_name = f"{product}.py"
|
|
544
|
-
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
545
|
-
|
|
546
|
-
# Create dataset
|
|
547
|
-
return await self._client.datasets.create_dataset(
|
|
548
|
-
name=dataset,
|
|
279
|
+
await self._client.datasets.create_dataset(
|
|
280
|
+
name=virtual_dataset_name,
|
|
549
281
|
collection="terrakio-datasets",
|
|
550
|
-
products=[
|
|
551
|
-
path=f"gs://
|
|
282
|
+
products=[virtual_product_name],
|
|
283
|
+
path=f"gs://{bucket_name}/{uid}/virtual_datasets/{virtual_dataset_name}/inference_scripts",
|
|
552
284
|
input=input_expression,
|
|
553
285
|
dates_iso8601=dates_iso8601,
|
|
554
286
|
padding=0
|
|
555
287
|
)
|
|
556
288
|
|
|
557
|
-
|
|
289
|
+
@require_api_key
|
|
290
|
+
async def _generate_random_forest_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
|
|
558
291
|
"""
|
|
559
|
-
|
|
292
|
+
Generate Python inference script for the Random Forest model.
|
|
560
293
|
|
|
561
294
|
Args:
|
|
562
|
-
|
|
295
|
+
bucket_name: Name of the bucket where the model is stored
|
|
296
|
+
virtual_dataset_name: Name of the virtual dataset and the model
|
|
297
|
+
virtual_product_name: Name of the virtual product
|
|
298
|
+
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
563
299
|
|
|
564
300
|
Returns:
|
|
565
|
-
|
|
301
|
+
str: Generated Python script content
|
|
566
302
|
"""
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
303
|
+
user_info = await self._client.auth.get_user_info()
|
|
304
|
+
uid = user_info["uid"]
|
|
305
|
+
preprocessing_code, postprocessing_code = None, None
|
|
306
|
+
|
|
307
|
+
if processing_script_path:
|
|
308
|
+
try:
|
|
309
|
+
preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
|
|
310
|
+
if preprocessing_code:
|
|
311
|
+
self._client.logger.info(f"Using custom preprocessing from: {processing_script_path}")
|
|
312
|
+
if postprocessing_code:
|
|
313
|
+
self._client.logger.info(f"Using custom postprocessing from: {processing_script_path}")
|
|
314
|
+
if not preprocessing_code and not postprocessing_code:
|
|
315
|
+
self._client.logger.warning(f"No preprocessing or postprocessing functions found in {processing_script_path}")
|
|
316
|
+
self._client.logger.info("Deployment will continue without custom processing")
|
|
317
|
+
except Exception as e:
|
|
318
|
+
raise ValueError(f"Failed to load processing script: {str(e)}")
|
|
319
|
+
|
|
320
|
+
preprocessing_section = ""
|
|
321
|
+
if preprocessing_code and preprocessing_code.strip():
|
|
322
|
+
clean_preprocessing = textwrap.dedent(preprocessing_code)
|
|
323
|
+
preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
|
|
585
324
|
|
|
586
|
-
|
|
587
|
-
postprocessing_code
|
|
325
|
+
postprocessing_section = ""
|
|
326
|
+
if postprocessing_code and postprocessing_code.strip():
|
|
327
|
+
clean_postprocessing = textwrap.dedent(postprocessing_code)
|
|
328
|
+
postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
|
|
329
|
+
|
|
330
|
+
script_lines = [
|
|
331
|
+
"import logging",
|
|
332
|
+
"from io import BytesIO",
|
|
333
|
+
"import numpy as np",
|
|
334
|
+
"import pandas as pd",
|
|
335
|
+
"import xarray as xr",
|
|
336
|
+
"from google.cloud import storage",
|
|
337
|
+
"from onnxruntime import InferenceSession",
|
|
338
|
+
"from typing import Tuple",
|
|
339
|
+
"",
|
|
340
|
+
"logging.basicConfig(",
|
|
341
|
+
" level=logging.INFO",
|
|
342
|
+
")",
|
|
343
|
+
"",
|
|
344
|
+
]
|
|
588
345
|
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
346
|
+
if preprocessing_section:
|
|
347
|
+
script_lines.extend([
|
|
348
|
+
"def validate_preprocessing_output(data_arrays):",
|
|
349
|
+
" \"\"\"",
|
|
350
|
+
" Validate preprocessing output coordinates and data type.",
|
|
351
|
+
" ",
|
|
352
|
+
" Args:",
|
|
353
|
+
" data_arrays: List of xarray DataArrays from preprocessing",
|
|
354
|
+
" ",
|
|
355
|
+
" Returns:",
|
|
356
|
+
" str: Validation signature symbol",
|
|
357
|
+
" ",
|
|
358
|
+
" Raises:",
|
|
359
|
+
" ValueError: If validation fails",
|
|
360
|
+
" \"\"\"",
|
|
361
|
+
" import numpy as np",
|
|
362
|
+
" ",
|
|
363
|
+
" if not data_arrays:",
|
|
364
|
+
" raise ValueError(\"No data arrays provided from preprocessing\")",
|
|
365
|
+
" ",
|
|
366
|
+
" reference_shape = None",
|
|
367
|
+
" ",
|
|
368
|
+
" for i, data_array in enumerate(data_arrays):",
|
|
369
|
+
" # Check if it's an xarray DataArray",
|
|
370
|
+
" if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
|
|
371
|
+
" raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
|
|
372
|
+
" ",
|
|
373
|
+
" # Check coordinates",
|
|
374
|
+
" if 'time' not in data_array.coords:",
|
|
375
|
+
" raise ValueError(f\"Channel {i+1} missing time coordinate\")",
|
|
376
|
+
" ",
|
|
377
|
+
" spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
|
|
378
|
+
" if len(spatial_dims) != 2:",
|
|
379
|
+
" raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
|
|
380
|
+
" ",
|
|
381
|
+
" for dim in spatial_dims:",
|
|
382
|
+
" if dim not in data_array.coords:",
|
|
383
|
+
" raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
|
|
384
|
+
" ",
|
|
385
|
+
" # Check shape consistency",
|
|
386
|
+
" shape = data_array.shape",
|
|
387
|
+
" if reference_shape is None:",
|
|
388
|
+
" reference_shape = shape",
|
|
389
|
+
" else:",
|
|
390
|
+
" if shape != reference_shape:",
|
|
391
|
+
" raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
|
|
392
|
+
" ",
|
|
393
|
+
" # Generate validation signature",
|
|
394
|
+
" signature_components = [",
|
|
395
|
+
" f\"CH{len(data_arrays)}\", # Channel count",
|
|
396
|
+
" f\"T{reference_shape[0]}\", # Time dimension",
|
|
397
|
+
" f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
|
|
398
|
+
" f\"DT{data_arrays[0].values.dtype}\", # Data type",
|
|
399
|
+
" ]",
|
|
400
|
+
" ",
|
|
401
|
+
" signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
|
|
402
|
+
" ",
|
|
403
|
+
" return signature",
|
|
404
|
+
"",
|
|
405
|
+
])
|
|
598
406
|
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
407
|
+
if postprocessing_section:
|
|
408
|
+
script_lines.extend([
|
|
409
|
+
"def validate_postprocessing_output(result_array):",
|
|
410
|
+
" \"\"\"",
|
|
411
|
+
" Validate postprocessing output coordinates and data type.",
|
|
412
|
+
" ",
|
|
413
|
+
" Args:",
|
|
414
|
+
" result_array: xarray DataArray from postprocessing",
|
|
415
|
+
" ",
|
|
416
|
+
" Returns:",
|
|
417
|
+
" str: Validation signature symbol",
|
|
418
|
+
" ",
|
|
419
|
+
" Raises:",
|
|
420
|
+
" ValueError: If validation fails",
|
|
421
|
+
" \"\"\"",
|
|
422
|
+
" import numpy as np",
|
|
423
|
+
" ",
|
|
424
|
+
" # Check if it's an xarray DataArray",
|
|
425
|
+
" if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
|
|
426
|
+
" raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
|
|
427
|
+
" ",
|
|
428
|
+
" # Check required coordinates",
|
|
429
|
+
" if 'time' not in result_array.coords:",
|
|
430
|
+
" raise ValueError(\"Missing time coordinate\")",
|
|
431
|
+
" ",
|
|
432
|
+
" spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
|
|
433
|
+
" if len(spatial_dims) != 2:",
|
|
434
|
+
" raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
|
|
435
|
+
" ",
|
|
436
|
+
" for dim in spatial_dims:",
|
|
437
|
+
" if dim not in result_array.coords:",
|
|
438
|
+
" raise ValueError(f\"Missing spatial coordinate: {dim}\")",
|
|
439
|
+
" ",
|
|
440
|
+
" # Check shape",
|
|
441
|
+
" shape = result_array.shape",
|
|
442
|
+
" ",
|
|
443
|
+
" # Generate validation signature",
|
|
444
|
+
" signature_components = [",
|
|
445
|
+
" f\"T{shape[0]}\", # Time dimension",
|
|
446
|
+
" f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
|
|
447
|
+
" f\"DT{result_array.values.dtype}\", # Data type",
|
|
448
|
+
" ]",
|
|
449
|
+
" ",
|
|
450
|
+
" signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
|
|
451
|
+
" ",
|
|
452
|
+
" return signature",
|
|
453
|
+
"",
|
|
454
|
+
])
|
|
609
455
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
456
|
+
if preprocessing_section:
|
|
457
|
+
script_lines.extend([
|
|
458
|
+
"def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
|
|
459
|
+
preprocessing_section,
|
|
460
|
+
"",
|
|
461
|
+
])
|
|
615
462
|
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
463
|
+
if postprocessing_section:
|
|
464
|
+
script_lines.extend([
|
|
465
|
+
"def postprocessing(array: xr.DataArray) -> xr.DataArray:",
|
|
466
|
+
postprocessing_section,
|
|
467
|
+
"",
|
|
468
|
+
])
|
|
619
469
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
470
|
+
script_lines.extend([
|
|
471
|
+
"def get_model():",
|
|
472
|
+
f" logging.info(\"Loading Random Forest model for {virtual_dataset_name}...\")",
|
|
473
|
+
"",
|
|
474
|
+
" client = storage.Client()",
|
|
475
|
+
f" bucket = client.get_bucket('{bucket_name}')",
|
|
476
|
+
f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
|
|
477
|
+
"",
|
|
478
|
+
" model = BytesIO()",
|
|
479
|
+
" blob.download_to_file(model)",
|
|
480
|
+
" model.seek(0)",
|
|
481
|
+
"",
|
|
482
|
+
" session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
|
|
483
|
+
" return session",
|
|
484
|
+
"",
|
|
485
|
+
f"def {virtual_product_name}(*bands, model):",
|
|
486
|
+
" logging.info(\"Start preparing Random Forest data\")",
|
|
487
|
+
" data_arrays = list(bands)",
|
|
488
|
+
" ",
|
|
489
|
+
" if not data_arrays:",
|
|
490
|
+
" raise ValueError(\"No bands provided\")",
|
|
491
|
+
" ",
|
|
492
|
+
])
|
|
625
493
|
|
|
626
|
-
if
|
|
627
|
-
|
|
494
|
+
if preprocessing_section:
|
|
495
|
+
script_lines.extend([
|
|
496
|
+
" # Apply preprocessing",
|
|
497
|
+
" data_arrays = preprocessing(tuple(data_arrays))",
|
|
498
|
+
" data_arrays = list(data_arrays) # Convert back to list for processing",
|
|
499
|
+
" ",
|
|
500
|
+
" # Validate preprocessing output",
|
|
501
|
+
" preprocessing_signature = validate_preprocessing_output(data_arrays)",
|
|
502
|
+
" ",
|
|
503
|
+
])
|
|
628
504
|
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
505
|
+
script_lines.extend([
|
|
506
|
+
" reference_array = data_arrays[0]",
|
|
507
|
+
" original_shape = reference_array.shape",
|
|
508
|
+
" ",
|
|
509
|
+
" if 'time' in reference_array.dims:",
|
|
510
|
+
" time_coords = reference_array.coords['time']",
|
|
511
|
+
" if len(time_coords) == 1:",
|
|
512
|
+
" output_timestamp = time_coords[0]",
|
|
513
|
+
" else:",
|
|
514
|
+
" years = [pd.to_datetime(t).year for t in time_coords.values]",
|
|
515
|
+
" unique_years = set(years)",
|
|
516
|
+
" ",
|
|
517
|
+
" if len(unique_years) == 1:",
|
|
518
|
+
" year = list(unique_years)[0]",
|
|
519
|
+
" output_timestamp = pd.Timestamp(f\"{year}-01-01\")",
|
|
520
|
+
" else:",
|
|
521
|
+
" latest_year = max(unique_years)",
|
|
522
|
+
" output_timestamp = pd.Timestamp(f\"{latest_year}-01-01\")",
|
|
523
|
+
" else:",
|
|
524
|
+
" output_timestamp = pd.Timestamp(\"1970-01-01\")",
|
|
525
|
+
"",
|
|
526
|
+
" averaged_bands = []",
|
|
527
|
+
" for data_array in data_arrays:",
|
|
528
|
+
" if 'time' in data_array.dims:",
|
|
529
|
+
" averaged_band = np.mean(data_array.values, axis=0)",
|
|
530
|
+
" else:",
|
|
531
|
+
" averaged_band = data_array.values",
|
|
532
|
+
"",
|
|
533
|
+
" flattened_band = averaged_band.reshape(-1, 1)",
|
|
534
|
+
" averaged_bands.append(flattened_band)",
|
|
535
|
+
"",
|
|
536
|
+
" input_data = np.hstack(averaged_bands)",
|
|
537
|
+
"",
|
|
538
|
+
" output = model.run(None, {\"float_input\": input_data.astype(np.float32)})[0]",
|
|
539
|
+
"",
|
|
540
|
+
" if len(original_shape) >= 3:",
|
|
541
|
+
" spatial_shape = original_shape[1:]",
|
|
542
|
+
" else:",
|
|
543
|
+
" spatial_shape = original_shape",
|
|
544
|
+
"",
|
|
545
|
+
" output_reshaped = output.reshape(spatial_shape)",
|
|
546
|
+
"",
|
|
547
|
+
" output_with_time = np.expand_dims(output_reshaped, axis=0)",
|
|
548
|
+
"",
|
|
549
|
+
" if 'time' in reference_array.dims:",
|
|
550
|
+
" spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
|
|
551
|
+
" spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}",
|
|
552
|
+
" else:",
|
|
553
|
+
" spatial_dims = list(reference_array.dims)",
|
|
554
|
+
" spatial_coords = dict(reference_array.coords)",
|
|
555
|
+
"",
|
|
556
|
+
" result = xr.DataArray(",
|
|
557
|
+
" data=output_with_time.astype(np.float32),",
|
|
558
|
+
" dims=['time'] + list(spatial_dims),",
|
|
559
|
+
" coords={",
|
|
560
|
+
" 'time': [output_timestamp.values],",
|
|
561
|
+
" 'y': spatial_coords['y'].values,",
|
|
562
|
+
" 'x': spatial_coords['x'].values",
|
|
563
|
+
" },",
|
|
564
|
+
" attrs={",
|
|
565
|
+
" 'description': 'Random Forest model prediction',",
|
|
566
|
+
" }",
|
|
567
|
+
" )",
|
|
568
|
+
])
|
|
632
569
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
570
|
+
if postprocessing_section:
|
|
571
|
+
script_lines.extend([
|
|
572
|
+
" # Apply postprocessing",
|
|
573
|
+
" result = postprocessing(result)",
|
|
574
|
+
" ",
|
|
575
|
+
" # Validate postprocessing output",
|
|
576
|
+
" postprocessing_signature = validate_postprocessing_output(result)",
|
|
577
|
+
" ",
|
|
578
|
+
])
|
|
636
579
|
|
|
637
|
-
return
|
|
580
|
+
script_lines.append(" return result")
|
|
638
581
|
|
|
582
|
+
return "\n".join(script_lines)
|
|
583
|
+
|
|
639
584
|
@require_api_key
|
|
640
|
-
async def
|
|
641
|
-
self,
|
|
642
|
-
dataset: str,
|
|
643
|
-
product: str,
|
|
644
|
-
model_name: str,
|
|
645
|
-
input_expression: str,
|
|
646
|
-
model_training_job_name: str,
|
|
647
|
-
dates_iso8601: list,
|
|
648
|
-
processing_script_path: Optional[str] = None
|
|
649
|
-
) -> Dict[str, Any]:
|
|
585
|
+
async def _generate_cnn_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
|
|
650
586
|
"""
|
|
651
|
-
|
|
587
|
+
Generate Python inference script for CNN model with time-stacked bands.
|
|
652
588
|
|
|
653
589
|
Args:
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
input_expression: Input expression for the dataset
|
|
658
|
-
model_training_job_name: Name of the training job
|
|
659
|
-
dates_iso8601: List of dates in ISO8601 format
|
|
590
|
+
bucket_name: Name of the bucket where the model is stored
|
|
591
|
+
virtual_dataset_name: Name of the virtual dataset and the model
|
|
592
|
+
virtual_product_name: Name of the virtual product
|
|
660
593
|
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
661
594
|
Returns:
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
Raises:
|
|
665
|
-
APIError: If the API request fails
|
|
595
|
+
str: Generated Python script content
|
|
666
596
|
"""
|
|
667
|
-
# Get user info to get UID
|
|
668
597
|
user_info = await self._client.auth.get_user_info()
|
|
669
598
|
uid = user_info["uid"]
|
|
670
|
-
|
|
671
599
|
preprocessing_code, postprocessing_code = None, None
|
|
600
|
+
|
|
672
601
|
if processing_script_path:
|
|
673
|
-
# if there is a function that is being passed in
|
|
674
602
|
try:
|
|
675
603
|
preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
|
|
676
604
|
if preprocessing_code:
|
|
@@ -682,415 +610,17 @@ class ModelManagement:
|
|
|
682
610
|
self._client.logger.info("Deployment will continue without custom processing")
|
|
683
611
|
except Exception as e:
|
|
684
612
|
raise ValueError(f"Failed to load processing script: {str(e)}")
|
|
685
|
-
# so we already have the preprocessing code and the post processing code, I need to pass them to the generate cnn script function
|
|
686
|
-
# Generate and upload script
|
|
687
|
-
# Build preprocessing section with CONSISTENT 8-space indentation
|
|
688
|
-
preprocessing_section = ""
|
|
689
|
-
if preprocessing_code and preprocessing_code.strip():
|
|
690
|
-
# First dedent the preprocessing code to remove any existing indentation
|
|
691
|
-
clean_preprocessing = preprocessing_code
|
|
692
|
-
# Then add consistent 8-space indentation to match the template
|
|
693
|
-
preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
|
|
694
|
-
# print(preprocessing_section)
|
|
695
|
-
script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
|
|
696
|
-
script_name = f"{product}.py"
|
|
697
|
-
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
698
|
-
# Create dataset
|
|
699
|
-
return await self._client.datasets.create_dataset(
|
|
700
|
-
name=dataset,
|
|
701
|
-
collection="terrakio-datasets",
|
|
702
|
-
products=[product],
|
|
703
|
-
path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
|
|
704
|
-
input=input_expression,
|
|
705
|
-
dates_iso8601=dates_iso8601,
|
|
706
|
-
padding=0
|
|
707
|
-
)
|
|
708
|
-
|
|
709
|
-
@require_api_key
|
|
710
|
-
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
711
|
-
"""
|
|
712
|
-
Generate Python inference script for the model.
|
|
713
|
-
|
|
714
|
-
Args:
|
|
715
|
-
model_name: Name of the model
|
|
716
|
-
product: Product name
|
|
717
|
-
model_training_job_name: Training job name
|
|
718
|
-
uid: User ID
|
|
719
|
-
|
|
720
|
-
Returns:
|
|
721
|
-
str: Generated Python script content
|
|
722
|
-
"""
|
|
723
|
-
return textwrap.dedent(f'''
|
|
724
|
-
import logging
|
|
725
|
-
from io import BytesIO
|
|
726
|
-
|
|
727
|
-
import numpy as np
|
|
728
|
-
import pandas as pd
|
|
729
|
-
import xarray as xr
|
|
730
|
-
from google.cloud import storage
|
|
731
|
-
from onnxruntime import InferenceSession
|
|
732
|
-
|
|
733
|
-
logging.basicConfig(
|
|
734
|
-
level=logging.INFO
|
|
735
|
-
)
|
|
736
|
-
|
|
737
|
-
def get_model():
|
|
738
|
-
logging.info("Loading model for {model_name}...")
|
|
739
|
-
|
|
740
|
-
client = storage.Client()
|
|
741
|
-
bucket = client.get_bucket('terrakio-mass-requests')
|
|
742
|
-
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
743
|
-
|
|
744
|
-
model = BytesIO()
|
|
745
|
-
blob.download_to_file(model)
|
|
746
|
-
model.seek(0)
|
|
747
|
-
|
|
748
|
-
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
749
|
-
return session
|
|
750
|
-
|
|
751
|
-
def {product}(*bands, model):
|
|
752
|
-
logging.info("start preparing data")
|
|
753
|
-
|
|
754
|
-
data_arrays = list(bands)
|
|
755
|
-
|
|
756
|
-
reference_array = data_arrays[0]
|
|
757
|
-
original_shape = reference_array.shape
|
|
758
|
-
logging.info(f"Original shape: {{original_shape}}")
|
|
759
613
|
|
|
760
|
-
if 'time' in reference_array.dims:
|
|
761
|
-
time_coords = reference_array.coords['time']
|
|
762
|
-
if len(time_coords) == 1:
|
|
763
|
-
output_timestamp = time_coords[0]
|
|
764
|
-
else:
|
|
765
|
-
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
766
|
-
unique_years = set(years)
|
|
767
|
-
|
|
768
|
-
if len(unique_years) == 1:
|
|
769
|
-
year = list(unique_years)[0]
|
|
770
|
-
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
771
|
-
else:
|
|
772
|
-
latest_year = max(unique_years)
|
|
773
|
-
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
774
|
-
else:
|
|
775
|
-
output_timestamp = pd.Timestamp("1970-01-01")
|
|
776
|
-
|
|
777
|
-
averaged_bands = []
|
|
778
|
-
for data_array in data_arrays:
|
|
779
|
-
if 'time' in data_array.dims:
|
|
780
|
-
averaged_band = np.mean(data_array.values, axis=0)
|
|
781
|
-
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
782
|
-
else:
|
|
783
|
-
averaged_band = data_array.values
|
|
784
|
-
logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
|
|
785
|
-
|
|
786
|
-
flattened_band = averaged_band.reshape(-1, 1)
|
|
787
|
-
averaged_bands.append(flattened_band)
|
|
788
|
-
|
|
789
|
-
input_data = np.hstack(averaged_bands)
|
|
790
|
-
|
|
791
|
-
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
792
|
-
|
|
793
|
-
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
794
|
-
|
|
795
|
-
logging.info(f"Model output shape: {{output.shape}}")
|
|
796
|
-
|
|
797
|
-
if len(original_shape) >= 3:
|
|
798
|
-
spatial_shape = original_shape[1:]
|
|
799
|
-
else:
|
|
800
|
-
spatial_shape = original_shape
|
|
801
|
-
|
|
802
|
-
output_reshaped = output.reshape(spatial_shape)
|
|
803
|
-
|
|
804
|
-
output_with_time = np.expand_dims(output_reshaped, axis=0)
|
|
805
|
-
|
|
806
|
-
if 'time' in reference_array.dims:
|
|
807
|
-
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
808
|
-
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
|
|
809
|
-
else:
|
|
810
|
-
spatial_dims = list(reference_array.dims)
|
|
811
|
-
spatial_coords = dict(reference_array.coords)
|
|
812
|
-
|
|
813
|
-
result = xr.DataArray(
|
|
814
|
-
data=output_with_time.astype(np.float32),
|
|
815
|
-
dims=['time'] + list(spatial_dims),
|
|
816
|
-
coords={{
|
|
817
|
-
'time': [output_timestamp.values],
|
|
818
|
-
'y': spatial_coords['y'].values,
|
|
819
|
-
'x': spatial_coords['x'].values
|
|
820
|
-
}}
|
|
821
|
-
)
|
|
822
|
-
return result
|
|
823
|
-
''').strip()
|
|
824
|
-
|
|
825
|
-
# @require_api_key
|
|
826
|
-
# def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
|
|
827
|
-
# """
|
|
828
|
-
# Generate Python inference script for CNN model with time-stacked bands.
|
|
829
|
-
|
|
830
|
-
# Args:
|
|
831
|
-
# model_name: Name of the model
|
|
832
|
-
# product: Product name
|
|
833
|
-
# model_training_job_name: Training job name
|
|
834
|
-
# uid: User ID
|
|
835
|
-
# preprocessing_code: Preprocessing code
|
|
836
|
-
# postprocessing_code: Postprocessing code
|
|
837
|
-
# Returns:
|
|
838
|
-
# str: Generated Python script content
|
|
839
|
-
# """
|
|
840
|
-
# import textwrap
|
|
841
|
-
|
|
842
|
-
# # Build preprocessing section with CONSISTENT 4-space indentation
|
|
843
|
-
# preprocessing_section = ""
|
|
844
|
-
# if preprocessing_code and preprocessing_code.strip():
|
|
845
|
-
# clean_preprocessing = textwrap.dedent(preprocessing_code)
|
|
846
|
-
# preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
|
|
847
|
-
|
|
848
|
-
# # Build postprocessing section with CONSISTENT 4-space indentation
|
|
849
|
-
# postprocessing_section = ""
|
|
850
|
-
# if postprocessing_code and postprocessing_code.strip():
|
|
851
|
-
# clean_postprocessing = textwrap.dedent(postprocessing_code)
|
|
852
|
-
# postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
|
|
853
|
-
|
|
854
|
-
# # Build the template WITHOUT dedenting the whole thing, so indentation is preserved
|
|
855
|
-
# script_lines = [
|
|
856
|
-
# "import logging",
|
|
857
|
-
# "from io import BytesIO",
|
|
858
|
-
# "import numpy as np",
|
|
859
|
-
# "import pandas as pd",
|
|
860
|
-
# "import xarray as xr",
|
|
861
|
-
# "from google.cloud import storage",
|
|
862
|
-
# "from onnxruntime import InferenceSession",
|
|
863
|
-
# "from typing import Tuple",
|
|
864
|
-
# "",
|
|
865
|
-
# "logging.basicConfig(",
|
|
866
|
-
# " level=logging.INFO",
|
|
867
|
-
# ")",
|
|
868
|
-
# "",
|
|
869
|
-
# ]
|
|
870
|
-
|
|
871
|
-
# # Add preprocessing function definition BEFORE the main function
|
|
872
|
-
# if preprocessing_section:
|
|
873
|
-
# script_lines.extend([
|
|
874
|
-
# "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
|
|
875
|
-
# preprocessing_section,
|
|
876
|
-
# "",
|
|
877
|
-
# ])
|
|
878
|
-
|
|
879
|
-
# # Add postprocessing function definition BEFORE the main function
|
|
880
|
-
# if postprocessing_section:
|
|
881
|
-
# script_lines.extend([
|
|
882
|
-
# "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
|
|
883
|
-
# postprocessing_section,
|
|
884
|
-
# "",
|
|
885
|
-
# ])
|
|
886
|
-
|
|
887
|
-
# # Add the get_model function
|
|
888
|
-
# script_lines.extend([
|
|
889
|
-
# "def get_model():",
|
|
890
|
-
# f" logging.info(\"Loading CNN model for {model_name}...\")",
|
|
891
|
-
# "",
|
|
892
|
-
# " client = storage.Client()",
|
|
893
|
-
# " bucket = client.get_bucket('terrakio-mass-requests')",
|
|
894
|
-
# f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
|
|
895
|
-
# "",
|
|
896
|
-
# " model = BytesIO()",
|
|
897
|
-
# " blob.download_to_file(model)",
|
|
898
|
-
# " model.seek(0)",
|
|
899
|
-
# "",
|
|
900
|
-
# " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
|
|
901
|
-
# " return session",
|
|
902
|
-
# "",
|
|
903
|
-
# f"def {product}(*bands, model):",
|
|
904
|
-
# " logging.info(\"Start preparing CNN data with time-stacked bands\")",
|
|
905
|
-
# " data_arrays = list(bands)",
|
|
906
|
-
# " ",
|
|
907
|
-
# " if not data_arrays:",
|
|
908
|
-
# " raise ValueError(\"No bands provided\")",
|
|
909
|
-
# " ",
|
|
910
|
-
# ])
|
|
911
|
-
|
|
912
|
-
# # Add preprocessing call if preprocessing exists
|
|
913
|
-
# if preprocessing_section:
|
|
914
|
-
# script_lines.extend([
|
|
915
|
-
# " # Apply preprocessing",
|
|
916
|
-
# " data_arrays = preprocessing(tuple(data_arrays))",
|
|
917
|
-
# " data_arrays = list(data_arrays) # Convert back to list for processing",
|
|
918
|
-
# " ",
|
|
919
|
-
# ])
|
|
920
|
-
|
|
921
|
-
# # Continue with the rest of the processing logic
|
|
922
|
-
# script_lines.extend([
|
|
923
|
-
# " reference_array = data_arrays[0]",
|
|
924
|
-
# " original_shape = reference_array.shape",
|
|
925
|
-
# " logging.info(f\"Original shape: {original_shape}\")",
|
|
926
|
-
# " ",
|
|
927
|
-
# " # Get time coordinates - all bands should have the same time dimension",
|
|
928
|
-
# " if 'time' not in reference_array.dims:",
|
|
929
|
-
# " raise ValueError(\"Time dimension is required for CNN processing\")",
|
|
930
|
-
# " ",
|
|
931
|
-
# " time_coords = reference_array.coords['time']",
|
|
932
|
-
# " num_timestamps = len(time_coords)",
|
|
933
|
-
# " logging.info(f\"Number of timestamps: {num_timestamps}\")",
|
|
934
|
-
# " ",
|
|
935
|
-
# " # Get spatial dimensions",
|
|
936
|
-
# " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
|
|
937
|
-
# " height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
|
|
938
|
-
# " width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
|
|
939
|
-
# " logging.info(f\"Spatial dimensions: {height} x {width}\")",
|
|
940
|
-
# " ",
|
|
941
|
-
# " # Stack bands across time dimension",
|
|
942
|
-
# " # Result will be: (num_bands * num_timestamps, height, width)",
|
|
943
|
-
# " stacked_channels = []",
|
|
944
|
-
# " ",
|
|
945
|
-
# " for band_idx, data_array in enumerate(data_arrays):",
|
|
946
|
-
# " logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
|
|
947
|
-
# " ",
|
|
948
|
-
# " # Ensure consistent time coordinates across bands",
|
|
949
|
-
# " if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
|
|
950
|
-
# " logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
|
|
951
|
-
# " data_array = data_array.sel(time=time_coords, method='nearest')",
|
|
952
|
-
# " ",
|
|
953
|
-
# " # Extract values and ensure proper ordering (time, height, width)",
|
|
954
|
-
# " band_values = data_array.values",
|
|
955
|
-
# " if band_values.ndim == 3:",
|
|
956
|
-
# " # Reorder dimensions if needed to ensure (time, height, width)",
|
|
957
|
-
# " time_dim_idx = data_array.dims.index('time')",
|
|
958
|
-
# " if time_dim_idx != 0:",
|
|
959
|
-
# " axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]",
|
|
960
|
-
# " band_values = np.transpose(band_values, axes_order)",
|
|
961
|
-
# " ",
|
|
962
|
-
# " # Add each timestamp of this band to the channel stack",
|
|
963
|
-
# " for t in range(num_timestamps):",
|
|
964
|
-
# " stacked_channels.append(band_values[t])",
|
|
965
|
-
# " ",
|
|
966
|
-
# " # Stack all channels: (num_bands * num_timestamps, height, width)",
|
|
967
|
-
# " input_channels = np.stack(stacked_channels, axis=0)",
|
|
968
|
-
# " total_channels = len(data_arrays) * num_timestamps",
|
|
969
|
-
# " logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
|
|
970
|
-
# " logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
|
|
971
|
-
# " ",
|
|
972
|
-
# " # Add batch dimension: (1, num_channels, height, width)",
|
|
973
|
-
# " input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
|
|
974
|
-
# " logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
|
|
975
|
-
# " ",
|
|
976
|
-
# " # Run inference",
|
|
977
|
-
# " output = model.run(None, {\"float_input\": input_data})[0]",
|
|
978
|
-
# " logging.info(f\"Model output shape: {output.shape}\")",
|
|
979
|
-
# " ",
|
|
980
|
-
# " # UPDATED: Handle multi-class CNN output properly",
|
|
981
|
-
# " if output.ndim == 4:",
|
|
982
|
-
# " if output.shape[1] == 1:",
|
|
983
|
-
# " # Single class output (regression or binary classification)",
|
|
984
|
-
# " output_2d = output[0, 0]",
|
|
985
|
-
# " logging.info(\"Single channel output detected\")",
|
|
986
|
-
# " else:",
|
|
987
|
-
# " # Multi-class output - convert logits/probabilities to class predictions",
|
|
988
|
-
# " output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
|
|
989
|
-
# " output_2d = output_classes[0] # Shape: (height, width)",
|
|
990
|
-
# " ",
|
|
991
|
-
# " # Apply class merging: merge class 6 into class 3",
|
|
992
|
-
# " output_2d = np.where(output_2d == 6, 3, output_2d)",
|
|
993
|
-
# " ",
|
|
994
|
-
# " logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
|
|
995
|
-
# " logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
|
|
996
|
-
# " logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
|
|
997
|
-
# " elif output.ndim == 3:",
|
|
998
|
-
# " # Remove batch dimension",
|
|
999
|
-
# " output_2d = output[0]",
|
|
1000
|
-
# " logging.info(\"3D output detected, removed batch dimension\")",
|
|
1001
|
-
# " else:",
|
|
1002
|
-
# " # Handle other cases",
|
|
1003
|
-
# " output_2d = np.squeeze(output)",
|
|
1004
|
-
# " if output_2d.ndim != 2:",
|
|
1005
|
-
# " logging.error(f\"Cannot process output shape: {output.shape}\")",
|
|
1006
|
-
# " logging.error(f\"After squeeze: {output_2d.shape}\")",
|
|
1007
|
-
# " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
|
|
1008
|
-
# " logging.info(\"Applied squeeze to output\")",
|
|
1009
|
-
# " ",
|
|
1010
|
-
# " # Ensure output is 2D",
|
|
1011
|
-
# " if output_2d.ndim != 2:",
|
|
1012
|
-
# " raise ValueError(f\"Final output must be 2D, got shape: {output_2d.shape}\")",
|
|
1013
|
-
# " ",
|
|
1014
|
-
# " # Determine output timestamp (use the latest timestamp)",
|
|
1015
|
-
# " output_timestamp = time_coords[-1]",
|
|
1016
|
-
# " ",
|
|
1017
|
-
# " # Get spatial coordinates from reference array",
|
|
1018
|
-
# " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims}",
|
|
1019
|
-
# " ",
|
|
1020
|
-
# " # Create output DataArray with appropriate data type",
|
|
1021
|
-
# " # Use int32 for classification, float32 for regression",
|
|
1022
|
-
# " is_multiclass = output.ndim == 4 and output.shape[1] > 1",
|
|
1023
|
-
# " if is_multiclass:",
|
|
1024
|
-
# " # Multi-class classification - use integer type",
|
|
1025
|
-
# " output_dtype = np.int32",
|
|
1026
|
-
# " output_type = 'classification'",
|
|
1027
|
-
# " else:",
|
|
1028
|
-
# " # Single output - use float type",
|
|
1029
|
-
# " output_dtype = np.float32",
|
|
1030
|
-
# " output_type = 'regression'",
|
|
1031
|
-
# " ",
|
|
1032
|
-
# " result = xr.DataArray(",
|
|
1033
|
-
# " data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
|
|
1034
|
-
# " dims=['time'] + spatial_dims,",
|
|
1035
|
-
# " coords={",
|
|
1036
|
-
# " 'time': [output_timestamp.values],",
|
|
1037
|
-
# " spatial_dims[0]: spatial_coords[spatial_dims[0]].values,",
|
|
1038
|
-
# " spatial_dims[1]: spatial_coords[spatial_dims[1]].values",
|
|
1039
|
-
# " },",
|
|
1040
|
-
# " attrs={",
|
|
1041
|
-
# " 'description': 'CNN model prediction',",
|
|
1042
|
-
# " }",
|
|
1043
|
-
# " )",
|
|
1044
|
-
# " ",
|
|
1045
|
-
# " logging.info(f\"Final result shape: {result.shape}\")",
|
|
1046
|
-
# " logging.info(f\"Final result data type: {result.dtype}\")",
|
|
1047
|
-
# " logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
|
|
1048
|
-
# ])
|
|
1049
|
-
|
|
1050
|
-
# # Add postprocessing call if postprocessing exists
|
|
1051
|
-
# if postprocessing_section:
|
|
1052
|
-
# script_lines.extend([
|
|
1053
|
-
# " # Apply postprocessing",
|
|
1054
|
-
# " result = postprocessing(result)",
|
|
1055
|
-
# " ",
|
|
1056
|
-
# ])
|
|
1057
|
-
|
|
1058
|
-
# # Single return statement at the end
|
|
1059
|
-
# script_lines.append(" return result")
|
|
1060
|
-
|
|
1061
|
-
# return "\n".join(script_lines)
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
@require_api_key
|
|
1065
|
-
def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
|
|
1066
|
-
"""
|
|
1067
|
-
Generate Python inference script for CNN model with time-stacked bands.
|
|
1068
|
-
|
|
1069
|
-
Args:
|
|
1070
|
-
model_name: Name of the model
|
|
1071
|
-
product: Product name
|
|
1072
|
-
model_training_job_name: Training job name
|
|
1073
|
-
uid: User ID
|
|
1074
|
-
preprocessing_code: Preprocessing code
|
|
1075
|
-
postprocessing_code: Postprocessing code
|
|
1076
|
-
Returns:
|
|
1077
|
-
str: Generated Python script content
|
|
1078
|
-
"""
|
|
1079
|
-
import textwrap
|
|
1080
|
-
|
|
1081
|
-
# Build preprocessing section with CONSISTENT 4-space indentation
|
|
1082
614
|
preprocessing_section = ""
|
|
1083
615
|
if preprocessing_code and preprocessing_code.strip():
|
|
1084
616
|
clean_preprocessing = textwrap.dedent(preprocessing_code)
|
|
1085
617
|
preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
|
|
1086
618
|
|
|
1087
|
-
# Build postprocessing section with CONSISTENT 4-space indentation
|
|
1088
619
|
postprocessing_section = ""
|
|
1089
620
|
if postprocessing_code and postprocessing_code.strip():
|
|
1090
621
|
clean_postprocessing = textwrap.dedent(postprocessing_code)
|
|
1091
622
|
postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
|
|
1092
623
|
|
|
1093
|
-
# Build the template WITHOUT dedenting the whole thing, so indentation is preserved
|
|
1094
624
|
script_lines = [
|
|
1095
625
|
"import logging",
|
|
1096
626
|
"from io import BytesIO",
|
|
@@ -1107,7 +637,6 @@ class ModelManagement:
|
|
|
1107
637
|
"",
|
|
1108
638
|
]
|
|
1109
639
|
|
|
1110
|
-
# Add preprocessing validation function if preprocessing exists
|
|
1111
640
|
if preprocessing_section:
|
|
1112
641
|
script_lines.extend([
|
|
1113
642
|
"def validate_preprocessing_output(data_arrays):",
|
|
@@ -1125,18 +654,12 @@ class ModelManagement:
|
|
|
1125
654
|
" \"\"\"",
|
|
1126
655
|
" import numpy as np",
|
|
1127
656
|
" ",
|
|
1128
|
-
" logging.info(\"=\" * 60)",
|
|
1129
|
-
" logging.info(\"VALIDATING PREPROCESSING OUTPUT\")",
|
|
1130
|
-
" logging.info(\"=\" * 60)",
|
|
1131
|
-
" ",
|
|
1132
657
|
" if not data_arrays:",
|
|
1133
658
|
" raise ValueError(\"No data arrays provided from preprocessing\")",
|
|
1134
659
|
" ",
|
|
1135
660
|
" reference_shape = None",
|
|
1136
661
|
" ",
|
|
1137
662
|
" for i, data_array in enumerate(data_arrays):",
|
|
1138
|
-
" logging.info(f\"Validating channel {i+1}/{len(data_arrays)}: {data_array.name}\")",
|
|
1139
|
-
" ",
|
|
1140
663
|
" # Check if it's an xarray DataArray",
|
|
1141
664
|
" if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
|
|
1142
665
|
" raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
|
|
@@ -1153,12 +676,6 @@ class ModelManagement:
|
|
|
1153
676
|
" if dim not in data_array.coords:",
|
|
1154
677
|
" raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
|
|
1155
678
|
" ",
|
|
1156
|
-
" logging.info(f\" Coordinates: {list(data_array.coords.keys())}\")",
|
|
1157
|
-
" ",
|
|
1158
|
-
" # Check data type",
|
|
1159
|
-
" data_values = data_array.values",
|
|
1160
|
-
" logging.info(f\" Data type: {data_values.dtype}\")",
|
|
1161
|
-
" ",
|
|
1162
679
|
" # Check shape consistency",
|
|
1163
680
|
" shape = data_array.shape",
|
|
1164
681
|
" if reference_shape is None:",
|
|
@@ -1166,8 +683,6 @@ class ModelManagement:
|
|
|
1166
683
|
" else:",
|
|
1167
684
|
" if shape != reference_shape:",
|
|
1168
685
|
" raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
|
|
1169
|
-
" ",
|
|
1170
|
-
" logging.info(f\" Shape: {shape}\")",
|
|
1171
686
|
" ",
|
|
1172
687
|
" # Generate validation signature",
|
|
1173
688
|
" signature_components = [",
|
|
@@ -1179,19 +694,10 @@ class ModelManagement:
|
|
|
1179
694
|
" ",
|
|
1180
695
|
" signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
|
|
1181
696
|
" ",
|
|
1182
|
-
" logging.info(\"-\" * 60)",
|
|
1183
|
-
" logging.info(\"PREPROCESSING VALIDATION SUMMARY\")",
|
|
1184
|
-
" logging.info(\"-\" * 60)",
|
|
1185
|
-
" logging.info(f\"Channels validated: {len(data_arrays)}\")",
|
|
1186
|
-
" logging.info(f\"Common shape: {reference_shape}\")",
|
|
1187
|
-
" logging.info(f\"Validation signature: {signature}\")",
|
|
1188
|
-
" logging.info(\"=\" * 60)",
|
|
1189
|
-
" ",
|
|
1190
697
|
" return signature",
|
|
1191
698
|
"",
|
|
1192
699
|
])
|
|
1193
700
|
|
|
1194
|
-
# Add postprocessing validation function if postprocessing exists
|
|
1195
701
|
if postprocessing_section:
|
|
1196
702
|
script_lines.extend([
|
|
1197
703
|
"def validate_postprocessing_output(result_array):",
|
|
@@ -1209,10 +715,6 @@ class ModelManagement:
|
|
|
1209
715
|
" \"\"\"",
|
|
1210
716
|
" import numpy as np",
|
|
1211
717
|
" ",
|
|
1212
|
-
" logging.info(\"=\" * 60)",
|
|
1213
|
-
" logging.info(\"VALIDATING POSTPROCESSING OUTPUT\")",
|
|
1214
|
-
" logging.info(\"=\" * 60)",
|
|
1215
|
-
" ",
|
|
1216
718
|
" # Check if it's an xarray DataArray",
|
|
1217
719
|
" if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
|
|
1218
720
|
" raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
|
|
@@ -1229,39 +731,22 @@ class ModelManagement:
|
|
|
1229
731
|
" if dim not in result_array.coords:",
|
|
1230
732
|
" raise ValueError(f\"Missing spatial coordinate: {dim}\")",
|
|
1231
733
|
" ",
|
|
1232
|
-
" logging.info(f\"Coordinates found: {list(result_array.coords.keys())}\")",
|
|
1233
|
-
" ",
|
|
1234
|
-
" # Check data type",
|
|
1235
|
-
" data_values = result_array.values",
|
|
1236
|
-
" logging.info(f\"Data type: {data_values.dtype}\")",
|
|
1237
|
-
" ",
|
|
1238
734
|
" # Check shape",
|
|
1239
735
|
" shape = result_array.shape",
|
|
1240
|
-
" logging.info(f\"Shape: {shape}\")",
|
|
1241
736
|
" ",
|
|
1242
737
|
" # Generate validation signature",
|
|
1243
738
|
" signature_components = [",
|
|
1244
739
|
" f\"T{shape[0]}\", # Time dimension",
|
|
1245
740
|
" f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
|
|
1246
|
-
" f\"DT{
|
|
741
|
+
" f\"DT{result_array.values.dtype}\", # Data type",
|
|
1247
742
|
" ]",
|
|
1248
743
|
" ",
|
|
1249
744
|
" signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
|
|
1250
745
|
" ",
|
|
1251
|
-
" logging.info(\"-\" * 60)",
|
|
1252
|
-
" logging.info(\"POSTPROCESSING VALIDATION SUMMARY\")",
|
|
1253
|
-
" logging.info(\"-\" * 60)",
|
|
1254
|
-
" logging.info(f\"Final shape: {shape}\")",
|
|
1255
|
-
" logging.info(f\"Final coordinates: {list(result_array.coords.keys())}\")",
|
|
1256
|
-
" logging.info(f\"Data type: {data_values.dtype}\")",
|
|
1257
|
-
" logging.info(f\"Validation signature: {signature}\")",
|
|
1258
|
-
" logging.info(\"=\" * 60)",
|
|
1259
|
-
" ",
|
|
1260
746
|
" return signature",
|
|
1261
747
|
"",
|
|
1262
748
|
])
|
|
1263
749
|
|
|
1264
|
-
# Add preprocessing function definition BEFORE the main function
|
|
1265
750
|
if preprocessing_section:
|
|
1266
751
|
script_lines.extend([
|
|
1267
752
|
"def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
|
|
@@ -1269,7 +754,6 @@ class ModelManagement:
|
|
|
1269
754
|
"",
|
|
1270
755
|
])
|
|
1271
756
|
|
|
1272
|
-
# Add postprocessing function definition BEFORE the main function
|
|
1273
757
|
if postprocessing_section:
|
|
1274
758
|
script_lines.extend([
|
|
1275
759
|
"def postprocessing(array: xr.DataArray) -> xr.DataArray:",
|
|
@@ -1277,14 +761,13 @@ class ModelManagement:
|
|
|
1277
761
|
"",
|
|
1278
762
|
])
|
|
1279
763
|
|
|
1280
|
-
# Add the get_model function
|
|
1281
764
|
script_lines.extend([
|
|
1282
765
|
"def get_model():",
|
|
1283
|
-
f" logging.info(\"Loading CNN model for {
|
|
766
|
+
f" logging.info(\"Loading CNN model for {virtual_dataset_name}...\")",
|
|
1284
767
|
"",
|
|
1285
768
|
" client = storage.Client()",
|
|
1286
|
-
" bucket = client.get_bucket('
|
|
1287
|
-
f" blob = bucket.blob('{uid}/{
|
|
769
|
+
f" bucket = client.get_bucket('{bucket_name}')",
|
|
770
|
+
f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
|
|
1288
771
|
"",
|
|
1289
772
|
" model = BytesIO()",
|
|
1290
773
|
" blob.download_to_file(model)",
|
|
@@ -1293,7 +776,7 @@ class ModelManagement:
|
|
|
1293
776
|
" session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
|
|
1294
777
|
" return session",
|
|
1295
778
|
"",
|
|
1296
|
-
f"def {
|
|
779
|
+
f"def {virtual_product_name}(*bands, model):",
|
|
1297
780
|
" logging.info(\"Start preparing CNN data with time-stacked bands\")",
|
|
1298
781
|
" data_arrays = list(bands)",
|
|
1299
782
|
" ",
|
|
@@ -1302,7 +785,6 @@ class ModelManagement:
|
|
|
1302
785
|
" ",
|
|
1303
786
|
])
|
|
1304
787
|
|
|
1305
|
-
# Add preprocessing call and validation if preprocessing exists
|
|
1306
788
|
if preprocessing_section:
|
|
1307
789
|
script_lines.extend([
|
|
1308
790
|
" # Apply preprocessing",
|
|
@@ -1311,15 +793,12 @@ class ModelManagement:
|
|
|
1311
793
|
" ",
|
|
1312
794
|
" # Validate preprocessing output",
|
|
1313
795
|
" preprocessing_signature = validate_preprocessing_output(data_arrays)",
|
|
1314
|
-
" logging.info(f\"Preprocessing validation signature: {preprocessing_signature}\")",
|
|
1315
796
|
" ",
|
|
1316
797
|
])
|
|
1317
798
|
|
|
1318
|
-
# Continue with the rest of the processing logic
|
|
1319
799
|
script_lines.extend([
|
|
1320
800
|
" reference_array = data_arrays[0]",
|
|
1321
801
|
" original_shape = reference_array.shape",
|
|
1322
|
-
" logging.info(f\"Original shape: {original_shape}\")",
|
|
1323
802
|
" ",
|
|
1324
803
|
" # Get time coordinates - all bands should have the same time dimension",
|
|
1325
804
|
" if 'time' not in reference_array.dims:",
|
|
@@ -1327,24 +806,19 @@ class ModelManagement:
|
|
|
1327
806
|
" ",
|
|
1328
807
|
" time_coords = reference_array.coords['time']",
|
|
1329
808
|
" num_timestamps = len(time_coords)",
|
|
1330
|
-
" logging.info(f\"Number of timestamps: {num_timestamps}\")",
|
|
1331
809
|
" ",
|
|
1332
810
|
" # Get spatial dimensions",
|
|
1333
811
|
" spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
|
|
1334
812
|
" height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
|
|
1335
813
|
" width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
|
|
1336
|
-
" logging.info(f\"Spatial dimensions: {height} x {width}\")",
|
|
1337
814
|
" ",
|
|
1338
815
|
" # Stack bands across time dimension",
|
|
1339
816
|
" # Result will be: (num_bands * num_timestamps, height, width)",
|
|
1340
817
|
" stacked_channels = []",
|
|
1341
818
|
" ",
|
|
1342
819
|
" for band_idx, data_array in enumerate(data_arrays):",
|
|
1343
|
-
" logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
|
|
1344
|
-
" ",
|
|
1345
820
|
" # Ensure consistent time coordinates across bands",
|
|
1346
821
|
" if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
|
|
1347
|
-
" logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
|
|
1348
822
|
" data_array = data_array.sel(time=time_coords, method='nearest')",
|
|
1349
823
|
" ",
|
|
1350
824
|
" # Extract values and ensure proper ordering (time, height, width)",
|
|
@@ -1363,23 +837,18 @@ class ModelManagement:
|
|
|
1363
837
|
" # Stack all channels: (num_bands * num_timestamps, height, width)",
|
|
1364
838
|
" input_channels = np.stack(stacked_channels, axis=0)",
|
|
1365
839
|
" total_channels = len(data_arrays) * num_timestamps",
|
|
1366
|
-
" logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
|
|
1367
|
-
" logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
|
|
1368
840
|
" ",
|
|
1369
841
|
" # Add batch dimension: (1, num_channels, height, width)",
|
|
1370
842
|
" input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
|
|
1371
|
-
" logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
|
|
1372
843
|
" ",
|
|
1373
844
|
" # Run inference",
|
|
1374
845
|
" output = model.run(None, {\"float_input\": input_data})[0]",
|
|
1375
|
-
" logging.info(f\"Model output shape: {output.shape}\")",
|
|
1376
846
|
" ",
|
|
1377
|
-
" #
|
|
847
|
+
" # Handle multi-class CNN output properly",
|
|
1378
848
|
" if output.ndim == 4:",
|
|
1379
849
|
" if output.shape[1] == 1:",
|
|
1380
850
|
" # Single class output (regression or binary classification)",
|
|
1381
851
|
" output_2d = output[0, 0]",
|
|
1382
|
-
" logging.info(\"Single channel output detected\")",
|
|
1383
852
|
" else:",
|
|
1384
853
|
" # Multi-class output - convert logits/probabilities to class predictions",
|
|
1385
854
|
" output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
|
|
@@ -1387,22 +856,14 @@ class ModelManagement:
|
|
|
1387
856
|
" ",
|
|
1388
857
|
" # Apply class merging: merge class 6 into class 3",
|
|
1389
858
|
" output_2d = np.where(output_2d == 6, 3, output_2d)",
|
|
1390
|
-
" ",
|
|
1391
|
-
" logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
|
|
1392
|
-
" logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
|
|
1393
|
-
" logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
|
|
1394
859
|
" elif output.ndim == 3:",
|
|
1395
860
|
" # Remove batch dimension",
|
|
1396
861
|
" output_2d = output[0]",
|
|
1397
|
-
" logging.info(\"3D output detected, removed batch dimension\")",
|
|
1398
862
|
" else:",
|
|
1399
863
|
" # Handle other cases",
|
|
1400
864
|
" output_2d = np.squeeze(output)",
|
|
1401
865
|
" if output_2d.ndim != 2:",
|
|
1402
|
-
" logging.error(f\"Cannot process output shape: {output.shape}\")",
|
|
1403
|
-
" logging.error(f\"After squeeze: {output_2d.shape}\")",
|
|
1404
866
|
" raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
|
|
1405
|
-
" logging.info(\"Applied squeeze to output\")",
|
|
1406
867
|
" ",
|
|
1407
868
|
" # Ensure output is 2D",
|
|
1408
869
|
" if output_2d.ndim != 2:",
|
|
@@ -1420,11 +881,9 @@ class ModelManagement:
|
|
|
1420
881
|
" if is_multiclass:",
|
|
1421
882
|
" # Multi-class classification - use integer type",
|
|
1422
883
|
" output_dtype = np.int32",
|
|
1423
|
-
" output_type = 'classification'",
|
|
1424
884
|
" else:",
|
|
1425
885
|
" # Single output - use float type",
|
|
1426
886
|
" output_dtype = np.float32",
|
|
1427
|
-
" output_type = 'regression'",
|
|
1428
887
|
" ",
|
|
1429
888
|
" result = xr.DataArray(",
|
|
1430
889
|
" data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
|
|
@@ -1438,13 +897,8 @@ class ModelManagement:
|
|
|
1438
897
|
" 'description': 'CNN model prediction',",
|
|
1439
898
|
" }",
|
|
1440
899
|
" )",
|
|
1441
|
-
" ",
|
|
1442
|
-
" logging.info(f\"Final result shape: {result.shape}\")",
|
|
1443
|
-
" logging.info(f\"Final result data type: {result.dtype}\")",
|
|
1444
|
-
" logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
|
|
1445
900
|
])
|
|
1446
901
|
|
|
1447
|
-
# Add postprocessing call and validation if postprocessing exists
|
|
1448
902
|
if postprocessing_section:
|
|
1449
903
|
script_lines.extend([
|
|
1450
904
|
" # Apply postprocessing",
|
|
@@ -1452,21 +906,223 @@ class ModelManagement:
|
|
|
1452
906
|
" ",
|
|
1453
907
|
" # Validate postprocessing output",
|
|
1454
908
|
" postprocessing_signature = validate_postprocessing_output(result)",
|
|
1455
|
-
" logging.info(f\"Postprocessing validation signature: {postprocessing_signature}\")",
|
|
1456
909
|
" ",
|
|
1457
910
|
])
|
|
1458
911
|
|
|
1459
|
-
# Single return statement at the end
|
|
1460
912
|
script_lines.append(" return result")
|
|
1461
913
|
|
|
1462
914
|
return "\n".join(script_lines)
|
|
1463
915
|
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
916
|
+
def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
|
|
917
|
+
"""
|
|
918
|
+
Parse a Python file and extract preprocessing and postprocessing function bodies.
|
|
919
|
+
|
|
920
|
+
Args:
|
|
921
|
+
script_path: Path to the Python file containing processing functions
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
Tuple of (preprocessing_code, postprocessing_code) where each can be None
|
|
925
|
+
"""
|
|
926
|
+
try:
|
|
927
|
+
with open(script_path, 'r', encoding='utf-8') as f:
|
|
928
|
+
script_content = f.read()
|
|
929
|
+
except FileNotFoundError:
|
|
930
|
+
raise FileNotFoundError(f"Processing script not found: {script_path}")
|
|
931
|
+
except Exception as e:
|
|
932
|
+
raise ValueError(f"Error reading processing script: {e}")
|
|
933
|
+
|
|
934
|
+
if not script_content.strip():
|
|
935
|
+
self._client.logger.info(f"Processing script {script_path} is empty")
|
|
936
|
+
return None, None
|
|
937
|
+
|
|
938
|
+
try:
|
|
939
|
+
tree = ast.parse(script_content)
|
|
940
|
+
except SyntaxError as e:
|
|
941
|
+
raise ValueError(f"Syntax error in processing script: {e}")
|
|
942
|
+
|
|
943
|
+
preprocessing_code = None
|
|
944
|
+
postprocessing_code = None
|
|
945
|
+
|
|
946
|
+
function_names = []
|
|
947
|
+
for node in ast.walk(tree):
|
|
948
|
+
if isinstance(node, ast.FunctionDef):
|
|
949
|
+
function_names.append(node.name)
|
|
950
|
+
if node.name == 'preprocessing':
|
|
951
|
+
preprocessing_code = self._extract_function_body(script_content, node)
|
|
952
|
+
elif node.name == 'postprocessing':
|
|
953
|
+
postprocessing_code = self._extract_function_body(script_content, node)
|
|
954
|
+
|
|
955
|
+
if not function_names:
|
|
956
|
+
self._client.logger.warning(f"No functions found in processing script: {script_path}")
|
|
957
|
+
else:
|
|
958
|
+
found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
|
|
959
|
+
if found_functions:
|
|
960
|
+
self._client.logger.info(f"Found processing functions: {found_functions}")
|
|
961
|
+
else:
|
|
962
|
+
self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
|
|
963
|
+
f"Available functions: {function_names}")
|
|
964
|
+
|
|
965
|
+
return preprocessing_code, postprocessing_code
|
|
966
|
+
|
|
967
|
+
def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
|
|
968
|
+
"""Extract the body of a function from the script content."""
|
|
969
|
+
lines = script_content.split('\n')
|
|
970
|
+
|
|
971
|
+
start_line = func_node.lineno - 1
|
|
972
|
+
end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
|
|
973
|
+
|
|
974
|
+
body_lines = []
|
|
975
|
+
for i in range(start_line + 1, end_line + 1):
|
|
976
|
+
if i < len(lines):
|
|
977
|
+
body_lines.append(lines[i])
|
|
978
|
+
|
|
979
|
+
if not body_lines:
|
|
980
|
+
return ""
|
|
981
|
+
|
|
982
|
+
body_text = '\n'.join(body_lines)
|
|
983
|
+
cleaned_body = textwrap.dedent(body_text).strip()
|
|
984
|
+
|
|
985
|
+
if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
|
|
986
|
+
return ""
|
|
987
|
+
|
|
988
|
+
return cleaned_body
|
|
989
|
+
|
|
990
|
+
def _convert_model_to_onnx(self, model, input_shape: Tuple[int, ...] = None, model_type: Optional[str] = None) -> bytes:
|
|
991
|
+
"""
|
|
992
|
+
Convert a model to ONNX format and return as bytes.
|
|
993
|
+
|
|
994
|
+
Args:
|
|
995
|
+
model: The model object (PyTorch or scikit-learn)
|
|
996
|
+
input_shape: Shape of input data
|
|
997
|
+
model_type: Type of model (neural_network, random_forest), only used for onnx model generation
|
|
998
|
+
Returns:
|
|
999
|
+
bytes: ONNX model as bytes
|
|
1000
|
+
|
|
1001
|
+
Raises:
|
|
1002
|
+
ValueError: If model type is not supported
|
|
1003
|
+
ImportError: If required libraries are not installed
|
|
1004
|
+
"""
|
|
1005
|
+
if isinstance(model, torch.nn.Module):
|
|
1006
|
+
if not TORCH_AVAILABLE:
|
|
1007
|
+
raise ImportError("PyTorch is not installed. Please install it with: pip install torch")
|
|
1008
|
+
return self._convert_pytorch_to_onnx(model, input_shape), "neural_network"
|
|
1009
|
+
elif isinstance(model, BaseEstimator):
|
|
1010
|
+
if not SKL2ONNX_AVAILABLE:
|
|
1011
|
+
raise ImportError("skl2onnx is not installed. Please install it with: pip install skl2onnx")
|
|
1012
|
+
return self._convert_sklearn_to_onnx(model, input_shape), "random_forest"
|
|
1013
|
+
elif isinstance(model, ort.InferenceSession):
|
|
1014
|
+
if model_type is None:
|
|
1015
|
+
raise ValueError(
|
|
1016
|
+
"For ONNX InferenceSession models, you must specify the 'model_type' parameter. Currently 'nerual network' and 'random forest' are supported."
|
|
1017
|
+
"Example: model_type='random forest' or model_type='neural network'"
|
|
1018
|
+
)
|
|
1019
|
+
return model.SerializeToString(), model_type
|
|
1020
|
+
else:
|
|
1021
|
+
model_type = type(model).__name__
|
|
1022
|
+
raise ValueError(f"Unsupported model type: {model_type}. Supported types: PyTorch nn.Module, sklearn BaseEstimator")
|
|
1023
|
+
|
|
1024
|
+
def _convert_pytorch_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
|
|
1025
|
+
try:
|
|
1026
|
+
model.eval()
|
|
1027
|
+
dummy_input = torch.randn(input_shape)
|
|
1028
|
+
onnx_buffer = BytesIO()
|
|
1029
|
+
|
|
1030
|
+
if len(input_shape) == 4:
|
|
1031
|
+
dynamic_axes = {
|
|
1032
|
+
'float_input': {
|
|
1033
|
+
0: 'batch_size',
|
|
1034
|
+
2: 'height',
|
|
1035
|
+
3: 'width'
|
|
1036
|
+
}
|
|
1037
|
+
}
|
|
1038
|
+
|
|
1039
|
+
elif len(input_shape) == 5:
|
|
1040
|
+
dynamic_axes = {
|
|
1041
|
+
'float_input': {
|
|
1042
|
+
0: 'batch_size',
|
|
1043
|
+
3: 'height',
|
|
1044
|
+
4: 'width'
|
|
1045
|
+
}
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
else:
|
|
1049
|
+
dynamic_axes = {
|
|
1050
|
+
'float_input': {
|
|
1051
|
+
0: 'batch_size'
|
|
1052
|
+
}
|
|
1053
|
+
}
|
|
1054
|
+
|
|
1055
|
+
torch.onnx.export(
|
|
1056
|
+
model,
|
|
1057
|
+
dummy_input,
|
|
1058
|
+
onnx_buffer,
|
|
1059
|
+
input_names=['float_input'],
|
|
1060
|
+
dynamic_axes=dynamic_axes
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
return onnx_buffer.getvalue()
|
|
1064
|
+
except Exception as e:
|
|
1065
|
+
raise ValueError(f"Failed to convert PyTorch model to ONNX: {str(e)}")
|
|
1066
|
+
|
|
1067
|
+
def _convert_sklearn_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
|
|
1068
|
+
"""
|
|
1069
|
+
Convert scikit-learn model(assume it is a random forest model) to ONNX format.
|
|
1070
|
+
|
|
1071
|
+
Args:
|
|
1072
|
+
model: The scikit-learn model object
|
|
1073
|
+
input_shape: Shape of input data (required)
|
|
1074
|
+
|
|
1075
|
+
Returns:
|
|
1076
|
+
bytes: ONNX model as bytes
|
|
1077
|
+
|
|
1078
|
+
Raises:
|
|
1079
|
+
ValueError: If conversion fails
|
|
1080
|
+
"""
|
|
1081
|
+
self._client.logger.info(f"Converting random forest model to ONNX...")
|
|
1082
|
+
|
|
1083
|
+
try:
|
|
1084
|
+
initial_type = [('float_input', FloatTensorType(input_shape))]
|
|
1085
|
+
onnx_model = convert_sklearn(model, initial_types=initial_type)
|
|
1086
|
+
return onnx_model.SerializeToString()
|
|
1087
|
+
except Exception as e:
|
|
1088
|
+
raise ValueError(f"Failed to convert scikit-learn model to ONNX: {str(e)}")
|
|
1467
1089
|
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1090
|
+
@require_api_key
|
|
1091
|
+
def train_model(
|
|
1092
|
+
self,
|
|
1093
|
+
model_name: str,
|
|
1094
|
+
training_dataset: str,
|
|
1095
|
+
task_type: str,
|
|
1096
|
+
model_category: str,
|
|
1097
|
+
architecture: str,
|
|
1098
|
+
region: str,
|
|
1099
|
+
hyperparameters: dict = None
|
|
1100
|
+
) -> dict:
|
|
1101
|
+
"""
|
|
1102
|
+
Train a model using the external model training API.
|
|
1103
|
+
|
|
1104
|
+
Args:
|
|
1105
|
+
model_name (str): The name of the model to train.
|
|
1106
|
+
training_dataset (str): The training dataset identifier.
|
|
1107
|
+
task_type (str): The type of ML task (e.g., regression, classification).
|
|
1108
|
+
model_category (str): The category of model (e.g., random_forest).
|
|
1109
|
+
architecture (str): The model architecture.
|
|
1110
|
+
region (str): The region identifier.
|
|
1111
|
+
hyperparameters (dict, optional): Additional hyperparameters for training.
|
|
1112
|
+
|
|
1113
|
+
Returns:
|
|
1114
|
+
dict: The response from the model training API.
|
|
1115
|
+
|
|
1116
|
+
Raises:
|
|
1117
|
+
APIError: If the API request fails
|
|
1118
|
+
"""
|
|
1119
|
+
payload = {
|
|
1120
|
+
"model_name": model_name,
|
|
1121
|
+
"training_dataset": training_dataset,
|
|
1122
|
+
"task_type": task_type,
|
|
1123
|
+
"model_category": model_category,
|
|
1124
|
+
"architecture": architecture,
|
|
1125
|
+
"region": region,
|
|
1126
|
+
"hyperparameters": hyperparameters
|
|
1127
|
+
}
|
|
1128
|
+
return self._client._terrakio_request("POST", "/train_model", json=payload)
|