terrakio-core 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

@@ -1,36 +1,36 @@
1
- import os
1
+ # Standard library imports
2
+ import ast
2
3
  import json
3
- import time
4
4
  import textwrap
5
- import logging
6
- from typing import Dict, Any, Union, Tuple, Optional
5
+ import time
7
6
  from io import BytesIO
8
- import numpy as np
9
- from google.cloud import storage
10
- import ast
11
- from ..helper.decorators import require_token, require_api_key, require_auth
7
+ from typing import Optional, Tuple
8
+ import onnxruntime as ort
9
+
10
+ # Internal imports
11
+ from ..helper.decorators import require_api_key
12
+
13
+ # Optional dependency flags
12
14
  TORCH_AVAILABLE = False
13
15
  SKL2ONNX_AVAILABLE = False
14
16
 
17
+ # PyTorch imports
15
18
  try:
16
19
  import torch
17
20
  TORCH_AVAILABLE = True
18
21
  except ImportError:
19
22
  torch = None
20
23
 
24
+ # Scikit-learn and ONNX conversion imports
21
25
  try:
26
+ from sklearn.base import BaseEstimator
22
27
  from skl2onnx import convert_sklearn
23
28
  from skl2onnx.common.data_types import FloatTensorType
24
- from sklearn.base import BaseEstimator
25
29
  SKL2ONNX_AVAILABLE = True
26
30
  except ImportError:
31
+ BaseEstimator = None
27
32
  convert_sklearn = None
28
33
  FloatTensorType = None
29
- BaseEstimator = None
30
-
31
- from io import BytesIO
32
- from typing import Tuple
33
-
34
34
 
35
35
  class ModelManagement:
36
36
  def __init__(self, client):
@@ -128,14 +128,12 @@ class ModelManagement:
128
128
  )
129
129
  task_id = task_response["task_id"]
130
130
 
131
- # Wait for job completion with progress bar
132
131
  while True:
133
132
  result = await self._client.mass_stats.track_job(ids=[task_id])
134
133
  status = result[task_id]['status']
135
134
  completed = result[task_id].get('completed', 0)
136
135
  total = result[task_id].get('total', 1)
137
136
 
138
- # Create progress bar
139
137
  progress = completed / total if total > 0 else 0
140
138
  bar_length = 50
141
139
  filled_length = int(bar_length * progress)
@@ -151,526 +149,456 @@ class ModelManagement:
151
149
  self._client.logger.info("Job encountered an error")
152
150
  raise Exception(f"Job {task_id} encountered an error")
153
151
 
154
- # Wait 5 seconds before checking again
155
152
  time.sleep(5)
156
153
 
157
- # after all the random sample jobs are done, we then start the mass stats job
158
- # task_id = self._client.mass_stats.start_mass_stats_job(task_id)
159
154
  task_id = await self._client.mass_stats.start_job(task_id)
160
155
  return task_id
161
- # the folder that is being created is not under the jobs folder, its directly under the UID folder
162
156
 
163
- # @require_api_key
164
- # async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
165
- # """
166
- # Upload a model to the bucket so that it can be used for inference.
167
- # Converts PyTorch and scikit-learn models to ONNX format before uploading.
168
-
169
- # Args:
170
- # model: The model object (PyTorch model or scikit-learn model)
171
- # model_name: Name for the model (without extension)
172
- # input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
173
- # Required for PyTorch models, optional for scikit-learn models
174
-
175
- # Raises:
176
- # APIError: If the API request fails
177
- # ValueError: If model type is not supported or input_shape is missing for PyTorch models
178
- # ImportError: If required libraries (torch or skl2onnx) are not installed
179
- # """
180
- # uid = (await self._client.auth.get_user_info())["uid"]
181
- # # above line is getting the uid,
182
-
183
- # client = storage.Client()
184
- # bucket = client.get_bucket('terrakio-mass-requests')
185
-
186
- # # Convert model to ONNX format
187
- # onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
188
-
189
- # # Upload ONNX model to bucket
190
- # # blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
191
- # # we don't need to upload the model to the bucket
192
- # # so the stuff is stored under the virtual datasets folder
193
- # # the model name and the virtual dataset name should be the same
194
- # virtual_dataset_name = model_name
195
- # blob = bucket.blob(f'{uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx')
196
- # # wer are uploading the model to the virtual dataset folder
197
-
198
- # blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
199
-
200
- # self._client.logger.info(f"Model uploaded successfully to {uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx")
201
-
202
- # this is the upload model function, I think we need to upload to the user, under the virutal_datasets folder, and create the virtual dataset
203
-
204
-
205
157
  @require_api_key
206
- async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
158
+ async def _get_url_for_upload_model_and_script(self, expression: str, model_name: str, script_name: str) -> str:
207
159
  """
208
- Upload a model to the bucket so that it can be used for inference.
209
- Converts PyTorch and scikit-learn models to ONNX format before uploading.
210
-
160
+ Get the url for the upload of the model
211
161
  Args:
212
- model: The model object (PyTorch model or scikit-learn model)
213
- model_name: Name for the model (without extension)
214
- input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
215
- Required for PyTorch models, optional for scikit-learn models
216
-
217
- Raises:
218
- APIError: If the API request fails
219
- ValueError: If model type is not supported or input_shape is missing for PyTorch models
220
- ImportError: If required libraries (torch or skl2onnx) are not installed
162
+ expression: The expression to use for the upload(for deciding which bucket to upload to)
163
+ model_name: The name of the model to upload
164
+ script_name: The name of the script to upload
165
+ Returns:
166
+ The url for the upload of the model
221
167
  """
222
- uid = (await self._client.auth.get_user_info())["uid"]
223
-
224
- client = storage.Client()
225
- bucket = client.get_bucket('terrakio-mass-requests')
226
-
227
- # Convert model to ONNX format
228
- onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
229
-
230
- # Upload ONNX model to bucket
231
- blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
232
-
233
- blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
234
- self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
168
+ payload = {
169
+ "model_name": model_name,
170
+ "expression": expression,
171
+ "script_name": script_name
172
+ }
173
+ return await self._client._terrakio_request("POST", "models/upload", json=payload)
235
174
 
236
- def _convert_model_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
175
+ async def _upload_model_to_url(self, upload_model_url: str, model: bytes):
237
176
  """
238
- Convert a model to ONNX format and return as bytes.
239
-
177
+ Upload a model to a given URL.
240
178
  Args:
241
- model: The model object (PyTorch or scikit-learn)
242
- model_name: Name of the model for logging
243
- input_shape: Shape of input data
244
-
179
+ model_url: The url to upload the model to
180
+ model: The model to upload
181
+
245
182
  Returns:
246
- bytes: ONNX model as bytes
247
-
248
- Raises:
249
- ValueError: If model type is not supported
250
- ImportError: If required libraries are not installed
183
+ The response from the server
251
184
  """
