terrakio-core 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- terrakio_core/__init__.py +1 -1
- terrakio_core/client.py +291 -60
- terrakio_core/dataset_management.py +62 -10
- terrakio_core/decorators.py +18 -0
- terrakio_core/generation/tiles.py +95 -0
- terrakio_core/mass_stats.py +260 -18
- {terrakio_core-0.3.0.dist-info → terrakio_core-0.3.2.dist-info}/METADATA +2 -1
- terrakio_core-0.3.2.dist-info/RECORD +16 -0
- terrakio_core-0.3.0.dist-info/RECORD +0 -14
- {terrakio_core-0.3.0.dist-info → terrakio_core-0.3.2.dist-info}/WHEEL +0 -0
- {terrakio_core-0.3.0.dist-info → terrakio_core-0.3.2.dist-info}/top_level.txt +0 -0
terrakio_core/__init__.py
CHANGED
terrakio_core/client.py
CHANGED
|
@@ -13,6 +13,7 @@ from shapely.geometry import shape, mapping
|
|
|
13
13
|
from shapely.geometry.base import BaseGeometry as ShapelyGeometry
|
|
14
14
|
from google.cloud import storage
|
|
15
15
|
from .exceptions import APIError, ConfigurationError
|
|
16
|
+
from .decorators import admin_only_params
|
|
16
17
|
import logging
|
|
17
18
|
import textwrap
|
|
18
19
|
|
|
@@ -129,6 +130,7 @@ class BaseClient:
|
|
|
129
130
|
"resolution": resolution,
|
|
130
131
|
**kwargs
|
|
131
132
|
}
|
|
133
|
+
print("the payload is ", payload)
|
|
132
134
|
request_url = f"{self.url}/geoquery"
|
|
133
135
|
for attempt in range(retry + 1):
|
|
134
136
|
try:
|
|
@@ -536,8 +538,8 @@ class BaseClient:
|
|
|
536
538
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
537
539
|
self.close()
|
|
538
540
|
|
|
539
|
-
|
|
540
|
-
def
|
|
541
|
+
@admin_only_params('location', 'force_loc', 'server')
|
|
542
|
+
def execute_job(self, name, region, output, config, overwrite=False, skip_existing=False, request_json=None, manifest_json=None, location=None, force_loc=None, server="dev-au.terrak.io"):
|
|
541
543
|
if not self.mass_stats:
|
|
542
544
|
from terrakio_core.mass_stats import MassStats
|
|
543
545
|
if not self.url or not self.key:
|
|
@@ -548,20 +550,8 @@ class BaseClient:
|
|
|
548
550
|
verify=self.verify,
|
|
549
551
|
timeout=self.timeout
|
|
550
552
|
)
|
|
551
|
-
return self.mass_stats.
|
|
553
|
+
return self.mass_stats.execute_job(name, region, output, config, overwrite, skip_existing, request_json, manifest_json, location, force_loc, server)
|
|
552
554
|
|
|
553
|
-
def start_mass_stats_job(self, task_id):
|
|
554
|
-
if not self.mass_stats:
|
|
555
|
-
from terrakio_core.mass_stats import MassStats
|
|
556
|
-
if not self.url or not self.key:
|
|
557
|
-
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
558
|
-
self.mass_stats = MassStats(
|
|
559
|
-
base_url=self.url,
|
|
560
|
-
api_key=self.key,
|
|
561
|
-
verify=self.verify,
|
|
562
|
-
timeout=self.timeout
|
|
563
|
-
)
|
|
564
|
-
return self.mass_stats.start_job(task_id)
|
|
565
555
|
|
|
566
556
|
def get_mass_stats_task_id(self, name, stage, uid=None):
|
|
567
557
|
if not self.mass_stats:
|
|
@@ -576,7 +566,7 @@ class BaseClient:
|
|
|
576
566
|
)
|
|
577
567
|
return self.mass_stats.get_task_id(name, stage, uid)
|
|
578
568
|
|
|
579
|
-
def track_mass_stats_job(self, ids=None):
|
|
569
|
+
def track_mass_stats_job(self, ids: Optional[list] = None):
|
|
580
570
|
if not self.mass_stats:
|
|
581
571
|
from terrakio_core.mass_stats import MassStats
|
|
582
572
|
if not self.url or not self.key:
|
|
@@ -1049,6 +1039,20 @@ class BaseClient:
|
|
|
1049
1039
|
)
|
|
1050
1040
|
return self.space_management.delete_data_in_path(path, region)
|
|
1051
1041
|
|
|
1042
|
+
def start_mass_stats_job(self, task_id):
|
|
1043
|
+
if not self.mass_stats:
|
|
1044
|
+
from terrakio_core.mass_stats import MassStats
|
|
1045
|
+
if not self.url or not self.key:
|
|
1046
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1047
|
+
self.mass_stats = MassStats(
|
|
1048
|
+
base_url=self.url,
|
|
1049
|
+
api_key=self.key,
|
|
1050
|
+
verify=self.verify,
|
|
1051
|
+
timeout=self.timeout
|
|
1052
|
+
)
|
|
1053
|
+
return self.mass_stats.start_job(task_id)
|
|
1054
|
+
|
|
1055
|
+
|
|
1052
1056
|
def generate_ai_dataset(
|
|
1053
1057
|
self,
|
|
1054
1058
|
name: str,
|
|
@@ -1118,27 +1122,58 @@ class BaseClient:
|
|
|
1118
1122
|
overwrite=True
|
|
1119
1123
|
)["task_id"]
|
|
1120
1124
|
print("the task id is ", task_id)
|
|
1125
|
+
|
|
1126
|
+
# Wait for job completion
|
|
1127
|
+
import time
|
|
1128
|
+
|
|
1129
|
+
while True:
|
|
1130
|
+
result = self.track_mass_stats_job(ids=[task_id])
|
|
1131
|
+
status = result[task_id]['status']
|
|
1132
|
+
print(f"Job status: {status}")
|
|
1133
|
+
|
|
1134
|
+
if status == "Completed":
|
|
1135
|
+
break
|
|
1136
|
+
elif status == "Error":
|
|
1137
|
+
raise Exception(f"Job {task_id} encountered an error")
|
|
1138
|
+
|
|
1139
|
+
# Wait 30 seconds before checking again
|
|
1140
|
+
time.sleep(30)
|
|
1141
|
+
|
|
1142
|
+
# print("the result is ", result)
|
|
1143
|
+
# after all the random sample jos are done, we then start the mass stats job
|
|
1121
1144
|
task_id = self.start_mass_stats_job(task_id)
|
|
1122
|
-
|
|
1123
|
-
return task_id
|
|
1145
|
+
# now we hav ethe random sampel
|
|
1124
1146
|
|
|
1147
|
+
# print("the task id is ", task_id)
|
|
1148
|
+
return task_id
|
|
1125
1149
|
|
|
1126
|
-
def train_model(self, model_name: str,
|
|
1150
|
+
def train_model(self, model_name: str, training_dataset: str, task_type: str, model_category: str, architecture: str, region: str, hyperparameters: dict = None) -> dict:
|
|
1127
1151
|
"""
|
|
1128
1152
|
Train a model using the external model training API.
|
|
1129
|
-
|
|
1153
|
+
|
|
1130
1154
|
Args:
|
|
1131
1155
|
model_name (str): The name of the model to train.
|
|
1132
|
-
|
|
1133
|
-
|
|
1156
|
+
training_dataset (str): The training dataset identifier.
|
|
1157
|
+
task_type (str): The type of ML task (e.g., regression, classification).
|
|
1158
|
+
model_category (str): The category of model (e.g., random_forest).
|
|
1159
|
+
architecture (str): The model architecture.
|
|
1160
|
+
region (str): The region identifier.
|
|
1161
|
+
hyperparameters (dict, optional): Additional hyperparameters for training.
|
|
1162
|
+
|
|
1134
1163
|
Returns:
|
|
1135
1164
|
dict: The response from the model training API.
|
|
1136
1165
|
"""
|
|
1137
|
-
endpoint = "https://modeltraining-573248941006.australia-southeast1.run.app/train_model"
|
|
1138
1166
|
payload = {
|
|
1139
1167
|
"model_name": model_name,
|
|
1140
|
-
"
|
|
1168
|
+
"training_dataset": training_dataset,
|
|
1169
|
+
"task_type": task_type,
|
|
1170
|
+
"model_category": model_category,
|
|
1171
|
+
"architecture": architecture,
|
|
1172
|
+
"region": region,
|
|
1173
|
+
"hyperparameters": hyperparameters
|
|
1141
1174
|
}
|
|
1175
|
+
endpoint = f"{self.url.rstrip('/')}/train_model"
|
|
1176
|
+
print("the payload is ", payload)
|
|
1142
1177
|
try:
|
|
1143
1178
|
response = self.session.post(endpoint, json=payload, timeout=self.timeout, verify=self.verify)
|
|
1144
1179
|
if not response.ok:
|
|
@@ -1155,35 +1190,163 @@ class BaseClient:
|
|
|
1155
1190
|
except requests.RequestException as e:
|
|
1156
1191
|
raise APIError(f"Model training request failed: {str(e)}")
|
|
1157
1192
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1193
|
+
# Mass Stats methods
|
|
1194
|
+
def combine_tiles(self,
|
|
1195
|
+
data_name: str,
|
|
1196
|
+
usezarr: bool,
|
|
1197
|
+
overwrite: bool,
|
|
1198
|
+
output : str) -> dict:
|
|
1199
|
+
|
|
1200
|
+
if not self.mass_stats:
|
|
1201
|
+
from terrakio_core.mass_stats import MassStats
|
|
1202
|
+
if not self.url or not self.key:
|
|
1203
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1204
|
+
self.mass_stats = MassStats(
|
|
1205
|
+
base_url=self.url,
|
|
1206
|
+
api_key=self.key,
|
|
1207
|
+
verify=self.verify,
|
|
1208
|
+
timeout=self.timeout
|
|
1209
|
+
)
|
|
1210
|
+
return self.mass_stats.combine_tiles(data_name, usezarr, overwrite, output)
|
|
1211
|
+
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def create_dataset_file(
|
|
1215
|
+
self,
|
|
1216
|
+
name: str,
|
|
1217
|
+
aoi: str,
|
|
1218
|
+
expression: str,
|
|
1219
|
+
output: str,
|
|
1220
|
+
tile_size: float = 128.0,
|
|
1221
|
+
crs: str = "epsg:4326",
|
|
1222
|
+
res: float = 0.0001,
|
|
1223
|
+
region: str = "aus",
|
|
1224
|
+
to_crs: str = "epsg:4326",
|
|
1225
|
+
overwrite: bool = True,
|
|
1226
|
+
skip_existing: bool = False,
|
|
1227
|
+
non_interactive: bool = True,
|
|
1228
|
+
usezarr: bool = False,
|
|
1229
|
+
poll_interval: int = 30 # seconds between job status checks
|
|
1230
|
+
) -> dict:
|
|
1231
|
+
|
|
1232
|
+
from terrakio_core.generation.tiles import tiles
|
|
1233
|
+
import tempfile
|
|
1234
|
+
import time
|
|
1235
|
+
|
|
1236
|
+
body, reqs, groups = tiles(
|
|
1237
|
+
name = name,
|
|
1238
|
+
aoi = aoi,
|
|
1239
|
+
expression = expression,
|
|
1240
|
+
output = output,
|
|
1241
|
+
tile_size = tile_size,
|
|
1242
|
+
crs = crs,
|
|
1243
|
+
res = res,
|
|
1244
|
+
region = region,
|
|
1245
|
+
to_crs = to_crs,
|
|
1246
|
+
fully_cover = True,
|
|
1247
|
+
overwrite = overwrite,
|
|
1248
|
+
skip_existing = skip_existing,
|
|
1249
|
+
non_interactive = non_interactive
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
# Create temp json files before upload
|
|
1253
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempreq:
|
|
1254
|
+
tempreq.write(reqs)
|
|
1255
|
+
tempreqname = tempreq.name
|
|
1256
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tempmanifest:
|
|
1257
|
+
tempmanifest.write(groups)
|
|
1258
|
+
tempmanifestname = tempmanifest.name
|
|
1259
|
+
|
|
1260
|
+
if not self.mass_stats:
|
|
1261
|
+
from terrakio_core.mass_stats import MassStats
|
|
1262
|
+
if not self.url or not self.key:
|
|
1263
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1264
|
+
self.mass_stats = MassStats(
|
|
1265
|
+
base_url=self.url,
|
|
1266
|
+
api_key=self.key,
|
|
1267
|
+
verify=self.verify,
|
|
1268
|
+
timeout=self.timeout
|
|
1269
|
+
)
|
|
1167
1270
|
|
|
1271
|
+
task_id = self.mass_stats.execute_job(
|
|
1272
|
+
name=body["name"],
|
|
1273
|
+
region=body["region"],
|
|
1274
|
+
output=body["output"],
|
|
1275
|
+
config = {},
|
|
1276
|
+
overwrite=body["overwrite"],
|
|
1277
|
+
skip_existing=body["skip_existing"],
|
|
1278
|
+
request_json=tempreqname,
|
|
1279
|
+
manifest_json=tempmanifestname,
|
|
1280
|
+
)
|
|
1281
|
+
|
|
1282
|
+
### Start combining tiles when generation-tiles job is done
|
|
1283
|
+
start_time = time.time()
|
|
1284
|
+
status = None
|
|
1285
|
+
|
|
1286
|
+
while True:
|
|
1287
|
+
try:
|
|
1288
|
+
taskid = task_id['task_id']
|
|
1289
|
+
trackinfo = self.mass_stats.track_job([taskid])
|
|
1290
|
+
status = trackinfo[taskid]['status']
|
|
1291
|
+
|
|
1292
|
+
# Check completion states
|
|
1293
|
+
if status == 'Completed':
|
|
1294
|
+
print('Tiles generated successfully!')
|
|
1295
|
+
break
|
|
1296
|
+
elif status in ['Failed', 'Cancelled', 'Error']:
|
|
1297
|
+
raise RuntimeError(f"Job {taskid} failed with status: {status}")
|
|
1298
|
+
else:
|
|
1299
|
+
# Job is still running
|
|
1300
|
+
elapsed_time = time.time() - start_time
|
|
1301
|
+
print(f"Job status: {status} - Elapsed time: {elapsed_time:.1f}s", end='\r')
|
|
1302
|
+
|
|
1303
|
+
# Sleep before next check
|
|
1304
|
+
time.sleep(poll_interval)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
except KeyboardInterrupt:
|
|
1308
|
+
print(f"\nInterrupted! Job {taskid} is still running in the background.")
|
|
1309
|
+
raise
|
|
1310
|
+
except Exception as e:
|
|
1311
|
+
print(f"\nError tracking job: {e}")
|
|
1312
|
+
raise
|
|
1313
|
+
|
|
1314
|
+
# Clean up temporary files
|
|
1315
|
+
import os
|
|
1316
|
+
os.unlink(tempreqname)
|
|
1317
|
+
os.unlink(tempmanifestname)
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
# Start combining tiles
|
|
1321
|
+
if not self.mass_stats:
|
|
1322
|
+
from terrakio_core.mass_stats import MassStats
|
|
1323
|
+
if not self.url or not self.key:
|
|
1324
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1325
|
+
self.mass_stats = MassStats(
|
|
1326
|
+
base_url=self.url,
|
|
1327
|
+
api_key=self.key,
|
|
1328
|
+
verify=self.verify,
|
|
1329
|
+
timeout=self.timeout
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
return self.mass_stats.combine_tiles(body["name"], usezarr, body["overwrite"], body["output"])
|
|
1333
|
+
|
|
1334
|
+
def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
|
|
1168
1335
|
script_content = self._generate_script(model_name, product, model_training_job_name, uid)
|
|
1169
|
-
# self.create_dataset(collection = "terrakio-datasets", input = input, )
|
|
1170
|
-
# we have the script, we need to upload it to the bucket
|
|
1171
1336
|
script_name = f"{product}.py"
|
|
1172
|
-
print("the script content is ", script_content)
|
|
1173
|
-
print("the script name is ", script_name)
|
|
1174
1337
|
self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
|
|
1175
|
-
# after uploading the script, we need to create a new virtual dataset
|
|
1176
1338
|
self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
|
|
1177
1339
|
|
|
1178
1340
|
def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
|
|
1179
1341
|
return textwrap.dedent(f'''
|
|
1180
1342
|
import logging
|
|
1181
1343
|
from io import BytesIO
|
|
1182
|
-
|
|
1183
|
-
from onnxruntime import InferenceSession
|
|
1344
|
+
|
|
1184
1345
|
import numpy as np
|
|
1346
|
+
import pandas as pd
|
|
1185
1347
|
import xarray as xr
|
|
1186
|
-
import
|
|
1348
|
+
from google.cloud import storage
|
|
1349
|
+
from onnxruntime import InferenceSession
|
|
1187
1350
|
|
|
1188
1351
|
logging.basicConfig(
|
|
1189
1352
|
level=logging.INFO
|
|
@@ -1191,54 +1354,122 @@ class BaseClient:
|
|
|
1191
1354
|
|
|
1192
1355
|
def get_model():
|
|
1193
1356
|
logging.info("Loading model for {model_name}...")
|
|
1194
|
-
|
|
1357
|
+
|
|
1195
1358
|
client = storage.Client()
|
|
1196
1359
|
bucket = client.get_bucket('terrakio-mass-requests')
|
|
1197
1360
|
blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
|
|
1198
|
-
|
|
1361
|
+
|
|
1199
1362
|
model = BytesIO()
|
|
1200
1363
|
blob.download_to_file(model)
|
|
1201
1364
|
model.seek(0)
|
|
1202
|
-
|
|
1365
|
+
|
|
1203
1366
|
session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
|
|
1204
1367
|
return session
|
|
1205
1368
|
|
|
1206
1369
|
def {product}(*bands, model):
|
|
1207
1370
|
logging.info("start preparing data")
|
|
1371
|
+
print("the bands are ", bands)
|
|
1208
1372
|
|
|
1209
|
-
|
|
1210
|
-
logging.info(f"Original shape: {{original_shape}}")
|
|
1373
|
+
data_arrays = list(bands)
|
|
1211
1374
|
|
|
1212
|
-
|
|
1213
|
-
for band in bands:
|
|
1214
|
-
transformed_band = band.values.reshape(-1,1)
|
|
1215
|
-
transformed_bands.append(transformed_band)
|
|
1375
|
+
print("the data arrays are ", [da.name for da in data_arrays])
|
|
1216
1376
|
|
|
1217
|
-
|
|
1377
|
+
reference_array = data_arrays[0]
|
|
1378
|
+
original_shape = reference_array.shape
|
|
1379
|
+
logging.info(f"Original shape: {{original_shape}}")
|
|
1218
1380
|
|
|
1381
|
+
if 'time' in reference_array.dims:
|
|
1382
|
+
time_coords = reference_array.coords['time']
|
|
1383
|
+
if len(time_coords) == 1:
|
|
1384
|
+
output_timestamp = time_coords[0]
|
|
1385
|
+
else:
|
|
1386
|
+
years = [pd.to_datetime(t).year for t in time_coords.values]
|
|
1387
|
+
unique_years = set(years)
|
|
1388
|
+
|
|
1389
|
+
if len(unique_years) == 1:
|
|
1390
|
+
year = list(unique_years)[0]
|
|
1391
|
+
output_timestamp = pd.Timestamp(f"{{year}}-01-01")
|
|
1392
|
+
else:
|
|
1393
|
+
latest_year = max(unique_years)
|
|
1394
|
+
output_timestamp = pd.Timestamp(f"{{latest_year}}-01-01")
|
|
1395
|
+
else:
|
|
1396
|
+
output_timestamp = pd.Timestamp("1970-01-01")
|
|
1397
|
+
|
|
1398
|
+
averaged_bands = []
|
|
1399
|
+
for data_array in data_arrays:
|
|
1400
|
+
if 'time' in data_array.dims:
|
|
1401
|
+
averaged_band = np.mean(data_array.values, axis=0)
|
|
1402
|
+
logging.info(f"Averaged band from {{data_array.shape}} to {{averaged_band.shape}}")
|
|
1403
|
+
else:
|
|
1404
|
+
averaged_band = data_array.values
|
|
1405
|
+
logging.info(f"No time dimension, shape: {{averaged_band.shape}}")
|
|
1406
|
+
|
|
1407
|
+
flattened_band = averaged_band.reshape(-1, 1)
|
|
1408
|
+
averaged_bands.append(flattened_band)
|
|
1409
|
+
|
|
1410
|
+
input_data = np.hstack(averaged_bands)
|
|
1411
|
+
|
|
1219
1412
|
logging.info(f"Final input shape: {{input_data.shape}}")
|
|
1220
|
-
|
|
1413
|
+
|
|
1221
1414
|
output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
|
|
1222
|
-
|
|
1415
|
+
|
|
1223
1416
|
logging.info(f"Model output shape: {{output.shape}}")
|
|
1224
1417
|
|
|
1225
|
-
|
|
1418
|
+
if len(original_shape) >= 3:
|
|
1419
|
+
spatial_shape = original_shape[1:]
|
|
1420
|
+
else:
|
|
1421
|
+
spatial_shape = original_shape
|
|
1422
|
+
|
|
1423
|
+
output_reshaped = output.reshape(spatial_shape)
|
|
1424
|
+
|
|
1425
|
+
output_with_time = np.expand_dims(output_reshaped, axis=0)
|
|
1426
|
+
|
|
1427
|
+
if 'time' in reference_array.dims:
|
|
1428
|
+
spatial_dims = [dim for dim in reference_array.dims if dim != 'time']
|
|
1429
|
+
spatial_coords = {{dim: reference_array.coords[dim] for dim in spatial_dims if dim in reference_array.coords}}
|
|
1430
|
+
else:
|
|
1431
|
+
spatial_dims = list(reference_array.dims)
|
|
1432
|
+
spatial_coords = dict(reference_array.coords)
|
|
1433
|
+
|
|
1226
1434
|
result = xr.DataArray(
|
|
1227
|
-
data=
|
|
1228
|
-
dims=
|
|
1229
|
-
coords=
|
|
1435
|
+
data=output_with_time.astype(np.float32),
|
|
1436
|
+
dims=['time'] + list(spatial_dims),
|
|
1437
|
+
coords={
|
|
1438
|
+
'time': [output_timestamp.values],
|
|
1439
|
+
'y': spatial_coords['y'].values,
|
|
1440
|
+
'x': spatial_coords['x'].values
|
|
1441
|
+
}
|
|
1230
1442
|
)
|
|
1231
|
-
|
|
1232
1443
|
return result
|
|
1233
1444
|
''').strip()
|
|
1234
|
-
|
|
1445
|
+
|
|
1235
1446
|
def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
|
|
1236
1447
|
"""Upload the generated script to Google Cloud Storage"""
|
|
1237
1448
|
|
|
1238
1449
|
client = storage.Client()
|
|
1239
1450
|
bucket = client.get_bucket('terrakio-mass-requests')
|
|
1240
1451
|
blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
|
|
1241
|
-
# the first layer is the uid, the second layer is the model training job name
|
|
1242
1452
|
blob.upload_from_string(script_content, content_type='text/plain')
|
|
1243
1453
|
logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
|
|
1244
1454
|
|
|
1455
|
+
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
def download_file_to_path(self, job_name, stage, file_name, output_path):
|
|
1459
|
+
if not self.mass_stats:
|
|
1460
|
+
from terrakio_core.mass_stats import MassStats
|
|
1461
|
+
if not self.url or not self.key:
|
|
1462
|
+
raise ConfigurationError("Mass Stats client not initialized. Make sure API URL and key are set.")
|
|
1463
|
+
self.mass_stats = MassStats(
|
|
1464
|
+
base_url=self.url,
|
|
1465
|
+
api_key=self.key,
|
|
1466
|
+
verify=self.verify,
|
|
1467
|
+
timeout=self.timeout
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
# fetch bucket info based on job name and stage
|
|
1471
|
+
|
|
1472
|
+
taskid = self.mass_stats.get_task_id(job_name, stage).get('task_id')
|
|
1473
|
+
trackinfo = self.mass_stats.track_job([taskid])
|
|
1474
|
+
bucket = trackinfo[taskid]['bucket']
|
|
1475
|
+
return self.mass_stats.download_file(job_name, bucket, file_name, output_path)
|
|
@@ -83,10 +83,63 @@ class DatasetManagement:
|
|
|
83
83
|
except requests.RequestException as e:
|
|
84
84
|
raise APIError(f"Request failed: {str(e)}")
|
|
85
85
|
|
|
86
|
+
# def create_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs) -> Dict[str, Any]:
|
|
87
|
+
# """
|
|
88
|
+
# Create a new dataset.
|
|
89
|
+
|
|
90
|
+
# Args:
|
|
91
|
+
# name: Name of the dataset (required)
|
|
92
|
+
# collection: Dataset collection (default: 'terrakio-datasets')
|
|
93
|
+
# **kwargs: Additional dataset parameters including:
|
|
94
|
+
# - products: List of products
|
|
95
|
+
# - dates_iso8601: List of dates
|
|
96
|
+
# - bucket: Storage bucket
|
|
97
|
+
# - path: Storage path
|
|
98
|
+
# - data_type: Data type
|
|
99
|
+
# - no_data: No data value
|
|
100
|
+
# - l_max: Maximum level
|
|
101
|
+
# - y_size: Y size
|
|
102
|
+
# - x_size: X size
|
|
103
|
+
# - proj4: Projection string
|
|
104
|
+
# - abstract: Dataset abstract
|
|
105
|
+
# - geotransform: Geotransform parameters
|
|
106
|
+
|
|
107
|
+
# Returns:
|
|
108
|
+
# Created dataset information
|
|
109
|
+
|
|
110
|
+
# Raises:
|
|
111
|
+
# APIError: If the API request fails
|
|
112
|
+
# """
|
|
113
|
+
# endpoint = f"{self.api_url}/datasets"
|
|
114
|
+
# params = {"collection": collection}
|
|
115
|
+
# # Create payload with required name parameter
|
|
116
|
+
# payload = {"name": name}
|
|
117
|
+
|
|
118
|
+
# # Add optional parameters if provided
|
|
119
|
+
# for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
|
|
120
|
+
# "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
|
|
121
|
+
# if param in kwargs:
|
|
122
|
+
# payload[param] = kwargs[param]
|
|
123
|
+
|
|
124
|
+
# try:
|
|
125
|
+
# response = self.session.post(
|
|
126
|
+
# endpoint,
|
|
127
|
+
# params=params,
|
|
128
|
+
# json=payload,
|
|
129
|
+
# timeout=self.timeout,
|
|
130
|
+
# verify=self.verify
|
|
131
|
+
# )
|
|
132
|
+
|
|
133
|
+
# if not response.ok:
|
|
134
|
+
# raise APIError(f"API request failed: {response.status_code} {response.reason}")
|
|
135
|
+
# return response.json()
|
|
136
|
+
# except requests.RequestException as e:
|
|
137
|
+
# raise APIError(f"Request failed: {str(e)}")
|
|
138
|
+
|
|
86
139
|
def create_dataset(self, name: str, collection: str = "terrakio-datasets", **kwargs) -> Dict[str, Any]:
|
|
87
140
|
"""
|
|
88
141
|
Create a new dataset.
|
|
89
|
-
|
|
142
|
+
|
|
90
143
|
Args:
|
|
91
144
|
name: Name of the dataset (required)
|
|
92
145
|
collection: Dataset collection (default: 'terrakio-datasets')
|
|
@@ -103,24 +156,23 @@ class DatasetManagement:
|
|
|
103
156
|
- proj4: Projection string
|
|
104
157
|
- abstract: Dataset abstract
|
|
105
158
|
- geotransform: Geotransform parameters
|
|
106
|
-
|
|
159
|
+
- padding: Padding value
|
|
160
|
+
|
|
107
161
|
Returns:
|
|
108
162
|
Created dataset information
|
|
109
|
-
|
|
163
|
+
|
|
110
164
|
Raises:
|
|
111
165
|
APIError: If the API request fails
|
|
112
166
|
"""
|
|
113
167
|
endpoint = f"{self.api_url}/datasets"
|
|
114
168
|
params = {"collection": collection}
|
|
115
|
-
# Create payload with required name parameter
|
|
116
169
|
payload = {"name": name}
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
"no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
|
|
170
|
+
|
|
171
|
+
for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
|
|
172
|
+
"no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input", "padding"]:
|
|
121
173
|
if param in kwargs:
|
|
122
174
|
payload[param] = kwargs[param]
|
|
123
|
-
|
|
175
|
+
|
|
124
176
|
try:
|
|
125
177
|
response = self.session.post(
|
|
126
178
|
endpoint,
|
|
@@ -129,7 +181,7 @@ class DatasetManagement:
|
|
|
129
181
|
timeout=self.timeout,
|
|
130
182
|
verify=self.verify
|
|
131
183
|
)
|
|
132
|
-
|
|
184
|
+
|
|
133
185
|
if not response.ok:
|
|
134
186
|
raise APIError(f"API request failed: {response.status_code} {response.reason}")
|
|
135
187
|
return response.json()
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# terrakio_core/decorators.py
|
|
2
|
+
def admin_only_params(*restricted_params):
|
|
3
|
+
"""
|
|
4
|
+
Decorator factory for restricting method parameters to admin users only.
|
|
5
|
+
"""
|
|
6
|
+
def decorator(func):
|
|
7
|
+
def wrapper(self, *args, **kwargs):
|
|
8
|
+
if hasattr(self, '_is_admin') and self._is_admin:
|
|
9
|
+
return func(self, *args, **kwargs)
|
|
10
|
+
|
|
11
|
+
admin_params_used = set(kwargs.keys()) & set(restricted_params)
|
|
12
|
+
if admin_params_used:
|
|
13
|
+
raise PermissionError(f"Parameters {admin_params_used} are only available to admin users")
|
|
14
|
+
|
|
15
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in restricted_params}
|
|
16
|
+
return func(self, *args, **filtered_kwargs)
|
|
17
|
+
return wrapper
|
|
18
|
+
return decorator
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
### implementing generation-tiles in python api
|
|
2
|
+
### function should just generate the json file for mass_stats to pick up.
|
|
3
|
+
|
|
4
|
+
import geopandas as gpd
|
|
5
|
+
import shapely.geometry
|
|
6
|
+
import json
|
|
7
|
+
from rich import print
|
|
8
|
+
|
|
9
|
+
def escape_newline(string):
|
|
10
|
+
if isinstance(string, list):
|
|
11
|
+
return [s.replace('\\n', '\n') for s in string]
|
|
12
|
+
else:
|
|
13
|
+
return string.replace('\\n', '\n')
|
|
14
|
+
|
|
15
|
+
def get_bounds(aoi, crs, to_crs = None):
|
|
16
|
+
aoi : gpd.GeoDataFrame = gpd.read_file(aoi)
|
|
17
|
+
aoi = aoi.set_crs(crs, allow_override=True)
|
|
18
|
+
if to_crs:
|
|
19
|
+
aoi = aoi.to_crs(to_crs)
|
|
20
|
+
bounds = aoi.geometry[0].bounds
|
|
21
|
+
return *bounds, aoi
|
|
22
|
+
|
|
23
|
+
def tile_generator(x_min, y_min, x_max, y_max, aoi, crs, res, tile_size, expression, output, fully_cover=True):
|
|
24
|
+
i_max = int((x_max-x_min)/(tile_size*res))
|
|
25
|
+
j_max = int((y_max-y_min)/(tile_size*res))
|
|
26
|
+
if fully_cover:
|
|
27
|
+
i_max += 1
|
|
28
|
+
j_max += 1
|
|
29
|
+
for j in range(0, int(j_max)):
|
|
30
|
+
for i in range(0, int(i_max)):
|
|
31
|
+
#print(f"Processing tile {i} {j}")
|
|
32
|
+
x = x_min + i*(tile_size*res)
|
|
33
|
+
y = y_max - j*(tile_size*res)
|
|
34
|
+
bbox = shapely.geometry.box(x, y-(tile_size*res), x + (tile_size*res), y)
|
|
35
|
+
if not aoi.geometry[0].intersects(bbox):
|
|
36
|
+
continue
|
|
37
|
+
feat = {"type": "Feature", "geometry": bbox.__geo_interface__}
|
|
38
|
+
data = {
|
|
39
|
+
"feature": feat,
|
|
40
|
+
"in_crs": crs,
|
|
41
|
+
"out_crs": crs,
|
|
42
|
+
"resolution": res,
|
|
43
|
+
"expr" : expression,
|
|
44
|
+
"output" : output,
|
|
45
|
+
}
|
|
46
|
+
yield data, i , j
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def tiles(
|
|
50
|
+
name: str,
|
|
51
|
+
aoi : str,
|
|
52
|
+
expression: str = "red=S2v2#(year,median).red@(year =2024) \n red",
|
|
53
|
+
output: str = "netcdf",
|
|
54
|
+
tile_size : float = 512,
|
|
55
|
+
crs : str = "epsg:3577",
|
|
56
|
+
res: float = 10,
|
|
57
|
+
region : str = "eu",
|
|
58
|
+
to_crs: str = None,
|
|
59
|
+
fully_cover: bool = True,
|
|
60
|
+
overwrite: bool = False,
|
|
61
|
+
skip_existing: bool = False,
|
|
62
|
+
non_interactive: bool = False,
|
|
63
|
+
):
|
|
64
|
+
|
|
65
|
+
# Create requests for each tile
|
|
66
|
+
reqs = []
|
|
67
|
+
x_min, y_min, x_max, y_max, aoi = get_bounds(aoi, crs, to_crs)
|
|
68
|
+
#print(f"Bounds: {x_min}, {y_min}, {x_max}, {y_max}")
|
|
69
|
+
|
|
70
|
+
if to_crs is None:
|
|
71
|
+
to_crs = crs
|
|
72
|
+
for tile_req, i, j in tile_generator(x_min, y_min, x_max, y_max, aoi, to_crs, res, tile_size, expression, output, fully_cover):
|
|
73
|
+
req_name = f"{name}_{i:02d}_{j:02d}"
|
|
74
|
+
reqs.append({"group": "tiles", "file": req_name, "request": tile_req})
|
|
75
|
+
|
|
76
|
+
#print(f"Generated {len(reqs)} tile requests.")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
count = len(reqs)
|
|
80
|
+
groups = list(set(dic["group"] for dic in reqs))
|
|
81
|
+
|
|
82
|
+
body = {
|
|
83
|
+
"name" : name,
|
|
84
|
+
"output" : output,
|
|
85
|
+
"region" : region,
|
|
86
|
+
"size" : count,
|
|
87
|
+
"overwrite" : overwrite,
|
|
88
|
+
"non_interactive": non_interactive,
|
|
89
|
+
"skip_existing" : skip_existing,
|
|
90
|
+
}
|
|
91
|
+
request_json = json.dumps(reqs)
|
|
92
|
+
manifest_json = json.dumps(groups)
|
|
93
|
+
|
|
94
|
+
return body, request_json, manifest_json
|
|
95
|
+
|
terrakio_core/mass_stats.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import requests
|
|
2
|
-
from typing import Optional, Dict, Any
|
|
2
|
+
from typing import Optional, Dict, Any, List
|
|
3
|
+
import json
|
|
4
|
+
import json as json_lib
|
|
5
|
+
import gzip
|
|
3
6
|
|
|
4
7
|
class MassStats:
|
|
5
8
|
def __init__(self, base_url: str, api_key: str, verify: bool = True, timeout: int = 60):
|
|
@@ -12,61 +15,272 @@ class MassStats:
|
|
|
12
15
|
'x-api-key': self.api_key
|
|
13
16
|
})
|
|
14
17
|
|
|
18
|
+
def _upload_file(self, file_path: str, url: str, use_gzip: bool = False):
|
|
19
|
+
"""
|
|
20
|
+
Helper method to upload a JSON file to a signed URL.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
file_path: Path to the JSON file
|
|
24
|
+
url: Signed URL to upload to
|
|
25
|
+
use_gzip: Whether to compress the file with gzip
|
|
26
|
+
"""
|
|
27
|
+
try:
|
|
28
|
+
with open(file_path, 'r') as file:
|
|
29
|
+
json_data = json_lib.load(file)
|
|
30
|
+
except FileNotFoundError:
|
|
31
|
+
raise FileNotFoundError(f"JSON file not found: {file_path}")
|
|
32
|
+
except json.JSONDecodeError as e:
|
|
33
|
+
raise ValueError(f"Invalid JSON in file {file_path}: {e}")
|
|
34
|
+
|
|
35
|
+
# Check if using simplejson and support ignore_nan
|
|
36
|
+
if hasattr(json_lib, 'dumps') and 'ignore_nan' in json_lib.dumps.__code__.co_varnames:
|
|
37
|
+
dumps_kwargs = {'ignore_nan': True}
|
|
38
|
+
else:
|
|
39
|
+
dumps_kwargs = {}
|
|
40
|
+
|
|
41
|
+
if use_gzip:
|
|
42
|
+
# Serialize and compress the JSON data
|
|
43
|
+
body = gzip.compress(json_lib.dumps(json_data, **dumps_kwargs).encode('utf-8'))
|
|
44
|
+
headers = {
|
|
45
|
+
'Content-Type': 'application/json',
|
|
46
|
+
'Content-Encoding': 'gzip'
|
|
47
|
+
}
|
|
48
|
+
else:
|
|
49
|
+
body = json_lib.dumps(json_data, **dumps_kwargs).encode('utf-8')
|
|
50
|
+
headers = {
|
|
51
|
+
'Content-Type': 'application/json'
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Make the PUT request to the signed URL
|
|
55
|
+
response = requests.put(
|
|
56
|
+
url,
|
|
57
|
+
data=body,
|
|
58
|
+
headers=headers
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return response
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def download_file(self, job_name: str, bucket:str, file_name: str, output_path: str) -> str:
|
|
65
|
+
"""
|
|
66
|
+
Download a file from mass_stats using job name and file name.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
job_name: Name of the job
|
|
70
|
+
file_name: Name of the file to download
|
|
71
|
+
output_path: Path where the file should be saved
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
str: Path to the downloaded file
|
|
75
|
+
"""
|
|
76
|
+
import os
|
|
77
|
+
from pathlib import Path
|
|
78
|
+
|
|
79
|
+
endpoint_url = f"{self.base_url}/mass_stats/download_files"
|
|
80
|
+
request_body = {
|
|
81
|
+
"job_name": job_name,
|
|
82
|
+
"bucket": bucket,
|
|
83
|
+
"file_name": file_name
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
# Get signed URL
|
|
88
|
+
response = self.session.post(
|
|
89
|
+
endpoint_url,
|
|
90
|
+
json=request_body,
|
|
91
|
+
verify=self.verify,
|
|
92
|
+
timeout=self.timeout
|
|
93
|
+
)
|
|
94
|
+
signed_url = response.json().get('download_url')
|
|
95
|
+
if not signed_url:
|
|
96
|
+
raise Exception("No download URL received from server")
|
|
97
|
+
print(f"Generated signed URL for download")
|
|
98
|
+
|
|
99
|
+
# Create output directory if it doesn't exist
|
|
100
|
+
output_dir = Path(output_path).parent
|
|
101
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
102
|
+
|
|
103
|
+
# Download the file using the signed URL
|
|
104
|
+
download_response = self.session.get(
|
|
105
|
+
signed_url,
|
|
106
|
+
verify=self.verify,
|
|
107
|
+
timeout=self.timeout,
|
|
108
|
+
stream=True # Stream for large files
|
|
109
|
+
)
|
|
110
|
+
download_response.raise_for_status()
|
|
111
|
+
|
|
112
|
+
# Check if file exists in the response (content-length header)
|
|
113
|
+
content_length = download_response.headers.get('content-length')
|
|
114
|
+
if content_length and int(content_length) == 0:
|
|
115
|
+
raise Exception("File appears to be empty")
|
|
116
|
+
|
|
117
|
+
# Write the file
|
|
118
|
+
with open(output_path, 'wb') as file:
|
|
119
|
+
for chunk in download_response.iter_content(chunk_size=8192):
|
|
120
|
+
if chunk:
|
|
121
|
+
file.write(chunk)
|
|
122
|
+
|
|
123
|
+
# Verify file was written
|
|
124
|
+
if not os.path.exists(output_path):
|
|
125
|
+
raise Exception(f"File was not written to {output_path}")
|
|
126
|
+
|
|
127
|
+
file_size = os.path.getsize(output_path)
|
|
128
|
+
print(f"File downloaded successfully to {output_path} (size: {file_size / (1024 * 1024):.4f} mb)")
|
|
129
|
+
|
|
130
|
+
return output_path
|
|
131
|
+
|
|
132
|
+
except self.session.exceptions.RequestException as e:
|
|
133
|
+
if hasattr(e, 'response') and e.response is not None:
|
|
134
|
+
error_detail = e.response.text
|
|
135
|
+
raise Exception(f"Error getting signed URL: {e}. Details: {error_detail}")
|
|
136
|
+
raise Exception(f"Error in download process: {e}")
|
|
137
|
+
except IOError as e:
|
|
138
|
+
raise Exception(f"Error writing file to {output_path}: {e}")
|
|
139
|
+
except Exception as e:
|
|
140
|
+
# Clean up partial file if it exists
|
|
141
|
+
if os.path.exists(output_path):
|
|
142
|
+
try:
|
|
143
|
+
os.remove(output_path)
|
|
144
|
+
except:
|
|
145
|
+
pass
|
|
146
|
+
raise
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
|
|
15
151
|
def upload_request(
|
|
16
152
|
self,
|
|
17
153
|
name: str,
|
|
18
154
|
size: int,
|
|
19
|
-
|
|
155
|
+
region: List[str],
|
|
20
156
|
output: str,
|
|
157
|
+
config: Dict[str, Any],
|
|
21
158
|
location: Optional[str] = None,
|
|
22
|
-
force_loc: bool =
|
|
23
|
-
config: Optional[Dict[str, Any]] = None,
|
|
159
|
+
force_loc: Optional[bool] = None,
|
|
24
160
|
overwrite: bool = False,
|
|
25
161
|
server: Optional[str] = None,
|
|
26
|
-
skip_existing: bool = False
|
|
162
|
+
skip_existing: bool = False,
|
|
27
163
|
) -> Dict[str, Any]:
|
|
28
164
|
"""
|
|
29
165
|
Initiate a mass stats upload job.
|
|
30
166
|
|
|
31
167
|
Args:
|
|
32
168
|
name: Name of the job
|
|
33
|
-
size: Size of the
|
|
34
|
-
|
|
35
|
-
output: Output
|
|
169
|
+
size: Size of the job
|
|
170
|
+
region: Region to run job [aus, eu, us]
|
|
171
|
+
output: Output type
|
|
172
|
+
config: Configuration dictionary
|
|
36
173
|
location: (Optional) Location for the upload
|
|
37
174
|
force_loc: Force location usage
|
|
38
|
-
config: Optional configuration dictionary
|
|
39
175
|
overwrite: Overwrite existing data
|
|
40
176
|
server: Optional server
|
|
41
177
|
skip_existing: Skip existing files
|
|
42
178
|
"""
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# Step 2: Create the upload job and get signed URLs
|
|
43
183
|
url = f"{self.base_url}/mass_stats/upload"
|
|
184
|
+
|
|
44
185
|
data = {
|
|
45
186
|
"name": name,
|
|
46
187
|
"size": size,
|
|
47
|
-
"
|
|
188
|
+
"region": region,
|
|
48
189
|
"output": output,
|
|
49
|
-
"
|
|
190
|
+
"config": config,
|
|
50
191
|
"overwrite": overwrite,
|
|
51
192
|
"skip_existing": skip_existing
|
|
52
193
|
}
|
|
194
|
+
|
|
53
195
|
if location is not None:
|
|
54
196
|
data["location"] = location
|
|
55
|
-
if
|
|
56
|
-
data["
|
|
197
|
+
if force_loc is not None:
|
|
198
|
+
data["force_loc"] = force_loc
|
|
57
199
|
if server is not None:
|
|
58
200
|
data["server"] = server
|
|
59
|
-
response = self.session.post(
|
|
60
|
-
|
|
61
|
-
|
|
201
|
+
response = self.session.post(
|
|
202
|
+
url,
|
|
203
|
+
json=data,
|
|
204
|
+
verify=self.verify,
|
|
205
|
+
timeout=self.timeout
|
|
206
|
+
)
|
|
62
207
|
return response.json()
|
|
63
208
|
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def execute_job(
|
|
213
|
+
self,
|
|
214
|
+
name: str,
|
|
215
|
+
region: str,
|
|
216
|
+
output: str,
|
|
217
|
+
config: Dict[str, Any],
|
|
218
|
+
overwrite: bool = False,
|
|
219
|
+
skip_existing: bool = False,
|
|
220
|
+
request_json: Optional[str] = None,
|
|
221
|
+
manifest_json: Optional[str] = None,
|
|
222
|
+
location: Optional[str] = None,
|
|
223
|
+
force_loc: Optional[bool] = None,
|
|
224
|
+
server: Optional[str] = None
|
|
225
|
+
) -> Dict[str, Any]:
|
|
226
|
+
# Step 1: Calculate size from request JSON file if provided
|
|
227
|
+
size = 0
|
|
228
|
+
if request_json is not None:
|
|
229
|
+
try:
|
|
230
|
+
with open(request_json, 'r') as file:
|
|
231
|
+
request_data = json_lib.load(file)
|
|
232
|
+
|
|
233
|
+
if isinstance(request_data, list):
|
|
234
|
+
size = len(request_data)
|
|
235
|
+
else:
|
|
236
|
+
raise ValueError(f"Request JSON file {request_json} should contain a list of dictionaries")
|
|
237
|
+
|
|
238
|
+
except FileNotFoundError:
|
|
239
|
+
raise FileNotFoundError(f"Request JSON file not found: {request_json}")
|
|
240
|
+
except json.JSONDecodeError as e:
|
|
241
|
+
raise ValueError(f"Invalid JSON in request file {request_json}: {e}")
|
|
242
|
+
|
|
243
|
+
upload_result = self.upload_request(name, size, region, output, config, location, force_loc, overwrite, server, skip_existing)
|
|
244
|
+
|
|
245
|
+
# Step 3: Upload JSON files if provided
|
|
246
|
+
if request_json is not None or manifest_json is not None:
|
|
247
|
+
requests_url = upload_result.get('requests_url')
|
|
248
|
+
manifest_url = upload_result.get('manifest_url')
|
|
249
|
+
|
|
250
|
+
if request_json is not None:
|
|
251
|
+
if not requests_url:
|
|
252
|
+
raise ValueError("No requests_url returned from server for request JSON upload")
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
requests_response = self._upload_file(request_json, requests_url, use_gzip=True)
|
|
256
|
+
if requests_response.status_code not in [200, 201, 204]:
|
|
257
|
+
print(f"Requests upload error: {requests_response.text}")
|
|
258
|
+
raise Exception(f"Failed to upload request JSON: {requests_response.text}")
|
|
259
|
+
except Exception as e:
|
|
260
|
+
raise Exception(f"Error uploading request JSON file {request_json}: {e}")
|
|
261
|
+
|
|
262
|
+
if manifest_json is not None:
|
|
263
|
+
if not manifest_url:
|
|
264
|
+
raise ValueError("No manifest_url returned from server for manifest JSON upload")
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
manifest_response = self._upload_file(manifest_json, manifest_url, use_gzip=False)
|
|
268
|
+
if manifest_response.status_code not in [200, 201, 204]:
|
|
269
|
+
print(f"Manifest upload error: {manifest_response.text}")
|
|
270
|
+
raise Exception(f"Failed to upload manifest JSON: {manifest_response.text}")
|
|
271
|
+
except Exception as e:
|
|
272
|
+
raise Exception(f"Error uploading manifest JSON file {manifest_json}: {e}")
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
start_job_task_id =self.start_job(upload_result.get("id"))
|
|
276
|
+
return start_job_task_id
|
|
277
|
+
|
|
278
|
+
|
|
64
279
|
def start_job(self, task_id: str) -> Dict[str, Any]:
|
|
65
280
|
"""
|
|
66
281
|
Start a mass stats job by task ID.
|
|
67
282
|
"""
|
|
68
283
|
url = f"{self.base_url}/mass_stats/start/{task_id}"
|
|
69
|
-
print("the self session header is ", self.session.headers)
|
|
70
284
|
response = self.session.post(url, verify=self.verify, timeout=self.timeout)
|
|
71
285
|
response.raise_for_status()
|
|
72
286
|
return response.json()
|
|
@@ -79,7 +293,7 @@ class MassStats:
|
|
|
79
293
|
if uid is not None:
|
|
80
294
|
url += f"&uid={uid}"
|
|
81
295
|
response = self.session.get(url, verify=self.verify, timeout=self.timeout)
|
|
82
|
-
print("response text is ", response.text)
|
|
296
|
+
#print("response text is ", response.text)
|
|
83
297
|
return response.json()
|
|
84
298
|
|
|
85
299
|
def track_job(self, ids: Optional[list] = None) -> Dict[str, Any]:
|
|
@@ -259,4 +473,32 @@ class MassStats:
|
|
|
259
473
|
print("Response text:", response.text)
|
|
260
474
|
# response.raise_for_status()
|
|
261
475
|
return response.json()
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
### Adding the wrapper function to call endpoint /mass_stats/combine_tiles
|
|
479
|
+
def combine_tiles(
|
|
480
|
+
self,
|
|
481
|
+
data_name: str,
|
|
482
|
+
usezarr: bool = False,
|
|
483
|
+
overwrite: bool = True,
|
|
484
|
+
output : str = "netcdf"
|
|
485
|
+
) -> Dict[str, Any]:
|
|
486
|
+
|
|
487
|
+
url = f"{self.base_url}/mass_stats/combine_tiles"
|
|
488
|
+
request_body = {
|
|
489
|
+
'data_name': data_name,
|
|
490
|
+
'usezarr': str(usezarr).lower(),
|
|
491
|
+
'output': output,
|
|
492
|
+
'overwrite': str(overwrite).lower()
|
|
493
|
+
}
|
|
494
|
+
print(f"Request body: {json.dumps(request_body, indent=2)}")
|
|
495
|
+
response = self.session.post(url, json=request_body, verify=self.verify, timeout=self.timeout)
|
|
496
|
+
print(f"Response text: {response.text}")
|
|
497
|
+
return response.json()
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
|
|
262
504
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: terrakio-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Core components for Terrakio API clients
|
|
5
5
|
Author-email: Yupeng Chao <yupeng@haizea.com.au>
|
|
6
6
|
Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
|
|
@@ -22,6 +22,7 @@ Requires-Dist: xarray>=2023.1.0
|
|
|
22
22
|
Requires-Dist: shapely>=2.0.0
|
|
23
23
|
Requires-Dist: geopandas>=0.13.0
|
|
24
24
|
Requires-Dist: google-cloud-storage>=2.0.0
|
|
25
|
+
Requires-Dist: nest_asyncio
|
|
25
26
|
|
|
26
27
|
# Terrakio Core
|
|
27
28
|
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
terrakio_core/__init__.py,sha256=faOeYeL7Lmg3aTRVxVEBT6Dbhi62N4eeXK3mAluD3pA,88
|
|
2
|
+
terrakio_core/auth.py,sha256=Nuj0_X3Hiy17svYgGxrSAR-LXpTlP0J0dSrfMnkPUbI,7717
|
|
3
|
+
terrakio_core/client.py,sha256=oJPg4hcdCBaPwYN0SKQy57YUsJBTTA-SbQ4aaHaNq8E,65578
|
|
4
|
+
terrakio_core/config.py,sha256=AwJ1VgR5K7N32XCU5k7_Dp1nIv_FYt8MBonq9yKlGzA,2658
|
|
5
|
+
terrakio_core/dataset_management.py,sha256=Hdk3nkwd70jw3lBNEaGixrqNVhUWOmsIYktzm_8vXdc,10913
|
|
6
|
+
terrakio_core/decorators.py,sha256=QeNOUX6WEAmdgBL5Igt5DXyYduh3jnmLbodttmwvXhE,785
|
|
7
|
+
terrakio_core/exceptions.py,sha256=9S-I20-QiDRj1qgjFyYUwYM7BLic_bxurcDOIm2Fu_0,410
|
|
8
|
+
terrakio_core/group_access_management.py,sha256=NJ7SX4keUzZAUENmJ5L6ynKf4eRlqtyir5uoKFyY17A,7315
|
|
9
|
+
terrakio_core/mass_stats.py,sha256=UGZo8BH4hzWe3k7pevsYAdRwnVZl-08lXjTlHD4nMgQ,18212
|
|
10
|
+
terrakio_core/space_management.py,sha256=wlUUQrlj_4U_Lpjn9lbF5oj0Rv3NPvvnrd5mWej5kmA,4211
|
|
11
|
+
terrakio_core/user_management.py,sha256=MMNWkz0V_9X7ZYjjteuRU4H4W3F16iuQw1dpA2wVTGg,7400
|
|
12
|
+
terrakio_core/generation/tiles.py,sha256=eiiMNzqaga-c42kG_7zHXTF2o8ZInCPUj0Vu4Ye30Ts,2980
|
|
13
|
+
terrakio_core-0.3.2.dist-info/METADATA,sha256=QaQJrukRnVgy2jAqNr-BDCUaybrqFMbjOlr2KqLlhJI,1476
|
|
14
|
+
terrakio_core-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
terrakio_core-0.3.2.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
|
|
16
|
+
terrakio_core-0.3.2.dist-info/RECORD,,
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
terrakio_core/__init__.py,sha256=iguSJomKouzVNPOB2_Ox-FGnQBUQ0ykx8CshjVzU1QM,88
|
|
2
|
-
terrakio_core/auth.py,sha256=Nuj0_X3Hiy17svYgGxrSAR-LXpTlP0J0dSrfMnkPUbI,7717
|
|
3
|
-
terrakio_core/client.py,sha256=CQ1qiR_8tWKEGX-UT2wLeatk8fYMpyo9KseMpCapw7c,56813
|
|
4
|
-
terrakio_core/config.py,sha256=AwJ1VgR5K7N32XCU5k7_Dp1nIv_FYt8MBonq9yKlGzA,2658
|
|
5
|
-
terrakio_core/dataset_management.py,sha256=LKUESSDPRu1JubQaQJWdPqHLGt-_Xv77Fpb4IM7vkzM,8751
|
|
6
|
-
terrakio_core/exceptions.py,sha256=9S-I20-QiDRj1qgjFyYUwYM7BLic_bxurcDOIm2Fu_0,410
|
|
7
|
-
terrakio_core/group_access_management.py,sha256=NJ7SX4keUzZAUENmJ5L6ynKf4eRlqtyir5uoKFyY17A,7315
|
|
8
|
-
terrakio_core/mass_stats.py,sha256=AqYJsd6nqo2BDh4vEPUDgsv4T0UR1_TPDoXa3WO3gTU,9284
|
|
9
|
-
terrakio_core/space_management.py,sha256=wlUUQrlj_4U_Lpjn9lbF5oj0Rv3NPvvnrd5mWej5kmA,4211
|
|
10
|
-
terrakio_core/user_management.py,sha256=MMNWkz0V_9X7ZYjjteuRU4H4W3F16iuQw1dpA2wVTGg,7400
|
|
11
|
-
terrakio_core-0.3.0.dist-info/METADATA,sha256=8mS_NJQUoFcr1lE3iUQXQi5VwSZo07t3XF0pCL7VNSI,1448
|
|
12
|
-
terrakio_core-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
13
|
-
terrakio_core-0.3.0.dist-info/top_level.txt,sha256=5cBj6O7rNWyn97ND4YuvvXm0Crv4RxttT4JZvNdOG6Q,14
|
|
14
|
-
terrakio_core-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|