terrakio-core 0.3.4__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

terrakio_core/client.py CHANGED
@@ -1,1727 +1,36 @@
1
- import json
2
- import asyncio
3
- from io import BytesIO
4
- from typing import Dict, Any, Optional, Union
5
- from functools import wraps
6
-
7
- import requests
8
- import aiohttp
9
- import pandas as pd
10
- import geopandas as gpd
11
- import xarray as xr
12
- import nest_asyncio
13
- from shapely.geometry import shape, mapping
14
- from shapely.geometry.base import BaseGeometry as ShapelyGeometry
15
- from google.cloud import storage
16
- from .exceptions import APIError, ConfigurationError
17
- from .decorators import admin_only_params
1
+ from typing import Optional
18
2
  import logging
19
- import textwrap
3
+ from terrakio_core.config import read_config_file, DEFAULT_CONFIG_FILE
4
+ from abc import abstractmethod
20
5
 
6
+ class BaseClient():
7
+ def __init__(self, url: Optional[str] = None, api_key: Optional[str] = None, verbose: bool = False):
21
8
 
22
- def require_api_key(func):
23
- """
24
- Decorator to ensure an API key is available before a method can be executed.
25
- Will check for the presence of an API key before allowing the function to be called.
26
- """
27
- @wraps(func)
28
- def wrapper(self, *args, **kwargs):
29
- # Check if the API key is set
30
- if not self.key:
31
- error_msg = "No API key found. Please provide an API key to use this client."
32
- if not self.quiet:
33
- print(error_msg)
34
- print("API key is required for this function.")
35
- print("You can set an API key by either:")
36
- print("1. Loading from a config file")
37
- print("2. Using the login() method: client.login(email='your-email@example.com', password='your-password')")
38
- raise ConfigurationError(error_msg)
39
-
40
- # Check if URL is set (required for API calls)
41
- if not self.url:
42
- if not self.quiet:
43
- print("API URL is not set. Using default API URL: https://api.terrak.io")
44
- self.url = "https://api.terrak.io"
45
-
46
- return func(self, *args, **kwargs)
47
-
48
- # Mark as decorated to avoid double decoration
49
- wrapper._is_decorated = True
50
- return wrapper
9
+ self.verbose = verbose
10
+ self.logger = logging.getLogger("terrakio")
11
+ if verbose:
12
+ self.logger.setLevel(logging.INFO)
13
+ else:
14
+ self.logger.setLevel(logging.WARNING)
51
15
 
16
+ self.timeout = 300
17
+ self.retry = 3
18
+
19
+ self.session = None
52
20
 
53
- class BaseClient:
54
- def __init__(self, url: Optional[str] = None,
55
- auth_url: Optional[str] = "https://dev-au.terrak.io",
56
- quiet: bool = False, config_file: Optional[str] = None,
57
- verify: bool = True, timeout: int = 300):
58
- nest_asyncio.apply()
59
- self.quiet = quiet
60
- self.verify = verify
61
- self.timeout = timeout
62
- self.auth_client = None
63
- # Initialize session early to avoid AttributeError
64
- self.session = requests.Session()
65
- self.user_management = None
66
- self.dataset_management = None
67
- self.mass_stats = None
68
- self._aiohttp_session = None
69
-
70
- # if auth_url:
71
- # from terrakio_core.auth import AuthClient
72
- # self.auth_client = AuthClient(
73
- # base_url=auth_url,
74
- # verify=verify,
75
- # timeout=timeout
76
- # )
77
21
  self.url = url
78
- self.key = None
79
-
80
- # Load configuration from file
81
- from terrakio_core.config import read_config_file, DEFAULT_CONFIG_FILE
82
-
83
- if config_file is None:
84
- config_file = DEFAULT_CONFIG_FILE
85
-
86
- # Get config using the read_config_file function that now handles all cases
87
- config = read_config_file(config_file, quiet=self.quiet)
88
-
89
- # If URL is not provided, try to get it from config
22
+ self.key = api_key
23
+
24
+ config = read_config_file( DEFAULT_CONFIG_FILE, logger = self.logger)
90
25
  if self.url is None:
91
26
  self.url = config.get('url')
92
27
 
93
- # Get API key from config file (never from parameters)
94
28
  self.key = config.get('key')
95
29
 
96
30
  self.token = config.get('token')
