terrakio-core 0.3.3__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

terrakio_core/client.py CHANGED
@@ -1,1731 +1,36 @@
1
- import json
2
- import asyncio
3
- from io import BytesIO
4
- from typing import Dict, Any, Optional, Union
5
- from functools import wraps
6
-
7
- import requests
8
- import aiohttp
9
- import pandas as pd
10
- import geopandas as gpd
11
- import xarray as xr
12
- import nest_asyncio
13
- from shapely.geometry import shape, mapping
14
- from shapely.geometry.base import BaseGeometry as ShapelyGeometry
15
- from google.cloud import storage
16
- from .exceptions import APIError, ConfigurationError
17
- from .decorators import admin_only_params
1
+ from typing import Optional
18
2
  import logging
19
- import textwrap
3
+ from terrakio_core.config import read_config_file, DEFAULT_CONFIG_FILE
4
+ from abc import abstractmethod
20
5
 
6
+ class BaseClient():
7
+ def __init__(self, url: Optional[str] = None, api_key: Optional[str] = None, verbose: bool = False):
21
8
 
22
- def require_api_key(func):
23
- """
24
- Decorator to ensure an API key is available before a method can be executed.
25
- Will check for the presence of an API key before allowing the function to be called.
26
- """
27
- @wraps(func)
28
- def wrapper(self, *args, **kwargs):
29
- # Check if the API key is set
30
- if not self.key:
31
- error_msg = "No API key found. Please provide an API key to use this client."
32
- if not self.quiet:
33
- print(error_msg)
34
- print("API key is required for this function.")
35
- print("You can set an API key by either:")
36
- print("1. Loading from a config file")
37
- print("2. Using the login() method: client.login(email='your-email@example.com', password='your-password')")
38
- raise ConfigurationError(error_msg)
39
-
40
- # Check if URL is set (required for API calls)
41
- if not self.url:
42
- if not self.quiet:
43
- print("API URL is not set. Using default API URL: https://api.terrak.io")
44
- self.url = "https://api.terrak.io"
45
-
46
- return func(self, *args, **kwargs)
47
-
48
- # Mark as decorated to avoid double decoration
49
- wrapper._is_decorated = True
50
- return wrapper
9
+ self.verbose = verbose
10
+ self.logger = logging.getLogger("terrakio")
11
+ if verbose:
12
+ self.logger.setLevel(logging.INFO)
13
+ else:
14
+ self.logger.setLevel(logging.WARNING)
51
15
 
16
+ self.timeout = 300
17
+ self.retry = 3
18
+
19
+ self.session = None
52
20
 
53
- class BaseClient:
54
- def __init__(self, url: Optional[str] = None,
55
- auth_url: Optional[str] = "https://dev-au.terrak.io",
56
- quiet: bool = False, config_file: Optional[str] = None,
57
- verify: bool = True, timeout: int = 300):
58
- nest_asyncio.apply()
59
- self.quiet = quiet
60
- self.verify = verify
61
- self.timeout = timeout
62
- self.auth_client = None
63
- # Initialize session early to avoid AttributeError
64
- self.session = requests.Session()
65
- self.user_management = None
66
- self.dataset_management = None
67
- self.mass_stats = None
68
- self._aiohttp_session = None
69
-
70
- # if auth_url:
71
- # from terrakio_core.auth import AuthClient
72
- # self.auth_client = AuthClient(
73
- # base_url=auth_url,
74
- # verify=verify,
75
- # timeout=timeout
76
- # )
77
21
  self.url = url
78
- self.key = None
79
-
80
- # Load configuration from file
81
- from terrakio_core.config import read_config_file, DEFAULT_CONFIG_FILE
82
-
83
- if config_file is None:
84
- config_file = DEFAULT_CONFIG_FILE
85
-
86
- # Get config using the read_config_file function that now handles all cases
87
- config = read_config_file(config_file, quiet=self.quiet)
88
-
89
- # If URL is not provided, try to get it from config
22
+ self.key = api_key
23
+
24
+ config = read_config_file( DEFAULT_CONFIG_FILE, logger = self.logger)
90
25
  if self.url is None:
91
26
  self.url = config.get('url')
92
27
 
93
- # Get API key from config file (never from parameters)
94
28
  self.key = config.get('key')
95
29
 
96
30
  self.token = config.get('token')
