terrakio-core 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

@@ -1,43 +1,43 @@
1
- import os
1
+ # Standard library imports
2
+ import ast
2
3
  import json
3
- import time
4
4
  import textwrap
5
- import logging
6
- from typing import Dict, Any, Union, Tuple, Optional
5
+ import time
7
6
  from io import BytesIO
8
- import numpy as np
9
- from google.cloud import storage
10
- import ast
11
- from ..helper.decorators import require_token, require_api_key, require_auth
7
+ from typing import Optional, Tuple
8
+ import onnxruntime as ort
9
+
10
+ # Internal imports
11
+ from ..helper.decorators import require_api_key
12
+
13
+ # Optional dependency flags
12
14
  TORCH_AVAILABLE = False
13
15
  SKL2ONNX_AVAILABLE = False
14
16
 
17
+ # PyTorch imports
15
18
  try:
16
19
  import torch
17
20
  TORCH_AVAILABLE = True
18
21
  except ImportError:
19
22
  torch = None
20
23
 
24
+ # Scikit-learn and ONNX conversion imports
21
25
  try:
26
+ from sklearn.base import BaseEstimator
22
27
  from skl2onnx import convert_sklearn
23
28
  from skl2onnx.common.data_types import FloatTensorType
24
- from sklearn.base import BaseEstimator
25
29
  SKL2ONNX_AVAILABLE = True
26
30
  except ImportError:
31
+ BaseEstimator = None
27
32
  convert_sklearn = None
28
33
  FloatTensorType = None
29
- BaseEstimator = None
30
-
31
- from io import BytesIO
32
- from typing import Tuple
33
-
34
34
 
35
35
  class ModelManagement:
36
36
  def __init__(self, client):
37
37
  self._client = client
38
38
 
39
39
  @require_api_key