97
-
98
- # Update auth_client with token if it exists
99
- if self.auth_client and self.token:
100
- self.auth_client.token = self.token
101
- self.auth_client.session.headers.update({
102
- "Authorization": self.token
103
- })
104
-
105
- # If we have a key, we're good to go even if not logged in
106
- if self.key:
107
- # If URL is missing but we have a key, use the default URL
108
- if not self.url:
109
- self.url = "https://api.terrak.io"
110
- print("the token is! ", self.token)
111
- print("the key is! ", self.key)
112
- # Update session headers with the key
113
- headers = {
114
- 'Content-Type': 'application/json',
115
- 'x-api-key': self.key
116
- }
117
-
118
- # Only add Authorization header if token exists and is not None
119
- if self.token:
120
- headers['Authorization'] = self.token
121
-
122
- self.session.headers.update(headers)
123
- print("the session headers are ", self.session.headers)
124
-
125
-
126
- if not self.quiet and config.get('user_email'):
127
- print(f"Using API key for: {config.get('user_email')}")
128
-
129
- return
130
-
131
- # Check if we have the required configuration
132
- if not self.key:
133
- # No API key available - inform the user
134
- if not self.quiet:
135
- print("API key is required to use this client.")
136
- print("You can set an API key by either:")
137
- print("1. Loading from a config file")
138
- print("2. Using the login() method: client.login(email='your-email@example.com', password='your-password')")
139
- return
140
-
141
- self.url = self.url.rstrip('/')
142
- if not self.quiet:
143
- print(f"Using Terrakio API at: {self.url}")
144
-
145
- # Update the session headers with API key
146
- self.session.headers.update({
147
- 'Content-Type': 'application/json',
148
- 'x-api-key': self.key
149
- })
150
-
151
- @property
152
- @require_api_key
153
- async def aiohttp_session(self):
154
- if self._aiohttp_session is None or self._aiohttp_session.closed:
155
- self._aiohttp_session = aiohttp.ClientSession(
156
- headers={
157
- 'Content-Type': 'application/json',
158
- 'x-api-key': self.key
159
- },
160
- timeout=aiohttp.ClientTimeout(total=self.timeout)
161
- )
162
- return self._aiohttp_session
163
-
164
- @require_api_key
165
- async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
166
- in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
167
- output: str = "csv", resolution: int = -1, buffer: bool = False,
168
- retry: int = 3, **kwargs):
169
- """
170
- Asynchronous version of the wcs() method using aiohttp.
171
-
172
- Args:
173
- expr (str): The WCS expression to evaluate
174
- feature (Union[Dict[str, Any], ShapelyGeometry]): The geographic feature
175
- in_crs (str): Input coordinate reference system
176
- out_crs (str): Output coordinate reference system
177
- output (str): Output format ('csv' or 'netcdf')
178
- resolution (int): Resolution parameter
179
- buffer (bool): Whether to buffer the request (default True)
180
- retry (int): Number of retry attempts (default 3)
181
- **kwargs: Additional parameters to pass to the WCS request
182
-
183
- Returns:
184
- Union[pd.DataFrame, xr.Dataset, bytes]: The response data in the requested format
185
- """
186
- if hasattr(feature, 'is_valid'):
187
- from shapely.geometry import mapping
188
- feature = {
189
- "type": "Feature",
190
- "geometry": mapping(feature),
191
- "properties": {}
192
- }
193
-
194
- payload = {
195
- "feature": feature,
196
- "in_crs": in_crs,
197
- "out_crs": out_crs,
198
- "output": output,
199
- "resolution": resolution,
200
- "expr": expr,
201
- "buffer": buffer,
202
- "resolution": resolution,
203
- **kwargs
204
- }
205
- request_url = f"{self.url}/geoquery"
206
- for attempt in range(retry + 1):
207
- try:
208
- session = await self.aiohttp_session
209
- async with session.post(request_url, json=payload, ssl=self.verify) as response:
210
- if not response.ok:
211
- should_retry = False
212
- if response.status in [408, 502, 503, 504]:
213
- should_retry = True
214
- elif response.status == 500:
215
- try:
216
- response_text = await response.text()
217
- if "Internal server error" not in response_text:
218
- should_retry = True
219
- except:
220
- should_retry = True
221
-
222
- if should_retry and attempt < retry:
223
- continue
224
- else:
225
- error_msg = f"API request failed: {response.status} {response.reason}"
226
- try:
227
- error_data = await response.json()
228
- if "detail" in error_data:
229
- error_msg += f" - {error_data['detail']}"
230
- except:
231
- pass
232
- raise APIError(error_msg)
233
-
234
- content = await response.read()
235
-
236
- if output.lower() == "csv":
237
- import pandas as pd
238
- df = pd.read_csv(BytesIO(content))
239
- return df
240
- elif output.lower() == "netcdf":
241
- return xr.open_dataset(BytesIO(content))
242
- else:
243
- try:
244
- return xr.open_dataset(BytesIO(content))
245
- except ValueError:
246
- import pandas as pd
247
- try:
248
- return pd.read_csv(BytesIO(content))
249
- except:
250
- return content
251
-
252
- except aiohttp.ClientError as e:
253
- if attempt == retry:
254
- raise APIError(f"Request failed: {str(e)}")
255
- continue
256
- except Exception as e:
257
- if attempt == retry:
258
- raise
259
- continue
260
-
261
- @require_api_key
262
- async def close_async(self):
263
- """Close the aiohttp session"""
264
- if self._aiohttp_session and not self._aiohttp_session.closed:
265
- await self._aiohttp_session.close()
266
- self._aiohttp_session = None
267
-
268
- @require_api_key
269
- async def __aenter__(self):
270
- return self
271
-
272
- @require_api_key
273
- async def __aexit__(self, exc_type, exc_val, exc_tb):
274
- await self.close_async()
275
-
276
- def signup(self, email: str, password: str) -> Dict[str, Any]:
277
- if not self.auth_client:
278
- from terrakio_core.auth import AuthClient
279
- self.auth_client = AuthClient(
280
- base_url=self.url,
281
- verify=self.verify,
282
- timeout=self.timeout
283
- )
284
- self.auth_client.session = self.session
285
- return self.auth_client.signup(email, password)
286
-
287
- def login(self, email: str, password: str) -> Dict[str, str]:
288
-
289
- if not self.auth_client:
290
- from terrakio_core.auth import AuthClient
291
- self.auth_client = AuthClient(
292
- base_url=self.url,
293
- verify=self.verify,
294
- timeout=self.timeout
295
- )
296
- self.auth_client.session = self.session
297
- try:
298
- # First attempt to login
299
- token_response = self.auth_client.login(email, password)
300
-
301
- # Only proceed with API key retrieval if login was successful
302
- if token_response:
303
- # After successful login, get the API key
304
- api_key_response = self.auth_client.view_api_key()
305
- self.key = api_key_response
306
-
307
- # Make sure URL is set
308
- if not self.url:
309
- self.url = "https://api.terrak.io"
310
-
311
- # Make sure session headers are updated with the new key
312
- self.session.headers.update({
313
- 'Content-Type': 'application/json',
314
- 'x-api-key': self.key
315
- })
316
-
317
- # Set URL if not already set
318
- if not self.url:
319
- self.url = "https://api.terrak.io"
320
- self.url = self.url.rstrip('/')
321
-
322
- # Save email and API key to config file
323
- import os
324
- import json
325
- config_path = os.path.join(os.environ.get("HOME", ""), ".tkio_config.json")
326
- try:
327
- config = {"EMAIL": email, "TERRAKIO_API_KEY": self.key}
328
- if os.path.exists(config_path):
329
- with open(config_path, 'r') as f:
330
- config = json.load(f)
331
- config["EMAIL"] = email
332
- config["TERRAKIO_API_KEY"] = self.key
333
- config["PERSONAL_TOKEN"] = token_response
334
-
335
- os.makedirs(os.path.dirname(config_path), exist_ok=True)
336
- with open(config_path, 'w') as f:
337
- json.dump(config, f, indent=4)
338
-
339
- if not self.quiet:
340
- print(f"Successfully authenticated as: {email}")
341
- print(f"Using Terrakio API at: {self.url}")
342
- print(f"API key saved to {config_path}")
343
- except Exception as e:
344
- if not self.quiet:
345
- print(f"Warning: Failed to update config file: {e}")
346
-
347
- return {"token": token_response} if token_response else {"error": "Login failed"}
348
- except Exception as e:
349
- if not self.quiet:
350
- print(f"Login failed: {str(e)}")
351
- raise
352
-
353
- @require_api_key
354
- def refresh_api_key(self) -> str:
355
- # If we have auth_client and it has a token, refresh the key
356
- if self.auth_client and self.auth_client.token:
357
- self.key = self.auth_client.refresh_api_key()
358
- self.session.headers.update({'x-api-key': self.key})
359
-
360
- # Update the config file with the new key
361
- import os
362
- config_path = os.path.join(os.environ.get("HOME", ""), ".tkio_config.json")
363
- try:
364
- config = {"EMAIL": "", "TERRAKIO_API_KEY": ""}
365
- if os.path.exists(config_path):
366
- with open(config_path, 'r') as f:
367
- config = json.load(f)
368
- config["TERRAKIO_API_KEY"] = self.key
369
- os.makedirs(os.path.dirname(config_path), exist_ok=True)
370
- with open(config_path, 'w') as f:
371
- json.dump(config, f, indent=4)
372
- if not self.quiet:
373
- print(f"API key generated successfully and updated in {config_path}")
374
- except Exception as e:
375
- if not self.quiet:
376
- print(f"Warning: Failed to update config file: {e}")
377
- return self.key
378
- else:
379
- # If we don't have auth_client with a token but have a key already, return it
380
- if self.key:
381
- if not self.quiet:
382
- print("Using existing API key from config file.")
383
- return self.key
384
- else:
385
- raise ConfigurationError("No authentication token available. Please login first to refresh the API key.")
386
-
387
- @require_api_key
388
- def view_api_key(self) -> str:
389
- # If we have the auth client and token, refresh the key
390
- if self.auth_client and self.auth_client.token:
391
- self.key = self.auth_client.view_api_key()
392
- self.session.headers.update({'x-api-key': self.key})
393
-
394
- return self.key
395
-
396
- @require_api_key
397
- def view_api_key(self) -> str:
398
- # If we have the auth client and token, refresh the key
399
- if not self.auth_client:
400
- from terrakio_core.auth import AuthClient
401
- self.auth_client = AuthClient(
402
- base_url=self.url,
403
- verify=self.verify,
404
- timeout=self.timeout
405
- )
406
- self.auth_client.session = self.session
407
- return self.auth_client.view_api_key()
408
-
409
- # @require_api_key
410
- # def get_user_info(self) -> Dict[str, Any]:
411
- # return self.auth_client.get_user_info()
412
- @require_api_key
413
- def get_user_info(self) -> Dict[str, Any]:
414
- # Initialize auth_client if it doesn't exist
415
- if not self.auth_client:
416
- from terrakio_core.auth import AuthClient
417
- self.auth_client = AuthClient(
418
- base_url=self.url,
419
- verify=self.verify,
420
- timeout=self.timeout
421
- )
422
- # Use the same session as the base client
423
- self.auth_client.session = self.session
424
-
425
- return self.auth_client.get_user_info()
426
-
427
- @require_api_key
428
- def wcs(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry], in_crs: str = "epsg:4326",
429
- out_crs: str = "epsg:4326", output: str = "csv", resolution: int = -1,
430
- **kwargs):
431
- if hasattr(feature, 'is_valid'):
432
- from shapely.geometry import mapping
433
- feature = {
434
- "type": "Feature",
435
- "geometry": mapping(feature),
436
- "properties": {}
437
- }
438
- payload = {
439
- "feature": feature,
440
- "in_crs": in_crs,
441
- "out_crs": out_crs,
442
- "output": output,
443
- "resolution": resolution,
444
- "expr": expr,
445
- **kwargs
446
- }
447
- request_url = f"{self.url}/geoquery"
448
- try:
449
- response = self.session.post(request_url, json=payload, timeout=self.timeout, verify=self.verify)
450
- print("the response is ", response.text)
451
- if not response.ok:
452
- error_msg = f"API request failed: {response.status_code} {response.reason}"
453
- try:
454
- error_data = response.json()
455
- if "detail" in error_data:
456
- error_msg += f" - {error_data['detail']}"
457
- except:
458
- pass
459
- raise APIError(error_msg)
460
- if output.lower() == "csv":
461
- import pandas as pd
462
- return pd.read_csv(BytesIO(response.content))
463
- elif output.lower() == "netcdf":
464
- return xr.open_dataset(BytesIO(response.content))
465
- else:
466
- try:
467
- return xr.open_dataset(BytesIO(response.content))
468
- except ValueError:
469
- import pandas as pd
470
- try:
471
- return pd.read_csv(BytesIO(response.content))
472
- except:
473
- return response.content
474
- except requests.RequestException as e:
475
- raise APIError(f"Request failed: {str(e)}")
476
-
477
- # Admin/protected methods
478
- @require_api_key
479
- def _get_user_by_id(self, user_id: str):
480
- if not self.user_management:
481
- from terrakio_core.user_management import UserManagement
482
- if not self.url or not self.key:
483
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
484
- self.user_management = UserManagement(
485
- api_url=self.url,
486
- api_key=self.key,
487
- verify=self.verify,
488
- timeout=self.timeout
489
- )
490
- return self.user_management.get_user_by_id(user_id)
491
-
492
- @require_api_key
493
- def _get_user_by_email(self, email: str):
494
- if not self.user_management:
495
- from terrakio_core.user_management import UserManagement
496
- if not self.url or not self.key:
497
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
498
- self.user_management = UserManagement(
499
- api_url=self.url,
500
- api_key=self.key,
501
- verify=self.verify,
502
- timeout=self.timeout
503
- )
504
- return self.user_management.get_user_by_email(email)
505
-
506
- @require_api_key
507
- def _list_users(self, substring: str = None, uid: bool = False):
508
- if not self.user_management:
509
- from terrakio_core.user_management import UserManagement
510
- if not self.url or not self.key:
511
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
512
- self.user_management = UserManagement(
513
- api_url=self.url,
514
- api_key=self.key,
515
- verify=self.verify,
516
- timeout=self.timeout
517
- )
518
- return self.user_management.list_users(substring=substring, uid=uid)
519
-
520
- @require_api_key
521
- def _edit_user(self, user_id: str, uid: str = None, email: str = None, role: str = None, apiKey: str = None, groups: list = None, quota: int = None):
522
- if not self.user_management:
523
- from terrakio_core.user_management import UserManagement
524
- if not self.url or not self.key:
525
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
526
- self.user_management = UserManagement(
527
- api_url=self.url,
528
- api_key=self.key,
529
- verify=self.verify,
530
- timeout=self.timeout
531
- )
532
- return self.user_management.edit_user(
533
- user_id=user_id,
534
- uid=uid,
535
- email=email,
536
- role=role,
537
- apiKey=apiKey,
538
- groups=groups,
539
- quota=quota
540
- )
541
-
542
- @require_api_key
543
- def _reset_quota(self, email: str, quota: int = None):
544
- if not self.user_management:
545
- from terrakio_core.user_management import UserManagement
546
- if not self.url or not self.key:
547
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
548
- self.user_management = UserManagement(
549
- api_url=self.url,
550
- api_key=self.key,
551
- verify=self.verify,
552
- timeout=self.timeout
553
- )
554
- return self.user_management.reset_quota(email=email, quota=quota)
555
-
556
- @require_api_key
557
- def _delete_user(self, uid: str):
558
- if not self.user_management:
559
- from terrakio_core.user_management import UserManagement
560
- if not self.url or not self.key:
561
- raise ConfigurationError("User management client not initialized. Make sure API URL and key are set.")
562
- self.user_management = UserManagement(
563
- api_url=self.url,
564
- api_key=self.key,
565
- verify=self.verify,
566
- timeout=self.timeout
567
- )
568
- return self.user_management.delete_user(uid=uid)
569
-
570
- # Dataset management protected methods
571
- @require_api_key
572
- def _get_dataset(self, name: str, collection: str = "terrakio-datasets"):
573
- if not self.dataset_management:
574
- from terrakio_core.dataset_management import DatasetManagement
575
- if not self.url or not self.key:
576
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
577
- self.dataset_management = DatasetManagement(
578
- api_url=self.url,
579
- api_key=self.key,
580
- verify=self.verify,
581
- timeout=self.timeout
582
- )
583
- return self.dataset_management.get_dataset(name=name, collection=collection)
584
-
585
- @require_api_key
586
- def _list_datasets(self, substring: str = None, collection: str = "terrakio-datasets"):
587
- if not self.dataset_management:
588
- from terrakio_core.dataset_management import DatasetManagement
589
- if not self.url or not self.key:
590
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
591
- self.dataset_management = DatasetManagement(
592
- api_url=self.url,
593
- api_key=self.key,
594
- verify=self.verify,
595
- timeout=self.timeout
596
- )
597
- return self.dataset_management.list_datasets(substring=substring, collection=collection)
598
-
599
- @require_api_key
600
- def _create_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs):
601
- if not self.dataset_management:
602
- from terrakio_core.dataset_management import DatasetManagement
603
- if not self.url or not self.key:
604
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
605
- self.dataset_management = DatasetManagement(
606
- api_url=self.url,
607
- api_key=self.key,
608
- verify=self.verify,
609
- timeout=self.timeout
610
- )
611
- return self.dataset_management.create_dataset(name=name, collection=collection, **kwargs)
612
-
613
- @require_api_key
614
- def _update_dataset(self, name: str, append: bool = True, collection: str = "terrakio-datasets", **kwargs):
615
- if not self.dataset_management:
616
- from terrakio_core.dataset_management import DatasetManagement
617
- if not self.url or not self.key:
618
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
619
- self.dataset_management = DatasetManagement(
620
- api_url=self.url,
621
- api_key=self.key,
622
- verify=self.verify,
623
- timeout=self.timeout
624
- )
625
- return self.dataset_management.update_dataset(name=name, append=append, collection=collection, **kwargs)
626
-
627
- @require_api_key
628
- def _overwrite_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs):
629
- if not self.dataset_management:
630
- from terrakio_core.dataset_management import DatasetManagement
631
- if not self.url or not self.key:
632
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
633
- self.dataset_management = DatasetManagement(
634
- api_url=self.url,
635
- api_key=self.key,
636
- verify=self.verify,
637
- timeout=self.timeout
638
- )
639
- return self.dataset_management.overwrite_dataset(name=name, collection=collection, **kwargs)
640
-
641
- @require_api_key
642
- def _delete_dataset(self, name: str, collection: str = "terrakio-datasets"):
643
- if not self.dataset_management:
644
- from terrakio_core.dataset_management import DatasetManagement
645
- if not self.url or not self.key:
646
- raise ConfigurationError("Dataset management client not initialized. Make sure API URL and key are set.")
647
- self.dataset_management = DatasetManagement(
648
- api_url=self.url,
649
- api_key=self.key,
650
- verify=self.verify,
651
- timeout=self.timeout
652
- )
653
- return self.dataset_management.delete_dataset(name=name, collection=collection)
654
-
655
- @require_api_key
656
- def close(self):
657
- """Close all client sessions"""
658
- self.session.close()
659
- if self.auth_client:
660
- self.auth_client.session.close()
661
- # Close aiohttp session if it exists
662
- if self._aiohttp_session and not self._aiohttp_session.closed:
663
- try:
664
- nest_asyncio.apply()
665
- asyncio.run(self.close_async())
666
- except ImportError:
667
- try:
668
- asyncio.run(self.close_async())
669
- except RuntimeError as e:
670
- if "cannot be called from a running event loop" in str(e):
671
- # In Jupyter, we can't properly close the async session
672
- # Log a warning or handle gracefully
673
- import warnings
674
- warnings.warn("Cannot properly close aiohttp session in Jupyter environment. "
675
- "Consider using 'await client.close_async()' instead.")
676
- else:
677
- raise
678
- except RuntimeError:
679
- # Event loop may already be closed, ignore
680
- pass
681
-
682
- @require_api_key
683
- def __enter__(self):
684
- return self
685
-
686
- @require_api_key
687
- def __exit__(self, exc_type, exc_val, exc_tb):
688
- self.close()
689
-
690
- @admin_only_params('location', 'force_loc', 'server')
691
- @require_api_key
692
- def execute_job(self, name, region, output, config, overwrite=False, skip_existing=False, request_json=None, manifest_json=None, location=None, force_loc=None, server="dev-au.terrak.io"):
693
- if not self.mass_stats:
694
- from terrakio_core.mass_stats import MassStats
695
- if not self.url or not self.key:
696
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
697
- self.mass_stats = MassStats(
698
- base_url=self.url,
699
- api_key=self.key,
700
- verify=self.verify,
701
- timeout=self.timeout
702
- )
703
- return self.mass_stats.execute_job(name, region, output, config, overwrite, skip_existing, request_json, manifest_json, location, force_loc, server)
704
-
705
-
706
- @require_api_key
707
- def get_mass_stats_task_id(self, name, stage, uid=None):
708
- if not self.mass_stats:
709
- from terrakio_core.mass_stats import MassStats
710
- if not self.url or not self.key:
711
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
712
- self.mass_stats = MassStats(
713
- base_url=self.url,
714
- api_key=self.key,
715
- verify=self.verify,
716
- timeout=self.timeout
717
- )
718
- return self.mass_stats.get_task_id(name, stage, uid)
719
-
720
- @require_api_key
721
- def track_mass_stats_job(self, ids: Optional[list] = None):
722
- if not self.mass_stats:
723
- from terrakio_core.mass_stats import MassStats
724
- if not self.url or not self.key:
725
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
726
- self.mass_stats = MassStats(
727
- base_url=self.url,
728
- api_key=self.key,
729
- verify=self.verify,
730
- timeout=self.timeout
731
- )
732
- return self.mass_stats.track_job(ids)
733
-
734
- @require_api_key
735
- def get_mass_stats_history(self, limit=100):
736
- if not self.mass_stats:
737
- from terrakio_core.mass_stats import MassStats
738
- if not self.url or not self.key:
739
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
740
- self.mass_stats = MassStats(
741
- base_url=self.url,
742
- api_key=self.key,
743
- verify=self.verify,
744
- timeout=self.timeout
745
- )
746
- return self.mass_stats.get_history(limit)
747
-
748
- @require_api_key
749
- def start_mass_stats_post_processing(self, process_name, data_name, output, consumer_path, overwrite=False):
750
- if not self.mass_stats:
751
- from terrakio_core.mass_stats import MassStats
752
- if not self.url or not self.key:
753
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
754
- self.mass_stats = MassStats(
755
- base_url=self.url,
756
- api_key=self.key,
757
- verify=self.verify,
758
- timeout=self.timeout
759
- )
760
- return self.mass_stats.start_post_processing(process_name, data_name, output, consumer_path, overwrite)
761
-
762
- @require_api_key
763
- def download_mass_stats_results(self, id=None, force_loc=False, **kwargs):
764
- if not self.mass_stats:
765
- from terrakio_core.mass_stats import MassStats
766
- if not self.url or not self.key:
767
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
768
- self.mass_stats = MassStats(
769
- base_url=self.url,
770
- api_key=self.key,
771
- verify=self.verify,
772
- timeout=self.timeout
773
- )
774
- return self.mass_stats.download_results(id, force_loc, **kwargs)
775
-
776
- @require_api_key
777
- def cancel_mass_stats_job(self, id):
778
- if not self.mass_stats:
779
- from terrakio_core.mass_stats import MassStats
780
- if not self.url or not self.key:
781
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
782
- self.mass_stats = MassStats(
783
- base_url=self.url,
784
- api_key=self.key,
785
- verify=self.verify,
786
- timeout=self.timeout
787
- )
788
- return self.mass_stats.cancel_job(id)
789
-
790
- @require_api_key
791
- def cancel_all_mass_stats_jobs(self):
792
- if not self.mass_stats:
793
- from terrakio_core.mass_stats import MassStats
794
- if not self.url or not self.key:
795
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
796
- self.mass_stats = MassStats(
797
- base_url=self.url,
798
- api_key=self.key,
799
- verify=self.verify,
800
- timeout=self.timeout
801
- )
802
- return self.mass_stats.cancel_all_jobs()
803
-
804
- @require_api_key
805
- def _create_pyramids(self, name, levels, config):
806
- if not self.mass_stats:
807
- from terrakio_core.mass_stats import MassStats
808
- if not self.url or not self.key:
809
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
810
- self.mass_stats = MassStats(
811
- base_url=self.url,
812
- api_key=self.key,
813
- verify=self.verify,
814
- timeout=self.timeout
815
- )
816
- return self.mass_stats.create_pyramids(name, levels, config)
817
-
818
- @require_api_key
819
- def random_sample(self, name, **kwargs):
820
- if not self.mass_stats:
821
- from terrakio_core.mass_stats import MassStats
822
- if not self.url or not self.key:
823
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
824
- self.mass_stats = MassStats(
825
- base_url=self.url,
826
- api_key=self.key,
827
- verify=self.verify,
828
- timeout=self.timeout
829
- )
830
- return self.mass_stats.random_sample(name, **kwargs)
831
-
832
- @require_api_key
833
- async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv",
834
- in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
835
- """
836
- Compute zonal statistics for all geometries in a GeoDataFrame using asyncio for concurrency.
837
-
838
- Args:
839
- gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
840
- expr (str): Terrakio expression to evaluate, can include spatial aggregations
841
- conc (int): Number of concurrent requests to make
842
- inplace (bool): Whether to modify the input GeoDataFrame in place
843
- output (str): Output format (csv or netcdf)
844
- in_crs (str): Input coordinate reference system
845
- out_crs (str): Output coordinate reference system
846
- resolution (int): Resolution parameter
847
- buffer (bool): Whether to buffer the request (default True)
848
-
849
- Returns:
850
- geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
851
- """
852
- if conc > 100:
853
- raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
854
-
855
- # Process geometries in batches
856
- all_results = []
857
- row_indices = []
858
-
859
- # Calculate total batches for progress reporting
860
- total_geometries = len(gdb)
861
- total_batches = (total_geometries + conc - 1) // conc # Ceiling division
862
- completed_batches = 0
863
-
864
- print(f"Processing {total_geometries} geometries with concurrency {conc}")
865
-
866
- async def process_geometry(geom, index):
867
- """Process a single geometry"""
868
- try:
869
- feature = {
870
- "type": "Feature",
871
- "geometry": mapping(geom),
872
- "properties": {"index": index}
873
- }
874
- result = await self.wcs_async(expr=expr, feature=feature, output=output,
875
- in_crs=in_crs, out_crs=out_crs, resolution=resolution, buffer=buffer)
876
- # Add original index to track which geometry this result belongs to
877
- if isinstance(result, pd.DataFrame):
878
- result['_geometry_index'] = index
879
- return result
880
- except Exception as e:
881
- raise
882
-
883
- async def process_batch(batch_indices):
884
- """Process a batch of geometries concurrently using TaskGroup"""
885
- try:
886
- async with asyncio.TaskGroup() as tg:
887
- tasks = []
888
- for idx in batch_indices:
889
- geom = gdb.geometry.iloc[idx]
890
- task = tg.create_task(process_geometry(geom, idx))
891
- tasks.append(task)
892
-
893
- # Get results from completed tasks
894
- results = []
895
- for task in tasks:
896
- try:
897
- result = task.result()
898
- results.append(result)
899
- except Exception as e:
900
- raise
901
-
902
- return results
903
- except* Exception as e:
904
- # Get the actual exceptions from the tasks
905
- for task in tasks:
906
- if task.done() and task.exception():
907
- raise task.exception()
908
- raise
909
-
910
- # Process in batches to control concurrency
911
- for i in range(0, len(gdb), conc):
912
- batch_indices = range(i, min(i + conc, len(gdb)))
913
- try:
914
- batch_results = await process_batch(batch_indices)
915
- all_results.extend(batch_results)
916
- row_indices.extend(batch_indices)
917
-
918
- # Update progress
919
- completed_batches += 1
920
- processed_geometries = min(i + conc, total_geometries)
921
- print(f"Progress: {completed_batches}/{total_batches} completed ({processed_geometries}/{total_geometries} geometries processed)")
922
-
923
- except Exception as e:
924
- if hasattr(e, 'response'):
925
- raise APIError(f"API request failed: {e.response.text}")
926
- raise
927
-
928
- print("All batches completed! Processing results...")
929
-
930
- if not all_results:
931
- raise ValueError("No valid results were returned for any geometry")
932
-
933
- # Combine all results
934
- combined_df = pd.concat(all_results, ignore_index=True)
935
-
936
- # Check if we have temporal results
937
- has_time = 'time' in combined_df.columns
938
-
939
- # Create a result GeoDataFrame
940
- if has_time:
941
- # For temporal data, we'll create a hierarchical index
942
- # First make sure we have the geometry index and time columns
943
- if '_geometry_index' not in combined_df.columns:
944
- raise ValueError("Missing geometry index in results")
945
-
946
- # Create hierarchical index on geometry_index and time
947
- combined_df.set_index(['_geometry_index', 'time'], inplace=True)
948
-
949
- # For each unique geometry index, we need the corresponding geometry
950
- geometry_series = gdb.geometry.copy()
951
-
952
- # Get columns that will become new attributes (exclude index/utility columns)
953
- result_cols = combined_df.columns
954
-
955
- # Create a new GeoDataFrame with multi-index
956
- result_rows = []
957
- geometries = []
958
-
959
- # Iterate through the hierarchical index
960
- for (geom_idx, time_val), row in combined_df.iterrows():
961
- # Create a new row with geometry properties + result columns
962
- new_row = {}
963
-
964
- # Add original GeoDataFrame columns (except geometry)
965
- for col in gdb.columns:
966
- if col != 'geometry':
967
- new_row[col] = gdb.loc[geom_idx, col]
968
-
969
- # Add result columns
970
- for col in result_cols:
971
- new_row[col] = row[col]
972
-
973
- result_rows.append(new_row)
974
- geometries.append(gdb.geometry.iloc[geom_idx])
975
-
976
- # Create a new GeoDataFrame with multi-index
977
- multi_index = pd.MultiIndex.from_tuples(
978
- combined_df.index.tolist(),
979
- names=['geometry_index', 'time']
980
- )
981
-
982
- result_gdf = gpd.GeoDataFrame(
983
- result_rows,
984
- geometry=geometries,
985
- index=multi_index
986
- )
987
-
988
- if inplace:
989
- # Can't really do inplace with multi-temporal results as we're changing the structure
990
- return result_gdf
991
- else:
992
- return result_gdf
993
- else:
994
- # Non-temporal data - just add new columns to the existing GeoDataFrame
995
- result_gdf = gdb.copy() if not inplace else gdb
996
-
997
- # Get column names from the results (excluding utility columns)
998
- result_cols = [col for col in combined_df.columns if col not in ['_geometry_index']]
999
-
1000
- # Create a mapping from geometry index to result rows
1001
- geom_idx_to_row = {}
1002
- for idx, row in combined_df.iterrows():
1003
- geom_idx = int(row['_geometry_index'])
1004
- geom_idx_to_row[geom_idx] = row
1005
-
1006
- # Add results as new columns to the GeoDataFrame
1007
- for col in result_cols:
1008
- # Initialize the column with None or appropriate default
1009
- if col not in result_gdf.columns:
1010
- result_gdf[col] = None
1011
-
1012
- # Fill in values from results
1013
- for geom_idx, row in geom_idx_to_row.items():
1014
- result_gdf.loc[geom_idx, col] = row[col]
1015
-
1016
- if inplace:
1017
- return None
1018
- else:
1019
- return result_gdf
1020
31
 
