terrakio-core 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of terrakio-core might be problematic. Click here for more details.
- terrakio_core/__init__.py +1 -1
- terrakio_core/auth.py +0 -15
- terrakio_core/client.py +489 -170
- terrakio_core/config.py +63 -21
- terrakio_core/dataset_management.py +62 -10
- terrakio_core/mass_stats.py +81 -73
- {terrakio_core-0.3.1.dist-info → terrakio_core-0.3.3.dist-info}/METADATA +2 -1
- terrakio_core-0.3.3.dist-info/RECORD +16 -0
- terrakio_core-0.3.1.dist-info/RECORD +0 -16
- {terrakio_core-0.3.1.dist-info → terrakio_core-0.3.3.dist-info}/WHEEL +0 -0
- {terrakio_core-0.3.1.dist-info → terrakio_core-0.3.3.dist-info}/top_level.txt +0 -0
terrakio_core/client.py
CHANGED
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import asyncio
|
|
3
3
|
from io import BytesIO
|
|
4
4
|
from typing import Dict, Any, Optional, Union
|
|
5
|
+
from functools import wraps
|
|
5
6
|
|
|
6
7
|
import requests
|
|
7
8
|
import aiohttp
|
|
@@ -18,8 +19,39 @@ import logging
|
|
|
18
19
|
import textwrap
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
def require_api_key(func):
|
|
23
|
+
"""
|
|
24
|
+
Decorator to ensure an API key is available before a method can be executed.
|
|
25
|
+
Will check for the presence of an API key before allowing the function to be called.
|
|
26
|
+
"""
|
|
27
|
+
@wraps(func)
|
|
28
|
+
def wrapper(self, *args, **kwargs):
|
|
29
|
+
# Check if the API key is set
|
|
30
|
+
if not self.key:
|
|
31
|
+
error_msg = "No API key found. Please provide an API key to use this client."
|
|
32
|
+
if not self.quiet:
|
|
33
|
+
print(error_msg)
|
|
34
|
+
print("API key is required for this function.")
|
|
35
|
+
print("You can set an API key by either:")
|
|
36
|
+
print("1. Loading from a config file")
|
|
37
|
+
print("2. Using the login() method: client.login(email='your-email@example.com', password='your-password')")
|
|
38
|
+
raise ConfigurationError(error_msg)
|
|
39
|
+
|
|
40
|
+
# Check if URL is set (required for API calls)
|
|
41
|
+
if not self.url:
|
|
42
|
+
if not self.quiet:
|
|
43
|
+
print("API URL is not set. Using default API URL: https://api.terrak.io")
|
|
44
|
+
self.url = "https://api.terrak.io"
|
|
45
|
+
|
|
46
|
+
return func(self, *args, **kwargs)
|
|
47
|
+
|
|
48
|
+
# Mark as decorated to avoid double decoration
|
|
49
|
+
wrapper._is_decorated = True
|
|
50
|
+
return wrapper
|
|
51
|
+
|
|
52
|
+
|
|
21
53
|
class BaseClient:
|
|
22
|
-
def __init__(self, url: Optional[str] = None,
|
|
54
|
+
def __init__(self, url: Optional[str] = None,
|
|
23
55
|
auth_url: Optional[str] = "https://dev-au.terrak.io",
|
|
24
56
|
quiet: bool = False, config_file: Optional[str] = None,
|
|
25
57
|
verify: bool = True, timeout: int = 300):
|
|
@@ -28,57 +60,96 @@ class BaseClient:
|
|
|
28
60
|
self.verify = verify
|
|
29
61
|
self.timeout = timeout
|
|
30
62
|
self.auth_client = None
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
63
|
+
# Initialize session early to avoid AttributeError
|
|
64
|
+
self.session = requests.Session()
|
|
65
|
+
self.user_management = None
|
|
66
|
+
self.dataset_management = None
|
|
67
|
+
self.mass_stats = None
|
|
68
|
+
self._aiohttp_session = None
|
|
69
|
+
|
|
70
|
+
# if auth_url:
|
|
71
|
+
# from terrakio_core.auth import AuthClient
|
|
72
|
+
# self.auth_client = AuthClient(
|
|
73
|
+
# base_url=auth_url,
|
|
74
|
+
# verify=verify,
|
|
75
|
+
# timeout=timeout
|
|
76
|
+
# )
|
|
38
77
|
self.url = url
|
|
39
|
-
self.key =
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
78
|
+
self.key = None
|
|
79
|
+
|
|
80
|
+
# Load configuration from file
|
|
81
|
+
from terrakio_core.config import read_config_file, DEFAULT_CONFIG_FILE
|
|
82
|
+
|
|
83
|
+
if config_file is None:
|
|
84
|
+
config_file = DEFAULT_CONFIG_FILE
|
|
85
|
+
|
|
86
|
+
# Get config using the read_config_file function that now handles all cases
|
|
87
|
+
config = read_config_file(config_file, quiet=self.quiet)
|
|
88
|
+
|
|
89
|
+
# If URL is not provided, try to get it from config
|
|
90
|
+
if self.url is None:
|
|
91
|
+
self.url = config.get('url')
|
|
92
|
+
|
|
93
|
+
# Get API key from config file (never from parameters)
|
|
94
|
+
self.key = config.get('key')
|
|
95
|
+
|
|
96
|
+
self.token = config.get('token')
|
|
97
|
+
|
|
98
|
+
# Update auth_client with token if it exists
|
|
99
|
+
if self.auth_client and self.token:
|
|
100
|
+
self.auth_client.token = self.token
|
|
101
|
+
self.auth_client.session.headers.update({
|
|
102
|
+
"Authorization": self.token
|
|
103
|
+
})
|
|
104
|
+
|
|
105
|
+
# If we have a key, we're good to go even if not logged in
|
|
106
|
+
if self.key:
|
|
107
|
+
# If URL is missing but we have a key, use the default URL
|
|
108
|
+
if not self.url:
|
|
109
|
+
self.url = "https://api.terrak.io"
|
|
110
|
+
print("the token is! ", self.token)
|
|
111
|
+
print("the key is! ", self.key)
|
|
112
|
+
# Update session headers with the key
|
|
113
|
+
headers = {
|
|
114
|
+
'Content-Type': 'application/json',
|
|
115
|
+
'x-api-key': self.key
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
# Only add Authorization header if token exists and is not None
|
|
119
|
+
if self.token:
|
|
120
|
+
headers['Authorization'] = self.token
|
|
121
|
+
|
|
122
|
+
self.session.headers.update(headers)
|
|
123
|
+
print("the session headers are ", self.session.headers)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
if not self.quiet and config.get('user_email'):
|
|
127
|
+
print(f"Using API key for: {config.get('user_email')}")
|
|
128
|
+
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
# Check if we have the required configuration
|
|
66
132
|
if not self.key:
|
|
67
|
-
|
|
133
|
+
# No API key available - inform the user
|
|
134
|
+
if not self.quiet:
|
|
135
|
+
print("API key is required to use this client.")
|
|
136
|
+
print("You can set an API key by either:")
|
|
137
|
+
print("1. Loading from a config file")
|
|
138
|
+
print("2. Using the login() method: client.login(email='your-email@example.com', password='your-password')")
|
|
139
|
+
return
|
|
140
|
+
|
|
68
141
|
self.url = self.url.rstrip('/')
|
|
69
142
|
if not self.quiet:
|
|
70
143
|
print(f"Using Terrakio API at: {self.url}")
|
|
71
|
-
|
|
144
|
+
|
|
145
|
+
# Update the session headers with API key
|
|
72
146
|
self.session.headers.update({
|
|
73
147
|
'Content-Type': 'application/json',
|
|
74
148
|
'x-api-key': self.key
|
|
75
149
|
})
|
|
76
|
-
self.user_management = None
|
|
77
|
-
self.dataset_management = None
|
|
78
|
-
self.mass_stats = None
|
|
79
|
-
self._aiohttp_session = None
|
|
80
150
|
|
|
81
151
|
@property
|
|
152
|
+
@require_api_key
|
|
82
153
|
async def aiohttp_session(self):
|
|
83
154
|
if self._aiohttp_session is None or self._aiohttp_session.closed:
|
|
84
155
|
self._aiohttp_session = aiohttp.ClientSession(
|
|
@@ -90,6 +161,7 @@ class BaseClient:
|
|
|
90
161
|
)
|
|
91
162
|
return self._aiohttp_session
|
|
92
163
|
|
|
164
|
+
@require_api_key
|
|
93
165
|
async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
|
|
94
166
|
in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
|
|
95
167
|
output: str = "csv", resolution: int = -1, buffer: bool = False,
|
|
@@ -130,6 +202,7 @@ class BaseClient:
|
|
|
130
202
|
"resolution": resolution,
|
|
131
203
|
**kwargs
|
|
132
204
|
}
|
|
205
|
+
print("the payload is ", payload)
|
|
133
206
|
request_url = f"{self.url}/geoquery"
|
|
134
207
|
for attempt in range(retry + 1):
|
|
135
208
|
try:
|
|
@@ -186,38 +259,67 @@ class BaseClient:
|
|
|
186
259
|
raise
|
|
187
260
|
continue
|
|
188
261
|
|
|
262
|
+
@require_api_key
|
|
189
263
|
async def close_async(self):
|
|
190
264
|
"""Close the aiohttp session"""
|
|
191
265
|
if self._aiohttp_session and not self._aiohttp_session.closed:
|
|
192
266
|
await self._aiohttp_session.close()
|
|
193
267
|
self._aiohttp_session = None
|
|
194
268
|
|
|
269
|
+
@require_api_key
|
|
195
270
|
async def __aenter__(self):
|
|
196
271
|
return self
|
|
197
272
|
|
|
273
|
+
@require_api_key
|
|
198
274
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
199
275
|
await self.close_async()
|
|
200
276
|
|
|
201
277
|
def signup(self, email: str, password: str) -> Dict[str, Any]:
|
|
202
278
|
if not self.auth_client:
|
|
203
|
-
|
|
279
|
+
from terrakio_core.auth import AuthClient
|
|
280
|
+
self.auth_client = AuthClient(
|
|
281
|
+
base_url=self.url,
|
|
282
|
+
verify=self.verify,
|
|
283
|
+
timeout=self.timeout
|
|
284
|
+
)
|
|
285
|
+
self.auth_client.session = self.session
|
|
204
286
|
return self.auth_client.signup(email, password)
|
|
205
287
|
|
|
206
288
|
def login(self, email: str, password: str) -> Dict[str, str]:
|
|
289
|
+
|
|
207
290
|
if not self.auth_client:
|
|
208
|
-
|
|
209
|
-
|
|
291
|
+
from terrakio_core.auth import AuthClient
|
|
292
|
+
self.auth_client = AuthClient(
|
|
293
|
+
base_url=self.url,
|
|
294
|
+
verify=self.verify,
|
|
295
|
+
timeout=self.timeout
|
|
296
|
+
)
|
|
297
|
+
self.auth_client.session = self.session
|
|
210
298
|
try:
|
|
211
299
|
# First attempt to login
|
|
212
300
|
token_response = self.auth_client.login(email, password)
|
|
213
301
|
|
|
214
|
-
print("the token response is ", token_response)
|
|
215
302
|
# Only proceed with API key retrieval if login was successful
|
|
216
303
|
if token_response:
|
|
217
304
|
# After successful login, get the API key
|
|
218
|
-
api_key_response = self.view_api_key()
|
|
305
|
+
api_key_response = self.auth_client.view_api_key()
|
|
219
306
|
self.key = api_key_response
|
|
220
307
|
|
|
308
|
+
# Make sure URL is set
|
|
309
|
+
if not self.url:
|
|
310
|
+
self.url = "https://api.terrak.io"
|
|
311
|
+
|
|
312
|
+
# Make sure session headers are updated with the new key
|
|
313
|
+
self.session.headers.update({
|
|
314
|
+
'Content-Type': 'application/json',
|
|
315
|
+
'x-api-key': self.key
|
|
316
|
+
})
|
|
317
|
+
|
|
318
|
+
# Set URL if not already set
|
|
319
|
+
if not self.url:
|
|
320
|
+
self.url = "https://api.terrak.io"
|
|
321
|
+
self.url = self.url.rstrip('/')
|
|
322
|
+
|
|
221
323
|
# Save email and API key to config file
|
|
222
324
|
import os
|
|
223
325
|
import json
|
|
@@ -229,6 +331,7 @@ class BaseClient:
|
|
|
229
331
|
config = json.load(f)
|
|
230
332
|
config["EMAIL"] = email
|
|
231
333
|
config["TERRAKIO_API_KEY"] = self.key
|
|
334
|
+
config["PERSONAL_TOKEN"] = token_response
|
|
232
335
|
|
|
233
336
|
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
|
234
337
|
with open(config_path, 'w') as f:
|
|
@@ -236,6 +339,7 @@ class BaseClient:
|
|
|
236
339
|
|
|
237
340
|
if not self.quiet:
|
|
238
341
|
print(f"Successfully authenticated as: {email}")
|
|
342
|
+
print(f"Using Terrakio API at: {self.url}")
|
|
239
343
|
print(f"API key saved to {config_path}")
|
|
240
344
|
except Exception as e:
|
|
241
345
|
if not self.quiet:
|
|
@@ -247,47 +351,81 @@ class BaseClient:
|
|
|
247
351
|
print(f"Login failed: {str(e)}")
|
|
248
352
|
raise
|
|
249
353
|
|
|
354
|
+
@require_api_key
|
|
250
355
|
def refresh_api_key(self) -> str:
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
356
|
+
# If we have auth_client and it has a token, refresh the key
|
|
357
|
+
if self.auth_client and self.auth_client.token:
|
|
358
|
+
self.key = self.auth_client.refresh_api_key()
|
|
359
|
+
self.session.headers.update({'x-api-key': self.key})
|
|
360
|
+
|
|
361
|
+
# Update the config file with the new key
|
|
362
|
+
import os
|
|
363
|
+
config_path = os.path.join(os.environ.get("HOME", ""), ".tkio_config.json")
|
|
364
|
+
try:
|
|
365
|
+
config = {"EMAIL": "", "TERRAKIO_API_KEY": ""}
|
|
366
|
+
if os.path.exists(config_path):
|
|
367
|
+
with open(config_path, 'r') as f:
|
|
368
|
+
config = json.load(f)
|
|
369
|
+
config["TERRAKIO_API_KEY"] = self.key
|
|
370
|
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
|
371
|
+
with open(config_path, 'w') as f:
|
|
372
|
+
json.dump(config, f, indent=4)
|
|
373
|
+
if not self.quiet:
|
|
374
|
+
print(f"API key generated successfully and updated in {config_path}")
|
|
375
|
+
except Exception as e:
|
|
376
|
+
if not self.quiet:
|
|
377
|
+
print(f"Warning: Failed to update config file: {e}")
|
|
378
|
+
return self.key
|
|
379
|
+
else:
|
|
380
|
+
# If we don't have auth_client with a token but have a key already, return it
|
|
381
|
+
if self.key:
|
|
382
|
+
if not self.quiet:
|
|
383
|
+
print("Using existing API key from config file.")
|
|
384
|
+
return self.key
|
|
385
|
+
else:
|
|
386
|
+
raise ConfigurationError("No authentication token available. Please login first to refresh the API key.")
|
|
387
|
+
|
|
388
|
+
@require_api_key
|
|
389
|
+
def view_api_key(self) -> str:
|
|
390
|
+
# If we have the auth client and token, refresh the key
|
|
391
|
+
if self.auth_client and self.auth_client.token:
|
|
392
|
+
self.key = self.auth_client.view_api_key()
|
|
393
|
+
self.session.headers.update({'x-api-key': self.key})
|
|
394
|
+
|
|
273
395
|
return self.key
|
|
274
396
|
|
|
397
|
+
@require_api_key
|
|
275
398
|
def view_api_key(self) -> str:
|
|
399
|
+
# If we have the auth client and token, refresh the key
|
|
276
400
|
if not self.auth_client:
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
401
|
+
from terrakio_core.auth import AuthClient
|
|
402
|
+
self.auth_client = AuthClient(
|
|
403
|
+
base_url=self.url,
|
|
404
|
+
verify=self.verify,
|
|
405
|
+
timeout=self.timeout
|
|
406
|
+
)
|
|
407
|
+
self.auth_client.session = self.session
|
|
408
|
+
return self.auth_client.view_api_key()
|
|
283
409
|
|
|
410
|
+
# @require_api_key
|
|
411
|
+
# def get_user_info(self) -> Dict[str, Any]:
|
|
412
|
+
# return self.auth_client.get_user_info()
|
|
413
|
+
@require_api_key
|
|
284
414
|
def get_user_info(self) -> Dict[str, Any]:
|
|
415
|
+
# Initialize auth_client if it doesn't exist
|
|
285
416
|
if not self.auth_client:
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
417
|
+
from terrakio_core.auth import AuthClient
|
|
418
|
+
self.auth_client = AuthClient(
|
|
419
|
+
base_url=self.url,
|
|
420
|
+
verify=self.verify,
|
|
421
|
+
timeout=self.timeout
|
|
422
|
+
)
|
|
423
|
+
# Use the same session as the base client
|
|
424
|
+
self.auth_client.session = self.session
|
|
425
|
+
|
|
289
426
|
return self.auth_client.get_user_info()
|
|
290
427
|
|
|
428
|
+
@require_api_key
|
|
291
429
|
def wcs(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry], in_crs: str = "epsg:4326",
|
|
292
430
|
out_crs: str = "epsg:4326", output: str = "csv", resolution: int = -1,
|
|
293
431
|
**kwargs):
|
|
@@ -340,6 +478,7 @@ class BaseClient:
|
|
|
340
478
|
raise APIError(f"Request failed: {str(e)}")
|
|
341
479
|
|
|
342
480
|
# Admin/protected methods
|
|
481
|
+
@require_api_key
|
|
343
482
|
def _get_user_by_id(self, user_id: str):
|
|
344
483
|
if not self.user_management:
|
|
345
484
|
from terrakio_core.user_management import UserManagement
|
|
@@ -353,6 +492,7 @@ class BaseClient:
|
|
|
353
492
|
)
|
|
354
493
|
return self.user_management.get_user_by_id(user_id)
|
|
355
494
|
|
|
495
|
+
@require_api_key
|
|
356
496
|
def _get_user_by_email(self, email: str):
|
|
357
497
|
if not self.user_management:
|
|
358
498
|
from terrakio_core.user_management import UserManagement
|
|
@@ -366,6 +506,7 @@ class BaseClient:
|
|
|
366
506
|
)
|
|
367
507
|
return self.user_management.get_user_by_email(email)
|
|
368
508
|
|
|
509
|
+
@require_api_key
|
|
369
510
|
def _list_users(self, substring: str = None, uid: bool = False):
|
|
370
511
|
if not self.user_management:
|
|
371
512
|
from terrakio_core.user_management import UserManagement
|
|
@@ -379,6 +520,7 @@ class BaseClient:
|
|
|
379
520
|
)
|
|
380
521
|
return self.user_management.list_users(substring=substring, uid=uid)
|
|
381
522
|
|
|
523
|
+
@require_api_key
|
|
382
524
|
def _edit_user(self, user_id: str, uid: str = None, email: str = None, role: str = None, apiKey: str = None, groups: list = None, quota: int = None):
|
|
383
525
|
if not self.user_management:
|
|
384
526
|
from terrakio_core.user_management import UserManagement
|
|
@@ -400,6 +542,7 @@ class BaseClient:
|
|
|
400
542
|
quota=quota
|
|
401
543
|
)
|
|
402
544
|
|
|
545
|
+
@require_api_key
|
|
403
546
|
def _reset_quota(self, email: str, quota: int = None):
|
|
404
547
|
if not self.user_management:
|
|
405
548
|
from terrakio_core.user_management import UserManagement
|
|
@@ -413,6 +556,7 @@ class BaseClient:
|
|
|
413
556
|
)
|
|
414
557
|
return self.user_management.reset_quota(email=email, quota=quota)
|
|
415
558
|
|
|
559
|
+
@require_api_key
|
|
416
560
|
def _delete_user(self, uid: str):
|
|
417
561
|
if not self.user_management:
|
|
418
562
|
from terrakio_core.user_management import UserManagement
|
|
@@ -427,6 +571,7 @@ class BaseClient:
|
|
|
427
571
|
return self.user_management.delete_user(uid=uid)
|
|
428
572
|
|
|
429
573
|
# Dataset management protected methods
|
|
574
|
+
@require_api_key
|
|
430
575
|
def _get_dataset(self, name: str, collection: str = "terrakio-datasets"):
|
|
431
576
|
if not self.dataset_management:
|
|
432
577
|
from terrakio_core.dataset_management import DatasetManagement
|
|
@@ -440,6 +585,7 @@ class BaseClient:
|
|
|
440
585
|
)
|
|
441
586
|
return self.dataset_management.get_dataset(name=name, collection=collection)
|
|
442
587
|
|
|
588
|
+
@require_api_key
|
|
443
589
|
def _list_datasets(self, substring: str = None, collection: str = "terrakio-datasets"):
|
|
444
590
|
if not self.dataset_management:
|
|
445
591
|
from terrakio_core.dataset_management import DatasetManagement
|
|
@@ -453,6 +599,7 @@ class BaseClient:
|
|
|
453
599
|
)
|
|
454
600
|
return self.dataset_management.list_datasets(substring=substring, collection=collection)
|
|
455
601
|
|
|
602
|
+
@require_api_key
|
|
456
603
|
def _create_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs):
|
|
457
604
|
if not self.dataset_management:
|
|
458
605
|
from terrakio_core.dataset_management import DatasetManagement
|
|
@@ -466,6 +613,7 @@ class BaseClient:
|
|
|
466
613
|
)
|
|
467
614
|
return self.dataset_management.create_dataset(name=name, collection=collection, **kwargs)
|
|
468
615
|
|
|
616
|
+
@require_api_key
|
|
469
617
|
def _update_dataset(self, name: str, append: bool = True, collection: str = "terrakio-datasets", **kwargs):
|
|
470
618
|
if not self.dataset_management:
|
|
471
619
|
from terrakio_core.dataset_management import DatasetManagement
|
|
@@ -479,6 +627,7 @@ class BaseClient:
|
|
|
479
627
|
)
|
|
480
628
|
return self.dataset_management.update_dataset(name=name, append=append, collection=collection, **kwargs)
|
|
481
629
|
|
|
630
|
+
@require_api_key
|
|
482
631
|
def _overwrite_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs):
|
|
483
632
|
if not self.dataset_management:
|
|
484
633
|
from terrakio_core.dataset_management import DatasetManagement
|
|
@@ -492,6 +641,7 @@ class BaseClient:
|
|
|
492
641
|
)
|
|
493
642
|
return self.dataset_management.overwrite_dataset(name=name, collection=collection, **kwargs)
|
|
494
643
|
|
|
644
|
+
@require_api_key
|
|
495
645
|
def _delete_dataset(self, name: str, collection: str = "terrakio-datasets"):
|
|
496
646
|
if not self.dataset_management:
|
|
497
647
|
from terrakio_core.dataset_management import DatasetManagement
|
|
@@ -505,6 +655,7 @@ class BaseClient:
|
|
|
505
655
|
)
|
|
506
656
|
return self.dataset_management.delete_dataset(name=name, collection=collection)
|
|
507
657
|
|
|
658
|
+
@require_api_key
|
|
508
659
|
def close(self):
|
|
509
660
|
"""Close all client sessions"""
|
|
510
661
|
self.session.close()
|
|
@@ -531,13 +682,16 @@ class BaseClient:
|
|
|
531
682
|
# Event loop may already be closed, ignore
|
|
532
683
|
pass
|
|
533
684
|
|
|
685
|
+
@require_api_key
|
|
534
686
|
def __enter__(self):
|
|
535
687
|
return self
|
|
536
688
|
|
|
689
|
+
@require_api_key
|
|
537
690
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
538
691
|
self.close()
|
|
539
692
|
|
|
540
693
|
@admin_only_params('location', 'force_loc', 'server')
|
|
694
|
+
@require_api_key
|
|
541
695
|
def execute_job(self, name, region, output, config, overwrite=False, skip_existing=False, request_json=None, manifest_json=None, location=None, force_loc=None, server="dev-au.terrak.io"):
|
|
542
696
|
if not self.mass_stats:
|
|
543
697
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -552,6 +706,7 @@ class BaseClient:
|
|
|
552
706
|
return self.mass_stats.execute_job(name, region, output, config, overwrite, skip_existing, request_json, manifest_json, location, force_loc, server)
|
|
553
707
|
|
|
554
708
|
|
|
709
|
+
@require_api_key
|
|
555
710
|
def get_mass_stats_task_id(self, name, stage, uid=None):
|
|
556
711
|
if not self.mass_stats:
|
|
557
712
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -565,7 +720,8 @@ class BaseClient:
|
|
|
565
720
|
)
|
|
566
721
|
return self.mass_stats.get_task_id(name, stage, uid)
|
|
567
722
|
|
|
568
|
-
|
|
723
|
+
@require_api_key
|
|
724
|
+
def track_mass_stats_job(self, ids: Optional[list] = None):
|
|
569
725
|
if not self.mass_stats:
|
|
570
726
|
from terrakio_core.mass_stats import MassStats
|
|
571
727
|
if not self.url or not self.key:
|
|
@@ -578,6 +734,7 @@ class BaseClient:
|
|
|
578
734
|
)
|
|
579
735
|
return self.mass_stats.track_job(ids)
|
|
580
736
|
|
|
737
|
+
@require_api_key
|
|
581
738
|
def get_mass_stats_history(self, limit=100):
|
|
582
739
|
if not self.mass_stats:
|
|
583
740
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -591,6 +748,7 @@ class BaseClient:
|
|
|
591
748
|
)
|
|
592
749
|
return self.mass_stats.get_history(limit)
|
|
593
750
|
|
|
751
|
+
@require_api_key
|
|
594
752
|
def start_mass_stats_post_processing(self, process_name, data_name, output, consumer_path, overwrite=False):
|
|
595
753
|
if not self.mass_stats:
|
|
596
754
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -604,6 +762,7 @@ class BaseClient:
|
|
|
604
762
|
)
|
|
605
763
|
return self.mass_stats.start_post_processing(process_name, data_name, output, consumer_path, overwrite)
|
|
606
764
|
|
|
765
|
+
@require_api_key
|
|
607
766
|
def download_mass_stats_results(self, id=None, force_loc=False, **kwargs):
|
|
608
767
|
if not self.mass_stats:
|
|
609
768
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -617,6 +776,7 @@ class BaseClient:
|
|
|
617
776
|
)
|
|
618
777
|
return self.mass_stats.download_results(id, force_loc, **kwargs)
|
|
619
778
|
|
|
779
|
+
@require_api_key
|
|
620
780
|
def cancel_mass_stats_job(self, id):
|
|
621
781
|
if not self.mass_stats:
|
|
622
782
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -630,6 +790,7 @@ class BaseClient:
|
|
|
630
790
|
)
|
|
631
791
|
return self.mass_stats.cancel_job(id)
|
|
632
792
|
|
|
793
|
+
@require_api_key
|
|
633
794
|
def cancel_all_mass_stats_jobs(self):
|
|
634
795
|
if not self.mass_stats:
|
|
635
796
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -643,6 +804,7 @@ class BaseClient:
|
|
|
643
804
|
)
|
|
644
805
|
return self.mass_stats.cancel_all_jobs()
|
|
645
806
|
|
|
807
|
+
@require_api_key
|
|
646
808
|
def _create_pyramids(self, name, levels, config):
|
|
647
809
|
if not self.mass_stats:
|
|
648
810
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -656,6 +818,7 @@ class BaseClient:
|
|
|
656
818
|
)
|
|
657
819
|
return self.mass_stats.create_pyramids(name, levels, config)
|
|
658
820
|
|
|
821
|
+
@require_api_key
|
|
659
822
|
def random_sample(self, name, **kwargs):
|
|
660
823
|
if not self.mass_stats:
|
|
661
824
|
from terrakio_core.mass_stats import MassStats
|
|
@@ -669,6 +832,7 @@ class BaseClient:
|
|
|
669
832
|
)
|
|
670
833
|
return self.mass_stats.random_sample(name, **kwargs)
|
|
671
834
|
|
|
835
|
+
@require_api_key
|
|
672
836
|
async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv",
|
|
673
837
|
in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
|
|
674
838
|
"""
|
|
@@ -857,6 +1021,7 @@ class BaseClient:
|
|
|
857
1021
|
else:
|
|
858
1022
|
return result_gdf
|
|
859
1023
|
|
|
1024
|
+
@require_api_key
|
|
860
1025
|
def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv",
|
|
861
1026
|
in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
|
|
862
1027
|
"""
|
|
@@ -919,75 +1084,80 @@ class BaseClient:
|
|
|
919
1084
|
return result
|
|
920
1085
|
|
|
921
1086
|
# Group access management protected methods
|
|
1087
|
+
@require_api_key
|
|
922
1088
|
def _get_group_users_and_datasets(self, group_name: str):
|
|
923
|
-
if not
|
|
1089
|
+
if not self.group_access:
|
|
924
1090
|
from terrakio_core.group_access_management import GroupAccessManagement
|
|
925
1091
|
if not self.url or not self.key:
|
|
926
1092
|
raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
|
|
927
|
-
self.
|
|
1093
|
+
self.group_access = GroupAccessManagement(
|
|
928
1094
|
api_url=self.url,
|
|
929
1095
|
api_key=self.key,
|
|
930
1096
|
verify=self.verify,
|
|
931
1097
|
timeout=self.timeout
|
|
932
1098
|
)
|
|
933
|
-
return self.
|
|
1099
|
+
return self.group_access.get_group_users_and_datasets(group_name=group_name)
|
|
934
1100
|
|
|
1101
|
+
@require_api_key
|
|
935
1102
|
def _add_group_to_dataset(self, dataset: str, group: str):
|
|
936
|
-
if not
|
|
1103
|
+
if not self.group_access:
|
|
937
1104
|
from terrakio_core.group_access_management import GroupAccessManagement
|
|
938
1105
|
if not self.url or not self.key:
|
|
939
1106
|
raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
|
|
940
|
-
self.
|
|
1107
|
+
self.group_access = GroupAccessManagement(
|
|
941
1108
|
api_url=self.url,
|
|
942
1109
|
api_key=self.key,
|
|
943
1110
|
verify=self.verify,
|
|
944
1111
|
timeout=self.timeout
|
|
945
1112
|
)
|
|
946
|
-
return self.
|
|
1113
|
+
return self.group_access.add_group_to_dataset(dataset=dataset, group=group)
|
|
947
1114
|
|
|
1115
|
+
@require_api_key
|
|
948
1116
|
def _add_group_to_user(self, uid: str, group: str):
|
|
949
|
-
if not
|
|
1117
|
+
if not self.group_access:
|
|
950
1118
|
from terrakio_core.group_access_management import GroupAccessManagement
|
|
951
1119
|
if not self.url or not self.key:
|
|
952
1120
|
raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
|
|
953
|
-
self.
|
|
1121
|
+
self.group_access = GroupAccessManagement(
|
|
954
1122
|
api_url=self.url,
|
|
955
1123
|
api_key=self.key,
|
|
956
1124
|
verify=self.verify,
|
|
957
1125
|
timeout=self.timeout
|
|
958
1126
|
)
|
|
959
|
-
|
|
960
|
-
return self.group_access_management.add_group_to_user(uid, group)
|
|
1127
|
+
return self.group_access.add_group_to_user(uid=uid, group=group)
|
|
961
1128
|
|
|
1129
|
+
@require_api_key
|
|
962
1130
|
def _delete_group_from_user(self, uid: str, group: str):
|
|
963
|
-
if not
|
|
1131
|
+
if not self.group_access:
|
|
964
1132
|
from terrakio_core.group_access_management import GroupAccessManagement
|
|
965
1133
|
if not self.url or not self.key:
|
|
966
1134
|
raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
|
|
967
|
-
self.
|
|
1135
|
+
self.group_access = GroupAccessManagement(
|
|
968
1136
|
api_url=self.url,
|
|
969
1137
|
api_key=self.key,
|
|
970
1138
|
verify=self.verify,
|
|
971
1139
|
timeout=self.timeout
|
|
972
1140
|
)
|
|
973
|
-
return self.
|
|
1141
|
+
return self.group_access.delete_group_from_user(uid=uid, group=group)
|
|
974
1142
|
|
|
1143
|
+
@require_api_key
|
|
975
1144
|
def _delete_group_from_dataset(self, dataset: str, group: str):
|
|
976
|
-
if not
|
|
1145
|
+
if not self.group_access:
|
|
977
1146
|
from terrakio_core.group_access_management import GroupAccessManagement
|
|
978
1147
|
if not self.url or not self.key:
|
|
979
1148
|
raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
|
|
980
|
-
self.
|
|
1149
|
+
self.group_access = GroupAccessManagement(
|
|
981
1150
|
api_url=self.url,
|
|
982
1151
|
api_key=self.key,
|
|
983
1152
|
verify=self.verify,
|
|
984
1153
|
timeout=self.timeout
|
|
985
1154
|
)
|
|
986
|
-
return self.
|
|
1155
|
+
return self.group_access.delete_group_from_dataset(dataset=dataset, group=group)
|
|
987
1156
|
|
|
988
1157
|
# Space management protected methods
|
|
1158
|
+
@require_api_key
|
|
989
1159
|
def _get_total_space_used(self):
|
|
990
|
-
if not
|
|
1160
|
+
if not self.space_management:
|
|
991
1161
|
from terrakio_core.space_management import SpaceManagement
|
|
992
1162
|
if not self.url or not self.key:
|
|
993
1163
|
raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
|
|
@@ -999,8 +1169,9 @@ class BaseClient:
|
|
|
999
1169
|
)
|
|
1000
1170
|
return self.space_management.get_total_space_used()
|
|
1001
1171
|
|
|
1172
|
+
@require_api_key
|
|
1002
1173
|
def _get_space_used_by_job(self, name: str, region: str = None):
|
|
1003
|
-
if not
|
|
1174
|
+
if not self.space_management:
|
|
1004
1175
|
from terrakio_core.space_management import SpaceManagement
|
|
1005
1176
|
if not self.url or not self.key:
|
|
1006
1177
|
raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
|
|
@@ -1012,8 +1183,9 @@ class BaseClient:
|
|
|
1012
1183
|
)
|
|
1013
1184
|
return self.space_management.get_space_used_by_job(name, region)
|
|
1014
1185
|
|
|
1186
|
+
@require_api_key
|
|
1015
1187
|
def _delete_user_job(self, name: str, region: str = None):
|
|
1016
|
-
if not
|
|
1188
|
+
if not self.space_management:
|
|
1017
1189
|
from terrakio_core.space_management import SpaceManagement
|
|
1018
1190
|
if not self.url or not self.key:
|
|
1019
1191
|
raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
|
|
@@ -1025,8 +1197,9 @@ class BaseClient:
|
|
|
1025
1197
|
)
|
|
1026
1198
|
return self.space_management.delete_user_job(name, region)
|
|
1027
1199
|
|
|
1200
|
+
@require_api_key
|
|
1028
1201
|
def _delete_data_in_path(self, path: str, region: str = None):
|
|
1029
|
-
if not
|
|
1202
|
+
if not self.space_management:
|
|
1030
1203
|
from terrakio_core.space_management import SpaceManagement
|
|
1031
1204
|
if not self.url or not self.key:
|
|
1032
1205
|
raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
|
|
@@ -1038,6 +1211,22 @@ class BaseClient:
|
|
|
1038
1211
|
)
|
|
1039
1212
|
return self.space_management.delete_data_in_path(path, region)
|
|
1040
1213
|
|
|
1214
|
+
@require_api_key
|
|
1215
|
+
def start_mass_stats_job(self, task_id):
|
|
1216
|
+
if not self.mass_stats:
|
|
1217
|
+
from terrakio_core.mass_stats import MassStats
|
|
1218
|
+
if not self.url or not self.key:
|
|
1219
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1220
|
+
self.mass_stats = MassStats(
|
|
1221
|
+
base_url=self.url,
|
|
1222
|
+
api_key=self.key,
|
|
1223
|
+
verify=self.verify,
|
|
1224
|
+
timeout=self.timeout
|
|
1225
|
+
)
|
|
1226
|
+
return self.mass_stats.start_job(task_id)
|
|
1227
|
+
|
|
1228
|
+
|
|
1229
|
+
@require_api_key
|
|
1041
1230
|
def generate_ai_dataset(
|
|
1042
1231
|
self,
|
|
1043
1232
|
name: str,
|
|
@@ -1107,44 +1296,32 @@ class BaseClient:
|
|
|
1107
1296
|
overwrite=True
|
|
1108
1297
|
)["task_id"]
|
|
1109
1298
|
print("the task id is ", task_id)
|
|
1110
|
-
task_id = self.start_mass_stats_job(task_id)
|
|
1111
|
-
print("the task id is ", task_id)
|
|
1112
|
-
return task_id
|
|
1113
1299
|
|
|
1300
|
+
# Wait for job completion
|
|
1301
|
+
import time
|
|
1302
|
+
|
|
1303
|
+
while True:
|
|
1304
|
+
result = self.track_mass_stats_job(ids=[task_id])
|
|
1305
|
+
status = result[task_id]['status']
|
|
1306
|
+
print(f"Job status: {status}")
|
|
1307
|
+
|
|
1308
|
+
if status == "Completed":
|
|
1309
|
+
break
|
|
1310
|
+
elif status == "Error":
|
|
1311
|
+
raise Exception(f"Job {task_id} encountered an error")
|
|
1312
|
+
|
|
1313
|
+
# Wait 30 seconds before checking again
|
|
1314
|
+
time.sleep(30)
|
|
1114
1315
|
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
# Args:
|
|
1120
|
-
# model_name (str): The name of the model to train.
|
|
1121
|
-
# training_data (dict): Dictionary containing training data parameters.
|
|
1122
|
-
|
|
1123
|
-
# Returns:
|
|
1124
|
-
# dict: The response from the model training API.
|
|
1125
|
-
# """
|
|
1126
|
-
# endpoint = "https://modeltraining-573248941006.australia-southeast1.run.app/train_model"
|
|
1127
|
-
# payload = {
|
|
1128
|
-
# "model_name": model_name,
|
|
1129
|
-
# "training_data": training_data
|
|
1130
|
-
# }
|
|
1131
|
-
# try:
|
|
1132
|
-
# response = self.session.post(endpoint, json=payload, timeout=self.timeout, verify=self.verify)
|
|
1133
|
-
# if not response.ok:
|
|
1134
|
-
# error_msg = f"Model training request failed: {response.status_code} {response.reason}"
|
|
1135
|
-
# try:
|
|
1136
|
-
# error_data = response.json()
|
|
1137
|
-
# if "detail" in error_data:
|
|
1138
|
-
# error_msg += f" - {error_data['detail']}"
|
|
1139
|
-
# except Exception:
|
|
1140
|
-
# if response.text:
|
|
1141
|
-
# error_msg += f" - {response.text}"
|
|
1142
|
-
# raise APIError(error_msg)
|
|
1143
|
-
# return response.json()
|
|
1144
|
-
# except requests.RequestException as e:
|
|
1145
|
-
# raise APIError(f"Model training request failed: {str(e)}")
|
|
1316
|
+
# print("the result is ", result)
|
|
1317
|
+
# after all the random sample jos are done, we then start the mass stats job
|
|
1318
|
+
task_id = self.start_mass_stats_job(task_id)
|
|
1319
|
+
# now we hav ethe random sampel
|
|
1146
1320
|
|
|
1321
|
+
# print("the task id is ", task_id)
|
|
1322
|
+
return task_id
|
|
1147
1323
|
|
|
1324
|
+
@require_api_key
|
|
1148
1325
|
def train_model(self, model_name: str, training_dataset: str, task_type: str, model_category: str, architecture: str, region: str, hyperparameters: dict = None) -> dict:
|
|
1149
1326
|
"""
|
|
1150
1327
|
Train a model using the external model training API.
|
|
@@ -1189,6 +1366,7 @@ class BaseClient:
|
|
|
1189
1366
|
raise APIError(f"Model training request failed: {str(e)}")
|
|
1190
1367
|
|
|
1191
1368
|
# Mass Stats methods
|
|
1369
|
+
@require_api_key
|
|
1192
1370
|
def combine_tiles(self,
|
|
1193
1371
|
data_name: str,
|
|
1194
1372
|
usezarr: bool,
|
|
@@ -1209,7 +1387,8 @@ class BaseClient:
|
|
|
1209
1387
|
|
|
1210
1388
|
|
|
1211
1389
|
|
|
1212
|
-
|
|
1390
|
+
@require_api_key
|
|
1391
|
+
def create_dataset_file(
|
|
1213
1392
|
self,
|
|
1214
1393
|
name: str,
|
|
1215
1394
|
aoi: str,
|
|
@@ -1224,9 +1403,22 @@ class BaseClient:
|
|
|
1224
1403
|
skip_existing: bool = False,
|
|
1225
1404
|
non_interactive: bool = True,
|
|
1226
1405
|
usezarr: bool = False,
|
|
1227
|
-
poll_interval: int = 30
|
|
1406
|
+
poll_interval: int = 30,
|
|
1407
|
+
download_path: str = "/home/user/Downloads",
|
|
1228
1408
|
) -> dict:
|
|
1229
1409
|
|
|
1410
|
+
if not self.mass_stats:
|
|
1411
|
+
from terrakio_core.mass_stats import MassStats
|
|
1412
|
+
if not self.url or not self.key:
|
|
1413
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1414
|
+
self.mass_stats = MassStats(
|
|
1415
|
+
base_url=self.url,
|
|
1416
|
+
api_key=self.key,
|
|
1417
|
+
verify=self.verify,
|
|
1418
|
+
timeout=self.timeout
|
|
1419
|
+
)
|
|
1420
|
+
|
|
1421
|
+
|
|
1230
1422
|
from terrakio_core.generation.tiles import tiles
|
|
1231
1423
|
import tempfile
|
|
1232
1424
|
import time
|
|
@@ -1315,44 +1507,70 @@ class BaseClient:
|
|
|
1315
1507
|
os.unlink(tempmanifestname)
|
|
1316
1508
|
|
|
1317
1509
|
|
|
1510
|
+
# return self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
|
|
1511
|
+
|
|
1318
1512
|
# Start combining tiles
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
if not self.url or not self.key:
|
|
1322
|
-
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1323
|
-
self.mass_stats = MassStats(
|
|
1324
|
-
base_url=self.url,
|
|
1325
|
-
api_key=self.key,
|
|
1326
|
-
verify=self.verify,
|
|
1327
|
-
timeout=self.timeout
|
|
1328
|
-
)
|
|
1329
|
-
|
|
1330
|
-
return self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
|
|
1331
|
-
|
|
1332
|
-
|
|
1513
|
+
combine_result = self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
|
|
1514
|
+
combine_task_id = combine_result.get("task_id")
|
|
1333
1515
|
|
|
1516
|
+
# Poll combine_tiles job status
|
|
1517
|
+
combine_start_time = time.time()
|
|
1518
|
+
while True:
|
|
1519
|
+
try:
|
|
1520
|
+
trackinfo = self.mass_stats.track_job([combine_task_id])
|
|
1521
|
+
download_file_name = trackinfo[combine_task_id]['folder'] + '.nc'
|
|
1522
|
+
bucket = trackinfo[combine_task_id]['bucket']
|
|
1523
|
+
combine_status = trackinfo[combine_task_id]['status']
|
|
1524
|
+
if combine_status == 'Completed':
|
|
1525
|
+
print('Tiles combined successfully!')
|
|
1526
|
+
break
|
|
1527
|
+
elif combine_status in ['Failed', 'Cancelled', 'Error']:
|
|
1528
|
+
raise RuntimeError(f"Combine job {combine_task_id} failed with status: {combine_status}")
|
|
1529
|
+
else:
|
|
1530
|
+
elapsed_time = time.time() - combine_start_time
|
|
1531
|
+
print(f"Combine job status: {combine_status} - Elapsed time: {elapsed_time:.1f}s", end='\r')
|
|
1532
|
+
time.sleep(poll_interval)
|
|
1533
|
+
except KeyboardInterrupt:
|
|
1534
|
+
print(f"\nInterrupted! Combine job {combine_task_id} is still running in the background.")
|
|
1535
|
+
raise
|
|
1536
|
+
except Exception as e:
|
|
1537
|
+
print(f"\nError tracking combine job: {e}")
|
|
1538
|
+
raise
|
|
1334
1539
|
|
|
1540
|
+
# Download the resulting file
|
|
1541
|
+
if download_path:
|
|
1542
|
+
self.mass_stats.download_file(body["name"], bucket, download_file_name, download_path)
|
|
1543
|
+
else:
|
|
1544
|
+
path = f"{body['name']}/combinedOutput/{download_file_name}"
|
|
1545
|
+
print(f"Combined file is available at {path}")
|
|
1335
1546
|
|
|
1547
|
+
return {"generation_task_id": task_id, "combine_task_id": combine_task_id}
|
|
1336
1548
|
|
|
1337
1549
|
|
|
1550
|
+
# taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
|
|
1551
|
+
# trackinfo = self.mass_stats.track_job([taskid])
|
|
1552
|
+
# bucket = trackinfo[taskid]['bucket']
|
|
1553
|
+
# return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
|
|
1338
1554
|
|
|
1339
1555
|
|
|
1556
|
+
@require_api_key
|
|
1340
1557
|
def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
|
|
1341
1558
|
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
1342
1559
|
script_name = f"{product}.py"
|
|
1343
1560
|
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
1344
|
-
# after uploading the script, we need to create a new virtual dataset
|
|
1345
1561
|
self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
|
|
1346
1562
|
|
|
1563
|
+
@require_api_key
|
|
1347
1564
|
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
1348
1565
|
return textwrap.dedent(f'''
|
|
1349
1566
|
import logging
|
|
1350
1567
|
from io import BytesIO
|
|
1351
|
-
|
|
1352
|
-
from onnxruntime import InferenceSession
|
|
1568
|
+
|
|
1353
1569
|
import numpy as np
|
|
1570
|
+
import pandas as pd
|
|
1354
1571
|
import xarray as xr
|
|
1355
|
-
import
|
|
1572
|
+
from google.cloud import storage
|
|
1573
|
+
from onnxruntime import InferenceSession
|
|
1356
1574
|
|
|
1357
1575
|
logging.basicConfig(
|
|
1358
1576
|
level=logging.INFO
|
|
@@ -1360,47 +1578,96 @@ class BaseClient:
|
|
|
1360
1578
|
|
|
1361
1579
|
def get_model():
|
|
1362
1580
|
logging.info("Loading model for {model_name}...")
|
|
1363
|
-
|
|
1581
|
+
|
|
1364
1582
|
client = storage.Client()
|
|
1365
1583
|
bucket = client.get_bucket('terrakio-mass-requests')
|
|
1366
1584
|
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
1367
|
-
|
|
1585
|
+
|
|
1368
1586
|
model = BytesIO()
|
|
1369
1587
|
blob.download_to_file(model)
|
|
1370
1588
|
model.seek(0)
|
|
1371
|
-
|
|
1589
|
+
|
|
1372
1590
|
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
1373
1591
|
return session
|
|
1374
1592
|
|
|
1375
1593
|
def {product}(*bands, model):
|
|
1376
1594
|
logging.info("start preparing data")
|
|
1595
|
+
print("the bands are ", bands)
|
|
1377
1596
|
|
|
1378
|
-
|
|
1379
|
-
logging.info(f"Original shape: {{original_shape}}")
|
|
1597
|
+
data_arrays = list(bands)
|
|
1380
1598
|
|
|
1381
|
-
|
|
1382
|
-
for band in bands:
|
|
1383
|
-
transformed_band = band.values.reshape(-1,1)
|
|
1384
|
-
transformed_bands.append(transformed_band)
|
|
1599
|
+
print("the data arrays are ", [da.name for da in data_arrays])
|
|
1385
1600
|
|
|
1386
|
-
|
|
1601
|
+
reference_array = data_arrays[0]
|
|
1602
|
+
original_shape = reference_array.shape
|
|
1603
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
1387
1604
|
|
|
1605
|
+
if 'time' in reference_array.dims:
|
|
1606
|
+
time_coords = reference_array.coords['time']
|
|
1607
|
+
if len(time_coords) == 1:
|
|
1608
|
+
output_timestamp = time_coords[0]
|
|
1609
|
+
else:
|
|
1610
|
+
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
1611
|
+
unique_years = set(years)
|
|
1612
|
+
|
|
1613
|
+
if len(unique_years) == 1:
|
|
1614
|
+
year = list(unique_years)[0]
|
|
1615
|
+
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
1616
|
+
else:
|
|
1617
|
+
latest_year = max(unique_years)
|
|
1618
|
+
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
1619
|
+
else:
|
|
1620
|
+
output_timestamp = pd.Timestamp("1970-01-01")
|
|
1621
|
+
|
|
1622
|
+
averaged_bands = []
|
|
1623
|
+
for data_array in data_arrays:
|
|
1624
|
+
if 'time' in data_array.dims:
|
|
1625
|
+
averaged_band = np.mean(data_array.values, axis=0)
|
|
1626
|
+
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
1627
|
+
else:
|
|
1628
|
+
averaged_band = data_array.values
|
|
1629
|
+
logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
|
|
1630
|
+
|
|
1631
|
+
flattened_band = averaged_band.reshape(-1, 1)
|
|
1632
|
+
averaged_bands.append(flattened_band)
|
|
1633
|
+
|
|
1634
|
+
input_data = np.hstack(averaged_bands)
|
|
1635
|
+
|
|
1388
1636
|
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
1389
|
-
|
|
1637
|
+
|
|
1390
1638
|
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
1391
|
-
|
|
1639
|
+
|
|
1392
1640
|
logging.info(f"Model output shape: {{output.shape}}")
|
|
1393
1641
|
|
|
1394
|
-
|
|
1642
|
+
if len(original_shape) >= 3:
|
|
1643
|
+
spatial_shape = original_shape[1:]
|
|
1644
|
+
else:
|
|
1645
|
+
spatial_shape = original_shape
|
|
1646
|
+
|
|
1647
|
+
output_reshaped = output.reshape(spatial_shape)
|
|
1648
|
+
|
|
1649
|
+
output_with_time = np.expand_dims(output_reshaped, axis=0)
|
|
1650
|
+
|
|
1651
|
+
if 'time' in reference_array.dims:
|
|
1652
|
+
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
1653
|
+
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
|
|
1654
|
+
else:
|
|
1655
|
+
spatial_dims = list(reference_array.dims)
|
|
1656
|
+
spatial_coords = dict(reference_array.coords)
|
|
1657
|
+
|
|
1395
1658
|
result = xr.DataArray(
|
|
1396
|
-
data=
|
|
1397
|
-
dims=
|
|
1398
|
-
coords=
|
|
1659
|
+
data=output_with_time.astype(np.float32),
|
|
1660
|
+
dims=['time'] + list(spatial_dims),
|
|
1661
|
+
coords={
|
|
1662
|
+
'time': [output_timestamp.values],
|
|
1663
|
+
'y': spatial_coords['y'].values,
|
|
1664
|
+
'x': spatial_coords['x'].values
|
|
1665
|
+
}
|
|
1399
1666
|
)
|
|
1400
|
-
|
|
1401
1667
|
return result
|
|
1402
1668
|
''').strip()
|
|
1403
|
-
|
|
1669
|
+
|
|
1670
|
+
@require_api_key
|
|
1404
1671
|
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
1405
1672
|
"""Upload the generated script to Google Cloud Storage"""
|
|
1406
1673
|
|
|
@@ -1410,3 +1677,55 @@ class BaseClient:
|
|
|
1410
1677
|
blob.upload_from_string(script_content, content_type='text/plain')
|
|
1411
1678
|
logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
|
|
1412
1679
|
|
|
1680
|
+
|
|
1681
|
+
|
|
1682
|
+
|
|
1683
|
+
@require_api_key
|
|
1684
|
+
def download_file_to_path(self, job_name, stage, file_name, output_path):
|
|
1685
|
+
if not self.mass_stats:
|
|
1686
|
+
from terrakio_core.mass_stats import MassStats
|
|
1687
|
+
if not self.url or not self.key:
|
|
1688
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1689
|
+
self.mass_stats = MassStats(
|
|
1690
|
+
base_url=self.url,
|
|
1691
|
+
api_key=self.key,
|
|
1692
|
+
verify=self.verify,
|
|
1693
|
+
timeout=self.timeout
|
|
1694
|
+
)
|
|
1695
|
+
|
|
1696
|
+
# fetch bucket info based on job name and stage
|
|
1697
|
+
|
|
1698
|
+
taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
|
|
1699
|
+
trackinfo = self.mass_stats.track_job([taskid])
|
|
1700
|
+
bucket = trackinfo[taskid]['bucket']
|
|
1701
|
+
return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
|
|
1702
|
+
|
|
1703
|
+
|
|
1704
|
+
# Apply the @require_api_key decorator to ALL methods in BaseClient
|
|
1705
|
+
# except only the absolute minimum that shouldn't require auth
|
|
1706
|
+
def _apply_api_key_decorator():
|
|
1707
|
+
# Only these methods can be used without API key
|
|
1708
|
+
skip_auth_methods = [
|
|
1709
|
+
'__init__', 'login', 'signup'
|
|
1710
|
+
]
|
|
1711
|
+
|
|
1712
|
+
# Get all attributes of BaseClient
|
|
1713
|
+
for attr_name in dir(BaseClient):
|
|
1714
|
+
# Skip special methods and methods in skip list
|
|
1715
|
+
if attr_name.startswith('__') and attr_name not in ['__enter__', '__exit__', '__aenter__', '__aexit__'] or attr_name in skip_auth_methods:
|
|
1716
|
+
continue
|
|
1717
|
+
|
|
1718
|
+
# Get the attribute
|
|
1719
|
+
attr = getattr(BaseClient, attr_name)
|
|
1720
|
+
|
|
1721
|
+
# Skip if not callable (not a method) or already decorated
|
|
1722
|
+
if not callable(attr) or hasattr(attr, '_is_decorated'):
|
|
1723
|
+
continue
|
|
1724
|
+
|
|
1725
|
+
# Apply decorator to EVERY method not in skip list
|
|
1726
|
+
setattr(BaseClient, attr_name, require_api_key(attr))
|
|
1727
|
+
# Mark as decorated to avoid double decoration
|
|
1728
|
+
getattr(BaseClient, attr_name)._is_decorated = True
|
|
1729
|
+
|
|
1730
|
+
# Run the decorator application
|
|
1731
|
+
_apply_api_key_decorator()
|