terrakio-core 0.3.6__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

terrakio_core/__init__.py CHANGED
@@ -5,7 +5,7 @@ Terrakio Core
5
5
  Core components for Terrakio API clients.
6
6
  """
7
7
 
8
- __version__ = "0.3.4"
8
+ __version__ = "0.3.8"
9
9
 
10
10
  from .async_client import AsyncClient
11
11
  from .sync_client import SyncClient
@@ -46,65 +46,65 @@ class AsyncClient(BaseClient):
46
46
  else:
47
47
  return await self._make_request_with_retry(self._session, method, endpoint, **kwargs)
48
48
 
49
-
50
49
  async def _make_request_with_retry(self, session: aiohttp.ClientSession, method: str, endpoint: str, **kwargs) -> Dict[Any, Any]:
51
50
  url = f"{self.url}/{endpoint.lstrip('/')}"
51
+ last_exception = None
52
+
52
53
  for attempt in range(self.retry + 1):
53
- try:
54
+ try:
54
55
  async with session.request(method, url, **kwargs) as response:
55
- response_text = await response.text()
56
+ if not response.ok and self._should_retry(response.status, attempt):
57
+ self.logger.info(f"Request failed (attempt {attempt+1}/{self.retry+1}): {response.status}. Retrying...")
58
+ continue
56
59
  if not response.ok:
57
- should_retry = False
58
-
59
- if response.status == 400:
60
- should_retry = False
61
- else:
62
- if response.status in [408, 502, 503, 504]:
63
- should_retry = True
64
- elif response.status == 500:
65
- try:
66
- response_text = await response.text()
67
- if "Internal server error" not in response_text:
68
- should_retry = True
69
- except:
70
- should_retry = True
71
-
72
- if should_retry and attempt < self.retry:
73
- self.logger.info(f"Request failed (attempt {attempt+1}/{self.retry+1}): {response.status} {response.reason}. Retrying...")
74
- continue
75
- else:
76
- error_msg = f"API request failed: {response.status} {response.reason}"
77
- try:
78
- error_data = await response.json()
79
- if "detail" in error_data:
80
- error_msg += f" - {error_data['detail']}"
81
- except:
82
- pass
83
- raise APIError(error_msg, status_code=response.status)
84
-
85
- content_type = response.headers.get('content-type', '').lower()
86
- content = await response.read()
87
- if 'json' in content_type:
88
- return json.loads(content.decode('utf-8'))
89
- elif 'csv' in content_type:
90
- return pd.read_csv(BytesIO(content))
91
- elif 'image/' in content_type:
92
- return content
93
- elif 'text' in content_type:
94
- return content.decode('utf-8')
95
- else:
60
+ error_msg = f"API request failed: {response.status} {response.reason}"
96
61
  try:
97
- return xr.open_dataset(BytesIO(content))
62
+ error_data = await response.json()
63
+ if "detail" in error_data:
64
+ error_msg += f" - {error_data['detail']}"
98
65
  except:
99
- raise APIError(f"Unknown response format. Content-Type: {response.headers.get('content-type', 'unknown')}", status_code=response.status)
66
+ pass
67
+ raise APIError(error_msg, status_code=response.status)
68
+ return await self._parse_response(response)
69
+
100
70
  except aiohttp.ClientError as e:
71
+ last_exception = e
101
72
  if attempt < self.retry:
102
- self.logger.info(f"Request failed (attempt {attempt+1}/{self.retry+1}): {e}. Retrying...")
73
+ self.logger.info(f"Networking error (attempt {attempt+1}/{self.retry+1}): {e}. Retrying...")
103
74
  continue
104
75
  else:
105
- raise APIError(f"Request failed after {self.retry+1} attempts: {e}", status_code=None)
106
-
76
+ break
77
+
78
+ raise APIError(f"Networking error, request failed after {self.retry+1} attempts: {last_exception}", status_code=None)
79
+
80
+ def _should_retry(self, status_code: int, attempt: int) -> bool:
81
+ """Determine if the request should be retried based on status code."""
82
+ if attempt >= self.retry:
83
+ return False
84
+ elif status_code in [408, 502, 503, 504]:
85
+ return True
86
+ else:
87
+ return False
88
+
89
+ async def _parse_response(self, response) -> Any:
90
+ """Parse response based on content type."""
91
+ content_type = response.headers.get('content-type', '').lower()
92
+ content = await response.read()
107
93
 
94
+ if 'json' in content_type:
95
+ return json.loads(content.decode('utf-8'))
96
+ elif 'csv' in content_type:
97
+ return pd.read_csv(BytesIO(content))
98
+ elif 'image/' in content_type:
99
+ return content
100
+ elif 'text' in content_type:
101
+ return content.decode('utf-8')
102
+ else:
103
+ try:
104
+ return xr.open_dataset(BytesIO(content))
105
+ except:
106
+ raise APIError(f"Unknown response format: {content_type}", status_code=response.status)
107
+
108
108
  async def _regular_request(self, method: str, endpoint: str, **kwargs):
109
109
  url = endpoint.lstrip('/')
110
110
  if self._session is None:
@@ -134,6 +134,7 @@ class AsyncClient(BaseClient):
134
134
  output: str = "csv",
135
135
  resolution: int = -1,
136
136
  geom_fix: bool = False,
137
+ validated: bool = True,
137
138
  **kwargs
138
139
  ):
139
140
  """
