terrakio-core 0.2.4__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of terrakio-core might be problematic. Click here for more details.

Files changed (20) hide show
  1. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/PKG-INFO +1 -1
  2. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/pyproject.toml +1 -1
  3. terrakio_core-0.2.7/terrakio_core/__init__.py +7 -0
  4. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/client.py +306 -82
  5. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/dataset_management.py +1 -1
  6. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core.egg-info/PKG-INFO +1 -1
  7. terrakio_core-0.2.4/terrakio_core/__init__.py +0 -0
  8. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/README.md +0 -0
  9. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/setup.cfg +0 -0
  10. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/auth.py +0 -0
  11. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/config.py +0 -0
  12. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/exceptions.py +0 -0
  13. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/group_access_management.py +0 -0
  14. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/mass_stats.py +0 -0
  15. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/space_management.py +0 -0
  16. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core/user_management.py +0 -0
  17. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core.egg-info/SOURCES.txt +0 -0
  18. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core.egg-info/dependency_links.txt +0 -0
  19. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core.egg-info/requires.txt +0 -0
  20. {terrakio_core-0.2.4 → terrakio_core-0.2.7}/terrakio_core.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: terrakio-core
3
- Version: 0.2.4
3
+ Version: 0.2.7
4
4
  Summary: Core components for Terrakio API clients
5
5
  Author-email: Yupeng Chao <yupeng@haizea.com.au>
6
6
  Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "terrakio-core"
7
- version = "0.2.4"
7
+ version = "0.2.7"
8
8
  authors = [
9
9
  {name = "Yupeng Chao", email = "yupeng@haizea.com.au"},
10
10
  ]
@@ -0,0 +1,7 @@
1
+ """
2
+ Terrakio Core
3
+
4
+ Core components for Terrakio API clients.
5
+ """
6
+
7
+ __version__ = "0.2.7"
@@ -8,16 +8,21 @@ import aiohttp
8
8
  import pandas as pd
9
9
  import geopandas as gpd
10
10
  import xarray as xr
11
+ import nest_asyncio
11
12
  from shapely.geometry import shape, mapping
12
13
  from shapely.geometry.base import BaseGeometry as ShapelyGeometry
13
-
14
+ from google.cloud import storage
14
15
  from .exceptions import APIError, ConfigurationError
16
+ import logging
17
+ import textwrap
18
+
15
19
 
16
20
  class BaseClient:
17
21
  def __init__(self, url: Optional[str] = None, key: Optional[str] = None,
18
22
  auth_url: Optional[str] = "https://dev-au.terrak.io",
19
23
  quiet: bool = False, config_file: Optional[str] = None,
20
- verify: bool = True, timeout: int = 60):
24
+ verify: bool = True, timeout: int = 300):
25
+ nest_asyncio.apply()
21
26
  self.quiet = quiet
22
27
  self.verify = verify
23
28
  self.timeout = timeout
@@ -84,9 +89,10 @@ class BaseClient:
84
89
  )
85
90
  return self._aiohttp_session
86
91
 