1021
- @require_api_key
1022
- def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv",
1023
- in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
1024
- """
1025
- Compute zonal statistics for all geometries in a GeoDataFrame.
1026
-
1027
- Args:
1028
- gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
1029
- expr (str): Terrakio expression to evaluate, can include spatial aggregations
1030
- conc (int): Number of concurrent requests to make
1031
- inplace (bool): Whether to modify the input GeoDataFrame in place
1032
- output (str): Output format (csv or netcdf)
1033
- in_crs (str): Input coordinate reference system
1034
- out_crs (str): Output coordinate reference system
1035
- resolution (int): Resolution parameter
1036
- buffer (bool): Whether to buffer the request (default True)
1037
-
1038
- Returns:
1039
- geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
1040
- """
1041
- if conc > 100:
1042
- raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
1043
- import asyncio
1044
-
1045
- print(f"Starting zonal statistics computation for expression: {expr}")
1046
-
1047
- # Check if we're in a Jupyter environment or already have an event loop
1048
- try:
1049
- loop = asyncio.get_running_loop()
1050
- # We're in an async context (like Jupyter), use create_task
1051
- nest_asyncio.apply()
1052
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1053
- in_crs, out_crs, resolution, buffer))
1054
- except RuntimeError:
1055
- # No running event loop, safe to use asyncio.run()
1056
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1057
- in_crs, out_crs, resolution, buffer))
1058
- except ImportError:
1059
- # nest_asyncio not available, try alternative approach
1060
- try:
1061
- loop = asyncio.get_running_loop()
1062
- # Create task in existing loop
1063
- task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1064
- in_crs, out_crs, resolution, buffer))
1065
- # This won't work directly - we need a different approach
1066
- raise RuntimeError("Cannot run async code in Jupyter without nest_asyncio. Please install: pip install nest-asyncio")
1067
- except RuntimeError:
1068
- # No event loop, use asyncio.run
1069
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
1070
- in_crs, out_crs, resolution, buffer))
1071
32
 
