terrakio-core 0.3.8__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

@@ -3,10 +3,11 @@ import json
3
3
  import time
4
4
  import textwrap
5
5
  import logging
6
- from typing import Dict, Any, Union, Tuple
6
+ from typing import Dict, Any, Union, Tuple, Optional
7
7
  from io import BytesIO
8
8
  import numpy as np
9
9
  from google.cloud import storage
10
+ import ast
10
11
  from ..helper.decorators import require_token, require_api_key, require_auth
11
12
  TORCH_AVAILABLE = False
12
13
  SKL2ONNX_AVAILABLE = False
@@ -30,6 +31,7 @@ except ImportError:
30
31
  from io import BytesIO
31
32
  from typing import Tuple
32
33
 
34
+
33
35
  class ModelManagement:
34
36
  def __init__(self, client):
35
37
  self._client = client
@@ -347,9 +349,10 @@ class ModelManagement:
347
349
 
348
350
  except Exception as e:
349
351
  raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
350
-
352
+
353
+
351
354
  @require_api_key
352
- async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
355
+ async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
353
356
  """
354
357
  Upload a CNN model to the bucket and deploy it.
355
358
 
@@ -361,6 +364,7 @@ class ModelManagement:
361
364
  input_expression: Input expression for the dataset
362
365
  dates_iso8601: List of dates in ISO8601 format
363
366
  input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
367
+ processing_script_path: Path to the processing script, if not provided, no processing will be done
364
368
 
365
369
  Raises:
366
370
  APIError: If the API request fails
@@ -369,7 +373,8 @@ class ModelManagement:
369
373
  """
370
374
  await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
371
375
  # so the uploading process is kinda similar, but the deployment step is kinda different
372
- await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
376
+ # we should pass the processing script path to the deploy cnn model function
377
+ await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
373
378
 
374
379
  @require_api_key
375
380
  async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
@@ -475,6 +480,88 @@ class ModelManagement:
475
480
  padding=0
476
481
  )
477
482
 
483
+ def _parse_processing_script(self, script_path: str) -> Tuple[Optional[str], Optional[str]]:
484
+ """
485
+ Parse a Python file and extract preprocessing and postprocessing function bodies.
486
+
487
+ Args:
488
+ script_path: Path to the Python file containing processing functions
489
+
490
+ Returns:
491
+ Tuple of (preprocessing_code, postprocessing_code) where each can be None
492
+ """
493
+ try:
494
+ with open(script_path, 'r', encoding='utf-8') as f:
495
+ script_content = f.read()
496
+ except FileNotFoundError:
497
+ raise FileNotFoundError(f"Processing script not found: {script_path}")
498
+ except Exception as e:
499
+ raise ValueError(f"Error reading processing script: {e}")
500
+
501
+ # Handle empty file
502
+ if not script_content.strip():
503
+ self._client.logger.info(f"Processing script {script_path} is empty")
504
+ return None, None
505
+
506
+ try:
507
+ # Parse the Python file
508
+ tree = ast.parse(script_content)
509
+ except SyntaxError as e:
510
+ raise ValueError(f"Syntax error in processing script: {e}")
511
+
512
+ preprocessing_code = None
513
+ postprocessing_code = None
514
+
515
+ # Find function definitions
516
+ function_names = []
517
+ for node in ast.walk(tree):
518
+ if isinstance(node, ast.FunctionDef):
519
+ function_names.append(node.name)
520
+ if node.name == 'preprocessing':
521
+ preprocessing_code = self._extract_function_body(script_content, node)
522
+ elif node.name == 'postprocessing':
523
+ postprocessing_code = self._extract_function_body(script_content, node)
524
+
525
+ # Log what was found for debugging
526
+ if not function_names:
527
+ self._client.logger.warning(f"No functions found in processing script: {script_path}")
528
+ else:
529
+ found_functions = [name for name in function_names if name in ['preprocessing', 'postprocessing']]
530
+ if found_functions:
531
+ self._client.logger.info(f"Found processing functions: {found_functions}")
532
+ else:
533
+ self._client.logger.warning(f"No 'preprocessing' or 'postprocessing' functions found in {script_path}. "
534
+ f"Available functions: {function_names}")
535
+
536
+ return preprocessing_code, postprocessing_code
537
+
538
+ def _extract_function_body(self, script_content: str, func_node: ast.FunctionDef) -> str:
539
+ """Extract the body of a function from the script content."""
540
+ lines = script_content.split('\n')
541
+
542
+ # AST line numbers are 1-indexed, convert to 0-indexed
543
+ start_line = func_node.lineno - 1 # This is the 'def' line (0-indexed)
544
+ end_line = func_node.end_lineno - 1 if hasattr(func_node, 'end_lineno') else len(lines) - 1
545
+
546
+ # Extract ONLY the body lines (skip the def line entirely)
547
+ body_lines = []
548
+ for i in range(start_line + 1, end_line + 1): # +1 to skip the 'def' line
549
+ if i < len(lines):
550
+ body_lines.append(lines[i])
551
+
552
+ if not body_lines:
553
+ return ""
554
+
555
+ # Join and dedent to remove function-level indentation
556
+ body_text = '\n'.join(body_lines)
557
+ cleaned_body = textwrap.dedent(body_text).strip()
558
+
559
+ # Handle empty function body
560
+ if not cleaned_body or cleaned_body in ['pass', 'return', 'return None']:
561
+ return ""
562
+
563
+ return cleaned_body
564
+
478
565
  @require_api_key
