terrakio-core 0.3.3__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

@@ -0,0 +1,594 @@
1
+ from typing import Dict, Any, Optional
2
+ import json
3
+ import gzip
4
+ import os
5
+ from pathlib import Path
6
+ from urllib.parse import urlparse
7
+ from ..helper.decorators import require_token, require_api_key, require_auth
8
+ import aiohttp
9
+
10
+ class MassStats:
11
+ def __init__(self, client):
12
+ self._client = client
13
+
14
+ @require_api_key
15
+ async def upload_request(
16
+ self,
17
+ name: str,
18
+ size: int,
19
+ region: str,
20
+ output: str,
21
+ config: Dict[str, Any],
22
+ overwrite: bool = False,
23
+ skip_existing: bool = False,
24
+ location: Optional[str] = None,
25
+ force_loc: Optional[bool] = None,
26
+ server: Optional[str] = "dev-au.terrak.io",
27
+ ) -> Dict[str, Any]:
28
+ """
29
+ Upload a request to the mass stats server.
30
+
31
+ Args:
32
+ name: The name of the job
33
+ size: The size of the job
34
+ region: The region of the job
35
+ output: The output of the job
36
+ config: The config of the job
37
+ overwrite: Whether to overwrite the job
38
+ skip_existing: Whether to skip existing jobs
39
+ location: The location of the job
40
+ force_loc: Whether to force the location
41
+ server: The server to use
42
+
43
+ Returns:
44
+ API response as a dictionary
45
+
46
+ Raises:
47
+ APIError: If the API request fails
48
+ """
49
+ payload = {
50
+ "name": name,
51
+ "size": size,
52
+ "region": region,
53
+ "output": output,
54
+ "config": config,
55
+ "overwrite": overwrite,
56
+ "skip_existing": skip_existing,
57
+ "server": server
58
+ }
59
+ payload_mapping = {
60
+ "location": location,
61
+ "force_loc": force_loc
62
+ }
63
+ for key, value in payload_mapping.items():
64
+ if value is not None:
65
+ payload[key] = str(value).lower()
66
+ return await self._client._terrakio_request("POST", "mass_stats/upload", json=payload)
67
+
68
+
69
+ @require_api_key
70
+ def start_job(self, id: str) -> Dict[str, Any]:
71
+ """
72
+ Start a mass stats job by task ID.
73
+
74
+ Args:
75
+ task_id: The ID of the task to start
76
+
77
+ Returns:
78
+ API response as a dictionary
79
+
80
+ """
81
+ return self._client._terrakio_request("POST", f"mass_stats/start/{id}")
82
+
83
+ @require_api_key
84
+ def get_task_id(self, name: str, stage: str, uid: Optional[str] = None) -> Dict[str, Any]:
85
+ """
86
+ Get the task ID for a mass stats job by name and stage (and optionally user ID).
87
+
88
+ Args:
89
+ name: The name of the job
90
+ stage: The stage of the job
91
+ uid: The user ID of the job
92
+
93
+ Returns:
94
+ API response as a dictionary
95
+
96
+ Raises:
97
+ APIError: If the API request fails
98
+ """
99
+ url = f"mass_stats/job_id?name={name}&stage={stage}"
100
+ if uid is not None:
101
+ url += f"&uid={uid}"
102
+ return self._client._terrakio_request("GET", url)
103
+
104
+ @require_api_key
105
+ async def track_job(self, ids: Optional[list] = None) -> Dict[str, Any]:
106
+ """
107
+ Track the status of one or more mass stats jobs.
108
+
109
+ Args:
110
+ ids: The IDs of the jobs to track
111
+
112
+ Returns:
113
+ API response as a dictionary
114
+
115
+ Raises:
116
+ APIError: If the API request fails
117
+ """
118
+ data = {"ids": ids} if ids is not None else {}
119
+ return await self._client._terrakio_request("POST", "mass_stats/track", json=data)
120
+
121
+ @require_api_key
122
+ def get_history(self, limit: Optional[int] = 100) -> Dict[str, Any]:
123
+ """
124
+ Get the history of mass stats jobs.
125
+
126
+ Args:
127
+ limit: The number of jobs to return
128
+
129
+ Returns:
130
+ API response as a dictionary
131
+
132
+ Raises:
133
+ APIError: If the API request fails
134
+ """
135
+ params = {"limit": limit}
136
+ return self._client._terrakio_request("GET", "mass_stats/history", params=params)
137
+
138
+ @require_api_key
139
+ def start_post_processing(
140
+ self,
141
+ process_name: str,
142
+ data_name: str,
143
+ output: str,
144
+ consumer: str,
145
+ overwrite: bool = False
146
+ ) -> Dict[str, Any]:
147
+ """
148
+ Start post processing for a mass stats job.
149
+
150
+ Args:
151
+ process_name: The name of the post processing process
152
+ data_name: The name of the data to process
153
+ output: The output of the post processing
154
+ consumer: The consumer of the post processing
155
+ overwrite: Whether to overwrite the post processing
156
+
157
+ Returns:
158
+ API response as a dictionary
159
+
160
+ Raises:
161
+ APIError: If the API request fails
162
+ """
163
+ payload ={
164
+ "process_name": process_name,
165
+ "data_name": data_name,
166
+ "output": output,
167
+ "consumer": consumer,
168
+ "overwrite": overwrite
169
+ }
170
+ return self._client._terrakio_request("POST", "post_processing/start", json=payload)
171
+
172
+ @require_api_key
173
+ def download_results(
174
+ self,
175
+ file_name: str,
176
+ id: Optional[str] = None,
177
+ force_loc: Optional[bool] = None,
178
+ bucket: Optional[str] = None,
179
+ location: Optional[str] = None,
180
+ output: Optional[str] = None
181
+ ) -> Dict[str, Any]:
182
+ """
183
+ Download results from a mass stats job or arbitrary results if force_loc is True.
184
+
185
+ Args:
186
+ file_name: File name of resulting zip file (required)
187
+ id: Post processing id. Can't be used with 'force_loc'
188
+ force_loc: Download arbitrary results not connected to a mass-stats job id. Can't be used with 'id'
189
+ bucket: Bucket name (required if force_loc is True)
190
+ location: Path to folder in bucket (required if force_loc is True)
191
+ output: Output type (required if force_loc is True)
192
+
193
+ Returns:
194
+ API response as a dictionary
195
+
196
+ Raises:
197
+ APIError: If the API request fails
198
+ ValueError: If validation fails for parameter combinations
199
+ """
200
+ if id is not None and force_loc is True:
201
+ raise ValueError("Cannot use both 'id' and 'force_loc' parameters simultaneously")
202
+
203
+ if id is None and force_loc is not True:
204
+ raise ValueError("Either 'id' or 'force_loc=True' must be provided")
205
+
206
+ if force_loc is True:
207
+ if bucket is None or location is None or output is None:
208
+ raise ValueError("When force_loc is True, 'bucket', 'location', and 'output' must be provided")
209
+
210
+ params = {"file_name": file_name}
211
+
212
+ if id is not None:
213
+ params["id"] = id
214
+ if force_loc is True:
215
+ params["force_loc"] = force_loc
216
+ params["bucket"] = bucket
217
+ params["location"] = location
218
+ params["output"] = output
219
+
220
+ return self._client._terrakio_request("GET", "mass_stats/download", params=params)
221
+
222
+ @require_api_key
223
+ async def upload_file(self, file_path: str, url: str, use_gzip: bool = False):
224
+ """
225
+ Helper method to upload a JSON file to a signed URL.
226
+
227
+ Args:
228
+ file_path: Path to the JSON file
229
+ url: Signed URL to upload to
230
+ use_gzip: Whether to compress the file with gzip
231
+ """
232
+ try:
233
+ with open(file_path, 'r') as file:
234
+ json_data = json.load(file)
235
+ except FileNotFoundError:
236
+ raise FileNotFoundError(f"JSON file not found: {file_path}")
237
+ except json.JSONDecodeError as e:
238
+ raise ValueError(f"Invalid JSON in file {file_path}: {e}")
239
+
240
+ if hasattr(json, 'dumps') and 'ignore_nan' in json.dumps.__code__.co_varnames:
241
+ dumps_kwargs = {'ignore_nan': True}
242
+ else:
243
+ dumps_kwargs = {}
244
+
245
+ if use_gzip:
246
+ body = gzip.compress(json.dumps(json_data, **dumps_kwargs).encode('utf-8'))
247
+ headers = {
248
+ 'Content-Type': 'application/json',
249
+ 'Content-Encoding': 'gzip'
250
+ }
251
+ else:
252
+ body = json.dumps(json_data, **dumps_kwargs).encode('utf-8')
253
+ headers = {
254
+ 'Content-Type': 'application/json'
255
+ }
256
+ response = await self._client._regular_request("PUT", url, data=body, headers=headers)
257
+ return response
258
+
259
+ @require_api_key
260
+ async def download_file(self,
261
+ job_name: str,
262
+ bucket: str,
263
+ file_type: str,
264
+ output_path: str,
265
+ page_size: int = None,
266
+ ) -> list:
267
+ """
268
+ Download a file from mass_stats using job name and file name.
269
+
270
+ Args:
271
+ job_name: Name of the job
272
+ download_all: Whether to download all raw files from the job
273
+ file_type: either 'raw' or 'processed'
274
+ current_page: Current page number for pagination
275
+ page_size: Number of file per page for download
276
+ output_path: Path where the file should be saved
277
+
278
+ Returns:
279
+ str: Path to the downloaded file
280
+ """
281
+
282
+
283
+ if file_type not in ("raw", "processed"):
284
+ raise ValueError("file_type must be 'raw' or 'processed'.")
285
+
286
+ if file_type == "raw" and page_size is None:
287
+ raise ValueError("page_size is required to define pagination size when downloading raw files.")
288
+
289
+ request_body = {
290
+ "job_name": job_name,
291
+ "bucket": bucket,
292
+ "file_type": file_type
293
+ }
294
+
295
+ output_dir = Path(output_path)
296
+ output_dir.mkdir(parents=True, exist_ok=True)
297
+ output_files = []
298
+
299
+ async def download_urls_batch(download_urls, session):
300
+ for url in download_urls:
301
+ self._client.logger.info(f"Processing download URL: {url}")
302
+ parsed = urlparse(url)
303
+ path_parts = Path(parsed.path).parts
304
+ try:
305
+ data_idx = path_parts.index("data") if file_type == "raw" else path_parts.index("outputs")
306
+ subpath = Path(*path_parts[data_idx + 1:])
307
+ except ValueError:
308
+ subpath = Path(path_parts[-1])
309
+ file_save_path = output_dir / subpath
310
+ file_save_path.parent.mkdir(parents=True, exist_ok=True)
311
+ self._client.logger.info(f"Downloading file to {file_save_path}")
312
+
313
+ async with session.get(url) as resp:
314
+ resp.raise_for_status()
315
+ import aiofiles
316
+ async with aiofiles.open(file_save_path, 'wb') as file:
317
+ async for chunk in resp.content.iter_chunked(1048576): # 1 MB
318
+ if chunk:
319
+ await file.write(chunk)
320
+
321
+ if not os.path.exists(file_save_path):
322
+ raise Exception(f"File was not written to {file_save_path}")
323
+
324
+ file_size = os.path.getsize(file_save_path)
325
+ self._client.logger.info(f"File downloaded successfully to {file_save_path} (size: {file_size / (1024 * 1024):.4f} mb)")
326
+ output_files.append(str(file_save_path))
327
+
328
+ try:
329
+ page = 1
330
+ total_files = None
331
+ downloaded_files = 0
332
+ async with aiohttp.ClientSession() as session:
333
+ while True:
334
+ params = {
335
+ "page": page,
336
+ "page_size": page_size
337
+ }
338
+ response = await self._client._terrakio_request("POST", "mass_stats/download_files", json=request_body, params=params)
339
+ data = response
340
+
341
+ self._client.logger.info(f'processed, endpoint response is {data}')
342
+ download_urls = data.get('download_urls')
343
+ if not download_urls:
344
+ break
345
+ await download_urls_batch(download_urls, session)
346
+ if total_files is None:
347
+ total_files = data.get('subdir_total_files')
348
+ downloaded_files += len(download_urls)
349
+ if total_files is not None and downloaded_files >= total_files:
350
+ break
351
+ if len(download_urls) < page_size:
352
+ break # Last page
353
+ page += 1
354
+ return output_files
355
+ except Exception as e:
356
+ raise Exception(f"Error in download process: {e}")
357
+
358
+ def validate_request(self, request_json_path: str):
359
+ with open(request_json_path, 'r') as file:
360
+ request_data = json.load(file)
361
+ if not isinstance(request_data, list):
362
+ raise ValueError(f"Request JSON file {request_json_path} should contain a list of dictionaries")
363
+ for i, request in enumerate(request_data):
364
+ if not isinstance(request, dict):
365
+ raise ValueError(f"Request {i} should be a dictionary")
366
+ required_keys = ["request", "group", "file"]
367
+ for key in required_keys:
368
+ if key not in request:
369
+ raise ValueError(f"Request {i} should contain {key}")
370
+ try:
371
+ str(request["group"])
372
+ except ValueError:
373
+ ValueError("Group must be string or convertible to string")
374
+ if not isinstance(request["request"], dict):
375
+ raise ValueError("Request must be a dictionary")
376
+ if not isinstance(request["file"], (str, int, list)):
377
+ raise ValueError("'file' must be a string or a list of strings")
378
+ # Only check the first 3 requests
379
+ if i == 3:
380
+ break
381
+
382
+ @require_api_key
383
+ async def execute_job(
384
+ self,
385
+ name: str,
386
+ region: str,
387
+ output: str,
388
+ config: Dict[str, Any],
389
+ request_json: Dict[str, Any],
390
+ manifest_json: Dict[str, Any],
391
+ overwrite: bool = False,
392
+ skip_existing: bool = False,
393
+ location: str = None,
394
+ force_loc: bool = None,
395
+ server: str = "dev-au.terrak.io"
396
+ ) -> Dict[str, Any]:
397
+ """
398
+ Execute a mass stats job.
399
+
400
+ Args:
401
+ name: The name of the job
402
+ region: The region of the job
403
+ output: The output of the job
404
+ config: The config of the job
405
+ request_json: The request JSON
406
+ manifest_json: The manifest JSON
407
+ overwrite: Whether to overwrite the job
408
+ skip_existing: Whether to skip existing jobs
409
+ location: The location of the job
410
+ force_loc: Whether to force the location
411
+ server: The server to use
412
+
413
+ Returns:
414
+ API response as a dictionary
415
+
416
+ Raises:
417
+ APIError: If the API request fails
418
+ """
419
+ try:
420
+ with open(request_json, 'r') as file:
421
+ request_data = json.load(file)
422
+ if isinstance(request_data, list):
423
+ size = len(request_data)
424
+ else:
425
+ raise ValueError(f"Request JSON file {request_json} should contain a list of dictionaries")
426
+ except FileNotFoundError as e:
427
+ return e
428
+ except json.JSONDecodeError as e:
429
+ return e
430
+ upload_result = await self.upload_request(name = name, size = size, region = region, output = output, config = config, location = location, force_loc = force_loc, overwrite = overwrite, server = server, skip_existing = skip_existing)
431
+ requests_url = upload_result.get('requests_url')
432
+ manifest_url = upload_result.get('manifest_url')
433
+ if not requests_url:
434
+ raise ValueError("No requests_url returned from server for request JSON upload")
435
+
436
+ try:
437
+ # in this place we are uploading the request json file, we need to check whether the json is in the correct format or not
438
+ self.validate_request(request_json)
439
+ requests_response = await self.upload_file(request_json, requests_url, use_gzip=True)
440
+ if requests_response.status not in [200, 201, 204]:
441
+ self._client.logger.error(f"Requests upload error: {requests_response.text()}")
442
+ raise Exception(f"Failed to upload request JSON: {requests_response.text()}")
443
+ except Exception as e:
444
+ raise Exception(f"Error uploading request JSON file {request_json}: {e}")
445
+
446
+ if not manifest_url:
447
+ raise ValueError("No manifest_url returned from server for manifest JSON upload")
448
+
449
+ try:
450
+ manifest_response = await self.upload_file(manifest_json, manifest_url, use_gzip=False)
451
+ if manifest_response.status not in [200, 201, 204]:
452
+ self._client.logger.error(f"Manifest upload error: {manifest_response.text()}")
453
+ raise Exception(f"Failed to upload manifest JSON: {manifest_response.text()}")
454
+ except Exception as e:
455
+ raise Exception(f"Error uploading manifest JSON file {manifest_json}: {e}")
456
+
457
+ start_job_task_id =await self.start_job(upload_result.get("id"))
458
+ return start_job_task_id
459
+
460
+ @require_api_key
461
+ def cancel_job(self, id: str) -> Dict[str, Any]:
462
+ """
463
+ Cancel a mass stats job by ID.
464
+
465
+ Args:
466
+ id: The ID of the mass stats job to cancel
467
+
468
+ Returns:
469
+ API response as a dictionary
470
+
471
+ Raises:
472
+ APIError: If the API request fails
473
+ """
474
+ return self._client._terrakio_request("POST", f"mass_stats/cancel/{id}")
475
+
476
+ @require_api_key
477
+ def cancel_all_jobs(self) -> Dict[str, Any]:
478
+ """
479
+ Cancel all mass stats jobs.
480
+
481
+ Returns:
482
+ API response as a dictionary
483
+
484
+ Raises:
485
+ APIError: If the API request fails
486
+ """
487
+ return self._client._terrakio_request("POST", "mass_stats/cancel")
488
+
489
+ @require_api_key
490
+ def random_sample(
491
+ self,
492
+ name: str,
493
+ config: dict,
494
+ aoi: dict,
495
+ samples: int,
496
+ crs: str,
497
+ tile_size: int,
498
+ res: float,
499
+ output: str,
500
+ region: str,
501
+ year_range: list[int] = None,
502
+ overwrite: bool = False,
503
+ server: str = None,
504
+ bucket: str = None,
505
+ ) -> Dict[str, Any]:
506
+ """
507
+ Submit a random sample job.
508
+
509
+ Args:
510
+ name: The name of the job
511
+ config: The config of the job
512
+ aoi: The AOI of the job
513
+ samples: The number of samples to take
514
+ crs: The CRS of the job
515
+ tile_size: The tile size of the job
516
+ res: The resolution of the job
517
+ output: The output of the job
518
+ region: The region of the job
519
+ year_range: The year range of the job
520
+ overwrite: Whether to overwrite the job
521
+ server: The server to use
522
+ bucket: The bucket to use
523
+
524
+ Returns:
525
+ API response as a dictionary
526
+
527
+ Raises:
528
+ APIError: If the API request fails
529
+ """
530
+ payload ={
531
+ "name": name,
532
+ "config": config,
533
+ "aoi": aoi,
534
+ "samples": samples,
535
+ "crs": crs,
536
+ "tile_size": tile_size,
537
+ "res": res,
538
+ "output": output,
539
+ "region": region,
540
+ "overwrite": str(overwrite).lower(),
541
+ }
542
+ payload_mapping = {
543
+ "year_range": year_range,
544
+ "server": server,
545
+ "bucket": bucket
546
+ }
547
+ for key, value in payload_mapping.items():
548
+ if value is not None:
549
+ payload[key] = value
550
+ return self._client._terrakio_request("POST", "random_sample", json=payload)
551
+
552
+
553
+ @require_api_key
554
+ def create_pyramids(self, name: str, levels: int, config: dict) -> Dict[str, Any]:
555
+ """
556
+ Create pyramids for a dataset.
557
+
558
+ Args:
559
+ name: The name of the job
560
+ levels: The levels of the pyramids
561
+ config: The config of the job
562
+
563
+ Returns:
564
+ API response as a dictionary
565
+ """
566
+ payload = {
567
+ "name": name,
568
+ "levels": levels,
569
+ "config": config
570
+ }
571
+ return self._client._terrakio_request("POST", "pyramids/create", json=payload)
572
+
573
+ @require_api_key
574
+ async def combine_tiles(self, data_name: str, overwrite: bool = True, output: str = "netcdf") -> Dict[str, Any]:
575
+ """
576
+ Combine tiles for a dataset.
577
+
578
+ Args:
579
+ data_name: The name of the dataset
580
+ overwrite: Whether to overwrite the dataset
581
+ output: The output of the dataset
582
+
583
+ Returns:
584
+ API response as a dictionary
585
+
586
+ Raises:
587
+ APIError: If the API request fails
588
+ """
589
+ payload = {
590
+ 'data_name': data_name,
591
+ 'output': output,
592
+ 'overwrite': str(overwrite).lower()
593
+ }
594
+ return await self._client._terrakio_request("POST", "mass_stats/combine_tiles", json=payload)