terrakio-core 0.2.6__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

Files changed (20) hide show
  1. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/PKG-INFO +1 -1
  2. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/pyproject.toml +1 -1
  3. terrakio_core-0.2.8/terrakio_core/__init__.py +7 -0
  4. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/client.py +284 -86
  5. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/dataset_management.py +1 -1
  6. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/PKG-INFO +1 -1
  7. terrakio_core-0.2.6/terrakio_core/__init__.py +0 -0
  8. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/README.md +0 -0
  9. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/setup.cfg +0 -0
  10. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/auth.py +0 -0
  11. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/config.py +0 -0
  12. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/exceptions.py +0 -0
  13. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/group_access_management.py +0 -0
  14. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/mass_stats.py +0 -0
  15. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/space_management.py +0 -0
  16. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core/user_management.py +0 -0
  17. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/SOURCES.txt +0 -0
  18. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/dependency_links.txt +0 -0
  19. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/requires.txt +0 -0
  20. {terrakio_core-0.2.6 → terrakio_core-0.2.8}/terrakio_core.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: terrakio-core
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Core components for Terrakio API clients
5
5
  Author-email: Yupeng Chao <yupeng@haizea.com.au>
6
6
  Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "terrakio-core"
7
- version = "0.2.6"
7
+ version = "0.2.8"
8
8
  authors = [
9
9
  {name = "Yupeng Chao", email = "yupeng@haizea.com.au"},
10
10
  ]
@@ -0,0 +1,7 @@
1
+ """
2
+ Terrakio Core
3
+
4
+ Core components for Terrakio API clients.
5
+ """
6
+
7
+ __version__ = "0.2.8"
@@ -11,14 +11,17 @@ import xarray as xr
11
11
  import nest_asyncio
12
12
  from shapely.geometry import shape, mapping
13
13
  from shapely.geometry.base import BaseGeometry as ShapelyGeometry
14
-
14
+ from google.cloud import storage
15
15
  from .exceptions import APIError, ConfigurationError
16
+ import logging
17
+ import textwrap
18
+
16
19
 
17
20
  class BaseClient:
18
21
  def __init__(self, url: Optional[str] = None, key: Optional[str] = None,
19
22
  auth_url: Optional[str] = "https://dev-au.terrak.io",
20
23
  quiet: bool = False, config_file: Optional[str] = None,
21
- verify: bool = True, timeout: int = 60):
24
+ verify: bool = True, timeout: int = 300):
22
25
  nest_asyncio.apply()
23
26
  self.quiet = quiet
24
27
  self.verify = verify
@@ -86,9 +89,10 @@ class BaseClient:
86
89
  )
87
90
  return self._aiohttp_session
88
91
 
