terrakio-core 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
terrakio_core/__init__.py CHANGED
@@ -4,4 +4,4 @@ Terrakio Core
4
4
  Core components for Terrakio API clients.
5
5
  """
6
6
 
7
- __version__ = "0.3.0"
7
+ __version__ = "0.3.2"
terrakio_core/client.py CHANGED
@@ -13,6 +13,7 @@ from shapely.geometry import shape, mapping
13
13
  from shapely.geometry.base import BaseGeometry as ShapelyGeometry
14
14
  from google.cloud import storage
15
15
  from .exceptions import APIError, ConfigurationError
16
+ from .decorators import admin_only_params
16
17
  import logging
17
18
  import textwrap
18
19
 
@@ -129,6 +130,7 @@ class BaseClient:
129
130
  "resolution": resolution,
130
131
  **kwargs
131
132
  }
133
+ print("the payload is ", payload)
132
134
  request_url = f"{self.url}/geoquery"
133
135
  for attempt in range(retry + 1):
134
136
  try:
@@ -536,8 +538,8 @@ class BaseClient:
536
538
  def __exit__(self, exc_type, exc_val, exc_tb):
537
539
  self.close()
538
540
 
539
- # Mass Stats methods
540
- def upload_mass_stats(self, name, size, bucket, output, location=None, **kwargs):
541
+ @admin_only_params('location', 'force_loc', 'server')
542
+ def execute_job(self, name, region, output, config, overwrite=False, skip_existing=False, request_json=None, manifest_json=None, location=None, force_loc=None, server="dev-au.terrak.io"):
541
543
  if not self.mass_stats:
542
544
  from terrakio_core.mass_stats import MassStats
543
545
  if not self.url or not self.key:
@@ -548,20 +550,8 @@ class BaseClient:
548
550
  verify=self.verify,
549
551
  timeout=self.timeout
550
552
  )
551
- return self.mass_stats.upload_request(name, size, bucket, output, location, **kwargs)
553
+ return self.mass_stats.execute_job(name, region, output, config, overwrite, skip_existing, request_json, manifest_json, location, force_loc, server)
552
554
 
553
- def start_mass_stats_job(self, task_id):
554
- if not self.mass_stats:
555
- from terrakio_core.mass_stats import MassStats
556
- if not self.url or not self.key:
557
- raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
558
- self.mass_stats = MassStats(
559
- base_url=self.url,
560
- api_key=self.key,
561
- verify=self.verify,
562
- timeout=self.timeout
563
- )
564
- return self.mass_stats.start_job(task_id)
565
555
 
566
556
  def get_mass_stats_task_id(self, name, stage, uid=None):
567
557
  if not self.mass_stats:
@@ -576,7 +566,7 @@ class BaseClient:
576
566
  )
577
567
  return self.mass_stats.get_task_id(name, stage, uid)
578
568
 
579
- def track_mass_stats_job(self, ids=None):
569
+ def track_mass_stats_job(self, ids: Optional[list] = None):
580
570
  if not self.mass_stats:
581
571
  from terrakio_core.mass_stats import MassStats
582
572
  if not self.url or not self.key:
@@ -1049,6 +1039,20 @@ class BaseClient:
1049
1039
  )
1050
1040
  return self.space_management.delete_data_in_path(path, region)
1051
1041
 
1042
+ def start_mass_stats_job(self, task_id):
1043
+ if not self.mass_stats:
1044
+ from terrakio_core.mass_stats import MassStats
1045
+ if not self.url or not self.key:
1046
+ raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1047
+ self.mass_stats = MassStats(
1048
+ base_url=self.url,
1049
+ api_key=self.key,
1050
+ verify=self.verify,
1051
+ timeout=self.timeout
1052
+ )
1053
+ return self.mass_stats.start_job(task_id)
1054
+
1055
+
1052
1056
  def generate_ai_dataset(
1053
1057
  self,
1054
1058
  name: str,
@@ -1118,27 +1122,58 @@ class BaseClient:
1118
1122
  overwrite=True
1119
1123
  )["task_id"]
1120
1124
  print("the task id is ", task_id)
1125
+
1126
+ # Wait for job completion
1127
+ import time
1128
+
1129
+ while True:
1130
+ result = self.track_mass_stats_job(ids=[task_id])
1131
+ status = result[task_id]['status']
1132
+ print(f"Job status: {status}")
1133
+
1134
+ if status == "Completed":
1135
+ break
1136
+ elif status == "Error":
1137
+ raise Exception(f"Job {task_id} encountered an error")
1138
+
1139
+ # Wait 30 seconds before checking again
1140
+ time.sleep(30)
1141
+
1142
+ # print("the result is ", result)
1143
+ # after all the random sample jos are done, we then start the mass stats job
1121
1144
  task_id = self.start_mass_stats_job(task_id)
1122
- print("the task id is ", task_id)
1123
- return task_id
1145
+ # now we hav ethe random sampel
1124
1146
 
1147
+ # print("the task id is ", task_id)
1148
+ return task_id
1125
1149
 
1126
- def train_model(self, model_name: str, training_data: dict) -> dict:
1150
+ def train_model(self, model_name: str, training_dataset: str, task_type: str, model_category: str, architecture: str, region: str, hyperparameters: dict = None) -> dict:
1127
1151
  """
1128
1152
  Train a model using the external model training API.
1129
-
1153
+
1130
1154
  Args:
1131
1155
  model_name (str): The name of the model to train.