87
- async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
88
- in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
89
- output: str = "csv", resolution: int = -1, **kwargs):
92
+ async def wcs_async(self, expr: str, feature: Union[Dict[str, Any], ShapelyGeometry],
93
+ in_crs: str = "epsg:4326", out_crs: str = "epsg:4326",
94
+ output: str = "csv", resolution: int = -1, buffer: bool = True,
95
+ retry: int = 3, **kwargs):
90
96
  """
91
97
  Asynchronous version of the wcs() method using aiohttp.
92
98
 
@@ -97,6 +103,8 @@ class BaseClient:
97
103
  out_crs (str): Output coordinate reference system
98
104
  output (str): Output format ('csv' or 'netcdf')
99
105
  resolution (int): Resolution parameter
106
+ buffer (bool): Whether to buffer the request (default True)
107
+ retry (int): Number of retry attempts (default 3)
100
108
  **kwargs: Additional parameters to pass to the WCS request
101
109
 
102
110
  Returns:
@@ -109,8 +117,7 @@ class BaseClient:
109
117
  "geometry": mapping(feature),
110
118
  "properties": {}
111
119
  }
112
- self.validate_feature(feature)
113
-
120
+
114
121
  payload = {
115
122
  "feature": feature,
116
123
  "in_crs": in_crs,
@@ -118,47 +125,68 @@ class BaseClient:
118
125
  "output": output,
119
126
  "resolution": resolution,
120
127
  "expr": expr,
128
+ "buffer": buffer,
129
+ "resolution": resolution,
121
130
  **kwargs
122
131
  }
123
132
 
124
133
  request_url = f"{self.url}/geoquery"
125
134
 
126
- try:
127
- # Get the shared aiohttp session
128
- session = await self.aiohttp_session
129
- async with session.post(request_url, json=payload, ssl=self.verify) as response:
130
- if not response.ok:
131
- error_msg = f"API request failed: {response.status} {response.reason}"
132
- try:
133
- error_data = await response.json()
134
- if "detail" in error_data:
135
- error_msg += f" - {error_data['detail']}"
136
- except:
137
- pass
138
- raise APIError(error_msg)
139
-
140
- content = await response.read()
141
-
142
- if output.lower() == "csv":
143
- import pandas as pd
144
- df = pd.read_csv(BytesIO(content))
145
- return df
146
- elif output.lower() == "netcdf":
147
- return xr.open_dataset(BytesIO(content))
148
- else:
149
- try:
150
- return xr.open_dataset(BytesIO(content))
151
- except ValueError:
135
+ for attempt in range(retry + 1):
136
+ try:
137
+ session = await self.aiohttp_session
138
+ async with session.post(request_url, json=payload, ssl=self.verify) as response:
139
+ if not response.ok:
140
+ should_retry = False
141
+
142
+ if response.status in [408, 502, 503, 504]:
143
+ should_retry = True
144
+ elif response.status == 500:
145
+ try:
146
+ response_text = await response.text()
147
+ if "Internal server error" not in response_text:
148
+ should_retry = True
149
+ except:
150
+ should_retry = True
151
+
152
+ if should_retry and attempt < retry:
153
+ continue
154
+ else:
155
+ error_msg = f"API request failed: {response.status} {response.reason}"
156
+ try:
157
+ error_data = await response.json()
158
+ if "detail" in error_data:
159
+ error_msg += f" - {error_data['detail']}"
160
+ except:
161
+ pass
162
+ raise APIError(error_msg)
163
+
164
+ content = await response.read()
165
+
166
+ if output.lower() == "csv":
152
167
  import pandas as pd
168
+ df = pd.read_csv(BytesIO(content))
169
+ return df
170
+ elif output.lower() == "netcdf":
171
+ return xr.open_dataset(BytesIO(content))
172
+ else:
153
173
  try:
154
- return pd.read_csv(BytesIO(content))
155
- except:
156
- return content
157
-
158
- except aiohttp.ClientError as e:
159
- raise APIError(f"Request failed: {str(e)}")
160
- except Exception as e:
161
- raise
174
+ return xr.open_dataset(BytesIO(content))
175
+ except ValueError:
176
+ import pandas as pd
177
+ try:
178
+ return pd.read_csv(BytesIO(content))
179
+ except:
180
+ return content
181
+
182
+ except aiohttp.ClientError as e:
183
+ if attempt == retry:
184
+ raise APIError(f"Request failed: {str(e)}")
185
+ continue
186
+ except Exception as e:
187
+ if attempt == retry:
188
+ raise
189
+ continue
162
190
 
163
191
  async def close_async(self):
164
192
  """Close the aiohttp session"""
@@ -172,41 +200,6 @@ class BaseClient:
172
200
  async def __aexit__(self, exc_type, exc_val, exc_tb):
173
201
  await self.close_async()
174
202
 
175
- def validate_feature(self, feature: Dict[str, Any]) -> None:
176
- if hasattr(feature, 'is_valid'):
177
- from shapely.geometry import mapping
178
- feature = {
179
- "type": "Feature",
180
- "geometry": mapping(feature),
181
- "properties": {}
182
- }
183
- if not isinstance(feature, dict):
184
- raise ValueError("Feature must be a dictionary or a Shapely geometry")
185
- if feature.get("type") != "Feature":
186
- raise ValueError("GeoJSON object must be of type 'Feature'")
187
- if "geometry" not in feature:
188
- raise ValueError("Feature must contain a 'geometry' field")
189
- if "properties" not in feature:
190
- raise ValueError("Feature must contain a 'properties' field")
191
- try:
192
- geometry = shape(feature["geometry"])
193
- except Exception as e:
194
- raise ValueError(f"Invalid geometry format: {str(e)}")
195
- if not geometry.is_valid:
196
- raise ValueError(f"Invalid geometry: {geometry.is_valid_reason}")
197
- geom_type = feature["geometry"]["type"]
198
- if geom_type == "Point":
199
- if len(feature["geometry"]["coordinates"]) != 2:
200
- raise ValueError("Point must have exactly 2 coordinates")
201
- elif geom_type == "Polygon":
202
- if not geometry.is_simple:
203
- raise ValueError("Polygon must be simple (not self-intersecting)")
204
- if geometry.area == 0:
205
- raise ValueError("Polygon must have non-zero area")
206
- coords = feature["geometry"]["coordinates"][0]
207
- if coords[0] != coords[-1]:
208
- raise ValueError("Polygon must be closed (first and last points must match)")
209
-
210
203
  def signup(self, email: str, password: str) -> Dict[str, Any]:
211
204
  if not self.auth_client:
212
205
  raise ConfigurationError("Authentication client not initialized. Please provide auth_url during client initialization.")
@@ -307,7 +300,6 @@ class BaseClient:
307
300
  "geometry": mapping(feature),
308
301
  "properties": {}
309
302
  }
310
- self.validate_feature(feature)
311
303
  payload = {
312
304
  "feature": feature,
313
305
  "in_crs": in_crs,
@@ -319,7 +311,10 @@ class BaseClient:
319
311
  }
320
312
  request_url = f"{self.url}/geoquery"
321
313
  try:
314
+ print("the request url is ", request_url)
315
+ print("the payload is ", payload)
322
316
  response = self.session.post(request_url, json=payload, timeout=self.timeout, verify=self.verify)
317
+ print("the response is ", response.text)
323
318
  if not response.ok:
324
319
  error_msg = f"API request failed: {response.status_code} {response.reason}"
325
320
  try:
@@ -519,7 +514,24 @@ class BaseClient:
519
514
  self.auth_client.session.close()
520
515
  # Close aiohttp session if it exists
521
516
  if self._aiohttp_session and not self._aiohttp_session.closed:
522
- asyncio.run(self.close_async())
517
+ try:
518
+ nest_asyncio.apply()
519
+ asyncio.run(self.close_async())
520
+ except ImportError:
521
+ try:
522
+ asyncio.run(self.close_async())
523
+ except RuntimeError as e:
524
+ if "cannot be called from a running event loop" in str(e):
525
+ # In Jupyter, we can't properly close the async session
526
+ # Log a warning or handle gracefully
527
+ import warnings
528
+ warnings.warn("Cannot properly close aiohttp session in Jupyter environment. "
529
+ "Consider using 'await client.close_async()' instead.")
530
+ else:
531
+ raise
532
+ except RuntimeError:
533
+ # Event loop may already be closed, ignore
534
+ pass
523
535
 
524
536
  def __enter__(self):
525
537
  return self
@@ -671,10 +683,27 @@ class BaseClient:
671
683
  )
672
684
  return self.mass_stats.random_sample(name, **kwargs)
673
685
 
674
- async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv"):
686
+ async def zonal_stats_async(self, gdb, expr, conc=20, inplace=False, output="csv",
687
+ in_crs="epsg:4326", out_crs="epsg:4326", resolution=0.005, buffer=True):
675
688
  """