252
- # Early check for any conversion capability
253
- if not (TORCH_AVAILABLE or SKL2ONNX_AVAILABLE):
254
- raise ImportError(
255
- "ONNX conversion requires additional dependencies. Install with:\n"
256
- " pip install torch # For PyTorch models\n"
257
- " pip install skl2onnx # For scikit-learn models\n"
258
- " pip install torch skl2onnx # For both"
259
- )
260
-
261
- # Check if it's a PyTorch model using isinstance (preferred) with fallback
262
- is_pytorch = False
263
- if TORCH_AVAILABLE:
264
- is_pytorch = (isinstance(model, torch.nn.Module) or
265
- hasattr(model, 'state_dict'))
266
-
267
- # Check if it's a scikit-learn model
268
- is_sklearn = False
269
- if SKL2ONNX_AVAILABLE:
270
- is_sklearn = (isinstance(model, BaseEstimator) or
271
- (hasattr(model, 'fit') and hasattr(model, 'predict')))
272
-
273
- if is_pytorch and TORCH_AVAILABLE:
274
- return self._convert_pytorch_to_onnx(model, model_name, input_shape)
275
- elif is_sklearn and SKL2ONNX_AVAILABLE:
276
- return self._convert_sklearn_to_onnx(model, model_name, input_shape)
277
- else:
278
- # Provide helpful error message
279
- model_type = type(model).__name__
280
- model_module = type(model).__module__
281
- available_types = []
282
- missing_deps = []
283
-
284
- if TORCH_AVAILABLE:
285
- available_types.append("PyTorch (torch.nn.Module)")
286
- else:
287
- missing_deps.append("torch")
288
-
289
- if SKL2ONNX_AVAILABLE:
290
- available_types.append("scikit-learn (BaseEstimator)")
291
- else:
292
- missing_deps.append("skl2onnx")
293
-
294
- if missing_deps:
295
- raise ImportError(
296
- f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
297
- f"Install with: pip install {' '.join(missing_deps)}"
298
- )
299
- else:
300
- raise ValueError(
301
- f"Unsupported model type: {model_type} from {model_module}. "
302
- f"Supported types: {', '.join(available_types)}"
303
- )
304
-
305
- def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
306
- """Convert PyTorch model to ONNX format with dynamic input dimensions."""
307
- if input_shape is None:
308
- raise ValueError("input_shape is required for PyTorch models")
309
-
310
- self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
311
-
312
- try:
313
- # Set model to evaluation mode
314
- model.eval()
315
-
316
- # Create dummy input
317
- dummy_input = torch.randn(input_shape)
318
-
319
- # Use BytesIO to avoid creating temporary files
320
- onnx_buffer = BytesIO()
321
-
322
- # Determine dynamic axes based on input shape
323
- # Common patterns for different input types:
324
- if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
325
- dynamic_axes = {
326
- 'float_input': {
327
- 0: 'batch_size',
328
- 2: 'height', # Make height dynamic for variable input sizes
329
- 3: 'width' # Make width dynamic for variable input sizes
330
- },
331
- 'output': {0: 'batch_size'}
332
- }
333
- elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
334
- dynamic_axes = {
335
- 'float_input': {
336
- 0: 'batch_size',
337
- 1: 'dim1', # Generic dynamic dimension
338
- 2: 'dim2' # Generic dynamic dimension
339
- },
340
- 'output': {0: 'batch_size'}
341
- }
342
- elif len(input_shape) == 2: # Likely (batch, features)
343
- dynamic_axes = {
344
- 'float_input': {
345
- 0: 'batch_size'
346
- # Don't make features dynamic as it usually affects model architecture
347
- },
348
- 'output': {0: 'batch_size'}
349
- }
350
- else:
351
- # For other shapes, just make batch size dynamic
352
- dynamic_axes = {
353
- 'float_input': {0: 'batch_size'},
354
- 'output': {0: 'batch_size'}
355
- }
356
-
357
- torch.onnx.export(
358
- model,
359
- dummy_input,
360
- onnx_buffer,
361
- export_params=True,
362
- opset_version=11,
363
- do_constant_folding=True,
364
- input_names=['float_input'],
365
- output_names=['output'],
366
- dynamic_axes=dynamic_axes
367
- )
368
-
369
- self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
370
- return onnx_buffer.getvalue()
371
-
372
- except Exception as e:
373
- raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
374
-
375
-
376
- def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
377
- """Convert scikit-learn model to ONNX format."""
378
- self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
379
-
380
- # Try to infer input shape if not provided
381
- if input_shape is None:
382
- if hasattr(model, 'n_features_in_'):
383
- input_shape = (1, model.n_features_in_)
384
- else:
385
- raise ValueError(
386
- "input_shape is required for scikit-learn models when n_features_in_ is not available. "
387
- "This usually happens with older sklearn versions or models not fitted yet."
388
- )
389
-
390
- try:
391
- # Convert scikit-learn model to ONNX
392
- initial_type = [('float_input', FloatTensorType(input_shape))]
393
- onnx_model = convert_sklearn(model, initial_types=initial_type)
394
- return onnx_model.SerializeToString()
395
-
396
- except Exception as e:
397
- raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
398
-
399
-
400
- # we do not need to pass in both the model name and the dataset name, since the model name should the same as the virtual dataset name
401
- # but we are gonna have multiple products for the same virtual dataset
402
- # @require_api_key
403
- # async def upload_and_deploy_cnn_model(self, model, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
404
- # """
405
- # Upload a CNN model to the bucket and deploy it.
406
-
407
- # Args:
408
- # model: The model object (PyTorch model or scikit-learn model)
409
- # model_name: Name for the model (without extension)
410
- # dataset: Name of the dataset to create
411
- # product: Product name for the inference
412
- # input_expression: Input expression for the dataset
413
- # dates_iso8601: List of dates in ISO8601 format
414
- # input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
415
- # processing_script_path: Path to the processing script, if not provided, no processing will be done
416
-
417
- # Raises:
418
- # APIError: If the API request fails
419
- # ValueError: If model type is not supported or input_shape is missing for PyTorch models
420
- # ImportError: If required libraries (torch or skl2onnx) are not installed
421
- # """
422
- # await self.upload_model(model=model, model_name=dataset, input_shape=input_shape)
423
- # # so the uploading process is kinda similar, but the deployment step is kinda different
424
- # # we should pass the processing script path to the deploy cnn model function
425
- # await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
185
+ headers = {
186
+ "Content-Type": "application/octet-stream",
187
+ "Content-Length": str(len(model))
188
+ }
189
+ response = await self._client._regular_request("PUT", endpoint = upload_model_url, data=model, headers=headers)
190
+ return response
191
+
192
+ @require_api_key
193
+ async def _upload_script_to_url(self, upload_script_url: str, script_content: str):
194
+ """
195
+ Upload the generated script to the url
196
+ Args:
197
+ url: Url for the upload of the script
198
+ script_content: Content of the script
199
+ returns:
200
+ None
201
+ """
202
+ script_bytes = script_content.encode('utf-8')
203
+ headers = {
204
+ "Content-Type": "text/x-python",
205
+ "Content-Length": str(len(script_bytes))
206
+ }
207
+ response = await self._client._regular_request("PUT", endpoint=upload_script_url, data=script_bytes, headers=headers)
208
+ return response
426
209
 
427
210
  @require_api_key