1132
- training_data (dict): Dictionary containing training data parameters.
1133
-
1156
+ training_dataset (str): The training dataset identifier.
1157
+ task_type (str): The type of ML task (e.g., regression, classification).
1158
+ model_category (str): The category of model (e.g., random_forest).
1159
+ architecture (str): The model architecture.
1160
+ region (str): The region identifier.
1161
+ hyperparameters (dict, optional): Additional hyperparameters for training.
1162
+
1134
1163
  Returns:
1135
1164
  dict: The response from the model training API.
1136
1165
  """
1137
- endpoint = "https://modeltraining-573248941006.australia-southeast1.run.app/train_model"
1138
1166
  payload = {
1139
1167
  "model_name": model_name,
1140
- "training_data": training_data
1168
+ "training_dataset": training_dataset,
1169
+ "task_type": task_type,
1170
+ "model_category": model_category,
1171
+ "architecture": architecture,
1172
+ "region": region,
1173
+ "hyperparameters": hyperparameters
1141
1174
  }
1175
+ endpoint = f"{self.url.rstrip('/')}/train_model"
1176
+ print("the payload is ", payload)
1142
1177
  try:
1143
1178
  response = self.session.post(endpoint, json=payload, timeout=self.timeout, verify=self.verify)
1144
1179
  if not response.ok:
@@ -1155,35 +1190,163 @@ class BaseClient:
1155
1190
  except requests.RequestException as e:
1156
1191
  raise APIError(f"Model training request failed: {str(e)}")
1157
1192
 
1158
- def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
1159
- # we have the dataset and we have the product, and we have the model name, we need to create a new json file and add that to the dataset as our virtual dataset
1160
- # upload the script to the bucket, the script should be able to download the model and do the inferencing
1161
- # we need to upload the the json to the to the dataset as our virtual dataset
1162
- # then we do nothing and wait for the user to make the request call to the explorer
1163
- # we should have a uniform script for the random forest deployment
1164
- # create a script for each model
1165
- # upload script to google bucket,
1166
- #
1193
+ # Mass Stats methods
1194
+ def combine_tiles(self,
1195
+ data_name: str,
1196
+ usezarr: bool,
1197
+ overwrite: bool,
1198
+ output : str) -> dict:
1199
+
1200
+ if not self.mass_stats:
1201
+ from terrakio_core.mass_stats import MassStats
1202
+ if not self.url or not self.key:
1203
+ raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1204
+ self.mass_stats = MassStats(
1205
+ base_url=self.url,
1206
+ api_key=self.key,
1207
+ verify=self.verify,
1208
+ timeout=self.timeout
1209
+ )
1210
+ return self.mass_stats.combine_tiles(data_name, usezarr, overwrite, output)
1211
+
1212
+
1213
+
1214
+ def create_dataset_file(
1215
+ self,
1216
+ name: str,
1217
+ aoi: str,
1218
+ expression: str,
1219
+ output: str,
1220
+ tile_size: float = 128.0,
1221
+ crs: str = "epsg:4326",
1222
+ res: float = 0.0001,
1223
+ region: str = "aus",
1224
+ to_crs: str = "epsg:4326",
1225
+ overwrite: bool = True,
1226
+ skip_existing: bool = False,
1227
+ non_interactive: bool = True,
1228
+ usezarr: bool = False,
1229
+ poll_interval: int = 30 # seconds between job status checks
1230
+ ) -> dict:
1231
+
1232
+ from terrakio_core.generation.tiles import tiles
1233
+ import tempfile
1234
+ import time
1235
+
1236
+ body, reqs, groups = tiles(
1237
+ name = name,
1238
+ aoi = aoi,
1239
+ expression = expression,
1240
+ output = output,
1241
+ tile_size = tile_size,
1242
+ crs = crs,
1243
+ res = res,
1244
+ region = region,
1245
+ to_crs = to_crs,
1246
+ fully_cover = True,
1247
+ overwrite = overwrite,
1248
+ skip_existing = skip_existing,
1249
+ non_interactive = non_interactive
1250
+ )
1251
+
1252
+ # Create temp json files before upload
1253
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempreq:
1254
+ tempreq.write(reqs)
1255
+ tempreqname = tempreq.name
1256
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempmanifest:
1257
+ tempmanifest.write(groups)
1258
+ tempmanifestname = tempmanifest.name
1259
+
1260
+ if not self.mass_stats:
1261
+ from terrakio_core.mass_stats import MassStats
1262
+ if not self.url or not self.key:
1263
+ raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1264
+ self.mass_stats = MassStats(
1265
+ base_url=self.url,
1266
+ api_key=self.key,
1267
+ verify=self.verify,
1268
+ timeout=self.timeout
1269
+ )
1167
1270
 
1271
+ task_id = self.mass_stats.execute_job(
1272
+ name=body["name"],
1273
+ region=body["region"],
1274
+ output=body["output"],
1275
+ config = {},
1276
+ overwrite=body["overwrite"],
1277
+ skip_existing=body["skip_existing"],
1278
+ request_json=tempreqname,
1279
+ manifest_json=tempmanifestname,
1280
+ )
1281
+
1282
+ ### Start combining tiles when generation-tiles job is done
1283
+ start_time = time.time()
1284
+ status = None
1285
+
1286
+ while True:
1287
+ try:
1288
+ taskid = task_id['task_id']
1289
+ trackinfo = self.mass_stats.track_job([taskid])
1290
+ status = trackinfo[taskid]['status']
1291
+
1292
+ # Check completion states
1293
+ if status == 'Completed':
1294
+ print('Tiles generated successfully!')
1295
+ break
1296
+ elif status in ['Failed', 'Cancelled', 'Error']:
1297
+ raise RuntimeError(f"Job {taskid} failed with status: {status}")
1298
+ else:
1299
+ # Job is still running
1300
+ elapsed_time = time.time() - start_time
1301
+ print(f"Job status: {status} - Elapsed time: {elapsed_time:.1f}s", end='\r')
1302
+
1303
+ # Sleep before next check
1304
+ time.sleep(poll_interval)
1305
+
1306
+
1307
+ except KeyboardInterrupt:
1308
+ print(f"\nInterrupted! Job {taskid} is still running in the background.")
1309
+ raise
1310
+ except Exception as e:
1311
+ print(f"\nError tracking job: {e}")
1312
+ raise
1313
+
1314
+ # Clean up temporary files
1315
+ import os
1316
+ os.unlink(tempreqname)
1317
+ os.unlink(tempmanifestname)
1318
+
1319
+
1320
+ # Start combining tiles
1321
+ if not self.mass_stats:
1322
+ from terrakio_core.mass_stats import MassStats
1323
+ if not self.url or not self.key:
1324
+ raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1325
+ self.mass_stats = MassStats(
1326
+ base_url=self.url,
1327
+ api_key=self.key,
1328
+ verify=self.verify,
1329
+ timeout=self.timeout
1330
+ )
1331
+
1332
+ return self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
1333
+
1334
+ def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
1168
1335
  script_content = self._generate_script(model_name, product, model_training_job_name, uid)
1169
- # self.create_dataset(collection = "terrakio-datasets", input = input, )
1170
- # we have the script, we need to upload it to the bucket
1171
1336
  script_name = f"{product}.py"
1172
- print("the script content is ", script_content)
1173
- print("the script name is ", script_name)
1174
1337
  self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
1175
- # after uploading the script, we need to create a new virtual dataset
1176
1338
  self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
1177
1339
 
1178
1340
  def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
1179
1341
  return textwrap.dedent(f'''
1180
1342
  import logging
1181
1343
  from io import BytesIO
1182
- from google.cloud import storage
1183
- from onnxruntime import InferenceSession
1344
+
1184
1345
  import numpy as np
1346
+ import pandas as pd
1185
1347
  import xarray as xr
1186
- import datetime
1348
+ from google.cloud import storage
1349
+ from onnxruntime import InferenceSession
1187
1350
 
1188
1351
  logging.basicConfig(
1189
1352
  level=logging.INFO
@@ -1191,54 +1354,122 @@ class BaseClient:
1191
1354
 
1192
1355
  def get_model():
1193
1356
  logging.info("Loading model for {model_name}...")
1194
-
1357
+
1195
1358
  client = storage.Client()
1196
1359
  bucket = client.get_bucket('terrakio-mass-requests')
1197
1360
  blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
1198
-
1361
+
1199
1362
  model = BytesIO()
1200
1363
  blob.download_to_file(model)
1201
1364
  model.seek(0)
1202
-
1365
+
1203
1366
  session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
1204
1367
  return session
1205
1368
 
1206
1369
  def {product}(*bands, model):
1207
1370
  logging.info("start preparing data")
1371
+ print("the bands are ", bands)
1208
1372
 
1209
- original_shape = bands[0].shape
1210
- logging.info(f"Original shape: {{original_shape}}")
1373
+ data_arrays = list(bands)
1211
1374
 
1212
- transformed_bands = []
1213
- for band in bands:
1214
- transformed_band = band.values.reshape(-1,1)
1215
- transformed_bands.append(transformed_band)
1375
+ print("the data arrays are ", [da.name for da in data_arrays])
1216
1376
 
1217
- input_data = np.hstack(transformed_bands)
1377
+ reference_array = data_arrays[0]
1378
+ original_shape = reference_array.shape
1379
+ logging.info(f"Original shape: {{original_shape}}")
1218
1380
 
1381
+ if 'time' in reference_array.dims:
1382
+ time_coords = reference_array.coords['time']
1383
+ if len(time_coords) == 1:
1384
+ output_timestamp = time_coords[0]
1385
+ else:
1386
+ years = [pd.to_datetime(t).year for t in time_coords.values]
1387
+ unique_years = set(years)
1388
+
1389
+ if len(unique_years) == 1:
1390
+ year = list(unique_years)[0]
1391
+ output_timestamp = pd.Timestamp(f"{{year}}-01-01")
1392
+ else:
1393
+ latest_year = max(unique_years)
1394
+ output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
1395
+ else:
1396
+ output_timestamp = pd.Timestamp("1970-01-01")
1397
+
1398
+ averaged_bands = []
1399
+ for data_array in data_arrays:
1400
+ if 'time' in data_array.dims:
1401
+ averaged_band = np.mean(data_array.values, axis=0)
1402
+ logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
1403
+ else:
1404
+ averaged_band = data_array.values
1405
+ logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
1406
+
1407
+ flattened_band = averaged_band.reshape(-1, 1)
1408
+ averaged_bands.append(flattened_band)
1409
+
1410
+ input_data = np.hstack(averaged_bands)
1411
+
1219
1412
  logging.info(f"Final input shape: {{input_data.shape}}")
1220
-
1413
+
1221
1414
  output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
1222
-
1415
+
1223
1416
  logging.info(f"Model output shape: {{output.shape}}")
1224
1417
 
1225
- output_reshaped = output.reshape(original_shape)
1418
+ if len(original_shape) >= 3:
1419
+ spatial_shape = original_shape[1:]
1420
+ else:
1421
+ spatial_shape = original_shape
1422
+
1423
+ output_reshaped = output.reshape(spatial_shape)
1424
+
1425
+ output_with_time = np.expand_dims(output_reshaped, axis=0)
1426
+
1427
+ if 'time' in reference_array.dims:
1428
+ spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
1429
+ spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
1430
+ else:
1431
+ spatial_dims = list(reference_array.dims)
1432
+ spatial_coords = dict(reference_array.coords)
1433
+
1226
1434
  result = xr.DataArray(
1227
- data=output_reshaped,
1228
- dims=bands[0].dims,
1229
- coords=bands[0].coords
1435
+ data=output_with_time.astype(np.float32),
1436
+ dims=['time'] + list(spatial_dims),
1437
+ coords={
1438
+ 'time': [output_timestamp.values],
1439
+ 'y': spatial_coords['y'].values,
1440
+ 'x': spatial_coords['x'].values
1441
+ }
1230
1442
  )
1231
-
1232
1443
  return result
1233
1444
  ''').strip()
1234
-
1445
+
1235
1446
  def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
1236
1447
  """Upload the generated script to Google Cloud Storage"""
1237
1448
 
1238
1449
  client = storage.Client()
1239
1450
  bucket = client.get_bucket('terrakio-mass-requests')
1240
1451
  blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
1241
- # the first layer is the uid, the second layer is the model training job name
1242
1452
  blob.upload_from_string(script_content, content_type='text/plain')
1243
1453
  logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
1244
1454
 
1455
+
1456
+
1457
+
1458
+ def download_file_to_path(self, job_name, stage, file_name, output_path):
1459
+ if not self.mass_stats:
1460
+ from terrakio_core.mass_stats import MassStats
1461
+ if not self.url or not self.key:
1462
+ raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
1463
+ self.mass_stats = MassStats(
1464
+ base_url=self.url,
1465
+ api_key=self.key,
1466
+ verify=self.verify,
1467
+ timeout=self.timeout
1468
+ )
1469
+
1470
+ # fetch bucket info based on job name and stage
1471
+
1472
+ taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
1473
+ trackinfo = self.mass_stats.track_job([taskid])
1474
+ bucket = trackinfo[taskid]['bucket']
1475
+ return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
@@ -83,10 +83,63 @@ class DatasetManagement:
83
83
  except requests.RequestException as e:
84
84
  raise APIError(f"Request failed: {str(e)}")
85
85
 
86
+ # def create_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs) -> Dict[str, Any]:
87
+ # """
88
+ # Create a new dataset.
89
+
90
+ # Args:
91
+ # name: Name of the dataset (required)
92
+ # collection: Dataset collection (default: 'terrakio-datasets')
93
+ # **kwargs: Additional dataset parameters including:
94
+ # - products: List of products
95
+ # - dates_iso8601: List of dates
96
+ # - bucket: Storage bucket
97
+ # - path: Storage path
98
+ # - data_type: Data type
99
+ # - no_data: No data value
100
+ # - l_max: Maximum level
101
+ # - y_size: Y size
102
+ # - x_size: X size
103
+ # - proj4: Projection string
104
+ # - abstract: Dataset abstract
105
+ # - geotransform: Geotransform parameters
106
+
107
+ # Returns:
108
+ # Created dataset information
109
+
110
+ # Raises:
111
+ # APIError: If the API request fails
112
+ # """
113
+ # endpoint = f"{self.api_url}/datasets"
114
+ # params = {"collection": collection}
115
+ # # Create payload with required name parameter
116
+ # payload = {"name": name}
117
+
118
+ # # Add optional parameters if provided
119
+ # for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
120
+ # "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
121
+ # if param in kwargs:
122
+ # payload[param] = kwargs[param]
123
+
124
+ # try:
125
+ # response = self.session.post(
126
+ # endpoint,
127
+ # params=params,
128
+ # json=payload,
129
+ # timeout=self.timeout,
130
+ # verify=self.verify
131
+ # )
132
+
133
+ # if not response.ok:
134
+ # raise APIError(f"API request failed: {response.status_code} {response.reason}")
135
+ # return response.json()
136
+ # except requests.RequestException as e:
137
+ # raise APIError(f"Request failed: {str(e)}")
138
+
86
139
  def create_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs) -> Dict[str, Any]:
