terrakio-core 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

@@ -0,0 +1,385 @@
1
+ import os
2
+ import json
3
+ import time
4
+ import textwrap
5
+ import logging
6
+ from typing import Dict, Any
7
+ from google.cloud import storage
8
+ from ..helper.decorators import require_token, require_api_key, require_auth
9
+
10
+ class ModelManagement:
11
+ def __init__(self, client):
12
+ self._client = client
13
+
14
+ @require_api_key
15
+ def generate_ai_dataset(
16
+ self,
17
+ name: str,
18
+ aoi_geojson: str,
19
+ expression_x: str,
20
+ filter_x_rate: float,
21
+ filter_y_rate: float,
22
+ samples: int,
23
+ tile_size: int,
24
+ expression_y: str = "skip",
25
+ filter_x: str = "skip",
26
+ filter_y: str = "skip",
27
+ crs: str = "epsg:4326",
28
+ res: float = 0.001,
29
+ region: str = "aus",
30
+ start_year: int = None,
31
+ end_year: int = None,
32
+ ) -> dict:
33
+ """
34
+ Generate an AI dataset using specified parameters.
35
+
36
+ Args:
37
+ name (str): Name of the dataset to generate
38
+ aoi_geojson (str): Path to GeoJSON file containing area of interest
39
+ expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
40
+ filter_x (str): Filter for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
41
+ filter_x_rate (float): Filter rate for X variable (e.g. 0.5)
42
+ expression_y (str): Expression for Y variable with {year} placeholder
43
+ filter_y (str): Filter for Y variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
44
+ filter_y_rate (float): Filter rate for Y variable (e.g. 0.5)
45
+ samples (int): Number of samples to generate
46
+ tile_size (int): Size of tiles in degrees
47
+ crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
48
+ res (float, optional): Resolution in degrees. Defaults to 0.001
49
+ region (str, optional): Region code. Defaults to "aus"
50
+ start_year (int, optional): Start year for data generation. Required if end_year provided
51
+ end_year (int, optional): End year for data generation. Required if start_year provided
52
+
53
+ Returns:
54
+ dict: Response from the AI dataset generation API
55
+
56
+ Raises:
57
+ APIError: If the API request fails
58
+ """
59
+ # Build config for expressions and filters
60
+ config = {
61
+ "expressions": [{"expr": expression_x, "res": res, "prefix": "x"}],
62
+ "filters": []
63
+ }
64
+
65
+ if expression_y != "skip":
66
+ config["expressions"].append({"expr": expression_y, "res": res, "prefix": "y"})
67
+
68
+ if filter_x != "skip":
69
+ config["filters"].append({"expr": filter_x, "res": res, "rate": filter_x_rate})
70
+ if filter_y != "skip":
71
+ config["filters"].append({"expr": filter_y, "res": res, "rate": filter_y_rate})
72
+
73
+ # Replace year placeholders if start_year is provided
74
+ if start_year is not None:
75
+ expression_x = expression_x.replace("{year}", str(start_year))
76
+ if expression_y != "skip":
77
+ expression_y = expression_y.replace("{year}", str(start_year))
78
+ if filter_x != "skip":
79
+ filter_x = filter_x.replace("{year}", str(start_year))
80
+ if filter_y != "skip":
81
+ filter_y = filter_y.replace("{year}", str(start_year))
82
+
83
+ # Load AOI GeoJSON
84
+ with open(aoi_geojson, 'r') as f:
85
+ aoi_data = json.load(f)
86
+
87
+ task_response = self._client.mass_stats.random_sample(
88
+ name=name,
89
+ config=config,
90
+ aoi=aoi_data,
91
+ samples=samples,
92
+ year_range=[start_year, end_year],
93
+ crs=crs,
94
+ tile_size=tile_size,
95
+ res=res,
96
+ region=region,
97
+ output="netcdf",
98
+ server=self._client.url,
99
+ bucket="terrakio-mass-requests",
100
+ overwrite=True
101
+ )
102
+ task_id = task_response["task_id"]
103
+
104
+ # Wait for job completion with progress bar
105
+ while True:
106
+ result = self._client.track_mass_stats_job(ids=[task_id])
107
+ status = result[task_id]['status']
108
+ completed = result[task_id].get('completed', 0)
109
+ total = result[task_id].get('total', 1)
110
+
111
+ # Create progress bar
112
+ progress = completed / total if total > 0 else 0
113
+ bar_length = 50
114
+ filled_length = int(bar_length * progress)
115
+ bar = '█' * filled_length + '░' * (bar_length - filled_length)
116
+ percentage = progress * 100
117
+
118
+ # Print status with progress bar
119
+ print(f"\rJob status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})", end='')
120
+
121
+ if status == "Completed":
122
+ print("\nJob completed successfully!")
123
+ break
124
+ elif status == "Error":
125
+ print("\n") # New line before error message
126
+ raise Exception(f"Job {task_id} encountered an error")
127
+
128
+ # Wait 5 seconds before checking again
129
+ time.sleep(5)
130
+
131
+ # after all the random sample jobs are done, we then start the mass stats job
132
+ task_id = self._client.mass_stats.start_mass_stats_job(task_id)
133
+ return task_id
134
+
135
+ @require_api_key
136
+ async def upload_model(self, model_path: str):
137
+ """
138
+ Upload a model to the bucket so that it can be used for inference.
139
+
140
+ Args:
141
+ model_path: Path to the model file
142
+
143
+ Raises:
144
+ APIError: If the API request fails
145
+ """
146
+ uid = (await self._client.auth.get_user_info())["uid"]
147
+ model_name = os.path.basename(model_path)
148
+
149
+ client = storage.Client()
150
+ bucket = client.get_bucket('terrakio-mass-requests')
151
+ model_file_name = os.path.splitext(model_name)[0]
152
+ blob = bucket.blob(f'{uid}/{model_file_name}/models/{model_name}')
153
+
154
+ blob.upload_from_filename(model_path)
155
+ self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}")
156
+
157
+ @require_api_key
158
+ def upload_and_deploy_model(self, model_path: str, dataset: str, product: str, input_expression: str, dates_iso8601: list):
159
+
160
+ """
161
+ Upload a model to the bucket and deploy it.
162
+
163
+ Args:
164
+ model_path: Path to the model file
165
+ dataset Name of the dataset to create
166
+ product: Product name for the inference
167
+ input_expression: Input expression for the dataset
168
+ dates_iso8601: List of dates in ISO8601 format
169
+ """
170
+ self.upload_model(model_path = model_path)
171
+ model_name = os.path.basename(model_path)
172
+ self.deploy_model(dataset = dataset, product = product, model_name = model_name, input_expression = input_expression, model_training_job_name = model_name, dates_iso8601 = dates_iso8601)
173
+
174
+ @require_api_key
175
+ def train_model(
176
+ self,
177
+ model_name: str,
178
+ training_dataset: str,
179
+ task_type: str,
180
+ model_category: str,
181
+ architecture: str,
182
+ region: str,
183
+ hyperparameters: dict = None
184
+ ) -> dict:
185
+ """
186
+ Train a model using the external model training API.
187
+
188
+ Args:
189
+ model_name (str): The name of the model to train.
190
+ training_dataset (str): The training dataset identifier.
191
+ task_type (str): The type of ML task (e.g., regression, classification).
192
+ model_category (str): The category of model (e.g., random_forest).
193
+ architecture (str): The model architecture.
194
+ region (str): The region identifier.
195
+ hyperparameters (dict, optional): Additional hyperparameters for training.
196
+
197
+ Returns:
198
+ dict: The response from the model training API.
199
+
200
+ Raises:
201
+ APIError: If the API request fails
202
+ """
203
+ payload = {
204
+ "model_name": model_name,
205
+ "training_dataset": training_dataset,
206
+ "task_type": task_type,
207
+ "model_category": model_category,
208
+ "architecture": architecture,
209
+ "region": region,
210
+ "hyperparameters": hyperparameters
211
+ }
212
+ return self._client._terrakio_request("POST", "/train_model", json=payload)
213
+
214
+ @require_api_key
215
+ def deploy_model(
216
+ self,
217
+ dataset: str,
218
+ product: str,
219
+ model_name: str,
220
+ input_expression: str,
221
+ model_training_job_name: str,
222
+ dates_iso8601: list
223
+ ) -> Dict[str, Any]:
224
+ """
225
+ Deploy a model by generating inference script and creating dataset.
226
+
227
+ Args:
228
+ dataset: Name of the dataset to create
229
+ product: Product name for the inference
230
+ model_name: Name of the trained model
231
+ input_expression: Input expression for the dataset
232
+ model_training_job_name: Name of the training job
233
+ dates_iso8601: List of dates in ISO8601 format
234
+
235
+ Returns:
236
+ dict: Response from the deployment process
237
+
238
+ Raises:
239
+ APIError: If the API request fails
240
+ """
241
+ # Get user info to get UID
242
+ user_info = self._client.get_user_info()
243
+ uid = user_info["uid"]
244
+
245
+ # Generate and upload script
246
+ script_content = self._generate_script(model_name, product, model_training_job_name, uid)
247
+ script_name = f"{product}.py"
248
+ self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
249
+
250
+ # Create dataset
251
+ return self._client.datasets.create_dataset(
252
+ name=dataset,
253
+ collection="terrakio-datasets",
254
+ products=[product],
255
+ path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
256
+ input=input_expression,
257
+ dates_iso8601=dates_iso8601,
258
+ padding=0
259
+ )
260
+
261
+ @require_api_key
262
+ def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
263
+ """
264
+ Generate Python inference script for the model.
265
+
266
+ Args:
267
+ model_name: Name of the model
268
+ product: Product name
269
+ model_training_job_name: Training job name
270
+ uid: User ID
271
+
272
+ Returns:
273
+ str: Generated Python script content
274
+ """
275
+ return textwrap.dedent(f'''
276
+ import logging
277
+ from io import BytesIO
278
+
279
+ import numpy as np
280
+ import pandas as pd
281
+ import xarray as xr
282
+ from google.cloud import storage
283
+ from onnxruntime import InferenceSession
284
+
285
+ logging.basicConfig(
286
+ level=logging.INFO
287
+ )
288
+
289
+ def get_model():
290
+ logging.info("Loading model for {model_name}...")
291
+
292
+ client = storage.Client()
293
+ bucket = client.get_bucket('terrakio-mass-requests')
294
+ blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
295
+
296
+ model = BytesIO()
297
+ blob.download_to_file(model)
298
+ model.seek(0)
299
+
300
+ session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
301
+ return session
302
+
303
+ def {product}(*bands, model):
304
+ logging.info("start preparing data")
305
+
306
+ data_arrays = list(bands)
307
+
308
+ reference_array = data_arrays[0]
309
+ original_shape = reference_array.shape
310
+ logging.info(f"Original shape: {{original_shape}}")
311
+
312
+ if 'time' in reference_array.dims:
313
+ time_coords = reference_array.coords['time']
314
+ if len(time_coords) == 1:
315
+ output_timestamp = time_coords[0]
316
+ else:
317
+ years = [pd.to_datetime(t).year for t in time_coords.values]
318
+ unique_years = set(years)
319
+
320
+ if len(unique_years) == 1:
321
+ year = list(unique_years)[0]
322
+ output_timestamp = pd.Timestamp(f"{{year}}-01-01")
323
+ else:
324
+ latest_year = max(unique_years)
325
+ output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
326
+ else:
327
+ output_timestamp = pd.Timestamp("1970-01-01")
328
+
329
+ averaged_bands = []
330
+ for data_array in data_arrays:
331
+ if 'time' in data_array.dims:
332
+ averaged_band = np.mean(data_array.values, axis=0)
333
+ logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
334
+ else:
335
+ averaged_band = data_array.values
336
+ logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
337
+
338
+ flattened_band = averaged_band.reshape(-1, 1)
339
+ averaged_bands.append(flattened_band)
340
+
341
+ input_data = np.hstack(averaged_bands)
342
+
343
+ logging.info(f"Final input shape: {{input_data.shape}}")
344
+
345
+ output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
346
+
347
+ logging.info(f"Model output shape: {{output.shape}}")
348
+
349
+ if len(original_shape) >= 3:
350
+ spatial_shape = original_shape[1:]
351
+ else:
352
+ spatial_shape = original_shape
353
+
354
+ output_reshaped = output.reshape(spatial_shape)
355
+
356
+ output_with_time = np.expand_dims(output_reshaped, axis=0)
357
+
358
+ if 'time' in reference_array.dims:
359
+ spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
360
+ spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
361
+ else:
362
+ spatial_dims = list(reference_array.dims)
363
+ spatial_coords = dict(reference_array.coords)
364
+
365
+ result = xr.DataArray(
366
+ data=output_with_time.astype(np.float32),
367
+ dims=['time'] + list(spatial_dims),
368
+ coords={{
369
+ 'time': [output_timestamp.values],
370
+ 'y': spatial_coords['y'].values,
371
+ 'x': spatial_coords['x'].values
372
+ }}
373
+ )
374
+ return result
375
+ ''').strip()
376
+
377
+ @require_api_key
378
+ def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
379
+ """Upload the generated script to Google Cloud Storage"""
380
+
381
+ client = storage.Client()
382
+ bucket = client.get_bucket('terrakio-mass-requests')
383
+ blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
384
+ blob.upload_from_string(script_content, content_type='text/plain')
385
+ logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
@@ -0,0 +1,72 @@
1
+ from typing import Dict, Any, Optional
2
+ from ..helper.decorators import require_token, require_api_key, require_auth
3
+ class SpaceManagement:
4
+ def __init__(self, client):
5
+ self._client = client
6
+
7
+ @require_api_key
8
+ def get_total_space_used(self) -> Dict[str, Any]:
9
+ """
10
+ Get total space used by the user.
11
+
12
+ Returns:
13
+ Dict[str, Any]: Total space used by the user.
14
+
15
+ Raises:
16
+ APIError: If the API request fails
17
+ """
18
+ return self._client._terrakio_request("GET", "/users/jobs")
19
+
20
+ @require_api_key
21
+ def get_space_used_by_job(self, name: str, region: str) -> Dict[str, Any]:
22
+ """
23
+ Get space used by a specific job.
24
+
25
+ Args:
26
+ name: The name of the job
27
+ region: The region of the job
28
+
29
+ Returns:
30
+ Dict[str, Any]: Space used by the job.
31
+
32
+ Raises:
33
+ APIError: If the API request fails
34
+ """
35
+ params = {"region": region}
36
+ return self._client._terrakio_request("GET", f"/users/jobs/{name}", params=params)
37
+
38
+ @require_api_key
39
+ def delete_user_job(self, name: str, region: str) -> Dict[str, Any]:
40
+ """
41
+ Delete a user job by name and region.
42
+
43
+ Args:
44
+ name: The name of the job
45
+ region: The region of the job
46
+
47
+ Returns:
48
+ Dict[str, Any]: Response from the delete operation.
49
+
50
+ Raises:
51
+ APIError: If the API request fails
52
+ """
53
+ params = {"region": region}
54
+ return self._client._terrakio_request("DELETE", f"/users/jobs/{name}", params=params)
55
+
56
+ @require_api_key
57
+ def delete_data_in_path(self, path: str, region: str) -> Dict[str, Any]:
58
+ """
59
+ Delete data in a GCS path for a given region.
60
+
61
+ Args:
62
+ path: The GCS path to delete data from
63
+ region: The region where the data is located
64
+
65
+ Returns:
66
+ Dict[str, Any]: Response from the delete operation.
67
+
68
+ Raises:
69
+ APIError: If the API request fails
70
+ """
71
+ params = {"path": path, "region": region}
72
+ return self._client._terrakio_request("DELETE", "/users/jobs", params=params)
@@ -0,0 +1,131 @@
1
+ from typing import Dict, Any, List, Optional
2
+ from ..helper.decorators import require_token, require_api_key, require_auth
3
+
4
+ class UserManagement:
5
+ def __init__(self, client):
6
+ self._client = client
7
+
8
+ @require_api_key
9
+ def get_user_by_id(self, id: str) -> Dict[str, Any]:
10
+ """
11
+ Get user by ID.
12
+
13
+ Args:
14
+ user_id: User ID
15
+
16
+ Returns:
17
+ User information
18
+
19
+ Raises:
20
+ APIError: If the API request fails
21
+ """
22
+ return self._client._terrakio_request("GET", f"admin/users/{id}")
23
+
24
+ @require_api_key
25
+ def get_user_by_email(self, email: str) -> Dict[str, Any]:
26
+ """
27
+ Get user by email.
28
+
29
+ Args:
30
+ email: User email
31
+
32
+ Returns:
33
+ User information
34
+
35
+ Raises:
36
+ APIError: If the API request fails
37
+ """
38
+ return self._client._terrakio_request("GET", f"admin/users/email/{email}")
39
+
40
+ @require_api_key
41
+ def list_users(self, substring: Optional[str] = None, uid: bool = False) -> List[Dict[str, Any]]:
42
+ """
43
+ List users, optionally filtering by a substring.
44
+
45
+ Args:
46
+ substring: Optional substring to filter users
47
+ uid: If True, includes the user ID in the response (default: False)
48
+
49
+ Returns:
50
+ List of users
51
+
52
+ Raises:
53
+ APIError: If the API request fails
54
+ """
55
+ params = {"uid": str(uid).lower()}
56
+ if substring:
57
+ params['substring'] = substring
58
+ return self._client._terrakio_request("GET", "admin/users", params=params)
59
+
60
+ @require_api_key
61
+ def edit_user(
62
+ self,
63
+ uid: str,
64
+ email: Optional[str],
65
+ role: Optional[str],
66
+ apiKey: Optional[str],
67
+ groups: Optional[List[str]],
68
+ quota: Optional[int]
69
+ ) -> Dict[str, Any]:
70
+ """
71
+ Edit user info. Only provided fields will be updated.
72
+
73
+ Args:
74
+ uid: User ID
75
+ email: New user email
76
+ role: New user role
77
+ apiKey: New API key
78
+ groups: New list of groups
79
+ quota: New quota
80
+
81
+ Returns:
82
+ Updated user information
83
+
84
+ Raises:
85
+ APIError: If the API request fails
86
+ """
87
+ payload = {"uid": uid}
88
+ payload_mapping = {
89
+ "email": email,
90
+ "role": role,
91
+ "apiKey": apiKey,
92
+ "groups": groups,
93
+ "quota": quota
94
+ }
95
+ for key, value in payload_mapping.items():
96
+ if value is not None:
97
+ payload[key] = value
98
+ return self._client._terrakio_request("PATCH", "admin/users", json=payload)
99
+
100
+ @require_api_key
101
+ def reset_quota(self, email: str, quota: Optional[int] = None) -> Dict[str, Any]:
102
+ """
103
+ Reset the quota for a user by email.
104
+
105
+ Args:
106
+ email: The user's email (required)
107
+ quota: The new quota value (optional)
108
+
109
+ Returns:
110
+ API response as a dictionary
111
+ """
112
+ payload = {"email": email}
113
+ if quota is not None:
114
+ payload["quota"] = quota
115
+ return self._client._terrakio_request("PATCH", f"admin/users/reset_quota/{email}", json=payload)
116
+
117
+ @require_api_key
118
+ def delete_user(self, uid: str) -> Dict[str, Any]:
119
+ """
120
+ Delete a user by UID.
121
+
122
+ Args:
123
+ uid: The user's UID (required)
124
+
125
+ Returns:
126
+ API response as a dictionary
127
+
128
+ Raises:
129
+ APIError: If the API request fails
130
+ """
131
+ return self._client._terrakio_request("DELETE", f"admin/users/{uid}")
@@ -0,0 +1,20 @@
1
+ import asyncio
2
+
3
+ # Adapted from https://discuss.python.org/t/boundedtaskgroup-to-control-parallelism/27171
4
+
5
+ class BoundedTaskGroup(asyncio.TaskGroup):
6
+ def __init__(self, *args, max_concurrency = 0, **kwargs):
7
+ super().__init__(*args)
8
+ if max_concurrency:
9
+ self._sem = asyncio.Semaphore(max_concurrency)
10
+ else:
11
+ self._sem = None
12
+
13
+ def create_task(self, coro, *args, **kwargs):
14
+ if self._sem:
15
+ async def _wrapped_coro(sem, coro):
16
+ async with sem:
17
+ return await coro
18
+ coro = _wrapped_coro(self._sem, coro)
19
+
20
+ return super().create_task(coro, *args, **kwargs)
@@ -0,0 +1,58 @@
1
+ # terrakio_core/decorators.py
2
+ from functools import wraps
3
+ from ..exceptions import ConfigurationError
4
+
5
+ def require_token(func):
6
+ """Decorator to ensure a token is available before a method can be executed."""
7
+ @wraps(func)
8
+ def wrapper(self, *args, **kwargs):
9
+ # Check both direct token and client token
10
+ has_token = False
11
+ if hasattr(self, 'token') and self.token:
12
+ has_token = True
13
+ elif hasattr(self, '_client') and hasattr(self._client, 'token') and self._client.token:
14
+ has_token = True
15
+
16
+ if not has_token:
17
+ raise ConfigurationError("Authentication token required. Please login first.")
18
+ return func(self, *args, **kwargs)
19
+
20
+ wrapper._is_decorated = True
21
+ return wrapper
22
+
23
+ def require_api_key(func):
24
+ """Decorator to ensure an API key is available before a method can be executed."""
25
+ @wraps(func)
26
+ def wrapper(self, *args, **kwargs):
27
+ # Check both direct key and client key
28
+ has_key = False
29
+ if hasattr(self, 'key') and self.key:
30
+ has_key = True
31
+ elif hasattr(self, '_client') and hasattr(self._client, 'key') and self._client.key:
32
+ has_key = True
33
+
34
+ if not has_key:
35
+ raise ConfigurationError("API key required. Please provide an API key or login first.")
36
+ return func(self, *args, **kwargs)
37
+
38
+ wrapper._is_decorated = True
39
+ return wrapper
40
+
41
+ def require_auth(func):
42
+ """Decorator that requires either a token OR an API key"""
43
+ @wraps(func)
44
+ def wrapper(self, *args, **kwargs):
45
+ # Check both direct auth and client auth
46
+ has_token = (hasattr(self, 'token') and self.token) or \
47
+ (hasattr(self, '_client') and hasattr(self._client, 'token') and self._client.token)
48
+ has_api_key = (hasattr(self, 'key') and self.key) or \
49
+ (hasattr(self, '_client') and hasattr(self._client, 'key') and self._client.key)
50
+
51
+ if not has_token and not has_api_key:
52
+ raise ConfigurationError(
53
+ "Authentication required. Please provide either an API key or login to get a token."
54
+ )
55
+ return func(self, *args, **kwargs)
56
+
57
+ wrapper._is_decorated = True
58
+ return wrapper