428
- async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
211
+ async def _upload_model_and_script(self, model, model_name: str, script_name: str, input_expression: str, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
429
212
  """
430
- Upload a CNN model to the bucket and deploy it.
431
-
213
+ Upload a model and script to the bucket
432
214
  Args:
433
215
  model: The model object (PyTorch model or scikit-learn model)
434
216
  model_name: Name for the model (without extension)
435
- dataset: Name of the dataset to create
436
- product: Product name for the inference
217
+ script_name: Name for the script (without extension)
437
218
  input_expression: Input expression for the dataset
438
- dates_iso8601: List of dates in ISO8601 format
439
219
  input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
440
220
  processing_script_path: Path to the processing script, if not provided, no processing will be done
441
-
221
+ model_type: The type of the model we want to upload
442
222
  Raises:
443
223
  APIError: If the API request fails
444
224
  ValueError: If model type is not supported or input_shape is missing for PyTorch models
445
- ImportError: If required libraries (torch or skl2onnx) are not installed
446
- """
447
- await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
448
- # so the uploading process is kinda similar, but the deployment step is kinda different
449
- # we should pass the processing script path to the deploy cnn model function
450
- await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
451
225
 
226
+ Returns:
227
+ bucket_name: Name of the bucket where the model is stored
228
+ """
229
+ response = await self._get_url_for_upload_model_and_script(expression = input_expression, model_name = model_name, script_name = script_name)
230
+ model_url, script_url, bucket_name = response.get("model_upload_url"), response.get("script_upload_url"), response.get("bucket_name")
231
+ if not model_url or not script_url:
232
+ raise ValueError("No url returned from the server for the upload process")
233
+ try:
234
+ model_in_onnx_bytes, model_type = self._convert_model_to_onnx(model = model, input_shape = input_shape, model_type = model_type)
235
+ if model_type == "neural_network":
236
+ script_content = await self._generate_cnn_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
237
+ elif model_type == "random_forest":
238
+ script_content = await self._generate_random_forest_script(bucket_name = bucket_name, virtual_dataset_name = model_name, virtual_product_name = script_name, processing_script_path = processing_script_path)
239
+ else:
240
+ raise ValueError(f"Unsupported model type: {model_type}. Supported types: neural_network, random_forest")
241
+ script_upload_response = await self._upload_script_to_url( upload_script_url = script_url, script_content = script_content)
242
+ if script_upload_response.status not in [200, 201, 204]:
243
+ self._client.logger.error(f"Script upload error: {script_upload_response.text()}")
244
+ raise Exception(f"Failed to upload script: {script_upload_response.text()}")
245
+ model_upload_response = await self._upload_model_to_url(upload_model_url = model_url, model = model_in_onnx_bytes)
246
+ if model_upload_response.status not in [200, 201, 204]:
247
+ self._client.logger.error(f"Model upload error: {model_upload_response.text()}")
248
+ raise Exception(f"Failed to upload model: {model_upload_response.text()}")
249
+ except Exception as e:
250
+ raise Exception(f"Error uploading model: {e}")
251
+ self._client.logger.info(f"Model and Script uploaded successfully to {model_url}")
252
+ return bucket_name
452
253
 
453
254
  @require_api_key
454
- async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
255
+ async def upload_and_deploy_model(self, model, virtual_dataset_name: str, virtual_product_name: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None, model_type: Optional[str] = None):
455
256
  """
456
257
  Upload a model to the bucket and deploy it.
457
-
458
258
  Args:
459
259
  model: The model object (PyTorch model or scikit-learn model)
460
- model_name: Name for the model (without extension)
461
- dataset: Name of the dataset to create
462
- product: Product name for the inference
260
+ virtual_dataset_name: Name for the virtual dataset (without extension)
261
+ virtual_product_name: Product name for the inference
463
262
  input_expression: Input expression for the dataset
464
263
  dates_iso8601: List of dates in ISO8601 format
465
264
  input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
466
- """
467
- await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
468
- await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
265
+ processing_script_path: Path to the processing script, if not provided, no processing will be done
266
+ model_type: The type of the model we want to upload
469
267
 
470
- @require_api_key
471
- def train_model(
472
- self,
473
- model_name: str,
474
- training_dataset: str,
475
- task_type: str,
476
- model_category: str,
477
- architecture: str,
478
- region: str,
479
- hyperparameters: dict = None
480
- ) -> dict:
481
- """
482
- Train a model using the external model training API.
483
-
484
- Args:
485
- model_name (str): The name of the model to train.
486
- training_dataset (str): The training dataset identifier.
487
- task_type (str): The type of ML task (e.g., regression, classification).
488
- model_category (str): The category of model (e.g., random_forest).
489
- architecture (str): The model architecture.
490
- region (str): The region identifier.
491
- hyperparameters (dict, optional): Additional hyperparameters for training.
492
-
493
- Returns:
494
- dict: The response from the model training API.
495
-
496
268
  Raises:
497
269
  APIError: If the API request fails
498
- """
499
- payload = {
500
- "model_name": model_name,
501
- "training_dataset": training_dataset,
502
- "task_type": task_type,
503
- "model_category": model_category,
504
- "architecture": architecture,
505
- "region": region,
506
- "hyperparameters": hyperparameters
507
- }
508
- return self._client._terrakio_request("POST", "/train_model", json=payload)
270
+ ValueError: If model type is not supported or input_shape is missing for PyTorch models
271
+ ImportError: If required libraries (torch or skl2onnx) are not installed
509
272
 
510
- @require_api_key
511
- async def deploy_model(
512
- self,
513
- dataset: str,
514
- product: str,
515
- model_name: str,
516
- input_expression: str,
517
- model_training_job_name: str,
518
- dates_iso8601: list
519
- ) -> Dict[str, Any]:
520
- """
521
- Deploy a model by generating inference script and creating dataset.
522
-
523
- Args:
524
- dataset: Name of the dataset to create
525
- product: Product name for the inference
526
- model_name: Name of the trained model
527
- input_expression: Input expression for the dataset
528
- model_training_job_name: Name of the training job
529
- dates_iso8601: List of dates in ISO8601 format
530
-
531
273
  Returns:
532
- dict: Response from the deployment process
533
-
534
- Raises:
535
- APIError: If the API request fails
274
+ None
536
275
  """
537
- # Get user info to get UID
276
+ bucket_name = await self._upload_model_and_script(model=model, model_name=virtual_dataset_name, script_name= virtual_product_name, input_shape=input_shape, input_expression=input_expression, processing_script_path=processing_script_path, model_type= model_type)
538
277
  user_info = await self._client.auth.get_user_info()
539
278
  uid = user_info["uid"]
540
-
541
- # Generate and upload script
542
- script_content = self._generate_script(model_name, product, model_training_job_name, uid)
543
- script_name = f"{product}.py"
544
- self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
545
-
546
- # Create dataset
547
- return await self._client.datasets.create_dataset(
548
- name=dataset,
279
+ await self._client.datasets.create_dataset(
280
+ name=virtual_dataset_name,
549
281
  collection="terrakio-datasets",
550
- products=[product],
551
- path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
282
+ products=[virtual_product_name],
283
+ path=f"gs://{bucket_name}/{uid}/virtual_datasets/{virtual_dataset_name}/inference_scripts",
552
284
  input=input_expression,
553
285
  dates_iso8601=dates_iso8601,
554
286
  padding=0
555
287
  )
556
288
 
557
- def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
289
+ @require_api_key
290
+ async def _generate_random_forest_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
558
291
  """
559
- Parse a Python file and extract preprocessing and postprocessing function bodies.
292
+ Generate Python inference script for the Random Forest model.
560
293
 
561
294
  Args:
562
- script_path: Path to the Python file containing processing functions
295
+ bucket_name: Name of the bucket where the model is stored
296
+ virtual_dataset_name: Name of the virtual dataset and the model
297
+ virtual_product_name: Name of the virtual product
298
+ processing_script_path: Path to the processing script, if not provided, no processing will be done
563
299
 
564
300
  Returns:
565
- Tuple of (preprocessing_code, postprocessing_code) where each can be None
301
+ str: Generated Python script content
566
302
  """
567
- try:
568
- with open(script_path, 'r', encoding='utf-8') as f:
569
- script_content = f.read()
570
- except FileNotFoundError:
571
- raise FileNotFoundError(f"Processing script not found: {script_path}")
572
- except Exception as e:
573
- raise ValueError(f"Error reading processing script: {e}")
574
-
575
- # Handle empty file
576
- if not script_content.strip():
577
- self._client.logger.info(f"Processing script {script_path} is empty")
578
- return None, None
579
-
580
- try:
581
- # Parse the Python file
582
- tree = ast.parse(script_content)
583
- except SyntaxError as e:
584
- raise ValueError(f"Syntax error in processing script: {e}")
303
+ user_info = await self._client.auth.get_user_info()
304
+ uid = user_info["uid"]
305
+ preprocessing_code, postprocessing_code = None, None
306
+
307
+ if processing_script_path:
308
+ try:
309
+ preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
310
+ if preprocessing_code:
311
+ self._client.logger.info(f"Using custom preprocessing from: {processing_script_path}")
312
+ if postprocessing_code:
313
+ self._client.logger.info(f"Using custom postprocessing from: {processing_script_path}")
314
+ if not preprocessing_code and not postprocessing_code:
315
+ self._client.logger.warning(f"No preprocessing or postprocessing functions found in {processing_script_path}")
316
+ self._client.logger.info("Deployment will continue without custom processing")
317
+ except Exception as e:
318
+ raise ValueError(f"Failed to load processing script: {str(e)}")
319
+
320
+ preprocessing_section = ""
321
+ if preprocessing_code and preprocessing_code.strip():
322
+ clean_preprocessing = textwrap.dedent(preprocessing_code)
323
+ preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
585
324
 
586
- preprocessing_code = None
587
- postprocessing_code = None
325
+ postprocessing_section = ""
326
+ if postprocessing_code and postprocessing_code.strip():
327
+ clean_postprocessing = textwrap.dedent(postprocessing_code)
328
+ postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
329
+
330
+ script_lines = [
331
+ "import logging",
332
+ "from io import BytesIO",
333
+ "import numpy as np",
334
+ "import pandas as pd",
335
+ "import xarray as xr",
336
+ "from google.cloud import storage",
337
+ "from onnxruntime import InferenceSession",
338
+ "from typing import Tuple",
339
+ "",
340
+ "logging.basicConfig(",
341
+ " level=logging.INFO",
342
+ ")",
343
+ "",
344
+ ]
588
345
 
589
- # Find function definitions
590
- function_names = []
591
- for node in ast.walk(tree):
592
- if isinstance(node, ast.FunctionDef):
593
- function_names.append(node.name)
594
- if node.name == 'preprocessing':
595
- preprocessing_code = self._extract_function_body(script_content, node)
596
- elif node.name == 'postprocessing':
597
- postprocessing_code = self._extract_function_body(script_content, node)
346
+ if preprocessing_section:
347
+ script_lines.extend([
348
+ "def validate_preprocessing_output(data_arrays):",
349
+ " \"\"\"",
350
+ " Validate preprocessing output coordinates and data type.",
351
+ " ",
352
+ " Args:",
353
+ " data_arrays: List of xarray DataArrays from preprocessing",
354
+ " ",
355
+ " Returns:",
356
+ " str: Validation signature symbol",
357
+ " ",
358
+ " Raises:",
359
+ " ValueError: If validation fails",
360
+ " \"\"\"",
361
+ " import numpy as np",
362
+ " ",
363
+ " if not data_arrays:",
364
+ " raise ValueError(\"No data arrays provided from preprocessing\")",
365
+ " ",
366
+ " reference_shape = None",
367
+ " ",
368
+ " for i, data_array in enumerate(data_arrays):",
369
+ " # Check if it's an xarray DataArray",
370
+ " if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
371
+ " raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
372
+ " ",
373
+ " # Check coordinates",
374
+ " if 'time' not in data_array.coords:",
375
+ " raise ValueError(f\"Channel {i+1} missing time coordinate\")",
376
+ " ",
377
+ " spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
378
+ " if len(spatial_dims) != 2:",
379
+ " raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
380
+ " ",
381
+ " for dim in spatial_dims:",
382
+ " if dim not in data_array.coords:",
383
+ " raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
384
+ " ",
385
+ " # Check shape consistency",
386
+ " shape = data_array.shape",
387
+ " if reference_shape is None:",
388
+ " reference_shape = shape",
389
+ " else:",
390
+ " if shape != reference_shape:",
391
+ " raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
392
+ " ",
393
+ " # Generate validation signature",
394
+ " signature_components = [",
395
+ " f\"CH{len(data_arrays)}\", # Channel count",
396
+ " f\"T{reference_shape[0]}\", # Time dimension",
397
+ " f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
398
+ " f\"DT{data_arrays[0].values.dtype}\", # Data type",
399
+ " ]",
400
+ " ",
401
+ " signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
402
+ " ",
403
+ " return signature",
404
+ "",
405
+ ])
598
406
 
599
- # Log what was found for debugging
600
- if not function_names:
601
- self._client.logger.warning(f"No functions found in processing script: {script_path}")
602
- else:
603
- found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
604
- if found_functions:
605
- self._client.logger.info(f"Found processing functions: {found_functions}")
606
- else:
607
- self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
608
- f"Available functions: {function_names}")
407
+ if postprocessing_section:
408
+ script_lines.extend([
409
+ "def validate_postprocessing_output(result_array):",
410
+ " \"\"\"",
411
+ " Validate postprocessing output coordinates and data type.",
412
+ " ",
413
+ " Args:",
414
+ " result_array: xarray DataArray from postprocessing",
415
+ " ",
416
+ " Returns:",
417
+ " str: Validation signature symbol",
418
+ " ",
419
+ " Raises:",
420
+ " ValueError: If validation fails",
421
+ " \"\"\"",
422
+ " import numpy as np",
423
+ " ",
424
+ " # Check if it's an xarray DataArray",
425
+ " if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
426
+ " raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
427
+ " ",
428
+ " # Check required coordinates",
429
+ " if 'time' not in result_array.coords:",
430
+ " raise ValueError(\"Missing time coordinate\")",
431
+ " ",
432
+ " spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
433
+ " if len(spatial_dims) != 2:",
434
+ " raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
435
+ " ",
436
+ " for dim in spatial_dims:",
437
+ " if dim not in result_array.coords:",
438
+ " raise ValueError(f\"Missing spatial coordinate: {dim}\")",
439
+ " ",
440
+ " # Check shape",
441
+ " shape = result_array.shape",
442
+ " ",
443
+ " # Generate validation signature",
444
+ " signature_components = [",
445
+ " f\"T{shape[0]}\", # Time dimension",
446
+ " f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
447
+ " f\"DT{result_array.values.dtype}\", # Data type",
448
+ " ]",
449
+ " ",
450
+ " signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
451
+ " ",
452
+ " return signature",
453
+ "",
454
+ ])
609
455
 
610
- return preprocessing_code, postprocessing_code
611
-
612
- def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
613
- """Extract the body of a function from the script content."""
614
- lines = script_content.split('\n')
456
+ if preprocessing_section:
457
+ script_lines.extend([
458
+ "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
459
+ preprocessing_section,
460
+ "",
461
+ ])
615
462
 
616
- # AST line numbers are 1-indexed, convert to 0-indexed
617
- start_line = func_node.lineno - 1 # This is the 'def' line (0-indexed)
618
- end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
463
+ if postprocessing_section:
464
+ script_lines.extend([
465
+ "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
466
+ postprocessing_section,
467
+ "",
468
+ ])
619
469
 
620
- # Extract ONLY the body lines (skip the def line entirely)
621
- body_lines = []
622
- for i in range(start_line + 1, end_line + 1): # +1 to skip the 'def' line
623
- if i < len(lines):
624
- body_lines.append(lines[i])
470
+ script_lines.extend([
471
+ "def get_model():",
472
+ f" logging.info(\"Loading Random Forest model for {virtual_dataset_name}...\")",
473
+ "",
474
+ " client = storage.Client()",
475
+ f" bucket = client.get_bucket('{bucket_name}')",
476
+ f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
477
+ "",
478
+ " model = BytesIO()",
479
+ " blob.download_to_file(model)",
480
+ " model.seek(0)",
481
+ "",
482
+ " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
483
+ " return session",
484
+ "",
485
+ f"def {virtual_product_name}(*bands, model):",
486
+ " logging.info(\"Start preparing Random Forest data\")",
487
+ " data_arrays = list(bands)",
488
+ " ",
489
+ " if not data_arrays:",
490
+ " raise ValueError(\"No bands provided\")",
491
+ " ",
492
+ ])
625
493
 
626
- if not body_lines:
627
- return ""
494
+ if preprocessing_section:
495
+ script_lines.extend([
496
+ " # Apply preprocessing",
497
+ " data_arrays = preprocessing(tuple(data_arrays))",
498
+ " data_arrays = list(data_arrays) # Convert back to list for processing",
499
+ " ",
500
+ " # Validate preprocessing output",
501
+ " preprocessing_signature = validate_preprocessing_output(data_arrays)",
502
+ " ",
503
+ ])
628
504
 
629
- # Join and dedent to remove function-level indentation
630
- body_text = '\n'.join(body_lines)
631
- cleaned_body = textwrap.dedent(body_text).strip()
505
+ script_lines.extend([
506
+ " reference_array = data_arrays[0]",
507
+ " original_shape = reference_array.shape",
508
+ " ",
509
+ " if 'time' in reference_array.dims:",
510
+ " time_coords = reference_array.coords['time']",
511
+ " if len(time_coords) == 1:",
512
+ " output_timestamp = time_coords[0]",
513
+ " else:",
514
+ " years = [pd.to_datetime(t).year for t in time_coords.values]",
515
+ " unique_years = set(years)",
516
+ " ",
517
+ " if len(unique_years) == 1:",
518
+ " year = list(unique_years)[0]",
519
+ " output_timestamp = pd.Timestamp(f\"{year}-01-01\")",
520
+ " else:",
521
+ " latest_year = max(unique_years)",
522
+ " output_timestamp = pd.Timestamp(f\"{latest_year}-01-01\")",
523
+ " else:",
524
+ " output_timestamp = pd.Timestamp(\"1970-01-01\")",
525
+ "",
526
+ " averaged_bands = []",
527
+ " for data_array in data_arrays:",
528
+ " if 'time' in data_array.dims:",
529
+ " averaged_band = np.mean(data_array.values, axis=0)",
530
+ " else:",
531
+ " averaged_band = data_array.values",
532
+ "",
533
+ " flattened_band = averaged_band.reshape(-1, 1)",
534
+ " averaged_bands.append(flattened_band)",
535
+ "",
536
+ " input_data = np.hstack(averaged_bands)",
537
+ "",
538
+ " output = model.run(None, {\"float_input\": input_data.astype(np.float32)})[0]",
539
+ "",
540
+ " if len(original_shape) >= 3:",
541
+ " spatial_shape = original_shape[1:]",
542
+ " else:",
543
+ " spatial_shape = original_shape",
544
+ "",
545
+ " output_reshaped = output.reshape(spatial_shape)",
546
+ "",
547
+ " output_with_time = np.expand_dims(output_reshaped, axis=0)",
548
+ "",
549
+ " if 'time' in reference_array.dims:",
550
+ " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
551
+ " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}",
552
+ " else:",
553
+ " spatial_dims = list(reference_array.dims)",
554
+ " spatial_coords = dict(reference_array.coords)",
555
+ "",
556
+ " result = xr.DataArray(",
557
+ " data=output_with_time.astype(np.float32),",
558
+ " dims=['time'] + list(spatial_dims),",
559
+ " coords={",
560
+ " 'time': [output_timestamp.values],",
561
+ " 'y': spatial_coords['y'].values,",
562
+ " 'x': spatial_coords['x'].values",
563
+ " },",
564
+ " attrs={",
565
+ " 'description': 'Random Forest model prediction',",
566
+ " }",
567
+ " )",
568
+ ])
632
569
 
633
- # Handle empty function body
634
- if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
635
- return ""
570
+ if postprocessing_section:
571
+ script_lines.extend([
572
+ " # Apply postprocessing",
573
+ " result = postprocessing(result)",
574
+ " ",
575
+ " # Validate postprocessing output",
576
+ " postprocessing_signature = validate_postprocessing_output(result)",
577
+ " ",
578
+ ])
636
579
 
637
- return cleaned_body
580
+ script_lines.append(" return result")
638
581
 
582
+ return "\n".join(script_lines)
583
+
639
584
  @require_api_key
640
- async def deploy_cnn_model(
641
- self,
642
- dataset: str,
643
- product: str,
644
- model_name: str,
645
- input_expression: str,
646
- model_training_job_name: str,
647
- dates_iso8601: list,
648
- processing_script_path: Optional[str] = None
649
- ) -> Dict[str, Any]:
585
+ async def _generate_cnn_script(self, bucket_name: str, virtual_dataset_name: str, virtual_product_name: str, processing_script_path: Optional[str] = None) -> str:
650
586
  """
651
- Deploy a CNN model by generating inference script and creating dataset.
587
+ Generate Python inference script for CNN model with time-stacked bands.
652
588
 
653
589
  Args:
654
- dataset: Name of the dataset to create
655
- product: Product name for the inference
656
- model_name: Name of the trained model
657
- input_expression: Input expression for the dataset
658
- model_training_job_name: Name of the training job
659
- dates_iso8601: List of dates in ISO8601 format
590
+ bucket_name: Name of the bucket where the model is stored
591
+ virtual_dataset_name: Name of the virtual dataset and the model
592
+ virtual_product_name: Name of the virtual product
660
593
  processing_script_path: Path to the processing script, if not provided, no processing will be done
661
594
  Returns:
662
- dict: Response from the deployment process
663
-
664
- Raises:
665
- APIError: If the API request fails
595
+ str: Generated Python script content
666
596
  """
667
- # Get user info to get UID
668
597
  user_info = await self._client.auth.get_user_info()
669
598
  uid = user_info["uid"]
670
-
671
599
  preprocessing_code, postprocessing_code = None, None
600
+
672
601
  if processing_script_path:
673
- # if there is a function that is being passed in
674
602
  try:
675
603
  preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
676
604
  if preprocessing_code:
@@ -682,415 +610,17 @@ class ModelManagement:
682
610
  self._client.logger.info("Deployment will continue without custom processing")
683
611
  except Exception as e:
684
612
  raise ValueError(f"Failed to load processing script: {str(e)}")
685
- # so we already have the preprocessing code and the post processing code, I need to pass them to the generate cnn script function
686
- # Generate and upload script
687
- # Build preprocessing section with CONSISTENT 8-space indentation
688
- preprocessing_section = ""
689
- if preprocessing_code and preprocessing_code.strip():
690
- # First dedent the preprocessing code to remove any existing indentation
691
- clean_preprocessing = preprocessing_code
692
- # Then add consistent 8-space indentation to match the template
693
- preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
694
- # print(preprocessing_section)
695
- script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
696
- script_name = f"{product}.py"
697
- self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
698
- # Create dataset
699
- return await self._client.datasets.create_dataset(
700
- name=dataset,
701
- collection="terrakio-datasets",
702
- products=[product],
703
- path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
704
- input=input_expression,
705
- dates_iso8601=dates_iso8601,
706
- padding=0
707
- )
708
-
709
- @require_api_key
710
- def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
711
- """
712
- Generate Python inference script for the model.
713
-
714
- Args:
715
- model_name: Name of the model
716
- product: Product name
717
- model_training_job_name: Training job name
718
- uid: User ID
719
-
720
- Returns:
721
- str: Generated Python script content
722
- """
723
- return textwrap.dedent(f'''
724
- import logging
725
- from io import BytesIO
726
-
727
- import numpy as np
728
- import pandas as pd
729
- import xarray as xr
730
- from google.cloud import storage
731
- from onnxruntime import InferenceSession
732
-
733
- logging.basicConfig(
734
- level=logging.INFO
735
- )
736
-
737
- def get_model():
738
- logging.info("Loading model for {model_name}...")
739
-
740
- client = storage.Client()
741
- bucket = client.get_bucket('terrakio-mass-requests')
742
- blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
743
-
744
- model = BytesIO()
745
- blob.download_to_file(model)
746
- model.seek(0)
747
-
748
- session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
749
- return session
750
-
751
- def {product}(*bands, model):
752
- logging.info("start preparing data")
753
-
754
- data_arrays = list(bands)
755
-
756
- reference_array = data_arrays[0]
757
- original_shape = reference_array.shape
758
- logging.info(f"Original shape: {{original_shape}}")
759
613
 
760
- if 'time' in reference_array.dims:
761
- time_coords = reference_array.coords['time']
762
- if len(time_coords) == 1:
763
- output_timestamp = time_coords[0]
764
- else:
765
- years = [pd.to_datetime(t).year for t in time_coords.values]
766
- unique_years = set(years)
767
-
768
- if len(unique_years) == 1:
769
- year = list(unique_years)[0]
770
- output_timestamp = pd.Timestamp(f"{{year}}-01-01")
771
- else:
772
- latest_year = max(unique_years)
773
- output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
774
- else:
775
- output_timestamp = pd.Timestamp("1970-01-01")
776
-
777
- averaged_bands = []
778
- for data_array in data_arrays:
779
- if 'time' in data_array.dims:
780
- averaged_band = np.mean(data_array.values, axis=0)
781
- logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
782
- else:
783
- averaged_band = data_array.values
784
- logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
785
-
786
- flattened_band = averaged_band.reshape(-1, 1)
787
- averaged_bands.append(flattened_band)
788
-
789
- input_data = np.hstack(averaged_bands)
790
-
791
- logging.info(f"Final input shape: {{input_data.shape}}")
792
-
793
- output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
794
-
795
- logging.info(f"Model output shape: {{output.shape}}")
796
-
797
- if len(original_shape) >= 3:
798
- spatial_shape = original_shape[1:]
799
- else:
800
- spatial_shape = original_shape
801
-
802
- output_reshaped = output.reshape(spatial_shape)
803
-
804
- output_with_time = np.expand_dims(output_reshaped, axis=0)
805
-
806
- if 'time' in reference_array.dims:
807
- spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
808
- spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
809
- else:
810
- spatial_dims = list(reference_array.dims)
811
- spatial_coords = dict(reference_array.coords)
812
-
813
- result = xr.DataArray(
814
- data=output_with_time.astype(np.float32),
815
- dims=['time'] + list(spatial_dims),
816
- coords={{
817
- 'time': [output_timestamp.values],
818
- 'y': spatial_coords['y'].values,
819
- 'x': spatial_coords['x'].values
820
- }}
821
- )
822
- return result
823
- ''').strip()
824
-
825
- # @require_api_key
826
- # def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
827
- # """
828
- # Generate Python inference script for CNN model with time-stacked bands.
829
-
830
- # Args:
831
- # model_name: Name of the model
832
- # product: Product name
833
- # model_training_job_name: Training job name
834
- # uid: User ID
835
- # preprocessing_code: Preprocessing code
836
- # postprocessing_code: Postprocessing code
837
- # Returns:
838
- # str: Generated Python script content
839
- # """
840
- # import textwrap
841
-
842
- # # Build preprocessing section with CONSISTENT 4-space indentation
843
- # preprocessing_section = ""
844
- # if preprocessing_code and preprocessing_code.strip():
845
- # clean_preprocessing = textwrap.dedent(preprocessing_code)
846
- # preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
847
-
848
- # # Build postprocessing section with CONSISTENT 4-space indentation
849
- # postprocessing_section = ""
850
- # if postprocessing_code and postprocessing_code.strip():
851
- # clean_postprocessing = textwrap.dedent(postprocessing_code)
852
- # postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
853
-
854
- # # Build the template WITHOUT dedenting the whole thing, so indentation is preserved
855
- # script_lines = [
856
- # "import logging",
857
- # "from io import BytesIO",
858
- # "import numpy as np",
859
- # "import pandas as pd",
860
- # "import xarray as xr",
861
- # "from google.cloud import storage",
862
- # "from onnxruntime import InferenceSession",
863
- # "from typing import Tuple",
864
- # "",
865
- # "logging.basicConfig(",
866
- # " level=logging.INFO",
867
- # ")",
868
- # "",
869
- # ]
870
-
871
- # # Add preprocessing function definition BEFORE the main function
872
- # if preprocessing_section:
873
- # script_lines.extend([
874
- # "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
875
- # preprocessing_section,
876
- # "",
877
- # ])
878
-
879
- # # Add postprocessing function definition BEFORE the main function
880
- # if postprocessing_section:
881
- # script_lines.extend([
882
- # "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
883
- # postprocessing_section,
884
- # "",
885
- # ])
886
-
887
- # # Add the get_model function
888
- # script_lines.extend([
889
- # "def get_model():",
890
- # f" logging.info(\"Loading CNN model for {model_name}...\")",
891
- # "",
892
- # " client = storage.Client()",
893
- # " bucket = client.get_bucket('terrakio-mass-requests')",
894
- # f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
895
- # "",
896
- # " model = BytesIO()",
897
- # " blob.download_to_file(model)",
898
- # " model.seek(0)",
899
- # "",
900
- # " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
901
- # " return session",
902
- # "",
903
- # f"def {product}(*bands, model):",
904
- # " logging.info(\"Start preparing CNN data with time-stacked bands\")",
905
- # " data_arrays = list(bands)",
906
- # " ",
907
- # " if not data_arrays:",
908
- # " raise ValueError(\"No bands provided\")",
909
- # " ",
910
- # ])
911
-
912
- # # Add preprocessing call if preprocessing exists
913
- # if preprocessing_section:
914
- # script_lines.extend([
915
- # " # Apply preprocessing",
916
- # " data_arrays = preprocessing(tuple(data_arrays))",
917
- # " data_arrays = list(data_arrays) # Convert back to list for processing",
918
- # " ",
919
- # ])
920
-
921
- # # Continue with the rest of the processing logic
922
- # script_lines.extend([
923
- # " reference_array = data_arrays[0]",
924
- # " original_shape = reference_array.shape",
925
- # " logging.info(f\"Original shape: {original_shape}\")",
926
- # " ",
927
- # " # Get time coordinates - all bands should have the same time dimension",
928
- # " if 'time' not in reference_array.dims:",
929
- # " raise ValueError(\"Time dimension is required for CNN processing\")",
930
- # " ",
931
- # " time_coords = reference_array.coords['time']",
932
- # " num_timestamps = len(time_coords)",
933
- # " logging.info(f\"Number of timestamps: {num_timestamps}\")",
934
- # " ",
935
- # " # Get spatial dimensions",
936
- # " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
937
- # " height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
938
- # " width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
939
- # " logging.info(f\"Spatial dimensions: {height} x {width}\")",
940
- # " ",
941
- # " # Stack bands across time dimension",
942
- # " # Result will be: (num_bands * num_timestamps, height, width)",
943
- # " stacked_channels = []",
944
- # " ",
945
- # " for band_idx, data_array in enumerate(data_arrays):",
946
- # " logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
947
- # " ",
948
- # " # Ensure consistent time coordinates across bands",
949
- # " if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
950
- # " logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
951
- # " data_array = data_array.sel(time=time_coords, method='nearest')",
952
- # " ",
953
- # " # Extract values and ensure proper ordering (time, height, width)",
954
- # " band_values = data_array.values",
955
- # " if band_values.ndim == 3:",
956
- # " # Reorder dimensions if needed to ensure (time, height, width)",
957
- # " time_dim_idx = data_array.dims.index('time')",
958
- # " if time_dim_idx != 0:",
959
- # " axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]",
960
- # " band_values = np.transpose(band_values, axes_order)",
961
- # " ",
962
- # " # Add each timestamp of this band to the channel stack",
963
- # " for t in range(num_timestamps):",
964
- # " stacked_channels.append(band_values[t])",
965
- # " ",
966
- # " # Stack all channels: (num_bands * num_timestamps, height, width)",
967
- # " input_channels = np.stack(stacked_channels, axis=0)",
968
- # " total_channels = len(data_arrays) * num_timestamps",
969
- # " logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
970
- # " logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
971
- # " ",
972
- # " # Add batch dimension: (1, num_channels, height, width)",
973
- # " input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
974
- # " logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
975
- # " ",
976
- # " # Run inference",
977
- # " output = model.run(None, {\"float_input\": input_data})[0]",
978
- # " logging.info(f\"Model output shape: {output.shape}\")",
979
- # " ",
980
- # " # UPDATED: Handle multi-class CNN output properly",
981
- # " if output.ndim == 4:",
982
- # " if output.shape[1] == 1:",
983
- # " # Single class output (regression or binary classification)",
984
- # " output_2d = output[0, 0]",
985
- # " logging.info(\"Single channel output detected\")",
986
- # " else:",
987
- # " # Multi-class output - convert logits/probabilities to class predictions",
988
- # " output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
989
- # " output_2d = output_classes[0] # Shape: (height, width)",
990
- # " ",
991
- # " # Apply class merging: merge class 6 into class 3",
992
- # " output_2d = np.where(output_2d == 6, 3, output_2d)",
993
- # " ",
994
- # " logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
995
- # " logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
996
- # " logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
997
- # " elif output.ndim == 3:",
998
- # " # Remove batch dimension",
999
- # " output_2d = output[0]",
1000
- # " logging.info(\"3D output detected, removed batch dimension\")",
1001
- # " else:",
1002
- # " # Handle other cases",
1003
- # " output_2d = np.squeeze(output)",
1004
- # " if output_2d.ndim != 2:",
1005
- # " logging.error(f\"Cannot process output shape: {output.shape}\")",
1006
- # " logging.error(f\"After squeeze: {output_2d.shape}\")",
1007
- # " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
1008
- # " logging.info(\"Applied squeeze to output\")",
1009
- # " ",
1010
- # " # Ensure output is 2D",
1011
- # " if output_2d.ndim != 2:",
1012
- # " raise ValueError(f\"Final output must be 2D, got shape: {output_2d.shape}\")",
1013
- # " ",
1014
- # " # Determine output timestamp (use the latest timestamp)",
1015
- # " output_timestamp = time_coords[-1]",
1016
- # " ",
1017
- # " # Get spatial coordinates from reference array",
1018
- # " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims}",
1019
- # " ",
1020
- # " # Create output DataArray with appropriate data type",
1021
- # " # Use int32 for classification, float32 for regression",
1022
- # " is_multiclass = output.ndim == 4 and output.shape[1] > 1",
1023
- # " if is_multiclass:",
1024
- # " # Multi-class classification - use integer type",
1025
- # " output_dtype = np.int32",
1026
- # " output_type = 'classification'",
1027
- # " else:",
1028
- # " # Single output - use float type",
1029
- # " output_dtype = np.float32",
1030
- # " output_type = 'regression'",
1031
- # " ",
1032
- # " result = xr.DataArray(",
1033
- # " data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
1034
- # " dims=['time'] + spatial_dims,",
1035
- # " coords={",
1036
- # " 'time': [output_timestamp.values],",
1037
- # " spatial_dims[0]: spatial_coords[spatial_dims[0]].values,",
1038
- # " spatial_dims[1]: spatial_coords[spatial_dims[1]].values",
1039
- # " },",
1040
- # " attrs={",
1041
- # " 'description': 'CNN model prediction',",
1042
- # " }",
1043
- # " )",
1044
- # " ",
1045
- # " logging.info(f\"Final result shape: {result.shape}\")",
1046
- # " logging.info(f\"Final result data type: {result.dtype}\")",
1047
- # " logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
1048
- # ])
1049
-
1050
- # # Add postprocessing call if postprocessing exists
1051
- # if postprocessing_section:
1052
- # script_lines.extend([
1053
- # " # Apply postprocessing",
1054
- # " result = postprocessing(result)",
1055
- # " ",
1056
- # ])
1057
-
1058
- # # Single return statement at the end
1059
- # script_lines.append(" return result")
1060
-
1061
- # return "\n".join(script_lines)
1062
-
1063
-
1064
- @require_api_key
1065
- def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
1066
- """
1067
- Generate Python inference script for CNN model with time-stacked bands.
1068
-
1069
- Args:
1070
- model_name: Name of the model
1071
- product: Product name
1072
- model_training_job_name: Training job name
1073
- uid: User ID
1074
- preprocessing_code: Preprocessing code
1075
- postprocessing_code: Postprocessing code
1076
- Returns:
1077
- str: Generated Python script content
1078
- """
1079
- import textwrap
1080
-
1081
- # Build preprocessing section with CONSISTENT 4-space indentation
1082
614
  preprocessing_section = ""
1083
615
  if preprocessing_code and preprocessing_code.strip():
1084
616
  clean_preprocessing = textwrap.dedent(preprocessing_code)
1085
617
  preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
1086
618
 
1087
- # Build postprocessing section with CONSISTENT 4-space indentation
1088
619
  postprocessing_section = ""
1089
620
  if postprocessing_code and postprocessing_code.strip():
1090
621
  clean_postprocessing = textwrap.dedent(postprocessing_code)
1091
622
  postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
1092
623
 
1093
- # Build the template WITHOUT dedenting the whole thing, so indentation is preserved
1094
624
  script_lines = [
1095
625
  "import logging",
1096
626
  "from io import BytesIO",
@@ -1107,7 +637,6 @@ class ModelManagement:
1107
637
  "",
1108
638
  ]
1109
639
 
1110
- # Add preprocessing validation function if preprocessing exists
1111
640
  if preprocessing_section:
1112
641
  script_lines.extend([
1113
642
  "def validate_preprocessing_output(data_arrays):",
@@ -1125,18 +654,12 @@ class ModelManagement:
1125
654
  " \"\"\"",
1126
655
  " import numpy as np",
1127
656
  " ",
1128
- " logging.info(\"=\" * 60)",
1129
- " logging.info(\"VALIDATING PREPROCESSING OUTPUT\")",
1130
- " logging.info(\"=\" * 60)",
1131
- " ",
1132
657
  " if not data_arrays:",
1133
658
  " raise ValueError(\"No data arrays provided from preprocessing\")",
1134
659
  " ",
1135
660
  " reference_shape = None",
1136
661
  " ",
1137
662
  " for i, data_array in enumerate(data_arrays):",
1138
- " logging.info(f\"Validating channel {i+1}/{len(data_arrays)}: {data_array.name}\")",
1139
- " ",
1140
663
  " # Check if it's an xarray DataArray",
1141
664
  " if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
1142
665
  " raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
@@ -1153,12 +676,6 @@ class ModelManagement:
1153
676
  " if dim not in data_array.coords:",
1154
677
  " raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
1155
678
  " ",
1156
- " logging.info(f\" Coordinates: {list(data_array.coords.keys())}\")",
1157
- " ",
1158
- " # Check data type",
1159
- " data_values = data_array.values",
1160
- " logging.info(f\" Data type: {data_values.dtype}\")",
1161
- " ",
1162
679
  " # Check shape consistency",
1163
680
  " shape = data_array.shape",
1164
681
  " if reference_shape is None:",
@@ -1166,8 +683,6 @@ class ModelManagement:
1166
683
  " else:",
1167
684
  " if shape != reference_shape:",
1168
685
  " raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
1169
- " ",
1170
- " logging.info(f\" Shape: {shape}\")",
1171
686
  " ",
1172
687
  " # Generate validation signature",
1173
688
  " signature_components = [",
@@ -1179,19 +694,10 @@ class ModelManagement:
1179
694
  " ",
1180
695
  " signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
1181
696
  " ",
1182
- " logging.info(\"-\" * 60)",
1183
- " logging.info(\"PREPROCESSING VALIDATION SUMMARY\")",
1184
- " logging.info(\"-\" * 60)",
1185
- " logging.info(f\"Channels validated: {len(data_arrays)}\")",
1186
- " logging.info(f\"Common shape: {reference_shape}\")",
1187
- " logging.info(f\"Validation signature: {signature}\")",
1188
- " logging.info(\"=\" * 60)",
1189
- " ",
1190
697
  " return signature",
1191
698
  "",
1192
699
  ])
1193
700
 
1194
- # Add postprocessing validation function if postprocessing exists
1195
701
  if postprocessing_section:
1196
702
  script_lines.extend([
1197
703
  "def validate_postprocessing_output(result_array):",
@@ -1209,10 +715,6 @@ class ModelManagement:
1209
715
  " \"\"\"",
1210
716
  " import numpy as np",
1211
717
  " ",
1212
- " logging.info(\"=\" * 60)",
1213
- " logging.info(\"VALIDATING POSTPROCESSING OUTPUT\")",
1214
- " logging.info(\"=\" * 60)",
1215
- " ",
1216
718
  " # Check if it's an xarray DataArray",
1217
719
  " if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
1218
720
  " raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
@@ -1229,39 +731,22 @@ class ModelManagement:
1229
731
  " if dim not in result_array.coords:",
1230
732
  " raise ValueError(f\"Missing spatial coordinate: {dim}\")",
1231
733
  " ",
1232
- " logging.info(f\"Coordinates found: {list(result_array.coords.keys())}\")",
1233
- " ",
1234
- " # Check data type",
1235
- " data_values = result_array.values",
1236
- " logging.info(f\"Data type: {data_values.dtype}\")",
1237
- " ",
1238
734
  " # Check shape",
1239
735
  " shape = result_array.shape",
1240
- " logging.info(f\"Shape: {shape}\")",
1241
736
  " ",
1242
737
  " # Generate validation signature",
1243
738
  " signature_components = [",
1244
739
  " f\"T{shape[0]}\", # Time dimension",
1245
740
  " f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
1246
- " f\"DT{data_values.dtype}\", # Data type",
741
+ " f\"DT{result_array.values.dtype}\", # Data type",
1247
742
  " ]",
1248
743
  " ",
1249
744
  " signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
1250
745
  " ",
1251
- " logging.info(\"-\" * 60)",
1252
- " logging.info(\"POSTPROCESSING VALIDATION SUMMARY\")",
1253
- " logging.info(\"-\" * 60)",
1254
- " logging.info(f\"Final shape: {shape}\")",
1255
- " logging.info(f\"Final coordinates: {list(result_array.coords.keys())}\")",
1256
- " logging.info(f\"Data type: {data_values.dtype}\")",
1257
- " logging.info(f\"Validation signature: {signature}\")",
1258
- " logging.info(\"=\" * 60)",
1259
- " ",
1260
746
  " return signature",
1261
747
  "",
1262
748
  ])
1263
749
 
1264
- # Add preprocessing function definition BEFORE the main function
1265
750
  if preprocessing_section:
1266
751
  script_lines.extend([
1267
752
  "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
@@ -1269,7 +754,6 @@ class ModelManagement:
1269
754
  "",
1270
755
  ])
1271
756
 
1272
- # Add postprocessing function definition BEFORE the main function
1273
757
  if postprocessing_section:
1274
758
  script_lines.extend([
1275
759
  "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
@@ -1277,14 +761,13 @@ class ModelManagement:
1277
761
  "",
1278
762
  ])
1279
763
 
1280
- # Add the get_model function
1281
764
  script_lines.extend([
1282
765
  "def get_model():",
1283
- f" logging.info(\"Loading CNN model for {model_name}...\")",
766
+ f" logging.info(\"Loading CNN model for {virtual_dataset_name}...\")",
1284
767
  "",
1285
768
  " client = storage.Client()",
1286
- " bucket = client.get_bucket('terrakio-mass-requests')",
1287
- f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
769
+ f" bucket = client.get_bucket('{bucket_name}')",
770
+ f" blob = bucket.blob('{uid}/virtual_datasets/{virtual_dataset_name}/{virtual_dataset_name}.onnx')",
1288
771
  "",
1289
772
  " model = BytesIO()",
1290
773
  " blob.download_to_file(model)",
@@ -1293,7 +776,7 @@ class ModelManagement:
1293
776
  " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
1294
777
  " return session",
1295
778
  "",
1296
- f"def {product}(*bands, model):",
779
+ f"def {virtual_product_name}(*bands, model):",
1297
780
  " logging.info(\"Start preparing CNN data with time-stacked bands\")",
1298
781
  " data_arrays = list(bands)",
1299
782
  " ",
@@ -1302,7 +785,6 @@ class ModelManagement:
1302
785
  " ",
1303
786
  ])
1304
787
 
1305
- # Add preprocessing call and validation if preprocessing exists
1306
788
  if preprocessing_section:
1307
789
  script_lines.extend([
1308
790
  " # Apply preprocessing",
@@ -1311,15 +793,12 @@ class ModelManagement:
1311
793
  " ",
1312
794
  " # Validate preprocessing output",
1313
795
  " preprocessing_signature = validate_preprocessing_output(data_arrays)",
1314
- " logging.info(f\"Preprocessing validation signature: {preprocessing_signature}\")",
1315
796
  " ",
1316
797
  ])
1317
798
 
1318
- # Continue with the rest of the processing logic
1319
799
  script_lines.extend([
1320
800
  " reference_array = data_arrays[0]",
1321
801
  " original_shape = reference_array.shape",
1322
- " logging.info(f\"Original shape: {original_shape}\")",
1323
802
  " ",
1324
803
  " # Get time coordinates - all bands should have the same time dimension",
1325
804
  " if 'time' not in reference_array.dims:",
@@ -1327,24 +806,19 @@ class ModelManagement:
1327
806
  " ",
1328
807
  " time_coords = reference_array.coords['time']",
1329
808
  " num_timestamps = len(time_coords)",
1330
- " logging.info(f\"Number of timestamps: {num_timestamps}\")",
1331
809
  " ",
1332
810
  " # Get spatial dimensions",
1333
811
  " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
1334
812
  " height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
1335
813
  " width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
1336
- " logging.info(f\"Spatial dimensions: {height} x {width}\")",
1337
814
  " ",
1338
815
  " # Stack bands across time dimension",
1339
816
  " # Result will be: (num_bands * num_timestamps, height, width)",
1340
817
  " stacked_channels = []",
1341
818
  " ",
1342
819
  " for band_idx, data_array in enumerate(data_arrays):",
1343
- " logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
1344
- " ",
1345
820
  " # Ensure consistent time coordinates across bands",
1346
821
  " if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
1347
- " logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
1348
822
  " data_array = data_array.sel(time=time_coords, method='nearest')",
1349
823
  " ",
1350
824
  " # Extract values and ensure proper ordering (time, height, width)",
@@ -1363,23 +837,18 @@ class ModelManagement:
1363
837
  " # Stack all channels: (num_bands * num_timestamps, height, width)",
1364
838
  " input_channels = np.stack(stacked_channels, axis=0)",
1365
839
  " total_channels = len(data_arrays) * num_timestamps",
1366
- " logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
1367
- " logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
1368
840
  " ",
1369
841
  " # Add batch dimension: (1, num_channels, height, width)",
1370
842
  " input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
1371
- " logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
1372
843
  " ",
1373
844
  " # Run inference",
1374
845
  " output = model.run(None, {\"float_input\": input_data})[0]",
1375
- " logging.info(f\"Model output shape: {output.shape}\")",
1376
846
  " ",
1377
- " # UPDATED: Handle multi-class CNN output properly",
847
+ " # Handle multi-class CNN output properly",
1378
848
  " if output.ndim == 4:",
1379
849
  " if output.shape[1] == 1:",
1380
850
  " # Single class output (regression or binary classification)",
1381
851
  " output_2d = output[0, 0]",
1382
- " logging.info(\"Single channel output detected\")",
1383
852
  " else:",
1384
853
  " # Multi-class output - convert logits/probabilities to class predictions",
1385
854
  " output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
@@ -1387,22 +856,14 @@ class ModelManagement:
1387
856
  " ",
1388
857
  " # Apply class merging: merge class 6 into class 3",
1389
858
  " output_2d = np.where(output_2d == 6, 3, output_2d)",
1390
- " ",
1391
- " logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
1392
- " logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
1393
- " logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
1394
859
  " elif output.ndim == 3:",
1395
860
  " # Remove batch dimension",
1396
861
  " output_2d = output[0]",
1397
- " logging.info(\"3D output detected, removed batch dimension\")",
1398
862
  " else:",
1399
863
  " # Handle other cases",
1400
864
  " output_2d = np.squeeze(output)",
1401
865
  " if output_2d.ndim != 2:",
1402
- " logging.error(f\"Cannot process output shape: {output.shape}\")",
1403
- " logging.error(f\"After squeeze: {output_2d.shape}\")",
1404
866
  " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
1405
- " logging.info(\"Applied squeeze to output\")",
1406
867
  " ",
1407
868
  " # Ensure output is 2D",
1408
869
  " if output_2d.ndim != 2:",
@@ -1420,11 +881,9 @@ class ModelManagement:
1420
881
  " if is_multiclass:",
1421
882
  " # Multi-class classification - use integer type",
1422
883
  " output_dtype = np.int32",
1423
- " output_type = 'classification'",
1424
884
  " else:",
1425
885
  " # Single output - use float type",
1426
886
  " output_dtype = np.float32",
1427
- " output_type = 'regression'",
1428
887
  " ",
1429
888
  " result = xr.DataArray(",
1430
889
  " data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
@@ -1438,13 +897,8 @@ class ModelManagement:
1438
897
  " 'description': 'CNN model prediction',",
1439
898
  " }",
1440
899
  " )",
1441
- " ",
1442
- " logging.info(f\"Final result shape: {result.shape}\")",
1443
- " logging.info(f\"Final result data type: {result.dtype}\")",
1444
- " logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
1445
900
  ])
1446
901
 
1447
- # Add postprocessing call and validation if postprocessing exists
1448
902
  if postprocessing_section:
1449
903
  script_lines.extend([
1450
904
  " # Apply postprocessing",
@@ -1452,21 +906,223 @@ class ModelManagement:
1452
906
  " ",
1453
907
  " # Validate postprocessing output",
1454
908
  " postprocessing_signature = validate_postprocessing_output(result)",
1455
- " logging.info(f\"Postprocessing validation signature: {postprocessing_signature}\")",
1456
909
  " ",
1457
910
  ])
1458
911
 
1459
- # Single return statement at the end
1460
912
  script_lines.append(" return result")
1461
913
 
1462
914
  return "\n".join(script_lines)
1463
915
 
1464
- @require_api_key
1465
- def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
1466
- """Upload the generated script to Google Cloud Storage"""
916
+ def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
917
+ """
918
+ Parse a Python file and extract preprocessing and postprocessing function bodies.
919
+
920
+ Args:
921
+ script_path: Path to the Python file containing processing functions
922
+
923
+ Returns:
924
+ Tuple of (preprocessing_code, postprocessing_code) where each can be None
925
+ """
926
+ try:
927
+ with open(script_path, 'r', encoding='utf-8') as f:
928
+ script_content = f.read()
929
+ except FileNotFoundError:
930
+ raise FileNotFoundError(f"Processing script not found: {script_path}")
931
+ except Exception as e:
932
+ raise ValueError(f"Error reading processing script: {e}")
933
+
934
+ if not script_content.strip():
935
+ self._client.logger.info(f"Processing script {script_path} is empty")
936
+ return None, None
937
+
938
+ try:
939
+ tree = ast.parse(script_content)
940
+ except SyntaxError as e:
941
+ raise ValueError(f"Syntax error in processing script: {e}")
942
+
943
+ preprocessing_code = None
944
+ postprocessing_code = None
945
+
946
+ function_names = []
947
+ for node in ast.walk(tree):
948
+ if isinstance(node, ast.FunctionDef):
949
+ function_names.append(node.name)
950
+ if node.name == 'preprocessing':
951
+ preprocessing_code = self._extract_function_body(script_content, node)
952
+ elif node.name == 'postprocessing':
953
+ postprocessing_code = self._extract_function_body(script_content, node)
954
+
955
+ if not function_names:
956
+ self._client.logger.warning(f"No functions found in processing script: {script_path}")
957
+ else:
958
+ found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
959
+ if found_functions:
960
+ self._client.logger.info(f"Found processing functions: {found_functions}")
961
+ else:
962
+ self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
963
+ f"Available functions: {function_names}")
964
+
965
+ return preprocessing_code, postprocessing_code
966
+
967
+ def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
968
+ """Extract the body of a function from the script content."""
969
+ lines = script_content.split('\n')
970
+
971
+ start_line = func_node.lineno - 1
972
+ end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
973
+
974
+ body_lines = []
975
+ for i in range(start_line + 1, end_line + 1):
976
+ if i < len(lines):
977
+ body_lines.append(lines[i])
978
+
979
+ if not body_lines:
980
+ return ""
981
+
982
+ body_text = '\n'.join(body_lines)
983
+ cleaned_body = textwrap.dedent(body_text).strip()
984
+
985
+ if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
986
+ return ""
987
+
988
+ return cleaned_body
989
+
990
+ def _convert_model_to_onnx(self, model, input_shape: Tuple[int, ...] = None, model_type: Optional[str] = None) -> bytes:
991
+ """
992
+ Convert a model to ONNX format and return as bytes.
993
+
994
+ Args:
995
+ model: The model object (PyTorch or scikit-learn)
996
+ input_shape: Shape of input data
997
+ model_type: Type of model (neural_network, random_forest), only used for onnx model generation
998
+ Returns:
999
+ bytes: ONNX model as bytes
1000
+
1001
+ Raises:
1002
+ ValueError: If model type is not supported
1003
+ ImportError: If required libraries are not installed
1004
+ """
1005
+ if isinstance(model, torch.nn.Module):
1006
+ if not TORCH_AVAILABLE:
1007
+ raise ImportError("PyTorch is not installed. Please install it with: pip install torch")
1008
+ return self._convert_pytorch_to_onnx(model, input_shape), "neural_network"
1009
+ elif isinstance(model, BaseEstimator):
1010
+ if not SKL2ONNX_AVAILABLE:
1011
+ raise ImportError("skl2onnx is not installed. Please install it with: pip install skl2onnx")
1012
+ return self._convert_sklearn_to_onnx(model, input_shape), "random_forest"
1013
+ elif isinstance(model, ort.InferenceSession):
1014
+ if model_type is None:
1015
+ raise ValueError(
1016
+ "For ONNX InferenceSession models, you must specify the 'model_type' parameter. Currently 'nerual network' and 'random forest' are supported."
1017
+ "Example: model_type='random forest' or model_type='neural network'"
1018
+ )
1019
+ return model.SerializeToString(), model_type
1020
+ else:
1021
+ model_type = type(model).__name__
1022
+ raise ValueError(f"Unsupported model type: {model_type}. Supported types: PyTorch nn.Module, sklearn BaseEstimator")
1023
+
1024
+ def _convert_pytorch_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
1025
+ try:
1026
+ model.eval()
1027
+ dummy_input = torch.randn(input_shape)
1028
+ onnx_buffer = BytesIO()
1029
+
1030
+ if len(input_shape) == 4:
1031
+ dynamic_axes = {
1032
+ 'float_input': {
1033
+ 0: 'batch_size',
1034
+ 2: 'height',
1035
+ 3: 'width'
1036
+ }
1037
+ }
1038
+
1039
+ elif len(input_shape) == 5:
1040
+ dynamic_axes = {
1041
+ 'float_input': {
1042
+ 0: 'batch_size',
1043
+ 3: 'height',
1044
+ 4: 'width'
1045
+ }
1046
+ }
1047
+
1048
+ else:
1049
+ dynamic_axes = {
1050
+ 'float_input': {
1051
+ 0: 'batch_size'
1052
+ }
1053
+ }
1054
+
1055
+ torch.onnx.export(
1056
+ model,
1057
+ dummy_input,
1058
+ onnx_buffer,
1059
+ input_names=['float_input'],
1060
+ dynamic_axes=dynamic_axes
1061
+ )
1062
+
1063
+ return onnx_buffer.getvalue()
1064
+ except Exception as e:
1065
+ raise ValueError(f"Failed to convert PyTorch model to ONNX: {str(e)}")
1066
+
1067
+ def _convert_sklearn_to_onnx(self, model, input_shape: Tuple[int, ...]) -> bytes:
1068
+ """
1069
+ Convert scikit-learn model(assume it is a random forest model) to ONNX format.
1070
+
1071
+ Args:
1072
+ model: The scikit-learn model object
1073
+ input_shape: Shape of input data (required)
1074
+
1075
+ Returns:
1076
+ bytes: ONNX model as bytes
1077
+
1078
+ Raises:
1079
+ ValueError: If conversion fails
1080
+ """
1081
+ self._client.logger.info(f"Converting random forest model to ONNX...")
1082
+
1083
+ try:
1084
+ initial_type = [('float_input', FloatTensorType(input_shape))]
1085
+ onnx_model = convert_sklearn(model, initial_types=initial_type)
1086
+ return onnx_model.SerializeToString()
1087
+ except Exception as e:
1088
+ raise ValueError(f"Failed to convert scikit-learn model to ONNX: {str(e)}")
1467
1089
 