676
689
  Compute zonal statistics for all geometries in a GeoDataFrame using asyncio for concurrency.
690
+
691
+ Args:
692
+ gdb (geopandas.GeoDataFrame): GeoDataFrame containing geometries
693
+ expr (str): Terrakio expression to evaluate, can include spatial aggregations
694
+ conc (int): Number of concurrent requests to make
695
+ inplace (bool): Whether to modify the input GeoDataFrame in place
696
+ output (str): Output format (csv or netcdf)
697
+ in_crs (str): Input coordinate reference system
698
+ out_crs (str): Output coordinate reference system
699
+ resolution (int): Resolution parameter
700
+ buffer (bool): Whether to buffer the request (default True)
701
+
702
+ Returns:
703
+ geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
677
704
  """
705
+ if conc > 100:
706
+ raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
678
707
 
679
708
  # Process geometries in batches
680
709
  all_results = []
@@ -688,7 +717,8 @@ class BaseClient:
688
717
  "geometry": mapping(geom),
689
718
  "properties": {"index": index}
690
719
  }
691
- result = await self.wcs_async(expr=expr, feature=feature, output=output)
720
+ result = await self.wcs_async(expr=expr, feature=feature, output=output,
721
+ in_crs=in_crs, out_crs=out_crs, resolution=resolution, buffer=buffer)
692
722
  # Add original index to track which geometry this result belongs to
693
723
  if isinstance(result, pd.DataFrame):
694
724
  result['_geometry_index'] = index
@@ -826,7 +856,8 @@ class BaseClient:
826
856
  else:
827
857
  return result_gdf
828
858
 
829
- def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv"):
859
+ def zonal_stats(self, gdb, expr, conc=20, inplace=False, output="csv",
860
+ in_crs="epsg:4326", out_crs="epsg:4326", resolution=0.005, buffer=True):
830
861
  """
