terrakio-core 0.4.0__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

terrakio_core/__init__.py CHANGED
@@ -5,7 +5,7 @@ Terrakio Core
5
5
  Core components for Terrakio API clients.
6
6
  """
7
7
 
8
- __version__ = "0.4.0"
8
+ __version__ = "0.4.3"
9
9
 
10
10
  from .async_client import AsyncClient
11
11
  from .sync_client import SyncClient as Client
@@ -47,7 +47,7 @@ class AsyncClient(BaseClient):
47
47
  return await self._make_request_with_retry(self._session, method, endpoint, **kwargs)
48
48
 
49
49
  async def _make_request_with_retry(self, session: aiohttp.ClientSession, method: str, endpoint: str, **kwargs) -> Dict[Any, Any]:
50
- url = f"{self.url}/{endpoint.lstrip('/')}"
50
+ url = f"{self.url}/{endpoint.lstrip('/')}"
51
51
  last_exception = None
52
52
 
53
53
  for attempt in range(self.retry + 1):
@@ -67,7 +67,7 @@ class MassStats:
67
67
 
68
68
 
69
69
  @require_api_key
70
- def start_job(self, id: str) -> Dict[str, Any]:
70
+ async def start_job(self, id: str) -> Dict[str, Any]:
71
71
  """
72
72
  Start a mass stats job by task ID.
73
73
 
@@ -78,7 +78,7 @@ class MassStats:
78
78
  API response as a dictionary
79
79
 
80
80
  """
81
- return self._client._terrakio_request("POST", f"mass_stats/start/{id}")
81
+ return await self._client._terrakio_request("POST", f"mass_stats/start/{id}")
82
82
 
83
83
  @require_api_key
84
84
  def get_task_id(self, name: str, stage: str, uid: Optional[str] = None) -> Dict[str, Any]:
@@ -542,7 +542,7 @@ class MassStats:
542
542
  return self._client._terrakio_request("POST", "mass_stats/cancel")
543
543
 
544
544
  @require_api_key
545
- def random_sample(
545
+ async def random_sample(
546
546
  self,
547
547
  name: str,
548
548
  config: dict,
@@ -556,7 +556,7 @@ class MassStats:
556
556
  year_range: list[int] = None,
557
557
  overwrite: bool = False,
558
558
  server: str = None,
559
- bucket: str = None,
559
+ bucket: str = None
560
560
  ) -> Dict[str, Any]:
561
561
  """
562
562
  Submit a random sample job.
@@ -591,18 +591,18 @@ class MassStats:
591
591
  "tile_size": tile_size,
592
592
  "res": res,
593
593
  "output": output,
594
- "region": region,
595
594
  "overwrite": str(overwrite).lower(),
596
595
  }
597
596
  payload_mapping = {
598
597
  "year_range": year_range,
599
598
  "server": server,
600
- "bucket": bucket
599
+ "region": region,
600
+ "bucket": bucket,
601
601
  }
602
602
  for key, value in payload_mapping.items():
603
603
  if value is not None:
604
604
  payload[key] = value
605
- return self._client._terrakio_request("POST", "random_sample", json=payload)
605
+ return await self._client._terrakio_request("POST", "random_sample", json=payload)
606
606
 
607
607
 
608
608
  @require_api_key
@@ -37,7 +37,7 @@ class ModelManagement:
37
37
  self._client = client
38
38
 
39
39
  @require_api_key