479
566
  async def deploy_cnn_model(
480
567
  self,
@@ -483,7 +570,8 @@ class ModelManagement:
483
570
  model_name: str,
484
571
  input_expression: str,
485
572
  model_training_job_name: str,
486
- dates_iso8601: list
573
+ dates_iso8601: list,
574
+ processing_script_path: Optional[str] = None
487
575
  ) -> Dict[str, Any]:
488
576
  """
489
577
  Deploy a CNN model by generating inference script and creating dataset.
@@ -495,7 +583,7 @@ class ModelManagement:
495
583
  input_expression: Input expression for the dataset
496
584
  model_training_job_name: Name of the training job
497
585
  dates_iso8601: List of dates in ISO8601 format
498
-
586
+ processing_script_path: Path to the processing script, if not provided, no processing will be done
499
587
  Returns:
500
588
  dict: Response from the deployment process
501
589
 
@@ -506,8 +594,31 @@ class ModelManagement:
506
594
  user_info = await self._client.auth.get_user_info()
507
595
  uid = user_info["uid"]
508
596
 
597
+ preprocessing_code, postprocessing_code = None, None
598
+ if processing_script_path:
599
+ # if there is a function that is being passed in
600
+ try:
601
+ preprocessing_code, postprocessing_code = self._parse_processing_script(processing_script_path)
602
+ if preprocessing_code:
603
+ self._client.logger.info(f"Using custom preprocessing from: {processing_script_path}")
604
+ if postprocessing_code:
605
+ self._client.logger.info(f"Using custom postprocessing from: {processing_script_path}")
606
+ if not preprocessing_code and not postprocessing_code:
607
+ self._client.logger.warning(f"No preprocessing or postprocessing functions found in {processing_script_path}")
608
+ self._client.logger.info("Deployment will continue without custom processing")
609
+ except Exception as e:
610
+ raise ValueError(f"Failed to load processing script: {str(e)}")
611
+ # so we already have the preprocessing code and the post processing code, I need to pass them to the generate cnn script function
509
612
  # Generate and upload script
510
- script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid)
613
+ # Build preprocessing section with CONSISTENT 8-space indentation
614
+ preprocessing_section = ""
615
+ if preprocessing_code and preprocessing_code.strip():
616
+ # First dedent the preprocessing code to remove any existing indentation
617
+ clean_preprocessing = preprocessing_code
618
+ # Then add consistent 8-space indentation to match the template
619
+ preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
620
+ print(preprocessing_section)
621
+ script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
511
622
  script_name = f"{product}.py"
512
623
  self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
513
624
  # Create dataset
@@ -536,109 +647,109 @@ class ModelManagement:
536
647
  str: Generated Python script content
537
648
  """