1468
- client = storage.Client()
1469
- bucket = client.get_bucket('terrakio-mass-requests')
1470
- blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
1471
- blob.upload_from_string(script_content, content_type='text/plain')
1472
- logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
1090
+ @require_api_key
1091
+ def train_model(
1092
+ self,
1093
+ model_name: str,
1094
+ training_dataset: str,
1095
+ task_type: str,
1096
+ model_category: str,
1097
+ architecture: str,
1098
+ region: str,
1099
+ hyperparameters: dict = None
1100
+ ) -> dict:
1101
+ """
1102
+ Train a model using the external model training API.
1103
+
1104
+ Args:
1105
+ model_name (str): The name of the model to train.
1106
+ training_dataset (str): The training dataset identifier.
1107
+ task_type (str): The type of ML task (e.g., regression, classification).
1108
+ model_category (str): The category of model (e.g., random_forest).
1109
+ architecture (str): The model architecture.
1110
+ region (str): The region identifier.
1111
+ hyperparameters (dict, optional): Additional hyperparameters for training.
1112
+
1113
+ Returns:
1114
+ dict: The response from the model training API.
1115
+
1116
+ Raises:
1117
+ APIError: If the API request fails
1118
+ """
1119
+ payload = {
1120
+ "model_name": model_name,
1121
+ "training_dataset": training_dataset,
1122
+ "task_type": task_type,
1123
+ "model_category": model_category,
1124
+ "architecture": architecture,
1125
+ "region": region,
1126
+ "hyperparameters": hyperparameters
1127
+ }
1128
+ return self._client._terrakio_request("POST", "/train_model", json=payload)