89
- async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
90
- in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
91
- output: str = "csv", resolution: int = -1, **kwargs):
92
+ async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
93
+ in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
94
+ output: str = "csv", resolution: int = -1, buffer: bool = False,
95
+ retry: int = 3, **kwargs):
92
96
  """
93
97
  Asynchronous version of the wcs() method using aiohttp.
94
98
 
@@ -99,6 +103,8 @@ class BaseClient:
99
103
  out_crs (str): Output coordinate reference system
100
104
  output (str): Output format ('csv' or 'netcdf')
101
105
  resolution (int): Resolution parameter
106
+ buffer (bool): Whether to buffer the request (default True)
107
+ retry (int): Number of retry attempts (default 3)
102
108
  **kwargs: Additional parameters to pass to the WCS request
103
109
 
104
110
  Returns:
@@ -111,8 +117,7 @@ class BaseClient:
111
117
  "geometry": mapping(feature),
112
118
  "properties": {}
113
119
  }
114
- self.validate_feature(feature)
115
-
120
+
116
121
  payload = {
117
122
  "feature": feature,
118
123
  "in_crs": in_crs,
@@ -120,47 +125,65 @@ class BaseClient:
120
125
  "output": output,
121
126
  "resolution": resolution,
122
127
  "expr": expr,
128
+ "buffer": buffer,
129
+ "resolution": resolution,
123
130
  **kwargs
124
131
  }
125
-
126
132
  request_url = f"{self.url}/geoquery"
127
-
128
- try:
129
- # Get the shared aiohttp session
130
- session = await self.aiohttp_session
131
- async with session.post(request_url, json=payload, ssl=self.verify) as response:
132
- if not response.ok:
133
- error_msg = f"API request failed: {response.status} {response.reason}"
134
- try:
135
- error_data = await response.json()
136
- if "detail" in error_data:
137
- error_msg += f" - {error_data['detail']}"
138
- except:
139
- pass
140
- raise APIError(error_msg)
141
-
142
- content = await response.read()
143
-
144
- if output.lower() == "csv":
145
- import pandas as pd
146
- df = pd.read_csv(BytesIO(content))
147
- return df
148
- elif output.lower() == "netcdf":
149
- return xr.open_dataset(BytesIO(content))
150
- else:
151
- try:
152
- return xr.open_dataset(BytesIO(content))
153
- except ValueError:
133
+ for attempt in range(retry + 1):
134
+ try:
135
+ session = await self.aiohttp_session
136
+ async with session.post(request_url, json=payload, ssl=self.verify) as response:
137
+ if not response.ok:
138
+ should_retry = False
139
+ if response.status in [408, 502, 503, 504]:
140
+ should_retry = True
141
+ elif response.status == 500:
142
+ try:
143
+ response_text = await response.text()
144
+ if "Internal server error" not in response_text:
145
+ should_retry = True
146
+ except:
147
+ should_retry = True
148
+
149
+ if should_retry and attempt < retry:
150
+ continue
151
+ else:
152
+ error_msg = f"API request failed: {response.status} {response.reason}"
153
+ try:
154
+ error_data = await response.json()
155
+ if "detail" in error_data:
156
+ error_msg += f" - {error_data['detail']}"
157
+ except:
158
+ pass
159
+ raise APIError(error_msg)
160
+
161
+ content = await response.read()
162
+
163
+ if output.lower() == "csv":
154
164
  import pandas as pd
165
+ df = pd.read_csv(BytesIO(content))
166
+ return df
167
+ elif output.lower() == "netcdf":
168
+ return xr.open_dataset(BytesIO(content))
169
+ else:
155
170
  try:
156
- return pd.read_csv(BytesIO(content))
157
- except:
158
- return content
159
-
160
- except aiohttp.ClientError as e:
161
- raise APIError(f"Request failed: {str(e)}")
162
- except Exception as e:
163
- raise
171
+ return xr.open_dataset(BytesIO(content))
172
+ except ValueError:
173
+ import pandas as pd
174
+ try:
175
+ return pd.read_csv(BytesIO(content))
176
+ except:
177
+ return content
178
+
179
+ except aiohttp.ClientError as e:
180
+ if attempt == retry:
181
+ raise APIError(f"Request failed: {str(e)}")
182
+ continue
183
+ except Exception as e:
184
+ if attempt == retry:
185
+ raise
186
+ continue
164
187
 
165
188
  async def close_async(self):
166
189
  """Close the aiohttp session"""
@@ -174,41 +197,6 @@ class BaseClient:
174
197
  async def __aexit__(self, exc_type, exc_val, exc_tb):
175
198
  await self.close_async()
176
199
 
177
- def validate_feature(self, feature: Dict[str, Any]) -> None:
178
- if hasattr(feature, 'is_valid'):
179
- from shapely.geometry import mapping
180
- feature = {
181
- "type": "Feature",
182
- "geometry": mapping(feature),
183
- "properties": {}
184
- }
185
- if not isinstance(feature, dict):
186
- raise ValueError("Feature must be a dictionary or a Shapely geometry")
187
- if feature.get("type") != "Feature":
188
- raise ValueError("GeoJSON object must be of type 'Feature'")
189
- if "geometry" not in feature:
190
- raise ValueError("Feature must contain a 'geometry' field")
191
- if "properties" not in feature:
192
- raise ValueError("Feature must contain a 'properties' field")
193
- try:
194
- geometry = shape(feature["geometry"])
195
- except Exception as e:
196
- raise ValueError(f"Invalid geometry format: {str(e)}")
197
- if not geometry.is_valid:
198
- raise ValueError(f"Invalid geometry: {geometry.is_valid_reason}")
199
- geom_type = feature["geometry"]["type"]
200
- if geom_type == "Point":
201
- if len(feature["geometry"]["coordinates"]) != 2:
202
- raise ValueError("Point must have exactly 2 coordinates")
203
- elif geom_type == "Polygon":
204
- if not geometry.is_simple:
205
- raise ValueError("Polygon must be simple (not self-intersecting)")
206
- if geometry.area == 0:
207
- raise ValueError("Polygon must have non-zero area")
208
- coords = feature["geometry"]["coordinates"][0]
209
- if coords[0] != coords[-1]:
210
- raise ValueError("Polygon must be closed (first and last points must match)")
211
-
212
200
  def signup(self, email: str, password: str) -> Dict[str, Any]:
213
201
  if not self.auth_client:
214
202
  raise ConfigurationError("Authentication client not initialized. Please provide auth_url during client initialization.")
@@ -309,7 +297,6 @@ class BaseClient:
309
297
  "geometry": mapping(feature),
310
298
  "properties": {}
311
299
  }
312
- self.validate_feature(feature)
313
300
  payload = {
314
301
  "feature": feature,
315
302
  "in_crs": in_crs,
@@ -321,7 +308,10 @@ class BaseClient:
321
308
  }
322
309
  request_url = f"{self.url}/geoquery"
323
310
  try:
311
+ print("the request url is ", request_url)
312
+ print("the payload is ", payload)
324
313
  response = self.session.post(request_url, json=payload, timeout=self.timeout, verify=self.verify)
314
+ print("the response is ", response.text)
325
315
  if not response.ok:
326
316
  error_msg = f"API request failed: {response.status_code} {response.reason}"
327
317
  try:
@@ -690,15 +680,39 @@ class BaseClient:
690
680
  )
691
681
  return self.mass_stats.random_sample(name, **kwargs)
692
682
 
693
- async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv"):
683
+ async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv",
684
+ in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
694
685
  """