538
649
  return textwrap.dedent(f'''
539
- import logging
540
- from io import BytesIO
650
+ import logging
651
+ from io import BytesIO
541
652
 
542
- import numpy as np
543
- import pandas as pd
544
- import xarray as xr
545
- from google.cloud import storage
546
- from onnxruntime import InferenceSession
653
+ import numpy as np
654
+ import pandas as pd
655
+ import xarray as xr
656
+ from google.cloud import storage
657
+ from onnxruntime import InferenceSession
547
658
 
548
- logging.basicConfig(
549
- level=logging.INFO
550
- )
659
+ logging.basicConfig(
660
+ level=logging.INFO
661
+ )
551
662
 
552
- def get_model():
553
- logging.info("Loading model for {model_name}...")
663
+ def get_model():
664
+ logging.info("Loading model for {model_name}...")
554
665
 
555
- client = storage.Client()
556
- bucket = client.get_bucket('terrakio-mass-requests')
557
- blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
666
+ client = storage.Client()
667
+ bucket = client.get_bucket('terrakio-mass-requests')
668
+ blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
558
669
 
559
- model = BytesIO()
560
- blob.download_to_file(model)
561
- model.seek(0)
670
+ model = BytesIO()
671
+ blob.download_to_file(model)
672
+ model.seek(0)
562
673
 
563
- session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
564
- return session
674
+ session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
675
+ return session
565
676
 
566
- def {product}(*bands, model):
567
- logging.info("start preparing data")
568
-
569
- data_arrays = list(bands)
570
-
571
- reference_array = data_arrays[0]
572
- original_shape = reference_array.shape
573
- logging.info(f"Original shape: {{original_shape}}")
574
-
575
- if 'time' in reference_array.dims:
576
- time_coords = reference_array.coords['time']
577
- if len(time_coords) == 1:
578
- output_timestamp = time_coords[0]
579
- else:
580
- years = [pd.to_datetime(t).year for t in time_coords.values]
581
- unique_years = set(years)
582
-
583
- if len(unique_years) == 1:
584
- year = list(unique_years)[0]
585
- output_timestamp = pd.Timestamp(f"{{year}}-01-01")
586
- else:
587
- latest_year = max(unique_years)
588
- output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
677
+ def {product}(*bands, model):
678
+ logging.info("start preparing data")
679
+
680
+ data_arrays = list(bands)
681
+
682
+ reference_array = data_arrays[0]
683
+ original_shape = reference_array.shape
684
+ logging.info(f"Original shape: {{original_shape}}")
685
+
686
+ if 'time' in reference_array.dims:
687
+ time_coords = reference_array.coords['time']
688
+ if len(time_coords) == 1:
689
+ output_timestamp = time_coords[0]
589
690
  else:
590
- output_timestamp = pd.Timestamp("1970-01-01")
591
-
592
- averaged_bands = []
593
- for data_array in data_arrays:
594
- if 'time' in data_array.dims:
595
- averaged_band = np.mean(data_array.values, axis=0)
596
- logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
691
+ years = [pd.to_datetime(t).year for t in time_coords.values]
692
+ unique_years = set(years)
693
+
694
+ if len(unique_years) == 1:
695
+ year = list(unique_years)[0]
696
+ output_timestamp = pd.Timestamp(f"{{year}}-01-01")
597
697
  else:
598
- averaged_band = data_array.values
599
- logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
698
+ latest_year = max(unique_years)
699
+ output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
700
+ else:
701
+ output_timestamp = pd.Timestamp("1970-01-01")
600
702
 
601
- flattened_band = averaged_band.reshape(-1, 1)
602
- averaged_bands.append(flattened_band)
703
+ averaged_bands = []
704
+ for data_array in data_arrays:
705
+ if 'time' in data_array.dims:
706
+ averaged_band = np.mean(data_array.values, axis=0)
707
+ logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
708
+ else:
709
+ averaged_band = data_array.values
710
+ logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
603
711
 
604
- input_data = np.hstack(averaged_bands)
712
+ flattened_band = averaged_band.reshape(-1, 1)
713
+ averaged_bands.append(flattened_band)
605
714
 
606
- logging.info(f"Final input shape: {{input_data.shape}}")
715
+ input_data = np.hstack(averaged_bands)
607
716
 
608
- output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
717
+ logging.info(f"Final input shape: {{input_data.shape}}")
609
718
 
610
- logging.info(f"Model output shape: {{output.shape}}")
719
+ output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
611
720
 
612
- if len(original_shape) >= 3:
613
- spatial_shape = original_shape[1:]
614
- else:
615
- spatial_shape = original_shape
721
+ logging.info(f"Model output shape: {{output.shape}}")
616
722
 
617
- output_reshaped = output.reshape(spatial_shape)
723
+ if len(original_shape) >= 3:
724
+ spatial_shape = original_shape[1:]
725
+ else:
726
+ spatial_shape = original_shape
618
727
 
619
- output_with_time = np.expand_dims(output_reshaped, axis=0)
728
+ output_reshaped = output.reshape(spatial_shape)
620
729
 
621
- if 'time' in reference_array.dims:
622
- spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
623
- spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
624
- else:
625
- spatial_dims = list(reference_array.dims)
626
- spatial_coords = dict(reference_array.coords)
627
-
628
- result = xr.DataArray(
629
- data=output_with_time.astype(np.float32),
630
- dims=['time'] + list(spatial_dims),
631
- coords={{
632
- 'time': [output_timestamp.values],
633
- 'y': spatial_coords['y'].values,
634
- 'x': spatial_coords['x'].values
635
- }}
636
- )
637
- return result
730
+ output_with_time = np.expand_dims(output_reshaped, axis=0)
731
+
732
+ if 'time' in reference_array.dims:
733
+ spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
734
+ spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
735
+ else:
736
+ spatial_dims = list(reference_array.dims)
737
+ spatial_coords = dict(reference_array.coords)
738
+
739
+ result = xr.DataArray(
740
+ data=output_with_time.astype(np.float32),
741
+ dims=['time'] + list(spatial_dims),
742
+ coords={{
743
+ 'time': [output_timestamp.values],
744
+ 'y': spatial_coords['y'].values,
745
+ 'x': spatial_coords['x'].values
746
+ }}
747
+ )
748
+ return result
638
749
  ''').strip()