@@ -147,6 +148,7 @@ class AsyncClient(BaseClient):
147
148
  output (str): Output format ('csv' or 'netcdf')
148
149
  resolution (int): Resolution parameter
149
150
  geom_fix (bool): Whether to fix the geometry (default False)
151
+ validated (bool): Whether to use validated data (default True)
150
152
  **kwargs: Additional parameters to pass to the WCS request
151
153
 
152
154
  Returns:
@@ -169,6 +171,7 @@ class AsyncClient(BaseClient):
169
171
  "resolution": resolution,
170
172
  "expr": expr,
171
173
  "buffer": geom_fix,
174
+ "validated": validated,
172
175
  **kwargs
173
176
  }
174
177
  return await self._terrakio_request("POST", "geoquery", json=payload)
@@ -42,7 +42,7 @@ class DatasetManagement:
42
42
  return self._client._terrakio_request("GET", f"/datasets/{name}", params = params)
43
43
 
44
44
  @require_api_key
45
- def create_dataset(
45
+ async def create_dataset(
46
46
  self,
47
47
  name: str,
48
48
  collection: str = "terrakio-datasets",
@@ -59,7 +59,8 @@ class DatasetManagement:
59
59
  proj4: Optional[str] = None,
60
60
  abstract: Optional[str] = None,
61
61
  geotransform: Optional[List[float]] = None,
62
- padding: Optional[Any] = None
62
+ padding: Optional[Any] = None,
63
+ input: Optional[str] = None
63
64
  ) -> Dict[str, Any]:
64
65
  """
65
66
  Create a new dataset.
@@ -104,12 +105,13 @@ class DatasetManagement:
104
105
  "proj4": proj4,
105
106
  "abstract": abstract,
106
107
  "geotransform": geotransform,
107
- "padding": padding
108
+ "padding": padding,
109
+ "input": input
108
110
  }
109
111
  for param, value in param_mapping.items():
110
112
  if value is not None:
111
113
  payload[param] = value
112
- return self._client._terrakio_request("POST", "/datasets", params = params, json = payload)
114
+ return await self._client._terrakio_request("POST", "/datasets", params = params, json = payload)
113
115
 
114
116
  @require_api_key
115
117
  def update_dataset(
@@ -12,7 +12,7 @@ class MassStats:
12
12
  self._client = client
13
13
 
14
14
  @require_api_key
15
- async def upload_request(
15
+ async def _upload_request(
16
16
  self,
17
17
  name: str,
18
18
  size: int,
@@ -220,7 +220,7 @@ class MassStats:
220
220
  return self._client._terrakio_request("GET", "mass_stats/download", params=params)
221
221
 
222
222
  @require_api_key
223
- async def upload_file(self, file_path: str, url: str, use_gzip: bool = False):
223
+ async def _upload_file(self, file_path: str, url: str, use_gzip: bool = False):
224
224
  """
225
225
  Helper method to upload a JSON file to a signed URL.
226
226
 
@@ -427,7 +427,7 @@ class MassStats:
427
427
  return e
428
428
  except json.JSONDecodeError as e:
429
429
  return e
430
- upload_result = await self.upload_request(name = name, size = size, region = region, output = output, config = config, location = location, force_loc = force_loc, overwrite = overwrite, server = server, skip_existing = skip_existing)
430
+ upload_result = await self._upload_request(name = name, size = size, region = region, output = output, config = config, location = location, force_loc = force_loc, overwrite = overwrite, server = server, skip_existing = skip_existing)
431
431
  requests_url = upload_result.get('requests_url')
432
432
  manifest_url = upload_result.get('manifest_url')
433
433
  if not requests_url:
@@ -436,7 +436,7 @@ class MassStats:
436
436
  try:
437
437
  # in this place we are uploading the request json file, we need to check whether the json is in the correct format or not
438
438
  self.validate_request(request_json)
439
- requests_response = await self.upload_file(request_json, requests_url, use_gzip=True)
439
+ requests_response = await self._upload_file(request_json, requests_url, use_gzip=True)
440
440
  if requests_response.status not in [200, 201, 204]:
441
441
  self._client.logger.error(f"Requests upload error: {requests_response.text()}")
442
442
  raise Exception(f"Failed to upload request JSON: {requests_response.text()}")
@@ -447,7 +447,7 @@ class MassStats:
447
447
  raise ValueError("No manifest_url returned from server for manifest JSON upload")
448
448
 
449
449
  try:
450
- manifest_response = await self.upload_file(manifest_json, manifest_url, use_gzip=False)
450
+ manifest_response = await self._upload_file(manifest_json, manifest_url, use_gzip=False)
451
451
  if manifest_response.status not in [200, 201, 204]:
452
452
  self._client.logger.error(f"Manifest upload error: {manifest_response.text()}")
453
453
  raise Exception(f"Failed to upload manifest JSON: {manifest_response.text()}")
@@ -3,9 +3,32 @@ import json
3
3
  import time
4
4
  import textwrap
5
5
  import logging
6
- from typing import Dict, Any
6
+ from typing import Dict, Any, Union, Tuple
7
+ from io import BytesIO
8
+ import numpy as np
7
9
  from google.cloud import storage
8
10
  from ..helper.decorators import require_token, require_api_key, require_auth
11
+ TORCH_AVAILABLE = False
12
+ SKL2ONNX_AVAILABLE = False
13
+
14
+ try:
15
+ import torch
16
+ TORCH_AVAILABLE = True
17
+ except ImportError:
18
+ torch = None
19
+
20
+ try:
21
+ from skl2onnx import convert_sklearn
22
+ from skl2onnx.common.data_types import FloatTensorType
23
+ from sklearn.base import BaseEstimator
24
+ SKL2ONNX_AVAILABLE = True
25
+ except ImportError:
26
+ convert_sklearn = None
27
+ FloatTensorType = None
28
+ BaseEstimator = None
29
+
30
+ from io import BytesIO
31
+ from typing import Tuple
9
32
 
10
33
  class ModelManagement:
11
34
  def __init__(self, client):
@@ -115,14 +138,13 @@ class ModelManagement:
115
138
  bar = '█' * filled_length + '░' * (bar_length - filled_length)
116
139
  percentage = progress * 100
117
140
 
118
- # Print status with progress bar
119
- print(f"\rJob status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})", end='')
141
+ self._client.logger.info(f"Job status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})")
120
142
 
121
143
  if status == "Completed":
122
- print("\nJob completed successfully!")
144
+ self._client.logger.info("Job completed successfully!")
123
145
  break
124
146
  elif status == "Error":
125
- print("\n") # New line before error message
147
+ self._client.logger.info("Job encountered an error")
126
148
  raise Exception(f"Job {task_id} encountered an error")
127
149
 
128
150
  # Wait 5 seconds before checking again
@@ -133,43 +155,238 @@ class ModelManagement:
133
155
  return task_id
134
156
 
135
157
  @require_api_key
136
- async def upload_model(self, model_path: str):
158
+ async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
137
159
  """
138
160
  Upload a model to the bucket so that it can be used for inference.
161
+ Converts PyTorch and scikit-learn models to ONNX format before uploading.
139
162
 
140
163
  Args:
141
- model_path: Path to the model file
164
+ model: The model object (PyTorch model or scikit-learn model)
165
+ model_name: Name for the model (without extension)
166
+ input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
167
+ Required for PyTorch models, optional for scikit-learn models
142
168
 
143
169
  Raises:
144
170
  APIError: If the API request fails
171
+ ValueError: If model type is not supported or input_shape is missing for PyTorch models
172
+ ImportError: If required libraries (torch or skl2onnx) are not installed
145
173
  """
146
174
  uid = (await self._client.auth.get_user_info())["uid"]
147
- model_name = os.path.basename(model_path)
148
175
 
149
176
  client = storage.Client()
150
177
  bucket = client.get_bucket('terrakio-mass-requests')
151
- model_file_name = os.path.splitext(model_name)[0]
152
- blob = bucket.blob(f'{uid}/{model_file_name}/models/{model_name}')
153
178
 
154
- blob.upload_from_filename(model_path)
155
- self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}")
179
+ # Convert model to ONNX format
180
+ onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
181
+
182
+ # Upload ONNX model to bucket
183
+ blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
184
+
185
+ blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
186
+ self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
187
+
188
+ def _convert_model_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
189
+ """
190
+ Convert a model to ONNX format and return as bytes.
191
+
192
+ Args:
193
+ model: The model object (PyTorch or scikit-learn)
194
+ model_name: Name of the model for logging
195
+ input_shape: Shape of input data
196
+
197
+ Returns:
198
+ bytes: ONNX model as bytes
199
+
200
+ Raises:
201
+ ValueError: If model type is not supported
202
+ ImportError: If required libraries are not installed
203
+ """
204
+ # Early check for any conversion capability
205
+ if not (TORCH_AVAILABLE or SKL2ONNX_AVAILABLE):
206
+ raise ImportError(
207
+ "ONNX conversion requires additional dependencies. Install with:\n"
208
+ " pip install torch # For PyTorch models\n"
209
+ " pip install skl2onnx # For scikit-learn models\n"
210
+ " pip install torch skl2onnx # For both"
211
+ )
212
+
213
+ # Check if it's a PyTorch model using isinstance (preferred) with fallback
214
+ is_pytorch = False
215
+ if TORCH_AVAILABLE:
216
+ is_pytorch = (isinstance(model, torch.nn.Module) or
217
+ hasattr(model, 'state_dict'))
218
+
219
+ # Check if it's a scikit-learn model
220
+ is_sklearn = False
221
+ if SKL2ONNX_AVAILABLE:
222
+ is_sklearn = (isinstance(model, BaseEstimator) or
223
+ (hasattr(model, 'fit') and hasattr(model, 'predict')))
224
+
225
+ if is_pytorch and TORCH_AVAILABLE:
226
+ return self._convert_pytorch_to_onnx(model, model_name, input_shape)
227
+ elif is_sklearn and SKL2ONNX_AVAILABLE:
228
+ return self._convert_sklearn_to_onnx(model, model_name, input_shape)
229
+ else:
230
+ # Provide helpful error message
231
+ model_type = type(model).__name__
232
+ model_module = type(model).__module__
233
+ available_types = []
234
+ missing_deps = []
235
+
236
+ if TORCH_AVAILABLE:
237
+ available_types.append("PyTorch (torch.nn.Module)")
238
+ else:
239
+ missing_deps.append("torch")
240
+
241
+ if SKL2ONNX_AVAILABLE:
242
+ available_types.append("scikit-learn (BaseEstimator)")
243
+ else:
244
+ missing_deps.append("skl2onnx")
245
+
246
+ if missing_deps:
247
+ raise ImportError(
248
+ f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
249
+ f"Install with: pip install {' '.join(missing_deps)}"
250
+ )
251
+ else:
252
+ raise ValueError(
253
+ f"Unsupported model type: {model_type} from {model_module}. "
254
+ f"Supported types: {', '.join(available_types)}"
255
+ )
156
256
 