695
686
  Compute zonal statistics for all geometries in a GeoDataFrame using asyncio for concurrency.
687
+
688
+ Args:
689
+ gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
690
+ expr (str): Terrakio expression to evaluate, can include spatial aggregations
691
+ conc (int): Number of concurrent requests to make
692
+ inplace (bool): Whether to modify the input GeoDataFrame in place
693
+ output (str): Output format (csv or netcdf)
694
+ in_crs (str): Input coordinate reference system
695
+ out_crs (str): Output coordinate reference system
696
+ resolution (int): Resolution parameter
697
+ buffer (bool): Whether to buffer the request (default True)
698
+
699
+ Returns:
700
+ geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
696
701
  """
702
+ if conc > 100:
703
+ raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
697
704
 
698
705
  # Process geometries in batches
699
706
  all_results = []
700
707
  row_indices = []
701
708
 
709
+ # Calculate total batches for progress reporting
710
+ total_geometries = len(gdb)
711
+ total_batches = (total_geometries + conc - 1) // conc # Ceiling division
712
+ completed_batches = 0
713
+
714
+ print(f"Processing {total_geometries} geometries with concurrency {conc}")
715
+
702
716
  async def process_geometry(geom, index):
703
717
  """Process a single geometry"""
704
718
  try:
@@ -707,7 +721,8 @@ class BaseClient:
707
721
  "geometry": mapping(geom),
708
722
  "properties": {"index": index}
709
723
  }
710
- result = await self.wcs_async(expr=expr, feature=feature, output=output)
724
+ result = await self.wcs_async(expr=expr, feature=feature, output=output,
725
+ in_crs=in_crs, out_crs=out_crs, resolution=resolution, buffer=buffer)
711
726
  # Add original index to track which geometry this result belongs to
712
727
  if isinstance(result, pd.DataFrame):
713
728
  result['_geometry_index'] = index
@@ -749,11 +764,19 @@ class BaseClient:
749
764
  batch_results = await process_batch(batch_indices)
750
765
  all_results.extend(batch_results)
751
766
  row_indices.extend(batch_indices)
767
+
768
+ # Update progress
769
+ completed_batches += 1
770
+ processed_geometries = min(i + conc, total_geometries)
771
+ print(f"Progress: {completed_batches}/{total_batches} completed ({processed_geometries}/{total_geometries} geometries processed)")
772
+
752
773
  except Exception as e:
753
774
  if hasattr(e, 'response'):
754
775
  raise APIError(f"API request failed: {e.response.text}")
755
776
  raise
756
777
 
778
+ print("All batches completed! Processing results...")
779
+
757
780
  if not all_results:
758
781
  raise ValueError("No valid results were returned for any geometry")
759
782
 
@@ -845,7 +868,8 @@ class BaseClient:
845
868
  else:
846
869
  return result_gdf
847
870
 
848
- def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv"):
871
+ def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv",
872
+ in_crs="epsg:4326", out_crs="epsg:4326", resolution=-1, buffer=False):
849
873
  """
850
874
  Compute zonal statistics for all geometries in a GeoDataFrame.
851
875
 
@@ -855,32 +879,44 @@ class BaseClient:
855
879
  conc (int): Number of concurrent requests to make
856
880
  inplace (bool): Whether to modify the input GeoDataFrame in place
857
881
  output (str): Output format (csv or netcdf)
882
+ in_crs (str): Input coordinate reference system
883
+ out_crs (str): Output coordinate reference system
884
+ resolution (int): Resolution parameter
885
+ buffer (bool): Whether to buffer the request (default True)
858
886
 
