terrakio-core 0.3.9__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +1 -1
- terrakio_core/async_client.py +21 -2
- terrakio_core/client.py +101 -5
- terrakio_core/convenience_functions/convenience_functions.py +280 -29
- terrakio_core/endpoints/mass_stats.py +71 -16
- terrakio_core/endpoints/model_management.py +424 -217
- terrakio_core/endpoints/user_management.py +5 -5
- terrakio_core/sync_client.py +106 -185
- {terrakio_core-0.3.9.dist-info → terrakio_core-0.4.2.dist-info}/METADATA +1 -1
- {terrakio_core-0.3.9.dist-info → terrakio_core-0.4.2.dist-info}/RECORD +12 -12
- {terrakio_core-0.3.9.dist-info → terrakio_core-0.4.2.dist-info}/WHEEL +0 -0
- {terrakio_core-0.3.9.dist-info → terrakio_core-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,11 @@ import json
|
|
|
3
3
|
import time
|
|
4
4
|
import textwrap
|
|
5
5
|
import logging
|
|
6
|
-
from typing import Dict, Any, Union, Tuple
|
|
6
|
+
from typing import Dict, Any, Union, Tuple, Optional
|
|
7
7
|
from io import BytesIO
|
|
8
8
|
import numpy as np
|
|
9
9
|
from google.cloud import storage
|
|
10
|
+
import ast
|
|
10
11
|
from ..helper.decorators import require_token, require_api_key, require_auth
|
|
11
12
|
TORCH_AVAILABLE = False
|
|
12
13
|
SKL2ONNX_AVAILABLE = False
|
|
@@ -30,6 +31,7 @@ except ImportError:
|
|
|
30
31
|
from io import BytesIO
|
|
31
32
|
from typing import Tuple
|
|
32
33
|
|
|
34
|
+
|
|
33
35
|
class ModelManagement:
|
|
34
36
|
def __init__(self, client):
|
|
35
37
|
self._client = client
|
|
@@ -347,9 +349,10 @@ class ModelManagement:
|
|
|
347
349
|
|
|
348
350
|
except Exception as e:
|
|
349
351
|
raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
|
|
350
|
-
|
|
352
|
+
|
|
353
|
+
|
|
351
354
|
@require_api_key
|
|
352
|
-
async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
|
|
355
|
+
async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
|
|
353
356
|
"""
|
|
354
357
|
Upload a CNN model to the bucket and deploy it.
|
|
355
358
|
|
|
@@ -361,6 +364,7 @@ class ModelManagement:
|
|
|
361
364
|
input_expression: Input expression for the dataset
|
|
362
365
|
dates_iso8601: List of dates in ISO8601 format
|
|
363
366
|
input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
367
|
+
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
364
368
|
|
|
365
369
|
Raises:
|
|
366
370
|
APIError: If the API request fails
|
|
@@ -369,7 +373,8 @@ class ModelManagement:
|
|
|
369
373
|
"""
|
|
370
374
|
await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
|
|
371
375
|
# so the uploading process is kinda similar, but the deployment step is kinda different
|
|
372
|
-
|
|
376
|
+
# we should pass the processing script path to the deploy cnn model function
|
|
377
|
+
await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
|
|
373
378
|
|
|
374
379
|
@require_api_key
|
|
375
380
|
async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
|
|
@@ -475,6 +480,88 @@ class ModelManagement:
|
|
|
475
480
|
padding=0
|
|
476
481
|
)
|
|
477
482
|
|
|
483
|
+
def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
|
|
484
|
+
"""
|
|
485
|
+
Parse a Python file and extract preprocessing and postprocessing function bodies.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
script_path: Path to the Python file containing processing functions
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
Tuple of (preprocessing_code, postprocessing_code) where each can be None
|
|
492
|
+
"""
|
|
493
|
+
try:
|
|
494
|
+
with open(script_path, 'r', encoding='utf-8') as f:
|
|
495
|
+
script_content = f.read()
|
|
496
|
+
except FileNotFoundError:
|
|
497
|
+
raise FileNotFoundError(f"Processing script not found: {script_path}")
|
|
498
|
+
except Exception as e:
|
|
499
|
+
raise ValueError(f"Error reading processing script: {e}")
|
|
500
|
+
|
|
501
|
+
# Handle empty file
|
|
502
|
+
if not script_content.strip():
|
|
503
|
+
self._client.logger.info(f"Processing script {script_path} is empty")
|
|
504
|
+
return None, None
|
|
505
|
+
|
|
506
|
+
try:
|
|
507
|
+
# Parse the Python file
|
|
508
|
+
tree = ast.parse(script_content)
|
|
509
|
+
except SyntaxError as e:
|
|
510
|
+
raise ValueError(f"Syntax error in processing script: {e}")
|
|
511
|
+
|
|
512
|
+
preprocessing_code = None
|
|
513
|
+
postprocessing_code = None
|
|
514
|
+
|
|
515
|
+
# Find function definitions
|
|
516
|
+
function_names = []
|
|
517
|
+
for node in ast.walk(tree):
|
|
518
|
+
if isinstance(node, ast.FunctionDef):
|
|
519
|
+
function_names.append(node.name)
|
|
520
|
+
if node.name == 'preprocessing':
|
|
521
|
+
preprocessing_code = self._extract_function_body(script_content, node)
|
|
522
|
+
elif node.name == 'postprocessing':
|
|
523
|
+
postprocessing_code = self._extract_function_body(script_content, node)
|
|
524
|
+
|
|
525
|
+
# Log what was found for debugging
|
|
526
|
+
if not function_names:
|
|
527
|
+
self._client.logger.warning(f"No functions found in processing script: {script_path}")
|
|
528
|
+
else:
|
|
529
|
+
found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
|
|
530
|
+
if found_functions:
|
|
531
|
+
self._client.logger.info(f"Found processing functions: {found_functions}")
|
|
532
|
+
else:
|
|
533
|
+
self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
|
|
534
|
+
f"Available functions: {function_names}")
|
|
535
|
+
|
|
536
|
+
return preprocessing_code, postprocessing_code
|
|
537
|
+
|
|
538
|
+
def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
|
|
539
|
+
"""Extract the body of a function from the script content."""
|
|
540
|
+
lines = script_content.split('\n')
|
|
541
|
+
|
|
542
|
+
# AST line numbers are 1-indexed, convert to 0-indexed
|
|
543
|
+
start_line = func_node.lineno - 1 # This is the 'def' line (0-indexed)
|
|
544
|
+
end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
|
|
545
|
+
|
|
546
|
+
# Extract ONLY the body lines (skip the def line entirely)
|
|
547
|
+
body_lines = []
|
|
548
|
+
for i in range(start_line + 1, end_line + 1): # +1 to skip the 'def' line
|
|
549
|
+
if i < len(lines):
|
|
550
|
+
body_lines.append(lines[i])
|
|
551
|
+
|
|
552
|
+
if not body_lines:
|
|
553
|
+
return ""
|
|
554
|
+
|
|
555
|
+
# Join and dedent to remove function-level indentation
|
|
556
|
+
body_text = '\n'.join(body_lines)
|
|
557
|
+
cleaned_body = textwrap.dedent(body_text).strip()
|
|
558
|
+
|
|
559
|
+
# Handle empty function body
|
|
560
|
+
if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
|
|
561
|
+
return ""
|
|
562
|
+
|
|
563
|
+
return cleaned_body
|
|
564
|
+
|
|
478
565
|
@require_api_key
|
|
479
566
|
async def deploy_cnn_model(
|
|
480
567
|
self,
|
|
@@ -483,7 +570,8 @@ class ModelManagement:
|
|
|
483
570
|
model_name: str,
|
|
484
571
|
input_expression: str,
|
|
485
572
|
model_training_job_name: str,
|
|
486
|
-
dates_iso8601: list
|
|
573
|
+
dates_iso8601: list,
|
|
574
|
+
processing_script_path: Optional[str] = None
|
|
487
575
|
) -> Dict[str, Any]:
|
|
488
576
|
"""
|
|
489
577
|
Deploy a CNN model by generating inference script and creating dataset.
|
|
@@ -495,7 +583,7 @@ class ModelManagement:
|
|
|
495
583
|
input_expression: Input expression for the dataset
|
|
496
584
|
model_training_job_name: Name of the training job
|
|
497
585
|
dates_iso8601: List of dates in ISO8601 format
|
|
498
|
-
|
|
586
|
+
processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
499
587
|
Returns:
|
|
500
588
|
dict: Response from the deployment process
|
|
501
589
|
|
|
@@ -506,8 +594,31 @@ class ModelManagement:
|
|
|
506
594
|
user_info = await self._client.auth.get_user_info()
|
|
507
595
|
uid = user_info["uid"]
|
|
508
596
|
|
|
597
|
+
preprocessing_code, postprocessing_code = None, None
|
|
598
|
+
if processing_script_path:
|
|
599
|
+
# if there is a function that is being passed in
|
|
600
|
+
try:
|
|
601
|
+
preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
|
|
602
|
+
if preprocessing_code:
|
|
603
|
+
self._client.logger.info(f"Using custom preprocessing from: {processing_script_path}")
|
|
604
|
+
if postprocessing_code:
|
|
605
|
+
self._client.logger.info(f"Using custom postprocessing from: {processing_script_path}")
|
|
606
|
+
if not preprocessing_code and not postprocessing_code:
|
|
607
|
+
self._client.logger.warning(f"No preprocessing or postprocessing functions found in {processing_script_path}")
|
|
608
|
+
self._client.logger.info("Deployment will continue without custom processing")
|
|
609
|
+
except Exception as e:
|
|
610
|
+
raise ValueError(f"Failed to load processing script: {str(e)}")
|
|
611
|
+
# so we already have the preprocessing code and the post processing code, I need to pass them to the generate cnn script function
|
|
509
612
|
# Generate and upload script
|
|
510
|
-
|
|
613
|
+
# Build preprocessing section with CONSISTENT 8-space indentation
|
|
614
|
+
preprocessing_section = ""
|
|
615
|
+
if preprocessing_code and preprocessing_code.strip():
|
|
616
|
+
# First dedent the preprocessing code to remove any existing indentation
|
|
617
|
+
clean_preprocessing = preprocessing_code
|
|
618
|
+
# Then add consistent 8-space indentation to match the template
|
|
619
|
+
preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
|
|
620
|
+
print(preprocessing_section)
|
|
621
|
+
script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
|
|
511
622
|
script_name = f"{product}.py"
|
|
512
623
|
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
513
624
|
# Create dataset
|
|
@@ -536,109 +647,109 @@ class ModelManagement:
|
|
|
536
647
|
str: Generated Python script content
|
|
537
648
|
"""
|
|
538
649
|
return textwrap.dedent(f'''
|
|
539
|
-
|
|
540
|
-
|
|
650
|
+
import logging
|
|
651
|
+
from io import BytesIO
|
|
541
652
|
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
653
|
+
import numpy as np
|
|
654
|
+
import pandas as pd
|
|
655
|
+
import xarray as xr
|
|
656
|
+
from google.cloud import storage
|
|
657
|
+
from onnxruntime import InferenceSession
|
|
547
658
|
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
659
|
+
logging.basicConfig(
|
|
660
|
+
level=logging.INFO
|
|
661
|
+
)
|
|
551
662
|
|
|
552
|
-
|
|
553
|
-
|
|
663
|
+
def get_model():
|
|
664
|
+
logging.info("Loading model for {model_name}...")
|
|
554
665
|
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
666
|
+
client = storage.Client()
|
|
667
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
668
|
+
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
558
669
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
670
|
+
model = BytesIO()
|
|
671
|
+
blob.download_to_file(model)
|
|
672
|
+
model.seek(0)
|
|
562
673
|
|
|
563
|
-
|
|
564
|
-
|
|
674
|
+
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
675
|
+
return session
|
|
565
676
|
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
else:
|
|
580
|
-
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
581
|
-
unique_years = set(years)
|
|
582
|
-
|
|
583
|
-
if len(unique_years) == 1:
|
|
584
|
-
year = list(unique_years)[0]
|
|
585
|
-
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
586
|
-
else:
|
|
587
|
-
latest_year = max(unique_years)
|
|
588
|
-
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
677
|
+
def {product}(*bands, model):
|
|
678
|
+
logging.info("start preparing data")
|
|
679
|
+
|
|
680
|
+
data_arrays = list(bands)
|
|
681
|
+
|
|
682
|
+
reference_array = data_arrays[0]
|
|
683
|
+
original_shape = reference_array.shape
|
|
684
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
685
|
+
|
|
686
|
+
if 'time' in reference_array.dims:
|
|
687
|
+
time_coords = reference_array.coords['time']
|
|
688
|
+
if len(time_coords) == 1:
|
|
689
|
+
output_timestamp = time_coords[0]
|
|
589
690
|
else:
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
691
|
+
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
692
|
+
unique_years = set(years)
|
|
693
|
+
|
|
694
|
+
if len(unique_years) == 1:
|
|
695
|
+
year = list(unique_years)[0]
|
|
696
|
+
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
597
697
|
else:
|
|
598
|
-
|
|
599
|
-
|
|
698
|
+
latest_year = max(unique_years)
|
|
699
|
+
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
700
|
+
else:
|
|
701
|
+
output_timestamp = pd.Timestamp("1970-01-01")
|
|
600
702
|
|
|
601
|
-
|
|
602
|
-
|
|
703
|
+
averaged_bands = []
|
|
704
|
+
for data_array in data_arrays:
|
|
705
|
+
if 'time' in data_array.dims:
|
|
706
|
+
averaged_band = np.mean(data_array.values, axis=0)
|
|
707
|
+
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
708
|
+
else:
|
|
709
|
+
averaged_band = data_array.values
|
|
710
|
+
logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
|
|
603
711
|
|
|
604
|
-
|
|
712
|
+
flattened_band = averaged_band.reshape(-1, 1)
|
|
713
|
+
averaged_bands.append(flattened_band)
|
|
605
714
|
|
|
606
|
-
|
|
715
|
+
input_data = np.hstack(averaged_bands)
|
|
607
716
|
|
|
608
|
-
|
|
717
|
+
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
609
718
|
|
|
610
|
-
|
|
719
|
+
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
611
720
|
|
|
612
|
-
|
|
613
|
-
spatial_shape = original_shape[1:]
|
|
614
|
-
else:
|
|
615
|
-
spatial_shape = original_shape
|
|
721
|
+
logging.info(f"Model output shape: {{output.shape}}")
|
|
616
722
|
|
|
617
|
-
|
|
723
|
+
if len(original_shape) >= 3:
|
|
724
|
+
spatial_shape = original_shape[1:]
|
|
725
|
+
else:
|
|
726
|
+
spatial_shape = original_shape
|
|
618
727
|
|
|
619
|
-
|
|
728
|
+
output_reshaped = output.reshape(spatial_shape)
|
|
620
729
|
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
730
|
+
output_with_time = np.expand_dims(output_reshaped, axis=0)
|
|
731
|
+
|
|
732
|
+
if 'time' in reference_array.dims:
|
|
733
|
+
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
734
|
+
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
|
|
735
|
+
else:
|
|
736
|
+
spatial_dims = list(reference_array.dims)
|
|
737
|
+
spatial_coords = dict(reference_array.coords)
|
|
738
|
+
|
|
739
|
+
result = xr.DataArray(
|
|
740
|
+
data=output_with_time.astype(np.float32),
|
|
741
|
+
dims=['time'] + list(spatial_dims),
|
|
742
|
+
coords={{
|
|
743
|
+
'time': [output_timestamp.values],
|
|
744
|
+
'y': spatial_coords['y'].values,
|
|
745
|
+
'x': spatial_coords['x'].values
|
|
746
|
+
}}
|
|
747
|
+
)
|
|
748
|
+
return result
|
|
638
749
|
''').strip()
|
|
639
750
|
|
|
640
751
|
@require_api_key
|
|
641
|
-
def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
752
|
+
def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
|
|
642
753
|
"""
|
|
643
754
|
Generate Python inference script for CNN model with time-stacked bands.
|
|
644
755
|
|
|
@@ -647,137 +758,233 @@ class ModelManagement:
|
|
|
647
758
|
product: Product name
|
|
648
759
|
model_training_job_name: Training job name
|
|
649
760
|
uid: User ID
|
|
650
|
-
|
|
761
|
+
preprocessing_code: Preprocessing code
|
|
762
|
+
postprocessing_code: Postprocessing code
|
|
651
763
|
Returns:
|
|
652
764
|
str: Generated Python script content
|
|
653
765
|
"""
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
)
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
766
|
+
import textwrap
|
|
767
|
+
|
|
768
|
+
# Build preprocessing section with CONSISTENT 4-space indentation
|
|
769
|
+
preprocessing_section = ""
|
|
770
|
+
if preprocessing_code and preprocessing_code.strip():
|
|
771
|
+
clean_preprocessing = textwrap.dedent(preprocessing_code)
|
|
772
|
+
preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
|
|
773
|
+
|
|
774
|
+
# Build postprocessing section with CONSISTENT 4-space indentation
|
|
775
|
+
postprocessing_section = ""
|
|
776
|
+
if postprocessing_code and postprocessing_code.strip():
|
|
777
|
+
clean_postprocessing = textwrap.dedent(postprocessing_code)
|
|
778
|
+
postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
|
|
779
|
+
|
|
780
|
+
# Build the template WITHOUT dedenting the whole thing, so indentation is preserved
|
|
781
|
+
script_lines = [
|
|
782
|
+
"import logging",
|
|
783
|
+
"from io import BytesIO",
|
|
784
|
+
"import numpy as np",
|
|
785
|
+
"import pandas as pd",
|
|
786
|
+
"import xarray as xr",
|
|
787
|
+
"from google.cloud import storage",
|
|
788
|
+
"from onnxruntime import InferenceSession",
|
|
789
|
+
"from typing import Tuple",
|
|
790
|
+
"",
|
|
791
|
+
"logging.basicConfig(",
|
|
792
|
+
" level=logging.INFO",
|
|
793
|
+
")",
|
|
794
|
+
"",
|
|
795
|
+
]
|
|
796
|
+
|
|
797
|
+
# Add preprocessing function definition BEFORE the main function
|
|
798
|
+
if preprocessing_section:
|
|
799
|
+
script_lines.extend([
|
|
800
|
+
"def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
|
|
801
|
+
preprocessing_section,
|
|
802
|
+
"",
|
|
803
|
+
])
|
|
804
|
+
|
|
805
|
+
# Add postprocessing function definition BEFORE the main function
|
|
806
|
+
if postprocessing_section:
|
|
807
|
+
script_lines.extend([
|
|
808
|
+
"def postprocessing(array: xr.DataArray) -> xr.DataArray:",
|
|
809
|
+
postprocessing_section,
|
|
810
|
+
"",
|
|
811
|
+
])
|
|
812
|
+
|
|
813
|
+
# Add the get_model function
|
|
814
|
+
script_lines.extend([
|
|
815
|
+
"def get_model():",
|
|
816
|
+
f" logging.info(\"Loading CNN model for {model_name}...\")",
|
|
817
|
+
"",
|
|
818
|
+
" client = storage.Client()",
|
|
819
|
+
" bucket = client.get_bucket('terrakio-mass-requests')",
|
|
820
|
+
f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
|
|
821
|
+
"",
|
|
822
|
+
" model = BytesIO()",
|
|
823
|
+
" blob.download_to_file(model)",
|
|
824
|
+
" model.seek(0)",
|
|
825
|
+
"",
|
|
826
|
+
" session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
|
|
827
|
+
" return session",
|
|
828
|
+
"",
|
|
829
|
+
f"def {product}(*bands, model):",
|
|
830
|
+
" logging.info(\"Start preparing CNN data with time-stacked bands\")",
|
|
831
|
+
" data_arrays = list(bands)",
|
|
832
|
+
" ",
|
|
833
|
+
" if not data_arrays:",
|
|
834
|
+
" raise ValueError(\"No bands provided\")",
|
|
835
|
+
" ",
|
|
836
|
+
])
|
|
837
|
+
|
|
838
|
+
# Add preprocessing call if preprocessing exists
|
|
839
|
+
if preprocessing_section:
|
|
840
|
+
script_lines.extend([
|
|
841
|
+
" # Apply preprocessing",
|
|
842
|
+
" data_arrays = preprocessing(tuple(data_arrays))",
|
|
843
|
+
" data_arrays = list(data_arrays) # Convert back to list for processing",
|
|
844
|
+
" ",
|
|
845
|
+
])
|
|
846
|
+
|
|
847
|
+
# Continue with the rest of the processing logic
|
|
848
|
+
script_lines.extend([
|
|
849
|
+
" reference_array = data_arrays[0]",
|
|
850
|
+
" original_shape = reference_array.shape",
|
|
851
|
+
" logging.info(f\"Original shape: {original_shape}\")",
|
|
852
|
+
" ",
|
|
853
|
+
" # Get time coordinates - all bands should have the same time dimension",
|
|
854
|
+
" if 'time' not in reference_array.dims:",
|
|
855
|
+
" raise ValueError(\"Time dimension is required for CNN processing\")",
|
|
856
|
+
" ",
|
|
857
|
+
" time_coords = reference_array.coords['time']",
|
|
858
|
+
" num_timestamps = len(time_coords)",
|
|
859
|
+
" logging.info(f\"Number of timestamps: {num_timestamps}\")",
|
|
860
|
+
" ",
|
|
861
|
+
" # Get spatial dimensions",
|
|
862
|
+
" spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
|
|
863
|
+
" height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
|
|
864
|
+
" width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
|
|
865
|
+
" logging.info(f\"Spatial dimensions: {height} x {width}\")",
|
|
866
|
+
" ",
|
|
867
|
+
" # Stack bands across time dimension",
|
|
868
|
+
" # Result will be: (num_bands * num_timestamps, height, width)",
|
|
869
|
+
" stacked_channels = []",
|
|
870
|
+
" ",
|
|
871
|
+
" for band_idx, data_array in enumerate(data_arrays):",
|
|
872
|
+
" logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
|
|
873
|
+
" ",
|
|
874
|
+
" # Ensure consistent time coordinates across bands",
|
|
875
|
+
" if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
|
|
876
|
+
" logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
|
|
877
|
+
" data_array = data_array.sel(time=time_coords, method='nearest')",
|
|
878
|
+
" ",
|
|
879
|
+
" # Extract values and ensure proper ordering (time, height, width)",
|
|
880
|
+
" band_values = data_array.values",
|
|
881
|
+
" if band_values.ndim == 3:",
|
|
882
|
+
" # Reorder dimensions if needed to ensure (time, height, width)",
|
|
883
|
+
" time_dim_idx = data_array.dims.index('time')",
|
|
884
|
+
" if time_dim_idx != 0:",
|
|
885
|
+
" axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]",
|
|
886
|
+
" band_values = np.transpose(band_values, axes_order)",
|
|
887
|
+
" ",
|
|
888
|
+
" # Add each timestamp of this band to the channel stack",
|
|
889
|
+
" for t in range(num_timestamps):",
|
|
890
|
+
" stacked_channels.append(band_values[t])",
|
|
891
|
+
" ",
|
|
892
|
+
" # Stack all channels: (num_bands * num_timestamps, height, width)",
|
|
893
|
+
" input_channels = np.stack(stacked_channels, axis=0)",
|
|
894
|
+
" total_channels = len(data_arrays) * num_timestamps",
|
|
895
|
+
" logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
|
|
896
|
+
" logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
|
|
897
|
+
" ",
|
|
898
|
+
" # Add batch dimension: (1, num_channels, height, width)",
|
|
899
|
+
" input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
|
|
900
|
+
" logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
|
|
901
|
+
" ",
|
|
902
|
+
" # Run inference",
|
|
903
|
+
" output = model.run(None, {\"float_input\": input_data})[0]",
|
|
904
|
+
" logging.info(f\"Model output shape: {output.shape}\")",
|
|
905
|
+
" ",
|
|
906
|
+
" # UPDATED: Handle multi-class CNN output properly",
|
|
907
|
+
" if output.ndim == 4:",
|
|
908
|
+
" if output.shape[1] == 1:",
|
|
909
|
+
" # Single class output (regression or binary classification)",
|
|
910
|
+
" output_2d = output[0, 0]",
|
|
911
|
+
" logging.info(\"Single channel output detected\")",
|
|
912
|
+
" else:",
|
|
913
|
+
" # Multi-class output - convert logits/probabilities to class predictions",
|
|
914
|
+
" output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
|
|
915
|
+
" output_2d = output_classes[0] # Shape: (height, width)",
|
|
916
|
+
" ",
|
|
917
|
+
" # Apply class merging: merge class 6 into class 3",
|
|
918
|
+
" output_2d = np.where(output_2d == 6, 3, output_2d)",
|
|
919
|
+
" ",
|
|
920
|
+
" logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
|
|
921
|
+
" logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
|
|
922
|
+
" logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
|
|
923
|
+
" elif output.ndim == 3:",
|
|
924
|
+
" # Remove batch dimension",
|
|
925
|
+
" output_2d = output[0]",
|
|
926
|
+
" logging.info(\"3D output detected, removed batch dimension\")",
|
|
927
|
+
" else:",
|
|
928
|
+
" # Handle other cases",
|
|
929
|
+
" output_2d = np.squeeze(output)",
|
|
930
|
+
" if output_2d.ndim != 2:",
|
|
931
|
+
" logging.error(f\"Cannot process output shape: {output.shape}\")",
|
|
932
|
+
" logging.error(f\"After squeeze: {output_2d.shape}\")",
|
|
933
|
+
" raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
|
|
934
|
+
" logging.info(\"Applied squeeze to output\")",
|
|
935
|
+
" ",
|
|
936
|
+
" # Ensure output is 2D",
|
|
937
|
+
" if output_2d.ndim != 2:",
|
|
938
|
+
" raise ValueError(f\"Final output must be 2D, got shape: {output_2d.shape}\")",
|
|
939
|
+
" ",
|
|
940
|
+
" # Determine output timestamp (use the latest timestamp)",
|
|
941
|
+
" output_timestamp = time_coords[-1]",
|
|
942
|
+
" ",
|
|
943
|
+
" # Get spatial coordinates from reference array",
|
|
944
|
+
" spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims}",
|
|
945
|
+
" ",
|
|
946
|
+
" # Create output DataArray with appropriate data type",
|
|
947
|
+
" # Use int32 for classification, float32 for regression",
|
|
948
|
+
" is_multiclass = output.ndim == 4 and output.shape[1] > 1",
|
|
949
|
+
" if is_multiclass:",
|
|
950
|
+
" # Multi-class classification - use integer type",
|
|
951
|
+
" output_dtype = np.int32",
|
|
952
|
+
" output_type = 'classification'",
|
|
953
|
+
" else:",
|
|
954
|
+
" # Single output - use float type",
|
|
955
|
+
" output_dtype = np.float32",
|
|
956
|
+
" output_type = 'regression'",
|
|
957
|
+
" ",
|
|
958
|
+
" result = xr.DataArray(",
|
|
959
|
+
" data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
|
|
960
|
+
" dims=['time'] + spatial_dims,",
|
|
961
|
+
" coords={",
|
|
962
|
+
" 'time': [output_timestamp.values],",
|
|
963
|
+
" spatial_dims[0]: spatial_coords[spatial_dims[0]].values,",
|
|
964
|
+
" spatial_dims[1]: spatial_coords[spatial_dims[1]].values",
|
|
965
|
+
" },",
|
|
966
|
+
" attrs={",
|
|
967
|
+
" 'description': 'CNN model prediction',",
|
|
968
|
+
" }",
|
|
969
|
+
" )",
|
|
970
|
+
" ",
|
|
971
|
+
" logging.info(f\"Final result shape: {result.shape}\")",
|
|
972
|
+
" logging.info(f\"Final result data type: {result.dtype}\")",
|
|
973
|
+
" logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
|
|
974
|
+
])
|
|
975
|
+
|
|
976
|
+
# Add postprocessing call if postprocessing exists
|
|
977
|
+
if postprocessing_section:
|
|
978
|
+
script_lines.extend([
|
|
979
|
+
" # Apply postprocessing",
|
|
980
|
+
" result = postprocessing(result)",
|
|
981
|
+
" ",
|
|
982
|
+
])
|
|
983
|
+
|
|
984
|
+
# Single return statement at the end
|
|
985
|
+
script_lines.append(" return result")
|
|
986
|
+
|
|
987
|
+
return "\n".join(script_lines)
|
|
781
988
|
|
|
782
989
|
@require_api_key
|
|
783
990
|
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|