87
140
  """
88
141
  Create a new dataset.
89
-
142
+
90
143
  Args:
91
144
  name: Name of the dataset (required)
92
145
  collection: Dataset collection (default: 'terrakio-datasets')
@@ -103,24 +156,23 @@ class DatasetManagement:
103
156
  - proj4: Projection string
104
157
  - abstract: Dataset abstract
105
158
  - geotransform: Geotransform parameters
106
-
159
+ - padding: Padding value
160
+
107
161
  Returns:
108
162
  Created dataset information
109
-
163
+
110
164
  Raises:
111
165
  APIError: If the API request fails
112
166
  """
113
167
  endpoint = f"{self.api_url}/datasets"
114
168
  params = {"collection": collection}
115
- # Create payload with required name parameter
116
169
  payload = {"name": name}
117
-
118
- # Add optional parameters if provided
119
- for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
120
- "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
170
+
171
+ for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
172
+ "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input", "padding"]:
121
173
  if param in kwargs:
122
174
  payload[param] = kwargs[param]
123
-
175
+
124
176
  try:
125
177
  response = self.session.post(
126
178
  endpoint,
@@ -129,7 +181,7 @@ class DatasetManagement:
129
181
  timeout=self.timeout,
130
182
  verify=self.verify
131
183
  )