859
887
  Returns:
860
888
  geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
861
889
  """
890
+ if conc > 100:
891
+ raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
862
892
  import asyncio
863
893
 
894
+ print(f"Starting zonal statistics computation for expression: {expr}")
895
+
864
896
  # Check if we're in a Jupyter environment or already have an event loop
865
897
  try:
866
898
  loop = asyncio.get_running_loop()
867
899
  # We're in an async context (like Jupyter), use create_task
868
900
  nest_asyncio.apply()
869
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output))
901
+ result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
902
+ in_crs, out_crs, resolution, buffer))
870
903
  except RuntimeError:
871
904
  # No running event loop, safe to use asyncio.run()
872
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output))
905
+ result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
906
+ in_crs, out_crs, resolution, buffer))
873
907
  except ImportError:
874
908
  # nest_asyncio not available, try alternative approach
875
909
  try:
876
910
  loop = asyncio.get_running_loop()
877
911
  # Create task in existing loop
878
- task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output))
912
+ task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output,
913
+ in_crs, out_crs, resolution, buffer))
879
914
  # This won't work directly - we need a different approach
880
915
  raise RuntimeError("Cannot run async code in Jupyter without nest_asyncio. Please install: pip install nest-asyncio")
881
916
  except RuntimeError:
882
917
  # No event loop, use asyncio.run
883
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output))
918
+ result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
919
+ in_crs, out_crs, resolution, buffer))
884
920
 
885
921
  # Ensure aiohttp session is closed after running async code
886
922
  try:
@@ -890,6 +926,7 @@ class BaseClient:
890
926
  # Event loop may already be closed, ignore
891
927
  pass
892
928
 
929
+ print("Zonal statistics computation completed!")
893
930
  return result
894
931
 
895
932
  # Group access management protected methods
@@ -1011,6 +1048,80 @@ class BaseClient:
1011
1048
  timeout=self.timeout
1012
1049
  )
1013
1050
  return self.space_management.delete_data_in_path(path, region)
1051
+
1052
+ def generate_ai_dataset(
1053
+ self,
1054
+ name: str,
1055
+ aoi_geojson: str,
1056
+ expression_x: str,
1057
+ expression_y: str,
1058
+ samples: int,
1059
+ tile_size: int,
1060
+ crs: str = "epsg:4326",
1061
+ res: float = 0.001,
1062
+ region: str = "aus",
1063
+ start_year: int = None,
1064
+ end_year: int = None,
1065
+ ) -> dict:
1066
+ """
1067
+ Generate an AI dataset using specified parameters.
1068
+
1069
+ Args:
1070
+ name (str): Name of the dataset to generate
1071
+ aoi_geojson (str): Path to GeoJSON file containing area of interest
1072
+ expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
1073
+ expression_y (str): Expression for Y variable with {year} placeholder
1074
+ samples (int): Number of samples to generate
1075
+ tile_size (int): Size of tiles in degrees
1076
+ crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
1077
+ res (float, optional): Resolution in degrees. Defaults to 0.001
1078
+ region (str, optional): Region code. Defaults to "aus"
1079
+ start_year (int, optional): Start year for data generation. Required if end_year provided
1080
+ end_year (int, optional): End year for data generation. Required if start_year provided
1081
+ overwrite (bool, optional): Whether to overwrite existing dataset. Defaults to False
1082
+
1083
+ Returns:
1084
+ dict: Response from the AI dataset generation API
1085
+
1086
+ Raises:
1087
+ ValidationError: If required parameters are missing or invalid
1088
+ APIError: If the API request fails
1089
+ """
1090
+
1091
+ # we have the parameters, let pass the parameters to the random sample function
1092
+ # task_id = self.random_sample(name, aoi_geojson, expression_x, expression_y, samples, tile_size, crs, res, region, start_year, end_year, overwrite)
1093
+ config = {
1094
+ "expressions" : [{"expr": expression_x, "res": res, "prefix": "x"}],
1095
+ "filters" : []
1096
+ }
1097
+ config["expressions"].append({"expr": expression_y, "res" : res, "prefix": "y"})
1098
+
1099
+ expression_x = expression_x.replace("{year}", str(start_year))
1100
+ expression_y = expression_y.replace("{year}", str(start_year))
1101
+ print("the aoi geojson is ", aoi_geojson)
1102
+ with open(aoi_geojson, 'r') as f:
1103
+ aoi_data = json.load(f)
1104
+ print("the config is ", config)
1105
+ task_id = self.random_sample(
1106
+ name=name,
1107
+ config=config,
1108
+ aoi=aoi_data,
1109
+ samples=samples,
1110
+ year_range=[start_year, end_year],
1111
+ crs=crs,
1112
+ tile_size=tile_size,
1113
+ res=res,
1114
+ region=region,
1115
+ output="netcdf",
1116
+ server=self.url,
1117
+ bucket="terrakio-mass-requests",
1118
+ overwrite=True
1119
+ )["task_id"]
1120
+ print("the task id is ", task_id)
1121
+ task_id = self.start_mass_stats_job(task_id)
1122
+ print("the task id is ", task_id)
1123
+ return task_id
1124
+
1014
1125
 
1015
1126
  def train_model(self, model_name: str, training_data: dict) -> dict:
1016
1127
  """
