stac-fastapi-opensearch 6.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1725 @@
1
+ """Database logic."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from base64 import urlsafe_b64decode, urlsafe_b64encode
6
+ from collections.abc import Iterable
7
+ from copy import deepcopy
8
+ from typing import Any, Dict, List, Optional, Tuple, Type
9
+
10
+ import attr
11
+ import orjson
12
+ from fastapi import HTTPException
13
+ from opensearchpy import exceptions, helpers
14
+ from opensearchpy.helpers.query import Q
15
+ from opensearchpy.helpers.search import Search
16
+ from starlette.requests import Request
17
+
18
+ from stac_fastapi.core.base_database_logic import BaseDatabaseLogic
19
+ from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer
20
+ from stac_fastapi.core.utilities import MAX_LIMIT, bbox2polygon, get_bool_env
21
+ from stac_fastapi.extensions.core.transaction.request import (
22
+ PartialCollection,
23
+ PartialItem,
24
+ PatchOperation,
25
+ )
26
+ from stac_fastapi.opensearch.config import (
27
+ AsyncOpensearchSettings as AsyncSearchSettings,
28
+ )
29
+ from stac_fastapi.opensearch.config import OpensearchSettings as SyncSearchSettings
30
+ from stac_fastapi.sfeos_helpers import filter as filter_module
31
+ from stac_fastapi.sfeos_helpers.database import (
32
+ add_bbox_shape_to_collection,
33
+ apply_collections_bbox_filter_shared,
34
+ apply_collections_datetime_filter_shared,
35
+ apply_free_text_filter_shared,
36
+ apply_intersects_filter_shared,
37
+ create_index_templates_shared,
38
+ delete_item_index_shared,
39
+ get_queryables_mapping_shared,
40
+ index_alias_by_collection_id,
41
+ mk_actions,
42
+ mk_item_id,
43
+ populate_sort_shared,
44
+ return_date,
45
+ validate_refresh,
46
+ )
47
+ from stac_fastapi.sfeos_helpers.database.query import (
48
+ ES_MAX_URL_LENGTH,
49
+ add_collections_to_body,
50
+ )
51
+ from stac_fastapi.sfeos_helpers.database.utils import (
52
+ merge_to_operations,
53
+ operations_to_script,
54
+ )
55
+ from stac_fastapi.sfeos_helpers.mappings import (
56
+ AGGREGATION_MAPPING,
57
+ COLLECTIONS_INDEX,
58
+ DEFAULT_SORT,
59
+ ES_COLLECTIONS_MAPPINGS,
60
+ ITEM_INDICES,
61
+ ITEMS_INDEX_PREFIX,
62
+ Geometry,
63
+ )
64
+ from stac_fastapi.sfeos_helpers.search_engine import (
65
+ BaseIndexInserter,
66
+ BaseIndexSelector,
67
+ IndexInsertionFactory,
68
+ IndexSelectorFactory,
69
+ )
70
+ from stac_fastapi.types.errors import ConflictError, NotFoundError
71
+ from stac_fastapi.types.links import resolve_links
72
+ from stac_fastapi.types.stac import Collection, Item
73
+
74
+ logger = logging.getLogger(__name__)
75
+
76
+
77
+ async def create_index_templates() -> None:
78
+ """
79
+ Create index templates for the Collection and Item indices.
80
+
81
+ Returns:
82
+ None
83
+
84
+ """
85
+ await create_index_templates_shared(settings=AsyncSearchSettings())
86
+
87
+
88
+ async def create_collection_index() -> None:
89
+ """
90
+ Create the index for a Collection. The settings of the index template will be used implicitly.
91
+
92
+ Returns:
93
+ None
94
+
95
+ """
96
+ client = AsyncSearchSettings().create_client
97
+
98
+ index = f"{COLLECTIONS_INDEX}-000001"
99
+
100
+ exists = await client.indices.exists(index=index)
101
+ if not exists:
102
+ await client.indices.create(
103
+ index=index,
104
+ body={
105
+ "aliases": {COLLECTIONS_INDEX: {}},
106
+ "mappings": ES_COLLECTIONS_MAPPINGS,
107
+ },
108
+ )
109
+ await client.close()
110
+
111
+
112
+ async def delete_item_index(collection_id: str) -> None:
113
+ """Delete the index for items in a collection.
114
+
115
+ Args:
116
+ collection_id (str): The ID of the collection whose items index will be deleted.
117
+
118
+ Notes:
119
+ This function delegates to the shared implementation in delete_item_index_shared.
120
+ """
121
+ await delete_item_index_shared(
122
+ settings=AsyncSearchSettings(), collection_id=collection_id
123
+ )
124
+
125
+
126
+ @attr.s
127
+ class DatabaseLogic(BaseDatabaseLogic):
128
+ """Database logic."""
129
+
130
+ async_settings: AsyncSearchSettings = attr.ib(factory=AsyncSearchSettings)
131
+ sync_settings: SyncSearchSettings = attr.ib(factory=SyncSearchSettings)
132
+
133
+ async_index_selector: BaseIndexSelector = attr.ib(init=False)
134
+ async_index_inserter: BaseIndexInserter = attr.ib(init=False)
135
+
136
+ client = attr.ib(init=False)
137
+ sync_client = attr.ib(init=False)
138
+
139
+ def __attrs_post_init__(self):
140
+ """Initialize clients after the class is instantiated."""
141
+ self.client = self.async_settings.create_client
142
+ self.sync_client = self.sync_settings.create_client
143
+ self.async_index_inserter = IndexInsertionFactory.create_insertion_strategy(
144
+ self.client
145
+ )
146
+ self.async_index_selector = IndexSelectorFactory.create_selector(self.client)
147
+
148
+ item_serializer: Type[ItemSerializer] = attr.ib(default=ItemSerializer)
149
+ collection_serializer: Type[CollectionSerializer] = attr.ib(
150
+ default=CollectionSerializer
151
+ )
152
+
153
+ extensions: List[str] = attr.ib(default=attr.Factory(list))
154
+
155
+ aggregation_mapping: Dict[str, Dict[str, Any]] = AGGREGATION_MAPPING
156
+
157
+ """CORE LOGIC"""
158
+
159
+ async def get_all_collections(
160
+ self,
161
+ token: Optional[str],
162
+ limit: int,
163
+ request: Request,
164
+ sort: Optional[List[Dict[str, Any]]] = None,
165
+ bbox: Optional[List[float]] = None,
166
+ q: Optional[List[str]] = None,
167
+ filter: Optional[Dict[str, Any]] = None,
168
+ query: Optional[Dict[str, Dict[str, Any]]] = None,
169
+ datetime: Optional[str] = None,
170
+ ) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[int]]:
171
+ """Retrieve a list of collections from OpenSearch, supporting pagination.
172
+
173
+ Args:
174
+ token (Optional[str]): The pagination token.
175
+ limit (int): The number of results to return.
176
+ request (Request): The FastAPI request object.
177
+ sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
178
+ bbox (Optional[List[float]]): Bounding box to filter collections by spatial extent.
179
+ q (Optional[List[str]]): Free text search terms.
180
+ query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
181
+ filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
182
+ datetime (Optional[str]): Temporal filter.
183
+
184
+ Returns:
185
+ A tuple of (collections, next pagination token if any).
186
+
187
+ Raises:
188
+ HTTPException: If sorting is requested on a field that is not sortable.
189
+ """
190
+ # Define sortable fields based on the ES_COLLECTIONS_MAPPINGS
191
+ sortable_fields = ["id", "extent.temporal.interval", "temporal"]
192
+
193
+ # Format the sort parameter
194
+ formatted_sort = []
195
+ if sort:
196
+ for item in sort:
197
+ field = item.get("field")
198
+ direction = item.get("direction", "asc")
199
+ if field:
200
+ # Validate that the field is sortable
201
+ if field not in sortable_fields:
202
+ raise HTTPException(
203
+ status_code=400,
204
+ detail=f"Field '{field}' is not sortable. Sortable fields are: {', '.join(sortable_fields)}. "
205
+ + "Text fields are not sortable by default in OpenSearch. "
206
+ + "To make a field sortable, update the mapping to use 'keyword' type or add a '.keyword' subfield. ",
207
+ )
208
+ formatted_sort.append({field: {"order": direction}})
209
+ # Always include id as a secondary sort to ensure consistent pagination
210
+ if not any("id" in item for item in formatted_sort):
211
+ formatted_sort.append({"id": {"order": "asc"}})
212
+ else:
213
+ formatted_sort = [{"id": {"order": "asc"}}]
214
+
215
+ body = {
216
+ "sort": formatted_sort,
217
+ "size": limit,
218
+ }
219
+
220
+ # Handle search_after token - split by '|' to get all sort values
221
+ search_after = None
222
+ if token:
223
+ try:
224
+ # The token should be a pipe-separated string of sort values
225
+ # e.g., "2023-01-01T00:00:00Z|collection-1"
226
+ search_after = token.split("|")
227
+ # If the number of sort fields doesn't match token parts, ignore the token
228
+ if len(search_after) != len(formatted_sort):
229
+ search_after = None
230
+ except Exception:
231
+ search_after = None
232
+
233
+ if search_after is not None:
234
+ body["search_after"] = search_after
235
+
236
+ # Build the query part of the body
237
+ query_parts = []
238
+
239
+ # Apply free text query if provided
240
+ if q:
241
+ # For collections, we want to search across all relevant fields
242
+ should_clauses = []
243
+
244
+ # For each search term
245
+ for term in q:
246
+ # Create a multi_match query for each term
247
+ for field in [
248
+ "id",
249
+ "title",
250
+ "description",
251
+ "keywords",
252
+ "summaries.platform",
253
+ "summaries.constellation",
254
+ "providers.name",
255
+ "providers.url",
256
+ ]:
257
+ should_clauses.append(
258
+ {
259
+ "wildcard": {
260
+ field: {"value": f"*{term}*", "case_insensitive": True}
261
+ }
262
+ }
263
+ )
264
+
265
+ # Add the free text query to the query parts
266
+ query_parts.append(
267
+ {"bool": {"should": should_clauses, "minimum_should_match": 1}}
268
+ )
269
+
270
+ # Apply structured filter if provided
271
+ if filter:
272
+ # Convert string filter to dict if needed
273
+ if isinstance(filter, str):
274
+ filter = orjson.loads(filter)
275
+ # Convert the filter to an OpenSearch query using the filter module
276
+ es_query = filter_module.to_es(await self.get_queryables_mapping(), filter)
277
+ query_parts.append(es_query)
278
+
279
+ # Apply query extension if provided
280
+ if query:
281
+ try:
282
+ # First create a search object to apply filters
283
+ search = Search(index=COLLECTIONS_INDEX)
284
+
285
+ # Process each field and operator in the query
286
+ for field_name, expr in query.items():
287
+ for op, value in expr.items():
288
+ # For collections, we don't need to prefix with 'properties__'
289
+ field = field_name
290
+ # Apply the filter using apply_stacql_filter
291
+ search = self.apply_stacql_filter(
292
+ search=search, op=op, field=field, value=value
293
+ )
294
+
295
+ # Convert the search object to a query dict and add it to query_parts
296
+ search_dict = search.to_dict()
297
+ if "query" in search_dict:
298
+ query_parts.append(search_dict["query"])
299
+
300
+ except Exception as e:
301
+ logger.error(f"Error converting query to OpenSearch: {e}")
302
+ # If there's an error, add a query that matches nothing
303
+ query_parts.append({"bool": {"must_not": {"match_all": {}}}})
304
+ raise
305
+
306
+ # Apply bbox filter if provided
307
+ bbox_filter = apply_collections_bbox_filter_shared(bbox)
308
+ if bbox_filter:
309
+ query_parts.append(bbox_filter)
310
+
311
+ # Apply datetime filter if provided
312
+ datetime_filter = apply_collections_datetime_filter_shared(datetime)
313
+ if datetime_filter:
314
+ query_parts.append(datetime_filter)
315
+
316
+ # Combine all query parts with AND logic
317
+ if query_parts:
318
+ body["query"] = (
319
+ query_parts[0]
320
+ if len(query_parts) == 1
321
+ else {"bool": {"must": query_parts}}
322
+ )
323
+
324
+ # Create async tasks for both search and count
325
+ search_task = asyncio.create_task(
326
+ self.client.search(
327
+ index=COLLECTIONS_INDEX,
328
+ body=body,
329
+ )
330
+ )
331
+
332
+ count_task = asyncio.create_task(
333
+ self.client.count(
334
+ index=COLLECTIONS_INDEX,
335
+ body={"query": body.get("query", {"match_all": {}})},
336
+ )
337
+ )
338
+
339
+ # Wait for search task to complete
340
+ response = await search_task
341
+
342
+ hits = response["hits"]["hits"]
343
+ collections = [
344
+ self.collection_serializer.db_to_stac(
345
+ collection=hit["_source"], request=request, extensions=self.extensions
346
+ )
347
+ for hit in hits
348
+ ]
349
+
350
+ next_token = None
351
+ if len(hits) == limit:
352
+ next_token_values = hits[-1].get("sort")
353
+ if next_token_values:
354
+ # Join all sort values with '|' to create the token
355
+ next_token = "|".join(str(val) for val in next_token_values)
356
+
357
+ # Get the total count of collections
358
+ matched = (
359
+ response["hits"]["total"]["value"]
360
+ if response["hits"]["total"]["relation"] == "eq"
361
+ else None
362
+ )
363
+
364
+ # If count task is done, use its result
365
+ if count_task.done():
366
+ try:
367
+ matched = count_task.result().get("count")
368
+ except Exception as e:
369
+ logger.error(f"Count task failed: {e}")
370
+
371
+ return collections, next_token, matched
372
+
373
+ async def get_one_item(self, collection_id: str, item_id: str) -> Dict:
374
+ """Retrieve a single item from the database.
375
+
376
+ Args:
377
+ collection_id (str): The id of the Collection that the Item belongs to.
378
+ item_id (str): The id of the Item.
379
+
380
+ Returns:
381
+ item (Dict): A dictionary containing the source data for the Item.
382
+
383
+ Raises:
384
+ NotFoundError: If the specified Item does not exist in the Collection.
385
+
386
+ Notes:
387
+ The Item is retrieved from the Opensearch database using the `client.get` method,
388
+ with the index for the Collection as the target index and the combined `mk_item_id` as the document id.
389
+ """
390
+ try:
391
+ response = await self.client.search(
392
+ index=index_alias_by_collection_id(collection_id),
393
+ body={
394
+ "query": {"term": {"_id": mk_item_id(item_id, collection_id)}},
395
+ "size": 1,
396
+ },
397
+ )
398
+ if response["hits"]["total"]["value"] == 0:
399
+ raise NotFoundError(
400
+ f"Item {item_id} does not exist inside Collection {collection_id}"
401
+ )
402
+
403
+ return response["hits"]["hits"][0]["_source"]
404
+ except exceptions.NotFoundError:
405
+ raise NotFoundError(
406
+ f"Item {item_id} does not exist inside Collection {collection_id}"
407
+ )
408
+
409
+ async def get_queryables_mapping(self, collection_id: str = "*") -> dict:
410
+ """Retrieve mapping of Queryables for search.
411
+
412
+ Args:
413
+ collection_id (str, optional): The id of the Collection the Queryables
414
+ belongs to. Defaults to "*".
415
+
416
+ Returns:
417
+ dict: A dictionary containing the Queryables mappings.
418
+ """
419
+ mappings = await self.client.indices.get_mapping(
420
+ index=f"{ITEMS_INDEX_PREFIX}{collection_id}",
421
+ )
422
+ return await get_queryables_mapping_shared(
423
+ collection_id=collection_id, mappings=mappings
424
+ )
425
+
426
+ @staticmethod
427
+ def make_search():
428
+ """Database logic to create a Search instance."""
429
+ return Search().sort(*DEFAULT_SORT)
430
+
431
+ @staticmethod
432
+ def apply_ids_filter(search: Search, item_ids: List[str]):
433
+ """Database logic to search a list of STAC item ids."""
434
+ return search.filter("terms", id=item_ids)
435
+
436
+ @staticmethod
437
+ def apply_collections_filter(search: Search, collection_ids: List[str]):
438
+ """Database logic to search a list of STAC collection ids."""
439
+ return search.filter("terms", collection=collection_ids)
440
+
441
+ @staticmethod
442
+ def apply_free_text_filter(search: Search, free_text_queries: Optional[List[str]]):
443
+ """Create a free text query for OpenSearch queries.
444
+
445
+ This method delegates to the shared implementation in apply_free_text_filter_shared.
446
+
447
+ Args:
448
+ search (Search): The search object to apply the query to.
449
+ free_text_queries (Optional[List[str]]): A list of text strings to search for in the properties.
450
+
451
+ Returns:
452
+ Search: The search object with the free text query applied, or the original search
453
+ object if no free_text_queries were provided.
454
+ """
455
+ return apply_free_text_filter_shared(
456
+ search=search, free_text_queries=free_text_queries
457
+ )
458
+
459
+ @staticmethod
460
+ def apply_datetime_filter(
461
+ search: Search, datetime: Optional[str]
462
+ ) -> Tuple[Search, Dict[str, Optional[str]]]:
463
+ """Apply a filter to search on datetime, start_datetime, and end_datetime fields.
464
+
465
+ Args:
466
+ search: The search object to filter.
467
+ datetime: Optional[str]
468
+
469
+ Returns:
470
+ The filtered search object.
471
+ """
472
+ datetime_search = return_date(datetime)
473
+
474
+ if not datetime_search:
475
+ return search, datetime_search
476
+
477
+ # USE_DATETIME env var
478
+ # True: Search by datetime, if null search by start/end datetime
479
+ # False: Always search only by start/end datetime
480
+ USE_DATETIME = get_bool_env("USE_DATETIME", default=True)
481
+
482
+ if USE_DATETIME:
483
+ if "eq" in datetime_search:
484
+ # For exact matches, include:
485
+ # 1. Items with matching exact datetime
486
+ # 2. Items with datetime:null where the time falls within their range
487
+ should = [
488
+ Q(
489
+ "bool",
490
+ filter=[
491
+ Q("exists", field="properties.datetime"),
492
+ Q(
493
+ "term",
494
+ **{"properties__datetime": datetime_search["eq"]},
495
+ ),
496
+ ],
497
+ ),
498
+ Q(
499
+ "bool",
500
+ must_not=[Q("exists", field="properties.datetime")],
501
+ filter=[
502
+ Q("exists", field="properties.start_datetime"),
503
+ Q("exists", field="properties.end_datetime"),
504
+ Q(
505
+ "range",
506
+ properties__start_datetime={
507
+ "lte": datetime_search["eq"]
508
+ },
509
+ ),
510
+ Q(
511
+ "range",
512
+ properties__end_datetime={"gte": datetime_search["eq"]},
513
+ ),
514
+ ],
515
+ ),
516
+ ]
517
+ else:
518
+ # For date ranges, include:
519
+ # 1. Items with datetime in the range
520
+ # 2. Items with datetime:null that overlap the search range
521
+ should = [
522
+ Q(
523
+ "bool",
524
+ filter=[
525
+ Q("exists", field="properties.datetime"),
526
+ Q(
527
+ "range",
528
+ properties__datetime={
529
+ "gte": datetime_search["gte"],
530
+ "lte": datetime_search["lte"],
531
+ },
532
+ ),
533
+ ],
534
+ ),
535
+ Q(
536
+ "bool",
537
+ must_not=[Q("exists", field="properties.datetime")],
538
+ filter=[
539
+ Q("exists", field="properties.start_datetime"),
540
+ Q("exists", field="properties.end_datetime"),
541
+ Q(
542
+ "range",
543
+ properties__start_datetime={
544
+ "lte": datetime_search["lte"]
545
+ },
546
+ ),
547
+ Q(
548
+ "range",
549
+ properties__end_datetime={
550
+ "gte": datetime_search["gte"]
551
+ },
552
+ ),
553
+ ],
554
+ ),
555
+ ]
556
+
557
+ return (
558
+ search.query(Q("bool", should=should, minimum_should_match=1)),
559
+ datetime_search,
560
+ )
561
+ else:
562
+ if "eq" in datetime_search:
563
+ filter_query = Q(
564
+ "bool",
565
+ filter=[
566
+ Q("exists", field="properties.start_datetime"),
567
+ Q("exists", field="properties.end_datetime"),
568
+ Q(
569
+ "range",
570
+ properties__start_datetime={"lte": datetime_search["eq"]},
571
+ ),
572
+ Q(
573
+ "range",
574
+ properties__end_datetime={"gte": datetime_search["eq"]},
575
+ ),
576
+ ],
577
+ )
578
+ else:
579
+ filter_query = Q(
580
+ "bool",
581
+ filter=[
582
+ Q("exists", field="properties.start_datetime"),
583
+ Q("exists", field="properties.end_datetime"),
584
+ Q(
585
+ "range",
586
+ properties__start_datetime={"lte": datetime_search["lte"]},
587
+ ),
588
+ Q(
589
+ "range",
590
+ properties__end_datetime={"gte": datetime_search["gte"]},
591
+ ),
592
+ ],
593
+ )
594
+ return search.query(filter_query), datetime_search
595
+
596
+ @staticmethod
597
+ def apply_bbox_filter(search: Search, bbox: List):
598
+ """Filter search results based on bounding box.
599
+
600
+ Args:
601
+ search (Search): The search object to apply the filter to.
602
+ bbox (List): The bounding box coordinates, represented as a list of four values [minx, miny, maxx, maxy].
603
+
604
+ Returns:
605
+ search (Search): The search object with the bounding box filter applied.
606
+
607
+ Notes:
608
+ The bounding box is transformed into a polygon using the `bbox2polygon` function and
609
+ a geo_shape filter is added to the search object, set to intersect with the specified polygon.
610
+ """
611
+ return search.filter(
612
+ Q(
613
+ {
614
+ "geo_shape": {
615
+ "geometry": {
616
+ "shape": {
617
+ "type": "polygon",
618
+ "coordinates": bbox2polygon(*bbox),
619
+ },
620
+ "relation": "intersects",
621
+ }
622
+ }
623
+ }
624
+ )
625
+ )
626
+
627
+ @staticmethod
628
+ def apply_intersects_filter(
629
+ search: Search,
630
+ intersects: Geometry,
631
+ ):
632
+ """Filter search results based on intersecting geometry.
633
+
634
+ Args:
635
+ search (Search): The search object to apply the filter to.
636
+ intersects (Geometry): The intersecting geometry, represented as a GeoJSON-like object.
637
+
638
+ Returns:
639
+ search (Search): The search object with the intersecting geometry filter applied.
640
+
641
+ Notes:
642
+ A geo_shape filter is added to the search object, set to intersect with the specified geometry.
643
+ """
644
+ filter = apply_intersects_filter_shared(intersects=intersects)
645
+ return search.filter(Q(filter))
646
+
647
+ @staticmethod
648
+ def apply_stacql_filter(search: Search, op: str, field: str, value: float):
649
+ """Filter search results based on a comparison between a field and a value.
650
+
651
+ Args:
652
+ search (Search): The search object to apply the filter to.
653
+ op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than),
654
+ 'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal).
655
+ field (str): The field to perform the comparison on.
656
+ value (float): The value to compare the field against.
657
+
658
+ Returns:
659
+ search (Search): The search object with the specified filter applied.
660
+ """
661
+ if op == "eq":
662
+ search = search.filter("term", **{field: value})
663
+ elif op == "ne" or op == "neq":
664
+ # For not equal, use a bool query with must_not
665
+ search = search.exclude("term", **{field: value})
666
+ elif op in ["gt", "gte", "lt", "lte"]:
667
+ # For range operators
668
+ key_filter = {field: {op: value}}
669
+ search = search.filter(Q("range", **key_filter))
670
+ elif op == "in":
671
+ # For in operator (value should be a list)
672
+ if isinstance(value, list):
673
+ search = search.filter("terms", **{field: value})
674
+ else:
675
+ search = search.filter("term", **{field: value})
676
+ elif op == "contains":
677
+ # For contains operator (for arrays)
678
+ search = search.filter("term", **{field: value})
679
+
680
+ return search
681
+
682
+ async def apply_cql2_filter(
683
+ self, search: Search, _filter: Optional[Dict[str, Any]]
684
+ ):
685
+ """
686
+ Apply a CQL2 filter to an Opensearch Search object.
687
+
688
+ This method transforms a dictionary representing a CQL2 filter into an Opensearch query
689
+ and applies it to the provided Search object. If the filter is None, the original Search
690
+ object is returned unmodified.
691
+
692
+ Args:
693
+ search (Search): The Opensearch Search object to which the filter will be applied.
694
+ _filter (Optional[Dict[str, Any]]): The filter in dictionary form that needs to be applied
695
+ to the search. The dictionary should follow the structure
696
+ required by the `to_es` function which converts it
697
+ to an Opensearch query.
698
+
699
+ Returns:
700
+ Search: The modified Search object with the filter applied if a filter is provided,
701
+ otherwise the original Search object.
702
+ """
703
+ if _filter is not None:
704
+ es_query = filter_module.to_es(await self.get_queryables_mapping(), _filter)
705
+ search = search.filter(es_query)
706
+
707
+ return search
708
+
709
+ @staticmethod
710
+ def populate_sort(sortby: List) -> Optional[Dict[str, Dict[str, str]]]:
711
+ """Create a sort configuration for OpenSearch queries.
712
+
713
+ This method delegates to the shared implementation in populate_sort_shared.
714
+
715
+ Args:
716
+ sortby (List): A list of sort specifications, each containing a field and direction.
717
+
718
+ Returns:
719
+ Optional[Dict[str, Dict[str, str]]]: A dictionary mapping field names to sort direction
720
+ configurations, or None if no sort was specified.
721
+ """
722
+ return populate_sort_shared(sortby=sortby)
723
+
724
+ async def execute_search(
725
+ self,
726
+ search: Search,
727
+ limit: int,
728
+ token: Optional[str],
729
+ sort: Optional[Dict[str, Dict[str, str]]],
730
+ collection_ids: Optional[List[str]],
731
+ datetime_search: Dict[str, Optional[str]],
732
+ ignore_unavailable: bool = True,
733
+ ) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]:
734
+ """Execute a search query with limit and other optional parameters.
735
+
736
+ Args:
737
+ search (Search): The search query to be executed.
738
+ limit (int): The maximum number of results to be returned.
739
+ token (Optional[str]): The token used to return the next set of results.
740
+ sort (Optional[Dict[str, Dict[str, str]]]): Specifies how the results should be sorted.
741
+ collection_ids (Optional[List[str]]): The collection ids to search.
742
+ datetime_search (Dict[str, Optional[str]]): Datetime range used for index selection.
743
+ ignore_unavailable (bool, optional): Whether to ignore unavailable collections. Defaults to True.
744
+
745
+ Returns:
746
+ Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]: A tuple containing:
747
+ - An iterable of search results, where each result is a dictionary with keys and values representing the
748
+ fields and values of each document.
749
+ - The total number of results (if the count could be computed), or None if the count could not be
750
+ computed.
751
+ - The token to be used to retrieve the next set of results, or None if there are no more results.
752
+
753
+ Raises:
754
+ NotFoundError: If the collections specified in `collection_ids` do not exist.
755
+ """
756
+ search_body: Dict[str, Any] = {}
757
+ query = search.query.to_dict() if search.query else None
758
+
759
+ index_param = await self.async_index_selector.select_indexes(
760
+ collection_ids, datetime_search
761
+ )
762
+ if len(index_param) > ES_MAX_URL_LENGTH - 300:
763
+ index_param = ITEM_INDICES
764
+ query = add_collections_to_body(collection_ids, query)
765
+
766
+ if query:
767
+ search_body["query"] = query
768
+
769
+ search_after = None
770
+
771
+ if token:
772
+ search_after = orjson.loads(urlsafe_b64decode(token))
773
+ if search_after:
774
+ search_body["search_after"] = search_after
775
+
776
+ search_body["sort"] = sort if sort else DEFAULT_SORT
777
+
778
+ max_result_window = MAX_LIMIT
779
+
780
+ size_limit = min(limit + 1, max_result_window)
781
+
782
+ search_task = asyncio.create_task(
783
+ self.client.search(
784
+ index=index_param,
785
+ ignore_unavailable=ignore_unavailable,
786
+ body=search_body,
787
+ size=size_limit,
788
+ )
789
+ )
790
+
791
+ count_task = asyncio.create_task(
792
+ self.client.count(
793
+ index=index_param,
794
+ ignore_unavailable=ignore_unavailable,
795
+ body=search.to_dict(count=True),
796
+ )
797
+ )
798
+
799
+ try:
800
+ es_response = await search_task
801
+ except exceptions.NotFoundError:
802
+ raise NotFoundError(f"Collections '{collection_ids}' do not exist")
803
+
804
+ hits = es_response["hits"]["hits"]
805
+ items = (hit["_source"] for hit in hits[:limit])
806
+
807
+ next_token = None
808
+ if len(hits) > limit and limit < max_result_window:
809
+ if hits and (sort_array := hits[limit - 1].get("sort")):
810
+ next_token = urlsafe_b64encode(orjson.dumps(sort_array)).decode()
811
+
812
+ matched = (
813
+ es_response["hits"]["total"]["value"]
814
+ if es_response["hits"]["total"]["relation"] == "eq"
815
+ else None
816
+ )
817
+ if count_task.done():
818
+ try:
819
+ matched = count_task.result().get("count")
820
+ except Exception as e:
821
+ logger.error(f"Count task failed: {e}")
822
+
823
+ return items, matched, next_token
824
+
825
+ """ AGGREGATE LOGIC """
826
+
827
+ async def aggregate(
828
+ self,
829
+ collection_ids: Optional[List[str]],
830
+ aggregations: List[str],
831
+ search: Search,
832
+ centroid_geohash_grid_precision: int,
833
+ centroid_geohex_grid_precision: int,
834
+ centroid_geotile_grid_precision: int,
835
+ geometry_geohash_grid_precision: int,
836
+ geometry_geotile_grid_precision: int,
837
+ datetime_frequency_interval: str,
838
+ datetime_search,
839
+ ignore_unavailable: Optional[bool] = True,
840
+ ):
841
+ """Return aggregations of STAC Items."""
842
+ search_body: Dict[str, Any] = {}
843
+ query = search.query.to_dict() if search.query else None
844
+ if query:
845
+ search_body["query"] = query
846
+
847
+ def _fill_aggregation_parameters(name: str, agg: dict) -> dict:
848
+ [key] = agg.keys()
849
+ agg_precision = {
850
+ "centroid_geohash_grid_frequency": centroid_geohash_grid_precision,
851
+ "centroid_geohex_grid_frequency": centroid_geohex_grid_precision,
852
+ "centroid_geotile_grid_frequency": centroid_geotile_grid_precision,
853
+ "geometry_geohash_grid_frequency": geometry_geohash_grid_precision,
854
+ "geometry_geotile_grid_frequency": geometry_geotile_grid_precision,
855
+ }
856
+ if name in agg_precision:
857
+ agg[key]["precision"] = agg_precision[name]
858
+
859
+ if key == "date_histogram":
860
+ agg[key]["calendar_interval"] = datetime_frequency_interval
861
+
862
+ return agg
863
+
864
+ # include all aggregations specified
865
+ # this will ignore aggregations with the wrong names
866
+ search_body["aggregations"] = {
867
+ k: _fill_aggregation_parameters(k, deepcopy(v))
868
+ for k, v in self.aggregation_mapping.items()
869
+ if k in aggregations
870
+ }
871
+
872
+ index_param = await self.async_index_selector.select_indexes(
873
+ collection_ids, datetime_search
874
+ )
875
+
876
+ search_task = asyncio.create_task(
877
+ self.client.search(
878
+ index=index_param,
879
+ ignore_unavailable=ignore_unavailable,
880
+ body=search_body,
881
+ )
882
+ )
883
+
884
+ try:
885
+ db_response = await search_task
886
+ except exceptions.NotFoundError:
887
+ raise NotFoundError(f"Collections '{collection_ids}' do not exist")
888
+
889
+ return db_response
890
+
891
+ """ TRANSACTION LOGIC """
892
+
893
+ async def check_collection_exists(self, collection_id: str):
894
+ """Database logic to check if a collection exists."""
895
+ if not await self.client.exists(index=COLLECTIONS_INDEX, id=collection_id):
896
+ raise NotFoundError(f"Collection {collection_id} does not exist")
897
+
898
+ async def async_prep_create_item(
899
+ self, item: Item, base_url: str, exist_ok: bool = False
900
+ ) -> Item:
901
+ """
902
+ Preps an item for insertion into the database.
903
+
904
+ Args:
905
+ item (Item): The item to be prepped for insertion.
906
+ base_url (str): The base URL used to create the item's self URL.
907
+ exist_ok (bool): Indicates whether the item can exist already.
908
+
909
+ Returns:
910
+ Item: The prepped item.
911
+
912
+ Raises:
913
+ ConflictError: If the item already exists in the database.
914
+
915
+ """
916
+ await self.check_collection_exists(collection_id=item["collection"])
917
+ alias = index_alias_by_collection_id(item["collection"])
918
+ doc_id = mk_item_id(item["id"], item["collection"])
919
+
920
+ if not exist_ok:
921
+ alias_exists = await self.client.indices.exists_alias(name=alias)
922
+
923
+ if alias_exists:
924
+ alias_info = await self.client.indices.get_alias(name=alias)
925
+ indices = list(alias_info.keys())
926
+
927
+ for index in indices:
928
+ if await self.client.exists(index=index, id=doc_id):
929
+ raise ConflictError(
930
+ f"Item {item['id']} in collection {item['collection']} already exists"
931
+ )
932
+
933
+ return self.item_serializer.stac_to_db(item, base_url)
934
+
935
+ async def bulk_async_prep_create_item(
936
+ self, item: Item, base_url: str, exist_ok: bool = False
937
+ ) -> Item:
938
+ """
939
+ Prepare an item for insertion into the database.
940
+
941
+ This method performs pre-insertion preparation on the given `item`, such as:
942
+ - Verifying that the collection the item belongs to exists.
943
+ - Optionally checking if an item with the same ID already exists in the database.
944
+ - Serializing the item into a database-compatible format.
945
+
946
+ Args:
947
+ item (Item): The item to be prepared for insertion.
948
+ base_url (str): The base URL used to construct the item's self URL.
949
+ exist_ok (bool): Indicates whether the item can already exist in the database.
950
+ If False, a `ConflictError` is raised if the item exists.
951
+
952
+ Returns:
953
+ Item: The prepared item, serialized into a database-compatible format.
954
+
955
+ Raises:
956
+ NotFoundError: If the collection that the item belongs to does not exist in the database.
957
+ ConflictError: If an item with the same ID already exists in the collection and `exist_ok` is False,
958
+ and `RAISE_ON_BULK_ERROR` is set to `true`.
959
+ """
960
+ logger.debug(f"Preparing item {item['id']} in collection {item['collection']}.")
961
+
962
+ # Check if the collection exists
963
+ await self.check_collection_exists(collection_id=item["collection"])
964
+
965
+ # Check if the item already exists in the database
966
+ if not exist_ok and await self.client.exists(
967
+ index=index_alias_by_collection_id(item["collection"]),
968
+ id=mk_item_id(item["id"], item["collection"]),
969
+ ):
970
+ error_message = (
971
+ f"Item {item['id']} in collection {item['collection']} already exists."
972
+ )
973
+ if self.async_settings.raise_on_bulk_error:
974
+ raise ConflictError(error_message)
975
+ else:
976
+ logger.warning(
977
+ f"{error_message} Continuing as `RAISE_ON_BULK_ERROR` is set to false."
978
+ )
979
+ # Serialize the item into a database-compatible format
980
+ prepped_item = self.item_serializer.stac_to_db(item, base_url)
981
+ logger.debug(f"Item {item['id']} prepared successfully.")
982
+ return prepped_item
983
+
984
+ def bulk_sync_prep_create_item(
985
+ self, item: Item, base_url: str, exist_ok: bool = False
986
+ ) -> Item:
987
+ """
988
+ Prepare an item for insertion into the database.
989
+
990
+ This method performs pre-insertion preparation on the given `item`, such as:
991
+ - Verifying that the collection the item belongs to exists.
992
+ - Optionally checking if an item with the same ID already exists in the database.
993
+ - Serializing the item into a database-compatible format.
994
+
995
+ Args:
996
+ item (Item): The item to be prepared for insertion.
997
+ base_url (str): The base URL used to construct the item's self URL.
998
+ exist_ok (bool): Indicates whether the item can already exist in the database.
999
+ If False, a `ConflictError` is raised if the item exists.
1000
+
1001
+ Returns:
1002
+ Item: The prepared item, serialized into a database-compatible format.
1003
+
1004
+ Raises:
1005
+ NotFoundError: If the collection that the item belongs to does not exist in the database.
1006
+ ConflictError: If an item with the same ID already exists in the collection and `exist_ok` is False,
1007
+ and `RAISE_ON_BULK_ERROR` is set to `true`.
1008
+ """
1009
+ logger.debug(f"Preparing item {item['id']} in collection {item['collection']}.")
1010
+
1011
+ # Check if the collection exists
1012
+ if not self.sync_client.exists(index=COLLECTIONS_INDEX, id=item["collection"]):
1013
+ raise NotFoundError(f"Collection {item['collection']} does not exist")
1014
+
1015
+ # Check if the item already exists in the database
1016
+ if not exist_ok and self.sync_client.exists(
1017
+ index=index_alias_by_collection_id(item["collection"]),
1018
+ id=mk_item_id(item["id"], item["collection"]),
1019
+ ):
1020
+ error_message = (
1021
+ f"Item {item['id']} in collection {item['collection']} already exists."
1022
+ )
1023
+ if self.sync_settings.raise_on_bulk_error:
1024
+ raise ConflictError(error_message)
1025
+ else:
1026
+ logger.warning(
1027
+ f"{error_message} Continuing as `RAISE_ON_BULK_ERROR` is set to false."
1028
+ )
1029
+
1030
+ # Serialize the item into a database-compatible format
1031
+ prepped_item = self.item_serializer.stac_to_db(item, base_url)
1032
+ logger.debug(f"Item {item['id']} prepared successfully.")
1033
+ return prepped_item
1034
+
1035
+ async def create_item(
1036
+ self,
1037
+ item: Item,
1038
+ base_url: str = "",
1039
+ exist_ok: bool = False,
1040
+ **kwargs: Any,
1041
+ ):
1042
+ """Database logic for creating one item.
1043
+
1044
+ Args:
1045
+ item (Item): The item to be created.
1046
+ base_url (str, optional): The base URL for the item. Defaults to an empty string.
1047
+ exist_ok (bool, optional): Whether to allow the item to exist already. Defaults to False.
1048
+ **kwargs: Additional keyword arguments like refresh.
1049
+
1050
+ Raises:
1051
+ ConflictError: If the item already exists in the database.
1052
+
1053
+ Returns:
1054
+ None
1055
+ """
1056
+ # todo: check if collection exists, but cache
1057
+ item_id = item["id"]
1058
+ collection_id = item["collection"]
1059
+
1060
+ # Ensure kwargs is a dictionary
1061
+ kwargs = kwargs or {}
1062
+
1063
+ # Resolve the `refresh` parameter
1064
+ refresh = kwargs.get("refresh", self.async_settings.database_refresh)
1065
+ refresh = validate_refresh(refresh)
1066
+
1067
+ # Log the creation attempt
1068
+ logger.info(
1069
+ f"Creating item {item_id} in collection {collection_id} with refresh={refresh}"
1070
+ )
1071
+
1072
+ item = await self.async_prep_create_item(
1073
+ item=item, base_url=base_url, exist_ok=exist_ok
1074
+ )
1075
+
1076
+ target_index = await self.async_index_inserter.get_target_index(
1077
+ collection_id, item
1078
+ )
1079
+
1080
+ await self.client.index(
1081
+ index=target_index,
1082
+ id=mk_item_id(item_id, collection_id),
1083
+ body=item,
1084
+ refresh=refresh,
1085
+ )
1086
+
1087
+ async def merge_patch_item(
1088
+ self,
1089
+ collection_id: str,
1090
+ item_id: str,
1091
+ item: PartialItem,
1092
+ base_url: str,
1093
+ refresh: bool = True,
1094
+ ) -> Item:
1095
+ """Database logic for merge patching an item following RF7396.
1096
+
1097
+ Args:
1098
+ collection_id(str): Collection that item belongs to.
1099
+ item_id(str): Id of item to be patched.
1100
+ item (PartialItem): The partial item to be updated.
1101
+ base_url: (str): The base URL used for constructing URLs for the item.
1102
+ refresh (bool, optional): Refresh the index after performing the operation. Defaults to True.
1103
+
1104
+ Returns:
1105
+ patched item.
1106
+ """
1107
+ operations = merge_to_operations(item.model_dump())
1108
+
1109
+ return await self.json_patch_item(
1110
+ collection_id=collection_id,
1111
+ item_id=item_id,
1112
+ operations=operations,
1113
+ base_url=base_url,
1114
+ create_nest=True,
1115
+ refresh=refresh,
1116
+ )
1117
+
1118
+ async def json_patch_item(
1119
+ self,
1120
+ collection_id: str,
1121
+ item_id: str,
1122
+ operations: List[PatchOperation],
1123
+ base_url: str,
1124
+ create_nest: bool = False,
1125
+ refresh: bool = True,
1126
+ ) -> Item:
1127
+ """Database logic for json patching an item following RF6902.
1128
+
1129
+ Args:
1130
+ collection_id(str): Collection that item belongs to.
1131
+ item_id(str): Id of item to be patched.
1132
+ operations (list): List of operations to run.
1133
+ base_url (str): The base URL used for constructing URLs for the item.
1134
+ refresh (bool, optional): Refresh the index after performing the operation. Defaults to True.
1135
+
1136
+ Returns:
1137
+ patched item.
1138
+ """
1139
+ new_item_id = None
1140
+ new_collection_id = None
1141
+ script_operations = []
1142
+
1143
+ for operation in operations:
1144
+ if operation.path in ["collection", "id"] and operation.op in [
1145
+ "add",
1146
+ "replace",
1147
+ ]:
1148
+ if operation.path == "collection" and collection_id != operation.value:
1149
+ await self.check_collection_exists(collection_id=operation.value)
1150
+ new_collection_id = operation.value
1151
+
1152
+ if operation.path == "id" and item_id != operation.value:
1153
+ new_item_id = operation.value
1154
+
1155
+ else:
1156
+ script_operations.append(operation)
1157
+
1158
+ script = operations_to_script(script_operations, create_nest=create_nest)
1159
+
1160
+ try:
1161
+ search_response = await self.client.search(
1162
+ index=index_alias_by_collection_id(collection_id),
1163
+ body={
1164
+ "query": {"term": {"_id": mk_item_id(item_id, collection_id)}},
1165
+ "size": 1,
1166
+ },
1167
+ )
1168
+ if search_response["hits"]["total"]["value"] == 0:
1169
+ raise NotFoundError(
1170
+ f"Item {item_id} does not exist inside Collection {collection_id}"
1171
+ )
1172
+ document_index = search_response["hits"]["hits"][0]["_index"]
1173
+ await self.client.update(
1174
+ index=document_index,
1175
+ id=mk_item_id(item_id, collection_id),
1176
+ body={"script": script},
1177
+ refresh=True,
1178
+ )
1179
+ except exceptions.NotFoundError:
1180
+ raise NotFoundError(
1181
+ f"Item {item_id} does not exist inside Collection {collection_id}"
1182
+ )
1183
+ except exceptions.RequestError as exc:
1184
+ raise HTTPException(
1185
+ status_code=400, detail=exc.info["error"]["caused_by"]
1186
+ ) from exc
1187
+
1188
+ item = await self.get_one_item(collection_id, item_id)
1189
+
1190
+ if new_collection_id:
1191
+ await self.client.reindex(
1192
+ body={
1193
+ "dest": {"index": f"{ITEMS_INDEX_PREFIX}{new_collection_id}"},
1194
+ "source": {
1195
+ "index": f"{ITEMS_INDEX_PREFIX}{collection_id}",
1196
+ "query": {"term": {"id": {"value": item_id}}},
1197
+ },
1198
+ "script": {
1199
+ "lang": "painless",
1200
+ "source": (
1201
+ f"""ctx._id = ctx._id.replace('{collection_id}', '{new_collection_id}');"""
1202
+ f"""ctx._source.collection = '{new_collection_id}';"""
1203
+ ),
1204
+ },
1205
+ },
1206
+ wait_for_completion=True,
1207
+ refresh=True,
1208
+ )
1209
+
1210
+ await self.delete_item(
1211
+ item_id=item_id,
1212
+ collection_id=collection_id,
1213
+ refresh=refresh,
1214
+ )
1215
+
1216
+ item["collection"] = new_collection_id
1217
+ collection_id = new_collection_id
1218
+
1219
+ if new_item_id:
1220
+ item["id"] = new_item_id
1221
+ item = await self.async_prep_create_item(item=item, base_url=base_url)
1222
+ await self.create_item(item=item, refresh=True)
1223
+
1224
+ await self.delete_item(
1225
+ item_id=item_id,
1226
+ collection_id=collection_id,
1227
+ refresh=refresh,
1228
+ )
1229
+
1230
+ return item
1231
+
1232
+ async def delete_item(self, item_id: str, collection_id: str, **kwargs: Any):
1233
+ """Delete a single item from the database.
1234
+
1235
+ Args:
1236
+ item_id (str): The id of the Item to be deleted.
1237
+ collection_id (str): The id of the Collection that the Item belongs to.
1238
+ **kwargs: Additional keyword arguments like refresh.
1239
+
1240
+ Raises:
1241
+ NotFoundError: If the Item does not exist in the database.
1242
+ """
1243
+ # Ensure kwargs is a dictionary
1244
+ kwargs = kwargs or {}
1245
+
1246
+ # Resolve the `refresh` parameter
1247
+ refresh = kwargs.get("refresh", self.async_settings.database_refresh)
1248
+ refresh = validate_refresh(refresh)
1249
+
1250
+ # Log the deletion attempt
1251
+ logger.info(
1252
+ f"Deleting item {item_id} from collection {collection_id} with refresh={refresh}"
1253
+ )
1254
+
1255
+ try:
1256
+ await self.client.delete_by_query(
1257
+ index=index_alias_by_collection_id(collection_id),
1258
+ body={"query": {"term": {"_id": mk_item_id(item_id, collection_id)}}},
1259
+ refresh=refresh,
1260
+ )
1261
+ except exceptions.NotFoundError:
1262
+ raise NotFoundError(
1263
+ f"Item {item_id} in collection {collection_id} not found"
1264
+ )
1265
+
1266
+ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]:
1267
+ """Get the mapping for the specified collection's items index.
1268
+
1269
+ Args:
1270
+ collection_id (str): The ID of the collection to get items mapping for.
1271
+
1272
+ Returns:
1273
+ Dict[str, Any]: The mapping information.
1274
+ """
1275
+ index_name = index_alias_by_collection_id(collection_id)
1276
+ try:
1277
+ mapping = await self.client.indices.get_mapping(
1278
+ index=index_name, params={"allow_no_indices": "false"}
1279
+ )
1280
+ return mapping
1281
+ except exceptions.NotFoundError:
1282
+ raise NotFoundError(f"Mapping for index {index_name} not found")
1283
+
1284
+ async def get_items_unique_values(
1285
+ self, collection_id: str, field_names: Iterable[str], *, limit: int = 100
1286
+ ) -> Dict[str, List[str]]:
1287
+ """Get the unique values for the given fields in the collection."""
1288
+ limit_plus_one = limit + 1
1289
+ index_name = index_alias_by_collection_id(collection_id)
1290
+
1291
+ query = await self.client.search(
1292
+ index=index_name,
1293
+ body={
1294
+ "size": 0,
1295
+ "aggs": {
1296
+ field: {"terms": {"field": field, "size": limit_plus_one}}
1297
+ for field in field_names
1298
+ },
1299
+ },
1300
+ )
1301
+
1302
+ result: Dict[str, List[str]] = {}
1303
+ for field, agg in query["aggregations"].items():
1304
+ if len(agg["buckets"]) > limit:
1305
+ logger.warning(
1306
+ "Skipping enum field %s: exceeds limit of %d unique values. "
1307
+ "Consider excluding this field from enumeration or increase the limit.",
1308
+ field,
1309
+ limit,
1310
+ )
1311
+ continue
1312
+ result[field] = [bucket["key"] for bucket in agg["buckets"]]
1313
+ return result
1314
+
1315
+ async def create_collection(self, collection: Collection, **kwargs: Any):
1316
+ """Create a single collection in the database.
1317
+
1318
+ Args:
1319
+ collection (Collection): The Collection object to be created.
1320
+ **kwargs: Additional keyword arguments like refresh.
1321
+
1322
+ Raises:
1323
+ ConflictError: If a Collection with the same id already exists in the database.
1324
+
1325
+ Notes:
1326
+ A new index is created for the items in the Collection if the index insertion strategy requires it.
1327
+ """
1328
+ collection_id = collection["id"]
1329
+
1330
+ # Ensure kwargs is a dictionary
1331
+ kwargs = kwargs or {}
1332
+
1333
+ # Resolve the `refresh` parameter
1334
+ refresh = kwargs.get("refresh", self.async_settings.database_refresh)
1335
+ refresh = validate_refresh(refresh)
1336
+
1337
+ # Log the creation attempt
1338
+ logger.info(f"Creating collection {collection_id} with refresh={refresh}")
1339
+
1340
+ if await self.client.exists(index=COLLECTIONS_INDEX, id=collection_id):
1341
+ raise ConflictError(f"Collection {collection_id} already exists")
1342
+
1343
+ if get_bool_env("ENABLE_COLLECTIONS_SEARCH") or get_bool_env(
1344
+ "ENABLE_COLLECTIONS_SEARCH_ROUTE"
1345
+ ):
1346
+ # Convert bbox to bbox_shape for geospatial queries (ES/OS specific)
1347
+ add_bbox_shape_to_collection(collection)
1348
+
1349
+ await self.client.index(
1350
+ index=COLLECTIONS_INDEX,
1351
+ id=collection_id,
1352
+ body=collection,
1353
+ refresh=refresh,
1354
+ )
1355
+ if self.async_index_inserter.should_create_collection_index():
1356
+ await self.async_index_inserter.create_simple_index(
1357
+ self.client, collection_id
1358
+ )
1359
+
1360
+ async def find_collection(self, collection_id: str) -> Collection:
1361
+ """Find and return a collection from the database.
1362
+
1363
+ Args:
1364
+ self: The instance of the object calling this function.
1365
+ collection_id (str): The ID of the collection to be found.
1366
+
1367
+ Returns:
1368
+ Collection: The found collection, represented as a `Collection` object.
1369
+
1370
+ Raises:
1371
+ NotFoundError: If the collection with the given `collection_id` is not found in the database.
1372
+
1373
+ Notes:
1374
+ This function searches for a collection in the database using the specified `collection_id` and returns the found
1375
+ collection as a `Collection` object. If the collection is not found, a `NotFoundError` is raised.
1376
+ """
1377
+ try:
1378
+ collection = await self.client.get(
1379
+ index=COLLECTIONS_INDEX, id=collection_id
1380
+ )
1381
+ except exceptions.NotFoundError:
1382
+ raise NotFoundError(f"Collection {collection_id} not found")
1383
+
1384
+ return collection["_source"]
1385
+
1386
+ async def update_collection(
1387
+ self, collection_id: str, collection: Collection, **kwargs: Any
1388
+ ):
1389
+ """Update a collection from the database.
1390
+
1391
+ Args:
1392
+ collection_id (str): The ID of the collection to be updated.
1393
+ collection (Collection): The Collection object to be used for the update.
1394
+ **kwargs: Additional keyword arguments like refresh.
1395
+
1396
+ Raises:
1397
+ NotFoundError: If the collection with the given `collection_id` is not
1398
+ found in the database.
1399
+
1400
+ Notes:
1401
+ This function updates the collection in the database using the specified
1402
+ `collection_id` and with the collection specified in the `Collection` object.
1403
+ If the collection is not found, a `NotFoundError` is raised.
1404
+ """
1405
+ # Ensure kwargs is a dictionary
1406
+ kwargs = kwargs or {}
1407
+
1408
+ # Resolve the `refresh` parameter
1409
+ refresh = kwargs.get("refresh", self.async_settings.database_refresh)
1410
+ refresh = validate_refresh(refresh)
1411
+
1412
+ # Log the update attempt
1413
+ logger.info(f"Updating collection {collection_id} with refresh={refresh}")
1414
+
1415
+ await self.find_collection(collection_id=collection_id)
1416
+
1417
+ if collection_id != collection["id"]:
1418
+ logger.info(
1419
+ f"Collection ID change detected: {collection_id} -> {collection['id']}"
1420
+ )
1421
+
1422
+ await self.create_collection(collection, refresh=refresh)
1423
+
1424
+ await self.client.reindex(
1425
+ body={
1426
+ "dest": {"index": f"{ITEMS_INDEX_PREFIX}{collection['id']}"},
1427
+ "source": {"index": f"{ITEMS_INDEX_PREFIX}{collection_id}"},
1428
+ "script": {
1429
+ "lang": "painless",
1430
+ "source": f"""ctx._id = ctx._id.replace('{collection_id}', '{collection["id"]}'); ctx._source.collection = '{collection["id"]}' ;""",
1431
+ },
1432
+ },
1433
+ wait_for_completion=True,
1434
+ refresh=refresh,
1435
+ )
1436
+
1437
+ await self.delete_collection(collection_id=collection_id, **kwargs)
1438
+
1439
+ else:
1440
+ if get_bool_env("ENABLE_COLLECTIONS_SEARCH") or get_bool_env(
1441
+ "ENABLE_COLLECTIONS_SEARCH_ROUTE"
1442
+ ):
1443
+ # Convert bbox to bbox_shape for geospatial queries (ES/OS specific)
1444
+ add_bbox_shape_to_collection(collection)
1445
+
1446
+ await self.client.index(
1447
+ index=COLLECTIONS_INDEX,
1448
+ id=collection_id,
1449
+ body=collection,
1450
+ refresh=refresh,
1451
+ )
1452
+
1453
+ async def merge_patch_collection(
1454
+ self,
1455
+ collection_id: str,
1456
+ collection: PartialCollection,
1457
+ base_url: str,
1458
+ refresh: bool = True,
1459
+ ) -> Collection:
1460
+ """Database logic for merge patching a collection following RF7396.
1461
+
1462
+ Args:
1463
+ collection_id(str): Id of collection to be patched.
1464
+ collection (PartialCollection): The partial collection to be updated.
1465
+ base_url: (str): The base URL used for constructing links.
1466
+ refresh (bool, optional): Refresh the index after performing the operation. Defaults to True.
1467
+
1468
+
1469
+ Returns:
1470
+ patched collection.
1471
+ """
1472
+ operations = merge_to_operations(collection.model_dump())
1473
+
1474
+ return await self.json_patch_collection(
1475
+ collection_id=collection_id,
1476
+ operations=operations,
1477
+ base_url=base_url,
1478
+ create_nest=True,
1479
+ refresh=refresh,
1480
+ )
1481
+
1482
+ async def json_patch_collection(
1483
+ self,
1484
+ collection_id: str,
1485
+ operations: List[PatchOperation],
1486
+ base_url: str,
1487
+ create_nest: bool = False,
1488
+ refresh: bool = True,
1489
+ ) -> Collection:
1490
+ """Database logic for json patching a collection following RF6902.
1491
+
1492
+ Args:
1493
+ collection_id(str): Id of collection to be patched.
1494
+ operations (list): List of operations to run.
1495
+ base_url (str): The base URL used for constructing links.
1496
+ refresh (bool, optional): Refresh the index after performing the operation. Defaults to True.
1497
+
1498
+ Returns:
1499
+ patched collection.
1500
+ """
1501
+ new_collection_id = None
1502
+ script_operations = []
1503
+
1504
+ for operation in operations:
1505
+ if (
1506
+ operation.op in ["add", "replace"]
1507
+ and operation.path == "collection"
1508
+ and collection_id != operation.value
1509
+ ):
1510
+ new_collection_id = operation.value
1511
+
1512
+ else:
1513
+ script_operations.append(operation)
1514
+
1515
+ script = operations_to_script(script_operations, create_nest=create_nest)
1516
+
1517
+ try:
1518
+ await self.client.update(
1519
+ index=COLLECTIONS_INDEX,
1520
+ id=collection_id,
1521
+ body={"script": script},
1522
+ refresh=True,
1523
+ )
1524
+
1525
+ except exceptions.RequestError as exc:
1526
+ raise HTTPException(
1527
+ status_code=400, detail=exc.info["error"]["caused_by"]
1528
+ ) from exc
1529
+
1530
+ collection = await self.find_collection(collection_id)
1531
+
1532
+ if new_collection_id:
1533
+ collection["id"] = new_collection_id
1534
+ collection["links"] = resolve_links([], base_url)
1535
+
1536
+ await self.update_collection(
1537
+ collection_id=collection_id,
1538
+ collection=collection,
1539
+ refresh=refresh,
1540
+ )
1541
+
1542
+ return collection
1543
+
1544
+ async def delete_collection(self, collection_id: str, **kwargs: Any):
1545
+ """Delete a collection from the database.
1546
+
1547
+ Parameters:
1548
+ self: The instance of the object calling this function.
1549
+ collection_id (str): The ID of the collection to be deleted.
1550
+ **kwargs: Additional keyword arguments like refresh.
1551
+
1552
+ Raises:
1553
+ NotFoundError: If the collection with the given `collection_id` is not found in the database.
1554
+
1555
+ Notes:
1556
+ This function first verifies that the collection with the specified `collection_id` exists in the database, and then
1557
+ deletes the collection. If `refresh` is set to "true", "false", or "wait_for", the index is refreshed accordingly after
1558
+ the deletion. Additionally, this function also calls `delete_item_index` to delete the index for the items in the collection.
1559
+ """
1560
+ # Ensure kwargs is a dictionary
1561
+ kwargs = kwargs or {}
1562
+
1563
+ await self.find_collection(collection_id=collection_id)
1564
+
1565
+ # Resolve the `refresh` parameter
1566
+ refresh = kwargs.get("refresh", self.async_settings.database_refresh)
1567
+ refresh = validate_refresh(refresh)
1568
+
1569
+ # Log the deletion attempt
1570
+ logger.info(f"Deleting collection {collection_id} with refresh={refresh}")
1571
+
1572
+ await self.client.delete(
1573
+ index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh
1574
+ )
1575
+ # Delete the item index for the collection
1576
+ await delete_item_index(collection_id)
1577
+
1578
+ async def bulk_async(
1579
+ self,
1580
+ collection_id: str,
1581
+ processed_items: List[Item],
1582
+ **kwargs: Any,
1583
+ ) -> Tuple[int, List[Dict[str, Any]]]:
1584
+ """
1585
+ Perform a bulk insert of items into the database asynchronously.
1586
+
1587
+ Args:
1588
+ collection_id (str): The ID of the collection to which the items belong.
1589
+ processed_items (List[Item]): A list of `Item` objects to be inserted into the database.
1590
+ **kwargs (Any): Additional keyword arguments, including:
1591
+ - refresh (str, optional): Whether to refresh the index after the bulk insert.
1592
+ Can be "true", "false", or "wait_for". Defaults to the value of `self.sync_settings.database_refresh`.
1593
+ - refresh (bool, optional): Whether to refresh the index after the bulk insert.
1594
+ - raise_on_error (bool, optional): Whether to raise an error if any of the bulk operations fail.
1595
+ Defaults to the value of `self.async_settings.raise_on_bulk_error`.
1596
+
1597
+ Returns:
1598
+ Tuple[int, List[Dict[str, Any]]]: A tuple containing:
1599
+ - The number of successfully processed actions (`success`).
1600
+ - A list of errors encountered during the bulk operation (`errors`).
1601
+
1602
+ Notes:
1603
+ This function performs a bulk insert of `processed_items` into the database using the specified `collection_id`.
1604
+ The insert is performed synchronously and blocking, meaning that the function does not return until the insert has
1605
+ completed. The `mk_actions` function is called to generate a list of actions for the bulk insert. The `refresh`
1606
+ parameter determines whether the index is refreshed after the bulk insert:
1607
+ - "true": Forces an immediate refresh of the index.
1608
+ - "false": Does not refresh the index immediately (default behavior).
1609
+ - "wait_for": Waits for the next refresh cycle to make the changes visible.
1610
+ """
1611
+ # Ensure kwargs is a dictionary
1612
+ kwargs = kwargs or {}
1613
+
1614
+ # Resolve the `refresh` parameter
1615
+ refresh = kwargs.get("refresh", self.async_settings.database_refresh)
1616
+ refresh = validate_refresh(refresh)
1617
+
1618
+ # Log the bulk insert attempt
1619
+ logger.info(
1620
+ f"Performing bulk insert for collection {collection_id} with refresh={refresh}"
1621
+ )
1622
+
1623
+ # Handle empty processed_items
1624
+ if not processed_items:
1625
+ logger.warning(f"No items to insert for collection {collection_id}")
1626
+ return 0, []
1627
+
1628
+ raise_on_error = self.async_settings.raise_on_bulk_error
1629
+ actions = await self.async_index_inserter.prepare_bulk_actions(
1630
+ collection_id, processed_items
1631
+ )
1632
+
1633
+ success, errors = await helpers.async_bulk(
1634
+ self.client,
1635
+ actions,
1636
+ refresh=refresh,
1637
+ raise_on_error=raise_on_error,
1638
+ )
1639
+ # Log the result
1640
+ logger.info(
1641
+ f"Bulk insert completed for collection {collection_id}: {success} successes, {len(errors)} errors"
1642
+ )
1643
+ return success, errors
1644
+
1645
+ def bulk_sync(
1646
+ self,
1647
+ collection_id: str,
1648
+ processed_items: List[Item],
1649
+ **kwargs: Any,
1650
+ ) -> Tuple[int, List[Dict[str, Any]]]:
1651
+ """
1652
+ Perform a bulk insert of items into the database asynchronously.
1653
+
1654
+ Args:
1655
+ collection_id (str): The ID of the collection to which the items belong.
1656
+ processed_items (List[Item]): A list of `Item` objects to be inserted into the database.
1657
+ **kwargs (Any): Additional keyword arguments, including:
1658
+ - refresh (str, optional): Whether to refresh the index after the bulk insert.
1659
+ Can be "true", "false", or "wait_for". Defaults to the value of `self.sync_settings.database_refresh`.
1660
+ - refresh (bool, optional): Whether to refresh the index after the bulk insert.
1661
+ - raise_on_error (bool, optional): Whether to raise an error if any of the bulk operations fail.
1662
+ Defaults to the value of `self.async_settings.raise_on_bulk_error`.
1663
+
1664
+ Returns:
1665
+ Tuple[int, List[Dict[str, Any]]]: A tuple containing:
1666
+ - The number of successfully processed actions (`success`).
1667
+ - A list of errors encountered during the bulk operation (`errors`).
1668
+
1669
+ Notes:
1670
+ This function performs a bulk insert of `processed_items` into the database using the specified `collection_id`.
1671
+ The insert is performed synchronously and blocking, meaning that the function does not return until the insert has
1672
+ completed. The `mk_actions` function is called to generate a list of actions for the bulk insert. The `refresh`
1673
+ parameter determines whether the index is refreshed after the bulk insert:
1674
+ - "true": Forces an immediate refresh of the index.
1675
+ - "false": Does not refresh the index immediately (default behavior).
1676
+ - "wait_for": Waits for the next refresh cycle to make the changes visible.
1677
+ """
1678
+ # Ensure kwargs is a dictionary
1679
+ kwargs = kwargs or {}
1680
+
1681
+ # Resolve the `refresh` parameter
1682
+ refresh = kwargs.get("refresh", self.async_settings.database_refresh)
1683
+ refresh = validate_refresh(refresh)
1684
+
1685
+ # Log the bulk insert attempt
1686
+ logger.info(
1687
+ f"Performing bulk insert for collection {collection_id} with refresh={refresh}"
1688
+ )
1689
+
1690
+ # Handle empty processed_items
1691
+ if not processed_items:
1692
+ logger.warning(f"No items to insert for collection {collection_id}")
1693
+ return 0, []
1694
+
1695
+ # Handle empty processed_items
1696
+ if not processed_items:
1697
+ logger.warning(f"No items to insert for collection {collection_id}")
1698
+ return 0, []
1699
+
1700
+ raise_on_error = self.sync_settings.raise_on_bulk_error
1701
+ success, errors = helpers.bulk(
1702
+ self.sync_client,
1703
+ mk_actions(collection_id, processed_items),
1704
+ refresh=refresh,
1705
+ raise_on_error=raise_on_error,
1706
+ )
1707
+ return success, errors
1708
+
1709
+ # DANGER
1710
+ async def delete_items(self) -> None:
1711
+ """Danger. this is only for tests."""
1712
+ await self.client.delete_by_query(
1713
+ index=ITEM_INDICES,
1714
+ body={"query": {"match_all": {}}},
1715
+ wait_for_completion=True,
1716
+ )
1717
+
1718
+ # DANGER
1719
+ async def delete_collections(self) -> None:
1720
+ """Danger. this is only for tests."""
1721
+ await self.client.delete_by_query(
1722
+ index=COLLECTIONS_INDEX,
1723
+ body={"query": {"match_all": {}}},
1724
+ wait_for_completion=True,
1725
+ )