132
-
184
+
133
185
  if not response.ok:
134
186
  raise APIError(f"API request failed: {response.status_code} {response.reason}")
135
187
  return response.json()
@@ -0,0 +1,18 @@
1
+ # terrakio_core/decorators.py
2
+ def admin_only_params(*restricted_params):
3
+ """
4
+ Decorator factory for restricting method parameters to admin users only.
5
+ """
6
+ def decorator(func):
7
+ def wrapper(self, *args, **kwargs):
8
+ if hasattr(self, '_is_admin') and self._is_admin:
9
+ return func(self, *args, **kwargs)
10
+
11
+ admin_params_used = set(kwargs.keys()) & set(restricted_params)
12
+ if admin_params_used:
13
+ raise PermissionError(f"Parameters {admin_params_used} are only available to admin users")
14
+
15
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k not in restricted_params}
16
+ return func(self, *args, **filtered_kwargs)
17
+ return wrapper
18
+ return decorator
@@ -0,0 +1,95 @@
1
+ ### implementing generation-tiles in python api
2
+ ### function should just generate the json file for mass_stats to pick up.
3
+
4
+ import geopandas as gpd
5
+ import shapely.geometry
6
+ import json
7
+ from rich import print
8
+
9
+ def escape_newline(string):
10
+ if isinstance(string, list):
11
+ return [s.replace('\\n', '\n') for s in string]
12
+ else:
13
+ return string.replace('\\n', '\n')
14
+
15
+ def get_bounds(aoi, crs, to_crs = None):
16
+ aoi : gpd.GeoDataFrame = gpd.read_file(aoi)
17
+ aoi = aoi.set_crs(crs, allow_override=True)
18
+ if to_crs:
19
+ aoi = aoi.to_crs(to_crs)
20
+ bounds = aoi.geometry[0].bounds
21
+ return *bounds, aoi
22
+
23
+ def tile_generator(x_min, y_min, x_max, y_max, aoi, crs, res, tile_size, expression, output, fully_cover=True):
24
+ i_max = int((x_max-x_min)/(tile_size*res))
25
+ j_max = int((y_max-y_min)/(tile_size*res))
26
+ if fully_cover:
27
+ i_max += 1
28
+ j_max += 1
29
+ for j in range(0, int(j_max)):
30
+ for i in range(0, int(i_max)):
31
+ #print(f"Processing tile {i} {j}")
32
+ x = x_min + i*(tile_size*res)
33
+ y = y_max - j*(tile_size*res)
34
+ bbox = shapely.geometry.box(x, y-(tile_size*res), x + (tile_size*res), y)
35
+ if not aoi.geometry[0].intersects(bbox):
36
+ continue
37
+ feat = {"type": "Feature", "geometry": bbox.__geo_interface__}
38
+ data = {
39
+ "feature": feat,
40
+ "in_crs": crs,
41
+ "out_crs": crs,
42
+ "resolution": res,
43
+ "expr" : expression,
44
+ "output" : output,
45
+ }
46
+ yield data, i , j
47
+
48
+
49
+ def tiles(
50
+ name: str,
51
+ aoi : str,
52
+ expression: str = "red=S2v2#(year,median).red@(year =2024) \n red",
53
+ output: str = "netcdf",
54
+ tile_size : float = 512,
55
+ crs : str = "epsg:3577",
56
+ res: float = 10,
57
+ region : str = "eu",
58
+ to_crs: str = None,
59
+ fully_cover: bool = True,
60
+ overwrite: bool = False,
61
+ skip_existing: bool = False,
62
+ non_interactive: bool = False,
63
+ ):
64
+
65
+ # Create requests for each tile
66
+ reqs = []
67
+ x_min, y_min, x_max, y_max, aoi = get_bounds(aoi, crs, to_crs)
68
+ #print(f"Bounds: {x_min}, {y_min}, {x_max}, {y_max}")
69
+
70
+ if to_crs is None:
71
+ to_crs = crs
72
+ for tile_req, i, j in tile_generator(x_min, y_min, x_max, y_max, aoi, to_crs, res, tile_size, expression, output, fully_cover):
73
+ req_name = f"{name}_{i:02d}_{j:02d}"
74
+ reqs.append({"group": "tiles", "file": req_name, "request": tile_req})
75
+
76
+ #print(f"Generated {len(reqs)} tile requests.")
77
+
78
+
79
+ count = len(reqs)
80
+ groups = list(set(dic["group"] for dic in reqs))
81
+
82
+ body = {
83
+ "name" : name,
84
+ "output" : output,
85
+ "region" : region,
86
+ "size" : count,
87
+ "overwrite" : overwrite,
88
+ "non_interactive": non_interactive,
89
+ "skip_existing" : skip_existing,
90
+ }
91
+ request_json = json.dumps(reqs)
92
+ manifest_json = json.dumps(groups)
93
+
94
+ return body, request_json, manifest_json
95
+
@@ -1,5 +1,8 @@
1
1
  import requests
