terrakio-core 0.4.2__tar.gz → 0.4.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/PKG-INFO +1 -1
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/pyproject.toml +1 -1
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/__init__.py +1 -1
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/async_client.py +1 -1
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/endpoints/mass_stats.py +7 -7
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/endpoints/model_management.py +486 -11
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core.egg-info/PKG-INFO +1 -1
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/README.md +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/setup.cfg +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/client.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/config.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/convenience_functions/convenience_functions.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/endpoints/auth.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/endpoints/dataset_management.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/endpoints/group_management.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/endpoints/space_management.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/endpoints/user_management.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/exceptions.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/helper/bounded_taskgroup.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/helper/decorators.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/helper/tiles.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core/sync_client.py +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core.egg-info/SOURCES.txt +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core.egg-info/dependency_links.txt +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core.egg-info/requires.txt +0 -0
- {terrakio_core-0.4.2 → terrakio_core-0.4.3}/terrakio_core.egg-info/top_level.txt +0 -0
|
@@ -47,7 +47,7 @@ class AsyncClient(BaseClient):
|
|
|
47
47
|
return await self._make_request_with_retry(self._session, method, endpoint, **kwargs)
|
|
48
48
|
|
|
49
49
|
async def _make_request_with_retry(self, session: aiohttp.ClientSession, method: str, endpoint: str, **kwargs) -> Dict[Any, Any]:
|
|
50
|
-
url = f"{self.url}/{endpoint.lstrip('/')}"
|
|
50
|
+
url = f"{self.url}/{endpoint.lstrip('/')}"
|
|
51
51
|
last_exception = None
|
|
52
52
|
|
|
53
53
|
for attempt in range(self.retry + 1):
|
|
@@ -67,7 +67,7 @@ class MassStats:
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
@require_api_key
|
|
70
|
-
def start_job(self, id: str) -> Dict[str, Any]:
|
|
70
|
+
async def start_job(self, id: str) -> Dict[str, Any]:
|
|
71
71
|
"""
|
|
72
72
|
Start a mass stats job by task ID.
|
|
73
73
|
|
|
@@ -78,7 +78,7 @@ class MassStats:
|
|
|
78
78
|
API response as a dictionary
|
|
79
79
|
|
|
80
80
|
"""
|
|
81
|
-
return self._client._terrakio_request("POST", f"mass_stats/start/{id}")
|
|
81
|
+
return await self._client._terrakio_request("POST", f"mass_stats/start/{id}")
|
|
82
82
|
|
|
83
83
|
@require_api_key
|
|
84
84
|
def get_task_id(self, name: str, stage: str, uid: Optional[str] = None) -> Dict[str, Any]:
|
|
@@ -542,7 +542,7 @@ class MassStats:
|
|
|
542
542
|
return self._client._terrakio_request("POST", "mass_stats/cancel")
|
|
543
543
|
|
|
544
544
|
@require_api_key
|
|
545
|
-
def random_sample(
|
|
545
|
+
async def random_sample(
|
|
546
546
|
self,
|
|
547
547
|
name: str,
|
|
548
548
|
config: dict,
|
|
@@ -556,7 +556,7 @@ class MassStats:
|
|
|
556
556
|
year_range: list[int] = None,
|
|
557
557
|
overwrite: bool = False,
|
|
558
558
|
server: str = None,
|
|
559
|
-
bucket: str = None
|
|
559
|
+
bucket: str = None
|
|
560
560
|
) -> Dict[str, Any]:
|
|
561
561
|
"""
|
|
562
562
|
Submit a random sample job.
|
|
@@ -591,18 +591,18 @@ class MassStats:
|
|
|
591
591
|
"tile_size": tile_size,
|
|
592
592
|
"res": res,
|
|
593
593
|
"output": output,
|
|
594
|
-
"region": region,
|
|
595
594
|
"overwrite": str(overwrite).lower(),
|
|
596
595
|
}
|
|
597
596
|
payload_mapping = {
|
|
598
597
|
"year_range": year_range,
|
|
599
598
|
"server": server,
|
|
600
|
-
"
|
|
599
|
+
"region": region,
|
|
600
|
+
"bucket": bucket,
|
|
601
601
|
}
|
|
602
602
|
for key, value in payload_mapping.items():
|
|
603
603
|
if value is not None:
|
|
604
604
|
payload[key] = value
|
|
605
|
-
return self._client._terrakio_request("POST", "random_sample", json=payload)
|
|
605
|
+
return await self._client._terrakio_request("POST", "random_sample", json=payload)
|
|
606
606
|
|
|
607
607
|
|
|
608
608
|
@require_api_key
|
|
@@ -37,7 +37,7 @@ class ModelManagement:
|
|
|
37
37
|
self._client = client
|
|
38
38
|
|
|
39
39
|
@require_api_key
|
|
40
|
-
def generate_ai_dataset(
|
|
40
|
+
async def generate_ai_dataset(
|
|
41
41
|
self,
|
|
42
42
|
name: str,
|
|
43
43
|
aoi_geojson: str,
|
|
@@ -51,7 +51,8 @@ class ModelManagement:
|
|
|
51
51
|
filter_y: str = "skip",
|
|
52
52
|
crs: str = "epsg:4326",
|
|
53
53
|
res: float = 0.001,
|
|
54
|
-
region: str =
|
|
54
|
+
region: str = None,
|
|
55
|
+
bucket: str = None,
|
|
55
56
|
start_year: int = None,
|
|
56
57
|
end_year: int = None,
|
|
57
58
|
) -> dict:
|
|
@@ -71,7 +72,8 @@ class ModelManagement:
|
|
|
71
72
|
tile_size (int): Size of tiles in degrees
|
|
72
73
|
crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
|
|
73
74
|
res (float, optional): Resolution in degrees. Defaults to 0.001
|
|
74
|
-
region (str, optional): Region code. Defaults to
|
|
75
|
+
region (str, optional): Region code. Defaults to None
|
|
76
|
+
bucket (str, optional): Bucket name. Defaults to None
|
|
75
77
|
start_year (int, optional): Start year for data generation. Required if end_year provided
|
|
76
78
|
end_year (int, optional): End year for data generation. Required if start_year provided
|
|
77
79
|
|
|
@@ -109,7 +111,7 @@ class ModelManagement:
|
|
|
109
111
|
with open(aoi_geojson, 'r') as f:
|
|
110
112
|
aoi_data = json.load(f)
|
|
111
113
|
|
|
112
|
-
task_response = self._client.mass_stats.random_sample(
|
|
114
|
+
task_response = await self._client.mass_stats.random_sample(
|
|
113
115
|
name=name,
|
|
114
116
|
config=config,
|
|
115
117
|
aoi=aoi_data,
|
|
@@ -121,14 +123,14 @@ class ModelManagement:
|
|
|
121
123
|
region=region,
|
|
122
124
|
output="netcdf",
|
|
123
125
|
server=self._client.url,
|
|
124
|
-
bucket=
|
|
126
|
+
bucket=bucket,
|
|
125
127
|
overwrite=True
|
|
126
128
|
)
|
|
127
129
|
task_id = task_response["task_id"]
|
|
128
130
|
|
|
129
131
|
# Wait for job completion with progress bar
|
|
130
132
|
while True:
|
|
131
|
-
result = self._client.
|
|
133
|
+
result = await self._client.mass_stats.track_job(ids=[task_id])
|
|
132
134
|
status = result[task_id]['status']
|
|
133
135
|
completed = result[task_id].get('completed', 0)
|
|
134
136
|
total = result[task_id].get('total', 1)
|
|
@@ -153,9 +155,53 @@ class ModelManagement:
|
|
|
153
155
|
time.sleep(5)
|
|
154
156
|
|
|
155
157
|
# after all the random sample jobs are done, we then start the mass stats job
|
|
156
|
-
task_id = self._client.mass_stats.start_mass_stats_job(task_id)
|
|
158
|
+
# task_id = self._client.mass_stats.start_mass_stats_job(task_id)
|
|
159
|
+
task_id = await self._client.mass_stats.start_job(task_id)
|
|
157
160
|
return task_id
|
|
161
|
+
# the folder that is being created is not under the jobs folder, its directly under the UID folder
|
|
158
162
|
|
|
163
|
+
# @require_api_key
|
|
164
|
+
# async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
|
|
165
|
+
# """
|
|
166
|
+
# Upload a model to the bucket so that it can be used for inference.
|
|
167
|
+
# Converts PyTorch and scikit-learn models to ONNX format before uploading.
|
|
168
|
+
|
|
169
|
+
# Args:
|
|
170
|
+
# model: The model object (PyTorch model or scikit-learn model)
|
|
171
|
+
# model_name: Name for the model (without extension)
|
|
172
|
+
# input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
|
|
173
|
+
# Required for PyTorch models, optional for scikit-learn models
|
|
174
|
+
|
|
175
|
+
# Raises:
|
|
176
|
+
# APIError: If the API request fails
|
|
177
|
+
# ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
178
|
+
# ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
179
|
+
# """
|
|
180
|
+
# uid = (await self._client.auth.get_user_info())["uid"]
|
|
181
|
+
# # above line is getting the uid,
|
|
182
|
+
|
|
183
|
+
# client = storage.Client()
|
|
184
|
+
# bucket = client.get_bucket('terrakio-mass-requests')
|
|
185
|
+
|
|
186
|
+
# # Convert model to ONNX format
|
|
187
|
+
# onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
|
|
188
|
+
|
|
189
|
+
# # Upload ONNX model to bucket
|
|
190
|
+
# # blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
|
|
191
|
+
# # we don't need to upload the model to the bucket
|
|
192
|
+
# # so the stuff is stored under the virtual datasets folder
|
|
193
|
+
# # the model name and the virtual dataset name should be the same
|
|
194
|
+
# virtual_dataset_name = model_name
|
|
195
|
+
# blob = bucket.blob(f'{uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx')
|
|
196
|
+
# # wer are uploading the model to the virtual dataset folder
|
|
197
|
+
|
|
198
|
+
# blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
|
|
199
|
+
|
|
200
|
+
# self._client.logger.info(f"Model uploaded successfully to {uid}/virtual_datasets/{virtual_dataset_name}/{model_name}.onnx")
|
|
201
|
+
|
|
202
|
+
# this is the upload model function, I think we need to upload to the user, under the virutal_datasets folder, and create the virtual dataset
|
|
203
|
+
|
|
204
|
+
|
|
159
205
|
@require_api_key
|
|
160
206
|
async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
|
|
161
207
|
"""
|
|
@@ -351,6 +397,33 @@ class ModelManagement:
|
|
|
351
397
|
raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
|
|
352
398
|
|
|
353
399
|
|
|
400
|
+
# we do not need to pass in both the model name and the dataset name, since the model name should the same as the virtual dataset name
|
|
401
|
+
# but we are gonna have multiple products for the same virtual dataset
|
|
402
|
+
# @require_api_key
|
|
403
|
+
# async def upload_and_deploy_cnn_model(self, model, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
|
|
404
|
+
# """
|
|
405
|
+
# Upload a CNN model to the bucket and deploy it.
|
|
406
|
+
|
|
407
|
+
# Args:
|
|
408
|
+
# model: The model object (PyTorch model or scikit-learn model)
|
|
409
|
+
# model_name: Name for the model (without extension)
|
|
410
|
+
# dataset: Name of the dataset to create
|
|
411
|
+
# product: Product name for the inference
|
|
412
|
+
# input_expression: Input expression for the dataset
|
|
413
|
+
# dates_iso8601: List of dates in ISO8601 format
|
|
414
|
+
# input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
|
|
415
|
+
# processing_script_path: Path to the processing script, if not provided, no processing will be done
|
|
416
|
+
|
|
417
|
+
# Raises:
|
|
418
|
+
# APIError: If the API request fails
|
|
419
|
+
# ValueError: If model type is not supported or input_shape is missing for PyTorch models
|
|
420
|
+
# ImportError: If required libraries (torch or skl2onnx) are not installed
|
|
421
|
+
# """
|
|
422
|
+
# await self.upload_model(model=model, model_name=dataset, input_shape=input_shape)
|
|
423
|
+
# # so the uploading process is kinda similar, but the deployment step is kinda different
|
|
424
|
+
# # we should pass the processing script path to the deploy cnn model function
|
|
425
|
+
# await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
|
|
426
|
+
|
|
354
427
|
@require_api_key
|
|
355
428
|
async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None, processing_script_path: Optional[str] = None):
|
|
356
429
|
"""
|
|
@@ -376,6 +449,7 @@ class ModelManagement:
|
|
|
376
449
|
# we should pass the processing script path to the deploy cnn model function
|
|
377
450
|
await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601, processing_script_path=processing_script_path)
|
|
378
451
|
|
|
452
|
+
|
|
379
453
|
@require_api_key
|
|
380
454
|
async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
|
|
381
455
|
"""
|
|
@@ -617,7 +691,7 @@ class ModelManagement:
|
|
|
617
691
|
clean_preprocessing = preprocessing_code
|
|
618
692
|
# Then add consistent 8-space indentation to match the template
|
|
619
693
|
preprocessing_section = f"""{textwrap.indent(clean_preprocessing, '')}""" # 8 spaces
|
|
620
|
-
print(preprocessing_section)
|
|
694
|
+
# print(preprocessing_section)
|
|
621
695
|
script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid, preprocessing_code, postprocessing_code)
|
|
622
696
|
script_name = f"{product}.py"
|
|
623
697
|
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
@@ -748,6 +822,245 @@ class ModelManagement:
|
|
|
748
822
|
return result
|
|
749
823
|
''').strip()
|
|
750
824
|
|
|
825
|
+
# @require_api_key
|
|
826
|
+
# def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
|
|
827
|
+
# """
|
|
828
|
+
# Generate Python inference script for CNN model with time-stacked bands.
|
|
829
|
+
|
|
830
|
+
# Args:
|
|
831
|
+
# model_name: Name of the model
|
|
832
|
+
# product: Product name
|
|
833
|
+
# model_training_job_name: Training job name
|
|
834
|
+
# uid: User ID
|
|
835
|
+
# preprocessing_code: Preprocessing code
|
|
836
|
+
# postprocessing_code: Postprocessing code
|
|
837
|
+
# Returns:
|
|
838
|
+
# str: Generated Python script content
|
|
839
|
+
# """
|
|
840
|
+
# import textwrap
|
|
841
|
+
|
|
842
|
+
# # Build preprocessing section with CONSISTENT 4-space indentation
|
|
843
|
+
# preprocessing_section = ""
|
|
844
|
+
# if preprocessing_code and preprocessing_code.strip():
|
|
845
|
+
# clean_preprocessing = textwrap.dedent(preprocessing_code)
|
|
846
|
+
# preprocessing_section = textwrap.indent(clean_preprocessing, ' ')
|
|
847
|
+
|
|
848
|
+
# # Build postprocessing section with CONSISTENT 4-space indentation
|
|
849
|
+
# postprocessing_section = ""
|
|
850
|
+
# if postprocessing_code and postprocessing_code.strip():
|
|
851
|
+
# clean_postprocessing = textwrap.dedent(postprocessing_code)
|
|
852
|
+
# postprocessing_section = textwrap.indent(clean_postprocessing, ' ')
|
|
853
|
+
|
|
854
|
+
# # Build the template WITHOUT dedenting the whole thing, so indentation is preserved
|
|
855
|
+
# script_lines = [
|
|
856
|
+
# "import logging",
|
|
857
|
+
# "from io import BytesIO",
|
|
858
|
+
# "import numpy as np",
|
|
859
|
+
# "import pandas as pd",
|
|
860
|
+
# "import xarray as xr",
|
|
861
|
+
# "from google.cloud import storage",
|
|
862
|
+
# "from onnxruntime import InferenceSession",
|
|
863
|
+
# "from typing import Tuple",
|
|
864
|
+
# "",
|
|
865
|
+
# "logging.basicConfig(",
|
|
866
|
+
# " level=logging.INFO",
|
|
867
|
+
# ")",
|
|
868
|
+
# "",
|
|
869
|
+
# ]
|
|
870
|
+
|
|
871
|
+
# # Add preprocessing function definition BEFORE the main function
|
|
872
|
+
# if preprocessing_section:
|
|
873
|
+
# script_lines.extend([
|
|
874
|
+
# "def preprocessing(array: Tuple[xr.DataArray, ...]) -> Tuple[xr.DataArray, ...]:",
|
|
875
|
+
# preprocessing_section,
|
|
876
|
+
# "",
|
|
877
|
+
# ])
|
|
878
|
+
|
|
879
|
+
# # Add postprocessing function definition BEFORE the main function
|
|
880
|
+
# if postprocessing_section:
|
|
881
|
+
# script_lines.extend([
|
|
882
|
+
# "def postprocessing(array: xr.DataArray) -> xr.DataArray:",
|
|
883
|
+
# postprocessing_section,
|
|
884
|
+
# "",
|
|
885
|
+
# ])
|
|
886
|
+
|
|
887
|
+
# # Add the get_model function
|
|
888
|
+
# script_lines.extend([
|
|
889
|
+
# "def get_model():",
|
|
890
|
+
# f" logging.info(\"Loading CNN model for {model_name}...\")",
|
|
891
|
+
# "",
|
|
892
|
+
# " client = storage.Client()",
|
|
893
|
+
# " bucket = client.get_bucket('terrakio-mass-requests')",
|
|
894
|
+
# f" blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')",
|
|
895
|
+
# "",
|
|
896
|
+
# " model = BytesIO()",
|
|
897
|
+
# " blob.download_to_file(model)",
|
|
898
|
+
# " model.seek(0)",
|
|
899
|
+
# "",
|
|
900
|
+
# " session = InferenceSession(model.read(), providers=[\"CPUExecutionProvider\"])",
|
|
901
|
+
# " return session",
|
|
902
|
+
# "",
|
|
903
|
+
# f"def {product}(*bands, model):",
|
|
904
|
+
# " logging.info(\"Start preparing CNN data with time-stacked bands\")",
|
|
905
|
+
# " data_arrays = list(bands)",
|
|
906
|
+
# " ",
|
|
907
|
+
# " if not data_arrays:",
|
|
908
|
+
# " raise ValueError(\"No bands provided\")",
|
|
909
|
+
# " ",
|
|
910
|
+
# ])
|
|
911
|
+
|
|
912
|
+
# # Add preprocessing call if preprocessing exists
|
|
913
|
+
# if preprocessing_section:
|
|
914
|
+
# script_lines.extend([
|
|
915
|
+
# " # Apply preprocessing",
|
|
916
|
+
# " data_arrays = preprocessing(tuple(data_arrays))",
|
|
917
|
+
# " data_arrays = list(data_arrays) # Convert back to list for processing",
|
|
918
|
+
# " ",
|
|
919
|
+
# ])
|
|
920
|
+
|
|
921
|
+
# # Continue with the rest of the processing logic
|
|
922
|
+
# script_lines.extend([
|
|
923
|
+
# " reference_array = data_arrays[0]",
|
|
924
|
+
# " original_shape = reference_array.shape",
|
|
925
|
+
# " logging.info(f\"Original shape: {original_shape}\")",
|
|
926
|
+
# " ",
|
|
927
|
+
# " # Get time coordinates - all bands should have the same time dimension",
|
|
928
|
+
# " if 'time' not in reference_array.dims:",
|
|
929
|
+
# " raise ValueError(\"Time dimension is required for CNN processing\")",
|
|
930
|
+
# " ",
|
|
931
|
+
# " time_coords = reference_array.coords['time']",
|
|
932
|
+
# " num_timestamps = len(time_coords)",
|
|
933
|
+
# " logging.info(f\"Number of timestamps: {num_timestamps}\")",
|
|
934
|
+
# " ",
|
|
935
|
+
# " # Get spatial dimensions",
|
|
936
|
+
# " spatial_dims = [dim for dim in reference_array.dims if dim != 'time']",
|
|
937
|
+
# " height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height",
|
|
938
|
+
# " width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width",
|
|
939
|
+
# " logging.info(f\"Spatial dimensions: {height} x {width}\")",
|
|
940
|
+
# " ",
|
|
941
|
+
# " # Stack bands across time dimension",
|
|
942
|
+
# " # Result will be: (num_bands * num_timestamps, height, width)",
|
|
943
|
+
# " stacked_channels = []",
|
|
944
|
+
# " ",
|
|
945
|
+
# " for band_idx, data_array in enumerate(data_arrays):",
|
|
946
|
+
# " logging.info(f\"Processing band {band_idx + 1}/{len(data_arrays)}\")",
|
|
947
|
+
# " ",
|
|
948
|
+
# " # Ensure consistent time coordinates across bands",
|
|
949
|
+
# " if not np.array_equal(data_array.coords['time'].values, time_coords.values):",
|
|
950
|
+
# " logging.warning(f\"Band {band_idx} has different time coordinates, aligning...\")",
|
|
951
|
+
# " data_array = data_array.sel(time=time_coords, method='nearest')",
|
|
952
|
+
# " ",
|
|
953
|
+
# " # Extract values and ensure proper ordering (time, height, width)",
|
|
954
|
+
# " band_values = data_array.values",
|
|
955
|
+
# " if band_values.ndim == 3:",
|
|
956
|
+
# " # Reorder dimensions if needed to ensure (time, height, width)",
|
|
957
|
+
# " time_dim_idx = data_array.dims.index('time')",
|
|
958
|
+
# " if time_dim_idx != 0:",
|
|
959
|
+
# " axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]",
|
|
960
|
+
# " band_values = np.transpose(band_values, axes_order)",
|
|
961
|
+
# " ",
|
|
962
|
+
# " # Add each timestamp of this band to the channel stack",
|
|
963
|
+
# " for t in range(num_timestamps):",
|
|
964
|
+
# " stacked_channels.append(band_values[t])",
|
|
965
|
+
# " ",
|
|
966
|
+
# " # Stack all channels: (num_bands * num_timestamps, height, width)",
|
|
967
|
+
# " input_channels = np.stack(stacked_channels, axis=0)",
|
|
968
|
+
# " total_channels = len(data_arrays) * num_timestamps",
|
|
969
|
+
# " logging.info(f\"Stacked channels shape: {input_channels.shape}\")",
|
|
970
|
+
# " logging.info(f\"Total channels: {total_channels} ({len(data_arrays)} bands × {num_timestamps} timestamps)\")",
|
|
971
|
+
# " ",
|
|
972
|
+
# " # Add batch dimension: (1, num_channels, height, width)",
|
|
973
|
+
# " input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)",
|
|
974
|
+
# " logging.info(f\"Final input shape for CNN: {input_data.shape}\")",
|
|
975
|
+
# " ",
|
|
976
|
+
# " # Run inference",
|
|
977
|
+
# " output = model.run(None, {\"float_input\": input_data})[0]",
|
|
978
|
+
# " logging.info(f\"Model output shape: {output.shape}\")",
|
|
979
|
+
# " ",
|
|
980
|
+
# " # UPDATED: Handle multi-class CNN output properly",
|
|
981
|
+
# " if output.ndim == 4:",
|
|
982
|
+
# " if output.shape[1] == 1:",
|
|
983
|
+
# " # Single class output (regression or binary classification)",
|
|
984
|
+
# " output_2d = output[0, 0]",
|
|
985
|
+
# " logging.info(\"Single channel output detected\")",
|
|
986
|
+
# " else:",
|
|
987
|
+
# " # Multi-class output - convert logits/probabilities to class predictions",
|
|
988
|
+
# " output_classes = np.argmax(output, axis=1) # Shape: (1, height, width)",
|
|
989
|
+
# " output_2d = output_classes[0] # Shape: (height, width)",
|
|
990
|
+
# " ",
|
|
991
|
+
# " # Apply class merging: merge class 6 into class 3",
|
|
992
|
+
# " output_2d = np.where(output_2d == 6, 3, output_2d)",
|
|
993
|
+
# " ",
|
|
994
|
+
# " logging.info(f\"Multi-class output processed. Original classes: {output.shape[1]}\")",
|
|
995
|
+
# " logging.info(f\"Unique classes in output: {np.unique(output_2d)}\")",
|
|
996
|
+
# " logging.info(f\"Class distribution: {np.bincount(output_2d.flatten())}\")",
|
|
997
|
+
# " elif output.ndim == 3:",
|
|
998
|
+
# " # Remove batch dimension",
|
|
999
|
+
# " output_2d = output[0]",
|
|
1000
|
+
# " logging.info(\"3D output detected, removed batch dimension\")",
|
|
1001
|
+
# " else:",
|
|
1002
|
+
# " # Handle other cases",
|
|
1003
|
+
# " output_2d = np.squeeze(output)",
|
|
1004
|
+
# " if output_2d.ndim != 2:",
|
|
1005
|
+
# " logging.error(f\"Cannot process output shape: {output.shape}\")",
|
|
1006
|
+
# " logging.error(f\"After squeeze: {output_2d.shape}\")",
|
|
1007
|
+
# " raise ValueError(f\"Unexpected output shape after processing: {output_2d.shape}\")",
|
|
1008
|
+
# " logging.info(\"Applied squeeze to output\")",
|
|
1009
|
+
# " ",
|
|
1010
|
+
# " # Ensure output is 2D",
|
|
1011
|
+
# " if output_2d.ndim != 2:",
|
|
1012
|
+
# " raise ValueError(f\"Final output must be 2D, got shape: {output_2d.shape}\")",
|
|
1013
|
+
# " ",
|
|
1014
|
+
# " # Determine output timestamp (use the latest timestamp)",
|
|
1015
|
+
# " output_timestamp = time_coords[-1]",
|
|
1016
|
+
# " ",
|
|
1017
|
+
# " # Get spatial coordinates from reference array",
|
|
1018
|
+
# " spatial_coords = {dim: reference_array.coords[dim] for dim in spatial_dims}",
|
|
1019
|
+
# " ",
|
|
1020
|
+
# " # Create output DataArray with appropriate data type",
|
|
1021
|
+
# " # Use int32 for classification, float32 for regression",
|
|
1022
|
+
# " is_multiclass = output.ndim == 4 and output.shape[1] > 1",
|
|
1023
|
+
# " if is_multiclass:",
|
|
1024
|
+
# " # Multi-class classification - use integer type",
|
|
1025
|
+
# " output_dtype = np.int32",
|
|
1026
|
+
# " output_type = 'classification'",
|
|
1027
|
+
# " else:",
|
|
1028
|
+
# " # Single output - use float type",
|
|
1029
|
+
# " output_dtype = np.float32",
|
|
1030
|
+
# " output_type = 'regression'",
|
|
1031
|
+
# " ",
|
|
1032
|
+
# " result = xr.DataArray(",
|
|
1033
|
+
# " data=np.expand_dims(output_2d.astype(output_dtype), axis=0),",
|
|
1034
|
+
# " dims=['time'] + spatial_dims,",
|
|
1035
|
+
# " coords={",
|
|
1036
|
+
# " 'time': [output_timestamp.values],",
|
|
1037
|
+
# " spatial_dims[0]: spatial_coords[spatial_dims[0]].values,",
|
|
1038
|
+
# " spatial_dims[1]: spatial_coords[spatial_dims[1]].values",
|
|
1039
|
+
# " },",
|
|
1040
|
+
# " attrs={",
|
|
1041
|
+
# " 'description': 'CNN model prediction',",
|
|
1042
|
+
# " }",
|
|
1043
|
+
# " )",
|
|
1044
|
+
# " ",
|
|
1045
|
+
# " logging.info(f\"Final result shape: {result.shape}\")",
|
|
1046
|
+
# " logging.info(f\"Final result data type: {result.dtype}\")",
|
|
1047
|
+
# " logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
|
|
1048
|
+
# ])
|
|
1049
|
+
|
|
1050
|
+
# # Add postprocessing call if postprocessing exists
|
|
1051
|
+
# if postprocessing_section:
|
|
1052
|
+
# script_lines.extend([
|
|
1053
|
+
# " # Apply postprocessing",
|
|
1054
|
+
# " result = postprocessing(result)",
|
|
1055
|
+
# " ",
|
|
1056
|
+
# ])
|
|
1057
|
+
|
|
1058
|
+
# # Single return statement at the end
|
|
1059
|
+
# script_lines.append(" return result")
|
|
1060
|
+
|
|
1061
|
+
# return "\n".join(script_lines)
|
|
1062
|
+
|
|
1063
|
+
|
|
751
1064
|
@require_api_key
|
|
752
1065
|
def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str, preprocessing_code: Optional[str] = None, postprocessing_code: Optional[str] = None) -> str:
|
|
753
1066
|
"""
|
|
@@ -794,6 +1107,160 @@ class ModelManagement:
|
|
|
794
1107
|
"",
|
|
795
1108
|
]
|
|
796
1109
|
|
|
1110
|
+
# Add preprocessing validation function if preprocessing exists
|
|
1111
|
+
if preprocessing_section:
|
|
1112
|
+
script_lines.extend([
|
|
1113
|
+
"def validate_preprocessing_output(data_arrays):",
|
|
1114
|
+
" \"\"\"",
|
|
1115
|
+
" Validate preprocessing output coordinates and data type.",
|
|
1116
|
+
" ",
|
|
1117
|
+
" Args:",
|
|
1118
|
+
" data_arrays: List of xarray DataArrays from preprocessing",
|
|
1119
|
+
" ",
|
|
1120
|
+
" Returns:",
|
|
1121
|
+
" str: Validation signature symbol",
|
|
1122
|
+
" ",
|
|
1123
|
+
" Raises:",
|
|
1124
|
+
" ValueError: If validation fails",
|
|
1125
|
+
" \"\"\"",
|
|
1126
|
+
" import numpy as np",
|
|
1127
|
+
" ",
|
|
1128
|
+
" logging.info(\"=\" * 60)",
|
|
1129
|
+
" logging.info(\"VALIDATING PREPROCESSING OUTPUT\")",
|
|
1130
|
+
" logging.info(\"=\" * 60)",
|
|
1131
|
+
" ",
|
|
1132
|
+
" if not data_arrays:",
|
|
1133
|
+
" raise ValueError(\"No data arrays provided from preprocessing\")",
|
|
1134
|
+
" ",
|
|
1135
|
+
" reference_shape = None",
|
|
1136
|
+
" ",
|
|
1137
|
+
" for i, data_array in enumerate(data_arrays):",
|
|
1138
|
+
" logging.info(f\"Validating channel {i+1}/{len(data_arrays)}: {data_array.name}\")",
|
|
1139
|
+
" ",
|
|
1140
|
+
" # Check if it's an xarray DataArray",
|
|
1141
|
+
" if not hasattr(data_array, 'dims') or not hasattr(data_array, 'coords'):",
|
|
1142
|
+
" raise ValueError(f\"Channel {i+1} is not a valid xarray DataArray\")",
|
|
1143
|
+
" ",
|
|
1144
|
+
" # Check coordinates",
|
|
1145
|
+
" if 'time' not in data_array.coords:",
|
|
1146
|
+
" raise ValueError(f\"Channel {i+1} missing time coordinate\")",
|
|
1147
|
+
" ",
|
|
1148
|
+
" spatial_dims = [dim for dim in data_array.dims if dim != 'time']",
|
|
1149
|
+
" if len(spatial_dims) != 2:",
|
|
1150
|
+
" raise ValueError(f\"Channel {i+1} must have exactly 2 spatial dimensions, got {spatial_dims}\")",
|
|
1151
|
+
" ",
|
|
1152
|
+
" for dim in spatial_dims:",
|
|
1153
|
+
" if dim not in data_array.coords:",
|
|
1154
|
+
" raise ValueError(f\"Channel {i+1} missing coordinate: {dim}\")",
|
|
1155
|
+
" ",
|
|
1156
|
+
" logging.info(f\" Coordinates: {list(data_array.coords.keys())}\")",
|
|
1157
|
+
" ",
|
|
1158
|
+
" # Check data type",
|
|
1159
|
+
" data_values = data_array.values",
|
|
1160
|
+
" logging.info(f\" Data type: {data_values.dtype}\")",
|
|
1161
|
+
" ",
|
|
1162
|
+
" # Check shape consistency",
|
|
1163
|
+
" shape = data_array.shape",
|
|
1164
|
+
" if reference_shape is None:",
|
|
1165
|
+
" reference_shape = shape",
|
|
1166
|
+
" else:",
|
|
1167
|
+
" if shape != reference_shape:",
|
|
1168
|
+
" raise ValueError(f\"Channel {i+1} shape {shape} doesn't match reference {reference_shape}\")",
|
|
1169
|
+
" ",
|
|
1170
|
+
" logging.info(f\" Shape: {shape}\")",
|
|
1171
|
+
" ",
|
|
1172
|
+
" # Generate validation signature",
|
|
1173
|
+
" signature_components = [",
|
|
1174
|
+
" f\"CH{len(data_arrays)}\", # Channel count",
|
|
1175
|
+
" f\"T{reference_shape[0]}\", # Time dimension",
|
|
1176
|
+
" f\"S{reference_shape[1]}x{reference_shape[2]}\", # Spatial dimensions",
|
|
1177
|
+
" f\"DT{data_arrays[0].values.dtype}\", # Data type",
|
|
1178
|
+
" ]",
|
|
1179
|
+
" ",
|
|
1180
|
+
" signature = \"★PRE_\" + \"_\".join(signature_components) + \"★\"",
|
|
1181
|
+
" ",
|
|
1182
|
+
" logging.info(\"-\" * 60)",
|
|
1183
|
+
" logging.info(\"PREPROCESSING VALIDATION SUMMARY\")",
|
|
1184
|
+
" logging.info(\"-\" * 60)",
|
|
1185
|
+
" logging.info(f\"Channels validated: {len(data_arrays)}\")",
|
|
1186
|
+
" logging.info(f\"Common shape: {reference_shape}\")",
|
|
1187
|
+
" logging.info(f\"Validation signature: {signature}\")",
|
|
1188
|
+
" logging.info(\"=\" * 60)",
|
|
1189
|
+
" ",
|
|
1190
|
+
" return signature",
|
|
1191
|
+
"",
|
|
1192
|
+
])
|
|
1193
|
+
|
|
1194
|
+
# Add postprocessing validation function if postprocessing exists
|
|
1195
|
+
if postprocessing_section:
|
|
1196
|
+
script_lines.extend([
|
|
1197
|
+
"def validate_postprocessing_output(result_array):",
|
|
1198
|
+
" \"\"\"",
|
|
1199
|
+
" Validate postprocessing output coordinates and data type.",
|
|
1200
|
+
" ",
|
|
1201
|
+
" Args:",
|
|
1202
|
+
" result_array: xarray DataArray from postprocessing",
|
|
1203
|
+
" ",
|
|
1204
|
+
" Returns:",
|
|
1205
|
+
" str: Validation signature symbol",
|
|
1206
|
+
" ",
|
|
1207
|
+
" Raises:",
|
|
1208
|
+
" ValueError: If validation fails",
|
|
1209
|
+
" \"\"\"",
|
|
1210
|
+
" import numpy as np",
|
|
1211
|
+
" ",
|
|
1212
|
+
" logging.info(\"=\" * 60)",
|
|
1213
|
+
" logging.info(\"VALIDATING POSTPROCESSING OUTPUT\")",
|
|
1214
|
+
" logging.info(\"=\" * 60)",
|
|
1215
|
+
" ",
|
|
1216
|
+
" # Check if it's an xarray DataArray",
|
|
1217
|
+
" if not hasattr(result_array, 'dims') or not hasattr(result_array, 'coords'):",
|
|
1218
|
+
" raise ValueError(\"Postprocessing output is not a valid xarray DataArray\")",
|
|
1219
|
+
" ",
|
|
1220
|
+
" # Check required coordinates",
|
|
1221
|
+
" if 'time' not in result_array.coords:",
|
|
1222
|
+
" raise ValueError(\"Missing time coordinate\")",
|
|
1223
|
+
" ",
|
|
1224
|
+
" spatial_dims = [dim for dim in result_array.dims if dim != 'time']",
|
|
1225
|
+
" if len(spatial_dims) != 2:",
|
|
1226
|
+
" raise ValueError(f\"Expected 2 spatial dimensions, got {len(spatial_dims)}: {spatial_dims}\")",
|
|
1227
|
+
" ",
|
|
1228
|
+
" for dim in spatial_dims:",
|
|
1229
|
+
" if dim not in result_array.coords:",
|
|
1230
|
+
" raise ValueError(f\"Missing spatial coordinate: {dim}\")",
|
|
1231
|
+
" ",
|
|
1232
|
+
" logging.info(f\"Coordinates found: {list(result_array.coords.keys())}\")",
|
|
1233
|
+
" ",
|
|
1234
|
+
" # Check data type",
|
|
1235
|
+
" data_values = result_array.values",
|
|
1236
|
+
" logging.info(f\"Data type: {data_values.dtype}\")",
|
|
1237
|
+
" ",
|
|
1238
|
+
" # Check shape",
|
|
1239
|
+
" shape = result_array.shape",
|
|
1240
|
+
" logging.info(f\"Shape: {shape}\")",
|
|
1241
|
+
" ",
|
|
1242
|
+
" # Generate validation signature",
|
|
1243
|
+
" signature_components = [",
|
|
1244
|
+
" f\"T{shape[0]}\", # Time dimension",
|
|
1245
|
+
" f\"S{shape[1]}x{shape[2]}\", # Spatial dimensions",
|
|
1246
|
+
" f\"DT{data_values.dtype}\", # Data type",
|
|
1247
|
+
" ]",
|
|
1248
|
+
" ",
|
|
1249
|
+
" signature = \"★POST_\" + \"_\".join(signature_components) + \"★\"",
|
|
1250
|
+
" ",
|
|
1251
|
+
" logging.info(\"-\" * 60)",
|
|
1252
|
+
" logging.info(\"POSTPROCESSING VALIDATION SUMMARY\")",
|
|
1253
|
+
" logging.info(\"-\" * 60)",
|
|
1254
|
+
" logging.info(f\"Final shape: {shape}\")",
|
|
1255
|
+
" logging.info(f\"Final coordinates: {list(result_array.coords.keys())}\")",
|
|
1256
|
+
" logging.info(f\"Data type: {data_values.dtype}\")",
|
|
1257
|
+
" logging.info(f\"Validation signature: {signature}\")",
|
|
1258
|
+
" logging.info(\"=\" * 60)",
|
|
1259
|
+
" ",
|
|
1260
|
+
" return signature",
|
|
1261
|
+
"",
|
|
1262
|
+
])
|
|
1263
|
+
|
|
797
1264
|
# Add preprocessing function definition BEFORE the main function
|
|
798
1265
|
if preprocessing_section:
|
|
799
1266
|
script_lines.extend([
|
|
@@ -835,13 +1302,17 @@ class ModelManagement:
|
|
|
835
1302
|
" ",
|
|
836
1303
|
])
|
|
837
1304
|
|
|
838
|
-
# Add preprocessing call if preprocessing exists
|
|
1305
|
+
# Add preprocessing call and validation if preprocessing exists
|
|
839
1306
|
if preprocessing_section:
|
|
840
1307
|
script_lines.extend([
|
|
841
1308
|
" # Apply preprocessing",
|
|
842
1309
|
" data_arrays = preprocessing(tuple(data_arrays))",
|
|
843
1310
|
" data_arrays = list(data_arrays) # Convert back to list for processing",
|
|
844
1311
|
" ",
|
|
1312
|
+
" # Validate preprocessing output",
|
|
1313
|
+
" preprocessing_signature = validate_preprocessing_output(data_arrays)",
|
|
1314
|
+
" logging.info(f\"Preprocessing validation signature: {preprocessing_signature}\")",
|
|
1315
|
+
" ",
|
|
845
1316
|
])
|
|
846
1317
|
|
|
847
1318
|
# Continue with the rest of the processing logic
|
|
@@ -973,19 +1444,23 @@ class ModelManagement:
|
|
|
973
1444
|
" logging.info(f\"Final result value range: {result.values.min()} to {result.values.max()}\")",
|
|
974
1445
|
])
|
|
975
1446
|
|
|
976
|
-
# Add postprocessing call if postprocessing exists
|
|
1447
|
+
# Add postprocessing call and validation if postprocessing exists
|
|
977
1448
|
if postprocessing_section:
|
|
978
1449
|
script_lines.extend([
|
|
979
1450
|
" # Apply postprocessing",
|
|
980
1451
|
" result = postprocessing(result)",
|
|
981
1452
|
" ",
|
|
1453
|
+
" # Validate postprocessing output",
|
|
1454
|
+
" postprocessing_signature = validate_postprocessing_output(result)",
|
|
1455
|
+
" logging.info(f\"Postprocessing validation signature: {postprocessing_signature}\")",
|
|
1456
|
+
" ",
|
|
982
1457
|
])
|
|
983
1458
|
|
|
984
1459
|
# Single return statement at the end
|
|
985
1460
|
script_lines.append(" return result")
|
|
986
1461
|
|
|
987
1462
|
return "\n".join(script_lines)
|
|
988
|
-
|
|
1463
|
+
|
|
989
1464
|
@require_api_key
|
|
990
1465
|
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
991
1466
|
"""Upload the generated script to Google Cloud Storage"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|