97
-
98
- # Update auth_client with token if it exists
99
- if self.auth_client and self.token:
100
- self.auth_client.token = self.token
101
- self.auth_client.session.headers.update({
102
- "Authorization": self.token
103
- })
104
-
105
- # If we have a key, we're good to go even if not logged in
106
- if self.key:
107
- # If URL is missing but we have a key, use the default URL
108
- if not self.url:
109
- self.url = "https://api.terrak.io"
110
- print("the token is! ", self.token)
111
- print("the key is! ", self.key)
112
- # Update session headers with the key
113
- headers = {
114
- 'Content-Type': 'application/json',
115
- 'x-api-key': self.key
116
- }
117
-
118
- # Only add Authorization header if token exists and is not None
119
- if self.token:
120
- headers['Authorization'] = self.token
121
-
122
- self.session.headers.update(headers)
123
- print("the session headers are ", self.session.headers)
124
-
125
-
126
- if not self.quiet and config.get('user_email'):
127
- print(f"Using API key for: {config.get('user_email')}")
128
-
129
- return
130
-
131
- # Check if we have the required configuration
132
- if not self.key:
133
- # No API key available - inform the user
134
- if not self.quiet:
135
- print("API key is required to use this client.")
136
- print("You can set an API key by either:")
137
- print("1. Loading from a config file")
138
- print("2. Using the login() method: client.login(email='your-email@example.com', password='your-password')")
139
- return
140
-
141
- self.url = self.url.rstrip('/')
142
- if not self.quiet:
143
- print(f"Using Terrakio API at: {self.url}")
144
-
145
- # Update the session headers with API key
146
- self.session.headers.update({
147
- 'Content-Type': 'application/json',
148
- 'x-api-key': self.key
149
- })
150
-
151
- @property
152
- @require_api_key
153
- async def aiohttp_session(self):
154
- if self._aiohttp_session is None or self._aiohttp_session.closed:
155
- self._aiohttp_session = aiohttp.ClientSession(
156
- headers={
157
- 'Content-Type': 'application/json',
158
- 'x-api-key': self.key
159
- },
160
- timeout=aiohttp.ClientTimeout(total=self.timeout)
161
- )
162
- return self._aiohttp_session
163
-
164
- @require_api_key
165
- async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
166
- in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
167
- output: str = "csv", resolution: int = -1, buffer: bool = False,
168
- retry: int = 3, **kwargs):
169
- """
170
- Asynchronous version of the wcs() method using aiohttp.
171
-
172
- Args:
173
- expr (str): The WCS expression to evaluate
174
- feature (Union[Dict[str, Any], ShapelyGeometry]): The geographic feature
175
- in_crs (str): Input coordinate reference system
176
- out_crs (str): Output coordinate reference system
177
- output (str): Output format ('csv' or 'netcdf')
178
- resolution (int): Resolution parameter
179
- buffer (bool): Whether to buffer the request (default True)
180
- retry (int): Number of retry attempts (default 3)
181
- **kwargs: Additional parameters to pass to the WCS request
182
-
183
- Returns:
184
- Union[pd.DataFrame, xr.Dataset, bytes]: The response data in the requested format
185
- """
186
- if hasattr(feature, 'is_valid'):
187
- from shapely.geometry import mapping
188
- feature = {
189
- "type": "Feature",
190
- "geometry": mapping(feature),
191
- "properties": {}
192
- }
193
-
194
- payload = {
195
- "feature": feature,
196
- "in_crs": in_crs,
197
- "out_crs": out_crs,
198
- "output": output,
199
- "resolution": resolution,
200
- "expr": expr,
201
- "buffer": buffer,
202
- "resolution": resolution,
203
- **kwargs
204
- }
205
- print("the payload is ", payload)
206
- request_url = f"{self.url}/geoquery"
207
- for attempt in range(retry + 1):
208
- try:
209
- session = await self.aiohttp_session
210
- async with session.post(request_url, json=payload, ssl=self.verify) as response:
211
- if not response.ok:
212
- should_retry = False
213
- if response.status in [408, 502, 503, 504]:
214
- should_retry = True
215
- elif response.status == 500:
216
- try:
217
- response_text = await response.text()
218
- if "Internal server error" not in response_text:
219
- should_retry = True
220
- except:
221
- should_retry = True
222
-
223
- if should_retry and attempt < retry:
224
- continue
225
- else:
226
- error_msg = f"API request failed: {response.status} {response.reason}"
227
- try:
228
- error_data = await response.json()
229
- if "detail" in error_data:
230
- error_msg += f" - {error_data['detail']}"
231
- except:
232
- pass
233
- raise APIError(error_msg)
234
-
235
- content = await response.read()
236
-
237
- if output.lower() == "csv":
238
- import pandas as pd
239
- df = pd.read_csv(BytesIO(content))
240
- return df
241
- elif output.lower() == "netcdf":
242
- return xr.open_dataset(BytesIO(content))
243
- else:
244
- try:
245
- return xr.open_dataset(BytesIO(content))
246
- except ValueError:
247
- import pandas as pd
248
- try:
249
- return pd.read_csv(BytesIO(content))
250
- except:
251
- return content
252
-
253
- except aiohttp.ClientError as e:
254
- if attempt == retry:
255
- raise APIError(f"Request failed: {str(e)}")
256
- continue
257
- except Exception as e:
258
- if attempt == retry:
259
- raise
260
- continue
261
-
262
- @require_api_key
263
- async def close_async(self):
264
- """Close the aiohttp session"""
265
- if self._aiohttp_session and not self._aiohttp_session.closed:
266
- await self._aiohttp_session.close()
267
- self._aiohttp_session = None
268
-
269
- @require_api_key
270
- async def __aenter__(self):
271
- return self
272
-
273
- @require_api_key
274
- async def __aexit__(self, exc_type, exc_val, exc_tb):
275
- await self.close_async()
276
-
277
- def signup(self, email: str, password: str) -> Dict[str, Any]:
278
- if not self.auth_client:
279
- from terrakio_core.auth import AuthClient
280
- self.auth_client = AuthClient(
281
- base_url=self.url,
282
- verify=self.verify,
283
- timeout=self.timeout
284
- )
285
- self.auth_client.session = self.session
286
- return self.auth_client.signup(email, password)
287
-
288
- def login(self, email: str, password: str) -> Dict[str, str]:
289
-
290
- if not self.auth_client:
291
- from terrakio_core.auth import AuthClient
292
- self.auth_client = AuthClient(
293
- base_url=self.url,
294
- verify=self.verify,
295
- timeout=self.timeout
296
- )
297
- self.auth_client.session = self.session
298
- try:
299
- # First attempt to login
300
- token_response = self.auth_client.login(email, password)
301
-
302
- # Only proceed with API key retrieval if login was successful
303
- if token_response:
304
- # After successful login, get the API key
305
- api_key_response = self.auth_client.view_api_key()
306
- self.key = api_key_response
307
-
308
- # Make sure URL is set
309
- if not self.url:
310
- self.url = "https://api.terrak.io"
311
-
312
- # Make sure session headers are updated with the new key
313
- self.session.headers.update({
314
- 'Content-Type': 'application/json',
315
- 'x-api-key': self.key
316
- })
317
-
318
- # Set URL if not already set
319
- if not self.url:
320
- self.url = "https://api.terrak.io"
321
- self.url = self.url.rstrip('/')
322
-
323
- # Save email and API key to config file
324
- import os
325
- import json
326
- config_path = os.path.join(os.environ.get("HOME", ""), ".tkio_config.json")
327
- try:
328
- config = {"EMAIL": email, "TERRAKIO_API_KEY": self.key}
329
- if os.path.exists(config_path):
330
- with open(config_path, 'r') as f:
331
- config = json.load(f)
332
- config["EMAIL"] = email
333
- config["TERRAKIO_API_KEY"] = self.key
334
- config["PERSONAL_TOKEN"] = token_response
335
-
336
- os.makedirs(os.path.dirname(config_path), exist_ok=True)
337
- with open(config_path, 'w') as f:
338
- json.dump(config, f, indent=4)
339
-
340
- if not self.quiet:
341
- print(f"Successfully authenticated as: {email}")
342
- print(f"Using Terrakio API at: {self.url}")
343
- print(f"API key saved to {config_path}")
344
- except Exception as e:
345
- if not self.quiet:
346
- print(f"Warning: Failed to update config file: {e}")
347
-
348
- return {"token": token_response} if token_response else {"error": "Login failed"}
349
- except Exception as e:
350
- if not self.quiet:
351
- print(f"Login failed: {str(e)}")
352
- raise
353
-
354
- @require_api_key
355
- def refresh_api_key(self) -> str:
356
- # If we have auth_client and it has a token, refresh the key
357
- if self.auth_client and self.auth_client.token:
358
- self.key = self.auth_client.refresh_api_key()
359
- self.session.headers.update({'x-api-key': self.key})
360
-
361
- # Update the config file with the new key
362
- import os
363
- config_path = os.path.join(os.environ.get("HOME", ""), ".tkio_config.json")
364
- try:
365
- config = {"EMAIL": "", "TERRAKIO_API_KEY": ""}
366
- if os.path.exists(config_path):
367
- with open(config_path, 'r') as f:
368
- config = json.load(f)
369
- config["TERRAKIO_API_KEY"] = self.key
370
- os.makedirs(os.path.dirname(config_path), exist_ok=True)
371
- with open(config_path, 'w') as f:
372
- json.dump(config, f, indent=4)
373
- if not self.quiet:
374
- print(f"API key generated successfully and updated in {config_path}")
375
- except Exception as e:
376
- if not self.quiet:
377
- print(f"Warning: Failed to update config file: {e}")
378
- return self.key
379
- else:
380
- # If we don't have auth_client with a token but have a key already, return it
381
- if self.key:
382
- if not self.quiet:
383
- print("Using existing API key from config file.")
384
- return self.key
385
- else:
386
- raise ConfigurationError("No authentication token available. Please login first to refresh the API key.")
387
-
388
- @require_api_key
389
- def view_api_key(self) -> str:
390
- # If we have the auth client and token, refresh the key
391
- if self.auth_client and self.auth_client.token:
392
- self.key = self.auth_client.view_api_key()
393
- self.session.headers.update({'x-api-key': self.key})
394
-
395
- return self.key
396
-
397
- @require_api_key
398
- def view_api_key(self) -> str:
399
- # If we have the auth client and token, refresh the key
400
- if not self.auth_client:
401
- from terrakio_core.auth import AuthClient
402
- self.auth_client = AuthClient(
403
- base_url=self.url,
404
- verify=self.verify,
405
- timeout=self.timeout
406
- )
407
- self.auth_client.session = self.session
408
- return self.auth_client.view_api_key()
409
-
410
- # @require_api_key
411
- # def get_user_info(self) -> Dict[str, Any]:
412
- # return self.auth_client.get_user_info()
413
- @require_api_key
414
- def get_user_info(self) -> Dict[str, Any]:
415
- # Initialize auth_client if it doesn't exist
416
- if not self.auth_client:
417
- from terrakio_core.auth import AuthClient
418
- self.auth_client = AuthClient(
419
- base_url=self.url,
420
- verify=self.verify,
421
- timeout=self.timeout
422
- )
423
- # Use the same session as the base client
424
- self.auth_client.session = self.session
425
-
426
- return self.auth_client.get_user_info()
427
-
428
- @require_api_key
429
- def wcs(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry], in_crs: str = "epsg:4326",
430
- out_crs: str = "epsg:4326", output: str = "csv", resolution: int = -1,
431
- **kwargs):
432
- if hasattr(feature, 'is_valid'):
433
- from shapely.geometry import mapping
434
- feature = {
435
- "type": "Feature",
436
- "geometry": mapping(feature),
437
- "properties": {}
438
- }
439
- payload = {
440
- "feature": feature,
441
- "in_crs": in_crs,
442
- "out_crs": out_crs,
443
- "output": output,
444
- "resolution": resolution,
445
- "expr": expr,
446
- **kwargs
447
- }
448
- request_url = f"{self.url}/geoquery"
449
- try:
450
- print("the request url is ", request_url)
451
- print("the payload is ", payload)
452
- response = self.session.post(request_url, json=payload, timeout=self.timeout, verify=self.verify)
453
- print("the response is ", response.text)
454
- if not response.ok:
455
- error_msg = f"API request failed: {response.status_code} {response.reason}"
456
- try:
457
- error_data = response.json()
458
- if "detail" in error_data:
459
- error_msg += f" - {error_data['detail']}"
460
- except:
461
- pass
462
- raise APIError(error_msg)
463
- if output.lower() == "csv":
464
- import pandas as pd
465
- return pd.read_csv(BytesIO(response.content))
466
- elif output.lower() == "netcdf":
467
- return xr.open_dataset(BytesIO(response.content))
468
- else:
469
- try:
470
- return xr.open_dataset(BytesIO(response.content))
471
- except ValueError:
472
- import pandas as pd
473
- try:
474
- return pd.read_csv(BytesIO(response.content))
475
- except:
476
- return response.content
477
- except requests.RequestException as e:
478
- raise APIError(f"Request failed: {str(e)}")
479
-
480
- # Admin/protected methods
481
- @require_api_key
482
- def _get_user_by_id(self, user_id: str):
483
- if not self.user_management:
484
- from terrakio_core.user_management import UserManagement
485
- if not self.url or not self.key:
486
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
487
- self.user_management = UserManagement(
488
- api_url=self.url,
489
- api_key=self.key,
490
- verify=self.verify,
491
- timeout=self.timeout
492
- )
493
- return self.user_management.get_user_by_id(user_id)
494
-
495
- @require_api_key
496
- def _get_user_by_email(self, email: str):
497
- if not self.user_management:
498
- from terrakio_core.user_management import UserManagement
499
- if not self.url or not self.key:
500
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
501
- self.user_management = UserManagement(
502
- api_url=self.url,
503
- api_key=self.key,
504
- verify=self.verify,
505
- timeout=self.timeout
506
- )
507
- return self.user_management.get_user_by_email(email)
508
-
509
- @require_api_key
510
- def _list_users(self, substring: str = None, uid: bool = False):
511
- if not self.user_management:
512
- from terrakio_core.user_management import UserManagement
513
- if not self.url or not self.key:
514
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
515
- self.user_management = UserManagement(
516
- api_url=self.url,
517
- api_key=self.key,
518
- verify=self.verify,
519
- timeout=self.timeout
520
- )
521
- return self.user_management.list_users(substring=substring, uid=uid)
522
-
523
- @require_api_key
524
- def _edit_user(self, user_id: str, uid: str = None, email: str = None, role: str = None, apiKey: str = None, groups: list = None, quota: int = None):
525
- if not self.user_management:
526
- from terrakio_core.user_management import UserManagement
527
- if not self.url or not self.key:
528
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
529
- self.user_management = UserManagement(
530
- api_url=self.url,
531
- api_key=self.key,
532
- verify=self.verify,
533
- timeout=self.timeout
534
- )
535
- return self.user_management.edit_user(
536
- user_id=user_id,
537
- uid=uid,
538
- email=email,
539
- role=role,
540
- apiKey=apiKey,
541
- groups=groups,
542
- quota=quota
543
- )
544
-
545
- @require_api_key
546
- def _reset_quota(self, email: str, quota: int = None):
547
- if not self.user_management:
548
- from terrakio_core.user_management import UserManagement
549
- if not self.url or not self.key:
550
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
551
- self.user_management = UserManagement(
552
- api_url=self.url,
553
- api_key=self.key,
554
- verify=self.verify,
555
- timeout=self.timeout
556
- )
557
- return self.user_management.reset_quota(email=email, quota=quota)
558
-
559
- @require_api_key
560
- def _delete_user(self, uid: str):
561
- if not self.user_management:
562
- from terrakio_core.user_management import UserManagement
563
- if not self.url or not self.key:
564
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
565
- self.user_management = UserManagement(
566
- api_url=self.url,
567
- api_key=self.key,
568
- verify=self.verify,
569
- timeout=self.timeout
570
- )
571
- return self.user_management.delete_user(uid=uid)
572
-
573
- # Dataset management protected methods
574
- @require_api_key
575
- def _get_dataset(self, name: str, collection: str = "terrakio-datasets"):
576
- if not self.dataset_management:
577
- from terrakio_core.dataset_management import DatasetManagement
578
- if not self.url or not self.key:
579
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
580
- self.dataset_management = DatasetManagement(
581
- api_url=self.url,
582
- api_key=self.key,
583
- verify=self.verify,
584
- timeout=self.timeout
585
- )
586
- return self.dataset_management.get_dataset(name=name, collection=collection)
587
-
588
- @require_api_key
589
- def _list_datasets(self, substring: str = None, collection: str = "terrakio-datasets"):
590
- if not self.dataset_management:
591
- from terrakio_core.dataset_management import DatasetManagement
592
- if not self.url or not self.key:
593
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
594
- self.dataset_management = DatasetManagement(
595
- api_url=self.url,
596
- api_key=self.key,
597
- verify=self.verify,
598
- timeout=self.timeout
599
- )
600
- return self.dataset_management.list_datasets(substring=substring, collection=collection)
601
-
602
- @require_api_key
603
- def _create_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs):
604
- if not self.dataset_management:
605
- from terrakio_core.dataset_management import DatasetManagement
606
- if not self.url or not self.key:
607
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
608
- self.dataset_management = DatasetManagement(
609
- api_url=self.url,
610
- api_key=self.key,
611
- verify=self.verify,
612
- timeout=self.timeout
613
- )
614
- return self.dataset_management.create_dataset(name=name, collection=collection, **kwargs)
615
-
616
- @require_api_key
617
- def _update_dataset(self, name: str, append: bool = True, collection: str = "terrakio-datasets", **kwargs):
618
- if not self.dataset_management:
619
- from terrakio_core.dataset_management import DatasetManagement
620
- if not self.url or not self.key:
621
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
622
- self.dataset_management = DatasetManagement(
623
- api_url=self.url,
624
- api_key=self.key,
625
- verify=self.verify,
626
- timeout=self.timeout
627
- )
628
- return self.dataset_management.update_dataset(name=name, append=append, collection=collection, **kwargs)
629
-
630
- @require_api_key
631
- def _overwrite_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs):
632
- if not self.dataset_management:
633
- from terrakio_core.dataset_management import DatasetManagement
634
- if not self.url or not self.key:
635
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
636
- self.dataset_management = DatasetManagement(
637
- api_url=self.url,
638
- api_key=self.key,
639
- verify=self.verify,
640
- timeout=self.timeout
641
- )
642
- return self.dataset_management.overwrite_dataset(name=name, collection=collection, **kwargs)
643
-
644
- @require_api_key
645
- def _delete_dataset(self, name: str, collection: str = "terrakio-datasets"):
646
- if not self.dataset_management:
647
- from terrakio_core.dataset_management import DatasetManagement
648
- if not self.url or not self.key:
649
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
650
- self.dataset_management = DatasetManagement(
651
- api_url=self.url,
652
- api_key=self.key,
653
- verify=self.verify,
654
- timeout=self.timeout
655
- )
656
- return self.dataset_management.delete_dataset(name=name, collection=collection)
657
-
658
- @require_api_key
659
- def close(self):
660
- """Close all client sessions"""
661
- self.session.close()
662
- if self.auth_client:
663
- self.auth_client.session.close()
664
- # Close aiohttp session if it exists
665
- if self._aiohttp_session and not self._aiohttp_session.closed:
666
- try:
667
- nest_asyncio.apply()
668
- asyncio.run(self.close_async())
669
- except ImportError:
670
- try:
671
- asyncio.run(self.close_async())
672
- except RuntimeError as e:
673
- if "cannot be called from a running event loop" in str(e):
674
- # In Jupyter, we can't properly close the async session
675
- # Log a warning or handle gracefully
676
- import warnings
677
- warnings.warn("Cannot properly close aiohttp session in Jupyter environment. "
678
- "Consider using 'await client.close_async()' instead.")
679
- else:
680
- raise
681
- except RuntimeError:
682
- # Event loop may already be closed, ignore
683
- pass
684
-
685
- @require_api_key
686
- def __enter__(self):
687
- return self
688
-
689
- @require_api_key
690
- def __exit__(self, exc_type, exc_val, exc_tb):
691
- self.close()
692
-
693
- @admin_only_params('location', 'force_loc', 'server')
694
- @require_api_key
695
- def execute_job(self, name, region, output, config, overwrite=False, skip_existing=False, request_json=None, manifest_json=None, location=None, force_loc=None, server="dev-au.terrak.io"):
696
- if not self.mass_stats:
697
- from terrakio_core.mass_stats import MassStats
698
- if not self.url or not self.key:
699
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
700
- self.mass_stats = MassStats(
701
- base_url=self.url,
702
- api_key=self.key,
703
- verify=self.verify,
704
- timeout=self.timeout
705
- )
706
- return self.mass_stats.execute_job(name, region, output, config, overwrite, skip_existing, request_json, manifest_json, location, force_loc, server)
707
-
708
-
709
- @require_api_key
710
- def get_mass_stats_task_id(self, name, stage, uid=None):
711
- if not self.mass_stats:
712
- from terrakio_core.mass_stats import MassStats
713
- if not self.url or not self.key:
714
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
715
- self.mass_stats = MassStats(
716
- base_url=self.url,
717
- api_key=self.key,
718
- verify=self.verify,
719
- timeout=self.timeout
720
- )
721
- return self.mass_stats.get_task_id(name, stage, uid)
722
-
723
- @require_api_key
724
- def track_mass_stats_job(self, ids: Optional[list] = None):
725
- if not self.mass_stats:
726
- from terrakio_core.mass_stats import MassStats
727
- if not self.url or not self.key:
728
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
729
- self.mass_stats = MassStats(
730
- base_url=self.url,
731
- api_key=self.key,
732
- verify=self.verify,
733
- timeout=self.timeout
734
- )
735
- return self.mass_stats.track_job(ids)
736
-
737
- @require_api_key
738
- def get_mass_stats_history(self, limit=100):
739
- if not self.mass_stats:
740
- from terrakio_core.mass_stats import MassStats
741
- if not self.url or not self.key:
742
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
743
- self.mass_stats = MassStats(
744
- base_url=self.url,
745
- api_key=self.key,
746
- verify=self.verify,
747
- timeout=self.timeout
748
- )
749
- return self.mass_stats.get_history(limit)
750
-
751
- @require_api_key
752
- def start_mass_stats_post_processing(self, process_name, data_name, output, consumer_path, overwrite=False):
753
- if not self.mass_stats:
754
- from terrakio_core.mass_stats import MassStats
755
- if not self.url or not self.key:
756
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
757
- self.mass_stats = MassStats(
758
- base_url=self.url,
759
- api_key=self.key,
760
- verify=self.verify,
761
- timeout=self.timeout
762
- )
763
- return self.mass_stats.start_post_processing(process_name, data_name, output, consumer_path, overwrite)
764
-
765
- @require_api_key
766
- def download_mass_stats_results(self, id=None, force_loc=False, **kwargs):
767
- if not self.mass_stats:
768
- from terrakio_core.mass_stats import MassStats
769
- if not self.url or not self.key:
770
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
771
- self.mass_stats = MassStats(
772
- base_url=self.url,
773
- api_key=self.key,
774
- verify=self.verify,
775
- timeout=self.timeout
776
- )
777
- return self.mass_stats.download_results(id, force_loc, **kwargs)
778
-
779
- @require_api_key
780
- def cancel_mass_stats_job(self, id):
781
- if not self.mass_stats:
782
- from terrakio_core.mass_stats import MassStats
783
- if not self.url or not self.key:
784
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
785
- self.mass_stats = MassStats(
786
- base_url=self.url,
787
- api_key=self.key,
788
- verify=self.verify,
789
- timeout=self.timeout
790
- )
791
- return self.mass_stats.cancel_job(id)
792
-
793
- @require_api_key
794
- def cancel_all_mass_stats_jobs(self):
795
- if not self.mass_stats:
796
- from terrakio_core.mass_stats import MassStats
797
- if not self.url or not self.key:
798
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
799
- self.mass_stats = MassStats(
800
- base_url=self.url,
801
- api_key=self.key,
802
- verify=self.verify,
803
- timeout=self.timeout
804
- )
805
- return self.mass_stats.cancel_all_jobs()
806
-
807
- @require_api_key
808
- def _create_pyramids(self, name, levels, config):
809
- if not self.mass_stats:
810
- from terrakio_core.mass_stats import MassStats
811
- if not self.url or not self.key:
812
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
813
- self.mass_stats = MassStats(
814
- base_url=self.url,
815
- api_key=self.key,
816
- verify=self.verify,
817
- timeout=self.timeout
818
- )
819
- return self.mass_stats.create_pyramids(name, levels, config)
820
-
821
- @require_api_key
822
- def random_sample(self, name, **kwargs):
823
- if not self.mass_stats:
824
- from terrakio_core.mass_stats import MassStats
825
- if not self.url or not self.key:
826
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
827
- self.mass_stats = MassStats(
828
- base_url=self.url,
829
- api_key=self.key,
830
- verify=self.verify,
831
- timeout=self.timeout
832
- )
833
- return self.mass_stats.random_sample(name, **kwargs)
834
-
835
- @require_api_key
836
- async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv",
837
- in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
838
- """
839
- Compute zonal statistics for all geometries in a GeoDataFrame using asyncio for concurrency.
840
-
841
- Args:
842
- gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
843
- expr (str): Terrakio expression to evaluate, can include spatial aggregations
844
- conc (int): Number of concurrent requests to make
845
- inplace (bool): Whether to modify the input GeoDataFrame in place
846
- output (str): Output format (csv or netcdf)
847
- in_crs (str): Input coordinate reference system
848
- out_crs (str): Output coordinate reference system
849
- resolution (int): Resolution parameter
850
- buffer (bool): Whether to buffer the request (default True)
851
-
852
- Returns:
853
- geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
854
- """
855
- if conc > 100:
856
- raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
857
-
858
- # Process geometries in batches
859
- all_results = []
860
- row_indices = []
861
-
862
- # Calculate total batches for progress reporting
863
- total_geometries = len(gdb)
864
- total_batches = (total_geometries + conc - 1) // conc # Ceiling division
865
- completed_batches = 0
866
-
867
- print(f"Processing {total_geometries} geometries with concurrency {conc}")
868
-
869
- async def process_geometry(geom, index):
870
- """Process a single geometry"""
871
- try:
872
- feature = {
873
- "type": "Feature",
874
- "geometry": mapping(geom),
875
- "properties": {"index": index}
876
- }
877
- result = await self.wcs_async(expr=expr, feature=feature, output=output,
878
- in_crs=in_crs, out_crs=out_crs, resolution=resolution, buffer=buffer)
879
- # Add original index to track which geometry this result belongs to
880
- if isinstance(result, pd.DataFrame):
881
- result['_geometry_index'] = index
882
- return result
883
- except Exception as e:
884
- raise
885
-
886
- async def process_batch(batch_indices):
887
- """Process a batch of geometries concurrently using TaskGroup"""
888
- try:
889
- async with asyncio.TaskGroup() as tg:
890
- tasks = []
891
- for idx in batch_indices:
892
- geom = gdb.geometry.iloc[idx]
893
- task = tg.create_task(process_geometry(geom, idx))
894
- tasks.append(task)
895
-
896
- # Get results from completed tasks
897
- results = []
898
- for task in tasks:
899
- try:
900
- result = task.result()
901
- results.append(result)
902
- except Exception as e:
903
- raise
904
-
905
- return results
906
- except* Exception as e:
907
- # Get the actual exceptions from the tasks
908
- for task in tasks:
909
- if task.done() and task.exception():
910
- raise task.exception()
911
- raise
912
-
913
- # Process in batches to control concurrency
914
- for i in range(0, len(gdb), conc):
915
- batch_indices = range(i, min(i + conc, len(gdb)))
916
- try:
917
- batch_results = await process_batch(batch_indices)
918
- all_results.extend(batch_results)
919
- row_indices.extend(batch_indices)
920
-
921
- # Update progress
922
- completed_batches += 1
923
- processed_geometries = min(i + conc, total_geometries)
924
- print(f"Progress: {completed_batches}/{total_batches} completed ({processed_geometries}/{total_geometries} geometries processed)")
925
-
926
- except Exception as e:
927
- if hasattr(e, 'response'):
928
- raise APIError(f"API request failed: {e.response.text}")
929
- raise
930
-
931
- print("All batches completed! Processing results...")
932
-
933
- if not all_results:
934
- raise ValueError("No valid results were returned for any geometry")
935
-
936
- # Combine all results
937
- combined_df = pd.concat(all_results, ignore_index=True)
938
-
939
- # Check if we have temporal results
940
- has_time = 'time' in combined_df.columns
941
-
942
- # Create a result GeoDataFrame
943
- if has_time:
944
- # For temporal data, we'll create a hierarchical index
945
- # First make sure we have the geometry index and time columns
946
- if '_geometry_index' not in combined_df.columns:
947
- raise ValueError("Missing geometry index in results")
948
-
949
- # Create hierarchical index on geometry_index and time
950
- combined_df.set_index(['_geometry_index', 'time'], inplace=True)
951
-
952
- # For each unique geometry index, we need the corresponding geometry
953
- geometry_series = gdb.geometry.copy()
954
-
955
- # Get columns that will become new attributes (exclude index/utility columns)
956
- result_cols = combined_df.columns
957
-
958
- # Create a new GeoDataFrame with multi-index
959
- result_rows = []
960
- geometries = []
961
-
962
- # Iterate through the hierarchical index
963
- for (geom_idx, time_val), row in combined_df.iterrows():
964
- # Create a new row with geometry properties + result columns
965
- new_row = {}
966
-
967
- # Add original GeoDataFrame columns (except geometry)
968
- for col in gdb.columns:
969
- if col != 'geometry':
970
- new_row[col] = gdb.loc[geom_idx, col]
971
-
972
- # Add result columns
973
- for col in result_cols:
974
- new_row[col] = row[col]
975
-
976
- result_rows.append(new_row)
977
- geometries.append(gdb.geometry.iloc[geom_idx])
978
-
979
- # Create a new GeoDataFrame with multi-index
980
- multi_index = pd.MultiIndex.from_tuples(
981
- combined_df.index.tolist(),
982
- names=['geometry_index', 'time']
983
- )
984
-
985
- result_gdf = gpd.GeoDataFrame(
986
- result_rows,
987
- geometry=geometries,
988
- index=multi_index
989
- )
990
-
991
- if inplace:
992
- # Can't really do inplace with multi-temporal results as we're changing the structure
993
- return result_gdf
994
- else:
995
- return result_gdf
996
- else:
997
- # Non-temporal data - just add new columns to the existing GeoDataFrame
998
- result_gdf = gdb.copy() if not inplace else gdb
999
-
1000
- # Get column names from the results (excluding utility columns)
1001
- result_cols = [col for col in combined_df.columns if col not in ['_geometry_index']]
1002
-
1003
- # Create a mapping from geometry index to result rows
1004
- geom_idx_to_row = {}
1005
- for idx, row in combined_df.iterrows():
1006
- geom_idx = int(row['_geometry_index'])
1007
- geom_idx_to_row[geom_idx] = row
1008
-
1009
- # Add results as new columns to the GeoDataFrame
1010
- for col in result_cols:
1011
- # Initialize the column with None or appropriate default
1012
- if col not in result_gdf.columns:
1013
- result_gdf[col] = None
1014
-
1015
- # Fill in values from results
1016
- for geom_idx, row in geom_idx_to_row.items():
1017
- result_gdf.loc[geom_idx, col] = row[col]
1018
-
1019
- if inplace:
1020
- return None
1021
- else:
1022
- return result_gdf
1023
31
 