639
750
 
640
751
  @require_api_key
641
- def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
752
+ def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
642
753
  """
643
754
  Generate Python inference script for CNN model with time-stacked bands.
644
755
 
@@ -647,137 +758,197 @@ class ModelManagement:
647
758
  product: Product name
648
759
  model_training_job_name: Training job name
649
760
  uid: User ID
650
-
761
+ preprocessing_code: Preprocessing code
762
+ postprocessing_code: Postprocessing code
651
763
  Returns:
652
764
  str: Generated Python script content
653
765
  """
654
- return textwrap.dedent(f'''
655
- import logging
656
- from io import BytesIO
657
-
658
- import numpy as np
659
- import pandas as pd
660
- import xarray as xr
661
- from google.cloud import storage
662
- from onnxruntime import InferenceSession
663
-
664
- logging.basicConfig(
665
- level=logging.INFO
666
- )
667
-
668
- def get_model():
669
- logging.info("Loading CNN model for {model_name}...")
670
-
671
- client = storage.Client()
672
- bucket = client.get_bucket('terrakio-mass-requests')
673
- blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
674
-
675
- model = BytesIO()
676
- blob.download_to_file(model)
677
- model.seek(0)
678
-
679
- session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
680
- return session
681
-
682
- def {product}(*bands, model):
683
- logging.info("Start preparing CNN data with time-stacked bands")
684
-
685
- data_arrays = list(bands)
686
-
687
- if not data_arrays:
688
- raise ValueError("No bands provided")
689
-
690
- reference_array = data_arrays[0]
691
- original_shape = reference_array.shape
692
- logging.info(f"Original shape: {{original_shape}}")
693
-
694
- # Get time coordinates - all bands should have the same time dimension
695
- if 'time' not in reference_array.dims:
696
- raise ValueError("Time dimension is required for CNN processing")
697
-
698
- time_coords = reference_array.coords['time']
699
- num_timestamps = len(time_coords)
700
- logging.info(f"Number of timestamps: {{num_timestamps}}")
701
-
702
- # Get spatial dimensions
703
- spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
704
- height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height
705
- width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width
706
- logging.info(f"Spatial dimensions: {{height}} x {{width}}")
707
-
708
- # Stack bands across time dimension
709
- # Result will be: (num_bands * num_timestamps, height, width)
710
- stacked_channels = []
711
-
712
- for band_idx, data_array in enumerate(data_arrays):
713
- logging.info(f"Processing band {{band_idx + 1}}/{{len(data_arrays)}}")
714
-
715
- # Ensure consistent time coordinates across bands
716
- if not np.array_equal(data_array.coords['time'].values, time_coords.values):
717
- logging.warning(f"Band {{band_idx}} has different time coordinates, aligning...")
718
- data_array = data_array.sel(time=time_coords, method='nearest')
719
-
720
- # Extract values and ensure proper ordering (time, height, width)
721
- band_values = data_array.values
722
- if band_values.ndim == 3:
723
- # Reorder dimensions if needed to ensure (time, height, width)
724
- time_dim_idx = data_array.dims.index('time')
725
- if time_dim_idx != 0:
726
- axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]
727
- band_values = np.transpose(band_values, axes_order)
728
-
729
- # Add each timestamp of this band to the channel stack
730
- for t in range(num_timestamps):
731
- stacked_channels.append(band_values[t])
732
-
733
- # Stack all channels: (num_bands * num_timestamps, height, width)
734
- input_channels = np.stack(stacked_channels, axis=0)
735
- total_channels = len(data_arrays) * num_timestamps
736
- logging.info(f"Stacked channels shape: {{input_channels.shape}}")
737
- logging.info(f"Total channels: {{total_channels}} ({{len(data_arrays)}} bands × {{num_timestamps}} timestamps)")
738
-
739
- # Add batch dimension: (1, num_channels, height, width)
740
- input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)
741
- logging.info(f"Final input shape for CNN: {{input_data.shape}}")
742
-
743
- # Run inference
744
- output = model.run(None, {{"float_input": input_data}})[0]
745
- logging.info(f"Model output shape: {{output.shape}}")
746
-
747
- # Process output back to xarray format
748
- # Assuming output is (1, height, width) or (1, 1, height, width)
749
- if output.ndim == 4 and output.shape[1] == 1:
750
- # Remove channel dimension if it's 1
751
- output_2d = output[0, 0]
752
- elif output.ndim == 3:
753
- # Remove batch dimension
754
- output_2d = output[0]
755
- else:
756
- # Handle other cases
757
- output_2d = np.squeeze(output)
758
- if output_2d.ndim != 2:
759
- raise ValueError(f"Unexpected output shape after processing: {{output_2d.shape}}")
760
-
761
- # Determine output timestamp (use the latest timestamp)
762
- output_timestamp = time_coords[-1]
763
-
764
- # Get spatial coordinates from reference array
765
- spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims}}
766
-
767
- # Create output DataArray
768
- result = xr.DataArray(
769
- data=np.expand_dims(output_2d.astype(np.float32), axis=0),
770
- dims=['time'] + spatial_dims,
771
- coords={{
772
- 'time': [output_timestamp.values],
773
- spatial_dims[0]: spatial_coords[spatial_dims[0]].values,
774
- spatial_dims[1]: spatial_coords[spatial_dims[1]].values
775
- }}
776
- )
777
-
778
- logging.info(f"Final result shape: {{result.shape}}")
779
- return result
780
- ''').strip()
766
+ import textwrap
767
+
768
+ # Build preprocessing section with CONSISTENT 4-space indentation
769
+ preprocessing_section = ""
770
+ if preprocessing_code and preprocessing_code.strip():
771
+ clean_preprocessing = textwrap.dedent(preprocessing_code)
772
+ preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
773
+
774
+ # Build postprocessing section with CONSISTENT 4-space indentation
775
+ postprocessing_section = ""
776
+ if postprocessing_code and postprocessing_code.strip():
777
+ clean_postprocessing = textwrap.dedent(postprocessing_code)
778
+ postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
779
+
780
+ # Build the template WITHOUT dedenting the whole thing, so indentation is preserved
781
+ script_lines = [
782
+ "import logging",
783
+ "from io import BytesIO",
784
+ "import numpy as np",
785
+ "import pandas as pd",
786
+ "import xarray as xr",
787
+ "from google.cloud import storage",
788
+ "from onnxruntime import InferenceSession",
789
+ "from typing import Tuple",
790
+ "",
791
+ "logging.basicConfig(",
792
+ " level=logging.INFO",
793
+ ")",
794
+ "",
795
+ ]
796
+
797
+ # Add preprocessing function definition BEFORE the main function
798
+ if preprocessing_section:
799
+ script_lines.extend([
800
+ "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
801
+ preprocessing_section,
802
+ "",
803
+ ])
804
+
805
+ # Add postprocessing function definition BEFORE the main function
806
+ if postprocessing_section:
807
+ script_lines.extend([
808
+ "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
809
+ postprocessing_section,
810
+ "",
811
+ ])
812
+
813
+ # Add the get_model function
814
+ script_lines.extend([
815
+ "def get_model():",
816
+ f" logging.info(\"Loading CNN model for {model_name}...\")",
817
+ "",
818
+ " client = storage.Client()",
819
+ " bucket = client.get_bucket('terrakio-mass-requests')",
820
+ f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
821
+ "",
822
+ " model = BytesIO()",
823
+ " blob.download_to_file(model)",
824
+ " model.seek(0)",
825
+ "",
826
+ " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
827
+ " return session",
828
+ "",
829
+ f"def {product}(*bands, model):",
830
+ " logging.info(\"Start preparing CNN data with time-stacked bands\")",
831
+ " data_arrays = list(bands)",
832
+ " ",
833
+ " if not data_arrays:",
834
+ " raise ValueError(\"No bands provided\")",
835
+ " ",
836
+ ])
837
+
838
+ # Add preprocessing call if preprocessing exists
839
+ if preprocessing_section:
840
+ script_lines.extend([
841
+ " # Apply preprocessing",
842
+ " data_arrays = preprocessing(tuple(data_arrays))",
843
+ " data_arrays = list(data_arrays) # Convert back to list for processing",
844
+ " ",
845
+ ])
846
+
847
+ # Continue with the rest of the processing logic
848
+ script_lines.extend([
849
+ " reference_array = data_arrays[0]",
850
+ " original_shape = reference_array.shape",
851
+ " logging.info(f\"Original shape: {original_shape}\")",
852
+ " ",
853
+ " # Get time coordinates - all bands should have the same time dimension",
854
+ " if 'time' not in reference_array.dims:",
855
+ " raise ValueError(\"Time dimension is required for CNN processing\")",
856
+ " ",
857
+ " time_coords = reference_array.coords['time']",
858
+ " num_timestamps = len(time_coords)",
859
+ " logging.info(f\"Number of timestamps: {num_timestamps}\")",
860
+ " ",
861
+ " # Get spatial dimensions",
862
+ " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
863
+ " height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
864
+ " width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
865
+ " logging.info(f\"Spatial dimensions: {height} x {width}\")",
866
+ " ",
867
+ " # Stack bands across time dimension",
868
+ " # Result will be: (num_bands * num_timestamps, height, width)",
869
+ " stacked_channels = []",
870
+ " ",
871
+ " for band_idx, data_array in enumerate(data_arrays):",
872
+ " logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
873
+ " ",
874
+ " # Ensure consistent time coordinates across bands",
875
+ " if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
876
+ " logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
877
+ " data_array = data_array.sel(time=time_coords, method='nearest')",
878
+ " ",
879
+ " # Extract values and ensure proper ordering (time, height, width)",
880
+ " band_values = data_array.values",
881
+ " if band_values.ndim == 3:",
882
+ " # Reorder dimensions if needed to ensure (time, height, width)",
883
+ " time_dim_idx = data_array.dims.index('time')",
884
+ " if time_dim_idx != 0:",
885
+ " axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]",
886
+ " band_values = np.transpose(band_values, axes_order)",
887
+ " ",
888
+ " # Add each timestamp of this band to the channel stack",
889
+ " for t in range(num_timestamps):",
890
+ " stacked_channels.append(band_values[t])",
891
+ " ",
892
+ " # Stack all channels: (num_bands * num_timestamps, height, width)",
893
+ " input_channels = np.stack(stacked_channels, axis=0)",
894
+ " total_channels = len(data_arrays) * num_timestamps",
895
+ " logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
896
+ " logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
897
+ " ",
898
+ " # Add batch dimension: (1, num_channels, height, width)",
899
+ " input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
900
+ " logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
901
+ " ",
902
+ " # Run inference",
903
+ " output = model.run(None, {\"float_input\": input_data})[0]",
904
+ " logging.info(f\"Model output shape: {output.shape}\")",
905
+ " ",
906
+ " # Process output back to xarray format",
907
+ " # Assuming output is (1, height, width) or (1, 1, height, width)",
908
+ " if output.ndim == 4 and output.shape[1] == 1:",
909
+ " # Remove channel dimension if it's 1",
910
+ " output_2d = output[0, 0]",
911
+ " elif output.ndim == 3:",
912
+ " # Remove batch dimension",
913
+ " output_2d = output[0]",
914
+ " else:",
915
+ " # Handle other cases",
916
+ " output_2d = np.squeeze(output)",
917
+ " if output_2d.ndim != 2:",
918
+ " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
919
+ " ",
920
+ " # Determine output timestamp (use the latest timestamp)",
921
+ " output_timestamp = time_coords[-1]",
922
+ " ",
923
+ " # Get spatial coordinates from reference array",
924
+ " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims}",
925
+ " ",
926
+ " # Create output DataArray",
927
+ " result = xr.DataArray(",
928
+ " data=np.expand_dims(output_2d.astype(np.float32), axis=0),",
929
+ " dims=['time'] + spatial_dims,",
930
+ " coords={",
931
+ " 'time': [output_timestamp.values],",
932
+ " spatial_dims[0]: spatial_coords[spatial_dims[0]].values,",
933
+ " spatial_dims[1]: spatial_coords[spatial_dims[1]].values",
934
+ " }",
935
+ " )",
936
+ " ",
937
+ " logging.info(f\"Final result shape: {result.shape}\")",
938
+ ])
939
+
940
+ # Add postprocessing call if postprocessing exists
941
+ if postprocessing_section:
942
+ script_lines.extend([
943
+ " # Apply postprocessing",
944
+ " result = postprocessing(result)",
945
+ " ",
946
+ ])
947
+
948
+ # Single return statement at the end
949
+ script_lines.append(" return result")
950
+
951
+ return "\n".join(script_lines)
781
952
 
782
953
  @require_api_key
783
954
  def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):