2
- from typing import Optional, Dict, Any
2
+ from typing import Optional, Dict, Any, List
3
+ import json
4
+ import json as json_lib
5
+ import gzip
3
6
 
4
7
  class MassStats:
5
8
  def __init__(self, base_url: str, api_key: str, verify: bool = True, timeout: int = 60):
@@ -12,61 +15,272 @@ class MassStats:
12
15
  'x-api-key': self.api_key
13
16
  })
14
17
 
18
+ def _upload_file(self, file_path: str, url: str, use_gzip: bool = False):
19
+ """
20
+ Helper method to upload a JSON file to a signed URL.
21
+
22
+ Args:
23
+ file_path: Path to the JSON file
24
+ url: Signed URL to upload to
25
+ use_gzip: Whether to compress the file with gzip
26
+ """
27
+ try:
28
+ with open(file_path, 'r') as file:
29
+ json_data = json_lib.load(file)
30
+ except FileNotFoundError:
31
+ raise FileNotFoundError(f"JSON file not found: {file_path}")
32
+ except json.JSONDecodeError as e:
33
+ raise ValueError(f"Invalid JSON in file {file_path}: {e}")
34
+
35
+ # Check if using simplejson and support ignore_nan
36
+ if hasattr(json_lib, 'dumps') and 'ignore_nan' in json_lib.dumps.__code__.co_varnames:
37
+ dumps_kwargs = {'ignore_nan': True}
38
+ else:
39
+ dumps_kwargs = {}
40
+
41
+ if use_gzip:
42
+ # Serialize and compress the JSON data
43
+ body = gzip.compress(json_lib.dumps(json_data, **dumps_kwargs).encode('utf-8'))
44
+ headers = {
45
+ 'Content-Type': 'application/json',
46
+ 'Content-Encoding': 'gzip'
47
+ }
48
+ else:
49
+ body = json_lib.dumps(json_data, **dumps_kwargs).encode('utf-8')
50
+ headers = {
51
+ 'Content-Type': 'application/json'
52
+ }
53
+
54
+ # Make the PUT request to the signed URL
55
+ response = requests.put(
56
+ url,
57
+ data=body,
58
+ headers=headers
59
+ )
60
+
61
+ return response
62
+
63
+
64
+ def download_file(self, job_name: str, bucket:str, file_name: str, output_path: str) -> str:
65
+ """
66
+ Download a file from mass_stats using job name and file name.
67
+
68
+ Args:
69
+ job_name: Name of the job
70
+ file_name: Name of the file to download
71
+ output_path: Path where the file should be saved
72
+
73
+ Returns:
74
+ str: Path to the downloaded file
75
+ """
76
+ import os
77
+ from pathlib import Path
78
+
79
+ endpoint_url = f"{self.base_url}/mass_stats/download_files"
80
+ request_body = {
81
+ "job_name": job_name,
82
+ "bucket": bucket,
83
+ "file_name": file_name
84
+ }
85
+
86
+ try:
87
+ # Get signed URL
88
+ response = self.session.post(
89
+ endpoint_url,
90
+ json=request_body,
91
+ verify=self.verify,
92
+ timeout=self.timeout
93
+ )
94
+ signed_url = response.json().get('download_url')
95
+ if not signed_url:
96
+ raise Exception("No download URL received from server")
97
+ print(f"Generated signed URL for download")
98
+
99
+ # Create output directory if it doesn't exist
100
+ output_dir = Path(output_path).parent
101
+ output_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ # Download the file using the signed URL
104
+ download_response = self.session.get(
105
+ signed_url,
106
+ verify=self.verify,
107
+ timeout=self.timeout,
108
+ stream=True # Stream for large files
109
+ )
110
+ download_response.raise_for_status()
111
+
112
+ # Check if file exists in the response (content-length header)
113
+ content_length = download_response.headers.get('content-length')
114
+ if content_length and int(content_length) == 0:
115
+ raise Exception("File appears to be empty")
116
+
117
+ # Write the file
118
+ with open(output_path, 'wb') as file:
119
+ for chunk in download_response.iter_content(chunk_size=8192):
120
+ if chunk:
121
+ file.write(chunk)
122
+
123
+ # Verify file was written
124
+ if not os.path.exists(output_path):
125
+ raise Exception(f"File was not written to {output_path}")
126
+
127
+ file_size = os.path.getsize(output_path)
128
+ print(f"File downloaded successfully to {output_path} (size: {file_size / (1024 * 1024):.4f} mb)")
129
+
130
+ return output_path
131
+
132
+ except self.session.exceptions.RequestException as e:
133
+ if hasattr(e, 'response') and e.response is not None:
134
+ error_detail = e.response.text
135
+ raise Exception(f"Error getting signed URL: {e}. Details: {error_detail}")
136
+ raise Exception(f"Error in download process: {e}")
137
+ except IOError as e:
138
+ raise Exception(f"Error writing file to {output_path}: {e}")
139
+ except Exception as e:
140
+ # Clean up partial file if it exists
141
+ if os.path.exists(output_path):
142
+ try:
143
+ os.remove(output_path)
144
+ except:
145
+ pass
146
+ raise
147
+
148
+
149
+
150
+
15
151
  def upload_request(
16
152
  self,
17
153
  name: str,
18
154
  size: int,
19
- bucket: str,
155
+ region: List[str],
20
156
  output: str,
157
+ config: Dict[str, Any],
21
158
  location: Optional[str] = None,
22
- force_loc: bool = False,
23
- config: Optional[Dict[str, Any]] = None,
159
+ force_loc: Optional[bool] = None,
24
160
  overwrite: bool = False,
25
161
  server: Optional[str] = None,
26
- skip_existing: bool = False
162
+ skip_existing: bool = False,
27
163
  ) -> Dict[str, Any]:
