eotdl 2025.4.22.post3__py3-none-any.whl → 2025.5.26.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,230 @@
1
+
2
+ from .temporal_utils import compute_temporal_extent
3
+ from .spatial_utils import buffer_geometry
4
+
5
+ import os
6
+ from typing import List
7
+ import geopandas as gpd
8
+
9
+ import pandas as pd
10
+ from typing import List, Dict, Any
11
+
12
+ try:
13
+ import openeo_gfmap
14
+ except ImportError:
15
+ print("openeo_gfmap is not installed, please install it with 'pip install openeo_gfmap'")
16
+
17
+ from openeo_gfmap.manager.job_splitters import split_job_s2grid, append_h3_index
18
+
19
+ def create_utm_patch(
20
+ row: Any,
21
+ parent_crs: str,
22
+ start_date: str,
23
+ nb_months: int,
24
+ pixel_size: int,
25
+ resolution: int
26
+ ) -> Dict[str, Any]:
27
+ """
28
+ Process a single row from a GeoDataFrame to extract metadata and buffer geometry.
29
+
30
+ Parameters:
31
+ row: A single row of a GeoDataFrame.
32
+ start_date: Start date for the temporal extent (ISO format string).
33
+ nb_months: Number of months for the temporal extent.
34
+ buffer_distance: Buffer distance in meters for the geometry.
35
+
36
+ Returns:
37
+ A dictionary with metadata and processed geometry.
38
+ """
39
+ # Extract the polygon geometry from the row
40
+ # Create a single-row GeoDataFrame
41
+ gdf = gpd.GeoDataFrame(
42
+ [row],
43
+ crs=parent_crs,
44
+ geometry="geometry"
45
+ )
46
+
47
+ buffer = int(pixel_size*resolution/2)
48
+ # Apply the buffer_geometry function
49
+ buffered_gdf = buffer_geometry(gdf, buffer=buffer, resolution = resolution)
50
+
51
+ # Compute the temporal extent
52
+ temporal_extent = compute_temporal_extent(start_date, nb_months)
53
+
54
+ # Return the processed data
55
+ return {
56
+ "fid": row.get("fid"), # Include any relevant identifier
57
+ "geometry": buffered_gdf.iloc[0].geometry,
58
+ "crs": buffered_gdf.crs.to_string(),
59
+ "temporal_extent": temporal_extent,
60
+ }
61
+
62
+ def process_patch_geodataframe(
63
+ geodataframe: gpd.GeoDataFrame,
64
+ start_date: str,
65
+ nb_months: int,
66
+ buffer_distance: int,
67
+ resolution: int
68
+ ) -> gpd.GeoDataFrame:
69
+ """
70
+ Process a GeoDataFrame to generate buffered geometries and metadata for each row.
71
+
72
+ Parameters:
73
+ geodataframe: The input GeoDataFrame with geometries to process.
74
+ start_date: Start date for the temporal extent (ISO format string).
75
+ nb_months: Number of months for the temporal extent.
76
+ buffer_distance: Buffer distance in meters for the geometry.
77
+ resolution: Spatial resolution for geometry rounding.
78
+
79
+ Returns:
80
+ A processed GeoDataFrame with buffered geometries and additional metadata.
81
+ """
82
+ results = [] # List to store processed rows
83
+
84
+ for _, row in geodataframe.iterrows():
85
+ # Process each row and collect the result
86
+ result = create_utm_patch(row, geodataframe.crs, start_date, nb_months, buffer_distance, resolution)
87
+ results.append(result)
88
+
89
+ # Convert the list of results into a GeoDataFrame
90
+ processed_gdf = gpd.GeoDataFrame(
91
+ results,
92
+ geometry="geometry",
93
+ crs=result['crs']
94
+ )
95
+
96
+ return processed_gdf
97
+
98
+ def process_geodataframe(
99
+ geodataframe: gpd.GeoDataFrame,
100
+ start_date: str,
101
+ nb_months: int,
102
+ extra_cols: List[str] = [],
103
+
104
+ ) -> gpd.GeoDataFrame:
105
+ """
106
+ Process a GeoDataFrame to generate buffered geometries and metadata for each row.
107
+
108
+ Parameters:
109
+ geodataframe: The input GeoDataFrame with geometries to process.
110
+ start_date: Start date for the temporal extent (ISO format string).
111
+ nb_months: Number of months for the temporal extent.
112
+ buffer_distance: Buffer distance in meters for the geometry.
113
+ resolution: Spatial resolution for geometry rounding.
114
+
115
+ Returns:
116
+ A processed GeoDataFrame with buffered geometries and additional metadata.
117
+ """
118
+ results = [] # List to store processed rows
119
+ geodataframe = geodataframe.to_crs(epsg=4326)
120
+
121
+ for _, row in geodataframe.iterrows():
122
+
123
+ # Compute the temporal extent
124
+ temporal_extent = compute_temporal_extent(start_date, nb_months)
125
+
126
+ # Return the processed data
127
+ result = {
128
+ # "fid": row.get("fid"), # Include any relevant identifier
129
+ "geometry": row.geometry,
130
+ "crs": geodataframe.crs.to_string(),
131
+ "temporal_extent": temporal_extent,
132
+ **{col: row[col] for col in extra_cols}
133
+ }
134
+
135
+ results.append(result)
136
+
137
+ # Convert the list of results into a GeoDataFrame
138
+ processed_gdf = gpd.GeoDataFrame(
139
+ results,
140
+ geometry="geometry",
141
+ crs=result['crs']
142
+ )
143
+
144
+ return processed_gdf
145
+
146
+ def split_geodataframe_by_s2_grid(base_gdf: gpd.GeoDataFrame, max_points:int, grid_resolution:int = 3) -> List[gpd.GeoDataFrame]:
147
+ """Append H3 index and split into smaller job dataframes."""
148
+ original_crs = base_gdf.crs
149
+ # Append H3 index, which will change the CRS temporarily
150
+ base_gdf = append_h3_index(base_gdf, grid_resolution=grid_resolution)
151
+ # Transform back to the original CRS
152
+ h3_gdf = base_gdf.to_crs(original_crs)
153
+ split_gdf = split_job_s2grid(h3_gdf, max_points=max_points)
154
+ return split_gdf
155
+
156
+ #TODO evaluate need
157
+ def generate_featurecollection_dataframe(split_jobs: List[gpd.GeoDataFrame]) -> pd.DataFrame:
158
+ """Create a DataFrame from split jobs with essential information for each job, including feature count."""
159
+ job_data = []
160
+ for job in split_jobs:
161
+
162
+ # Ensure the temporal_extent field exists and handle missing data
163
+ temporal_extent = job.temporal_extent.iloc[0] if 'temporal_extent' in job.columns and job.temporal_extent.iloc[0] else None
164
+
165
+ # Handle missing S2 and H3 information gracefully
166
+ s2_tile = job.tile.iloc[0] if 'tile' in job.columns and job.tile.iloc[0] else None
167
+ h3index = job.h3index.iloc[0] if 'h3index' in job.columns and job.h3index.iloc[0] else None
168
+
169
+ # Extract CRS as string
170
+ crs = job.crs.to_string() if job.crs else None
171
+
172
+ # Count the number of features (rows) in the GeoDataFrame
173
+ feature_count = len(job)
174
+
175
+ # Serialize the entire GeoDataFrame to GeoJSON (including geometry and attributes)
176
+ job_json = job.to_json()
177
+
178
+ # Append all information, including feature count
179
+ job_data.append({
180
+ 'temporal_extent': temporal_extent,
181
+ 'geometry': job_json, # The entire GeoDataFrame serialized to GeoJSON
182
+ 's2_tile': s2_tile,
183
+ 'h3index': h3index,
184
+ 'crs': crs,
185
+ 'feature_count': feature_count # Include the feature count for the job
186
+ })
187
+
188
+ # Return the DataFrame with all job metadata
189
+ return pd.DataFrame(job_data)
190
+
191
+
192
+ def process_and_create_advanced_patch_jobs(
193
+ gdf: gpd.GeoDataFrame,
194
+ start_date: str,
195
+ nb_months: int,
196
+ pixel_size: int,
197
+ resolution: int,
198
+ max_points: int
199
+ ) -> pd.DataFrame:
200
+ """
201
+ A wrapper function to process geospatial data and generate a job metadata DataFrame.
202
+
203
+ Args:
204
+ gdf (GeoDataFrame): The input GeoDataFrame with geometries.
205
+ start_date (str): The start date for the temporal extent (ISO format).
206
+ nb_months (int): Number of months for the temporal extent.
207
+ buffer_distance (int): Buffer distance in meters for geometry.
208
+ resolution (int): Spatial resolution for geometry rounding.
209
+ max_points (int): Maximum number of points per job for splitting the grid.
210
+
211
+ Returns:
212
+ pd.DataFrame: A DataFrame containing the job metadata.
213
+ """
214
+ # Step 1: Process GeoDataFrame in which we create patches and temporal info
215
+ processed_gdf = process_patch_geodataframe(
216
+ gdf, start_date, nb_months, pixel_size, resolution
217
+ )
218
+
219
+ # Step 2: Split processed GeoDataFrame into smaller jobs by Sentinel-2 grid
220
+ split_jobs = split_geodataframe_by_s2_grid(processed_gdf, max_points)
221
+
222
+ # Step 3: Generate the job metadata DataFrame
223
+ job_df = generate_featurecollection_dataframe(split_jobs)
224
+
225
+ return job_df
226
+
227
+
228
+
229
+
230
+
@@ -0,0 +1,180 @@
1
+ from __future__ import (
2
+ annotations, # Required because of type annotations for return results
3
+ )
4
+
5
+ import boto3
6
+ from boto3.s3.transfer import TransferConfig
7
+ from dataclasses import dataclass
8
+ import os
9
+ import hashlib
10
+ import datetime
11
+ import geopandas as gpd
12
+ from openeo.rest.connection import Connection
13
+ from tempfile import NamedTemporaryFile
14
+
15
+ def upload_geoparquet_artifactory(gdf: gpd.GeoDataFrame, row_id: int) -> str:
16
+ # Save the dataframe as geoparquet to upload it to artifactory
17
+ temporary_file = NamedTemporaryFile()
18
+ gdf.to_parquet(temporary_file.name)
19
+
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class S3URI:
24
+ bucket: str
25
+ key: str
26
+
27
+ @classmethod
28
+ def from_str(cls, uri: str) -> S3URI:
29
+ s3_prefix = "s3://"
30
+ if uri.startswith(s3_prefix):
31
+ without_prefix = uri[len(s3_prefix) :]
32
+ without_prefix_parts = without_prefix.split("/")
33
+ bucket = without_prefix_parts[0]
34
+ if len(without_prefix_parts) == 1:
35
+ return S3URI(bucket, "")
36
+ else:
37
+ return S3URI(bucket, "/".join(without_prefix_parts[1:]))
38
+ else:
39
+ raise ValueError(
40
+ "Input {uri} is not a valid S3 URI should be of form s3://<bucket>/<key>"
41
+ )
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class AWSSTSCredentials:
46
+ AWS_ACCESS_KEY_ID: str
47
+ AWS_SECRET_ACCESS_KEY: str
48
+ AWS_SESSION_TOKEN: str
49
+ subject_from_web_identity_token: str
50
+ STS_ENDPOINT = "https://sts.prod.warsaw.openeo.dataspace.copernicus.eu"
51
+
52
+ @classmethod
53
+ def _from_assume_role_response(cls, resp: dict) -> AWSSTSCredentials:
54
+ d = resp["Credentials"]
55
+
56
+ return AWSSTSCredentials(
57
+ AWS_ACCESS_KEY_ID=d["AccessKeyId"],
58
+ AWS_SECRET_ACCESS_KEY=d["SecretAccessKey"],
59
+ AWS_SESSION_TOKEN=d["SessionToken"],
60
+ subject_from_web_identity_token=resp["SubjectFromWebIdentityToken"],
61
+ )
62
+
63
+ def set_as_environment_variables(self) -> None:
64
+ """If temporary credentials are to be used elsewhere in the notebook"""
65
+ os.environ["AWS_ACCESS_KEY_ID"] = self.AWS_ACCESS_KEY_ID
66
+ os.environ["AWS_SECRET_ACCESS_KEY"] = self.AWS_SECRET_ACCESS_KEY
67
+ os.environ["AWS_SESSION_TOKEN"] = self.AWS_SESSION_TOKEN
68
+
69
+ def as_kwargs(self) -> dict:
70
+ return {
71
+ "aws_access_key_id": self.AWS_ACCESS_KEY_ID,
72
+ "aws_secret_access_key": self.AWS_SECRET_ACCESS_KEY,
73
+ "aws_session_token": self.AWS_SESSION_TOKEN,
74
+ }
75
+
76
+ @classmethod
77
+ def from_openeo_connection(cls, conn: Connection) -> AWSSTSCredentials:
78
+ """
79
+ Takes an OpenEO connection object and returns temporary credentials to interact with S3
80
+ """
81
+ auth_token = conn.auth.bearer.split("/")
82
+ os.environ["AWS_ENDPOINT_URL_STS"] = cls.STS_ENDPOINT
83
+ sts = boto3.client("sts")
84
+ return AWSSTSCredentials._from_assume_role_response(
85
+ sts.assume_role_with_web_identity(
86
+ RoleArn="arn:aws:iam::000000000000:role/S3Access",
87
+ RoleSessionName=auth_token[1],
88
+ WebIdentityToken=auth_token[2],
89
+ DurationSeconds=43200,
90
+ )
91
+ )
92
+
93
+ def get_user_hash(self) -> str:
94
+ hash_object = hashlib.sha1(self.subject_from_web_identity_token.encode())
95
+ return hash_object.hexdigest()
96
+
97
+
98
+ class OpenEOArtifactHelper:
99
+ BUCKET_NAME = "OpenEO-artifacts"
100
+ S3_ENDPOINT = "https://s3.prod.warsaw.openeo.dataspace.copernicus.eu"
101
+ # From what size will we switch to multi-part-upload
102
+ MULTIPART_THRESHOLD_IN_MB = 50
103
+
104
+ def __init__(self, creds: AWSSTSCredentials):
105
+ self._creds = creds
106
+ self.session = boto3.Session(**creds.as_kwargs())
107
+
108
+ @classmethod
109
+ def from_openeo_connection(cls, conn: Connection) -> OpenEOArtifactHelper:
110
+ creds = AWSSTSCredentials.from_openeo_connection(conn)
111
+ return OpenEOArtifactHelper(creds)
112
+
113
+ def get_s3_client(self):
114
+ return self.session.client("s3", endpoint_url=self.S3_ENDPOINT)
115
+
116
+ def set_env(self):
117
+ os.environ["AWS_ENDPOINT_URL_S3"] = self.S3_ENDPOINT
118
+
119
+ def user_prefix(self) -> str:
120
+ """Each user has its own prefix retrieve it"""
121
+ return self._creds.get_user_hash()
122
+
123
+ def get_upload_prefix(self):
124
+ return (
125
+ f"{self.user_prefix()}/{datetime.datetime.utcnow().strftime('%Y/%m/%d')}/"
126
+ )
127
+
128
+ def get_upload_key(self, object_name: str) -> str:
129
+ return f"{self.get_upload_prefix()}{object_name}"
130
+
131
+ def upload_bytes(self, object_name: str, blob: bytes) -> str:
132
+ """Upload a bunch of bytes into an object and return an S3 URI to it"""
133
+ bucket = self.BUCKET_NAME
134
+ key = self.get_upload_key(object_name)
135
+ self.get_s3_client().put_object(Body=blob, Bucket=bucket, Key=key)
136
+ return f"s3://{bucket}/{key}"
137
+
138
+ def upload_string(self, object_name: str, s: str, encoding: str = "utf-8") -> str:
139
+ """Upload a string into an object"""
140
+ return self.upload_bytes(object_name, s.encode(encoding))
141
+
142
+ def upload_file(self, object_name: str, src_file_path: str) -> str:
143
+ MB = 1024**2
144
+ config = TransferConfig(multipart_threshold=self.MULTIPART_THRESHOLD_IN_MB * MB)
145
+
146
+ bucket = self.BUCKET_NAME
147
+ key = self.get_upload_key(object_name)
148
+
149
+ self.get_s3_client().upload_file(src_file_path, bucket, key, Config=config)
150
+ return f"s3://{bucket}/{key}"
151
+
152
+ def get_presigned_url(
153
+ self, s3_uri: str, expires_in_seconds: int = 3600 * 24 * 6
154
+ ) -> str:
155
+ typed_s3_uri = S3URI.from_str(s3_uri)
156
+ return self.get_s3_client().generate_presigned_url(
157
+ "get_object",
158
+ Params={"Bucket": typed_s3_uri.bucket, "Key": typed_s3_uri.key},
159
+ ExpiresIn=expires_in_seconds,
160
+ )
161
+
162
+
163
+
164
+ def upload_geoparquet_file(gdf: gpd.GeoDataFrame, conn: Connection) -> str:
165
+ geo_parquet_path = "bounding_box_geometry.parquet"
166
+ gdf.to_parquet(geo_parquet_path)
167
+
168
+ # Create an instance of the custom S3 uploader
169
+ uploader = OpenEOArtifactHelper.from_openeo_connection(conn)
170
+
171
+ # Upload the GeoParquet file to S3
172
+ s3_uri = uploader.upload_file(geo_parquet_path, geo_parquet_path)
173
+
174
+ # Get the presigned URL for accessing the uploaded file
175
+ presigned_url = uploader.get_presigned_url(s3_uri)
176
+
177
+ # Optionally clean up the local file
178
+ os.remove(geo_parquet_path)
179
+
180
+ return presigned_url
@@ -0,0 +1,28 @@
1
+ import geopandas as gpd
2
+ from shapely.geometry import Point
3
+
4
+ def buffer_geometry(gdf: gpd.GeoDataFrame, buffer: int, resolution: int) -> gpd.GeoDataFrame:
5
+ """
6
+ Buffer the geometries in a GeoDataFrame and return the modified GeoDataFrame.
7
+
8
+ Parameters:
9
+ gdf: Input GeoDataFrame with geometries to buffer.
10
+ buffer: Buffer distance in meters.
11
+
12
+ Returns:
13
+ A GeoDataFrame with buffered geometries.
14
+ """
15
+ # Ensure the GeoDataFrame has a valid CRS
16
+ if gdf.crs is None:
17
+ raise ValueError("Input GeoDataFrame must have a defined CRS.")
18
+
19
+ # Estimate the UTM CRS based on the data's centroid
20
+ utm = gdf.estimate_utm_crs()
21
+ gdf_utm = gdf.to_crs(utm)
22
+
23
+ # Round the centroids to the nearest 20m grid and apply buffering
24
+ gdf_utm['geometry'] = gdf_utm.centroid.apply(
25
+ lambda point: Point(round(point.x / resolution) * resolution, round(point.y / resolution) * resolution)
26
+ ).buffer(distance=buffer, cap_style=3)
27
+
28
+ return gdf_utm
@@ -0,0 +1,16 @@
1
+ import pandas as pd
2
+ from typing import List
3
+
4
+ def compute_temporal_extent(start_date: str, nb_months: int) -> List[str]:
5
+ """Compute temporal extent based on a start date and duration in months.
6
+
7
+ Args:
8
+ start_date (str): Start date in 'YYYY-MM-DD' format.
9
+ nb_months (int): Number of months from the start date.
10
+
11
+ Returns:
12
+ List[str]: Temporal extent as [start_date, end_date].
13
+ """
14
+ start = pd.to_datetime(start_date)
15
+ end = start + pd.DateOffset(months=nb_months)
16
+ return [start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d')]
eotdl/fe/retrieve.py ADDED
@@ -0,0 +1,18 @@
1
+ from ..repos import FEAPIRepo
2
+
3
+
4
+ def retrieve_pipelines(name=None, limit=None):
5
+ api_repo = FEAPIRepo()
6
+ data, error = api_repo.retrieve_pipelines(name, limit)
7
+ if data and not error:
8
+ models = [d["name"] for d in data] if data else []
9
+ return models
10
+ return []
11
+
12
+
13
+ def retrieve_pipeline(name):
14
+ repo = FEAPIRepo()
15
+ data, error = repo.retrieve_pipeline(name)
16
+ if error:
17
+ raise Exception(error)
18
+ return data
eotdl/fe/stage.py ADDED
@@ -0,0 +1,63 @@
1
+ import os
2
+ from pathlib import Path
3
+ from tqdm import tqdm
4
+ import geopandas as gpd
5
+
6
+ from ..auth import with_auth
7
+ from .retrieve import retrieve_pipeline
8
+ from ..repos import FilesAPIRepo
9
+ from ..files.metadata import Metadata
10
+
11
+ @with_auth
12
+ def stage_pipeline(
13
+ pipeline_name,
14
+ version=None,
15
+ path=None,
16
+ logger=print,
17
+ assets=False,
18
+ force=False,
19
+ verbose=False,
20
+ user=None,
21
+ file=None,
22
+ ):
23
+ pipeline = retrieve_pipeline(pipeline_name)
24
+ if version is None:
25
+ version = sorted([v['version_id'] for v in pipeline["versions"]])[-1]
26
+ else:
27
+ assert version in [
28
+ v["version_id"] for v in pipeline["versions"]
29
+ ], f"Version {version} not found"
30
+ download_base_path = os.getenv(
31
+ "EOTDL_DOWNLOAD_PATH", str(Path.home()) + "/.cache/eotdl/pipelines"
32
+ )
33
+ if path is None:
34
+ download_path = download_base_path + "/" + pipeline_name
35
+ else:
36
+ download_path = path + "/" + pipeline_name
37
+ # check if pipeline already exists
38
+ if os.path.exists(download_path) and not force:
39
+ os.makedirs(download_path, exist_ok=True)
40
+ # raise Exception(
41
+ # f"pipeline `{pipeline['name']} v{str(version)}` already exists at {download_path}. To force download, use force=True or -f in the CLI."
42
+ # )
43
+
44
+ # stage metadata
45
+ repo = FilesAPIRepo()
46
+ catalog_path = repo.stage_file(pipeline["id"], f"catalog.v{version}.parquet", user, download_path)
47
+
48
+ # stage README.md
49
+ metadata = Metadata(**pipeline['metadata'], name=pipeline['name'])
50
+ metadata.save_metadata(download_path)
51
+
52
+ if assets:
53
+ gdf = gpd.read_parquet(catalog_path)
54
+ for _, row in tqdm(gdf.iterrows(), total=len(gdf), desc="Staging assets"):
55
+ for k, v in row["assets"].items():
56
+ stage_pipeline_file(v["href"], download_path)
57
+
58
+ return download_path
59
+
60
+ @with_auth
61
+ def stage_pipeline_file(file_url, path, user):
62
+ repo = FilesAPIRepo()
63
+ return repo.stage_file_url(file_url, path, user)
eotdl/fe/update.py ADDED
@@ -0,0 +1,12 @@
1
+ from ..repos import FEAPIRepo
2
+ from ..auth import with_auth
3
+ from .retrieve import retrieve_pipeline
4
+
5
+ @with_auth
6
+ def deactivate_pipeline(pipeline_name, user):
7
+ pipeline = retrieve_pipeline(pipeline_name)
8
+ repo = FEAPIRepo()
9
+ data, error = repo.deactivate_pipeline(pipeline['id'], user)
10
+ if error:
11
+ raise Exception(error)
12
+ return data
eotdl/files/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ from .get_url import get_file_url
eotdl/files/get_url.py ADDED
@@ -0,0 +1,18 @@
1
+ from ..auth import with_auth
2
+ from ..repos import FilesAPIRepo
3
+ from ..datasets import retrieve_dataset
4
+ from ..models import retrieve_model
5
+ from ..fe import retrieve_pipeline
6
+
7
+ @with_auth
8
+ def get_file_url(filename, dataset_or_model_name, endpoint, user):
9
+ if endpoint == "datasets":
10
+ dataset_or_model_id = retrieve_dataset(dataset_or_model_name)['id']
11
+ elif endpoint == "models":
12
+ dataset_or_model_id = retrieve_model(dataset_or_model_name)['id']
13
+ elif endpoint == "pipelines":
14
+ dataset_or_model_id = retrieve_pipeline(dataset_or_model_name)['id']
15
+ else:
16
+ raise Exception("Invalid endpoint (datasets, models or pipelines)")
17
+ repo = FilesAPIRepo()
18
+ return repo.generate_presigned_url(filename, dataset_or_model_id, user, endpoint)
eotdl/files/ingest.py CHANGED
@@ -98,13 +98,13 @@ def ingest_virtual( # could work for a list of paths with minimal changes...
98
98
  data.append(create_stac_item('README.md', str(path / "README.md")))
99
99
  gdf = gpd.GeoDataFrame(data, geometry='geometry')
100
100
  gdf.to_parquet(path / "catalog.parquet")
101
- return ingest(path, repo, retrieve, mode)
101
+ return ingest(path, repo, retrieve, mode, private=False)
102
102
 
103
103
  def ingest_catalog(path, repo, retrieve, mode):
104
- return ingest(path, repo, retrieve, mode)
104
+ return ingest(path, repo, retrieve, mode, private=False)
105
105
 
106
106
  @with_auth
107
- def ingest(path, repo, retrieve, mode, user):
107
+ def ingest(path, repo, retrieve, mode, private, user):
108
108
  try:
109
109
  readme = frontmatter.load(path.joinpath("README.md"))
110
110
  metadata_dict = readme.metadata
@@ -115,7 +115,7 @@ def ingest(path, repo, retrieve, mode, user):
115
115
  print(str(e))
116
116
  raise Exception("Error loading metadata")
117
117
  # retrieve dataset (create if doesn't exist)
118
- dataset_or_model = retrieve(metadata, user)
118
+ dataset_or_model = retrieve(metadata, user, private)
119
119
  current_version = sorted([v['version_id'] for v in dataset_or_model["versions"]])[-1]
120
120
  # TODO: update README if metadata changed in UI (db)
121
121
  # update_metadata = True
eotdl/files/metadata.py CHANGED
@@ -35,7 +35,8 @@ class Metadata(BaseModel):
35
35
  f.write(f"name: {self.name}\n")
36
36
  f.write(f"license: {self.license}\n")
37
37
  f.write(f"source: {self.source}\n")
38
- f.write(f"thumbnail: {self.thumbnail}\n")
38
+ if self.thumbnail:
39
+ f.write(f"thumbnail: {self.thumbnail}\n")
39
40
  f.write(f"authors:\n")
40
41
  for author in self.authors:
41
42
  f.write(f" - {author}\n")
eotdl/models/ingest.py CHANGED
@@ -3,7 +3,7 @@ from pathlib import Path
3
3
  from ..repos import ModelsAPIRepo
4
4
  from ..files.ingest import prep_ingest_stac, prep_ingest_folder, ingest, ingest_virtual, ingest_catalog
5
5
 
6
- def retrieve_model(metadata, user):
6
+ def retrieve_model(metadata, user, private=False):
7
7
  repo = ModelsAPIRepo()
8
8
  data, error = repo.retrieve_model(metadata.name)
9
9
  # print(data, error)
@@ -22,6 +22,7 @@ def ingest_model(
22
22
  logger=print,
23
23
  force_metadata_update=False,
24
24
  sync_metadata=False,
25
+ private=False,
25
26
  ):
26
27
  path = Path(path)
27
28
  if not path.is_dir():
@@ -30,7 +31,7 @@ def ingest_model(
30
31
  prep_ingest_stac(path, logger)
31
32
  else:
32
33
  prep_ingest_folder(path, verbose, logger, force_metadata_update, sync_metadata)
33
- return ingest(path, ModelsAPIRepo(), retrieve_model, 'models')
34
+ return ingest(path, ModelsAPIRepo(), retrieve_model, 'models', private)
34
35
 
35
36
  def ingest_virtual_model( # could work for a list of paths with minimal changes...
36
37
  path,
@@ -18,15 +18,26 @@ class DatasetsAPIRepo(APIRepo):
18
18
  response = requests.get(url)
19
19
  return self.format_response(response)
20
20
 
21
+ def retrieve_private_datasets(self, user):
22
+ url = self.url + "datasets/private"
23
+ response = requests.get(url, headers=self.generate_headers(user))
24
+ return self.format_response(response)
25
+
21
26
  def retrieve_dataset(self, name):
22
27
  response = requests.get(self.url + "datasets?name=" + name)
23
28
  return self.format_response(response)
24
29
 
30
+ def retrieve_private_dataset(self, name, user):
31
+ response = requests.get(self.url + "datasets/private?name=" + name, headers=self.generate_headers(user))
32
+ return self.format_response(response)
33
+
25
34
  def get_dataset_by_id(self, dataset_id):
26
35
  response = requests.get(self.url + "datasets/" + dataset_id)
27
36
  return self.format_response(response)
28
37
 
29
- def create_dataset(self, metadata, user):
38
+ def create_dataset(self, metadata, user, private=False):
39
+ if private:
40
+ metadata["visibility"] = "private"
30
41
  response = requests.post(
31
42
  self.url + "datasets",
32
43
  json=metadata,