40
- def generate_ai_dataset(
40
+ async def generate_ai_dataset(
41
41
  self,
42
42
  name: str,
43
43
  aoi_geojson: str,
@@ -51,7 +51,8 @@ class ModelManagement:
51
51
  filter_y: str = "skip",
52
52
  crs: str = "epsg:4326",
53
53
  res: float = 0.001,
54
- region: str = "aus",
54
+ region: str = None,
55
+ bucket: str = None,
55
56
  start_year: int = None,
56
57
  end_year: int = None,
57
58
  ) -> dict:
@@ -71,7 +72,8 @@ class ModelManagement:
71
72
  tile_size (int): Size of tiles in degrees
72
73
  crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
73
74
  res (float, optional): Resolution in degrees. Defaults to 0.001
74
- region (str, optional): Region code. Defaults to "aus"
75
+ region (str, optional): Region code. Defaults to None
76
+ bucket (str, optional): Bucket name. Defaults to None
75
77
  start_year (int, optional): Start year for data generation. Required if end_year provided
76
78
  end_year (int, optional): End year for data generation. Required if start_year provided
77
79
 
@@ -109,7 +111,7 @@ class ModelManagement:
109
111
  with open(aoi_geojson, 'r') as f:
110
112
  aoi_data = json.load(f)
111
113
 
112
- task_response = self._client.mass_stats.random_sample(
114
+ task_response = await self._client.mass_stats.random_sample(
113
115
  name=name,
114
116
  config=config,
115
117
  aoi=aoi_data,
@@ -121,14 +123,14 @@ class ModelManagement:
121
123
  region=region,
122
124
  output="netcdf",
123
125
  server=self._client.url,
124
- bucket="terrakio-mass-requests",
126
+ bucket=bucket,
125
127
  overwrite=True
126
128
  )
127
129
  task_id = task_response["task_id"]
128
130
 
129
131
  # Wait for job completion with progress bar
130
132
  while True:
131
- result = self._client.track_mass_stats_job(ids=[task_id])
133
+ result = await self._client.mass_stats.track_job(ids=[task_id])
132
134
  status = result[task_id]['status']
133
135
  completed = result[task_id].get('completed', 0)
134
136
  total = result[task_id].get('total', 1)
@@ -153,9 +155,53 @@ class ModelManagement:
153
155
  time.sleep(5)
154
156
 
155
157
  # after all the random sample jobs are done, we then start the mass stats job
156
- task_id = self._client.mass_stats.start_mass_stats_job(task_id)
158
+ # task_id = self._client.mass_stats.start_mass_stats_job(task_id)
159
+ task_id = await self._client.mass_stats.start_job(task_id)
157
160
  return task_id
161
+ # the folder that is being created is not under the jobs folder, its directly under the UID folder
158
162
 
163
+ # @require_api_key
164
+ # async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
165
+ # """
166
+ # Upload a model to the bucket so that it can be used for inference.
167
+ # Converts PyTorch and scikit-learn models to ONNX format before uploading.
168
+
169
+ # Args:
170
+ # model: The model object (PyTorch model or scikit-learn model)
171
+ # model_name: Name for the model (without extension)
172
+ # input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
173
+ # Required for PyTorch models, optional for scikit-learn models
174
+
175
+ # Raises:
176
+ # APIError: If the API request fails
177
+ # ValueError: If model type is not supported or input_shape is missing for PyTorch models
178
+ # ImportError: If required libraries (torch or skl2onnx) are not installed
179
+ # """
180
+ # uid = (await self._client.auth.get_user_info())["uid"]
181
+ # # above line is getting the uid,
182
+
183
+ # client = storage.Client()
184
+ # bucket = client.get_bucket('terrakio-mass-requests')
185
+
186
+ # # Convert model to ONNX format
187
+ # onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
188
+
189
+ # # Upload ONNX model to bucket
190
+ # # blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
191
+ # # we don't need to upload the model to the bucket
192
+ # # so the stuff is stored under the virtual datasets folder
193
+ # # the model name and the virtual dataset name should be the same
194
+ # virtual_dataset_name = model_name
195
+ # blob = bucket.blob(f'{uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx')
196
+ # # wer are uploading the model to the virtual dataset folder
197
+
198
+ # blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
199
+
200
+ # self._client.logger.info(f"Model uploaded successfully to {uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx")
201
+
202
+ # this is the upload model function, I think we need to upload to the user, under the virutal_datasets folder, and create the virtual dataset
203
+
204
+
159
205
  @require_api_key
160
206
  async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
161
207
  """
@@ -351,6 +397,33 @@ class ModelManagement:
351
397
  raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
352
398
 
353
399
 
400
+ # we do not need to pass in both the model name and the dataset name, since the model name should the same as the virtual dataset name
401
+ # but we are gonna have multiple products for the same virtual dataset
402
+ # @require_api_key
403
+ # async def upload_and_deploy_cnn_model(self, model, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
404
+ # """
405
+ # Upload a CNN model to the bucket and deploy it.
406
+
407
+ # Args:
408
+ # model: The model object (PyTorch model or scikit-learn model)
409
+ # model_name: Name for the model (without extension)
410
+ # dataset: Name of the dataset to create
411
+ # product: Product name for the inference
412
+ # input_expression: Input expression for the dataset
413
+ # dates_iso8601: List of dates in ISO8601 format
414
+ # input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
415
+ # processing_script_path: Path to the processing script, if not provided, no processing will be done
416
+
417
+ # Raises:
418
+ # APIError: If the API request fails
419
+ # ValueError: If model type is not supported or input_shape is missing for PyTorch models
420
+ # ImportError: If required libraries (torch or skl2onnx) are not installed
421
+ # """
422
+ # await self.upload_model(model=model, model_name=dataset, input_shape=input_shape)
423
+ # # so the uploading process is kinda similar, but the deployment step is kinda different
424
+ # # we should pass the processing script path to the deploy cnn model function
425
+ # await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
426
+
354
427
  @require_api_key
355
428
  async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
356
429
  """
@@ -376,6 +449,7 @@ class ModelManagement:
376
449
  # we should pass the processing script path to the deploy cnn model function
377
450
  await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
378
451
 
452
+
379
453
  @require_api_key
380
454
  async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
381
455
  """
@@ -617,7 +691,7 @@ class ModelManagement:
617
691
  clean_preprocessing = preprocessing_code
618
692
  # Then add consistent 8-space indentation to match the template
619
693
  preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
620
- print(preprocessing_section)
694
+ # print(preprocessing_section)
621
695
  script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
622
696
  script_name = f"{product}.py"
623
697
  self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
@@ -748,6 +822,245 @@ class ModelManagement:
748
822
  return result
749
823
  ''').strip()
750
824
 
825
+ # @require_api_key
826
+ # def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
827
+ # """
828
+ # Generate Python inference script for CNN model with time-stacked bands.
829
+
830
+ # Args:
831
+ # model_name: Name of the model
832
+ # product: Product name
833
+ # model_training_job_name: Training job name
834
+ # uid: User ID
835
+ # preprocessing_code: Preprocessing code
836
+ # postprocessing_code: Postprocessing code
837
+ # Returns:
838
+ # str: Generated Python script content
839
+ # """
840
+ # import textwrap
841
+
842
+ # # Build preprocessing section with CONSISTENT 4-space indentation
843
+ # preprocessing_section = ""
844
+ # if preprocessing_code and preprocessing_code.strip():
845
+ # clean_preprocessing = textwrap.dedent(preprocessing_code)
846
+ # preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
847
+
848
+ # # Build postprocessing section with CONSISTENT 4-space indentation
849
+ # postprocessing_section = ""
850
+ # if postprocessing_code and postprocessing_code.strip():
851
+ # clean_postprocessing = textwrap.dedent(postprocessing_code)
852
+ # postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
853
+
854
+ # # Build the template WITHOUT dedenting the whole thing, so indentation is preserved
855
+ # script_lines = [
856
+ # "import logging",
857
+ # "from io import BytesIO",
858
+ # "import numpy as np",
859
+ # "import pandas as pd",
860
+ # "import xarray as xr",
861
+ # "from google.cloud import storage",
862
+ # "from onnxruntime import InferenceSession",
863
+ # "from typing import Tuple",
864
+ # "",
865
+ # "logging.basicConfig(",
866
+ # " level=logging.INFO",
867
+ # ")",
868
+ # "",
869
+ # ]
870
+
871
+ # # Add preprocessing function definition BEFORE the main function
872
+ # if preprocessing_section:
873
+ # script_lines.extend([
874
+ # "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
875
+ # preprocessing_section,
876
+ # "",
877
+ # ])
878
+
879
+ # # Add postprocessing function definition BEFORE the main function
880
+ # if postprocessing_section:
881
+ # script_lines.extend([
882
+ # "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
883
+ # postprocessing_section,
884
+ # "",
885
+ # ])
886
+
887
+ # # Add the get_model function
888
+ # script_lines.extend([
889
+ # "def get_model():",
890
+ # f" logging.info(\"Loading CNN model for {model_name}...\")",
891
+ # "",
892
+ # " client = storage.Client()",
893
+ # " bucket = client.get_bucket('terrakio-mass-requests')",
894
+ # f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
895
+ # "",
896
+ # " model = BytesIO()",
897
+ # " blob.download_to_file(model)",
898
+ # " model.seek(0)",
899
+ # "",
900
+ # " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
901
+ # " return session",
902
+ # "",
903
+ # f"def {product}(*bands, model):",
904
+ # " logging.info(\"Start preparing CNN data with time-stacked bands\")",
905
+ # " data_arrays = list(bands)",
906
+ # " ",
907
+ # " if not data_arrays:",
908
+ # " raise ValueError(\"No bands provided\")",
909
+ # " ",
910
+ # ])
911
+
912
+ # # Add preprocessing call if preprocessing exists
913
+ # if preprocessing_section:
914
+ # script_lines.extend([
915
+ # " # Apply preprocessing",
916
+ # " data_arrays = preprocessing(tuple(data_arrays))",
917
+ # " data_arrays = list(data_arrays) # Convert back to list for processing",
918
+ # " ",
919
+ # ])
920
+
921
+ # # Continue with the rest of the processing logic
922
+ # script_lines.extend([
923
+ # " reference_array = data_arrays[0]",
924
+ # " original_shape = reference_array.shape",
925
+ # " logging.info(f\"Original shape: {original_shape}\")",
926
+ # " ",
927
+ # " # Get time coordinates - all bands should have the same time dimension",
928
+ # " if 'time' not in reference_array.dims:",
929
+ # " raise ValueError(\"Time dimension is required for CNN processing\")",
930
+ # " ",
931
+ # " time_coords = reference_array.coords['time']",
932
+ # " num_timestamps = len(time_coords)",
933
+ # " logging.info(f\"Number of timestamps: {num_timestamps}\")",
934
+ # " ",
935
+ # " # Get spatial dimensions",
936
+ # " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
937
+ # " height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
938
+ # " width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
939
+ # " logging.info(f\"Spatial dimensions: {height} x {width}\")",
940
+ # " ",
941
+ # " # Stack bands across time dimension",
942
+ # " # Result will be: (num_bands * num_timestamps, height, width)",
943
+ # " stacked_channels = []",
944
+ # " ",
945
+ # " for band_idx, data_array in enumerate(data_arrays):",
946
+ # " logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
947
+ # " ",
948
+ # " # Ensure consistent time coordinates across bands",
949
+ # " if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
950
+ # " logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
951
+ # " data_array = data_array.sel(time=time_coords, method='nearest')",
952
+ # " ",
953
+ # " # Extract values and ensure proper ordering (time, height, width)",
954
+ # " band_values = data_array.values",
955
+ # " if band_values.ndim == 3:",
956
+ # " # Reorder dimensions if needed to ensure (time, height, width)",
957
+ # " time_dim_idx = data_array.dims.index('time')",
958
+ # " if time_dim_idx != 0:",
959
+ # " axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]",
960
+ # " band_values = np.transpose(band_values, axes_order)",
961
+ # " ",
962
+ # " # Add each timestamp of this band to the channel stack",
963
+ # " for t in range(num_timestamps):",
964
+ # " stacked_channels.append(band_values[t])",
965
+ # " ",
966
+ # " # Stack all channels: (num_bands * num_timestamps, height, width)",
967
+ # " input_channels = np.stack(stacked_channels, axis=0)",
968
+ # " total_channels = len(data_arrays) * num_timestamps",
969
+ # " logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
970
+ # " logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
971
+ # " ",
972
+ # " # Add batch dimension: (1, num_channels, height, width)",
973
+ # " input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
974
+ # " logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
975
+ # " ",
976
+ # " # Run inference",
977
+ # " output = model.run(None, {\"float_input\": input_data})[0]",
978
+ # " logging.info(f\"Model output shape: {output.shape}\")",
979
+ # " ",
980
+ # " # UPDATED: Handle multi-class CNN output properly",
981
+ # " if output.ndim == 4:",
982
+ # " if output.shape[1] == 1:",
983
+ # " # Single class output (regression or binary classification)",
984
+ # " output_2d = output[0, 0]",
985
+ # " logging.info(\"Single channel output detected\")",
986
+ # " else:",
987
+ # " # Multi-class output - convert logits/probabilities to class predictions",
988
+ # " output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
989
+ # " output_2d = output_classes[0] # Shape: (height, width)",
990
+ # " ",
991
+ # " # Apply class merging: merge class 6 into class 3",
992
+ # " output_2d = np.where(output_2d == 6, 3, output_2d)",
993
+ # " ",
994
+ # " logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
995
+ # " logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
996
+ # " logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
997
+ # " elif output.ndim == 3:",
998
+ # " # Remove batch dimension",
999
+ # " output_2d = output[0]",
1000
+ # " logging.info(\"3D output detected, removed batch dimension\")",
1001
+ # " else:",
1002
+ # " # Handle other cases",
1003
+ # " output_2d = np.squeeze(output)",
1004
+ # " if output_2d.ndim != 2:",
1005
+ # " logging.error(f\"Cannot process output shape: {output.shape}\")",
1006
+ # " logging.error(f\"After squeeze: {output_2d.shape}\")",
1007
+ # " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
1008
+ # " logging.info(\"Applied squeeze to output\")",
1009
+ # " ",
1010
+ # " # Ensure output is 2D",
1011
+ # " if output_2d.ndim != 2:",
1012
+ # " raise ValueError(f\"Final output must be 2D, got shape: {output_2d.shape}\")",
1013
+ # " ",
1014
+ # " # Determine output timestamp (use the latest timestamp)",
1015
+ # " output_timestamp = time_coords[-1]",
1016
+ # " ",
1017
+ # " # Get spatial coordinates from reference array",
1018
+ # " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims}",
1019
+ # " ",
1020
+ # " # Create output DataArray with appropriate data type",
1021
+ # " # Use int32 for classification, float32 for regression",
1022
+ # " is_multiclass = output.ndim == 4 and output.shape[1] > 1",
1023
+ # " if is_multiclass:",
1024
+ # " # Multi-class classification - use integer type",
1025
+ # " output_dtype = np.int32",
1026
+ # " output_type = 'classification'",
1027
+ # " else:",
1028
+ # " # Single output - use float type",
1029
+ # " output_dtype = np.float32",
1030
+ # " output_type = 'regression'",
1031
+ # " ",
1032
+ # " result = xr.DataArray(",
1033
+ # " data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
1034
+ # " dims=['time'] + spatial_dims,",
1035
+ # " coords={",
1036
+ # " 'time': [output_timestamp.values],",
1037
+ # " spatial_dims[0]: spatial_coords[spatial_dims[0]].values,",
1038
+ # " spatial_dims[1]: spatial_coords[spatial_dims[1]].values",
1039
+ # " },",
1040
+ # " attrs={",
1041
+ # " 'description': 'CNN model prediction',",
1042
+ # " }",
1043
+ # " )",
1044
+ # " ",
1045
+ # " logging.info(f\"Final result shape: {result.shape}\")",
1046
+ # " logging.info(f\"Final result data type: {result.dtype}\")",
1047
+ # " logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
1048
+ # ])
1049
+
1050
+ # # Add postprocessing call if postprocessing exists
1051
+ # if postprocessing_section:
1052
+ # script_lines.extend([
1053
+ # " # Apply postprocessing",
1054
+ # " result = postprocessing(result)",
1055
+ # " ",
1056
+ # ])
1057
+
1058
+ # # Single return statement at the end
1059
+ # script_lines.append(" return result")
1060
+
1061
+ # return "\n".join(script_lines)
1062
+
1063
+
751
1064
  @require_api_key
752
1065
  def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
753
1066
  """
@@ -794,6 +1107,160 @@ class ModelManagement:
794
1107
  "",
795
1108
  ]
796
1109
 
1110
+ # Add preprocessing validation function if preprocessing exists
1111
+ if preprocessing_section:
1112
+ script_lines.extend([
1113
+ "def validate_preprocessing_output(data_arrays):",
1114
+ " \"\"\"",
1115
+ " Validate preprocessing output coordinates and data type.",
1116
+ " ",
1117
+ " Args:",
1118
+ " data_arrays: List of xarray DataArrays from preprocessing",
1119
+ " ",
1120
+ " Returns:",
1121
+ " str: Validation signature symbol",
1122
+ " ",
1123
+ " Raises:",
1124
+ " ValueError: If validation fails",
1125
+ " \"\"\"",
1126
+ " import numpy as np",
1127
+ " ",
1128
+ " logging.info(\"=\" * 60)",
1129
+ " logging.info(\"VALIDATING PREPROCESSING OUTPUT\")",
1130
+ " logging.info(\"=\" * 60)",
1131
+ " ",
1132
+ " if not data_arrays:",
1133
+ " raise ValueError(\"No data arrays provided from preprocessing\")",
1134
+ " ",
1135
+ " reference_shape = None",
1136
+ " ",
1137
+ " for i, data_array in enumerate(data_arrays):",
1138
+ " logging.info(f\"Validating channel {i+1}/{len(data_arrays)}: {data_array.name}\")",
1139
+ " ",
1140
+ " # Check if it's an xarray DataArray",
1141
+ " if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
1142
+ " raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
1143
+ " ",
1144
+ " # Check coordinates",
1145
+ " if 'time' not in data_array.coords:",
1146
+ " raise ValueError(f\"Channel {i+1} missing time coordinate\")",
1147
+ " ",
1148
+ " spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
1149
+ " if len(spatial_dims) != 2:",
1150
+ " raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
1151
+ " ",
1152
+ " for dim in spatial_dims:",
1153
+ " if dim not in data_array.coords:",
1154
+ " raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
1155
+ " ",
1156
+ " logging.info(f\" Coordinates: {list(data_array.coords.keys())}\")",
1157
+ " ",
1158
+ " # Check data type",
1159
+ " data_values = data_array.values",
1160
+ " logging.info(f\" Data type: {data_values.dtype}\")",
1161
+ " ",
1162
+ " # Check shape consistency",
1163
+ " shape = data_array.shape",
1164
+ " if reference_shape is None:",
1165
+ " reference_shape = shape",
1166
+ " else:",
1167
+ " if shape != reference_shape:",
1168
+ " raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
1169
+ " ",
1170
+ " logging.info(f\" Shape: {shape}\")",
1171
+ " ",
1172
+ " # Generate validation signature",
1173
+ " signature_components = [",
1174
+ " f\"CH{len(data_arrays)}\", # Channel count",
1175
+ " f\"T{reference_shape[0]}\", # Time dimension",
1176
+ " f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
1177
+ " f\"DT{data_arrays[0].values.dtype}\", # Data type",
1178
+ " ]",
1179
+ " ",
1180
+ " signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
1181
+ " ",
1182
+ " logging.info(\"-\" * 60)",
1183
+ " logging.info(\"PREPROCESSING VALIDATION SUMMARY\")",
1184
+ " logging.info(\"-\" * 60)",
1185
+ " logging.info(f\"Channels validated: {len(data_arrays)}\")",
1186
+ " logging.info(f\"Common shape: {reference_shape}\")",
1187
+ " logging.info(f\"Validation signature: {signature}\")",
1188
+ " logging.info(\"=\" * 60)",
1189
+ " ",
1190
+ " return signature",
1191
+ "",
1192
+ ])
1193
+
1194
+ # Add postprocessing validation function if postprocessing exists
1195
+ if postprocessing_section:
1196
+ script_lines.extend([
1197
+ "def validate_postprocessing_output(result_array):",
1198
+ " \"\"\"",
1199
+ " Validate postprocessing output coordinates and data type.",
1200
+ " ",
1201
+ " Args:",
1202
+ " result_array: xarray DataArray from postprocessing",
1203
+ " ",
1204
+ " Returns:",
1205
+ " str: Validation signature symbol",
1206
+ " ",
1207
+ " Raises:",
1208
+ " ValueError: If validation fails",
1209
+ " \"\"\"",
1210
+ " import numpy as np",
1211
+ " ",
1212
+ " logging.info(\"=\" * 60)",
1213
+ " logging.info(\"VALIDATING POSTPROCESSING OUTPUT\")",
1214
+ " logging.info(\"=\" * 60)",
1215
+ " ",
1216
+ " # Check if it's an xarray DataArray",
1217
+ " if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
1218
+ " raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
1219
+ " ",
1220
+ " # Check required coordinates",
1221
+ " if 'time' not in result_array.coords:",
1222
+ " raise ValueError(\"Missing time coordinate\")",
1223
+ " ",
1224
+ " spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
1225
+ " if len(spatial_dims) != 2:",
1226
+ " raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
1227
+ " ",
1228
+ " for dim in spatial_dims:",
1229
+ " if dim not in result_array.coords:",
1230
+ " raise ValueError(f\"Missing spatial coordinate: {dim}\")",
1231
+ " ",
1232
+ " logging.info(f\"Coordinates found: {list(result_array.coords.keys())}\")",
1233
+ " ",
1234
+ " # Check data type",
1235
+ " data_values = result_array.values",
1236
+ " logging.info(f\"Data type: {data_values.dtype}\")",
1237
+ " ",
1238
+ " # Check shape",
1239
+ " shape = result_array.shape",
1240
+ " logging.info(f\"Shape: {shape}\")",
1241
+ " ",
1242
+ " # Generate validation signature",
1243
+ " signature_components = [",
1244
+ " f\"T{shape[0]}\", # Time dimension",
1245
+ " f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
1246
+ " f\"DT{data_values.dtype}\", # Data type",
1247
+ " ]",
1248
+ " ",
1249
+ " signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
1250
+ " ",
1251
+ " logging.info(\"-\" * 60)",
1252
+ " logging.info(\"POSTPROCESSING VALIDATION SUMMARY\")",
1253
+ " logging.info(\"-\" * 60)",
1254
+ " logging.info(f\"Final shape: {shape}\")",
1255
+ " logging.info(f\"Final coordinates: {list(result_array.coords.keys())}\")",
1256
+ " logging.info(f\"Data type: {data_values.dtype}\")",
1257
+ " logging.info(f\"Validation signature: {signature}\")",
1258
+ " logging.info(\"=\" * 60)",
1259
+ " ",
1260
+ " return signature",
1261
+ "",
1262
+ ])
1263
+
797
1264
  # Add preprocessing function definition BEFORE the main function
798
1265
  if preprocessing_section:
799
1266
  script_lines.extend([
@@ -835,13 +1302,17 @@ class ModelManagement:
835
1302
  " ",
836
1303
  ])
837
1304
 
838
- # Add preprocessing call if preprocessing exists
1305
+ # Add preprocessing call and validation if preprocessing exists
839
1306
  if preprocessing_section:
840
1307
  script_lines.extend([
841
1308
  " # Apply preprocessing",
842
1309
  " data_arrays = preprocessing(tuple(data_arrays))",
843
1310
  " data_arrays = list(data_arrays) # Convert back to list for processing",
844
1311
  " ",
1312
+ " # Validate preprocessing output",
1313
+ " preprocessing_signature = validate_preprocessing_output(data_arrays)",
1314
+ " logging.info(f\"Preprocessing validation signature: {preprocessing_signature}\")",
1315
+ " ",
845
1316
  ])
846
1317
 
847
1318
  # Continue with the rest of the processing logic
@@ -903,19 +1374,39 @@ class ModelManagement:
903
1374
  " output = model.run(None, {\"float_input\": input_data})[0]",
904
1375
  " logging.info(f\"Model output shape: {output.shape}\")",
905
1376
  " ",
906
- " # Process output back to xarray format",
907
- " # Assuming output is (1, height, width) or (1, 1, height, width)",
908
- " if output.ndim == 4 and output.shape[1] == 1:",
909
- " # Remove channel dimension if it's 1",
910
- " output_2d = output[0, 0]",
1377
+ " # UPDATED: Handle multi-class CNN output properly",
1378
+ " if output.ndim == 4:",
1379
+ " if output.shape[1] == 1:",
1380
+ " # Single class output (regression or binary classification)",
1381
+ " output_2d = output[0, 0]",
1382
+ " logging.info(\"Single channel output detected\")",
1383
+ " else:",
1384
+ " # Multi-class output - convert logits/probabilities to class predictions",
1385
+ " output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
1386
+ " output_2d = output_classes[0] # Shape: (height, width)",
1387
+ " ",
1388
+ " # Apply class merging: merge class 6 into class 3",
1389
+ " output_2d = np.where(output_2d == 6, 3, output_2d)",
1390
+ " ",
1391
+ " logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
1392
+ " logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
1393
+ " logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
911
1394
  " elif output.ndim == 3:",
912
1395
  " # Remove batch dimension",
913
1396
  " output_2d = output[0]",
1397
+ " logging.info(\"3D output detected, removed batch dimension\")",
914
1398
  " else:",
915
1399
  " # Handle other cases",
916
1400
  " output_2d = np.squeeze(output)",
917
1401
  " if output_2d.ndim != 2:",
1402
+ " logging.error(f\"Cannot process output shape: {output.shape}\")",
1403
+ " logging.error(f\"After squeeze: {output_2d.shape}\")",
918
1404
  " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
1405
+ " logging.info(\"Applied squeeze to output\")",
1406
+ " ",
1407
+ " # Ensure output is 2D",
1408
+ " if output_2d.ndim != 2:",
1409
+ " raise ValueError(f\"Final output must be 2D, got shape: {output_2d.shape}\")",
919
1410
  " ",
920
1411
  " # Determine output timestamp (use the latest timestamp)",
921
1412
  " output_timestamp = time_coords[-1]",
@@ -923,33 +1414,53 @@ class ModelManagement:
923
1414
  " # Get spatial coordinates from reference array",
924
1415
  " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims}",
925
1416
  " ",
926
- " # Create output DataArray",
1417
+ " # Create output DataArray with appropriate data type",
1418
+ " # Use int32 for classification, float32 for regression",
1419
+ " is_multiclass = output.ndim == 4 and output.shape[1] > 1",
1420
+ " if is_multiclass:",
1421
+ " # Multi-class classification - use integer type",
1422
+ " output_dtype = np.int32",
1423
+ " output_type = 'classification'",
1424
+ " else:",
1425
+ " # Single output - use float type",
1426
+ " output_dtype = np.float32",
1427
+ " output_type = 'regression'",
1428
+ " ",
927
1429
  " result = xr.DataArray(",
928
- " data=np.expand_dims(output_2d.astype(np.float32), axis=0),",
1430
+ " data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
929
1431
  " dims=['time'] + spatial_dims,",
930
1432
  " coords={",
931
1433
  " 'time': [output_timestamp.values],",
932
1434
  " spatial_dims[0]: spatial_coords[spatial_dims[0]].values,",
933
1435
  " spatial_dims[1]: spatial_coords[spatial_dims[1]].values",
1436
+ " },",
1437
+ " attrs={",
1438
+ " 'description': 'CNN model prediction',",
934
1439
  " }",
935
1440
  " )",
936
1441
  " ",
937
1442
  " logging.info(f\"Final result shape: {result.shape}\")",
1443
+ " logging.info(f\"Final result data type: {result.dtype}\")",
1444
+ " logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
938
1445
  ])
939
1446
 
940
- # Add postprocessing call if postprocessing exists
1447
+ # Add postprocessing call and validation if postprocessing exists
941
1448
  if postprocessing_section:
942
1449
  script_lines.extend([
943
1450
  " # Apply postprocessing",
944
1451
  " result = postprocessing(result)",
945
1452
  " ",
1453
+ " # Validate postprocessing output",
1454
+ " postprocessing_signature = validate_postprocessing_output(result)",
1455
+ " logging.info(f\"Postprocessing validation signature: {postprocessing_signature}\")",
1456
+ " ",
946
1457
  ])
947
1458
 
948
1459
  # Single return statement at the end
949
1460
  script_lines.append(" return result")
950
1461
 
951
1462
  return "\n".join(script_lines)
952
-
1463
+
953
1464
  @require_api_key
954
1465
  def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
955
1466
  """Upload the generated script to Google Cloud Storage"""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: terrakio-core
3
- Version: 0.4.0
3
+ Version: 0.4.3
4
4
  Summary: Core components for Terrakio API clients
5
5
  Author-email: Yupeng Chao <yupeng@haizea.com.au>
6
6
  Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
@@ -1,5 +1,5 @@
1
- terrakio_core/__init__.py,sha256=-9YH041kcLF7sIqWy3V0rtFc7q14fVXx6JFK0U5Ey6s,248
2
- terrakio_core/async_client.py,sha256=T0AEyhh5EwZ2_okRiHuXkLgCqovxl-fuQSejH7hXU0I,13801
1
+ terrakio_core/__init__.py,sha256=GDUYT7cudfxltFZZ-2_5xM4Pd7o-pkQx1O8K9HrE5f0,248
2
+ terrakio_core/async_client.py,sha256=ffLKMbzclHfyBZBJn4uR8nlMW5E-96PX-4O5EghaSPE,13805
3
3
  terrakio_core/client.py,sha256=-tGffOKGMyuowsvBwaV7Wtc_EZSWuSwv26_I5FkUank,5446
4
4
  terrakio_core/config.py,sha256=r8NARVYOca4AuM88VP_j-8wQxOk1s7VcRdyEdseBlLE,4193
5
5
  terrakio_core/exceptions.py,sha256=4qnpOM1gOxsNIXDXY4qwY1d3I4Myhp7HBh7b2D0SVrU,529
@@ -8,14 +8,14 @@ terrakio_core/convenience_functions/convenience_functions.py,sha256=sBY2g7Vv3jak
8
8
  terrakio_core/endpoints/auth.py,sha256=e_hdNE6JOGhRVlQMFdEoOmoMHp5EzK6CclOEnc_AmZw,5863
9
9
  terrakio_core/endpoints/dataset_management.py,sha256=BUm8IIlW_Q45vDiQp16CiJGeSLheI8uWRVRQtMdhaNk,13161
10
10
  terrakio_core/endpoints/group_management.py,sha256=VFl3jakjQa9OPi351D3DZvLU9M7fHdfjCzGhmyJsx3U,6309
11
- terrakio_core/endpoints/mass_stats.py,sha256=yhLCYRrdQPiWwJVCIPbzU5NV3xU5m62pxhYY1FucYjI,23130
12
- terrakio_core/endpoints/model_management.py,sha256=uzyIHCRgyOwaQFConO0Ur6C0bnMdj4VDpyjiMG8R1Mc,42303
11
+ terrakio_core/endpoints/mass_stats.py,sha256=IZEozQ9GyOmUhd7V8M66Bz2OWsyq-VKzOw_sj_i-dng,23154
12
+ terrakio_core/endpoints/model_management.py,sha256=PF-2f6mW_RtkPWSL-N56kbgLeB6Z4EhU2N2qFqPan7o,70365
13
13
  terrakio_core/endpoints/space_management.py,sha256=YWb55nkJnFJGlALJ520DvurxDqVqwYtsvqQPWzxzhDs,2266
14
14
  terrakio_core/endpoints/user_management.py,sha256=WlFr3EfK8iI6DfkpMuYLHZUPk2n7_DHHO6z1hndmZB4,3816
15
15
  terrakio_core/helper/bounded_taskgroup.py,sha256=wiTH10jhKZgrsgrFUNG6gig8bFkUEPHkGRT2XY7Rgmo,677
16
16
  terrakio_core/helper/decorators.py,sha256=L6om7wmWNgCei3Wy5U0aZ-70OzsCwclkjIf7SfQuhCg,2289
17
17
  terrakio_core/helper/tiles.py,sha256=xNtp3oDD912PN_FQV5fb6uQYhwfHANuXyIcxoVCCfZU,2632
18
- terrakio_core-0.4.0.dist-info/METADATA,sha256=ctBxSZybuLE-4Mwh0-FTVZKCCkfXZl-jTRfIqth5bKc,1756
19
- terrakio_core-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- terrakio_core-0.4.0.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
21
- terrakio_core-0.4.0.dist-info/RECORD,,
18
+ terrakio_core-0.4.3.dist-info/METADATA,sha256=201kTjM26SSTmtRPBohs9GlBsP_3U3tUw6ja2X0D6uM,1756
19
+ terrakio_core-0.4.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ terrakio_core-0.4.3.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
21
+ terrakio_core-0.4.3.dist-info/RECORD,,