terrakio-core 0.3.4__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

@@ -0,0 +1,790 @@
1
+ import os
2
+ import json
3
+ import time
4
+ import textwrap
5
+ import logging
6
+ from typing import Dict, Any, Union, Tuple
7
+ from io import BytesIO
8
+ import numpy as np
9
+ from google.cloud import storage
10
+ from ..helper.decorators import require_token, require_api_key, require_auth
11
+ TORCH_AVAILABLE = False
12
+ SKL2ONNX_AVAILABLE = False
13
+
14
+ try:
15
+ import torch
16
+ TORCH_AVAILABLE = True
17
+ except ImportError:
18
+ torch = None
19
+
20
+ try:
21
+ from skl2onnx import convert_sklearn
22
+ from skl2onnx.common.data_types import FloatTensorType
23
+ from sklearn.base import BaseEstimator
24
+ SKL2ONNX_AVAILABLE = True
25
+ except ImportError:
26
+ convert_sklearn = None
27
+ FloatTensorType = None
28
+ BaseEstimator = None
29
+
30
+ from io import BytesIO
31
+ from typing import Tuple
32
+
33
+ class ModelManagement:
34
+ def __init__(self, client):
35
+ self._client = client
36
+
37
+ @require_api_key
38
+ def generate_ai_dataset(
39
+ self,
40
+ name: str,
41
+ aoi_geojson: str,
42
+ expression_x: str,
43
+ filter_x_rate: float,
44
+ filter_y_rate: float,
45
+ samples: int,
46
+ tile_size: int,
47
+ expression_y: str = "skip",
48
+ filter_x: str = "skip",
49
+ filter_y: str = "skip",
50
+ crs: str = "epsg:4326",
51
+ res: float = 0.001,
52
+ region: str = "aus",
53
+ start_year: int = None,
54
+ end_year: int = None,
55
+ ) -> dict:
56
+ """
57
+ Generate an AI dataset using specified parameters.
58
+
59
+ Args:
60
+ name (str): Name of the dataset to generate
61
+ aoi_geojson (str): Path to GeoJSON file containing area of interest
62
+ expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
63
+ filter_x (str): Filter for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
64
+ filter_x_rate (float): Filter rate for X variable (e.g. 0.5)
65
+ expression_y (str): Expression for Y variable with {year} placeholder
66
+ filter_y (str): Filter for Y variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
67
+ filter_y_rate (float): Filter rate for Y variable (e.g. 0.5)
68
+ samples (int): Number of samples to generate
69
+ tile_size (int): Size of tiles in degrees
70
+ crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
71
+ res (float, optional): Resolution in degrees. Defaults to 0.001
72
+ region (str, optional): Region code. Defaults to "aus"
73
+ start_year (int, optional): Start year for data generation. Required if end_year provided
74
+ end_year (int, optional): End year for data generation. Required if start_year provided
75
+
76
+ Returns:
77
+ dict: Response from the AI dataset generation API
78
+
79
+ Raises:
80
+ APIError: If the API request fails
81
+ """
82
+ # Build config for expressions and filters
83
+ config = {
84
+ "expressions": [{"expr": expression_x, "res": res, "prefix": "x"}],
85
+ "filters": []
86
+ }
87
+
88
+ if expression_y != "skip":
89
+ config["expressions"].append({"expr": expression_y, "res": res, "prefix": "y"})
90
+
91
+ if filter_x != "skip":
92
+ config["filters"].append({"expr": filter_x, "res": res, "rate": filter_x_rate})
93
+ if filter_y != "skip":
94
+ config["filters"].append({"expr": filter_y, "res": res, "rate": filter_y_rate})
95
+
96
+ # Replace year placeholders if start_year is provided
97
+ if start_year is not None:
98
+ expression_x = expression_x.replace("{year}", str(start_year))
99
+ if expression_y != "skip":
100
+ expression_y = expression_y.replace("{year}", str(start_year))
101
+ if filter_x != "skip":
102
+ filter_x = filter_x.replace("{year}", str(start_year))
103
+ if filter_y != "skip":
104
+ filter_y = filter_y.replace("{year}", str(start_year))
105
+
106
+ # Load AOI GeoJSON
107
+ with open(aoi_geojson, 'r') as f:
108
+ aoi_data = json.load(f)
109
+
110
+ task_response = self._client.mass_stats.random_sample(
111
+ name=name,
112
+ config=config,
113
+ aoi=aoi_data,
114
+ samples=samples,
115
+ year_range=[start_year, end_year],
116
+ crs=crs,
117
+ tile_size=tile_size,
118
+ res=res,
119
+ region=region,
120
+ output="netcdf",
121
+ server=self._client.url,
122
+ bucket="terrakio-mass-requests",
123
+ overwrite=True
124
+ )
125
+ task_id = task_response["task_id"]
126
+
127
+ # Wait for job completion with progress bar
128
+ while True:
129
+ result = self._client.track_mass_stats_job(ids=[task_id])
130
+ status = result[task_id]['status']
131
+ completed = result[task_id].get('completed', 0)
132
+ total = result[task_id].get('total', 1)
133
+
134
+ # Create progress bar
135
+ progress = completed / total if total > 0 else 0
136
+ bar_length = 50
137
+ filled_length = int(bar_length * progress)
138
+ bar = '█' * filled_length + '░' * (bar_length - filled_length)
139
+ percentage = progress * 100
140
+
141
+ self._client.logger.info(f"Job status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})")
142
+
143
+ if status == "Completed":
144
+ self._client.logger.info("Job completed successfully!")
145
+ break
146
+ elif status == "Error":
147
+ self._client.logger.info("Job encountered an error")
148
+ raise Exception(f"Job {task_id} encountered an error")
149
+
150
+ # Wait 5 seconds before checking again
151
+ time.sleep(5)
152
+
153
+ # after all the random sample jobs are done, we then start the mass stats job
154
+ task_id = self._client.mass_stats.start_mass_stats_job(task_id)
155
+ return task_id
156
+
157
+ @require_api_key
158
+ async def upload_model(self, model, model_name: str, input_shape: Tuple[int, ...] = None):
159
+ """
160
+ Upload a model to the bucket so that it can be used for inference.
161
+ Converts PyTorch and scikit-learn models to ONNX format before uploading.
162
+
163
+ Args:
164
+ model: The model object (PyTorch model or scikit-learn model)
165
+ model_name: Name for the model (without extension)
166
+ input_shape: Shape of input data for ONNX conversion (e.g., (1, 10) for batch_size=1, features=10)
167
+ Required for PyTorch models, optional for scikit-learn models
168
+
169
+ Raises:
170
+ APIError: If the API request fails
171
+ ValueError: If model type is not supported or input_shape is missing for PyTorch models
172
+ ImportError: If required libraries (torch or skl2onnx) are not installed
173
+ """
174
+ uid = (await self._client.auth.get_user_info())["uid"]
175
+
176
+ client = storage.Client()
177
+ bucket = client.get_bucket('terrakio-mass-requests')
178
+
179
+ # Convert model to ONNX format
180
+ onnx_bytes = self._convert_model_to_onnx(model, model_name, input_shape)
181
+
182
+ # Upload ONNX model to bucket
183
+ blob = bucket.blob(f'{uid}/{model_name}/models/{model_name}.onnx')
184
+
185
+ blob.upload_from_string(onnx_bytes, content_type='application/octet-stream')
186
+ self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}.onnx")
187
+
188
+ def _convert_model_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
189
+ """
190
+ Convert a model to ONNX format and return as bytes.
191
+
192
+ Args:
193
+ model: The model object (PyTorch or scikit-learn)
194
+ model_name: Name of the model for logging
195
+ input_shape: Shape of input data
196
+
197
+ Returns:
198
+ bytes: ONNX model as bytes
199
+
200
+ Raises:
201
+ ValueError: If model type is not supported
202
+ ImportError: If required libraries are not installed
203
+ """
204
+ # Early check for any conversion capability
205
+ if not (TORCH_AVAILABLE or SKL2ONNX_AVAILABLE):
206
+ raise ImportError(
207
+ "ONNX conversion requires additional dependencies. Install with:\n"
208
+ " pip install torch # For PyTorch models\n"
209
+ " pip install skl2onnx # For scikit-learn models\n"
210
+ " pip install torch skl2onnx # For both"
211
+ )
212
+
213
+ # Check if it's a PyTorch model using isinstance (preferred) with fallback
214
+ is_pytorch = False
215
+ if TORCH_AVAILABLE:
216
+ is_pytorch = (isinstance(model, torch.nn.Module) or
217
+ hasattr(model, 'state_dict'))
218
+
219
+ # Check if it's a scikit-learn model
220
+ is_sklearn = False
221
+ if SKL2ONNX_AVAILABLE:
222
+ is_sklearn = (isinstance(model, BaseEstimator) or
223
+ (hasattr(model, 'fit') and hasattr(model, 'predict')))
224
+
225
+ if is_pytorch and TORCH_AVAILABLE:
226
+ return self._convert_pytorch_to_onnx(model, model_name, input_shape)
227
+ elif is_sklearn and SKL2ONNX_AVAILABLE:
228
+ return self._convert_sklearn_to_onnx(model, model_name, input_shape)
229
+ else:
230
+ # Provide helpful error message
231
+ model_type = type(model).__name__
232
+ model_module = type(model).__module__
233
+ available_types = []
234
+ missing_deps = []
235
+
236
+ if TORCH_AVAILABLE:
237
+ available_types.append("PyTorch (torch.nn.Module)")
238
+ else:
239
+ missing_deps.append("torch")
240
+
241
+ if SKL2ONNX_AVAILABLE:
242
+ available_types.append("scikit-learn (BaseEstimator)")
243
+ else:
244
+ missing_deps.append("skl2onnx")
245
+
246
+ if missing_deps:
247
+ raise ImportError(
248
+ f"Model type {model_type} from {model_module} detected, but required dependencies missing: {', '.join(missing_deps)}. "
249
+ f"Install with: pip install {' '.join(missing_deps)}"
250
+ )
251
+ else:
252
+ raise ValueError(
253
+ f"Unsupported model type: {model_type} from {model_module}. "
254
+ f"Supported types: {', '.join(available_types)}"
255
+ )
256
+
257
+ def _convert_pytorch_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...]) -> bytes:
258
+ """Convert PyTorch model to ONNX format with dynamic input dimensions."""
259
+ if input_shape is None:
260
+ raise ValueError("input_shape is required for PyTorch models")
261
+
262
+ self._client.logger.info(f"Converting PyTorch model {model_name} to ONNX...")
263
+
264
+ try:
265
+ # Set model to evaluation mode
266
+ model.eval()
267
+
268
+ # Create dummy input
269
+ dummy_input = torch.randn(input_shape)
270
+
271
+ # Use BytesIO to avoid creating temporary files
272
+ onnx_buffer = BytesIO()
273
+
274
+ # Determine dynamic axes based on input shape
275
+ # Common patterns for different input types:
276
+ if len(input_shape) == 4: # Convolutional input: (batch, channels, height, width)
277
+ dynamic_axes = {
278
+ 'float_input': {
279
+ 0: 'batch_size',
280
+ 2: 'height', # Make height dynamic for variable input sizes
281
+ 3: 'width' # Make width dynamic for variable input sizes
282
+ },
283
+ 'output': {0: 'batch_size'}
284
+ }
285
+ elif len(input_shape) == 3: # Could be (batch, sequence, features) or (batch, height, width)
286
+ dynamic_axes = {
287
+ 'float_input': {
288
+ 0: 'batch_size',
289
+ 1: 'dim1', # Generic dynamic dimension
290
+ 2: 'dim2' # Generic dynamic dimension
291
+ },
292
+ 'output': {0: 'batch_size'}
293
+ }
294
+ elif len(input_shape) == 2: # Likely (batch, features)
295
+ dynamic_axes = {
296
+ 'float_input': {
297
+ 0: 'batch_size'
298
+ # Don't make features dynamic as it usually affects model architecture
299
+ },
300
+ 'output': {0: 'batch_size'}
301
+ }
302
+ else:
303
+ # For other shapes, just make batch size dynamic
304
+ dynamic_axes = {
305
+ 'float_input': {0: 'batch_size'},
306
+ 'output': {0: 'batch_size'}
307
+ }
308
+
309
+ torch.onnx.export(
310
+ model,
311
+ dummy_input,
312
+ onnx_buffer,
313
+ export_params=True,
314
+ opset_version=11,
315
+ do_constant_folding=True,
316
+ input_names=['float_input'],
317
+ output_names=['output'],
318
+ dynamic_axes=dynamic_axes
319
+ )
320
+
321
+ self._client.logger.info(f"Successfully converted {model_name} with dynamic axes: {dynamic_axes}")
322
+ return onnx_buffer.getvalue()
323
+
324
+ except Exception as e:
325
+ raise ValueError(f"Failed to convert PyTorch model {model_name} to ONNX: {str(e)}")
326
+
327
+
328
+ def _convert_sklearn_to_onnx(self, model, model_name: str, input_shape: Tuple[int, ...] = None) -> bytes:
329
+ """Convert scikit-learn model to ONNX format."""
330
+ self._client.logger.info(f"Converting scikit-learn model {model_name} to ONNX...")
331
+
332
+ # Try to infer input shape if not provided
333
+ if input_shape is None:
334
+ if hasattr(model, 'n_features_in_'):
335
+ input_shape = (1, model.n_features_in_)
336
+ else:
337
+ raise ValueError(
338
+ "input_shape is required for scikit-learn models when n_features_in_ is not available. "
339
+ "This usually happens with older sklearn versions or models not fitted yet."
340
+ )
341
+
342
+ try:
343
+ # Convert scikit-learn model to ONNX
344
+ initial_type = [('float_input', FloatTensorType(input_shape))]
345
+ onnx_model = convert_sklearn(model, initial_types=initial_type)
346
+ return onnx_model.SerializeToString()
347
+
348
+ except Exception as e:
349
+ raise ValueError(f"Failed to convert scikit-learn model {model_name} to ONNX: {str(e)}")
350
+
351
+ @require_api_key
352
+ async def upload_and_deploy_cnn_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
353
+ """
354
+ Upload a CNN model to the bucket and deploy it.
355
+
356
+ Args:
357
+ model: The model object (PyTorch model or scikit-learn model)
358
+ model_name: Name for the model (without extension)
359
+ dataset: Name of the dataset to create
360
+ product: Product name for the inference
361
+ input_expression: Input expression for the dataset
362
+ dates_iso8601: List of dates in ISO8601 format
363
+ input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
364
+
365
+ Raises:
366
+ APIError: If the API request fails
367
+ ValueError: If model type is not supported or input_shape is missing for PyTorch models
368
+ ImportError: If required libraries (torch or skl2onnx) are not installed
369
+ """
370
+ await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
371
+ # so the uploading process is kinda similar, but the deployment step is kinda different
372
+ await self.deploy_cnn_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
373
+
374
+ @require_api_key
375
+ async def upload_and_deploy_model(self, model, model_name: str, dataset: str, product: str, input_expression: str, dates_iso8601: list, input_shape: Tuple[int, ...] = None):
376
+ """
377
+ Upload a model to the bucket and deploy it.
378
+
379
+ Args:
380
+ model: The model object (PyTorch model or scikit-learn model)
381
+ model_name: Name for the model (without extension)
382
+ dataset: Name of the dataset to create
383
+ product: Product name for the inference
384
+ input_expression: Input expression for the dataset
385
+ dates_iso8601: List of dates in ISO8601 format
386
+ input_shape: Shape of input data for ONNX conversion (required for PyTorch models)
387
+ """
388
+ await self.upload_model(model=model, model_name=model_name, input_shape=input_shape)
389
+ await self.deploy_model(dataset=dataset, product=product, model_name=model_name, input_expression=input_expression, model_training_job_name=model_name, dates_iso8601=dates_iso8601)
390
+
391
+ @require_api_key
392
+ def train_model(
393
+ self,
394
+ model_name: str,
395
+ training_dataset: str,
396
+ task_type: str,
397
+ model_category: str,
398
+ architecture: str,
399
+ region: str,
400
+ hyperparameters: dict = None
401
+ ) -> dict:
402
+ """
403
+ Train a model using the external model training API.
404
+
405
+ Args:
406
+ model_name (str): The name of the model to train.
407
+ training_dataset (str): The training dataset identifier.
408
+ task_type (str): The type of ML task (e.g., regression, classification).
409
+ model_category (str): The category of model (e.g., random_forest).
410
+ architecture (str): The model architecture.
411
+ region (str): The region identifier.
412
+ hyperparameters (dict, optional): Additional hyperparameters for training.
413
+
414
+ Returns:
415
+ dict: The response from the model training API.
416
+
417
+ Raises:
418
+ APIError: If the API request fails
419
+ """
420
+ payload = {
421
+ "model_name": model_name,
422
+ "training_dataset": training_dataset,
423
+ "task_type": task_type,
424
+ "model_category": model_category,
425
+ "architecture": architecture,
426
+ "region": region,
427
+ "hyperparameters": hyperparameters
428
+ }
429
+ return self._client._terrakio_request("POST", "/train_model", json=payload)
430
+
431
+ @require_api_key
432
+ async def deploy_model(
433
+ self,
434
+ dataset: str,
435
+ product: str,
436
+ model_name: str,
437
+ input_expression: str,
438
+ model_training_job_name: str,
439
+ dates_iso8601: list
440
+ ) -> Dict[str, Any]:
441
+ """
442
+ Deploy a model by generating inference script and creating dataset.
443
+
444
+ Args:
445
+ dataset: Name of the dataset to create
446
+ product: Product name for the inference
447
+ model_name: Name of the trained model
448
+ input_expression: Input expression for the dataset
449
+ model_training_job_name: Name of the training job
450
+ dates_iso8601: List of dates in ISO8601 format
451
+
452
+ Returns:
453
+ dict: Response from the deployment process
454
+
455
+ Raises:
456
+ APIError: If the API request fails
457
+ """
458
+ # Get user info to get UID
459
+ user_info = await self._client.auth.get_user_info()
460
+ uid = user_info["uid"]
461
+
462
+ # Generate and upload script
463
+ script_content = self._generate_script(model_name, product, model_training_job_name, uid)
464
+ script_name = f"{product}.py"
465
+ self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
466
+
467
+ # Create dataset
468
+ return await self._client.datasets.create_dataset(
469
+ name=dataset,
470
+ collection="terrakio-datasets",
471
+ products=[product],
472
+ path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
473
+ input=input_expression,
474
+ dates_iso8601=dates_iso8601,
475
+ padding=0
476
+ )
477
+
478
+ @require_api_key
479
+ async def deploy_cnn_model(
480
+ self,
481
+ dataset: str,
482
+ product: str,
483
+ model_name: str,
484
+ input_expression: str,
485
+ model_training_job_name: str,
486
+ dates_iso8601: list
487
+ ) -> Dict[str, Any]:
488
+ """
489
+ Deploy a CNN model by generating inference script and creating dataset.
490
+
491
+ Args:
492
+ dataset: Name of the dataset to create
493
+ product: Product name for the inference
494
+ model_name: Name of the trained model
495
+ input_expression: Input expression for the dataset
496
+ model_training_job_name: Name of the training job
497
+ dates_iso8601: List of dates in ISO8601 format
498
+
499
+ Returns:
500
+ dict: Response from the deployment process
501
+
502
+ Raises:
503
+ APIError: If the API request fails
504
+ """
505
+ # Get user info to get UID
506
+ user_info = await self._client.auth.get_user_info()
507
+ uid = user_info["uid"]
508
+
509
+ # Generate and upload script
510
+ script_content = self.generate_cnn_script(model_name, product, model_training_job_name, uid)
511
+ script_name = f"{product}.py"
512
+ self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
513
+ # Create dataset
514
+ return await self._client.datasets.create_dataset(
515
+ name=dataset,
516
+ collection="terrakio-datasets",
517
+ products=[product],
518
+ path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
519
+ input=input_expression,
520
+ dates_iso8601=dates_iso8601,
521
+ padding=0
522
+ )
523
+
524
+ @require_api_key
525
+ def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
526
+ """
527
+ Generate Python inference script for the model.
528
+
529
+ Args:
530
+ model_name: Name of the model
531
+ product: Product name
532
+ model_training_job_name: Training job name
533
+ uid: User ID
534
+
535
+ Returns:
536
+ str: Generated Python script content
537
+ """
538
+ return textwrap.dedent(f'''
539
+ import logging
540
+ from io import BytesIO
541
+
542
+ import numpy as np
543
+ import pandas as pd
544
+ import xarray as xr
545
+ from google.cloud import storage
546
+ from onnxruntime import InferenceSession
547
+
548
+ logging.basicConfig(
549
+ level=logging.INFO
550
+ )
551
+
552
+ def get_model():
553
+ logging.info("Loading model for {model_name}...")
554
+
555
+ client = storage.Client()
556
+ bucket = client.get_bucket('terrakio-mass-requests')
557
+ blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
558
+
559
+ model = BytesIO()
560
+ blob.download_to_file(model)
561
+ model.seek(0)
562
+
563
+ session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
564
+ return session
565
+
566
+ def {product}(*bands, model):
567
+ logging.info("start preparing data")
568
+
569
+ data_arrays = list(bands)
570
+
571
+ reference_array = data_arrays[0]
572
+ original_shape = reference_array.shape
573
+ logging.info(f"Original shape: {{original_shape}}")
574
+
575
+ if 'time' in reference_array.dims:
576
+ time_coords = reference_array.coords['time']
577
+ if len(time_coords) == 1:
578
+ output_timestamp = time_coords[0]
579
+ else:
580
+ years = [pd.to_datetime(t).year for t in time_coords.values]
581
+ unique_years = set(years)
582
+
583
+ if len(unique_years) == 1:
584
+ year = list(unique_years)[0]
585
+ output_timestamp = pd.Timestamp(f"{{year}}-01-01")
586
+ else:
587
+ latest_year = max(unique_years)
588
+ output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
589
+ else:
590
+ output_timestamp = pd.Timestamp("1970-01-01")
591
+
592
+ averaged_bands = []
593
+ for data_array in data_arrays:
594
+ if 'time' in data_array.dims:
595
+ averaged_band = np.mean(data_array.values, axis=0)
596
+ logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
597
+ else:
598
+ averaged_band = data_array.values
599
+ logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
600
+
601
+ flattened_band = averaged_band.reshape(-1, 1)
602
+ averaged_bands.append(flattened_band)
603
+
604
+ input_data = np.hstack(averaged_bands)
605
+
606
+ logging.info(f"Final input shape: {{input_data.shape}}")
607
+
608
+ output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
609
+
610
+ logging.info(f"Model output shape: {{output.shape}}")
611
+
612
+ if len(original_shape) >= 3:
613
+ spatial_shape = original_shape[1:]
614
+ else:
615
+ spatial_shape = original_shape
616
+
617
+ output_reshaped = output.reshape(spatial_shape)
618
+
619
+ output_with_time = np.expand_dims(output_reshaped, axis=0)
620
+
621
+ if 'time' in reference_array.dims:
622
+ spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
623
+ spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
624
+ else:
625
+ spatial_dims = list(reference_array.dims)
626
+ spatial_coords = dict(reference_array.coords)
627
+
628
+ result = xr.DataArray(
629
+ data=output_with_time.astype(np.float32),
630
+ dims=['time'] + list(spatial_dims),
631
+ coords={{
632
+ 'time': [output_timestamp.values],
633
+ 'y': spatial_coords['y'].values,
634
+ 'x': spatial_coords['x'].values
635
+ }}
636
+ )
637
+ return result
638
+ ''').strip()
639
+
640
+ @require_api_key
641
+ def generate_cnn_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
642
+ """
643
+ Generate Python inference script for CNN model with time-stacked bands.
644
+
645
+ Args:
646
+ model_name: Name of the model
647
+ product: Product name
648
+ model_training_job_name: Training job name
649
+ uid: User ID
650
+
651
+ Returns:
652
+ str: Generated Python script content
653
+ """
654
+ return textwrap.dedent(f'''
655
+ import logging
656
+ from io import BytesIO
657
+
658
+ import numpy as np
659
+ import pandas as pd
660
+ import xarray as xr
661
+ from google.cloud import storage
662
+ from onnxruntime import InferenceSession
663
+
664
+ logging.basicConfig(
665
+ level=logging.INFO
666
+ )
667
+
668
+ def get_model():
669
+ logging.info("Loading CNN model for {model_name}...")
670
+
671
+ client = storage.Client()
672
+ bucket = client.get_bucket('terrakio-mass-requests')
673
+ blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
674
+
675
+ model = BytesIO()
676
+ blob.download_to_file(model)
677
+ model.seek(0)
678
+
679
+ session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
680
+ return session
681
+
682
+ def {product}(*bands, model):
683
+ logging.info("Start preparing CNN data with time-stacked bands")
684
+
685
+ data_arrays = list(bands)
686
+
687
+ if not data_arrays:
688
+ raise ValueError("No bands provided")
689
+
690
+ reference_array = data_arrays[0]
691
+ original_shape = reference_array.shape
692
+ logging.info(f"Original shape: {{original_shape}}")
693
+
694
+ # Get time coordinates - all bands should have the same time dimension
695
+ if 'time' not in reference_array.dims:
696
+ raise ValueError("Time dimension is required for CNN processing")
697
+
698
+ time_coords = reference_array.coords['time']
699
+ num_timestamps = len(time_coords)
700
+ logging.info(f"Number of timestamps: {{num_timestamps}}")
701
+
702
+ # Get spatial dimensions
703
+ spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
704
+ height = reference_array.sizes[spatial_dims[0]] # assuming first spatial dim is height
705
+ width = reference_array.sizes[spatial_dims[1]] # assuming second spatial dim is width
706
+ logging.info(f"Spatial dimensions: {{height}} x {{width}}")
707
+
708
+ # Stack bands across time dimension
709
+ # Result will be: (num_bands * num_timestamps, height, width)
710
+ stacked_channels = []
711
+
712
+ for band_idx, data_array in enumerate(data_arrays):
713
+ logging.info(f"Processing band {{band_idx + 1}}/{{len(data_arrays)}}")
714
+
715
+ # Ensure consistent time coordinates across bands
716
+ if not np.array_equal(data_array.coords['time'].values, time_coords.values):
717
+ logging.warning(f"Band {{band_idx}} has different time coordinates, aligning...")
718
+ data_array = data_array.sel(time=time_coords, method='nearest')
719
+
720
+ # Extract values and ensure proper ordering (time, height, width)
721
+ band_values = data_array.values
722
+ if band_values.ndim == 3:
723
+ # Reorder dimensions if needed to ensure (time, height, width)
724
+ time_dim_idx = data_array.dims.index('time')
725
+ if time_dim_idx != 0:
726
+ axes_order = [time_dim_idx] + [i for i in range(len(data_array.dims)) if i != time_dim_idx]
727
+ band_values = np.transpose(band_values, axes_order)
728
+
729
+ # Add each timestamp of this band to the channel stack
730
+ for t in range(num_timestamps):
731
+ stacked_channels.append(band_values[t])
732
+
733
+ # Stack all channels: (num_bands * num_timestamps, height, width)
734
+ input_channels = np.stack(stacked_channels, axis=0)
735
+ total_channels = len(data_arrays) * num_timestamps
736
+ logging.info(f"Stacked channels shape: {{input_channels.shape}}")
737
+ logging.info(f"Total channels: {{total_channels}} ({{len(data_arrays)}} bands × {{num_timestamps}} timestamps)")
738
+
739
+ # Add batch dimension: (1, num_channels, height, width)
740
+ input_data = np.expand_dims(input_channels, axis=0).astype(np.float32)
741
+ logging.info(f"Final input shape for CNN: {{input_data.shape}}")
742
+
743
+ # Run inference
744
+ output = model.run(None, {{"float_input": input_data}})[0]
745
+ logging.info(f"Model output shape: {{output.shape}}")
746
+
747
+ # Process output back to xarray format
748
+ # Assuming output is (1, height, width) or (1, 1, height, width)
749
+ if output.ndim == 4 and output.shape[1] == 1:
750
+ # Remove channel dimension if it's 1
751
+ output_2d = output[0, 0]
752
+ elif output.ndim == 3:
753
+ # Remove batch dimension
754
+ output_2d = output[0]
755
+ else:
756
+ # Handle other cases
757
+ output_2d = np.squeeze(output)
758
+ if output_2d.ndim != 2:
759
+ raise ValueError(f"Unexpected output shape after processing: {{output_2d.shape}}")
760
+
761
+ # Determine output timestamp (use the latest timestamp)
762
+ output_timestamp = time_coords[-1]
763
+
764
+ # Get spatial coordinates from reference array
765
+ spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims}}
766
+
767
+ # Create output DataArray
768
+ result = xr.DataArray(
769
+ data=np.expand_dims(output_2d.astype(np.float32), axis=0),
770
+ dims=['time'] + spatial_dims,
771
+ coords={{
772
+ 'time': [output_timestamp.values],
773
+ spatial_dims[0]: spatial_coords[spatial_dims[0]].values,
774
+ spatial_dims[1]: spatial_coords[spatial_dims[1]].values
775
+ }}
776
+ )
777
+
778
+ logging.info(f"Final result shape: {{result.shape}}")
779
+ return result
780
+ ''').strip()
781
+
782
+ @require_api_key
783
+ def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
784
+ """Upload the generated script to Google Cloud Storage"""
785
+
786
+ client = storage.Client()
787
+ bucket = client.get_bucket('terrakio-mass-requests')
788
+ blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
789
+ blob.upload_from_string(script_content, content_type='text/plain')
790
+ logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")