@@ -1044,3 +1155,90 @@ class BaseClient:
1044
1155
  except requests.RequestException as e:
1045
1156
  raise APIError(f"Model training request failed: {str(e)}")
1046
1157
 
1158
+ def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
1159
+ # we have the dataset and we have the product, and we have the model name, we need to create a new json file and add that to the dataset as our virtual dataset
1160
+ # upload the script to the bucket, the script should be able to download the model and do the inferencing
1161
+ # we need to upload the the json to the to the dataset as our virtual dataset
1162
+ # then we do nothing and wait for the user to make the request call to the explorer
1163
+ # we should have a uniform script for the random forest deployment
1164
+ # create a script for each model
1165
+ # upload script to google bucket,
1166
+ #
1167
+
1168
+ script_content = self._generate_script(model_name, product, model_training_job_name, uid)
1169
+ # self.create_dataset(collection = "terrakio-datasets", input = input, )
1170
+ # we have the script, we need to upload it to the bucket
1171
+ script_name = f"{product}.py"
1172
+ print("the script content is ", script_content)
1173
+ print("the script name is ", script_name)
1174
+ self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
1175
+ # after uploading the script, we need to create a new virtual dataset
1176
+ self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
1177
+
1178
+ def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
1179
+ return textwrap.dedent(f'''
1180
+ import logging
1181
+ from io import BytesIO
1182
+ from google.cloud import storage
1183
+ from onnxruntime import InferenceSession
1184
+ import numpy as np
1185
+ import xarray as xr
1186
+ import datetime
1187
+
1188
+ logging.basicConfig(
1189
+ level=logging.INFO
1190
+ )
1191
+
1192
+ def get_model():
1193
+ logging.info("Loading model for {model_name}...")
1194
+
1195
+ client = storage.Client()
1196
+ bucket = client.get_bucket('terrakio-mass-requests')
1197
+ blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
1198
+
1199
+ model = BytesIO()
1200
+ blob.download_to_file(model)
1201
+ model.seek(0)
1202
+
1203
+ session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
1204
+ return session
1205
+
1206
+ def {product}(*bands, model):
1207
+ logging.info("start preparing data")
1208
+
1209
+ original_shape = bands[0].shape
1210
+ logging.info(f"Original shape: {{original_shape}}")
1211
+
1212
+ transformed_bands = []
1213
+ for band in bands:
1214
+ transformed_band = band.values.reshape(-1,1)
1215
+ transformed_bands.append(transformed_band)
1216
+
1217
+ input_data = np.hstack(transformed_bands)
1218
+
1219
+ logging.info(f"Final input shape: {{input_data.shape}}")
1220
+
1221
+ output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
1222
+
1223
+ logging.info(f"Model output shape: {{output.shape}}")
1224
+
1225
+ output_reshaped = output.reshape(original_shape)
1226
+ result = xr.DataArray(
1227
+ data=output_reshaped,
1228
+ dims=bands[0].dims,
1229
+ coords=bands[0].coords
1230
+ )
1231
+
1232
+ return result
1233
+ ''').strip()
1234
+
1235
+ def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
1236
+ """Upload the generated script to Google Cloud Storage"""
1237
+
1238
+ client = storage.Client()
1239
+ bucket = client.get_bucket('terrakio-mass-requests')
1240
+ blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
1241
+ # the first layer is the uid, the second layer is the model training job name
1242
+ blob.upload_from_string(script_content, content_type='text/plain')
1243
+ logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
1244
+
@@ -117,7 +117,7 @@ class DatasetManagement:
117
117
 
118
118
  # Add optional parameters if provided
119
119
  for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
120
- "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform"]:
120
+ "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
121
121
  if param in kwargs:
122
122
  payload[param] = kwargs[param]
123
123
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: terrakio-core
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Core components for Terrakio API clients
5
5
  Author-email: Yupeng Chao <yupeng@haizea.com.au>
6
6
  Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
File without changes
File without changes
File without changes