40
- def generate_ai_dataset(
40
+ async def generate_ai_dataset(
41
41
  self,
42
42
  name: str,
43
43
  aoi_geojson: str,
@@ -51,7 +51,8 @@ class ModelManagement:
51
51
  filter_y: str = "skip",
52
52
  crs: str = "epsg:4326",
53
53
  res: float = 0.001,
54
- region: str = "aus",
54
+ region: str = None,
55
+ bucket: str = None,
55
56
  start_year: int = None,
56
57
  end_year: int = None,
57
58
  ) -> dict:
@@ -71,7 +72,8 @@ class ModelManagement:
71
72
  tile_size (int): Size of tiles in degrees
72
73
  crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
73
74
  res (float, optional): Resolution in degrees. Defaults to 0.001
74
- region (str, optional): Region code. Defaults to "aus"
75
+ region (str, optional): Region code. Defaults to None
76
+ bucket (str, optional): Bucket name. Defaults to None
75
77
  start_year (int, optional): Start year for data generation. Required if end_year provided
76
78
  end_year (int, optional): End year for data generation. Required if start_year provided
77
79
 
@@ -109,7 +111,7 @@ class ModelManagement:
109
111
  with open(aoi_geojson, 'r') as f:
110
112
  aoi_data = json.load(f)
111
113
 
112
- task_response = self._client.mass_stats.random_sample(
114
+ task_response = await self._client.mass_stats.random_sample(
113
115
  name=name,
114
116
  config=config,
115
117
  aoi=aoi_data,
@@ -121,19 +123,17 @@ class ModelManagement:
121
123
  region=region,
122
124
  output="netcdf",
123
125
  server=self._client.url,
124
- bucket="terrakio-mass-requests",
126
+ bucket=bucket,
125
127
  overwrite=True
126
128
  )
127
129
  task_id = task_response["task_id"]
128
130
 
129
- # Wait for job completion with progress bar
130
131
  while True:
131
- result = self._client.track_mass_stats_job(ids=[task_id])
132
+ result = await self._client.mass_stats.track_job(ids=[task_id])
132
133
  status = result[task_id]['status']
133
134
  completed = result[task_id].get('completed', 0)
134
135
  total = result[task_id].get('total', 1)
135
136
 
136
- # Create progress bar
137
137
  progress = completed / total if total > 0 else 0
138
138
  bar_length = 50
139
139
  filled_length = int(bar_length * progress)
@@ -149,454 +149,456 @@ class ModelManagement:
149
149
  self._client.logger.info("Job encountered an error")
150
150
  raise Exception(f"Job {task_id} encountered an error")
151
151
 
152
- # Wait 5 seconds before checking again
153
152
  time.sleep(5)
154
153
 
155
- # after all the random sample jobs are done, we then start the mass stats job
156
- task_id = self._client.mass_stats.start_mass_stats_job(task_id)
154
+ task_id = await self._client.mass_stats.start_job(task_id)
157
155
  return task_id
158
156
 
159
157
  @require_api_key
160
- async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
158
+ async def _get_url_for_upload_model_and_script(self, expression: str, model_name: str, script_name: str) -> str:
161
159
  """
162
- Upload a model to the bucket so that it can be used for inference.
163
- Converts PyTorch and scikit-learn models to ONNX format before uploading.
164
-
160
+ Get the url for the upload of the model
165
161
  Args:
166
- model: The model object (PyTorch model or scikit-learn model)
167
- model_name: Name for the model (without extension)
168
- input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
169
- Required for PyTorch models, optional for scikit-learn models
170
-
171
- Raises:
172
- APIError: If the API request fails
173
- ValueError: If model type is not supported or input_shape is missing for PyTorch models
174
- ImportError: If required libraries (torch or skl2onnx) are not installed
162
+ expression: The expression to use for the upload(for deciding which bucket to upload to)
163
+ model_name: The name of the model to upload
164
+ script_name: The name of the script to upload
165
+ Returns:
166
+ The url for the upload of the model
175
167
  """
176
- uid = (await self._client.auth.get_user_info())["uid"]
177
-
178
- client = storage.Client()
179
- bucket = client.get_bucket('terrakio-mass-requests')
180
-
181
- # Convert model to ONNX format
182
- onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
183
-
184
- # Upload ONNX model to bucket
185
- blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
186
-
187
- blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
188
- self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
168
+ payload = {
169
+ "model_name": model_name,
170
+ "expression": expression,
171
+ "script_name": script_name
172
+ }
173
+ return await self._client._terrakio_request("POST", "models/upload", json=payload)
189
174
 
190
- def _convert_model_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
175
+ async def _upload_model_to_url(self, upload_model_url: str, model: bytes):
191
176
  """
192
- Convert a model to ONNX format and return as bytes.
193
-
177
+ Upload a model to a given URL.
194
178
  Args:
195
- model: The model object (PyTorch or scikit-learn)
196
- model_name: Name of the model for logging
197
- input_shape: Shape of input data
198
-
179
+ model_url: The url to upload the model to
180
+ model: The model to upload
181
+
199
182
  Returns:
200
- bytes: ONNX model as bytes
201
-
202
- Raises:
203
- ValueError: If model type is not supported
204
- ImportError: If required libraries are not installed
183
+ The response from the server
205
184
  """
206
- # Early check for any conversion capability
207
- if not (TORCH_AVAILABLE or SKL2ONNX_AVAILABLE):
208
- raise ImportError(
209
- "ONNX conversion requires additional dependencies. Install with:\n"
210
- " pip install torch # For PyTorch models\n"
211
- " pip install skl2onnx # For scikit-learn models\n"
212
- " pip install torch skl2onnx # For both"
213
- )
214
-
215
- # Check if it's a PyTorch model using isinstance (preferred) with fallback
216
- is_pytorch = False
217
- if TORCH_AVAILABLE:
218
- is_pytorch = (isinstance(model, torch.nn.Module) or
219
- hasattr(model, 'state_dict'))
220
-
221
- # Check if it's a scikit-learn model
222
- is_sklearn = False
223
- if SKL2ONNX_AVAILABLE:
224
- is_sklearn = (isinstance(model, BaseEstimator) or
225
- (hasattr(model, 'fit') and hasattr(model, 'predict')))
226
-
227
- if is_pytorch and TORCH_AVAILABLE:
228
- return self._convert_pytorch_to_onnx(model, model_name, input_shape)
229
- elif is_sklearn and SKL2ONNX_AVAILABLE:
230
- return self._convert_sklearn_to_onnx(model, model_name, input_shape)
231
- else:
232
- # Provide helpful error message
233
- model_type = type(model).__name__
234
- model_module = type(model).__module__
235
- available_types = []
236
- missing_deps = []
237
-
238
- if TORCH_AVAILABLE:
239
- available_types.append("PyTorch (torch.nn.Module)")
240
- else:
241
- missing_deps.append("torch")
242
-
243
- if SKL2ONNX_AVAILABLE:
244
- available_types.append("scikit-learn (BaseEstimator)")
245
- else:
246
- missing_deps.append("skl2onnx")
247
-
248
- if missing_deps:
249
- raise ImportError(
250
- f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
251
- f"Install with: pip install {' '.join(missing_deps)}"
252
- )
253
- else:
254
- raise ValueError(
255
- f"Unsupported model type: {model_type} from {model_module}. "
256
- f"Supported types: {', '.join(available_types)}"
257
- )
258
-
259
- def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
260
- """Convert PyTorch model to ONNX format with dynamic input dimensions."""
261
- if input_shape is None:
262
- raise ValueError("input_shape is required for PyTorch models")
263
-
264
- self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
265
-
266
- try:
267
- # Set model to evaluation mode
268
- model.eval()
269
-
270
- # Create dummy input
271
- dummy_input = torch.randn(input_shape)
272
-
273
- # Use BytesIO to avoid creating temporary files
274
- onnx_buffer = BytesIO()
275
-
276
- # Determine dynamic axes based on input shape
277
- # Common patterns for different input types:
278
- if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
279
- dynamic_axes = {
280
- 'float_input': {
281
- 0: 'batch_size',
282
- 2: 'height', # Make height dynamic for variable input sizes
283
- 3: 'width' # Make width dynamic for variable input sizes
284
- },
285
- 'output': {0: 'batch_size'}
286
- }
287
- elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
288
- dynamic_axes = {
289
- 'float_input': {
290
- 0: 'batch_size',
291
- 1: 'dim1', # Generic dynamic dimension
292
- 2: 'dim2' # Generic dynamic dimension
293
- },
294
- 'output': {0: 'batch_size'}
295
- }
296
- elif len(input_shape) == 2: # Likely (batch, features)
297
- dynamic_axes = {
298
- 'float_input': {
299
- 0: 'batch_size'
300
- # Don't make features dynamic as it usually affects model architecture
301
- },
302
- 'output': {0: 'batch_size'}
303
- }
304
- else:
305
- # For other shapes, just make batch size dynamic
306
- dynamic_axes = {
307
- 'float_input': {0: 'batch_size'},
308
- 'output': {0: 'batch_size'}
309
- }
310
-
311
- torch.onnx.export(
312
- model,
313
- dummy_input,
314
- onnx_buffer,
315
- export_params=True,
316
- opset_version=11,
317
- do_constant_folding=True,
318
- input_names=['float_input'],
319
- output_names=['output'],
320
- dynamic_axes=dynamic_axes
321
- )
322
-
323
- self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
324
- return onnx_buffer.getvalue()
325
-
326
- except Exception as e:
327
- raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
328
-
329
-
330
- def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
331
- """Convert scikit-learn model to ONNX format."""
332
- self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
333
-
334
- # Try to infer input shape if not provided
335
- if input_shape is None:
336
- if hasattr(model, 'n_features_in_'):
337
- input_shape = (1, model.n_features_in_)
338
- else:
339
- raise ValueError(
340
- "input_shape is required for scikit-learn models when n_features_in_ is not available. "
341
- "This usually happens with older sklearn versions or models not fitted yet."
342
- )
343
-
344
- try:
345
- # Convert scikit-learn model to ONNX
346
- initial_type = [('float_input', FloatTensorType(input_shape))]
347
- onnx_model = convert_sklearn(model, initial_types=initial_type)
348
- return onnx_model.SerializeToString()
349
-
350
- except Exception as e:
351
- raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
352
-
185
+ headers = {
186
+ "Content-Type": "application/octet-stream",
187
+ "Content-Length": str(len(model))
188
+ }
189
+ response = await self._client._regular_request("PUT", endpoint = upload_model_url, data=model, headers=headers)
190
+ return response
191
+
192
+ @require_api_key
193
+ async def _upload_script_to_url(self, upload_script_url: str, script_content: str):
194
+ """
195
+ Upload the generated script to the url
196
+ Args:
197
+ url: Url for the upload of the script
198
+ script_content: Content of the script
199
+ returns:
200
+ None
201
+ """
202
+ script_bytes = script_content.encode('utf-8')
203
+ headers = {
204
+ "Content-Type": "text/x-python",
205
+ "Content-Length": str(len(script_bytes))
206
+ }
207
+ response = await self._client._regular_request("PUT", endpoint=upload_script_url, data=script_bytes, headers=headers)
208
+ return response
353
209
 
354
210
  @require_api_key
355
- async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
211
+ async def _upload_model_and_script(self, model, model_name: str, script_name: str, input_expression: str, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
356
212
  """
357
- Upload a CNN model to the bucket and deploy it.
358
-
213
+ Upload a model and script to the bucket
359
214
  Args:
360
215
  model: The model object (PyTorch model or scikit-learn model)
361
216
  model_name: Name for the model (without extension)
362
- dataset: Name of the dataset to create
363
- product: Product name for the inference
217
+ script_name: Name for the script (without extension)
364
218
  input_expression: Input expression for the dataset
365
- dates_iso8601: List of dates in ISO8601 format
366
219
  input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
367
220
  processing_script_path: Path to the processing script, if not provided, no processing will be done
368
-
221
+ model_type: The type of the model we want to upload
369
222
  Raises:
370
223
  APIError: If the API request fails
371
224
  ValueError: If model type is not supported or input_shape is missing for PyTorch models
372
- ImportError: If required libraries (torch or skl2onnx) are not installed
225
+
226
+ Returns:
227
+ bucket_name: Name of the bucket where the model is stored
373
228
  """
374
- await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
375
- # so the uploading process is kinda similar, but the deployment step is kinda different
376
- # we should pass the processing script path to the deploy cnn model function
377
- await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
229
+ response = await self._get_url_for_upload_model_and_script(expression = input_expression, model_name = model_name, script_name = script_name)
230
+ model_url, script_url, bucket_name = response.get("model_upload_url"), response.get("script_upload_url"), response.get("bucket_name")
231
+ if not model_url or not script_url:
232
+ raise ValueError("No url returned from the server for the upload process")
233
+ try:
234
+ model_in_onnx_bytes, model_type = self._convert_model_to_onnx(model = model, input_shape = input_shape, model_type = model_type)
235
+ if model_type == "neural_network":
236
+ script_content = await self._generate_cnn_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
237
+ elif model_type == "random_forest":
238
+ script_content = await self._generate_random_forest_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
239
+ else:
240
+ raise ValueError(f"Unsupported model type: {model_type}. Supported types: neural_network, random_forest")
241
+ script_upload_response = await self._upload_script_to_url( upload_script_url = script_url, script_content = script_content)
242
+ if script_upload_response.status not in [200, 201, 204]:
243
+ self._client.logger.error(f"Script upload error: {script_upload_response.text()}")
244
+ raise Exception(f"Failed to upload script: {script_upload_response.text()}")
245
+ model_upload_response = await self._upload_model_to_url(upload_model_url = model_url, model = model_in_onnx_bytes)
246
+ if model_upload_response.status not in [200, 201, 204]:
247
+ self._client.logger.error(f"Model upload error: {model_upload_response.text()}")
248
+ raise Exception(f"Failed to upload model: {model_upload_response.text()}")
249
+ except Exception as e:
250
+ raise Exception(f"Error uploading model: {e}")
251
+ self._client.logger.info(f"Model and Script uploaded successfully to {model_url}")
252
+ return bucket_name
378
253
 
379
254
  @require_api_key
380
- async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
255
+ async def upload_and_deploy_model(self, model, virtual_dataset_name: str, virtual_product_name: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
381
256
  """
382
257
  Upload a model to the bucket and deploy it.
383
-
384
258
  Args:
385
259
  model: The model object (PyTorch model or scikit-learn model)
386
- model_name: Name for the model (without extension)
387
- dataset: Name of the dataset to create
388
- product: Product name for the inference
260
+ virtual_dataset_name: Name for the virtual dataset (without extension)
261
+ virtual_product_name: Product name for the inference
389
262
  input_expression: Input expression for the dataset
390
263
  dates_iso8601: List of dates in ISO8601 format
391
264
  input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
392
- """
393
- await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
394
- await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
265
+ processing_script_path: Path to the processing script, if not provided, no processing will be done
266
+ model_type: The type of the model we want to upload
395
267
 
396
- @require_api_key
397
- def train_model(
398
- self,
399
- model_name: str,
400
- training_dataset: str,
401
- task_type: str,
402
- model_category: str,
403
- architecture: str,
404
- region: str,
405
- hyperparameters: dict = None
406
- ) -> dict:
407
- """
408
- Train a model using the external model training API.
409
-
410
- Args:
411
- model_name (str): The name of the model to train.
412
- training_dataset (str): The training dataset identifier.
413
- task_type (str): The type of ML task (e.g., regression, classification).
414
- model_category (str): The category of model (e.g., random_forest).
415
- architecture (str): The model architecture.
416
- region (str): The region identifier.
417
- hyperparameters (dict, optional): Additional hyperparameters for training.
418
-
419
- Returns:
420
- dict: The response from the model training API.
421
-
422
268
  Raises:
423
269
  APIError: If the API request fails
424
- """
425
- payload = {
426
- "model_name": model_name,
427
- "training_dataset": training_dataset,
428
- "task_type": task_type,
429
- "model_category": model_category,
430
- "architecture": architecture,
431
- "region": region,
432
- "hyperparameters": hyperparameters
433
- }
434
- return self._client._terrakio_request("POST", "/train_model", json=payload)
270
+ ValueError: If model type is not supported or input_shape is missing for PyTorch models
271
+ ImportError: If required libraries (torch or skl2onnx) are not installed
435
272
 
436
- @require_api_key
437
- async def deploy_model(
438
- self,
439
- dataset: str,
440
- product: str,
441
- model_name: str,
442
- input_expression: str,
443
- model_training_job_name: str,
444
- dates_iso8601: list
445
- ) -> Dict[str, Any]:
446
- """
447
- Deploy a model by generating inference script and creating dataset.
448
-
449
- Args:
450
- dataset: Name of the dataset to create
451
- product: Product name for the inference
452
- model_name: Name of the trained model
453
- input_expression: Input expression for the dataset
454
- model_training_job_name: Name of the training job
455
- dates_iso8601: List of dates in ISO8601 format
456
-
457
273
  Returns:
458
- dict: Response from the deployment process
459
-
460
- Raises:
461
- APIError: If the API request fails
274
+ None
462
275
  """
463
- # Get user info to get UID
276
+ bucket_name = await self._upload_model_and_script(model=model, model_name=virtual_dataset_name, script_name= virtual_product_name, input_shape=input_shape, input_expression=input_expression, processing_script_path=processing_script_path, model_type= model_type)
464
277
  user_info = await self._client.auth.get_user_info()
465
278
  uid = user_info["uid"]
466
-
467
- # Generate and upload script
468
- script_content = self._generate_script(model_name, product, model_training_job_name, uid)
469
- script_name = f"{product}.py"
470
- self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
471
-
472
- # Create dataset
473
- return await self._client.datasets.create_dataset(
474
- name=dataset,
279
+ await self._client.datasets.create_dataset(
280
+ name=virtual_dataset_name,
475
281
  collection="terrakio-datasets",
476
- products=[product],
477
- path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
282
+ products=[virtual_product_name],
283
+ path=f"gs://{bucket_name}/{uid}/virtual_datasets/{virtual_dataset_name}/inference_scripts",
478
284
  input=input_expression,
479
285
  dates_iso8601=dates_iso8601,
480
286
  padding=0
481
287
  )
482
288
 
483
- def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
289
+ @require_api_key
290
+ async def _generate_random_forest_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
484
291
  """
485
- Parse a Python file and extract preprocessing and postprocessing function bodies.
292
+ Generate Python inference script for the Random Forest model.
486
293
 
487
294
  Args:
488
- script_path: Path to the Python file containing processing functions
295
+ bucket_name: Name of the bucket where the model is stored
296
+ virtual_dataset_name: Name of the virtual dataset and the model
297
+ virtual_product_name: Name of the virtual product
298
+ processing_script_path: Path to the processing script, if not provided, no processing will be done
489
299
 
490
300
  Returns:
491
- Tuple of (preprocessing_code, postprocessing_code) where each can be None
301
+ str: Generated Python script content
492
302
  """
493
- try:
494
- with open(script_path, 'r', encoding='utf-8') as f:
495
- script_content = f.read()
496
- except FileNotFoundError:
497
- raise FileNotFoundError(f"Processing script not found: {script_path}")
498
- except Exception as e:
499
- raise ValueError(f"Error reading processing script: {e}")
500
-
501
- # Handle empty file
502
- if not script_content.strip():
503
- self._client.logger.info(f"Processing script {script_path} is empty")
504
- return None, None
505
-
506
- try:
507
- # Parse the Python file
508
- tree = ast.parse(script_content)
509
- except SyntaxError as e:
510
- raise ValueError(f"Syntax error in processing script: {e}")
303
+ user_info = await self._client.auth.get_user_info()
304
+ uid = user_info["uid"]
305
+ preprocessing_code, postprocessing_code = None, None
306
+
307
+ if processing_script_path:
308
+ try:
309
+ preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
310
+ if preprocessing_code:
311
+ self._client.logger.info(f"Using custom preprocessing from: {processing_script_path}")
312
+ if postprocessing_code:
313
+ self._client.logger.info(f"Using custom postprocessing from: {processing_script_path}")
314
+ if not preprocessing_code and not postprocessing_code:
315
+ self._client.logger.warning(f"No preprocessing or postprocessing functions found in {processing_script_path}")
316
+ self._client.logger.info("Deployment will continue without custom processing")
317
+ except Exception as e:
318
+ raise ValueError(f"Failed to load processing script: {str(e)}")
319
+
320
+ preprocessing_section = ""
321
+ if preprocessing_code and preprocessing_code.strip():
322
+ clean_preprocessing = textwrap.dedent(preprocessing_code)
323
+ preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
511
324
 
512
- preprocessing_code = None
513
- postprocessing_code = None
325
+ postprocessing_section = ""
326
+ if postprocessing_code and postprocessing_code.strip():
327
+ clean_postprocessing = textwrap.dedent(postprocessing_code)
328
+ postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
329
+
330
+ script_lines = [
331
+ "import logging",
332
+ "from io import BytesIO",
333
+ "import numpy as np",
334
+ "import pandas as pd",
335
+ "import xarray as xr",
336
+ "from google.cloud import storage",
337
+ "from onnxruntime import InferenceSession",
338
+ "from typing import Tuple",
339
+ "",
340
+ "logging.basicConfig(",
341
+ " level=logging.INFO",
342
+ ")",
343
+ "",
344
+ ]
514
345
 
515
- # Find function definitions
516
- function_names = []
517
- for node in ast.walk(tree):
518
- if isinstance(node, ast.FunctionDef):
519
- function_names.append(node.name)
520
- if node.name == 'preprocessing':
521
- preprocessing_code = self._extract_function_body(script_content, node)
522
- elif node.name == 'postprocessing':
523
- postprocessing_code = self._extract_function_body(script_content, node)
346
+ if preprocessing_section:
347
+ script_lines.extend([
348
+ "def validate_preprocessing_output(data_arrays):",
349
+ " \"\"\"",
350
+ " Validate preprocessing output coordinates and data type.",
351
+ " ",
352
+ " Args:",
353
+ " data_arrays: List of xarray DataArrays from preprocessing",
354
+ " ",
355
+ " Returns:",
356
+ " str: Validation signature symbol",
357
+ " ",
358
+ " Raises:",
359
+ " ValueError: If validation fails",
360
+ " \"\"\"",
361
+ " import numpy as np",
362
+ " ",
363
+ " if not data_arrays:",
364
+ " raise ValueError(\"No data arrays provided from preprocessing\")",
365
+ " ",
366
+ " reference_shape = None",
367
+ " ",
368
+ " for i, data_array in enumerate(data_arrays):",
369
+ " # Check if it's an xarray DataArray",
370
+ " if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
371
+ " raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
372
+ " ",
373
+ " # Check coordinates",
374
+ " if 'time' not in data_array.coords:",
375
+ " raise ValueError(f\"Channel {i+1} missing time coordinate\")",
376
+ " ",
377
+ " spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
378
+ " if len(spatial_dims) != 2:",
379
+ " raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
380
+ " ",
381
+ " for dim in spatial_dims:",
382
+ " if dim not in data_array.coords:",
383
+ " raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
384
+ " ",
385
+ " # Check shape consistency",
386
+ " shape = data_array.shape",
387
+ " if reference_shape is None:",
388
+ " reference_shape = shape",
389
+ " else:",
390
+ " if shape != reference_shape:",
391
+ " raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
392
+ " ",
393
+ " # Generate validation signature",
394
+ " signature_components = [",
395
+ " f\"CH{len(data_arrays)}\", # Channel count",
396
+ " f\"T{reference_shape[0]}\", # Time dimension",
397
+ " f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
398
+ " f\"DT{data_arrays[0].values.dtype}\", # Data type",
399
+ " ]",
400
+ " ",
401
+ " signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
402
+ " ",
403
+ " return signature",
404
+ "",
405
+ ])
524
406
 
525
- # Log what was found for debugging
526
- if not function_names:
527
- self._client.logger.warning(f"No functions found in processing script: {script_path}")
528
- else:
529
- found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
530
- if found_functions:
531
- self._client.logger.info(f"Found processing functions: {found_functions}")
532
- else:
533
- self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
534
- f"Available functions: {function_names}")
407
+ if postprocessing_section:
408
+ script_lines.extend([
409
+ "def validate_postprocessing_output(result_array):",
410
+ " \"\"\"",
411
+ " Validate postprocessing output coordinates and data type.",
412
+ " ",
413
+ " Args:",
414
+ " result_array: xarray DataArray from postprocessing",
415
+ " ",
416
+ " Returns:",
417
+ " str: Validation signature symbol",
418
+ " ",
419
+ " Raises:",
420
+ " ValueError: If validation fails",
421
+ " \"\"\"",
422
+ " import numpy as np",
423
+ " ",
424
+ " # Check if it's an xarray DataArray",
425
+ " if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
426
+ " raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
427
+ " ",
428
+ " # Check required coordinates",
429
+ " if 'time' not in result_array.coords:",
430
+ " raise ValueError(\"Missing time coordinate\")",
431
+ " ",
432
+ " spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
433
+ " if len(spatial_dims) != 2:",
434
+ " raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
435
+ " ",
436
+ " for dim in spatial_dims:",
437
+ " if dim not in result_array.coords:",
438
+ " raise ValueError(f\"Missing spatial coordinate: {dim}\")",
439
+ " ",
440
+ " # Check shape",
441
+ " shape = result_array.shape",
442
+ " ",
443
+ " # Generate validation signature",
444
+ " signature_components = [",
445
+ " f\"T{shape[0]}\", # Time dimension",
446
+ " f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
447
+ " f\"DT{result_array.values.dtype}\", # Data type",
448
+ " ]",
449
+ " ",
450
+ " signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
451
+ " ",
452
+ " return signature",
453
+ "",
454
+ ])
535
455
 
536
- return preprocessing_code, postprocessing_code
537
-
538
- def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
539
- """Extract the body of a function from the script content."""
540
- lines = script_content.split('\n')
456
+ if preprocessing_section:
457
+ script_lines.extend([
458
+ "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
459
+ preprocessing_section,
460
+ "",
461
+ ])
541
462
 
542
- # AST line numbers are 1-indexed, convert to 0-indexed
543
- start_line = func_node.lineno - 1 # This is the 'def' line (0-indexed)
544
- end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
463
+ if postprocessing_section:
464
+ script_lines.extend([
465
+ "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
466
+ postprocessing_section,
467
+ "",
468
+ ])
545
469
 
546
- # Extract ONLY the body lines (skip the def line entirely)
547
- body_lines = []
548
- for i in range(start_line + 1, end_line + 1): # +1 to skip the 'def' line
549
- if i < len(lines):
550
- body_lines.append(lines[i])
470
+ script_lines.extend([
471
+ "def get_model():",
472
+ f" logging.info(\"Loading Random Forest model for {virtual_dataset_name}...\")",
473
+ "",
474
+ " client = storage.Client()",
475
+ f" bucket = client.get_bucket('{bucket_name}')",
476
+ f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
477
+ "",
478
+ " model = BytesIO()",
479
+ " blob.download_to_file(model)",
480
+ " model.seek(0)",
481
+ "",
482
+ " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
483
+ " return session",
484
+ "",
485
+ f"def {virtual_product_name}(*bands, model):",
486
+ " logging.info(\"Start preparing Random Forest data\")",
487
+ " data_arrays = list(bands)",
488
+ " ",
489
+ " if not data_arrays:",
490
+ " raise ValueError(\"No bands provided\")",
491
+ " ",
492
+ ])
551
493
 
552
- if not body_lines:
553
- return ""
494
+ if preprocessing_section:
495
+ script_lines.extend([
496
+ " # Apply preprocessing",
497
+ " data_arrays = preprocessing(tuple(data_arrays))",
498
+ " data_arrays = list(data_arrays) # Convert back to list for processing",
499
+ " ",
500
+ " # Validate preprocessing output",
501
+ " preprocessing_signature = validate_preprocessing_output(data_arrays)",
502
+ " ",
503
+ ])
554
504
 
555
- # Join and dedent to remove function-level indentation
556
- body_text = '\n'.join(body_lines)
557
- cleaned_body = textwrap.dedent(body_text).strip()
505
+ script_lines.extend([
506
+ " reference_array = data_arrays[0]",
507
+ " original_shape = reference_array.shape",
508
+ " ",
509
+ " if 'time' in reference_array.dims:",
510
+ " time_coords = reference_array.coords['time']",
511
+ " if len(time_coords) == 1:",
512
+ " output_timestamp = time_coords[0]",
513
+ " else:",
514
+ " years = [pd.to_datetime(t).year for t in time_coords.values]",
515
+ " unique_years = set(years)",
516
+ " ",
517
+ " if len(unique_years) == 1:",
518
+ " year = list(unique_years)[0]",
519
+ " output_timestamp = pd.Timestamp(f\"{year}-01-01\")",
520
+ " else:",
521
+ " latest_year = max(unique_years)",
522
+ " output_timestamp = pd.Timestamp(f\"{latest_year}-01-01\")",
523
+ " else:",
524
+ " output_timestamp = pd.Timestamp(\"1970-01-01\")",
525
+ "",
526
+ " averaged_bands = []",
527
+ " for data_array in data_arrays:",
528
+ " if 'time' in data_array.dims:",
529
+ " averaged_band = np.mean(data_array.values, axis=0)",
530
+ " else:",
531
+ " averaged_band = data_array.values",
532
+ "",
533
+ " flattened_band = averaged_band.reshape(-1, 1)",
534
+ " averaged_bands.append(flattened_band)",
535
+ "",
536
+ " input_data = np.hstack(averaged_bands)",
537
+ "",
538
+ " output = model.run(None, {\"float_input\": input_data.astype(np.float32)})[0]",
539
+ "",
540
+ " if len(original_shape) >= 3:",
541
+ " spatial_shape = original_shape[1:]",
542
+ " else:",
543
+ " spatial_shape = original_shape",
544
+ "",
545
+ " output_reshaped = output.reshape(spatial_shape)",
546
+ "",
547
+ " output_with_time = np.expand_dims(output_reshaped, axis=0)",
548
+ "",
549
+ " if 'time' in reference_array.dims:",
550
+ " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
551
+ " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}",
552
+ " else:",
553
+ " spatial_dims = list(reference_array.dims)",
554
+ " spatial_coords = dict(reference_array.coords)",
555
+ "",
556
+ " result = xr.DataArray(",
557
+ " data=output_with_time.astype(np.float32),",
558
+ " dims=['time'] + list(spatial_dims),",
559
+ " coords={",
560
+ " 'time': [output_timestamp.values],",
561
+ " 'y': spatial_coords['y'].values,",
562
+ " 'x': spatial_coords['x'].values",
563
+ " },",
564
+ " attrs={",
565
+ " 'description': 'Random Forest model prediction',",
566
+ " }",
567
+ " )",
568
+ ])
558
569
 
559
- # Handle empty function body
560
- if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
561
- return ""
570
+ if postprocessing_section:
571
+ script_lines.extend([
572
+ " # Apply postprocessing",
573
+ " result = postprocessing(result)",
574
+ " ",
575
+ " # Validate postprocessing output",
576
+ " postprocessing_signature = validate_postprocessing_output(result)",
577
+ " ",
578
+ ])
562
579
 
563
- return cleaned_body
580
+ script_lines.append(" return result")
564
581
 
582
+ return "\n".join(script_lines)
583
+
565
584
  @require_api_key
566
- async def deploy_cnn_model(
567
- self,
568
- dataset: str,
569
- product: str,
570
- model_name: str,
571
- input_expression: str,
572
- model_training_job_name: str,
573
- dates_iso8601: list,
574
- processing_script_path: Optional[str] = None
575
- ) -> Dict[str, Any]:
585
+ async def _generate_cnn_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
576
586
  """
577
- Deploy a CNN model by generating inference script and creating dataset.
587
+ Generate Python inference script for CNN model with time-stacked bands.
578
588
 
579
589
  Args:
580
- dataset: Name of the dataset to create
581
- product: Product name for the inference
582
- model_name: Name of the trained model
583
- input_expression: Input expression for the dataset
584
- model_training_job_name: Name of the training job
585
- dates_iso8601: List of dates in ISO8601 format
590
+ bucket_name: Name of the bucket where the model is stored
591
+ virtual_dataset_name: Name of the virtual dataset and the model
592
+ virtual_product_name: Name of the virtual product
586
593
  processing_script_path: Path to the processing script, if not provided, no processing will be done
587
594
  Returns:
588
- dict: Response from the deployment process
589
-
590
- Raises:
591
- APIError: If the API request fails
595
+ str: Generated Python script content
592
596
  """
593
- # Get user info to get UID
594
597
  user_info = await self._client.auth.get_user_info()
595
598
  uid = user_info["uid"]
596
-
597
599
  preprocessing_code, postprocessing_code = None, None
600
+
598
601
  if processing_script_path:
599
- # if there is a function that is being passed in
600
602
  try:
601
603
  preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
602
604
  if preprocessing_code:
@@ -608,176 +610,17 @@ class ModelManagement:
608
610
  self._client.logger.info("Deployment will continue without custom processing")
609
611
  except Exception as e:
610
612
  raise ValueError(f"Failed to load processing script: {str(e)}")
611
- # so we already have the preprocessing code and the post processing code, I need to pass them to the generate cnn script function
612
- # Generate and upload script
613
- # Build preprocessing section with CONSISTENT 8-space indentation
614
- preprocessing_section = ""
615
- if preprocessing_code and preprocessing_code.strip():
616
- # First dedent the preprocessing code to remove any existing indentation
617
- clean_preprocessing = preprocessing_code
618
- # Then add consistent 8-space indentation to match the template
619
- preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
620
- print(preprocessing_section)
621
- script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
622
- script_name = f"{product}.py"
623
- self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
624
- # Create dataset
625
- return await self._client.datasets.create_dataset(
626
- name=dataset,
627
- collection="terrakio-datasets",
628
- products=[product],
629
- path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
630
- input=input_expression,
631
- dates_iso8601=dates_iso8601,
632
- padding=0
633
- )
634
-
635
- @require_api_key
636
- def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
637
- """
638
- Generate Python inference script for the model.
639
-
640
- Args:
641
- model_name: Name of the model
642
- product: Product name
643
- model_training_job_name: Training job name
644
- uid: User ID
645
-
646
- Returns:
647
- str: Generated Python script content
648
- """
649
- return textwrap.dedent(f'''
650
- import logging
651
- from io import BytesIO
652
-
653
- import numpy as np
654
- import pandas as pd
655
- import xarray as xr
656
- from google.cloud import storage
657
- from onnxruntime import InferenceSession
658
-
659
- logging.basicConfig(
660
- level=logging.INFO
661
- )
662
-
663
- def get_model():
664
- logging.info("Loading model for {model_name}...")
665
-
666
- client = storage.Client()
667
- bucket = client.get_bucket('terrakio-mass-requests')
668
- blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
669
-
670
- model = BytesIO()
671
- blob.download_to_file(model)
672
- model.seek(0)
673
-
674
- session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
675
- return session
676
-
677
- def {product}(*bands, model):
678
- logging.info("start preparing data")
679
-
680
- data_arrays = list(bands)
681
-
682
- reference_array = data_arrays[0]
683
- original_shape = reference_array.shape
684
- logging.info(f"Original shape: {{original_shape}}")
685
613
 
686
- if 'time' in reference_array.dims:
687
- time_coords = reference_array.coords['time']
688
- if len(time_coords) == 1:
689
- output_timestamp = time_coords[0]
690
- else:
691
- years = [pd.to_datetime(t).year for t in time_coords.values]
692
- unique_years = set(years)
693
-
694
- if len(unique_years) == 1:
695
- year = list(unique_years)[0]
696
- output_timestamp = pd.Timestamp(f"{{year}}-01-01")
697
- else:
698
- latest_year = max(unique_years)
699
- output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
700
- else:
701
- output_timestamp = pd.Timestamp("1970-01-01")
702
-
703
- averaged_bands = []
704
- for data_array in data_arrays:
705
- if 'time' in data_array.dims:
706
- averaged_band = np.mean(data_array.values, axis=0)
707
- logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
708
- else:
709
- averaged_band = data_array.values
710
- logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
711
-
712
- flattened_band = averaged_band.reshape(-1, 1)
713
- averaged_bands.append(flattened_band)
714
-
715
- input_data = np.hstack(averaged_bands)
716
-
717
- logging.info(f"Final input shape: {{input_data.shape}}")
718
-
719
- output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
720
-
721
- logging.info(f"Model output shape: {{output.shape}}")
722
-
723
- if len(original_shape) >= 3:
724
- spatial_shape = original_shape[1:]
725
- else:
726
- spatial_shape = original_shape
727
-
728
- output_reshaped = output.reshape(spatial_shape)
729
-
730
- output_with_time = np.expand_dims(output_reshaped, axis=0)
731
-
732
- if 'time' in reference_array.dims:
733
- spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
734
- spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
735
- else:
736
- spatial_dims = list(reference_array.dims)
737
- spatial_coords = dict(reference_array.coords)
738
-
739
- result = xr.DataArray(
740
- data=output_with_time.astype(np.float32),
741
- dims=['time'] + list(spatial_dims),
742
- coords={{
743
- 'time': [output_timestamp.values],
744
- 'y': spatial_coords['y'].values,
745
- 'x': spatial_coords['x'].values
746
- }}
747
- )
748
- return result
749
- ''').strip()
750
-
751
- @require_api_key
752
- def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
753
- """
754
- Generate Python inference script for CNN model with time-stacked bands.
755
-
756
- Args:
757
- model_name: Name of the model
758
- product: Product name
759
- model_training_job_name: Training job name
760
- uid: User ID
761
- preprocessing_code: Preprocessing code
762
- postprocessing_code: Postprocessing code
763
- Returns:
764
- str: Generated Python script content
765
- """
766
- import textwrap
767
-
768
- # Build preprocessing section with CONSISTENT 4-space indentation
769
614
  preprocessing_section = ""
770
615
  if preprocessing_code and preprocessing_code.strip():
771
616
  clean_preprocessing = textwrap.dedent(preprocessing_code)
772
617
  preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
773
618
 
774
- # Build postprocessing section with CONSISTENT 4-space indentation
775
619
  postprocessing_section = ""
776
620
  if postprocessing_code and postprocessing_code.strip():
777
621
  clean_postprocessing = textwrap.dedent(postprocessing_code)
778
622
  postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
779
623
 
780
- # Build the template WITHOUT dedenting the whole thing, so indentation is preserved
781
624
  script_lines = [
782
625
  "import logging",
783
626
  "from io import BytesIO",
@@ -794,7 +637,116 @@ class ModelManagement:
794
637
  "",
795
638
  ]
796
639
 
797
- # Add preprocessing function definition BEFORE the main function
640
+ if preprocessing_section:
641
+ script_lines.extend([
642
+ "def validate_preprocessing_output(data_arrays):",
643
+ " \"\"\"",
644
+ " Validate preprocessing output coordinates and data type.",
645
+ " ",
646
+ " Args:",
647
+ " data_arrays: List of xarray DataArrays from preprocessing",
648
+ " ",
649
+ " Returns:",
650
+ " str: Validation signature symbol",
651
+ " ",
652
+ " Raises:",
653
+ " ValueError: If validation fails",
654
+ " \"\"\"",
655
+ " import numpy as np",
656
+ " ",
657
+ " if not data_arrays:",
658
+ " raise ValueError(\"No data arrays provided from preprocessing\")",
659
+ " ",
660
+ " reference_shape = None",
661
+ " ",
662
+ " for i, data_array in enumerate(data_arrays):",
663
+ " # Check if it's an xarray DataArray",
664
+ " if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
665
+ " raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
666
+ " ",
667
+ " # Check coordinates",
668
+ " if 'time' not in data_array.coords:",
669
+ " raise ValueError(f\"Channel {i+1} missing time coordinate\")",
670
+ " ",
671
+ " spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
672
+ " if len(spatial_dims) != 2:",
673
+ " raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
674
+ " ",
675
+ " for dim in spatial_dims:",
676
+ " if dim not in data_array.coords:",
677
+ " raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
678
+ " ",
679
+ " # Check shape consistency",
680
+ " shape = data_array.shape",
681
+ " if reference_shape is None:",
682
+ " reference_shape = shape",
683
+ " else:",
684
+ " if shape != reference_shape:",
685
+ " raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
686
+ " ",
687
+ " # Generate validation signature",
688
+ " signature_components = [",
689
+ " f\"CH{len(data_arrays)}\", # Channel count",
690
+ " f\"T{reference_shape[0]}\", # Time dimension",
691
+ " f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
692
+ " f\"DT{data_arrays[0].values.dtype}\", # Data type",
693
+ " ]",
694
+ " ",
695
+ " signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
696
+ " ",
697
+ " return signature",
698
+ "",
699
+ ])
700
+
701
+ if postprocessing_section:
702
+ script_lines.extend([
703
+ "def validate_postprocessing_output(result_array):",
704
+ " \"\"\"",
705
+ " Validate postprocessing output coordinates and data type.",
706
+ " ",
707
+ " Args:",
708
+ " result_array: xarray DataArray from postprocessing",
709
+ " ",
710
+ " Returns:",
711
+ " str: Validation signature symbol",
712
+ " ",
713
+ " Raises:",
714
+ " ValueError: If validation fails",
715
+ " \"\"\"",
716
+ " import numpy as np",
717
+ " ",
718
+ " # Check if it's an xarray DataArray",
719
+ " if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
720
+ " raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
721
+ " ",
722
+ " # Check required coordinates",
723
+ " if 'time' not in result_array.coords:",
724
+ " raise ValueError(\"Missing time coordinate\")",
725
+ " ",
726
+ " spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
727
+ " if len(spatial_dims) != 2:",
728
+ " raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
729
+ " ",
730
+ " for dim in spatial_dims:",
731
+ " if dim not in result_array.coords:",
732
+ " raise ValueError(f\"Missing spatial coordinate: {dim}\")",
733
+ " ",
734
+ " # Check shape",
735
+ " shape = result_array.shape",
736
+ " ",
737
+ " # Generate validation signature",
738
+ " signature_components = [",
739
+ " f\"T{shape[0]}\", # Time dimension",
740
+ " f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
741
+ " f\"DT{result_array.values.dtype}\", # Data type",
742
+ " ]",
743
+ " ",
744
+ " signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
745
+ " ",
746
+ " return signature",
747
+ "",
748
+ ])
749
+
798
750
  if preprocessing_section:
799
751
  script_lines.extend([
800
752
  "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
@@ -802,7 +754,6 @@ class ModelManagement:
802
754
  "",
803
755
  ])
804
756
 
805
- # Add postprocessing function definition BEFORE the main function
806
757
  if postprocessing_section:
807
758
  script_lines.extend([
808
759
  "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
@@ -810,14 +761,13 @@ class ModelManagement:
810
761
  "",
811
762
  ])
812
763
 
813
- # Add the get_model function
814
764
  script_lines.extend([
815
765
  "def get_model():",
816
- f" logging.info(\"Loading CNN model for {model_name}...\")",
766
+ f" logging.info(\"Loading CNN model for {virtual_dataset_name}...\")",
817
767
  "",
818
768
  " client = storage.Client()",
819
- " bucket = client.get_bucket('terrakio-mass-requests')",
820
- f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
769
+ f" bucket = client.get_bucket('{bucket_name}')",
770
+ f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
821
771
  "",
822
772
  " model = BytesIO()",
823
773
  " blob.download_to_file(model)",
@@ -826,7 +776,7 @@ class ModelManagement:
826
776
  " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
827
777
  " return session",
828
778
  "",
829
- f"def {product}(*bands, model):",
779
+ f"def {virtual_product_name}(*bands, model):",
830
780
  " logging.info(\"Start preparing CNN data with time-stacked bands\")",
831
781
  " data_arrays = list(bands)",
832
782
  " ",
@@ -835,20 +785,20 @@ class ModelManagement:
835
785
  " ",
836
786
  ])
837
787
 
838
- # Add preprocessing call if preprocessing exists
839
788
  if preprocessing_section:
840
789
  script_lines.extend([
841
790
  " # Apply preprocessing",
842
791
  " data_arrays = preprocessing(tuple(data_arrays))",
843
792
  " data_arrays = list(data_arrays) # Convert back to list for processing",
844
793
  " ",
794
+ " # Validate preprocessing output",
795
+ " preprocessing_signature = validate_preprocessing_output(data_arrays)",
796
+ " ",
845
797
  ])
846
798
 
847
- # Continue with the rest of the processing logic
848
799
  script_lines.extend([
849
800
  " reference_array = data_arrays[0]",
850
801
  " original_shape = reference_array.shape",
851
- " logging.info(f\"Original shape: {original_shape}\")",
852
802
  " ",
853
803
  " # Get time coordinates - all bands should have the same time dimension",
854
804
  " if 'time' not in reference_array.dims:",
@@ -856,24 +806,19 @@ class ModelManagement:
856
806
  " ",
857
807
  " time_coords = reference_array.coords['time']",
858
808
  " num_timestamps = len(time_coords)",
859
- " logging.info(f\"Number of timestamps: {num_timestamps}\")",
860
809
  " ",
861
810
  " # Get spatial dimensions",
862
811
  " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
863
812
  " height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
864
813
  " width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
865
- " logging.info(f\"Spatial dimensions: {height} x {width}\")",
866
814
  " ",
867
815
  " # Stack bands across time dimension",
868
816
  " # Result will be: (num_bands * num_timestamps, height, width)",
869
817
  " stacked_channels = []",
870
818
  " ",
871
819
  " for band_idx, data_array in enumerate(data_arrays):",
872
- " logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
873
- " ",
874
820
  " # Ensure consistent time coordinates across bands",
875
821
  " if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
876
- " logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
877
822
  " data_array = data_array.sel(time=time_coords, method='nearest')",
878
823
  " ",
879
824
  " # Extract values and ensure proper ordering (time, height, width)",
@@ -892,23 +837,18 @@ class ModelManagement:
892
837
  " # Stack all channels: (num_bands * num_timestamps, height, width)",
893
838
  " input_channels = np.stack(stacked_channels, axis=0)",
894
839
  " total_channels = len(data_arrays) * num_timestamps",
895
- " logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
896
- " logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
897
840
  " ",
898
841
  " # Add batch dimension: (1, num_channels, height, width)",
899
842
  " input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
900
- " logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
901
843
  " ",
902
844
  " # Run inference",
903
845
  " output = model.run(None, {\"float_input\": input_data})[0]",
904
- " logging.info(f\"Model output shape: {output.shape}\")",
905
846
  " ",
906
- " # UPDATED: Handle multi-class CNN output properly",
847
+ " # Handle multi-class CNN output properly",
907
848
  " if output.ndim == 4:",
908
849
  " if output.shape[1] == 1:",
909
850
  " # Single class output (regression or binary classification)",
910
851
  " output_2d = output[0, 0]",
911
- " logging.info(\"Single channel output detected\")",
912
852
  " else:",
913
853
  " # Multi-class output - convert logits/probabilities to class predictions",
914
854
  " output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
@@ -916,22 +856,14 @@ class ModelManagement:
916
856
  " ",
917
857
  " # Apply class merging: merge class 6 into class 3",
918
858
  " output_2d = np.where(output_2d == 6, 3, output_2d)",
919
- " ",
920
- " logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
921
- " logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
922
- " logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
923
859
  " elif output.ndim == 3:",
924
860
  " # Remove batch dimension",
925
861
  " output_2d = output[0]",
926
- " logging.info(\"3D output detected, removed batch dimension\")",
927
862
  " else:",
928
863
  " # Handle other cases",
929
864
  " output_2d = np.squeeze(output)",
930
865
  " if output_2d.ndim != 2:",
931
- " logging.error(f\"Cannot process output shape: {output.shape}\")",
932
- " logging.error(f\"After squeeze: {output_2d.shape}\")",
933
866
  " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
934
- " logging.info(\"Applied squeeze to output\")",
935
867
  " ",
936
868
  " # Ensure output is 2D",
937
869
  " if output_2d.ndim != 2:",
@@ -949,11 +881,9 @@ class ModelManagement:
949
881
  " if is_multiclass:",
950
882
  " # Multi-class classification - use integer type",
951
883
  " output_dtype = np.int32",
952
- " output_type = 'classification'",
953
884
  " else:",
954
885
  " # Single output - use float type",
955
886
  " output_dtype = np.float32",
956
- " output_type = 'regression'",
957
887
  " ",
958
888
  " result = xr.DataArray(",
959
889
  " data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
@@ -967,31 +897,232 @@ class ModelManagement:
967
897
  " 'description': 'CNN model prediction',",
968
898
  " }",
969
899
  " )",
970
- " ",
971
- " logging.info(f\"Final result shape: {result.shape}\")",
972
- " logging.info(f\"Final result data type: {result.dtype}\")",
973
- " logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
974
900
  ])
975
901
 
976
- # Add postprocessing call if postprocessing exists
977
902
  if postprocessing_section:
978
903
  script_lines.extend([
979
904
  " # Apply postprocessing",
980
905
  " result = postprocessing(result)",
981
906
  " ",
907
+ " # Validate postprocessing output",
908
+ " postprocessing_signature = validate_postprocessing_output(result)",
909
+ " ",
982
910
  ])
983
911
 
984
- # Single return statement at the end
985
912
  script_lines.append(" return result")
986
913
 
987
914
  return "\n".join(script_lines)
915
+
916
+ def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
917
+ """
918
+ Parse a Python file and extract preprocessing and postprocessing function bodies.
919
+
920
+ Args:
921
+ script_path: Path to the Python file containing processing functions
922
+
923
+ Returns:
924
+ Tuple of (preprocessing_code, postprocessing_code) where each can be None
925
+ """
926
+ try:
927
+ with open(script_path, 'r', encoding='utf-8') as f:
928
+ script_content = f.read()
929
+ except FileNotFoundError:
930
+ raise FileNotFoundError(f"Processing script not found: {script_path}")
931
+ except Exception as e:
932
+ raise ValueError(f"Error reading processing script: {e}")
933
+
934
+ if not script_content.strip():
935
+ self._client.logger.info(f"Processing script {script_path} is empty")
936
+ return None, None
937
+
938
+ try:
939
+ tree = ast.parse(script_content)
940
+ except SyntaxError as e:
941
+ raise ValueError(f"Syntax error in processing script: {e}")
942
+
943
+ preprocessing_code = None
944
+ postprocessing_code = None
945
+
946
+ function_names = []
947
+ for node in ast.walk(tree):
948
+ if isinstance(node, ast.FunctionDef):
949
+ function_names.append(node.name)
950
+ if node.name == 'preprocessing':
951
+ preprocessing_code = self._extract_function_body(script_content, node)
952
+ elif node.name == 'postprocessing':
953
+ postprocessing_code = self._extract_function_body(script_content, node)
954
+
955
+ if not function_names:
956
+ self._client.logger.warning(f"No functions found in processing script: {script_path}")
957
+ else:
958
+ found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
959
+ if found_functions:
960
+ self._client.logger.info(f"Found processing functions: {found_functions}")
961
+ else:
962
+ self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
963
+ f"Available functions: {function_names}")
964
+
965
+ return preprocessing_code, postprocessing_code
966
+
967
+ def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
968
+ """Extract the body of a function from the script content."""
969
+ lines = script_content.split('\n')
970
+
971
+ start_line = func_node.lineno - 1
972
+ end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
973
+
974
+ body_lines = []
975
+ for i in range(start_line + 1, end_line + 1):
976
+ if i < len(lines):
977
+ body_lines.append(lines[i])
978
+
979
+ if not body_lines:
980
+ return ""
981
+
982
+ body_text = '\n'.join(body_lines)
983
+ cleaned_body = textwrap.dedent(body_text).strip()
984
+
985
+ if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
986
+ return ""
987
+
988
+ return cleaned_body
988
989
 
990
+ def _convert_model_to_onnx(self, model, input_shape: Tuple[int, ...] = None, model_type: Optional[str] = None) -> bytes:
991
+ """
992
+ Convert a model to ONNX format and return as bytes.
993
+
994
+ Args:
995
+ model: The model object (PyTorch or scikit-learn)
996
+ input_shape: Shape of input data
997
+ model_type: Type of model (neural_network, random_forest), only used for onnx model generation
998
+ Returns:
999
+ bytes: ONNX model as bytes
1000
+
1001
+ Raises:
1002
+ ValueError: If model type is not supported
1003
+ ImportError: If required libraries are not installed
1004
+ """
1005
+ if isinstance(model, torch.nn.Module):
1006
+ if not TORCH_AVAILABLE:
1007
+ raise ImportError("PyTorch is not installed. Please install it with: pip install torch")
1008
+ return self._convert_pytorch_to_onnx(model, input_shape), "neural_network"
1009
+ elif isinstance(model, BaseEstimator):
1010
+ if not SKL2ONNX_AVAILABLE:
1011
+ raise ImportError("skl2onnx is not installed. Please install it with: pip install skl2onnx")
1012
+ return self._convert_sklearn_to_onnx(model, input_shape), "random_forest"
1013
+ elif isinstance(model, ort.InferenceSession):
1014
+ if model_type is None:
1015
+ raise ValueError(
1016
+ "For ONNX InferenceSession models, you must specify the 'model_type' parameter. Currently 'nerual network' and 'random forest' are supported."
1017
+ "Example: model_type='random forest' or model_type='neural network'"
1018
+ )
1019
+ return model.SerializeToString(), model_type
1020
+ else:
1021
+ model_type = type(model).__name__
1022
+ raise ValueError(f"Unsupported model type: {model_type}. Supported types: PyTorch nn.Module, sklearn BaseEstimator")
1023
+
1024
+ def _convert_pytorch_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
1025
+ try:
1026
+ model.eval()
1027
+ dummy_input = torch.randn(input_shape)
1028
+ onnx_buffer = BytesIO()
1029
+
1030
+ if len(input_shape) == 4:
1031
+ dynamic_axes = {
1032
+ 'float_input': {
1033
+ 0: 'batch_size',
1034
+ 2: 'height',
1035
+ 3: 'width'
1036
+ }
1037
+ }
1038
+
1039
+ elif len(input_shape) == 5:
1040
+ dynamic_axes = {
1041
+ 'float_input': {
1042
+ 0: 'batch_size',
1043
+ 3: 'height',
1044
+ 4: 'width'
1045
+ }
1046
+ }
1047
+
1048
+ else:
1049
+ dynamic_axes = {
1050
+ 'float_input': {
1051
+ 0: 'batch_size'
1052
+ }
1053
+ }
1054
+
1055
+ torch.onnx.export(
1056
+ model,
1057
+ dummy_input,
1058
+ onnx_buffer,
1059
+ input_names=['float_input'],
1060
+ dynamic_axes=dynamic_axes
1061
+ )
1062
+
1063
+ return onnx_buffer.getvalue()
1064
+ except Exception as e:
1065
+ raise ValueError(f"Failed to convert PyTorch model to ONNX: {str(e)}")
1066
+
1067
+ def _convert_sklearn_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
1068
+ """
1069
+ Convert scikit-learn model(assume it is a random forest model) to ONNX format.
1070
+
1071
+ Args:
1072
+ model: The scikit-learn model object
1073
+ input_shape: Shape of input data (required)
1074
+
1075
+ Returns:
1076
+ bytes: ONNX model as bytes
1077
+
1078
+ Raises:
1079
+ ValueError: If conversion fails
1080
+ """
1081
+ self._client.logger.info(f"Converting random forest model to ONNX...")
1082
+
1083
+ try:
1084
+ initial_type = [('float_input', FloatTensorType(input_shape))]
1085
+ onnx_model = convert_sklearn(model, initial_types=initial_type)
1086
+ return onnx_model.SerializeToString()
1087
+ except Exception as e:
1088
+ raise ValueError(f"Failed to convert scikit-learn model to ONNX: {str(e)}")
1089
+
989
1090
  @require_api_key
990
- def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
991
- """Upload the generated script to Google Cloud Storage"""
992
-
993
- client = storage.Client()
994
- bucket = client.get_bucket('terrakio-mass-requests')
995
- blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
996
- blob.upload_from_string(script_content, content_type='text/plain')
997
- logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
1091
+ def train_model(
1092
+ self,
1093
+ model_name: str,
1094
+ training_dataset: str,
1095
+ task_type: str,
1096
+ model_category: str,
1097
+ architecture: str,
1098
+ region: str,
1099
+ hyperparameters: dict = None
1100
+ ) -> dict:
1101
+ """
1102
+ Train a model using the external model training API.
1103
+
1104
+ Args:
1105
+ model_name (str): The name of the model to train.
1106
+ training_dataset (str): The training dataset identifier.
1107
+ task_type (str): The type of ML task (e.g., regression, classification).
1108
+ model_category (str): The category of model (e.g., random_forest).
1109
+ architecture (str): The model architecture.
1110
+ region (str): The region identifier.
1111
+ hyperparameters (dict, optional): Additional hyperparameters for training.
1112
+
1113
+ Returns:
1114
+ dict: The response from the model training API.
1115
+
1116
+ Raises:
1117
+ APIError: If the API request fails
1118
+ """
1119
+ payload = {
1120
+ "model_name": model_name,
1121
+ "training_dataset": training_dataset,
1122
+ "task_type": task_type,
1123
+ "model_category": model_category,
1124
+ "architecture": architecture,
1125
+ "region": region,
1126
+ "hyperparameters": hyperparameters
1127
+ }
1128
+ return self._client._terrakio_request("POST", "/train_model", json=payload)