1024
- @require_api_key
1025
- def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv",
1026
- in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
1027
- """
1028
- Compute zonal statistics for all geometries in a GeoDataFrame.
1029
-
1030
- Args:
1031
- gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
1032
- expr (str): Terrakio expression to evaluate, can include spatial aggregations
1033
- conc (int): Number of concurrent requests to make
1034
- inplace (bool): Whether to modify the input GeoDataFrame in place
1035
- output (str): Output format (csv or netcdf)
1036
- in_crs (str): Input coordinate reference system
1037
- out_crs (str): Output coordinate reference system
1038
- resolution (int): Resolution parameter
1039
- buffer (bool): Whether to buffer the request (default True)
1040
-
1041
- Returns:
1042
- geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
1043
- """
1044
- if conc > 100:
1045
- raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
1046
- import asyncio
1047
-
1048
- print(f"Starting zonal statistics computation for expression: {expr}")
1049
-
1050
- # Check if we're in a Jupyter environment or already have an event loop
1051
- try:
1052
- loop = asyncio.get_running_loop()
1053
- # We're in an async context (like Jupyter), use create_task
1054
- nest_asyncio.apply()
1055
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1056
- in_crs, out_crs, resolution, buffer))
1057
- except RuntimeError:
1058
- # No running event loop, safe to use asyncio.run()
1059
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1060
- in_crs, out_crs, resolution, buffer))
1061
- except ImportError:
1062
- # nest_asyncio not available, try alternative approach
1063
- try:
1064
- loop = asyncio.get_running_loop()
1065
- # Create task in existing loop
1066
- task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1067
- in_crs, out_crs, resolution, buffer))
1068
- # This won't work directly - we need a different approach
1069
- raise RuntimeError("Cannot run async code in Jupyter without nest_asyncio. Please install: pip install nest-asyncio")
1070
- except RuntimeError:
1071
- # No event loop, use asyncio.run
1072
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1073
- in_crs, out_crs, resolution, buffer))
1074
32
 