831
862
  Compute zonal statistics for all geometries in a GeoDataFrame.
832
863
 
@@ -836,12 +867,43 @@ class BaseClient:
836
867
  conc (int): Number of concurrent requests to make
837
868
  inplace (bool): Whether to modify the input GeoDataFrame in place
838
869
  output (str): Output format (csv or netcdf)
870
+ in_crs (str): Input coordinate reference system
871
+ out_crs (str): Output coordinate reference system
872
+ resolution (int): Resolution parameter
873
+ buffer (bool): Whether to buffer the request (default True)
839
874
 
840
875
  Returns:
841
876
  geopandas.GeoDataFrame: GeoDataFrame with added columns for results, or None if inplace=True
842
877
  """
878
+ if conc > 100:
879
+ raise ValueError("Concurrency (conc) is too high. Please set conc to 100 or less.")
843
880
  import asyncio
844
- result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output))
881
+
882
+ # Check if we're in a Jupyter environment or already have an event loop
883
+ try:
884
+ loop = asyncio.get_running_loop()
885
+ # We're in an async context (like Jupyter), use create_task
886
+ nest_asyncio.apply()
887
+ result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
888
+ in_crs, out_crs, resolution, buffer))
889
+ except RuntimeError:
890
+ # No running event loop, safe to use asyncio.run()
891
+ result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
892
+ in_crs, out_crs, resolution, buffer))
893
+ except ImportError:
894
+ # nest_asyncio not available, try alternative approach
895
+ try:
896
+ loop = asyncio.get_running_loop()
897
+ # Create task in existing loop
898
+ task = loop.create_task(self.zonal_stats_async(gdb, expr, conc, inplace, output,
899
+ in_crs, out_crs, resolution, buffer))
900
+ # This won't work directly - we need a different approach
901
+ raise RuntimeError("Cannot run async code in Jupyter without nest_asyncio. Please install: pip install nest-asyncio")
902
+ except RuntimeError:
903
+ # No event loop, use asyncio.run
904
+ result = asyncio.run(self.zonal_stats_async(gdb, expr, conc, inplace, output,
905
+ in_crs, out_crs, resolution, buffer))
906
+
845
907
  # Ensure aiohttp session is closed after running async code
846
908
  try:
847
909
  if self._aiohttp_session and not self._aiohttp_session.closed:
@@ -849,6 +911,7 @@ class BaseClient:
849
911
  except RuntimeError:
850
912
  # Event loop may already be closed, ignore
851
913
  pass
914
+
852
915
  return result
853
916
 
854
917
  # Group access management protected methods
@@ -970,6 +1033,80 @@ class BaseClient:
970
1033
  timeout=self.timeout
971
1034
  )
972
1035
  return self.space_management.delete_data_in_path(path, region)
1036
+
1037
+ def generate_ai_dataset(
1038
+ self,
1039
+ name: str,
1040
+ aoi_geojson: str,
1041
+ expression_x: str,
1042
+ expression_y: str,
1043
+ samples: int,
1044
+ tile_size: int,
1045
+ crs: str = "epsg:4326",
1046
+ res: float = 0.001,
1047
+ region: str = "aus",
1048
+ start_year: int = None,
1049
+ end_year: int = None,
1050
+ ) -> dict:
1051
+ """
1052
+ Generate an AI dataset using specified parameters.
1053
+
1054
+ Args:
1055
+ name (str): Name of the dataset to generate
1056
+ aoi_geojson (str): Path to GeoJSON file containing area of interest
1057
+ expression_x (str): Expression for X variable (e.g. "MSWX.air_temperature@(year=2021, month=1)")
1058
+ expression_y (str): Expression for Y variable with {year} placeholder
1059
+ samples (int): Number of samples to generate
1060
+ tile_size (int): Size of tiles in degrees
1061
+ crs (str, optional): Coordinate reference system. Defaults to "epsg:4326"
1062
+ res (float, optional): Resolution in degrees. Defaults to 0.001
1063
+ region (str, optional): Region code. Defaults to "aus"
1064
+ start_year (int, optional): Start year for data generation. Required if end_year provided
1065
+ end_year (int, optional): End year for data generation. Required if start_year provided
1066
+ overwrite (bool, optional): Whether to overwrite existing dataset. Defaults to False
1067
+
1068
+ Returns:
1069
+ dict: Response from the AI dataset generation API
1070
+
1071
+ Raises:
1072
+ ValidationError: If required parameters are missing or invalid
1073
+ APIError: If the API request fails
1074
+ """
1075
+
1076
+ # we have the parameters, let pass the parameters to the random sample function
1077
+ # task_id = self.random_sample(name, aoi_geojson, expression_x, expression_y, samples, tile_size, crs, res, region, start_year, end_year, overwrite)
1078
+ config = {
1079
+ "expressions" : [{"expr": expression_x, "res": res, "prefix": "x"}],
1080
+ "filters" : []
1081
+ }
1082
+ config["expressions"].append({"expr": expression_y, "res" : res, "prefix": "y"})
1083
+
1084
+ expression_x = expression_x.replace("{year}", str(start_year))
1085
+ expression_y = expression_y.replace("{year}", str(start_year))
1086
+ print("the aoi geojson is ", aoi_geojson)
1087
+ with open(aoi_geojson, 'r') as f:
1088
+ aoi_data = json.load(f)
1089
+ print("the config is ", config)
1090
+ task_id = self.random_sample(
1091
+ name=name,
1092
+ config=config,
1093
+ aoi=aoi_data,
1094
+ samples=samples,
1095
+ year_range=[start_year, end_year],
1096
+ crs=crs,
1097
+ tile_size=tile_size,
1098
+ res=res,
1099
+ region=region,
1100
+ output="netcdf",
1101
+ server=self.url,
1102
+ bucket="terrakio-mass-requests",
1103
+ overwrite=True
1104
+ )["task_id"]
1105
+ print("the task id is ", task_id)
1106
+ task_id = self.start_mass_stats_job(task_id)
1107
+ print("the task id is ", task_id)
1108
+ return task_id
1109
+
973
1110
 
974
1111
  def train_model(self, model_name: str, training_data: dict) -> dict:
975
1112
  """