28
164
  """
29
165
  Initiate a mass stats upload job.
30
166
 
31
167
  Args:
32
168
  name: Name of the job
33
- size: Size of the data
34
- bucket: Storage bucket
35
- output: Output path or identifier
169
+ size: Size of the job
170
+ region: Region to run job [aus, eu, us]
171
+ output: Output type
172
+ config: Configuration dictionary
36
173
  location: (Optional) Location for the upload
37
174
  force_loc: Force location usage
38
- config: Optional configuration dictionary
39
175
  overwrite: Overwrite existing data
40
176
  server: Optional server
41
177
  skip_existing: Skip existing files
42
178
  """
179
+
180
+
181
+
182
+ # Step 2: Create the upload job and get signed URLs
43
183
  url = f"{self.base_url}/mass_stats/upload"
184
+
44
185
  data = {
45
186
  "name": name,
46
187
  "size": size,
47
- "bucket": bucket,
188
+ "region": region,
48
189
  "output": output,
49
- "force_loc": force_loc,
190
+ "config": config,
50
191
  "overwrite": overwrite,
51
192
  "skip_existing": skip_existing
52
193
  }
194
+
53
195
  if location is not None:
54
196
  data["location"] = location
55
- if config is not None:
56
- data["config"] = config
197
+ if force_loc is not None:
198
+ data["force_loc"] = force_loc
57
199
  if server is not None:
