terrakio-core 0.3.3__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +10 -1
- terrakio_core/async_client.py +304 -0
- terrakio_core/client.py +22 -1717
- terrakio_core/config.py +8 -15
- terrakio_core/convenience_functions/convenience_functions.py +296 -0
- terrakio_core/endpoints/auth.py +180 -0
- terrakio_core/endpoints/dataset_management.py +369 -0
- terrakio_core/endpoints/group_management.py +228 -0
- terrakio_core/endpoints/mass_stats.py +594 -0
- terrakio_core/endpoints/model_management.py +385 -0
- terrakio_core/endpoints/space_management.py +72 -0
- terrakio_core/endpoints/user_management.py +131 -0
- terrakio_core/helper/bounded_taskgroup.py +20 -0
- terrakio_core/helper/decorators.py +58 -0
- terrakio_core/{generation → helper}/tiles.py +1 -12
- terrakio_core/sync_client.py +370 -0
- {terrakio_core-0.3.3.dist-info → terrakio_core-0.3.6.dist-info}/METADATA +1 -1
- terrakio_core-0.3.6.dist-info/RECORD +21 -0
- terrakio_core/auth.py +0 -223
- terrakio_core/dataset_management.py +0 -287
- terrakio_core/decorators.py +0 -18
- terrakio_core/group_access_management.py +0 -232
- terrakio_core/mass_stats.py +0 -504
- terrakio_core/space_management.py +0 -101
- terrakio_core/user_management.py +0 -227
- terrakio_core-0.3.3.dist-info/RECORD +0 -16
- {terrakio_core-0.3.3.dist-info → terrakio_core-0.3.6.dist-info}/WHEEL +0 -0
- {terrakio_core-0.3.3.dist-info → terrakio_core-0.3.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import textwrap
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Dict, Any
|
|
7
|
+
from google.cloud import storage
|
|
8
|
+
from ..helper.decorators import require_token, require_api_key, require_auth
|
|
9
|
+
|
|
10
|
+
class ModelManagement:
|
|
11
|
+
def __init__(self, client):
|
|
12
|
+
self._client = client
|
|
13
|
+
|
|
14
|
+
@require_api_key
|
|
15
|
+
def generate_ai_dataset(
|
|
16
|
+
self,
|
|
17
|
+
name: str,
|
|
18
|
+
aoi_geojson: str,
|
|
19
|
+
expression_x: str,
|
|
20
|
+
filter_x_rate: float,
|
|
21
|
+
filter_y_rate: float,
|
|
22
|
+
samples: int,
|
|
23
|
+
tile_size: int,
|
|
24
|
+
expression_y: str = "skip",
|
|
25
|
+
filter_x: str = "skip",
|
|
26
|
+
filter_y: str = "skip",
|
|
27
|
+
crs: str = "epsg:4326",
|
|
28
|
+
res: float = 0.001,
|
|
29
|
+
region: str = "aus",
|
|
30
|
+
start_year: int = None,
|
|
31
|
+
end_year: int = None,
|
|
32
|
+
) -> dict:
|
|
33
|
+
"""
|
|
34
|
+
Generate an AI dataset using specified parameters.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
name (str): Name of the dataset to generate
|
|
38
|
+
aoi_geojson (str): Path to GeoJSON file containing area of interest
|
|
39
|
+
expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
40
|
+
filter_x (str): Filter for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
41
|
+
filter_x_rate (float): Filter rate for X variable (e.g. 0.5)
|
|
42
|
+
expression_y (str): Expression for Y variable with {year} placeholder
|
|
43
|
+
filter_y (str): Filter for Y variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
|
|
44
|
+
filter_y_rate (float): Filter rate for Y variable (e.g. 0.5)
|
|
45
|
+
samples (int): Number of samples to generate
|
|
46
|
+
tile_size (int): Size of tiles in degrees
|
|
47
|
+
crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
|
|
48
|
+
res (float, optional): Resolution in degrees. Defaults to 0.001
|
|
49
|
+
region (str, optional): Region code. Defaults to "aus"
|
|
50
|
+
start_year (int, optional): Start year for data generation. Required if end_year provided
|
|
51
|
+
end_year (int, optional): End year for data generation. Required if start_year provided
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
dict: Response from the AI dataset generation API
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
APIError: If the API request fails
|
|
58
|
+
"""
|
|
59
|
+
# Build config for expressions and filters
|
|
60
|
+
config = {
|
|
61
|
+
"expressions": [{"expr": expression_x, "res": res, "prefix": "x"}],
|
|
62
|
+
"filters": []
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if expression_y != "skip":
|
|
66
|
+
config["expressions"].append({"expr": expression_y, "res": res, "prefix": "y"})
|
|
67
|
+
|
|
68
|
+
if filter_x != "skip":
|
|
69
|
+
config["filters"].append({"expr": filter_x, "res": res, "rate": filter_x_rate})
|
|
70
|
+
if filter_y != "skip":
|
|
71
|
+
config["filters"].append({"expr": filter_y, "res": res, "rate": filter_y_rate})
|
|
72
|
+
|
|
73
|
+
# Replace year placeholders if start_year is provided
|
|
74
|
+
if start_year is not None:
|
|
75
|
+
expression_x = expression_x.replace("{year}", str(start_year))
|
|
76
|
+
if expression_y != "skip":
|
|
77
|
+
expression_y = expression_y.replace("{year}", str(start_year))
|
|
78
|
+
if filter_x != "skip":
|
|
79
|
+
filter_x = filter_x.replace("{year}", str(start_year))
|
|
80
|
+
if filter_y != "skip":
|
|
81
|
+
filter_y = filter_y.replace("{year}", str(start_year))
|
|
82
|
+
|
|
83
|
+
# Load AOI GeoJSON
|
|
84
|
+
with open(aoi_geojson, 'r') as f:
|
|
85
|
+
aoi_data = json.load(f)
|
|
86
|
+
|
|
87
|
+
task_response = self._client.mass_stats.random_sample(
|
|
88
|
+
name=name,
|
|
89
|
+
config=config,
|
|
90
|
+
aoi=aoi_data,
|
|
91
|
+
samples=samples,
|
|
92
|
+
year_range=[start_year, end_year],
|
|
93
|
+
crs=crs,
|
|
94
|
+
tile_size=tile_size,
|
|
95
|
+
res=res,
|
|
96
|
+
region=region,
|
|
97
|
+
output="netcdf",
|
|
98
|
+
server=self._client.url,
|
|
99
|
+
bucket="terrakio-mass-requests",
|
|
100
|
+
overwrite=True
|
|
101
|
+
)
|
|
102
|
+
task_id = task_response["task_id"]
|
|
103
|
+
|
|
104
|
+
# Wait for job completion with progress bar
|
|
105
|
+
while True:
|
|
106
|
+
result = self._client.track_mass_stats_job(ids=[task_id])
|
|
107
|
+
status = result[task_id]['status']
|
|
108
|
+
completed = result[task_id].get('completed', 0)
|
|
109
|
+
total = result[task_id].get('total', 1)
|
|
110
|
+
|
|
111
|
+
# Create progress bar
|
|
112
|
+
progress = completed / total if total > 0 else 0
|
|
113
|
+
bar_length = 50
|
|
114
|
+
filled_length = int(bar_length * progress)
|
|
115
|
+
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
|
116
|
+
percentage = progress * 100
|
|
117
|
+
|
|
118
|
+
# Print status with progress bar
|
|
119
|
+
print(f"\rJob status: {status} [{bar}] {percentage:.1f}% ({completed}/{total})", end='')
|
|
120
|
+
|
|
121
|
+
if status == "Completed":
|
|
122
|
+
print("\nJob completed successfully!")
|
|
123
|
+
break
|
|
124
|
+
elif status == "Error":
|
|
125
|
+
print("\n") # New line before error message
|
|
126
|
+
raise Exception(f"Job {task_id} encountered an error")
|
|
127
|
+
|
|
128
|
+
# Wait 5 seconds before checking again
|
|
129
|
+
time.sleep(5)
|
|
130
|
+
|
|
131
|
+
# after all the random sample jobs are done, we then start the mass stats job
|
|
132
|
+
task_id = self._client.mass_stats.start_mass_stats_job(task_id)
|
|
133
|
+
return task_id
|
|
134
|
+
|
|
135
|
+
@require_api_key
|
|
136
|
+
async def upload_model(self, model_path: str):
|
|
137
|
+
"""
|
|
138
|
+
Upload a model to the bucket so that it can be used for inference.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
model_path: Path to the model file
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
APIError: If the API request fails
|
|
145
|
+
"""
|
|
146
|
+
uid = (await self._client.auth.get_user_info())["uid"]
|
|
147
|
+
model_name = os.path.basename(model_path)
|
|
148
|
+
|
|
149
|
+
client = storage.Client()
|
|
150
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
151
|
+
model_file_name = os.path.splitext(model_name)[0]
|
|
152
|
+
blob = bucket.blob(f'{uid}/{model_file_name}/models/{model_name}')
|
|
153
|
+
|
|
154
|
+
blob.upload_from_filename(model_path)
|
|
155
|
+
self._client.logger.info(f"Model uploaded successfully to {uid}/{model_name}/models/{model_name}")
|
|
156
|
+
|
|
157
|
+
@require_api_key
|
|
158
|
+
def upload_and_deploy_model(self, model_path: str, dataset: str, product: str, input_expression: str, dates_iso8601: list):
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
Upload a model to the bucket and deploy it.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
model_path: Path to the model file
|
|
165
|
+
dataset Name of the dataset to create
|
|
166
|
+
product: Product name for the inference
|
|
167
|
+
input_expression: Input expression for the dataset
|
|
168
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
169
|
+
"""
|
|
170
|
+
self.upload_model(model_path = model_path)
|
|
171
|
+
model_name = os.path.basename(model_path)
|
|
172
|
+
self.deploy_model(dataset = dataset, product = product, model_name = model_name, input_expression = input_expression, model_training_job_name = model_name, dates_iso8601 = dates_iso8601)
|
|
173
|
+
|
|
174
|
+
@require_api_key
|
|
175
|
+
def train_model(
|
|
176
|
+
self,
|
|
177
|
+
model_name: str,
|
|
178
|
+
training_dataset: str,
|
|
179
|
+
task_type: str,
|
|
180
|
+
model_category: str,
|
|
181
|
+
architecture: str,
|
|
182
|
+
region: str,
|
|
183
|
+
hyperparameters: dict = None
|
|
184
|
+
) -> dict:
|
|
185
|
+
"""
|
|
186
|
+
Train a model using the external model training API.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
model_name (str): The name of the model to train.
|
|
190
|
+
training_dataset (str): The training dataset identifier.
|
|
191
|
+
task_type (str): The type of ML task (e.g., regression, classification).
|
|
192
|
+
model_category (str): The category of model (e.g., random_forest).
|
|
193
|
+
architecture (str): The model architecture.
|
|
194
|
+
region (str): The region identifier.
|
|
195
|
+
hyperparameters (dict, optional): Additional hyperparameters for training.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
dict: The response from the model training API.
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
APIError: If the API request fails
|
|
202
|
+
"""
|
|
203
|
+
payload = {
|
|
204
|
+
"model_name": model_name,
|
|
205
|
+
"training_dataset": training_dataset,
|
|
206
|
+
"task_type": task_type,
|
|
207
|
+
"model_category": model_category,
|
|
208
|
+
"architecture": architecture,
|
|
209
|
+
"region": region,
|
|
210
|
+
"hyperparameters": hyperparameters
|
|
211
|
+
}
|
|
212
|
+
return self._client._terrakio_request("POST", "/train_model", json=payload)
|
|
213
|
+
|
|
214
|
+
@require_api_key
|
|
215
|
+
def deploy_model(
|
|
216
|
+
self,
|
|
217
|
+
dataset: str,
|
|
218
|
+
product: str,
|
|
219
|
+
model_name: str,
|
|
220
|
+
input_expression: str,
|
|
221
|
+
model_training_job_name: str,
|
|
222
|
+
dates_iso8601: list
|
|
223
|
+
) -> Dict[str, Any]:
|
|
224
|
+
"""
|
|
225
|
+
Deploy a model by generating inference script and creating dataset.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
dataset: Name of the dataset to create
|
|
229
|
+
product: Product name for the inference
|
|
230
|
+
model_name: Name of the trained model
|
|
231
|
+
input_expression: Input expression for the dataset
|
|
232
|
+
model_training_job_name: Name of the training job
|
|
233
|
+
dates_iso8601: List of dates in ISO8601 format
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
dict: Response from the deployment process
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
APIError: If the API request fails
|
|
240
|
+
"""
|
|
241
|
+
# Get user info to get UID
|
|
242
|
+
user_info = self._client.get_user_info()
|
|
243
|
+
uid = user_info["uid"]
|
|
244
|
+
|
|
245
|
+
# Generate and upload script
|
|
246
|
+
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
247
|
+
script_name = f"{product}.py"
|
|
248
|
+
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
249
|
+
|
|
250
|
+
# Create dataset
|
|
251
|
+
return self._client.datasets.create_dataset(
|
|
252
|
+
name=dataset,
|
|
253
|
+
collection="terrakio-datasets",
|
|
254
|
+
products=[product],
|
|
255
|
+
path=f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts",
|
|
256
|
+
input=input_expression,
|
|
257
|
+
dates_iso8601=dates_iso8601,
|
|
258
|
+
padding=0
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
@require_api_key
|
|
262
|
+
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
263
|
+
"""
|
|
264
|
+
Generate Python inference script for the model.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
model_name: Name of the model
|
|
268
|
+
product: Product name
|
|
269
|
+
model_training_job_name: Training job name
|
|
270
|
+
uid: User ID
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
str: Generated Python script content
|
|
274
|
+
"""
|
|
275
|
+
return textwrap.dedent(f'''
|
|
276
|
+
import logging
|
|
277
|
+
from io import BytesIO
|
|
278
|
+
|
|
279
|
+
import numpy as np
|
|
280
|
+
import pandas as pd
|
|
281
|
+
import xarray as xr
|
|
282
|
+
from google.cloud import storage
|
|
283
|
+
from onnxruntime import InferenceSession
|
|
284
|
+
|
|
285
|
+
logging.basicConfig(
|
|
286
|
+
level=logging.INFO
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def get_model():
|
|
290
|
+
logging.info("Loading model for {model_name}...")
|
|
291
|
+
|
|
292
|
+
client = storage.Client()
|
|
293
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
294
|
+
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
295
|
+
|
|
296
|
+
model = BytesIO()
|
|
297
|
+
blob.download_to_file(model)
|
|
298
|
+
model.seek(0)
|
|
299
|
+
|
|
300
|
+
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
301
|
+
return session
|
|
302
|
+
|
|
303
|
+
def {product}(*bands, model):
|
|
304
|
+
logging.info("start preparing data")
|
|
305
|
+
|
|
306
|
+
data_arrays = list(bands)
|
|
307
|
+
|
|
308
|
+
reference_array = data_arrays[0]
|
|
309
|
+
original_shape = reference_array.shape
|
|
310
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
311
|
+
|
|
312
|
+
if 'time' in reference_array.dims:
|
|
313
|
+
time_coords = reference_array.coords['time']
|
|
314
|
+
if len(time_coords) == 1:
|
|
315
|
+
output_timestamp = time_coords[0]
|
|
316
|
+
else:
|
|
317
|
+
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
318
|
+
unique_years = set(years)
|
|
319
|
+
|
|
320
|
+
if len(unique_years) == 1:
|
|
321
|
+
year = list(unique_years)[0]
|
|
322
|
+
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
323
|
+
else:
|
|
324
|
+
latest_year = max(unique_years)
|
|
325
|
+
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
326
|
+
else:
|
|
327
|
+
output_timestamp = pd.Timestamp("1970-01-01")
|
|
328
|
+
|
|
329
|
+
averaged_bands = []
|
|
330
|
+
for data_array in data_arrays:
|
|
331
|
+
if 'time' in data_array.dims:
|
|
332
|
+
averaged_band = np.mean(data_array.values, axis=0)
|
|
333
|
+
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
334
|
+
else:
|
|
335
|
+
averaged_band = data_array.values
|
|
336
|
+
logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
|
|
337
|
+
|
|
338
|
+
flattened_band = averaged_band.reshape(-1, 1)
|
|
339
|
+
averaged_bands.append(flattened_band)
|
|
340
|
+
|
|
341
|
+
input_data = np.hstack(averaged_bands)
|
|
342
|
+
|
|
343
|
+
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
344
|
+
|
|
345
|
+
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
346
|
+
|
|
347
|
+
logging.info(f"Model output shape: {{output.shape}}")
|
|
348
|
+
|
|
349
|
+
if len(original_shape) >= 3:
|
|
350
|
+
spatial_shape = original_shape[1:]
|
|
351
|
+
else:
|
|
352
|
+
spatial_shape = original_shape
|
|
353
|
+
|
|
354
|
+
output_reshaped = output.reshape(spatial_shape)
|
|
355
|
+
|
|
356
|
+
output_with_time = np.expand_dims(output_reshaped, axis=0)
|
|
357
|
+
|
|
358
|
+
if 'time' in reference_array.dims:
|
|
359
|
+
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
360
|
+
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
|
|
361
|
+
else:
|
|
362
|
+
spatial_dims = list(reference_array.dims)
|
|
363
|
+
spatial_coords = dict(reference_array.coords)
|
|
364
|
+
|
|
365
|
+
result = xr.DataArray(
|
|
366
|
+
data=output_with_time.astype(np.float32),
|
|
367
|
+
dims=['time'] + list(spatial_dims),
|
|
368
|
+
coords={{
|
|
369
|
+
'time': [output_timestamp.values],
|
|
370
|
+
'y': spatial_coords['y'].values,
|
|
371
|
+
'x': spatial_coords['x'].values
|
|
372
|
+
}}
|
|
373
|
+
)
|
|
374
|
+
return result
|
|
375
|
+
''').strip()
|
|
376
|
+
|
|
377
|
+
@require_api_key
|
|
378
|
+
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
379
|
+
"""Upload the generated script to Google Cloud Storage"""
|
|
380
|
+
|
|
381
|
+
client = storage.Client()
|
|
382
|
+
bucket = client.get_bucket('terrakio-mass-requests')
|
|
383
|
+
blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
|
|
384
|
+
blob.upload_from_string(script_content, content_type='text/plain')
|
|
385
|
+
logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from typing import Dict, Any, Optional
|
|
2
|
+
from ..helper.decorators import require_token, require_api_key, require_auth
|
|
3
|
+
class SpaceManagement:
|
|
4
|
+
def __init__(self, client):
|
|
5
|
+
self._client = client
|
|
6
|
+
|
|
7
|
+
@require_api_key
|
|
8
|
+
def get_total_space_used(self) -> Dict[str, Any]:
|
|
9
|
+
"""
|
|
10
|
+
Get total space used by the user.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
Dict[str, Any]: Total space used by the user.
|
|
14
|
+
|
|
15
|
+
Raises:
|
|
16
|
+
APIError: If the API request fails
|
|
17
|
+
"""
|
|
18
|
+
return self._client._terrakio_request("GET", "/users/jobs")
|
|
19
|
+
|
|
20
|
+
@require_api_key
|
|
21
|
+
def get_space_used_by_job(self, name: str, region: str) -> Dict[str, Any]:
|
|
22
|
+
"""
|
|
23
|
+
Get space used by a specific job.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
name: The name of the job
|
|
27
|
+
region: The region of the job
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Dict[str, Any]: Space used by the job.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
APIError: If the API request fails
|
|
34
|
+
"""
|
|
35
|
+
params = {"region": region}
|
|
36
|
+
return self._client._terrakio_request("GET", f"/users/jobs/{name}", params=params)
|
|
37
|
+
|
|
38
|
+
@require_api_key
|
|
39
|
+
def delete_user_job(self, name: str, region: str) -> Dict[str, Any]:
|
|
40
|
+
"""
|
|
41
|
+
Delete a user job by name and region.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
name: The name of the job
|
|
45
|
+
region: The region of the job
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Dict[str, Any]: Response from the delete operation.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
APIError: If the API request fails
|
|
52
|
+
"""
|
|
53
|
+
params = {"region": region}
|
|
54
|
+
return self._client._terrakio_request("DELETE", f"/users/jobs/{name}", params=params)
|
|
55
|
+
|
|
56
|
+
@require_api_key
|
|
57
|
+
def delete_data_in_path(self, path: str, region: str) -> Dict[str, Any]:
|
|
58
|
+
"""
|
|
59
|
+
Delete data in a GCS path for a given region.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
path: The GCS path to delete data from
|
|
63
|
+
region: The region where the data is located
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Dict[str, Any]: Response from the delete operation.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
APIError: If the API request fails
|
|
70
|
+
"""
|
|
71
|
+
params = {"path": path, "region": region}
|
|
72
|
+
return self._client._terrakio_request("DELETE", "/users/jobs", params=params)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Optional
|
|
2
|
+
from ..helper.decorators import require_token, require_api_key, require_auth
|
|
3
|
+
|
|
4
|
+
class UserManagement:
|
|
5
|
+
def __init__(self, client):
|
|
6
|
+
self._client = client
|
|
7
|
+
|
|
8
|
+
@require_api_key
|
|
9
|
+
def get_user_by_id(self, id: str) -> Dict[str, Any]:
|
|
10
|
+
"""
|
|
11
|
+
Get user by ID.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
user_id: User ID
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
User information
|
|
18
|
+
|
|
19
|
+
Raises:
|
|
20
|
+
APIError: If the API request fails
|
|
21
|
+
"""
|
|
22
|
+
return self._client._terrakio_request("GET", f"admin/users/{id}")
|
|
23
|
+
|
|
24
|
+
@require_api_key
|
|
25
|
+
def get_user_by_email(self, email: str) -> Dict[str, Any]:
|
|
26
|
+
"""
|
|
27
|
+
Get user by email.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
email: User email
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
User information
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
APIError: If the API request fails
|
|
37
|
+
"""
|
|
38
|
+
return self._client._terrakio_request("GET", f"admin/users/email/{email}")
|
|
39
|
+
|
|
40
|
+
@require_api_key
|
|
41
|
+
def list_users(self, substring: Optional[str] = None, uid: bool = False) -> List[Dict[str, Any]]:
|
|
42
|
+
"""
|
|
43
|
+
List users, optionally filtering by a substring.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
substring: Optional substring to filter users
|
|
47
|
+
uid: If True, includes the user ID in the response (default: False)
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of users
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
APIError: If the API request fails
|
|
54
|
+
"""
|
|
55
|
+
params = {"uid": str(uid).lower()}
|
|
56
|
+
if substring:
|
|
57
|
+
params['substring'] = substring
|
|
58
|
+
return self._client._terrakio_request("GET", "admin/users", params=params)
|
|
59
|
+
|
|
60
|
+
@require_api_key
|
|
61
|
+
def edit_user(
|
|
62
|
+
self,
|
|
63
|
+
uid: str,
|
|
64
|
+
email: Optional[str],
|
|
65
|
+
role: Optional[str],
|
|
66
|
+
apiKey: Optional[str],
|
|
67
|
+
groups: Optional[List[str]],
|
|
68
|
+
quota: Optional[int]
|
|
69
|
+
) -> Dict[str, Any]:
|
|
70
|
+
"""
|
|
71
|
+
Edit user info. Only provided fields will be updated.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
uid: User ID
|
|
75
|
+
email: New user email
|
|
76
|
+
role: New user role
|
|
77
|
+
apiKey: New API key
|
|
78
|
+
groups: New list of groups
|
|
79
|
+
quota: New quota
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Updated user information
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
APIError: If the API request fails
|
|
86
|
+
"""
|
|
87
|
+
payload = {"uid": uid}
|
|
88
|
+
payload_mapping = {
|
|
89
|
+
"email": email,
|
|
90
|
+
"role": role,
|
|
91
|
+
"apiKey": apiKey,
|
|
92
|
+
"groups": groups,
|
|
93
|
+
"quota": quota
|
|
94
|
+
}
|
|
95
|
+
for key, value in payload_mapping.items():
|
|
96
|
+
if value is not None:
|
|
97
|
+
payload[key] = value
|
|
98
|
+
return self._client._terrakio_request("PATCH", "admin/users", json=payload)
|
|
99
|
+
|
|
100
|
+
@require_api_key
|
|
101
|
+
def reset_quota(self, email: str, quota: Optional[int] = None) -> Dict[str, Any]:
|
|
102
|
+
"""
|
|
103
|
+
Reset the quota for a user by email.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
email: The user's email (required)
|
|
107
|
+
quota: The new quota value (optional)
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
API response as a dictionary
|
|
111
|
+
"""
|
|
112
|
+
payload = {"email": email}
|
|
113
|
+
if quota is not None:
|
|
114
|
+
payload["quota"] = quota
|
|
115
|
+
return self._client._terrakio_request("PATCH", f"admin/users/reset_quota/{email}", json=payload)
|
|
116
|
+
|
|
117
|
+
@require_api_key
|
|
118
|
+
def delete_user(self, uid: str) -> Dict[str, Any]:
|
|
119
|
+
"""
|
|
120
|
+
Delete a user by UID.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
uid: The user's UID (required)
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
API response as a dictionary
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
APIError: If the API request fails
|
|
130
|
+
"""
|
|
131
|
+
return self._client._terrakio_request("DELETE", f"admin/users/{uid}")
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
# Adapted from https://discuss.python.org/t/boundedtaskgroup-to-control-parallelism/27171
|
|
4
|
+
|
|
5
|
+
class BoundedTaskGroup(asyncio.TaskGroup):
|
|
6
|
+
def __init__(self, *args, max_concurrency = 0, **kwargs):
|
|
7
|
+
super().__init__(*args)
|
|
8
|
+
if max_concurrency:
|
|
9
|
+
self._sem = asyncio.Semaphore(max_concurrency)
|
|
10
|
+
else:
|
|
11
|
+
self._sem = None
|
|
12
|
+
|
|
13
|
+
def create_task(self, coro, *args, **kwargs):
|
|
14
|
+
if self._sem:
|
|
15
|
+
async def _wrapped_coro(sem, coro):
|
|
16
|
+
async with sem:
|
|
17
|
+
return await coro
|
|
18
|
+
coro = _wrapped_coro(self._sem, coro)
|
|
19
|
+
|
|
20
|
+
return super().create_task(coro, *args, **kwargs)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# terrakio_core/decorators.py
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from ..exceptions import ConfigurationError
|
|
4
|
+
|
|
5
|
+
def require_token(func):
|
|
6
|
+
"""Decorator to ensure a token is available before a method can be executed."""
|
|
7
|
+
@wraps(func)
|
|
8
|
+
def wrapper(self, *args, **kwargs):
|
|
9
|
+
# Check both direct token and client token
|
|
10
|
+
has_token = False
|
|
11
|
+
if hasattr(self, 'token') and self.token:
|
|
12
|
+
has_token = True
|
|
13
|
+
elif hasattr(self, '_client') and hasattr(self._client, 'token') and self._client.token:
|
|
14
|
+
has_token = True
|
|
15
|
+
|
|
16
|
+
if not has_token:
|
|
17
|
+
raise ConfigurationError("Authentication token required. Please login first.")
|
|
18
|
+
return func(self, *args, **kwargs)
|
|
19
|
+
|
|
20
|
+
wrapper._is_decorated = True
|
|
21
|
+
return wrapper
|
|
22
|
+
|
|
23
|
+
def require_api_key(func):
|
|
24
|
+
"""Decorator to ensure an API key is available before a method can be executed."""
|
|
25
|
+
@wraps(func)
|
|
26
|
+
def wrapper(self, *args, **kwargs):
|
|
27
|
+
# Check both direct key and client key
|
|
28
|
+
has_key = False
|
|
29
|
+
if hasattr(self, 'key') and self.key:
|
|
30
|
+
has_key = True
|
|
31
|
+
elif hasattr(self, '_client') and hasattr(self._client, 'key') and self._client.key:
|
|
32
|
+
has_key = True
|
|
33
|
+
|
|
34
|
+
if not has_key:
|
|
35
|
+
raise ConfigurationError("API key required. Please provide an API key or login first.")
|
|
36
|
+
return func(self, *args, **kwargs)
|
|
37
|
+
|
|
38
|
+
wrapper._is_decorated = True
|
|
39
|
+
return wrapper
|
|
40
|
+
|
|
41
|
+
def require_auth(func):
|
|
42
|
+
"""Decorator that requires either a token OR an API key"""
|
|
43
|
+
@wraps(func)
|
|
44
|
+
def wrapper(self, *args, **kwargs):
|
|
45
|
+
# Check both direct auth and client auth
|
|
46
|
+
has_token = (hasattr(self, 'token') and self.token) or \
|
|
47
|
+
(hasattr(self, '_client') and hasattr(self._client, 'token') and self._client.token)
|
|
48
|
+
has_api_key = (hasattr(self, 'key') and self.key) or \
|
|
49
|
+
(hasattr(self, '_client') and hasattr(self._client, 'key') and self._client.key)
|
|
50
|
+
|
|
51
|
+
if not has_token and not has_api_key:
|
|
52
|
+
raise ConfigurationError(
|
|
53
|
+
"Authentication required. Please provide either an API key or login to get a token."
|
|
54
|
+
)
|
|
55
|
+
return func(self, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
wrapper._is_decorated = True
|
|
58
|
+
return wrapper
|