stac-fastapi-core 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1071 @@
1
+ """Core client."""
2
+
3
+ import logging
4
+ from collections import deque
5
+ from datetime import datetime as datetime_type
6
+ from datetime import timezone
7
+ from enum import Enum
8
+ from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
9
+ from urllib.parse import unquote_plus, urljoin
10
+
11
+ import attr
12
+ import orjson
13
+ from fastapi import HTTPException, Request
14
+ from overrides import overrides
15
+ from pydantic import ValidationError
16
+ from pygeofilter.backends.cql2_json import to_cql2
17
+ from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
18
+ from stac_pydantic import Collection, Item, ItemCollection
19
+ from stac_pydantic.links import Relations
20
+ from stac_pydantic.shared import BBox, MimeTypes
21
+ from stac_pydantic.version import STAC_VERSION
22
+
23
+ from stac_fastapi.core.base_database_logic import BaseDatabaseLogic
24
+ from stac_fastapi.core.base_settings import ApiBaseSettings
25
+ from stac_fastapi.core.models.links import PagingLinks
26
+ from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer
27
+ from stac_fastapi.core.session import Session
28
+ from stac_fastapi.core.utilities import filter_fields
29
+ from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient
30
+ from stac_fastapi.extensions.third_party.bulk_transactions import (
31
+ BaseBulkTransactionsClient,
32
+ BulkTransactionMethod,
33
+ Items,
34
+ )
35
+ from stac_fastapi.types import stac as stac_types
36
+ from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES
37
+ from stac_fastapi.types.core import AsyncBaseCoreClient, AsyncBaseTransactionsClient
38
+ from stac_fastapi.types.extension import ApiExtension
39
+ from stac_fastapi.types.requests import get_base_url
40
+ from stac_fastapi.types.rfc3339 import DateTimeType, rfc3339_str_to_datetime
41
+ from stac_fastapi.types.search import BaseSearchPostRequest
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ @attr.s
47
+ class CoreClient(AsyncBaseCoreClient):
48
+ """Client for core endpoints defined by the STAC specification.
49
+
50
+ This class is a implementation of `AsyncBaseCoreClient` that implements the core endpoints
51
+ defined by the STAC specification. It uses the `DatabaseLogic` class to interact with the
52
+ database, and `ItemSerializer` and `CollectionSerializer` to convert between STAC objects and
53
+ database records.
54
+
55
+ Attributes:
56
+ session (Session): A requests session instance to be used for all HTTP requests.
57
+ item_serializer (Type[serializers.ItemSerializer]): A serializer class to be used to convert
58
+ between STAC items and database records.
59
+ collection_serializer (Type[serializers.CollectionSerializer]): A serializer class to be
60
+ used to convert between STAC collections and database records.
61
+ database (DatabaseLogic): An instance of the `DatabaseLogic` class that is used to interact
62
+ with the database.
63
+ """
64
+
65
+ database: BaseDatabaseLogic = attr.ib()
66
+ base_conformance_classes: List[str] = attr.ib(
67
+ factory=lambda: BASE_CONFORMANCE_CLASSES
68
+ )
69
+ extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list))
70
+
71
+ session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
72
+ item_serializer: Type[ItemSerializer] = attr.ib(default=ItemSerializer)
73
+ collection_serializer: Type[CollectionSerializer] = attr.ib(
74
+ default=CollectionSerializer
75
+ )
76
+ post_request_model = attr.ib(default=BaseSearchPostRequest)
77
+ stac_version: str = attr.ib(default=STAC_VERSION)
78
+ landing_page_id: str = attr.ib(default="stac-fastapi")
79
+ title: str = attr.ib(default="stac-fastapi")
80
+ description: str = attr.ib(default="stac-fastapi")
81
+
82
+ def _landing_page(
83
+ self,
84
+ base_url: str,
85
+ conformance_classes: List[str],
86
+ extension_schemas: List[str],
87
+ ) -> stac_types.LandingPage:
88
+ landing_page = stac_types.LandingPage(
89
+ type="Catalog",
90
+ id=self.landing_page_id,
91
+ title=self.title,
92
+ description=self.description,
93
+ stac_version=self.stac_version,
94
+ conformsTo=conformance_classes,
95
+ links=[
96
+ {
97
+ "rel": Relations.self.value,
98
+ "type": MimeTypes.json,
99
+ "href": base_url,
100
+ },
101
+ {
102
+ "rel": Relations.root.value,
103
+ "type": MimeTypes.json,
104
+ "href": base_url,
105
+ },
106
+ {
107
+ "rel": "data",
108
+ "type": MimeTypes.json,
109
+ "href": urljoin(base_url, "collections"),
110
+ },
111
+ {
112
+ "rel": Relations.conformance.value,
113
+ "type": MimeTypes.json,
114
+ "title": "STAC/WFS3 conformance classes implemented by this server",
115
+ "href": urljoin(base_url, "conformance"),
116
+ },
117
+ {
118
+ "rel": Relations.search.value,
119
+ "type": MimeTypes.geojson,
120
+ "title": "STAC search",
121
+ "href": urljoin(base_url, "search"),
122
+ "method": "GET",
123
+ },
124
+ {
125
+ "rel": Relations.search.value,
126
+ "type": MimeTypes.geojson,
127
+ "title": "STAC search",
128
+ "href": urljoin(base_url, "search"),
129
+ "method": "POST",
130
+ },
131
+ ],
132
+ stac_extensions=extension_schemas,
133
+ )
134
+ return landing_page
135
+
136
+ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
137
+ """Landing page.
138
+
139
+ Called with `GET /`.
140
+
141
+ Returns:
142
+ API landing page, serving as an entry point to the API.
143
+ """
144
+ request: Request = kwargs["request"]
145
+ base_url = get_base_url(request)
146
+ landing_page = self._landing_page(
147
+ base_url=base_url,
148
+ conformance_classes=self.conformance_classes(),
149
+ extension_schemas=[],
150
+ )
151
+
152
+ if self.extension_is_enabled("FilterExtension"):
153
+ landing_page["links"].append(
154
+ {
155
+ # TODO: replace this with Relations.queryables.value,
156
+ "rel": "queryables",
157
+ # TODO: replace this with MimeTypes.jsonschema,
158
+ "type": "application/schema+json",
159
+ "title": "Queryables",
160
+ "href": urljoin(base_url, "queryables"),
161
+ }
162
+ )
163
+
164
+ if self.extension_is_enabled("AggregationExtension"):
165
+ landing_page["links"].extend(
166
+ [
167
+ {
168
+ "rel": "aggregate",
169
+ "type": "application/json",
170
+ "title": "Aggregate",
171
+ "href": urljoin(base_url, "aggregate"),
172
+ },
173
+ {
174
+ "rel": "aggregations",
175
+ "type": "application/json",
176
+ "title": "Aggregations",
177
+ "href": urljoin(base_url, "aggregations"),
178
+ },
179
+ ]
180
+ )
181
+
182
+ collections = await self.all_collections(request=kwargs["request"])
183
+ for collection in collections["collections"]:
184
+ landing_page["links"].append(
185
+ {
186
+ "rel": Relations.child.value,
187
+ "type": MimeTypes.json.value,
188
+ "title": collection.get("title") or collection.get("id"),
189
+ "href": urljoin(base_url, f"collections/{collection['id']}"),
190
+ }
191
+ )
192
+
193
+ # Add OpenAPI URL
194
+ landing_page["links"].append(
195
+ {
196
+ "rel": "service-desc",
197
+ "type": "application/vnd.oai.openapi+json;version=3.0",
198
+ "title": "OpenAPI service description",
199
+ "href": urljoin(
200
+ str(request.base_url), request.app.openapi_url.lstrip("/")
201
+ ),
202
+ }
203
+ )
204
+
205
+ # Add human readable service-doc
206
+ landing_page["links"].append(
207
+ {
208
+ "rel": "service-doc",
209
+ "type": "text/html",
210
+ "title": "OpenAPI service documentation",
211
+ "href": urljoin(
212
+ str(request.base_url), request.app.docs_url.lstrip("/")
213
+ ),
214
+ }
215
+ )
216
+
217
+ return landing_page
218
+
219
+ async def all_collections(self, **kwargs) -> stac_types.Collections:
220
+ """Read all collections from the database.
221
+
222
+ Args:
223
+ **kwargs: Keyword arguments from the request.
224
+
225
+ Returns:
226
+ A Collections object containing all the collections in the database and links to various resources.
227
+ """
228
+ request = kwargs["request"]
229
+ base_url = str(request.base_url)
230
+ limit = int(request.query_params.get("limit", 10))
231
+ token = request.query_params.get("token")
232
+
233
+ collections, next_token = await self.database.get_all_collections(
234
+ token=token, limit=limit, request=request
235
+ )
236
+
237
+ links = [
238
+ {"rel": Relations.root.value, "type": MimeTypes.json, "href": base_url},
239
+ {"rel": Relations.parent.value, "type": MimeTypes.json, "href": base_url},
240
+ {
241
+ "rel": Relations.self.value,
242
+ "type": MimeTypes.json,
243
+ "href": urljoin(base_url, "collections"),
244
+ },
245
+ ]
246
+
247
+ if next_token:
248
+ next_link = PagingLinks(next=next_token, request=request).link_next()
249
+ links.append(next_link)
250
+
251
+ return stac_types.Collections(collections=collections, links=links)
252
+
253
+ async def get_collection(
254
+ self, collection_id: str, **kwargs
255
+ ) -> stac_types.Collection:
256
+ """Get a collection from the database by its id.
257
+
258
+ Args:
259
+ collection_id (str): The id of the collection to retrieve.
260
+ kwargs: Additional keyword arguments passed to the API call.
261
+
262
+ Returns:
263
+ Collection: A `Collection` object representing the requested collection.
264
+
265
+ Raises:
266
+ NotFoundError: If the collection with the given id cannot be found in the database.
267
+ """
268
+ request = kwargs["request"]
269
+ collection = await self.database.find_collection(collection_id=collection_id)
270
+ return self.collection_serializer.db_to_stac(
271
+ collection=collection,
272
+ request=request,
273
+ extensions=[type(ext).__name__ for ext in self.extensions],
274
+ )
275
+
276
+ async def item_collection(
277
+ self,
278
+ collection_id: str,
279
+ bbox: Optional[BBox] = None,
280
+ datetime: Optional[str] = None,
281
+ limit: Optional[int] = 10,
282
+ token: Optional[str] = None,
283
+ **kwargs,
284
+ ) -> stac_types.ItemCollection:
285
+ """Read items from a specific collection in the database.
286
+
287
+ Args:
288
+ collection_id (str): The identifier of the collection to read items from.
289
+ bbox (Optional[BBox]): The bounding box to filter items by.
290
+ datetime (Optional[str]): The datetime range to filter items by.
291
+ limit (int): The maximum number of items to return. The default value is 10.
292
+ token (str): A token used for pagination.
293
+ request (Request): The incoming request.
294
+
295
+ Returns:
296
+ ItemCollection: An `ItemCollection` object containing the items from the specified collection that meet
297
+ the filter criteria and links to various resources.
298
+
299
+ Raises:
300
+ HTTPException: If the specified collection is not found.
301
+ Exception: If any error occurs while reading the items from the database.
302
+ """
303
+ request: Request = kwargs["request"]
304
+ token = request.query_params.get("token")
305
+
306
+ base_url = str(request.base_url)
307
+
308
+ collection = await self.get_collection(
309
+ collection_id=collection_id, request=request
310
+ )
311
+ collection_id = collection.get("id")
312
+ if collection_id is None:
313
+ raise HTTPException(status_code=404, detail="Collection not found")
314
+
315
+ search = self.database.make_search()
316
+ search = self.database.apply_collections_filter(
317
+ search=search, collection_ids=[collection_id]
318
+ )
319
+
320
+ if datetime:
321
+ datetime_search = self._return_date(datetime)
322
+ search = self.database.apply_datetime_filter(
323
+ search=search, datetime_search=datetime_search
324
+ )
325
+
326
+ if bbox:
327
+ bbox = [float(x) for x in bbox]
328
+ if len(bbox) == 6:
329
+ bbox = [bbox[0], bbox[1], bbox[3], bbox[4]]
330
+
331
+ search = self.database.apply_bbox_filter(search=search, bbox=bbox)
332
+
333
+ items, maybe_count, next_token = await self.database.execute_search(
334
+ search=search,
335
+ limit=limit,
336
+ sort=None,
337
+ token=token,
338
+ collection_ids=[collection_id],
339
+ )
340
+
341
+ items = [
342
+ self.item_serializer.db_to_stac(item, base_url=base_url) for item in items
343
+ ]
344
+
345
+ links = await PagingLinks(request=request, next=next_token).get_links()
346
+
347
+ return stac_types.ItemCollection(
348
+ type="FeatureCollection",
349
+ features=items,
350
+ links=links,
351
+ numReturned=len(items),
352
+ numMatched=maybe_count,
353
+ )
354
+
355
+ async def get_item(
356
+ self, item_id: str, collection_id: str, **kwargs
357
+ ) -> stac_types.Item:
358
+ """Get an item from the database based on its id and collection id.
359
+
360
+ Args:
361
+ collection_id (str): The ID of the collection the item belongs to.
362
+ item_id (str): The ID of the item to be retrieved.
363
+
364
+ Returns:
365
+ Item: An `Item` object representing the requested item.
366
+
367
+ Raises:
368
+ Exception: If any error occurs while getting the item from the database.
369
+ NotFoundError: If the item does not exist in the specified collection.
370
+ """
371
+ base_url = str(kwargs["request"].base_url)
372
+ item = await self.database.get_one_item(
373
+ item_id=item_id, collection_id=collection_id
374
+ )
375
+ return self.item_serializer.db_to_stac(item, base_url)
376
+
377
+ @staticmethod
378
+ def _return_date(
379
+ interval: Optional[Union[DateTimeType, str]]
380
+ ) -> Dict[str, Optional[str]]:
381
+ """
382
+ Convert a date interval.
383
+
384
+ (which may be a datetime, a tuple of one or two datetimes a string
385
+ representing a datetime or range, or None) into a dictionary for filtering
386
+ search results with Elasticsearch.
387
+
388
+ This function ensures the output dictionary contains 'gte' and 'lte' keys,
389
+ even if they are set to None, to prevent KeyError in the consuming logic.
390
+
391
+ Args:
392
+ interval (Optional[Union[DateTimeType, str]]): The date interval, which might be a single datetime,
393
+ a tuple with one or two datetimes, a string, or None.
394
+
395
+ Returns:
396
+ dict: A dictionary representing the date interval for use in filtering search results,
397
+ always containing 'gte' and 'lte' keys.
398
+ """
399
+ result: Dict[str, Optional[str]] = {"gte": None, "lte": None}
400
+
401
+ if interval is None:
402
+ return result
403
+
404
+ if isinstance(interval, str):
405
+ if "/" in interval:
406
+ parts = interval.split("/")
407
+ result["gte"] = parts[0] if parts[0] != ".." else None
408
+ result["lte"] = (
409
+ parts[1] if len(parts) > 1 and parts[1] != ".." else None
410
+ )
411
+ else:
412
+ converted_time = interval if interval != ".." else None
413
+ result["gte"] = result["lte"] = converted_time
414
+ return result
415
+
416
+ if isinstance(interval, datetime_type):
417
+ datetime_iso = interval.isoformat()
418
+ result["gte"] = result["lte"] = datetime_iso
419
+ elif isinstance(interval, tuple):
420
+ start, end = interval
421
+ # Ensure datetimes are converted to UTC and formatted with 'Z'
422
+ if start:
423
+ result["gte"] = start.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
424
+ if end:
425
+ result["lte"] = end.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
426
+
427
+ return result
428
+
429
+ def _format_datetime_range(self, date_str: str) -> str:
430
+ """
431
+ Convert a datetime range string into a normalized UTC string for API requests using rfc3339_str_to_datetime.
432
+
433
+ Args:
434
+ date_str (str): A string containing two datetime values separated by a '/'.
435
+
436
+ Returns:
437
+ str: A string formatted as 'YYYY-MM-DDTHH:MM:SSZ/YYYY-MM-DDTHH:MM:SSZ', with '..' used if any element is None.
438
+ """
439
+
440
+ def normalize(dt):
441
+ dt = dt.strip()
442
+ if not dt or dt == "..":
443
+ return ".."
444
+ dt_obj = rfc3339_str_to_datetime(dt)
445
+ dt_utc = dt_obj.astimezone(timezone.utc)
446
+ return dt_utc.strftime("%Y-%m-%dT%H:%M:%SZ")
447
+
448
+ if not isinstance(date_str, str):
449
+ return "../.."
450
+ if "/" not in date_str:
451
+ return f"{normalize(date_str)}/{normalize(date_str)}"
452
+ try:
453
+ start, end = date_str.split("/", 1)
454
+ except Exception:
455
+ return "../.."
456
+ return f"{normalize(start)}/{normalize(end)}"
457
+
458
+ async def get_search(
459
+ self,
460
+ request: Request,
461
+ collections: Optional[List[str]] = None,
462
+ ids: Optional[List[str]] = None,
463
+ bbox: Optional[BBox] = None,
464
+ datetime: Optional[str] = None,
465
+ limit: Optional[int] = 10,
466
+ query: Optional[str] = None,
467
+ token: Optional[str] = None,
468
+ fields: Optional[List[str]] = None,
469
+ sortby: Optional[str] = None,
470
+ q: Optional[List[str]] = None,
471
+ intersects: Optional[str] = None,
472
+ filter_expr: Optional[str] = None,
473
+ filter_lang: Optional[str] = None,
474
+ **kwargs,
475
+ ) -> stac_types.ItemCollection:
476
+ """Get search results from the database.
477
+
478
+ Args:
479
+ collections (Optional[List[str]]): List of collection IDs to search in.
480
+ ids (Optional[List[str]]): List of item IDs to search for.
481
+ bbox (Optional[BBox]): Bounding box to search in.
482
+ datetime (Optional[str]): Filter items based on the datetime field.
483
+ limit (Optional[int]): Maximum number of results to return.
484
+ query (Optional[str]): Query string to filter the results.
485
+ token (Optional[str]): Access token to use when searching the catalog.
486
+ fields (Optional[List[str]]): Fields to include or exclude from the results.
487
+ sortby (Optional[str]): Sorting options for the results.
488
+ q (Optional[List[str]]): Free text query to filter the results.
489
+ intersects (Optional[str]): GeoJSON geometry to search in.
490
+ kwargs: Additional parameters to be passed to the API.
491
+
492
+ Returns:
493
+ ItemCollection: Collection of `Item` objects representing the search results.
494
+
495
+ Raises:
496
+ HTTPException: If any error occurs while searching the catalog.
497
+ """
498
+ base_args = {
499
+ "collections": collections,
500
+ "ids": ids,
501
+ "bbox": bbox,
502
+ "limit": limit,
503
+ "token": token,
504
+ "query": orjson.loads(query) if query else query,
505
+ "q": q,
506
+ }
507
+
508
+ if datetime:
509
+ base_args["datetime"] = self._format_datetime_range(date_str=datetime)
510
+
511
+ if intersects:
512
+ base_args["intersects"] = orjson.loads(unquote_plus(intersects))
513
+
514
+ if sortby:
515
+ base_args["sortby"] = [
516
+ {"field": sort[1:], "direction": "desc" if sort[0] == "-" else "asc"}
517
+ for sort in sortby
518
+ ]
519
+
520
+ if filter_expr:
521
+ base_args["filter_lang"] = "cql2-json"
522
+ base_args["filter"] = orjson.loads(
523
+ unquote_plus(filter_expr)
524
+ if filter_lang == "cql2-json"
525
+ else to_cql2(parse_cql2_text(filter_expr))
526
+ )
527
+
528
+ if fields:
529
+ includes, excludes = set(), set()
530
+ for field in fields:
531
+ if field[0] == "-":
532
+ excludes.add(field[1:])
533
+ else:
534
+ includes.add(field[1:] if field[0] in "+ " else field)
535
+ base_args["fields"] = {"include": includes, "exclude": excludes}
536
+
537
+ # Do the request
538
+ try:
539
+ search_request = self.post_request_model(**base_args)
540
+ except ValidationError as e:
541
+ raise HTTPException(
542
+ status_code=400, detail=f"Invalid parameters provided: {e}"
543
+ )
544
+ resp = await self.post_search(search_request=search_request, request=request)
545
+
546
+ return resp
547
+
548
+ async def post_search(
549
+ self, search_request: BaseSearchPostRequest, request: Request
550
+ ) -> stac_types.ItemCollection:
551
+ """
552
+ Perform a POST search on the catalog.
553
+
554
+ Args:
555
+ search_request (BaseSearchPostRequest): Request object that includes the parameters for the search.
556
+ kwargs: Keyword arguments passed to the function.
557
+
558
+ Returns:
559
+ ItemCollection: A collection of items matching the search criteria.
560
+
561
+ Raises:
562
+ HTTPException: If there is an error with the cql2_json filter.
563
+ """
564
+ base_url = str(request.base_url)
565
+
566
+ search = self.database.make_search()
567
+
568
+ if search_request.ids:
569
+ search = self.database.apply_ids_filter(
570
+ search=search, item_ids=search_request.ids
571
+ )
572
+
573
+ if search_request.collections:
574
+ search = self.database.apply_collections_filter(
575
+ search=search, collection_ids=search_request.collections
576
+ )
577
+
578
+ if search_request.datetime:
579
+ datetime_search = self._return_date(search_request.datetime)
580
+ search = self.database.apply_datetime_filter(
581
+ search=search, datetime_search=datetime_search
582
+ )
583
+
584
+ if search_request.bbox:
585
+ bbox = search_request.bbox
586
+ if len(bbox) == 6:
587
+ bbox = [bbox[0], bbox[1], bbox[3], bbox[4]]
588
+
589
+ search = self.database.apply_bbox_filter(search=search, bbox=bbox)
590
+
591
+ if search_request.intersects:
592
+ search = self.database.apply_intersects_filter(
593
+ search=search, intersects=search_request.intersects
594
+ )
595
+
596
+ if search_request.query:
597
+ for field_name, expr in search_request.query.items():
598
+ field = "properties__" + field_name
599
+ for op, value in expr.items():
600
+ # Convert enum to string
601
+ operator = op.value if isinstance(op, Enum) else op
602
+ search = self.database.apply_stacql_filter(
603
+ search=search, op=operator, field=field, value=value
604
+ )
605
+
606
+ # only cql2_json is supported here
607
+ if hasattr(search_request, "filter_expr"):
608
+ cql2_filter = getattr(search_request, "filter_expr", None)
609
+ try:
610
+ search = self.database.apply_cql2_filter(search, cql2_filter)
611
+ except Exception as e:
612
+ raise HTTPException(
613
+ status_code=400, detail=f"Error with cql2_json filter: {e}"
614
+ )
615
+
616
+ if hasattr(search_request, "q"):
617
+ free_text_queries = getattr(search_request, "q", None)
618
+ try:
619
+ search = self.database.apply_free_text_filter(search, free_text_queries)
620
+ except Exception as e:
621
+ raise HTTPException(
622
+ status_code=400, detail=f"Error with free text query: {e}"
623
+ )
624
+
625
+ sort = None
626
+ if search_request.sortby:
627
+ sort = self.database.populate_sort(search_request.sortby)
628
+
629
+ limit = 10
630
+ if search_request.limit:
631
+ limit = search_request.limit
632
+
633
+ items, maybe_count, next_token = await self.database.execute_search(
634
+ search=search,
635
+ limit=limit,
636
+ token=search_request.token,
637
+ sort=sort,
638
+ collection_ids=search_request.collections,
639
+ )
640
+
641
+ fields = (
642
+ getattr(search_request, "fields", None)
643
+ if self.extension_is_enabled("FieldsExtension")
644
+ else None
645
+ )
646
+ include: Set[str] = fields.include if fields and fields.include else set()
647
+ exclude: Set[str] = fields.exclude if fields and fields.exclude else set()
648
+
649
+ items = [
650
+ filter_fields(
651
+ self.item_serializer.db_to_stac(item, base_url=base_url),
652
+ include,
653
+ exclude,
654
+ )
655
+ for item in items
656
+ ]
657
+ links = await PagingLinks(request=request, next=next_token).get_links()
658
+
659
+ return stac_types.ItemCollection(
660
+ type="FeatureCollection",
661
+ features=items,
662
+ links=links,
663
+ numReturned=len(items),
664
+ numMatched=maybe_count,
665
+ )
666
+
667
+
668
+ @attr.s
669
+ class TransactionsClient(AsyncBaseTransactionsClient):
670
+ """Transactions extension specific CRUD operations."""
671
+
672
+ database: BaseDatabaseLogic = attr.ib()
673
+ settings: ApiBaseSettings = attr.ib()
674
+ session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
675
+
676
+ @overrides
677
+ async def create_item(
678
+ self, collection_id: str, item: Union[Item, ItemCollection], **kwargs
679
+ ) -> Optional[stac_types.Item]:
680
+ """Create an item in the collection.
681
+
682
+ Args:
683
+ collection_id (str): The id of the collection to add the item to.
684
+ item (stac_types.Item): The item to be added to the collection.
685
+ kwargs: Additional keyword arguments.
686
+
687
+ Returns:
688
+ stac_types.Item: The created item.
689
+
690
+ Raises:
691
+ NotFound: If the specified collection is not found in the database.
692
+ ConflictError: If the item in the specified collection already exists.
693
+
694
+ """
695
+ item = item.model_dump(mode="json")
696
+ base_url = str(kwargs["request"].base_url)
697
+
698
+ # If a feature collection is posted
699
+ if item["type"] == "FeatureCollection":
700
+ bulk_client = BulkTransactionsClient(
701
+ database=self.database, settings=self.settings
702
+ )
703
+ processed_items = [
704
+ bulk_client.preprocess_item(
705
+ item, base_url, BulkTransactionMethod.INSERT
706
+ )
707
+ for item in item["features"]
708
+ ]
709
+
710
+ await self.database.bulk_async(
711
+ collection_id, processed_items, refresh=kwargs.get("refresh", False)
712
+ )
713
+
714
+ return None
715
+ else:
716
+ item = await self.database.prep_create_item(item=item, base_url=base_url)
717
+ await self.database.create_item(item, refresh=kwargs.get("refresh", False))
718
+ return ItemSerializer.db_to_stac(item, base_url)
719
+
720
+ @overrides
721
+ async def update_item(
722
+ self, collection_id: str, item_id: str, item: Item, **kwargs
723
+ ) -> stac_types.Item:
724
+ """Update an item in the collection.
725
+
726
+ Args:
727
+ collection_id (str): The ID of the collection the item belongs to.
728
+ item_id (str): The ID of the item to be updated.
729
+ item (stac_types.Item): The new item data.
730
+ kwargs: Other optional arguments, including the request object.
731
+
732
+ Returns:
733
+ stac_types.Item: The updated item object.
734
+
735
+ Raises:
736
+ NotFound: If the specified collection is not found in the database.
737
+
738
+ """
739
+ item = item.model_dump(mode="json")
740
+ base_url = str(kwargs["request"].base_url)
741
+ now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z")
742
+ item["properties"]["updated"] = now
743
+
744
+ await self.database.check_collection_exists(collection_id)
745
+ await self.delete_item(item_id=item_id, collection_id=collection_id)
746
+ await self.create_item(collection_id=collection_id, item=Item(**item), **kwargs)
747
+
748
+ return ItemSerializer.db_to_stac(item, base_url)
749
+
750
+ @overrides
751
+ async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> None:
752
+ """Delete an item from a collection.
753
+
754
+ Args:
755
+ item_id (str): The identifier of the item to delete.
756
+ collection_id (str): The identifier of the collection that contains the item.
757
+
758
+ Returns:
759
+ None: Returns 204 No Content on successful deletion
760
+ """
761
+ await self.database.delete_item(item_id=item_id, collection_id=collection_id)
762
+ return None
763
+
764
+ @overrides
765
+ async def create_collection(
766
+ self, collection: Collection, **kwargs
767
+ ) -> stac_types.Collection:
768
+ """Create a new collection in the database.
769
+
770
+ Args:
771
+ collection (stac_types.Collection): The collection to be created.
772
+ kwargs: Additional keyword arguments.
773
+
774
+ Returns:
775
+ stac_types.Collection: The created collection object.
776
+
777
+ Raises:
778
+ ConflictError: If the collection already exists.
779
+ """
780
+ collection = collection.model_dump(mode="json")
781
+ request = kwargs["request"]
782
+ collection = self.database.collection_serializer.stac_to_db(collection, request)
783
+ await self.database.create_collection(collection=collection)
784
+ return CollectionSerializer.db_to_stac(
785
+ collection,
786
+ request,
787
+ extensions=[type(ext).__name__ for ext in self.database.extensions],
788
+ )
789
+
790
+ @overrides
791
+ async def update_collection(
792
+ self, collection_id: str, collection: Collection, **kwargs
793
+ ) -> stac_types.Collection:
794
+ """
795
+ Update a collection.
796
+
797
+ This method updates an existing collection in the database by first finding
798
+ the collection by the id given in the keyword argument `collection_id`.
799
+ If no `collection_id` is given the id of the given collection object is used.
800
+ If the object and keyword collection ids don't match the sub items
801
+ collection id is updated else the items are left unchanged.
802
+ The updated collection is then returned.
803
+
804
+ Args:
805
+ collection_id: id of the existing collection to be updated
806
+ collection: A STAC collection that needs to be updated.
807
+ kwargs: Additional keyword arguments.
808
+
809
+ Returns:
810
+ A STAC collection that has been updated in the database.
811
+
812
+ """
813
+ collection = collection.model_dump(mode="json")
814
+
815
+ request = kwargs["request"]
816
+
817
+ collection = self.database.collection_serializer.stac_to_db(collection, request)
818
+ await self.database.update_collection(
819
+ collection_id=collection_id, collection=collection
820
+ )
821
+
822
+ return CollectionSerializer.db_to_stac(
823
+ collection,
824
+ request,
825
+ extensions=[type(ext).__name__ for ext in self.database.extensions],
826
+ )
827
+
828
+ @overrides
829
+ async def delete_collection(self, collection_id: str, **kwargs) -> None:
830
+ """
831
+ Delete a collection.
832
+
833
+ This method deletes an existing collection in the database.
834
+
835
+ Args:
836
+ collection_id (str): The identifier of the collection to delete
837
+
838
+ Returns:
839
+ None: Returns 204 No Content on successful deletion
840
+
841
+ Raises:
842
+ NotFoundError: If the collection doesn't exist
843
+ """
844
+ await self.database.delete_collection(collection_id=collection_id)
845
+ return None
846
+
847
+
848
+ @attr.s
849
+ class BulkTransactionsClient(BaseBulkTransactionsClient):
850
+ """A client for posting bulk transactions to a Postgres database.
851
+
852
+ Attributes:
853
+ session: An instance of `Session` to use for database connection.
854
+ database: An instance of `DatabaseLogic` to perform database operations.
855
+ """
856
+
857
+ database: BaseDatabaseLogic = attr.ib()
858
+ settings: ApiBaseSettings = attr.ib()
859
+ session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
860
+
861
+ def __attrs_post_init__(self):
862
+ """Create es engine."""
863
+ self.client = self.settings.create_client
864
+
865
+ def preprocess_item(
866
+ self, item: stac_types.Item, base_url, method: BulkTransactionMethod
867
+ ) -> stac_types.Item:
868
+ """Preprocess an item to match the data model.
869
+
870
+ Args:
871
+ item: The item to preprocess.
872
+ base_url: The base URL of the request.
873
+ method: The bulk transaction method.
874
+
875
+ Returns:
876
+ The preprocessed item.
877
+ """
878
+ exist_ok = method == BulkTransactionMethod.UPSERT
879
+ return self.database.sync_prep_create_item(
880
+ item=item, base_url=base_url, exist_ok=exist_ok
881
+ )
882
+
883
+ @overrides
884
+ def bulk_item_insert(
885
+ self, items: Items, chunk_size: Optional[int] = None, **kwargs
886
+ ) -> str:
887
+ """Perform a bulk insertion of items into the database using Elasticsearch.
888
+
889
+ Args:
890
+ items: The items to insert.
891
+ chunk_size: The size of each chunk for bulk processing.
892
+ **kwargs: Additional keyword arguments, such as `request` and `refresh`.
893
+
894
+ Returns:
895
+ A string indicating the number of items successfully added.
896
+ """
897
+ request = kwargs.get("request")
898
+ if request:
899
+ base_url = str(request.base_url)
900
+ else:
901
+ base_url = ""
902
+
903
+ processed_items = [
904
+ self.preprocess_item(item, base_url, items.method)
905
+ for item in items.items.values()
906
+ ]
907
+
908
+ # not a great way to get the collection_id-- should be part of the method signature
909
+ collection_id = processed_items[0]["collection"]
910
+
911
+ self.database.bulk_sync(
912
+ collection_id, processed_items, refresh=kwargs.get("refresh", False)
913
+ )
914
+
915
+ return f"Successfully added {len(processed_items)} Items."
916
+
917
+
918
+ _DEFAULT_QUERYABLES: Dict[str, Dict[str, Any]] = {
919
+ "id": {
920
+ "description": "ID",
921
+ "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/item.json#/definitions/core/allOf/2/properties/id",
922
+ },
923
+ "collection": {
924
+ "description": "Collection",
925
+ "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/item.json#/definitions/core/allOf/2/then/properties/collection",
926
+ },
927
+ "geometry": {
928
+ "description": "Geometry",
929
+ "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/item.json#/definitions/core/allOf/1/oneOf/0/properties/geometry",
930
+ },
931
+ "datetime": {
932
+ "description": "Acquisition Timestamp",
933
+ "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/datetime.json#/properties/datetime",
934
+ },
935
+ "created": {
936
+ "description": "Creation Timestamp",
937
+ "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/datetime.json#/properties/created",
938
+ },
939
+ "updated": {
940
+ "description": "Creation Timestamp",
941
+ "$ref": "https://schemas.stacspec.org/v1.0.0/item-spec/json-schema/datetime.json#/properties/updated",
942
+ },
943
+ "cloud_cover": {
944
+ "description": "Cloud Cover",
945
+ "$ref": "https://stac-extensions.github.io/eo/v1.0.0/schema.json#/definitions/fields/properties/eo:cloud_cover",
946
+ },
947
+ "cloud_shadow_percentage": {
948
+ "title": "Cloud Shadow Percentage",
949
+ "description": "Cloud Shadow Percentage",
950
+ "type": "number",
951
+ "minimum": 0,
952
+ "maximum": 100,
953
+ },
954
+ "nodata_pixel_percentage": {
955
+ "title": "No Data Pixel Percentage",
956
+ "description": "No Data Pixel Percentage",
957
+ "type": "number",
958
+ "minimum": 0,
959
+ "maximum": 100,
960
+ },
961
+ }
962
+
963
+ _ES_MAPPING_TYPE_TO_JSON: Dict[
964
+ str, Literal["string", "number", "boolean", "object", "array", "null"]
965
+ ] = {
966
+ "date": "string",
967
+ "date_nanos": "string",
968
+ "keyword": "string",
969
+ "match_only_text": "string",
970
+ "text": "string",
971
+ "wildcard": "string",
972
+ "byte": "number",
973
+ "double": "number",
974
+ "float": "number",
975
+ "half_float": "number",
976
+ "long": "number",
977
+ "scaled_float": "number",
978
+ "short": "number",
979
+ "token_count": "number",
980
+ "unsigned_long": "number",
981
+ "geo_point": "object",
982
+ "geo_shape": "object",
983
+ "nested": "array",
984
+ }
985
+
986
+
987
+ @attr.s
988
+ class EsAsyncBaseFiltersClient(AsyncBaseFiltersClient):
989
+ """Defines a pattern for implementing the STAC filter extension."""
990
+
991
+ database: BaseDatabaseLogic = attr.ib()
992
+
993
+ async def get_queryables(
994
+ self, collection_id: Optional[str] = None, **kwargs
995
+ ) -> Dict[str, Any]:
996
+ """Get the queryables available for the given collection_id.
997
+
998
+ If collection_id is None, returns the intersection of all
999
+ queryables over all collections.
1000
+
1001
+ This base implementation returns a blank queryable schema. This is not allowed
1002
+ under OGC CQL but it is allowed by the STAC API Filter Extension
1003
+
1004
+ https://github.com/radiantearth/stac-api-spec/tree/master/fragments/filter#queryables
1005
+
1006
+ Args:
1007
+ collection_id (str, optional): The id of the collection to get queryables for.
1008
+ **kwargs: additional keyword arguments
1009
+
1010
+ Returns:
1011
+ Dict[str, Any]: A dictionary containing the queryables for the given collection.
1012
+ """
1013
+ queryables: Dict[str, Any] = {
1014
+ "$schema": "https://json-schema.org/draft/2019-09/schema",
1015
+ "$id": "https://stac-api.example.com/queryables",
1016
+ "type": "object",
1017
+ "title": "Queryables for STAC API",
1018
+ "description": "Queryable names for the STAC API Item Search filter.",
1019
+ "properties": _DEFAULT_QUERYABLES,
1020
+ "additionalProperties": True,
1021
+ }
1022
+ if not collection_id:
1023
+ return queryables
1024
+
1025
+ properties: Dict[str, Any] = queryables["properties"]
1026
+ queryables.update(
1027
+ {
1028
+ "properties": properties,
1029
+ "additionalProperties": False,
1030
+ }
1031
+ )
1032
+
1033
+ mapping_data = await self.database.get_items_mapping(collection_id)
1034
+ mapping_properties = next(iter(mapping_data.values()))["mappings"]["properties"]
1035
+ stack = deque(mapping_properties.items())
1036
+
1037
+ while stack:
1038
+ field_name, field_def = stack.popleft()
1039
+
1040
+ # Iterate over nested fields
1041
+ field_properties = field_def.get("properties")
1042
+ if field_properties:
1043
+ # Fields in Item Properties should be exposed with their un-prefixed names,
1044
+ # and not require expressions to prefix them with properties,
1045
+ # e.g., eo:cloud_cover instead of properties.eo:cloud_cover.
1046
+ if field_name == "properties":
1047
+ stack.extend(field_properties.items())
1048
+ else:
1049
+ stack.extend(
1050
+ (f"{field_name}.{k}", v) for k, v in field_properties.items()
1051
+ )
1052
+
1053
+ # Skip non-indexed or disabled fields
1054
+ field_type = field_def.get("type")
1055
+ if not field_type or not field_def.get("enabled", True):
1056
+ continue
1057
+
1058
+ # Generate field properties
1059
+ field_result = _DEFAULT_QUERYABLES.get(field_name, {})
1060
+ properties[field_name] = field_result
1061
+
1062
+ field_name_human = field_name.replace("_", " ").title()
1063
+ field_result.setdefault("title", field_name_human)
1064
+
1065
+ field_type_json = _ES_MAPPING_TYPE_TO_JSON.get(field_type, field_type)
1066
+ field_result.setdefault("type", field_type_json)
1067
+
1068
+ if field_type in {"date", "date_nanos"}:
1069
+ field_result.setdefault("format", "date-time")
1070
+
1071
+ return queryables