58
200
  data["server"] = server
59
- response = self.session.post(url, json=data, verify=self.verify, timeout=self.timeout)
60
- print("the response is ", response.text)
61
- # response.raise_for_status()
201
+ response = self.session.post(
202
+ url,
203
+ json=data,
204
+ verify=self.verify,
205
+ timeout=self.timeout
206
+ )
62
207
  return response.json()
63
208
 
209
+
210
+
211
+
212
+ def execute_job(
213
+ self,
214
+ name: str,
215
+ region: str,
216
+ output: str,
217
+ config: Dict[str, Any],
218
+ overwrite: bool = False,
219
+ skip_existing: bool = False,
220
+ request_json: Optional[str] = None,
221
+ manifest_json: Optional[str] = None,
222
+ location: Optional[str] = None,
223
+ force_loc: Optional[bool] = None,
224
+ server: Optional[str] = None
225
+ ) -> Dict[str, Any]:
226
+ # Step 1: Calculate size from request JSON file if provided
227
+ size = 0
228
+ if request_json is not None:
229
+ try:
230
+ with open(request_json, 'r') as file:
231
+ request_data = json_lib.load(file)
232
+
233
+ if isinstance(request_data, list):
234
+ size = len(request_data)
235
+ else:
236
+ raise ValueError(f"Request JSON file {request_json} should contain a list of dictionaries")
237
+
238
+ except FileNotFoundError:
239
+ raise FileNotFoundError(f"Request JSON file not found: {request_json}")
240
+ except json.JSONDecodeError as e:
241
+ raise ValueError(f"Invalid JSON in request file {request_json}: {e}")
242
+
243
+ upload_result = self.upload_request(name, size, region, output, config, location, force_loc, overwrite, server, skip_existing)
244
+
245
+ # Step 3: Upload JSON files if provided
246
+ if request_json is not None or manifest_json is not None:
247
+ requests_url = upload_result.get('requests_url')
248
+ manifest_url = upload_result.get('manifest_url')
249
+
250
+ if request_json is not None:
251
+ if not requests_url:
252
+ raise ValueError("No requests_url returned from server for request JSON upload")
253
+
254
+ try:
255
+ requests_response = self._upload_file(request_json, requests_url, use_gzip=True)
256
+ if requests_response.status_code not in [200, 201, 204]:
257
+ print(f"Requests upload error: {requests_response.text}")
258
+ raise Exception(f"Failed to upload request JSON: {requests_response.text}")
259
+ except Exception as e:
260
+ raise Exception(f"Error uploading request JSON file {request_json}: {e}")
261
+
262
+ if manifest_json is not None:
263
+ if not manifest_url:
264
+ raise ValueError("No manifest_url returned from server for manifest JSON upload")
265
+
266
+ try:
267
+ manifest_response = self._upload_file(manifest_json, manifest_url, use_gzip=False)
268
+ if manifest_response.status_code not in [200, 201, 204]:
269
+ print(f"Manifest upload error: {manifest_response.text}")
270
+ raise Exception(f"Failed to upload manifest JSON: {manifest_response.text}")
271
+ except Exception as e:
272
+ raise Exception(f"Error uploading manifest JSON file {manifest_json}: {e}")
273
+
274
+
275
+ start_job_task_id =self.start_job(upload_result.get("id"))
276
+ return start_job_task_id
277
+
278
+
64
279
  def start_job(self, task_id: str) -> Dict[str, Any]:
65
280
  """
66
281
  Start a mass stats job by task ID.
67
282
  """
68
283
  url = f"{self.base_url}/mass_stats/start/{task_id}"
69
- print("the self session header is ", self.session.headers)
70
284
  response = self.session.post(url, verify=self.verify, timeout=self.timeout)
71
285
  response.raise_for_status()
72
286
  return response.json()
@@ -79,7 +293,7 @@ class MassStats:
79
293
  if uid is not None:
80
294
  url += f"&uid={uid}"
81
295
  response = self.session.get(url, verify=self.verify, timeout=self.timeout)
82
- print("response text is ", response.text)
296
+ #print("response text is ", response.text)
83
297
  return response.json()
84
298
 
85
299
  def track_job(self, ids: Optional[list] = None) -> Dict[str, Any]:
@@ -259,4 +473,32 @@ class MassStats:
259
473
  print("Response text:", response.text)
260
474
  # response.raise_for_status()
261
475
  return response.json()