1075
- # Ensure aiohttp session is closed after running async code
1076
- try:
1077
- if self._aiohttp_session and not self._aiohttp_session.closed:
1078
- asyncio.run(self.close_async())
1079
- except RuntimeError:
1080
- # Event loop may already be closed, ignore
1081
- pass
1082
-
1083
- print("Zonal statistics computation completed!")
1084
- return result
1085
-
1086
- # Group access management protected methods
1087
- @require_api_key
1088
- def _get_group_users_and_datasets(self, group_name: str):
1089
- if not self.group_access:
1090
- from terrakio_core.group_access_management import GroupAccessManagement
1091
- if not self.url or not self.key:
1092
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1093
- self.group_access = GroupAccessManagement(
1094
- api_url=self.url,
1095
- api_key=self.key,
1096
- verify=self.verify,
1097
- timeout=self.timeout
1098
- )
1099
- return self.group_access.get_group_users_and_datasets(group_name=group_name)
1100
-
1101
- @require_api_key
1102
- def _add_group_to_dataset(self, dataset: str, group: str):
1103
- if not self.group_access:
1104
- from terrakio_core.group_access_management import GroupAccessManagement
1105
- if not self.url or not self.key:
1106
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1107
- self.group_access = GroupAccessManagement(
1108
- api_url=self.url,
1109
- api_key=self.key,
1110
- verify=self.verify,
1111
- timeout=self.timeout
1112
- )
1113
- return self.group_access.add_group_to_dataset(dataset=dataset, group=group)
1114
-
1115
- @require_api_key
1116
- def _add_group_to_user(self, uid: str, group: str):
1117
- if not self.group_access:
1118
- from terrakio_core.group_access_management import GroupAccessManagement
1119
- if not self.url or not self.key:
1120
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1121
- self.group_access = GroupAccessManagement(
1122
- api_url=self.url,
1123
- api_key=self.key,
1124
- verify=self.verify,
1125
- timeout=self.timeout
1126
- )
1127
- return self.group_access.add_group_to_user(uid=uid, group=group)
1128
-
1129
- @require_api_key
1130
- def _delete_group_from_user(self, uid: str, group: str):
1131
- if not self.group_access:
1132
- from terrakio_core.group_access_management import GroupAccessManagement
1133
- if not self.url or not self.key:
1134
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1135
- self.group_access = GroupAccessManagement(
1136
- api_url=self.url,
1137
- api_key=self.key,
1138
- verify=self.verify,
1139
- timeout=self.timeout
1140
- )
1141
- return self.group_access.delete_group_from_user(uid=uid, group=group)
1142
-
1143
- @require_api_key
1144
- def _delete_group_from_dataset(self, dataset: str, group: str):
1145
- if not self.group_access:
1146
- from terrakio_core.group_access_management import GroupAccessManagement
1147
- if not self.url or not self.key:
1148
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1149
- self.group_access = GroupAccessManagement(
1150
- api_url=self.url,
1151
- api_key=self.key,
1152
- verify=self.verify,
1153
- timeout=self.timeout
1154
- )
1155
- return self.group_access.delete_group_from_dataset(dataset=dataset, group=group)
1156
-
1157
- # Space management protected methods
1158
- @require_api_key
1159
- def _get_total_space_used(self):
1160
- if not self.space_management:
1161
- from terrakio_core.space_management import SpaceManagement
1162
- if not self.url or not self.key:
1163
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1164
- self.space_management = SpaceManagement(
1165
- api_url=self.url,
1166
- api_key=self.key,
1167
- verify=self.verify,
1168
- timeout=self.timeout
1169
- )
1170
- return self.space_management.get_total_space_used()
1171
-
1172
- @require_api_key
1173
- def _get_space_used_by_job(self, name: str, region: str = None):
1174
- if not self.space_management:
1175
- from terrakio_core.space_management import SpaceManagement
1176
- if not self.url or not self.key:
1177
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1178
- self.space_management = SpaceManagement(
1179
- api_url=self.url,
1180
- api_key=self.key,
1181
- verify=self.verify,
1182
- timeout=self.timeout
1183
- )
1184
- return self.space_management.get_space_used_by_job(name, region)
1185
-
1186
- @require_api_key
1187
- def _delete_user_job(self, name: str, region: str = None):
1188
- if not self.space_management:
1189
- from terrakio_core.space_management import SpaceManagement
1190
- if not self.url or not self.key:
1191
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1192
- self.space_management = SpaceManagement(
1193
- api_url=self.url,
1194
- api_key=self.key,
1195
- verify=self.verify,
1196
- timeout=self.timeout
1197
- )
1198
- return self.space_management.delete_user_job(name, region)
1199
-
1200
- @require_api_key
1201
- def _delete_data_in_path(self, path: str, region: str = None):
1202
- if not self.space_management:
1203
- from terrakio_core.space_management import SpaceManagement
1204
- if not self.url or not self.key:
1205
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1206
- self.space_management = SpaceManagement(
1207
- api_url=self.url,
1208
- api_key=self.key,
1209
- verify=self.verify,
1210
- timeout=self.timeout
1211
- )
1212
- return self.space_management.delete_data_in_path(path, region)
1213
-
1214
- @require_api_key
1215
- def start_mass_stats_job(self, task_id):
1216
- if not self.mass_stats:
1217
- from terrakio_core.mass_stats import MassStats
1218
- if not self.url or not self.key:
1219
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1220
- self.mass_stats = MassStats(
1221
- base_url=self.url,
1222
- api_key=self.key,
1223
- verify=self.verify,
1224
- timeout=self.timeout
1225
- )
1226
- return self.mass_stats.start_job(task_id)
1227
-
1228
-
1229
- @require_api_key
1230
- def generate_ai_dataset(
1231
- self,
1232
- name: str,
1233
- aoi_geojson: str,
1234
- expression_x: str,
1235
- expression_y: str,
1236
- samples: int,
1237
- tile_size: int,
1238
- crs: str = "epsg:4326",
1239
- res: float = 0.001,
1240
- region: str = "aus",
1241
- start_year: int = None,
1242
- end_year: int = None,
1243
- ) -> dict:
1244
- """
1245
- Generate an AI dataset using specified parameters.
1246
-
1247
- Args:
1248
- name (str): Name of the dataset to generate
1249
- aoi_geojson (str): Path to GeoJSON file containing area of interest
1250
- expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
1251
- expression_y (str): Expression for Y variable with {year} placeholder
1252
- samples (int): Number of samples to generate
1253
- tile_size (int): Size of tiles in degrees
1254
- crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
1255
- res (float, optional): Resolution in degrees. Defaults to 0.001
1256
- region (str, optional): Region code. Defaults to "aus"
1257
- start_year (int, optional): Start year for data generation. Required if end_year provided
1258
- end_year (int, optional): End year for data generation. Required if start_year provided
1259
- overwrite (bool, optional): Whether to overwrite existing dataset. Defaults to False
1260
-
1261
- Returns:
1262
- dict: Response from the AI dataset generation API
1263
-
1264
- Raises:
1265
- ValidationError: If required parameters are missing or invalid
1266
- APIError: If the API request fails
1267
- """
1268
-
1269
- # we have the parameters, let pass the parameters to the random sample function
1270
- # task_id = self.random_sample(name, aoi_geojson, expression_x, expression_y, samples, tile_size, crs, res, region, start_year, end_year, overwrite)
1271
- config = {
1272
- "expressions" : [{"expr": expression_x, "res": res, "prefix": "x"}],
1273
- "filters" : []
1274
- }
1275
- config["expressions"].append({"expr": expression_y, "res" : res, "prefix": "y"})
1276
-
1277
- expression_x = expression_x.replace("{year}", str(start_year))
1278
- expression_y = expression_y.replace("{year}", str(start_year))
1279
- print("the aoi geojson is ", aoi_geojson)
1280
- with open(aoi_geojson, 'r') as f:
1281
- aoi_data = json.load(f)
1282
- print("the config is ", config)
1283
- task_id = self.random_sample(
1284
- name=name,
1285
- config=config,
1286
- aoi=aoi_data,
1287
- samples=samples,
1288
- year_range=[start_year, end_year],
1289
- crs=crs,
1290
- tile_size=tile_size,
1291
- res=res,
1292
- region=region,
1293
- output="netcdf",
1294
- server=self.url,
1295
- bucket="terrakio-mass-requests",
1296
- overwrite=True
1297
- )["task_id"]
1298
- print("the task id is ", task_id)
1299
-
1300
- # Wait for job completion
1301
- import time
1302
-
1303
- while True:
1304
- result = self.track_mass_stats_job(ids=[task_id])
1305
- status = result[task_id]['status']
1306
- print(f"Job status: {status}")
1307
-
1308
- if status == "Completed":
1309
- break
1310
- elif status == "Error":
1311
- raise Exception(f"Job {task_id} encountered an error")
1312
-
1313
- # Wait 30 seconds before checking again
1314
- time.sleep(30)
1315
-
1316
- # print("the result is ", result)
1317
- # after all the random sample jos are done, we then start the mass stats job
1318
- task_id = self.start_mass_stats_job(task_id)
1319
- # now we hav ethe random sampel
1320
-
1321
- # print("the task id is ", task_id)
1322
- return task_id
1323
-
1324
- @require_api_key
1325
- def train_model(self, model_name: str, training_dataset: str, task_type: str, model_category: str, architecture: str, region: str, hyperparameters: dict = None) -> dict:
1326
- """
1327
- Train a model using the external model training API.
1328
-
1329
- Args:
1330
- model_name (str): The name of the model to train.
1331
- training_dataset (str): The training dataset identifier.
1332
- task_type (str): The type of ML task (e.g., regression, classification).
1333
- model_category (str): The category of model (e.g., random_forest).
1334
- architecture (str): The model architecture.
1335
- region (str): The region identifier.
1336
- hyperparameters (dict, optional): Additional hyperparameters for training.
1337
-
1338
- Returns:
1339
- dict: The response from the model training API.
1340
- """
1341
- payload = {
1342
- "model_name": model_name,
1343
- "training_dataset": training_dataset,
1344
- "task_type": task_type,
1345
- "model_category": model_category,
1346
- "architecture": architecture,
1347
- "region": region,
1348
- "hyperparameters": hyperparameters
1349
- }
1350
- endpoint = f"{self.url.rstrip('/')}/train_model"
1351
- print("the payload is ", payload)
1352
- try:
1353
- response = self.session.post(endpoint, json=payload, timeout=self.timeout, verify=self.verify)
1354
- if not response.ok:
1355
- error_msg = f"Model training request failed: {response.status_code} {response.reason}"
1356
- try:
1357
- error_data = response.json()
1358
- if "detail" in error_data:
1359
- error_msg += f" - {error_data['detail']}"
1360
- except Exception:
1361
- if response.text:
1362
- error_msg += f" - {response.text}"
1363
- raise APIError(error_msg)
1364
- return response.json()
1365
- except requests.RequestException as e:
1366
- raise APIError(f"Model training request failed: {str(e)}")
1367
-
1368
- # Mass Stats methods
1369
- @require_api_key
1370
- def combine_tiles(self,
1371
- data_name: str,
1372
- usezarr: bool,
1373
- overwrite: bool,
1374
- output : str) -> dict:
1375
-
1376
- if not self.mass_stats:
1377
- from terrakio_core.mass_stats import MassStats
1378
- if not self.url or not self.key:
1379
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1380
- self.mass_stats = MassStats(
1381
- base_url=self.url,
1382
- api_key=self.key,
1383
- verify=self.verify,
1384
- timeout=self.timeout
1385
- )
1386
- return self.mass_stats.combine_tiles(data_name, usezarr, overwrite, output)
1387
-
1388
-
1389
-
1390
- @require_api_key
1391
- def create_dataset_file(
1392
- self,
1393
- name: str,
1394
- aoi: str,
1395
- expression: str,
1396
- output: str,
1397
- tile_size: float = 128.0,
1398
- crs: str = "epsg:4326",
1399
- res: float = 0.0001,
1400
- region: str = "aus",
1401
- to_crs: str = "epsg:4326",
1402
- overwrite: bool = True,
1403
- skip_existing: bool = False,
1404
- non_interactive: bool = True,
1405
- usezarr: bool = False,
1406
- poll_interval: int = 30,
1407
- download_path: str = "/home/user/Downloads",
1408
- ) -> dict:
1409
-
1410
- if not self.mass_stats:
1411
- from terrakio_core.mass_stats import MassStats
1412
- if not self.url or not self.key:
1413
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1414
- self.mass_stats = MassStats(
1415
- base_url=self.url,
1416
- api_key=self.key,
1417
- verify=self.verify,
1418
- timeout=self.timeout
1419
- )
1420
-
1421
-
1422
- from terrakio_core.generation.tiles import tiles
1423
- import tempfile
1424
- import time
1425
-
1426
- body, reqs, groups = tiles(
1427
- name = name,
1428
- aoi = aoi,
1429
- expression = expression,
1430
- output = output,
1431
- tile_size = tile_size,
1432
- crs = crs,
1433
- res = res,
1434
- region = region,
1435
- to_crs = to_crs,
1436
- fully_cover = True,
1437
- overwrite = overwrite,
1438
- skip_existing = skip_existing,
1439
- non_interactive = non_interactive
1440
- )
1441
-
1442
- # Create temp json files before upload
1443
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempreq:
1444
- tempreq.write(reqs)
1445
- tempreqname = tempreq.name
1446
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempmanifest:
1447
- tempmanifest.write(groups)
1448
- tempmanifestname = tempmanifest.name
1449
-
1450
- if not self.mass_stats:
1451
- from terrakio_core.mass_stats import MassStats
1452
- if not self.url or not self.key:
1453
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1454
- self.mass_stats = MassStats(
1455
- base_url=self.url,
1456
- api_key=self.key,
1457
- verify=self.verify,
1458
- timeout=self.timeout
1459
- )
1460
-
1461
- task_id = self.mass_stats.execute_job(
1462
- name=body["name"],
1463
- region=body["region"],
1464
- output=body["output"],
1465
- config = {},
1466
- overwrite=body["overwrite"],
1467
- skip_existing=body["skip_existing"],
1468
- request_json=tempreqname,
1469
- manifest_json=tempmanifestname,
1470
- )
1471
-
1472
- ### Start combining tiles when generation-tiles job is done
1473
- start_time = time.time()
1474
- status = None
1475
-
1476
- while True:
1477
- try:
1478
- taskid = task_id['task_id']
1479
- trackinfo = self.mass_stats.track_job([taskid])
1480
- status = trackinfo[taskid]['status']
1481
-
1482
- # Check completion states
1483
- if status == 'Completed':
1484
- print('Tiles generated successfully!')
1485
- break
1486
- elif status in ['Failed', 'Cancelled', 'Error']:
1487
- raise RuntimeError(f"Job {taskid} failed with status: {status}")
1488
- else:
1489
- # Job is still running
1490
- elapsed_time = time.time() - start_time
1491
- print(f"Job status: {status} - Elapsed time: {elapsed_time:.1f}s", end='\r')
1492
-
1493
- # Sleep before next check
1494
- time.sleep(poll_interval)
1495
-
1496
-
1497
- except KeyboardInterrupt:
1498
- print(f"\nInterrupted! Job {taskid} is still running in the background.")
1499
- raise
1500
- except Exception as e:
1501
- print(f"\nError tracking job: {e}")
1502
- raise
1503
-
1504
- # Clean up temporary files
1505
- import os
1506
- os.unlink(tempreqname)
1507
- os.unlink(tempmanifestname)
1508
-
1509
-
1510
- # return self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
1511
-
1512
- # Start combining tiles
1513
- combine_result = self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
1514
- combine_task_id = combine_result.get("task_id")
1515
-
1516
- # Poll combine_tiles job status
1517
- combine_start_time = time.time()
1518
- while True:
1519
- try:
1520
- trackinfo = self.mass_stats.track_job([combine_task_id])
1521
- download_file_name = trackinfo[combine_task_id]['folder'] + '.nc'
1522
- bucket = trackinfo[combine_task_id]['bucket']
1523
- combine_status = trackinfo[combine_task_id]['status']
1524
- if combine_status == 'Completed':
1525
- print('Tiles combined successfully!')
1526
- break
1527
- elif combine_status in ['Failed', 'Cancelled', 'Error']:
1528
- raise RuntimeError(f"Combine job {combine_task_id} failed with status: {combine_status}")
1529
- else:
1530
- elapsed_time = time.time() - combine_start_time
1531
- print(f"Combine job status: {combine_status} - Elapsed time: {elapsed_time:.1f}s", end='\r')
1532
- time.sleep(poll_interval)
1533
- except KeyboardInterrupt:
1534
- print(f"\nInterrupted! Combine job {combine_task_id} is still running in the background.")
1535
- raise
1536
- except Exception as e:
1537
- print(f"\nError tracking combine job: {e}")
1538
- raise
1539
-
1540
- # Download the resulting file
1541
- if download_path:
1542
- self.mass_stats.download_file(body["name"], bucket, download_file_name, download_path)
1543
- else:
1544
- path = f"{body['name']}/combinedOutput/{download_file_name}"
1545
- print(f"Combined file is available at {path}")
1546
-
1547
- return {"generation_task_id": task_id, "combine_task_id": combine_task_id}
1548
-
1549
-
1550
- # taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
1551
- # trackinfo = self.mass_stats.track_job([taskid])
1552
- # bucket = trackinfo[taskid]['bucket']
1553
- # return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
1554
-
1555
-
1556
- @require_api_key
1557
- def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
1558
- script_content = self._generate_script(model_name, product, model_training_job_name, uid)
1559
- script_name = f"{product}.py"
1560
- self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
1561
- self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
1562
-
1563
- @require_api_key
1564
- def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
1565
- return textwrap.dedent(f'''
1566
- import logging
1567
- from io import BytesIO
1568
-
1569
- import numpy as np
1570
- import pandas as pd
1571
- import xarray as xr
1572
- from google.cloud import storage
1573
- from onnxruntime import InferenceSession
1574
-
1575
- logging.basicConfig(
1576
- level=logging.INFO
1577
- )
1578
-
1579
- def get_model():
1580
- logging.info("Loading model for {model_name}...")
1581
-
1582
- client = storage.Client()
1583
- bucket = client.get_bucket('terrakio-mass-requests')
1584
- blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
1585
-
1586
- model = BytesIO()
1587
- blob.download_to_file(model)
1588
- model.seek(0)
1589
-
1590
- session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
1591
- return session
1592
-
1593
- def {product}(*bands, model):
1594
- logging.info("start preparing data")
1595
- print("the bands are ", bands)
1596
-
1597
- data_arrays = list(bands)
1598
-
1599
- print("the data arrays are ", [da.name for da in data_arrays])
1600
-
1601
- reference_array = data_arrays[0]
1602
- original_shape = reference_array.shape
1603
- logging.info(f"Original shape: {{original_shape}}")
1604
-
1605
- if 'time' in reference_array.dims:
1606
- time_coords = reference_array.coords['time']
1607
- if len(time_coords) == 1:
1608
- output_timestamp = time_coords[0]
1609
- else:
1610
- years = [pd.to_datetime(t).year for t in time_coords.values]
1611
- unique_years = set(years)
1612
-
1613
- if len(unique_years) == 1:
1614
- year = list(unique_years)[0]
1615
- output_timestamp = pd.Timestamp(f"{{year}}-01-01")
1616
- else:
1617
- latest_year = max(unique_years)
1618
- output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
1619
- else:
1620
- output_timestamp = pd.Timestamp("1970-01-01")
1621
-
1622
- averaged_bands = []
1623
- for data_array in data_arrays:
1624
- if 'time' in data_array.dims:
1625
- averaged_band = np.mean(data_array.values, axis=0)
1626
- logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
1627
- else:
1628
- averaged_band = data_array.values
1629
- logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
1630
-
1631
- flattened_band = averaged_band.reshape(-1, 1)
1632
- averaged_bands.append(flattened_band)
1633
-
1634
- input_data = np.hstack(averaged_bands)
1635
-
1636
- logging.info(f"Final input shape: {{input_data.shape}}")
1637
-
1638
- output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
1639
-
1640
- logging.info(f"Model output shape: {{output.shape}}")
1641
-
1642
- if len(original_shape) >= 3:
1643
- spatial_shape = original_shape[1:]
1644
- else:
1645
- spatial_shape = original_shape
1646
-
1647
- output_reshaped = output.reshape(spatial_shape)
1648
-
1649
- output_with_time = np.expand_dims(output_reshaped, axis=0)
1650
-
1651
- if 'time' in reference_array.dims:
1652
- spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
1653
- spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
1654
- else:
1655
- spatial_dims = list(reference_array.dims)
1656
- spatial_coords = dict(reference_array.coords)
1657
-
1658
- result = xr.DataArray(
1659
- data=output_with_time.astype(np.float32),
1660
- dims=['time'] + list(spatial_dims),
1661
- coords={
1662
- 'time': [output_timestamp.values],
1663
- 'y': spatial_coords['y'].values,
1664
- 'x': spatial_coords['x'].values
1665
- }
1666
- )
1667
- return result
1668
- ''').strip()
1669
-
1670
- @require_api_key
1671
- def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
1672
- """Upload the generated script to Google Cloud Storage"""
1673
-
1674
- client = storage.Client()
1675
- bucket = client.get_bucket('terrakio-mass-requests')
1676
- blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
1677
- blob.upload_from_string(script_content, content_type='text/plain')
1678
- logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
1679
-
1680
-
1681
-
1682
-
1683
- @require_api_key
1684
- def download_file_to_path(self, job_name, stage, file_name, output_path):
1685
- if not self.mass_stats:
1686
- from terrakio_core.mass_stats import MassStats
1687
- if not self.url or not self.key:
1688
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1689
- self.mass_stats = MassStats(
1690
- base_url=self.url,
1691
- api_key=self.key,
1692
- verify=self.verify,
1693
- timeout=self.timeout
1694
- )
1695
-
1696
- # fetch bucket info based on job name and stage
1697
-
1698
- taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
1699
- trackinfo = self.mass_stats.track_job([taskid])
1700
- bucket = trackinfo[taskid]['bucket']
1701
- return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
1702
-
1703
-
1704
- # Apply the @require_api_key decorator to ALL methods in BaseClient
1705
- # except only the absolute minimum that shouldn't require auth
1706
- def _apply_api_key_decorator():
1707
- # Only these methods can be used without API key
1708
- skip_auth_methods = [
1709
- '__init__', 'login', 'signup'
1710
- ]
1711
-
1712
- # Get all attributes of BaseClient
1713
- for attr_name in dir(BaseClient):
1714
- # Skip special methods and methods in skip list
1715
- if attr_name.startswith('__') and attr_name not in ['__enter__', '__exit__', '__aenter__', '__aexit__'] or attr_name in skip_auth_methods:
1716
- continue
1717
-
1718
- # Get the attribute
1719
- attr = getattr(BaseClient, attr_name)
1720
-
1721
- # Skip if not callable (not a method) or already decorated
1722
- if not callable(attr) or hasattr(attr, '_is_decorated'):
1723
- continue
1724
-
1725
- # Apply decorator to EVERY method not in skip list
1726
- setattr(BaseClient, attr_name, require_api_key(attr))
1727
- # Mark as decorated to avoid double decoration
1728
- getattr(BaseClient, attr_name)._is_decorated = True
1729
-
1730
- # Run the decorator application
1731
- _apply_api_key_decorator()
33
+ @abstractmethod
34
+ def _setup_session(self):
35
+ """Initialize the HTTP session - implemented by sync/async clients"""
36
+ pass