257
+ def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
258
+ """Convert PyTorch model to ONNX format with dynamic input dimensions."""
259
+ if input_shape is None:
260
+ raise ValueError("input_shape is required for PyTorch models")
261
+
262
+ self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
263
+
264
+ try:
265
+ # Set model to evaluation mode
266
+ model.eval()
267
+
268
+ # Create dummy input
269
+ dummy_input = torch.randn(input_shape)
270
+
271
+ # Use BytesIO to avoid creating temporary files
272
+ onnx_buffer = BytesIO()
273
+
274
+ # Determine dynamic axes based on input shape
275
+ # Common patterns for different input types:
276
+ if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
277
+ dynamic_axes = {
278
+ 'float_input': {
279
+ 0: 'batch_size',
280
+ 2: 'height', # Make height dynamic for variable input sizes
281
+ 3: 'width' # Make width dynamic for variable input sizes
282
+ },
283
+ 'output': {0: 'batch_size'}
284
+ }
285
+ elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
286
+ dynamic_axes = {
287
+ 'float_input': {
288
+ 0: 'batch_size',
289
+ 1: 'dim1', # Generic dynamic dimension
290
+ 2: 'dim2' # Generic dynamic dimension
291
+ },
292
+ 'output': {0: 'batch_size'}
293
+ }
294
+ elif len(input_shape) == 2: # Likely (batch, features)
295
+ dynamic_axes = {
296
+ 'float_input': {
297
+ 0: 'batch_size'
298
+ # Don't make features dynamic as it usually affects model architecture
299
+ },
300
+ 'output': {0: 'batch_size'}
301
+ }
302
+ else:
303
+ # For other shapes, just make batch size dynamic
304
+ dynamic_axes = {
305
+ 'float_input': {0: 'batch_size'},
306
+ 'output': {0: 'batch_size'}
307
+ }
308
+
309
+ torch.onnx.export(
310
+ model,
311
+ dummy_input,
312
+ onnx_buffer,
313
+ export_params=True,
314
+ opset_version=11,
315
+ do_constant_folding=True,
316
+ input_names=['float_input'],
317
+ output_names=['output'],
318
+ dynamic_axes=dynamic_axes
319
+ )
320
+
321
+ self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
322
+ return onnx_buffer.getvalue()
323
+
324
+ except Exception as e:
325
+ raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
326
+
327
+
328
+ def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
329
+ """Convert scikit-learn model to ONNX format."""
330
+ self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
331
+
332
+ # Try to infer input shape if not provided
333
+ if input_shape is None:
334
+ if hasattr(model, 'n_features_in_'):
335
+ input_shape = (1, model.n_features_in_)
336
+ else:
337
+ raise ValueError(
338
+ "input_shape is required for scikit-learn models when n_features_in_ is not available. "
339
+ "This usually happens with older sklearn versions or models not fitted yet."
340
+ )
341
+
342
+ try:
343
+ # Convert scikit-learn model to ONNX
344
+ initial_type = [('float_input', FloatTensorType(input_shape))]
345
+ onnx_model = convert_sklearn(model, initial_types=initial_type)
346
+ return onnx_model.SerializeToString()
347
+
348
+ except Exception as e:
349
+ raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
350
+
157
351
  @require_api_key