@@ -1003,3 +1140,90 @@ class BaseClient:
1003
1140
  except requests.RequestException as e:
1004
1141
  raise APIError(f"Model training request failed: {str(e)}")
1005
1142
 
1143
+ def deploy_model(self, dataset: str, product:str, model_name:str, input_expression: str, model_training_job_name: str, uid: str, dates_iso8601: list):
1144
+ # we have the dataset and we have the product, and we have the model name, we need to create a new json file and add that to the dataset as our virtual dataset
1145
+ # upload the script to the bucket, the script should be able to download the model and do the inferencing
1146
+ # we need to upload the the json to the to the dataset as our virtual dataset
1147
+ # then we do nothing and wait for the user to make the request call to the explorer
1148
+ # we should have a uniform script for the random forest deployment
1149
+ # create a script for each model
1150
+ # upload script to google bucket,
1151
+ #
1152
+
1153
+ script_content = self._generate_script(model_name, product, model_training_job_name, uid)
1154
+ # self.create_dataset(collection = "terrakio-datasets", input = input, )
1155
+ # we have the script, we need to upload it to the bucket
1156
+ script_name = f"{product}.py"
1157
+ print("the script content is ", script_content)
1158
+ print("the script name is ", script_name)
1159
+ self._upload_script_to_bucket(script_content, script_name, model_training_job_name, uid)
1160
+ # after uploading the script, we need to create a new virtual dataset
1161
+ self._create_dataset(name = dataset, collection = "terrakio-datasets", products = [product], path = f"gs://terrakio-mass-requests/{uid}/{model_training_job_name}/inference_scripts", input = input_expression, dates_iso8601 = dates_iso8601, padding = 0)
1162
+
1163
+ def _generate_script(self, model_name: str, product: str, model_training_job_name: str, uid: str) -> str:
1164
+ return textwrap.dedent(f'''
1165
+ import logging
1166
+ from io import BytesIO
1167
+ from google.cloud import storage
1168
+ from onnxruntime import InferenceSession
1169
+ import numpy as np
1170
+ import xarray as xr
1171
+ import datetime
1172
+
1173
+ logging.basicConfig(
1174
+ level=logging.INFO
1175
+ )
1176
+
1177
+ def get_model():
1178
+ logging.info("Loading model for {model_name}...")
1179
+
1180
+ client = storage.Client()
1181
+ bucket = client.get_bucket('terrakio-mass-requests')
1182
+ blob = bucket.blob('{uid}/{model_training_job_name}/models/{model_name}.onnx')
1183
+
1184
+ model = BytesIO()
1185
+ blob.download_to_file(model)
1186
+ model.seek(0)
1187
+
1188
+ session = InferenceSession(model.read(), providers=["CPUExecutionProvider"])
1189
+ return session
1190
+
1191
+ def {product}(*bands, model):
1192
+ logging.info("start preparing data")
1193
+
1194
+ original_shape = bands[0].shape
1195
+ logging.info(f"Original shape: {{original_shape}}")
1196
+
1197
+ transformed_bands = []
1198
+ for band in bands:
1199
+ transformed_band = band.values.reshape(-1,1)
1200
+ transformed_bands.append(transformed_band)
1201
+
1202
+ input_data = np.hstack(transformed_bands)
1203
+
1204
+ logging.info(f"Final input shape: {{input_data.shape}}")
1205
+
1206
+ output = model.run(None, {{"float_input": input_data.astype(np.float32)}})[0]
1207
+
1208
+ logging.info(f"Model output shape: {{output.shape}}")
1209
+
1210
+ output_reshaped = output.reshape(original_shape)
1211
+ result = xr.DataArray(
1212
+ data=output_reshaped,
1213
+ dims=bands[0].dims,
1214
+ coords=bands[0].coords
1215
+ )
1216
+
1217
+ return result
1218
+ ''').strip()
1219
+
1220
+ def _upload_script_to_bucket(self, script_content: str, script_name: str, model_training_job_name: str, uid: str):
1221
+ """Upload the generated script to Google Cloud Storage"""
1222
+
1223
+ client = storage.Client()
1224
+ bucket = client.get_bucket('terrakio-mass-requests')
1225
+ blob = bucket.blob(f'{uid}/{model_training_job_name}/inference_scripts/{script_name}')
1226
+ # the first layer is the uid, the second layer is the model training job name
1227
+ blob.upload_from_string(script_content, content_type='text/plain')
1228
+ logging.info(f"Script uploaded successfully to {uid}/{model_training_job_name}/inference_scripts/{script_name}")
1229
+
@@ -117,7 +117,7 @@ class DatasetManagement:
117
117
 
118
118
  # Add optional parameters if provided
119
119
  for param in ["products", "dates_iso8601", "bucket", "path", "data_type",
120
- "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform"]:
120
+ "no_data", "l_max", "y_size", "x_size", "proj4", "abstract", "geotransform", "input"]:
121
121
  if param in kwargs:
122
122
  payload[param] = kwargs[param]
123
123
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: terrakio-core
3
- Version: 0.2.4
3
+ Version: 0.2.7
4
4
  Summary: Core components for Terrakio API clients
5
5
  Author-email: Yupeng Chao <yupeng@haizea.com.au>
6
6
  Project-URL: Homepage, https://github.com/HaizeaAnalytics/terrakio-python-api
File without changes
File without changes
File without changes