1072
- # Ensure aiohttp session is closed after running async code
1073
- try:
1074
- if self._aiohttp_session and not self._aiohttp_session.closed:
1075
- asyncio.run(self.close_async())
1076
- except RuntimeError:
1077
- # Event loop may already be closed, ignore
1078
- pass
1079
-
1080
- print("Zonal statistics computation completed!")
1081
- return result
1082
-
1083
- # Group access management protected methods
1084
- @require_api_key
1085
- def _get_group_users_and_datasets(self, group_name: str):
1086
- if not self.group_access:
1087
- from terrakio_core.group_access_management import GroupAccessManagement
1088
- if not self.url or not self.key:
1089
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1090
- self.group_access = GroupAccessManagement(
1091
- api_url=self.url,
1092
- api_key=self.key,
1093
- verify=self.verify,
1094
- timeout=self.timeout
1095
- )
1096
- return self.group_access.get_group_users_and_datasets(group_name=group_name)
1097
-
1098
- @require_api_key
1099
- def _add_group_to_dataset(self, dataset: str, group: str):
1100
- if not self.group_access:
1101
- from terrakio_core.group_access_management import GroupAccessManagement
1102
- if not self.url or not self.key:
1103
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1104
- self.group_access = GroupAccessManagement(
1105
- api_url=self.url,
1106
- api_key=self.key,
1107
- verify=self.verify,
1108
- timeout=self.timeout
1109
- )
1110
- return self.group_access.add_group_to_dataset(dataset=dataset, group=group)
1111
-
1112
- @require_api_key
1113
- def _add_group_to_user(self, uid: str, group: str):
1114
- if not self.group_access:
1115
- from terrakio_core.group_access_management import GroupAccessManagement
1116
- if not self.url or not self.key:
1117
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1118
- self.group_access = GroupAccessManagement(
1119
- api_url=self.url,
1120
- api_key=self.key,
1121
- verify=self.verify,
1122
- timeout=self.timeout
1123
- )
1124
- return self.group_access.add_group_to_user(uid=uid, group=group)
1125
-
1126
- @require_api_key
1127
- def _delete_group_from_user(self, uid: str, group: str):
1128
- if not self.group_access:
1129
- from terrakio_core.group_access_management import GroupAccessManagement
1130
- if not self.url or not self.key:
1131
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1132
- self.group_access = GroupAccessManagement(
1133
- api_url=self.url,
1134
- api_key=self.key,
1135
- verify=self.verify,
1136
- timeout=self.timeout
1137
- )
1138
- return self.group_access.delete_group_from_user(uid=uid, group=group)
1139
-
1140
- @require_api_key
1141
- def _delete_group_from_dataset(self, dataset: str, group: str):
1142
- if not self.group_access:
1143
- from terrakio_core.group_access_management import GroupAccessManagement
1144
- if not self.url or not self.key:
1145
- raise ConfigurationError("Group access management client not initialized. Make sure API URL and key are set.")
1146
- self.group_access = GroupAccessManagement(
1147
- api_url=self.url,
1148
- api_key=self.key,
1149
- verify=self.verify,
1150
- timeout=self.timeout
1151
- )
1152
- return self.group_access.delete_group_from_dataset(dataset=dataset, group=group)
1153
-
1154
- # Space management protected methods
1155
- @require_api_key
1156
- def _get_total_space_used(self):
1157
- if not self.space_management:
1158
- from terrakio_core.space_management import SpaceManagement
1159
- if not self.url or not self.key:
1160
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1161
- self.space_management = SpaceManagement(
1162
- api_url=self.url,
1163
- api_key=self.key,
1164
- verify=self.verify,
1165
- timeout=self.timeout
1166
- )
1167
- return self.space_management.get_total_space_used()
1168
-
1169
- @require_api_key
1170
- def _get_space_used_by_job(self, name: str, region: str = None):
1171
- if not self.space_management:
1172
- from terrakio_core.space_management import SpaceManagement
1173
- if not self.url or not self.key:
1174
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1175
- self.space_management = SpaceManagement(
1176
- api_url=self.url,
1177
- api_key=self.key,
1178
- verify=self.verify,
1179
- timeout=self.timeout
1180
- )
1181
- return self.space_management.get_space_used_by_job(name, region)
1182
-
1183
- @require_api_key
1184
- def _delete_user_job(self, name: str, region: str = None):
1185
- if not self.space_management:
1186
- from terrakio_core.space_management import SpaceManagement
1187
- if not self.url or not self.key:
1188
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1189
- self.space_management = SpaceManagement(
1190
- api_url=self.url,
1191
- api_key=self.key,
1192
- verify=self.verify,
1193
- timeout=self.timeout
1194
- )
1195
- return self.space_management.delete_user_job(name, region)
1196
-
1197
- @require_api_key
1198
- def _delete_data_in_path(self, path: str, region: str = None):
1199
- if not self.space_management:
1200
- from terrakio_core.space_management import SpaceManagement
1201
- if not self.url or not self.key:
1202
- raise ConfigurationError("Space management client not initialized. Make sure API URL and key are set.")
1203
- self.space_management = SpaceManagement(
1204
- api_url=self.url,
1205
- api_key=self.key,
1206
- verify=self.verify,
1207
- timeout=self.timeout
1208
- )
1209
- return self.space_management.delete_data_in_path(path, region)
1210
-
1211
- @require_api_key
1212
- def start_mass_stats_job(self, task_id):
1213
- if not self.mass_stats:
1214
- from terrakio_core.mass_stats import MassStats
1215
- if not self.url or not self.key:
1216
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1217
- self.mass_stats = MassStats(
1218
- base_url=self.url,
1219
- api_key=self.key,
1220
- verify=self.verify,
1221
- timeout=self.timeout
1222
- )
1223
- return self.mass_stats.start_job(task_id)
1224
-
1225
-
1226
- @require_api_key
1227
- def generate_ai_dataset(
1228
- self,
1229
- name: str,
1230
- aoi_geojson: str,
1231
- expression_x: str,
1232
- expression_y: str,
1233
- samples: int,
1234
- tile_size: int,
1235
- crs: str = "epsg:4326",
1236
- res: float = 0.001,
1237
- region: str = "aus",
1238
- start_year: int = None,
1239
- end_year: int = None,
1240
- ) -> dict:
1241
- """
1242
- Generate an AI dataset using specified parameters.
1243
-
1244
- Args:
1245
- name (str): Name of the dataset to generate
1246
- aoi_geojson (str): Path to GeoJSON file containing area of interest
1247
- expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
1248
- expression_y (str): Expression for Y variable with {year} placeholder
1249
- samples (int): Number of samples to generate
1250
- tile_size (int): Size of tiles in degrees
1251
- crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
1252
- res (float, optional): Resolution in degrees. Defaults to 0.001
1253
- region (str, optional): Region code. Defaults to "aus"
1254
- start_year (int, optional): Start year for data generation. Required if end_year provided
1255
- end_year (int, optional): End year for data generation. Required if start_year provided
1256
- overwrite (bool, optional): Whether to overwrite existing dataset. Defaults to False
1257
-
1258
- Returns:
1259
- dict: Response from the AI dataset generation API
1260
-
1261
- Raises:
1262
- ValidationError: If required parameters are missing or invalid
1263
- APIError: If the API request fails
1264
- """
1265
-
1266
- # we have the parameters, let pass the parameters to the random sample function
1267
- # task_id = self.random_sample(name, aoi_geojson, expression_x, expression_y, samples, tile_size, crs, res, region, start_year, end_year, overwrite)
1268
- config = {
1269
- "expressions" : [{"expr": expression_x, "res": res, "prefix": "x"}],
1270
- "filters" : []
1271
- }
1272
- config["expressions"].append({"expr": expression_y, "res" : res, "prefix": "y"})
1273
-
1274
- expression_x = expression_x.replace("{year}", str(start_year))
1275
- expression_y = expression_y.replace("{year}", str(start_year))
1276
- print("the aoi geojson is ", aoi_geojson)
1277
- with open(aoi_geojson, 'r') as f:
1278
- aoi_data = json.load(f)
1279
- print("the config is ", config)
1280
- task_id = self.random_sample(
1281
- name=name,
1282
- config=config,
1283
- aoi=aoi_data,
1284
- samples=samples,
1285
- year_range=[start_year, end_year],
1286
- crs=crs,
1287
- tile_size=tile_size,
1288
- res=res,
1289
- region=region,
1290
- output="netcdf",
1291
- server=self.url,
1292
- bucket="terrakio-mass-requests",
1293
- overwrite=True
1294
- )["task_id"]
1295
- print("the task id is ", task_id)
1296
-
1297
- # Wait for job completion
1298
- import time
1299
-
1300
- while True:
1301
- result = self.track_mass_stats_job(ids=[task_id])
1302
- status = result[task_id]['status']
1303
- print(f"Job status: {status}")
1304
-
1305
- if status == "Completed":
1306
- break
1307
- elif status == "Error":
1308
- raise Exception(f"Job {task_id} encountered an error")
1309
-
1310
- # Wait 30 seconds before checking again
1311
- time.sleep(30)
1312
-
1313
- # print("the result is ", result)
1314
- # after all the random sample jos are done, we then start the mass stats job
1315
- task_id = self.start_mass_stats_job(task_id)
1316
- # now we hav ethe random sampel
1317
-
1318
- # print("the task id is ", task_id)
1319
- return task_id
1320
-
1321
- @require_api_key
1322
- def train_model(self, model_name: str, training_dataset: str, task_type: str, model_category: str, architecture: str, region: str, hyperparameters: dict = None) -> dict:
1323
- """
1324
- Train a model using the external model training API.
1325
-
1326
- Args:
1327
- model_name (str): The name of the model to train.
1328
- training_dataset (str): The training dataset identifier.
1329
- task_type (str): The type of ML task (e.g., regression, classification).
1330
- model_category (str): The category of model (e.g., random_forest).
1331
- architecture (str): The model architecture.
1332
- region (str): The region identifier.
1333
- hyperparameters (dict, optional): Additional hyperparameters for training.
1334
-
1335
- Returns:
1336
- dict: The response from the model training API.
1337
- """
1338
- payload = {
1339
- "model_name": model_name,
1340
- "training_dataset": training_dataset,
1341
- "task_type": task_type,
1342
- "model_category": model_category,
1343
- "architecture": architecture,
1344
- "region": region,
1345
- "hyperparameters": hyperparameters
1346
- }
1347
- endpoint = f"{self.url.rstrip('/')}/train_model"
1348
- try:
1349
- response = self.session.post(endpoint, json=payload, timeout=self.timeout, verify=self.verify)
1350
- if not response.ok:
1351
- error_msg = f"Model training request failed: {response.status_code} {response.reason}"
1352
- try:
1353
- error_data = response.json()
1354
- if "detail" in error_data:
1355
- error_msg += f" - {error_data['detail']}"
1356
- except Exception:
1357
- if response.text:
1358
- error_msg += f" - {response.text}"
1359
- raise APIError(error_msg)
1360
- return response.json()
1361
- except requests.RequestException as e:
1362
- raise APIError(f"Model training request failed: {str(e)}")
1363
-
1364
- # Mass Stats methods
1365
- @require_api_key
1366
- def combine_tiles(self,
1367
- data_name: str,
1368
- usezarr: bool,
1369
- overwrite: bool,
1370
- output : str) -> dict:
1371
-
1372
- if not self.mass_stats:
1373
- from terrakio_core.mass_stats import MassStats
1374
- if not self.url or not self.key:
1375
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1376
- self.mass_stats = MassStats(
1377
- base_url=self.url,
1378
- api_key=self.key,
1379
- verify=self.verify,
1380
- timeout=self.timeout
1381
- )
1382
- return self.mass_stats.combine_tiles(data_name, usezarr, overwrite, output)
1383
-
1384
-
1385
-
1386
- @require_api_key
1387
- def create_dataset_file(
1388
- self,
1389
- name: str,
1390
- aoi: str,
1391
- expression: str,
1392
- output: str,
1393
- tile_size: float = 128.0,
1394
- crs: str = "epsg:4326",
1395
- res: float = 0.0001,
1396
- region: str = "aus",
1397
- to_crs: str = "epsg:4326",
1398
- overwrite: bool = True,
1399
- skip_existing: bool = False,
1400
- non_interactive: bool = True,
1401
- usezarr: bool = False,
1402
- poll_interval: int = 30,
1403
- download_path: str = "/home/user/Downloads",
1404
- ) -> dict:
1405
-
1406
- if not self.mass_stats:
1407
- from terrakio_core.mass_stats import MassStats
1408
- if not self.url or not self.key:
1409
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1410
- self.mass_stats = MassStats(
1411
- base_url=self.url,
1412
- api_key=self.key,
1413
- verify=self.verify,
1414
- timeout=self.timeout
1415
- )
1416
-
1417
-
1418
- from terrakio_core.generation.tiles import tiles
1419
- import tempfile
1420
- import time
1421
-
1422
- body, reqs, groups = tiles(
1423
- name = name,
1424
- aoi = aoi,
1425
- expression = expression,
1426
- output = output,
1427
- tile_size = tile_size,
1428
- crs = crs,
1429
- res = res,
1430
- region = region,
1431
- to_crs = to_crs,
1432
- fully_cover = True,
1433
- overwrite = overwrite,
1434
- skip_existing = skip_existing,
1435
- non_interactive = non_interactive
1436
- )
1437
-
1438
- # Create temp json files before upload
1439
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempreq:
1440
- tempreq.write(reqs)
1441
- tempreqname = tempreq.name
1442
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempmanifest:
1443
- tempmanifest.write(groups)
1444
- tempmanifestname = tempmanifest.name
1445
-
1446
- if not self.mass_stats:
1447
- from terrakio_core.mass_stats import MassStats
1448
- if not self.url or not self.key:
1449
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1450
- self.mass_stats = MassStats(
1451
- base_url=self.url,
1452
- api_key=self.key,
1453
- verify=self.verify,
1454
- timeout=self.timeout
1455
- )
1456
-
1457
- task_id = self.mass_stats.execute_job(
1458
- name=body["name"],
1459
- region=body["region"],
1460
- output=body["output"],
1461
- config = {},
1462
- overwrite=body["overwrite"],
1463
- skip_existing=body["skip_existing"],
1464
- request_json=tempreqname,
1465
- manifest_json=tempmanifestname,
1466
- )
1467
-
1468
- ### Start combining tiles when generation-tiles job is done
1469
- start_time = time.time()
1470
- status = None
1471
-
1472
- while True:
1473
- try:
1474
- taskid = task_id['task_id']
1475
- trackinfo = self.mass_stats.track_job([taskid])
1476
- status = trackinfo[taskid]['status']
1477
-
1478
- # Check completion states
1479
- if status == 'Completed':
1480
- print('Tiles generated successfully!')
1481
- break
1482
- elif status in ['Failed', 'Cancelled', 'Error']:
1483
- raise RuntimeError(f"Job {taskid} failed with status: {status}")
1484
- else:
1485
- # Job is still running
1486
- elapsed_time = time.time() - start_time
1487
- print(f"Job status: {status} - Elapsed time: {elapsed_time:.1f}s", end='\r')
1488
-
1489
- # Sleep before next check
1490
- time.sleep(poll_interval)
1491
-
1492
-
1493
- except KeyboardInterrupt:
1494
- print(f"\nInterrupted! Job {taskid} is still running in the background.")
1495
- raise
1496
- except Exception as e:
1497
- print(f"\nError tracking job: {e}")
1498
- raise
1499
-
1500
- # Clean up temporary files
1501
- import os
1502
- os.unlink(tempreqname)
1503
- os.unlink(tempmanifestname)
1504
-
1505
-
1506
- # return self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
1507
-
1508
- # Start combining tiles
1509
- combine_result = self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
1510
- combine_task_id = combine_result.get("task_id")
1511
-
1512
- # Poll combine_tiles job status
1513
- combine_start_time = time.time()
1514
- while True:
1515
- try:
1516
- trackinfo = self.mass_stats.track_job([combine_task_id])
1517
- download_file_name = trackinfo[combine_task_id]['folder'] + '.nc'
1518
- bucket = trackinfo[combine_task_id]['bucket']
1519
- combine_status = trackinfo[combine_task_id]['status']
1520
- if combine_status == 'Completed':
1521
- print('Tiles combined successfully!')
1522
- break
1523
- elif combine_status in ['Failed', 'Cancelled', 'Error']:
1524
- raise RuntimeError(f"Combine job {combine_task_id} failed with status: {combine_status}")
1525
- else:
1526
- elapsed_time = time.time() - combine_start_time
1527
- print(f"Combine job status: {combine_status} - Elapsed time: {elapsed_time:.1f}s", end='\r')
1528
- time.sleep(poll_interval)
1529
- except KeyboardInterrupt:
1530
- print(f"\nInterrupted! Combine job {combine_task_id} is still running in the background.")
1531
- raise
1532
- except Exception as e:
1533
- print(f"\nError tracking combine job: {e}")
1534
- raise
1535
-
1536
- # Download the resulting file
1537
- if download_path:
1538
- self.mass_stats.download_file(body["name"], bucket, download_file_name, download_path)
1539
- else:
1540
- path = f"{body['name']}/combinedOutput/{download_file_name}"
1541
- print(f"Combined file is available at {path}")
1542
-
1543
- return {"generation_task_id": task_id, "combine_task_id": combine_task_id}
1544
-
1545
-
1546
- # taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
1547
- # trackinfo = self.mass_stats.track_job([taskid])
1548
- # bucket = trackinfo[taskid]['bucket']
1549
- # return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
1550
-
1551
-
1552
- @require_api_key
1553
- def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
1554
- script_content = self._generate_script(model_name, product, model_training_job_name, uid)
1555
- script_name = f"{product}.py"
1556
- self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
1557
- self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
1558
-
1559
- @require_api_key
1560
- def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
1561
- return textwrap.dedent(f'''
1562
- import logging
1563
- from io import BytesIO
1564
-
1565
- import numpy as np
1566
- import pandas as pd
1567
- import xarray as xr
1568
- from google.cloud import storage
1569
- from onnxruntime import InferenceSession
1570
-
1571
- logging.basicConfig(
1572
- level=logging.INFO
1573
- )
1574
-
1575
- def get_model():
1576
- logging.info("Loading model for {model_name}...")
1577
-
1578
- client = storage.Client()
1579
- bucket = client.get_bucket('terrakio-mass-requests')
1580
- blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
1581
-
1582
- model = BytesIO()
1583
- blob.download_to_file(model)
1584
- model.seek(0)
1585
-
1586
- session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
1587
- return session
1588
-
1589
- def {product}(*bands, model):
1590
- logging.info("start preparing data")
1591
- print("the bands are ", bands)
1592
-
1593
- data_arrays = list(bands)
1594
-
1595
- print("the data arrays are ", [da.name for da in data_arrays])
1596
-
1597
- reference_array = data_arrays[0]
1598
- original_shape = reference_array.shape
1599
- logging.info(f"Original shape: {{original_shape}}")
1600
-
1601
- if 'time' in reference_array.dims:
1602
- time_coords = reference_array.coords['time']
1603
- if len(time_coords) == 1:
1604
- output_timestamp = time_coords[0]
1605
- else:
1606
- years = [pd.to_datetime(t).year for t in time_coords.values]
1607
- unique_years = set(years)
1608
-
1609
- if len(unique_years) == 1:
1610
- year = list(unique_years)[0]
1611
- output_timestamp = pd.Timestamp(f"{{year}}-01-01")
1612
- else:
1613
- latest_year = max(unique_years)
1614
- output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
1615
- else:
1616
- output_timestamp = pd.Timestamp("1970-01-01")
1617
-
1618
- averaged_bands = []
1619
- for data_array in data_arrays:
1620
- if 'time' in data_array.dims:
1621
- averaged_band = np.mean(data_array.values, axis=0)
1622
- logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
1623
- else:
1624
- averaged_band = data_array.values
1625
- logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
1626
-
1627
- flattened_band = averaged_band.reshape(-1, 1)
1628
- averaged_bands.append(flattened_band)
1629
-
1630
- input_data = np.hstack(averaged_bands)
1631
-
1632
- logging.info(f"Final input shape: {{input_data.shape}}")
1633
-
1634
- output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
1635
-
1636
- logging.info(f"Model output shape: {{output.shape}}")
1637
-
1638
- if len(original_shape) >= 3:
1639
- spatial_shape = original_shape[1:]
1640
- else:
1641
- spatial_shape = original_shape
1642
-
1643
- output_reshaped = output.reshape(spatial_shape)
1644
-
1645
- output_with_time = np.expand_dims(output_reshaped, axis=0)
1646
-
1647
- if 'time' in reference_array.dims:
1648
- spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
1649
- spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
1650
- else:
1651
- spatial_dims = list(reference_array.dims)
1652
- spatial_coords = dict(reference_array.coords)
1653
-
1654
- result = xr.DataArray(
1655
- data=output_with_time.astype(np.float32),
1656
- dims=['time'] + list(spatial_dims),
1657
- coords={
1658
- 'time': [output_timestamp.values],
1659
- 'y': spatial_coords['y'].values,
1660
- 'x': spatial_coords['x'].values
1661
- }
1662
- )
1663
- return result
1664
- ''').strip()
1665
-
1666
- @require_api_key
1667
- def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
1668
- """Upload the generated script to Google Cloud Storage"""
1669
-
1670
- client = storage.Client()
1671
- bucket = client.get_bucket('terrakio-mass-requests')
1672
- blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
1673
- blob.upload_from_string(script_content, content_type='text/plain')
1674
- logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
1675
-
1676
-
1677
-
1678
-
1679
- @require_api_key
1680
- def download_file_to_path(self, job_name, stage, file_name, output_path):
1681
- if not self.mass_stats:
1682
- from terrakio_core.mass_stats import MassStats
1683
- if not self.url or not self.key:
1684
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1685
- self.mass_stats = MassStats(
1686
- base_url=self.url,
1687
- api_key=self.key,
1688
- verify=self.verify,
1689
- timeout=self.timeout
1690
- )
1691
-
1692
- # fetch bucket info based on job name and stage
1693
-
1694
- taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
1695
- trackinfo = self.mass_stats.track_job([taskid])
1696
- bucket = trackinfo[taskid]['bucket']
1697
- return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
1698
-
1699
-
1700
- # Apply the @require_api_key decorator to ALL methods in BaseClient
1701
- # except only the absolute minimum that shouldn't require auth
1702
- def _apply_api_key_decorator():
1703
- # Only these methods can be used without API key
1704
- skip_auth_methods = [
1705
- '__init__', 'login', 'signup'
1706
- ]
1707
-
1708
- # Get all attributes of BaseClient
1709
- for attr_name in dir(BaseClient):
1710
- # Skip special methods and methods in skip list
1711
- if attr_name.startswith('__') and attr_name not in ['__enter__', '__exit__', '__aenter__', '__aexit__'] or attr_name in skip_auth_methods:
1712
- continue
1713
-
1714
- # Get the attribute
1715
- attr = getattr(BaseClient, attr_name)
1716
-
1717
- # Skip if not callable (not a method) or already decorated
1718
- if not callable(attr) or hasattr(attr, '_is_decorated'):
1719
- continue
1720
-
1721
- # Apply decorator to EVERY method not in skip list
1722
- setattr(BaseClient, attr_name, require_api_key(attr))
1723
- # Mark as decorated to avoid double decoration
1724
- getattr(BaseClient, attr_name)._is_decorated = True
1725
-
1726
- # Run the decorator application
1727
- _apply_api_key_decorator()
33
+ @abstractmethod
34
+ def _setup_session(self):
35
+ """Initialize the HTTP session - implemented by sync/async clients"""
36
+ pass