158
- def upload_and_deploy_model(self, model_path: str, dataset: str, product: str, input_expression: str, dates_iso8601: list):
352
+ async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
353
+ """
354
+ Upload a CNN model to the bucket and deploy it.
159
355
 
356
+ Args:
357
+ model: The model object (PyTorch model or scikit-learn model)
358
+ model_name: Name for the model (without extension)
359
+ dataset: Name of the dataset to create
360
+ product: Product name for the inference
361
+ input_expression: Input expression for the dataset
362
+ dates_iso8601: List of dates in ISO8601 format
363
+ input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
364
+
365
+ Raises:
366
+ APIError: If the API request fails
367
+ ValueError: If model type is not supported or input_shape is missing for PyTorch models
368
+ ImportError: If required libraries (torch or skl2onnx) are not installed
369
+ """
370
+ await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
371
+ # so the uploading process is kinda similar, but the deployment step is kinda different
372
+ await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
373
+
374
+ @require_api_key
375
+ async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
160
376
  """
161
377
  Upload a model to the bucket and deploy it.
162
378
 
163
379
  Args:
164
- model_path: Path to the model file
165
- dataset Name of the dataset to create
380
+ model: The model object (PyTorch model or scikit-learn model)
381
+ model_name: Name for the model (without extension)
382
+ dataset: Name of the dataset to create
166
383
  product: Product name for the inference
167
384
  input_expression: Input expression for the dataset
168
385
  dates_iso8601: List of dates in ISO8601 format
386
+ input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
169
387
  """
