stac-fastapi-core 4.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,202 @@
1
+ """Filter extension logic for es conversion."""
2
+
3
+ # """
4
+ # Implements Filter Extension.
5
+
6
+ # Basic CQL2 (AND, OR, NOT), comparison operators (=, <>, <, <=, >, >=), and IS NULL.
7
+ # The comparison operators are allowed against string, numeric, boolean, date, and datetime types.
8
+
9
+ # Advanced comparison operators (http://www.opengis.net/spec/cql2/1.0/req/advanced-comparison-operators)
10
+ # defines the LIKE, IN, and BETWEEN operators.
11
+
12
+ # Basic Spatial Operators (http://www.opengis.net/spec/cql2/1.0/conf/basic-spatial-operators)
13
+ # defines the intersects operator (S_INTERSECTS).
14
+ # """
15
+
16
+ import re
17
+ from enum import Enum
18
+ from typing import Any, Dict
19
+
20
+ _cql2_like_patterns = re.compile(r"\\.|[%_]|\\$")
21
+ _valid_like_substitutions = {
22
+ "\\\\": "\\",
23
+ "\\%": "%",
24
+ "\\_": "_",
25
+ "%": "*",
26
+ "_": "?",
27
+ }
28
+
29
+
30
+ def _replace_like_patterns(match: re.Match) -> str:
31
+ pattern = match.group()
32
+ try:
33
+ return _valid_like_substitutions[pattern]
34
+ except KeyError:
35
+ raise ValueError(f"'{pattern}' is not a valid escape sequence")
36
+
37
+
38
+ def cql2_like_to_es(string: str) -> str:
39
+ """
40
+ Convert CQL2 "LIKE" characters to Elasticsearch "wildcard" characters.
41
+
42
+ Args:
43
+ string (str): The string containing CQL2 wildcard characters.
44
+
45
+ Returns:
46
+ str: The converted string with Elasticsearch compatible wildcards.
47
+
48
+ Raises:
49
+ ValueError: If an invalid escape sequence is encountered.
50
+ """
51
+ return _cql2_like_patterns.sub(
52
+ repl=_replace_like_patterns,
53
+ string=string,
54
+ )
55
+
56
+
57
+ class LogicalOp(str, Enum):
58
+ """Enumeration for logical operators used in constructing Elasticsearch queries."""
59
+
60
+ AND = "and"
61
+ OR = "or"
62
+ NOT = "not"
63
+
64
+
65
+ class ComparisonOp(str, Enum):
66
+ """Enumeration for comparison operators used in filtering queries according to CQL2 standards."""
67
+
68
+ EQ = "="
69
+ NEQ = "<>"
70
+ LT = "<"
71
+ LTE = "<="
72
+ GT = ">"
73
+ GTE = ">="
74
+ IS_NULL = "isNull"
75
+
76
+
77
+ class AdvancedComparisonOp(str, Enum):
78
+ """Enumeration for advanced comparison operators like 'like', 'between', and 'in'."""
79
+
80
+ LIKE = "like"
81
+ BETWEEN = "between"
82
+ IN = "in"
83
+
84
+
85
+ class SpatialIntersectsOp(str, Enum):
86
+ """Enumeration for spatial intersection operator as per CQL2 standards."""
87
+
88
+ S_INTERSECTS = "s_intersects"
89
+
90
+
91
+ queryables_mapping = {
92
+ "id": "id",
93
+ "collection": "collection",
94
+ "geometry": "geometry",
95
+ "datetime": "properties.datetime",
96
+ "created": "properties.created",
97
+ "updated": "properties.updated",
98
+ "cloud_cover": "properties.eo:cloud_cover",
99
+ "cloud_shadow_percentage": "properties.s2:cloud_shadow_percentage",
100
+ "nodata_pixel_percentage": "properties.s2:nodata_pixel_percentage",
101
+ }
102
+
103
+
104
+ def to_es_field(field: str) -> str:
105
+ """
106
+ Map a given field to its corresponding Elasticsearch field according to a predefined mapping.
107
+
108
+ Args:
109
+ field (str): The field name from a user query or filter.
110
+
111
+ Returns:
112
+ str: The mapped field name suitable for Elasticsearch queries.
113
+ """
114
+ return queryables_mapping.get(field, field)
115
+
116
+
117
+ def to_es(query: Dict[str, Any]) -> Dict[str, Any]:
118
+ """
119
+ Transform a simplified CQL2 query structure to an Elasticsearch compatible query DSL.
120
+
121
+ Args:
122
+ query (Dict[str, Any]): The query dictionary containing 'op' and 'args'.
123
+
124
+ Returns:
125
+ Dict[str, Any]: The corresponding Elasticsearch query in the form of a dictionary.
126
+ """
127
+ if query["op"] in [LogicalOp.AND, LogicalOp.OR, LogicalOp.NOT]:
128
+ bool_type = {
129
+ LogicalOp.AND: "must",
130
+ LogicalOp.OR: "should",
131
+ LogicalOp.NOT: "must_not",
132
+ }[query["op"]]
133
+ return {"bool": {bool_type: [to_es(sub_query) for sub_query in query["args"]]}}
134
+
135
+ elif query["op"] in [
136
+ ComparisonOp.EQ,
137
+ ComparisonOp.NEQ,
138
+ ComparisonOp.LT,
139
+ ComparisonOp.LTE,
140
+ ComparisonOp.GT,
141
+ ComparisonOp.GTE,
142
+ ]:
143
+ range_op = {
144
+ ComparisonOp.LT: "lt",
145
+ ComparisonOp.LTE: "lte",
146
+ ComparisonOp.GT: "gt",
147
+ ComparisonOp.GTE: "gte",
148
+ }
149
+
150
+ field = to_es_field(query["args"][0]["property"])
151
+ value = query["args"][1]
152
+ if isinstance(value, dict) and "timestamp" in value:
153
+ value = value["timestamp"]
154
+ if query["op"] == ComparisonOp.EQ:
155
+ return {"range": {field: {"gte": value, "lte": value}}}
156
+ elif query["op"] == ComparisonOp.NEQ:
157
+ return {
158
+ "bool": {
159
+ "must_not": [{"range": {field: {"gte": value, "lte": value}}}]
160
+ }
161
+ }
162
+ else:
163
+ return {"range": {field: {range_op[query["op"]]: value}}}
164
+ else:
165
+ if query["op"] == ComparisonOp.EQ:
166
+ return {"term": {field: value}}
167
+ elif query["op"] == ComparisonOp.NEQ:
168
+ return {"bool": {"must_not": [{"term": {field: value}}]}}
169
+ else:
170
+ return {"range": {field: {range_op[query["op"]]: value}}}
171
+
172
+ elif query["op"] == ComparisonOp.IS_NULL:
173
+ field = to_es_field(query["args"][0]["property"])
174
+ return {"bool": {"must_not": {"exists": {"field": field}}}}
175
+
176
+ elif query["op"] == AdvancedComparisonOp.BETWEEN:
177
+ field = to_es_field(query["args"][0]["property"])
178
+ gte, lte = query["args"][1], query["args"][2]
179
+ if isinstance(gte, dict) and "timestamp" in gte:
180
+ gte = gte["timestamp"]
181
+ if isinstance(lte, dict) and "timestamp" in lte:
182
+ lte = lte["timestamp"]
183
+ return {"range": {field: {"gte": gte, "lte": lte}}}
184
+
185
+ elif query["op"] == AdvancedComparisonOp.IN:
186
+ field = to_es_field(query["args"][0]["property"])
187
+ values = query["args"][1]
188
+ if not isinstance(values, list):
189
+ raise ValueError(f"Arg {values} is not a list")
190
+ return {"terms": {field: values}}
191
+
192
+ elif query["op"] == AdvancedComparisonOp.LIKE:
193
+ field = to_es_field(query["args"][0]["property"])
194
+ pattern = cql2_like_to_es(query["args"][1])
195
+ return {"wildcard": {field: {"value": pattern, "case_insensitive": True}}}
196
+
197
+ elif query["op"] == SpatialIntersectsOp.S_INTERSECTS:
198
+ field = to_es_field(query["args"][0]["property"])
199
+ geometry = query["args"][1]
200
+ return {"geo_shape": {field: {"shape": geometry, "relation": "intersects"}}}
201
+
202
+ return {}
@@ -0,0 +1,79 @@
1
+ """STAC SQLAlchemy specific query search model.
2
+
3
+ # TODO: replace with stac-pydantic
4
+ """
5
+
6
+ import logging
7
+ import operator
8
+ from dataclasses import dataclass
9
+ from enum import auto
10
+ from types import DynamicClassAttribute
11
+ from typing import Any, Callable, Dict, Optional
12
+
13
+ from pydantic import BaseModel, root_validator
14
+ from stac_pydantic.utils import AutoValueEnum
15
+
16
+ from stac_fastapi.extensions.core.query import QueryExtension as QueryExtensionBase
17
+
18
+ logger = logging.getLogger("uvicorn")
19
+ logger.setLevel(logging.INFO)
20
+
21
+
22
+ class Operator(str, AutoValueEnum):
23
+ """Defines the set of operators supported by the API."""
24
+
25
+ eq = auto()
26
+ ne = auto()
27
+ lt = auto()
28
+ lte = auto()
29
+ gt = auto()
30
+ gte = auto()
31
+
32
+ # TODO: These are defined in the spec but aren't currently implemented by the api
33
+ # startsWith = auto()
34
+ # endsWith = auto()
35
+ # contains = auto()
36
+ # in = auto()
37
+
38
+ @DynamicClassAttribute
39
+ def operator(self) -> Callable[[Any, Any], bool]:
40
+ """Return python operator."""
41
+ return getattr(operator, self._value_)
42
+
43
+
44
+ class Queryables(str, AutoValueEnum):
45
+ """Queryable fields."""
46
+
47
+ ...
48
+
49
+
50
+ @dataclass
51
+ class QueryableTypes:
52
+ """Defines a set of queryable fields."""
53
+
54
+ ...
55
+
56
+
57
+ class QueryExtensionPostRequest(BaseModel):
58
+ """Queryable validation.
59
+
60
+ Add queryables validation to the POST request
61
+ to raise errors for unsupported querys.
62
+ """
63
+
64
+ query: Optional[Dict[str, Dict[Operator, Any]]] = None
65
+
66
+ @root_validator(pre=True)
67
+ def validate_query_fields(cls, values: Dict) -> Dict:
68
+ """Validate query fields."""
69
+ ...
70
+
71
+
72
+ class QueryExtension(QueryExtensionBase):
73
+ """Query Extenson.
74
+
75
+ Override the POST request model to add validation against
76
+ supported fields
77
+ """
78
+
79
+ POST = QueryExtensionPostRequest
@@ -0,0 +1 @@
1
+ """stac_fastapi.elasticsearch.models module."""
@@ -0,0 +1,205 @@
1
+ """link helpers."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+ from urllib.parse import ParseResult, parse_qs, urlencode, urljoin, urlparse
5
+
6
+ import attr
7
+ from stac_pydantic.links import Relations
8
+ from stac_pydantic.shared import MimeTypes
9
+ from starlette.requests import Request
10
+
11
+ # Copied from pgstac links
12
+
13
+ # These can be inferred from the item/collection, so they aren't included in the database
14
+ # Instead they are dynamically generated when querying the database using the classes defined below
15
+ INFERRED_LINK_RELS = {"self", "item", "parent", "collection", "root"}
16
+
17
+
18
+ def merge_params(url: str, newparams: Dict) -> str:
19
+ """Merge url parameters."""
20
+ u = urlparse(url)
21
+ params = parse_qs(u.query)
22
+ params.update(newparams)
23
+ param_string = urlencode(params, True)
24
+
25
+ href = ParseResult(
26
+ scheme=u.scheme,
27
+ netloc=u.netloc,
28
+ path=u.path,
29
+ params=u.params,
30
+ query=param_string,
31
+ fragment=u.fragment,
32
+ ).geturl()
33
+ return href
34
+
35
+
36
+ @attr.s
37
+ class BaseLinks:
38
+ """Create inferred links common to collections and items."""
39
+
40
+ request: Request = attr.ib()
41
+
42
+ @property
43
+ def base_url(self):
44
+ """Get the base url."""
45
+ return str(self.request.base_url)
46
+
47
+ @property
48
+ def url(self):
49
+ """Get the current request url."""
50
+ return str(self.request.url)
51
+
52
+ def resolve(self, url):
53
+ """Resolve url to the current request url."""
54
+ return urljoin(str(self.base_url), str(url))
55
+
56
+ def link_self(self) -> Dict:
57
+ """Return the self link."""
58
+ return dict(rel=Relations.self.value, type=MimeTypes.json.value, href=self.url)
59
+
60
+ def link_root(self) -> Dict:
61
+ """Return the catalog root."""
62
+ return dict(
63
+ rel=Relations.root.value, type=MimeTypes.json.value, href=self.base_url
64
+ )
65
+
66
+ def create_links(self) -> List[Dict[str, Any]]:
67
+ """Return all inferred links."""
68
+ links = []
69
+ for name in dir(self):
70
+ if name.startswith("link_") and callable(getattr(self, name)):
71
+ link = getattr(self, name)()
72
+ if link is not None:
73
+ links.append(link)
74
+ return links
75
+
76
+ async def get_links(
77
+ self, extra_links: Optional[List[Dict[str, Any]]] = None
78
+ ) -> List[Dict[str, Any]]:
79
+ """
80
+ Generate all the links.
81
+
82
+ Get the links object for a stac resource by iterating through
83
+ available methods on this class that start with link_.
84
+ """
85
+ # TODO: Pass request.json() into function so this doesn't need to be coroutine
86
+ if self.request.method == "POST":
87
+ self.request.postbody = await self.request.json()
88
+ # join passed in links with generated links
89
+ # and update relative paths
90
+ links = self.create_links()
91
+
92
+ if extra_links:
93
+ # For extra links passed in,
94
+ # add links modified with a resolved href.
95
+ # Drop any links that are dynamically
96
+ # determined by the server (e.g. self, parent, etc.)
97
+ # Resolving the href allows for relative paths
98
+ # to be stored in pgstac and for the hrefs in the
99
+ # links of response STAC objects to be resolved
100
+ # to the request url.
101
+ links += [
102
+ {**link, "href": self.resolve(link["href"])}
103
+ for link in extra_links
104
+ if link["rel"] not in INFERRED_LINK_RELS
105
+ ]
106
+
107
+ return links
108
+
109
+
110
+ @attr.s
111
+ class CollectionLinks(BaseLinks):
112
+ """Create inferred links specific to collections."""
113
+
114
+ collection_id: str = attr.ib()
115
+ extensions: List[str] = attr.ib(default=attr.Factory(list))
116
+
117
+ def link_self(self) -> Dict:
118
+ """Return the self link."""
119
+ return dict(
120
+ rel=Relations.self.value,
121
+ type=MimeTypes.json.value,
122
+ href=urljoin(self.base_url, f"collections/{self.collection_id}"),
123
+ )
124
+
125
+ def link_parent(self) -> Dict[str, Any]:
126
+ """Create the `parent` link."""
127
+ return dict(rel=Relations.parent, type=MimeTypes.json.value, href=self.base_url)
128
+
129
+ def link_items(self) -> Dict[str, Any]:
130
+ """Create the `items` link."""
131
+ return dict(
132
+ rel="items",
133
+ type=MimeTypes.geojson.value,
134
+ href=urljoin(self.base_url, f"collections/{self.collection_id}/items"),
135
+ )
136
+
137
+ def link_queryables(self) -> Dict[str, Any]:
138
+ """Create the `queryables` link."""
139
+ if "FilterExtension" in self.extensions:
140
+ return dict(
141
+ rel="queryables",
142
+ type=MimeTypes.json.value,
143
+ href=urljoin(
144
+ self.base_url, f"collections/{self.collection_id}/queryables"
145
+ ),
146
+ )
147
+ else:
148
+ return None
149
+
150
+ def link_aggregate(self) -> Dict[str, Any]:
151
+ """Create the `aggregate` link."""
152
+ if "AggregationExtension" in self.extensions:
153
+ return dict(
154
+ rel="aggregate",
155
+ type=MimeTypes.json.value,
156
+ href=urljoin(
157
+ self.base_url, f"collections/{self.collection_id}/aggregate"
158
+ ),
159
+ )
160
+ else:
161
+ return None
162
+
163
+ def link_aggregations(self) -> Dict[str, Any]:
164
+ """Create the `aggregations` link."""
165
+ if "AggregationExtension" in self.extensions:
166
+ return dict(
167
+ rel="aggregations",
168
+ type=MimeTypes.json.value,
169
+ href=urljoin(
170
+ self.base_url, f"collections/{self.collection_id}/aggregations"
171
+ ),
172
+ )
173
+ else:
174
+ return None
175
+
176
+
177
+ @attr.s
178
+ class PagingLinks(BaseLinks):
179
+ """Create links for paging."""
180
+
181
+ next: Optional[str] = attr.ib(kw_only=True, default=None)
182
+
183
+ def link_next(self) -> Optional[Dict[str, Any]]:
184
+ """Create link for next page."""
185
+ if self.next is not None:
186
+ method = self.request.method
187
+ if method == "GET":
188
+ href = merge_params(self.url, {"token": self.next})
189
+ link = dict(
190
+ rel=Relations.next.value,
191
+ type=MimeTypes.json.value,
192
+ method=method,
193
+ href=href,
194
+ )
195
+ return link
196
+ if method == "POST":
197
+ return {
198
+ "rel": Relations.next,
199
+ "type": MimeTypes.json,
200
+ "method": method,
201
+ "href": f"{self.request.url}",
202
+ "body": {**self.request.postbody, "token": self.next},
203
+ }
204
+
205
+ return None
@@ -0,0 +1 @@
1
+ """Unused search model."""
@@ -0,0 +1,44 @@
1
+ """Rate limiting middleware."""
2
+
3
+ import logging
4
+ import os
5
+ from typing import Optional
6
+
7
+ from fastapi import FastAPI, Request
8
+ from slowapi import Limiter, _rate_limit_exceeded_handler
9
+ from slowapi.errors import RateLimitExceeded
10
+ from slowapi.middleware import SlowAPIMiddleware
11
+ from slowapi.util import get_remote_address
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_limiter(key_func=get_remote_address):
17
+ """Create and return a Limiter instance for rate limiting."""
18
+ return Limiter(key_func=key_func)
19
+
20
+
21
+ def setup_rate_limit(
22
+ app: FastAPI, rate_limit: Optional[str] = None, key_func=get_remote_address
23
+ ):
24
+ """Set up rate limiting middleware."""
25
+ RATE_LIMIT = rate_limit or os.getenv("STAC_FASTAPI_RATE_LIMIT")
26
+
27
+ if not RATE_LIMIT:
28
+ logger.info("Rate limiting is disabled")
29
+ return
30
+
31
+ logger.info(f"Setting up rate limit with RATE_LIMIT={RATE_LIMIT}")
32
+
33
+ limiter = get_limiter(key_func)
34
+ app.state.limiter = limiter
35
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
36
+ app.add_middleware(SlowAPIMiddleware)
37
+
38
+ @app.middleware("http")
39
+ @limiter.limit(RATE_LIMIT)
40
+ async def rate_limit_middleware(request: Request, call_next):
41
+ response = await call_next(request)
42
+ return response
43
+
44
+ logger.info("Rate limit setup complete")