476
+
477
+
478
+ ### Adding the wrapper function to call endpoint /mass_stats/combine_tiles
479
+ def combine_tiles(
480
+ self,
481
+ data_name: str,
482
+ usezarr: bool = False,
483
+ overwrite: bool = True,
484
+ output : str = "netcdf"
485
+ ) -> Dict[str, Any]:
486
+
487
+ url = f"{self.base_url}/mass_stats/combine_tiles"
488
+ request_body = {
489
+ 'data_name': data_name,
490
+ 'usezarr': str(usezarr).lower(),
491
+ 'output': output,
492
+ 'overwrite': str(overwrite).lower()
493
+ }
494
+ print(f"Request body: {json.dumps(request_body, indent=2)}")
495
+ response = self.session.post(url, json=request_body, verify=self.verify, timeout=self.timeout)
496
+ print(f"Response text: {response.text}")
497
+ return response.json()
498
+
499
+
500
+
501
+
502
+
503
+
262
504
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: terrakio-core
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Core components for Terrakio API clients
5
5
  Author-email: Yupeng Chao <yupeng@haizea.com.au>
6
6
  Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
@@ -22,6 +22,7 @@ Requires-Dist: xarray>=2023.1.0
22
22
  Requires-Dist: shapely>=2.0.0
23
23
  Requires-Dist: geopandas>=0.13.0
24
24
  Requires-Dist: google-cloud-storage>=2.0.0
25
+ Requires-Dist: nest_asyncio
25
26
 
26
27
  # Terrakio Core
27
28
 
@@ -0,0 +1,16 @@
1
+ terrakio_core/__init__.py,sha256=faOeYeL7Lmg3aTRVxVEBT6Dbhi62N4eeXK3mAluD3pA,88
2
+ terrakio_core/auth.py,sha256=Nuj0_X3Hiy17svYgGxrSAR-LXpTlP0J0dSrfMnkPUbI,7717
3
+ terrakio_core/client.py,sha256=oJPg4hcdCBaPwYN0SKQy57YUsJBTTA-SbQ4aaHaNq8E,65578
4
+ terrakio_core/config.py,sha256=AwJ1VgR5K7N32XCU5k7_Dp1nIv_FYt8MBonq9yKlGzA,2658
5
+ terrakio_core/dataset_management.py,sha256=Hdk3nkwd70jw3lBNEaGixrqNVhUWOmsIYktzm_8vXdc,10913
6
+ terrakio_core/decorators.py,sha256=QeNOUX6WEAmdgBL5Igt5DXyYduh3jnmLbodttmwvXhE,785
7
+ terrakio_core/exceptions.py,sha256=9S-I20-QiDRj1qgjFyYUwYM7BLic_bxurcDOIm2Fu_0,410
8
+ terrakio_core/group_access_management.py,sha256=NJ7SX4keUzZAUENmJ5L6ynKf4eRlqtyir5uoKFyY17A,7315
9
+ terrakio_core/mass_stats.py,sha256=UGZo8BH4hzWe3k7pevsYAdRwnVZl-08lXjTlHD4nMgQ,18212
10
+ terrakio_core/space_management.py,sha256=wlUUQrlj_4U_Lpjn9lbF5oj0Rv3NPvvnrd5mWej5kmA,4211
11
+ terrakio_core/user_management.py,sha256=MMNWkz0V_9X7ZYjjteuRU4H4W3F16iuQw1dpA2wVTGg,7400
12
+ terrakio_core/generation/tiles.py,sha256=eiiMNzqaga-c42kG_7zHXTF2o8ZInCPUj0Vu4Ye30Ts,2980
13
+ terrakio_core-0.3.2.dist-info/METADATA,sha256=QaQJrukRnVgy2jAqNr-BDCUaybrqFMbjOlr2KqLlhJI,1476
14
+ terrakio_core-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
+ terrakio_core-0.3.2.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
16
+ terrakio_core-0.3.2.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- terrakio_core/__init__.py,sha256=iguSJomKouzVNPOB2_Ox-FGnQBUQ0ykx8CshjVzU1QM,88
2
- terrakio_core/auth.py,sha256=Nuj0_X3Hiy17svYgGxrSAR-LXpTlP0J0dSrfMnkPUbI,7717
3
- terrakio_core/client.py,sha256=CQ1qiR_8tWKEGX-UT2wLeatk8fYMpyo9KseMpCapw7c,56813
4
- terrakio_core/config.py,sha256=AwJ1VgR5K7N32XCU5k7_Dp1nIv_FYt8MBonq9yKlGzA,2658
5
- terrakio_core/dataset_management.py,sha256=LKUESSDPRu1JubQaQJWdPqHLGt-_Xv77Fpb4IM7vkzM,8751
6
- terrakio_core/exceptions.py,sha256=9S-I20-QiDRj1qgjFyYUwYM7BLic_bxurcDOIm2Fu_0,410
7
- terrakio_core/group_access_management.py,sha256=NJ7SX4keUzZAUENmJ5L6ynKf4eRlqtyir5uoKFyY17A,7315
8
- terrakio_core/mass_stats.py,sha256=AqYJsd6nqo2BDh4vEPUDgsv4T0UR1_TPDoXa3WO3gTU,9284
9
- terrakio_core/space_management.py,sha256=wlUUQrlj_4U_Lpjn9lbF5oj0Rv3NPvvnrd5mWej5kmA,4211
10
- terrakio_core/user_management.py,sha256=MMNWkz0V_9X7ZYjjteuRU4H4W3F16iuQw1dpA2wVTGg,7400
11
- terrakio_core-0.3.0.dist-info/METADATA,sha256=8mS_NJQUoFcr1lE3iUQXQi5VwSZo07t3XF0pCL7VNSI,1448
12
- terrakio_core-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
- terrakio_core-0.3.0.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
14
- terrakio_core-0.3.0.dist-info/RECORD,,