170
- self.upload_model(model_path = model_path)
171
- model_name = os.path.basename(model_path)
172
- self.deploy_model(dataset = dataset, product = product, model_name = model_name, input_expression = input_expression, model_training_job_name = model_name, dates_iso8601 = dates_iso8601)
388
+ await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
389
+ await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
173
390
 
174
391
  @require_api_key
175
392
  def train_model(
@@ -212,7 +429,7 @@ class ModelManagement:
212
429
  return self._client._terrakio_request("POST", "/train_model", json=payload)
213
430
 
214
431
  @require_api_key
215
- def deploy_model(
432
+ async def deploy_model(
216
433
  self,
217
434
  dataset: str,
218
435
  product: str,
@@ -239,7 +456,7 @@ class ModelManagement:
239
456
  APIError: If the API request fails
240
457
  """
241
458
  # Get user info to get UID
242
- user_info = self._client.get_user_info()
459
+ user_info = await self._client.auth.get_user_info()
243
460
  uid = user_info["uid"]
244
461
 
245
462
  # Generate and upload script
@@ -248,7 +465,53 @@ class ModelManagement:
248
465
  self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
249
466
 
250
467
  # Create dataset
251
- return self._client.datasets.create_dataset(
468
+ return await self._client.datasets.create_dataset(
469
+ name=dataset,
470
+ collection="terrakio-datasets",
471
+ products=[product],
472
+ path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
473
+ input=input_expression,
474
+ dates_iso8601=dates_iso8601,
475
+ padding=0
476
+ )
477
+
478
+ @require_api_key
479
+ async def deploy_cnn_model(
480
+ self,
481
+ dataset: str,
482
+ product: str,
483
+ model_name: str,
484
+ input_expression: str,
485
+ model_training_job_name: str,
486
+ dates_iso8601: list
487
+ ) -> Dict[str, Any]:
488
+ """
489
+ Deploy a CNN model by generating inference script and creating dataset.
490
+
491
+ Args:
492
+ dataset: Name of the dataset to create
493
+ product: Product name for the inference
494
+ model_name: Name of the trained model
495
+ input_expression: Input expression for the dataset
496
+ model_training_job_name: Name of the training job
497
+ dates_iso8601: List of dates in ISO8601 format
498
+
499
+ Returns:
500
+ dict: Response from the deployment process
501
+
502
+ Raises:
503
+ APIError: If the API request fails
504
+ """
505
+ # Get user info to get UID
506
+ user_info = await self._client.auth.get_user_info()
507
+ uid = user_info["uid"]
508
+
509
+ # Generate and upload script
510
+ script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid)
511
+ script_name = f"{product}.py"
512
+ self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
513
+ # Create dataset
514
+ return await self._client.datasets.create_dataset(
252
515
  name=dataset,
253
516
  collection="terrakio-datasets",
254
517
  products=[product],
@@ -374,6 +637,148 @@ class ModelManagement:
374
637
  return result
375
638
  ''').strip()
376
639
 
640
+ @require_api_key
641
+ def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
642
+ """
643
+ Generate Python inference script for CNN model with time-stacked bands.
644
+
645
+ Args:
646
+ model_name: Name of the model
647
+ product: Product name
648
+ model_training_job_name: Training job name
649
+ uid: User ID
650
+
651
+ Returns:
652
+ str: Generated Python script content
653
+ """
654
+ return textwrap.dedent(f'''
655
+ import logging
656
+ from io import BytesIO
657
+
658
+ import numpy as np
659
+ import pandas as pd
660
+ import xarray as xr
661
+ from google.cloud import storage
662
+ from onnxruntime import InferenceSession
663
+
664
+ logging.basicConfig(
665
+ level=logging.INFO
666
+ )
667
+
668
+ def get_model():
669
+ logging.info("Loading CNN model for {model_name}...")
670
+
671
+ client = storage.Client()
672
+ bucket = client.get_bucket('terrakio-mass-requests')
673
+ blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
674
+
675
+ model = BytesIO()
676
+ blob.download_to_file(model)
677
+ model.seek(0)
678
+
679
+ session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
680
+ return session
681
+
682
+ def {product}(*bands, model):
683
+ logging.info("Start preparing CNN data with time-stacked bands")
684
+
685
+ data_arrays = list(bands)
686
+
687
+ if not data_arrays:
688
+ raise ValueError("No bands provided")
689
+
690
+ reference_array = data_arrays[0]
691
+ original_shape = reference_array.shape
692
+ logging.info(f"Original shape: {{original_shape}}")
693
+
694
+ # Get time coordinates - all bands should have the same time dimension
695
+ if 'time' not in reference_array.dims:
696
+ raise ValueError("Time dimension is required for CNN processing")
697
+
698
+ time_coords = reference_array.coords['time']
699
+ num_timestamps = len(time_coords)
700
+ logging.info(f"Number of timestamps: {{num_timestamps}}")
701
+
702
+ # Get spatial dimensions
703
+ spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
704
+ height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height
705
+ width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width
706
+ logging.info(f"Spatial dimensions: {{height}} x {{width}}")
707
+
708
+ # Stack bands across time dimension
709
+ # Result will be: (num_bands * num_timestamps, height, width)
710
+ stacked_channels = []
711
+
712
+ for band_idx, data_array in enumerate(data_arrays):
713
+ logging.info(f"Processing band {{band_idx + 1}}/{{len(data_arrays)}}")
714
+
715
+ # Ensure consistent time coordinates across bands
716
+ if not np.array_equal(data_array.coords['time'].values, time_coords.values):
717
+ logging.warning(f"Band {{band_idx}} has different time coordinates, aligning...")
718
+ data_array = data_array.sel(time=time_coords, method='nearest')
719
+
720
+ # Extract values and ensure proper ordering (time, height, width)
721
+ band_values = data_array.values
722
+ if band_values.ndim == 3:
723
+ # Reorder dimensions if needed to ensure (time, height, width)
724
+ time_dim_idx = data_array.dims.index('time')
725
+ if time_dim_idx != 0:
726
+ axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]
727
+ band_values = np.transpose(band_values, axes_order)
728
+
729
+ # Add each timestamp of this band to the channel stack
730
+ for t in range(num_timestamps):
731
+ stacked_channels.append(band_values[t])
732
+
733
+ # Stack all channels: (num_bands * num_timestamps, height, width)
734
+ input_channels = np.stack(stacked_channels, axis=0)
735
+ total_channels = len(data_arrays) * num_timestamps
736
+ logging.info(f"Stacked channels shape: {{input_channels.shape}}")
737
+ logging.info(f"Total channels: {{total_channels}} ({{len(data_arrays)}} bands × {{num_timestamps}} timestamps)")
738
+
739
+ # Add batch dimension: (1, num_channels, height, width)
740
+ input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)
741
+ logging.info(f"Final input shape for CNN: {{input_data.shape}}")
742
+
743
+ # Run inference
744
+ output = model.run(None, {{"float_input": input_data}})[0]
745
+ logging.info(f"Model output shape: {{output.shape}}")
746
+
747
+ # Process output back to xarray format
748
+ # Assuming output is (1, height, width) or (1, 1, height, width)
749
+ if output.ndim == 4 and output.shape[1] == 1:
750
+ # Remove channel dimension if it's 1
751
+ output_2d = output[0, 0]
752
+ elif output.ndim == 3:
753
+ # Remove batch dimension
754
+ output_2d = output[0]
755
+ else:
756
+ # Handle other cases
757
+ output_2d = np.squeeze(output)
758
+ if output_2d.ndim != 2:
759
+ raise ValueError(f"Unexpected output shape after processing: {{output_2d.shape}}")
760
+
761
+ # Determine output timestamp (use the latest timestamp)
762
+ output_timestamp = time_coords[-1]
763
+
764
+ # Get spatial coordinates from reference array
765
+ spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims}}
766
+
767
+ # Create output DataArray
768
+ result = xr.DataArray(
769
+ data=np.expand_dims(output_2d.astype(np.float32), axis=0),
770
+ dims=['time'] + spatial_dims,
771
+ coords={{
772
+ 'time': [output_timestamp.values],
773
+ spatial_dims[0]: spatial_coords[spatial_dims[0]].values,
774
+ spatial_dims[1]: spatial_coords[spatial_dims[1]].values
775
+ }}
776
+ )
777
+
778
+ logging.info(f"Final result shape: {{result.shape}}")
779
+ return result
780
+ ''').strip()
781
+
377
782
  @require_api_key
378
783
  def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
379
784
  """Upload the generated script to Google Cloud Storage"""
@@ -1,7 +1,9 @@
1
1
  class APIError(Exception):
2
2
  """Exception raised for errors in the API responses."""
3
- pass
4
-
3
+
4
+ def __init__(self, message, status_code=None):
5
+ super().__init__(message)
6
+ self.status_code = status_code
5
7
 
6
8
  class ConfigurationError(Exception):
7
9
  """Exception raised for errors in the configuration."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: terrakio-core
3
- Version: 0.3.6
3
+ Version: 0.3.8
4
4
  Summary: Core components for Terrakio API clients
5
5
  Author-email: Yupeng Chao <yupeng@haizea.com.au>
6
6
  Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
@@ -23,6 +23,12 @@ Requires-Dist: shapely>=2.0.0
23
23
  Requires-Dist: geopandas>=0.13.0
24
24
  Requires-Dist: google-cloud-storage>=2.0.0
25
25
  Requires-Dist: nest_asyncio
26
+ Provides-Extra: ml
27
+ Requires-Dist: torch>=2.7.1; extra == "ml"
28
+ Requires-Dist: scikit-learn>=1.7.0; extra == "ml"
29
+ Requires-Dist: skl2onnx>=1.19.1; extra == "ml"
30
+ Requires-Dist: onnx>=1.18.0; extra == "ml"
31
+ Requires-Dist: onnxruntime>=1.10.0; extra == "ml"
26
32
 
27
33
  # Terrakio Core
28
34
 
@@ -1,21 +1,21 @@
1
- terrakio_core/__init__.py,sha256=3MTDxMAQZCgXE5_gREyxTGTNK4WDZHlW91Syr0axr1M,242
2
- terrakio_core/async_client.py,sha256=0Zz-g5X4B2NZHRJrTuJAwNBUfO596Zvn1Nga3J-PiaE,13092
1
+ terrakio_core/__init__.py,sha256=nUk_Q29ij_1R32AjR8Ygwy0_ID4-zdVkQexvBz7reM4,242
2
+ terrakio_core/async_client.py,sha256=wyuIJGMAzfFwfM5BYgCTMuY0YAKNYJqRNB3IF5NpVow,12730
3
3
  terrakio_core/client.py,sha256=h8GW88g6RlGwNFW6MW48c_3BnaeT9nSd19LI1jCn1GU,1008
4
4
  terrakio_core/config.py,sha256=r8NARVYOca4AuM88VP_j-8wQxOk1s7VcRdyEdseBlLE,4193
5
- terrakio_core/exceptions.py,sha256=9S-I20-QiDRj1qgjFyYUwYM7BLic_bxurcDOIm2Fu_0,410
5
+ terrakio_core/exceptions.py,sha256=4qnpOM1gOxsNIXDXY4qwY1d3I4Myhp7HBh7b2D0SVrU,529
6
6
  terrakio_core/sync_client.py,sha256=v1mcBtUaKWACqZgw8dTTVPMxUfKfiY0kjtBKzDwtGTU,13634
7
7
  terrakio_core/convenience_functions/convenience_functions.py,sha256=U7bLGwfBF-FUYc0nv49pAViPsBQ6LgPlV6c6b-zeKo8,10616
8
8
  terrakio_core/endpoints/auth.py,sha256=e_hdNE6JOGhRVlQMFdEoOmoMHp5EzK6CclOEnc_AmZw,5863
9
- terrakio_core/endpoints/dataset_management.py,sha256=8uf6cxlSSevqnQWcldtA9Cd24D5VrmWyxkE7Ngx3IEw,13084
9
+ terrakio_core/endpoints/dataset_management.py,sha256=BUm8IIlW_Q45vDiQp16CiJGeSLheI8uWRVRQtMdhaNk,13161
10
10
  terrakio_core/endpoints/group_management.py,sha256=VFl3jakjQa9OPi351D3DZvLU9M7fHdfjCzGhmyJsx3U,6309
11
- terrakio_core/endpoints/mass_stats.py,sha256=KDmIlMYy4nkehPU5Ejtb_WN9Cz5mkt_rIsyDZkTWOLA,21351
12
- terrakio_core/endpoints/model_management.py,sha256=1ZYymaTQ7IY191sLSS7MWvhrHLmy2VeAM2A1Ty5NhU0,15346
11
+ terrakio_core/endpoints/mass_stats.py,sha256=y1w3QLkDD0sKP1tBcFDqgLYLNxX94I-LYbNotaKhLYM,21356
12
+ terrakio_core/endpoints/model_management.py,sha256=Q2bqsVfBILu-hZVw1tr5WjOR68qoYF6m326YJXgAOeo,33886
13
13
  terrakio_core/endpoints/space_management.py,sha256=YWb55nkJnFJGlALJ520DvurxDqVqwYtsvqQPWzxzhDs,2266
14
14
  terrakio_core/endpoints/user_management.py,sha256=x0JW6VET7eokngmkhZPukegxoJNR1X09BVehJt2nIdI,3781
15
15
  terrakio_core/helper/bounded_taskgroup.py,sha256=wiTH10jhKZgrsgrFUNG6gig8bFkUEPHkGRT2XY7Rgmo,677
16
16
  terrakio_core/helper/decorators.py,sha256=L6om7wmWNgCei3Wy5U0aZ-70OzsCwclkjIf7SfQuhCg,2289
17
17
  terrakio_core/helper/tiles.py,sha256=xNtp3oDD912PN_FQV5fb6uQYhwfHANuXyIcxoVCCfZU,2632
18
- terrakio_core-0.3.6.dist-info/METADATA,sha256=b0a6IvGiQAjaN-iBUu7gmE4-oxaZLIhfw8KQ_xV8tOs,1476
19
- terrakio_core-0.3.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- terrakio_core-0.3.6.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
21
- terrakio_core-0.3.6.dist-info/RECORD,,
18
+ terrakio_core-0.3.8.dist-info/METADATA,sha256=oZlZhEda5qq8myogGbxvlV0ZJExcEf5kMaYWEVES0BE,1728
19
+ terrakio_core-0.3.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ terrakio_core-0.3.8.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
21
+ terrakio_core-0.3.8.dist-info/RECORD,,