hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperion/__init__.py +0 -0
- hyperion/asyncutils.py +79 -0
- hyperion/catalog/__init__.py +8 -0
- hyperion/catalog/catalog.py +623 -0
- hyperion/catalog/schema.py +153 -0
- hyperion/collections/__init__.py +0 -0
- hyperion/collections/asset_collection.py +285 -0
- hyperion/config.py +77 -0
- hyperion/dateutils.py +238 -0
- hyperion/entities/__init__.py +0 -0
- hyperion/entities/catalog.py +190 -0
- hyperion/infrastructure/__init__.py +0 -0
- hyperion/infrastructure/aws.py +220 -0
- hyperion/infrastructure/cache.py +396 -0
- hyperion/infrastructure/geo/__init__.py +7 -0
- hyperion/infrastructure/geo/gmaps.py +124 -0
- hyperion/infrastructure/geo/location.py +186 -0
- hyperion/infrastructure/http.py +62 -0
- hyperion/infrastructure/keyval.py +264 -0
- hyperion/infrastructure/queue.py +151 -0
- hyperion/infrastructure/secrets.py +63 -0
- hyperion/logging.py +122 -0
- hyperion/py.typed +0 -0
- hyperion/sources/__init__.py +0 -0
- hyperion/sources/base.py +105 -0
- hyperion/typeutils.py +52 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/METADATA +476 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/RECORD +29 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Common location classes and functions."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import ClassVar, Generic, TypeVar, cast
|
|
8
|
+
|
|
9
|
+
import haversine
|
|
10
|
+
import numpy as np
|
|
11
|
+
import numpy.typing as npt
|
|
12
|
+
from haversine.haversine import get_avg_earth_radius
|
|
13
|
+
|
|
14
|
+
from hyperion.infrastructure.cache import Cache
|
|
15
|
+
from hyperion.logging import get_logger
|
|
16
|
+
|
|
17
|
+
LATITUDE_DEGREE_TO_METERS = 111_000
|
|
18
|
+
EARTH_RADIUS_METERS = get_avg_earth_radius(haversine.Unit.METERS)
|
|
19
|
+
|
|
20
|
+
logger = get_logger("hyperion-geo")
|
|
21
|
+
|
|
22
|
+
AnyLocation = TypeVar("AnyLocation", bound="Location")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def meters_to_degrees(meters: float, at_latitude: float) -> tuple[float, float]:
|
|
26
|
+
"""Convert meters to degrees at a given latitude.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
meters (float): The distance in meters.
|
|
30
|
+
at_latitude (float): The latitude at which the conversion should be done.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
tuple[float, float]: The distance in degrees for latitude and longitude.
|
|
34
|
+
"""
|
|
35
|
+
return meters / LATITUDE_DEGREE_TO_METERS, meters / (
|
|
36
|
+
LATITUDE_DEGREE_TO_METERS * math.cos(math.radians(at_latitude))
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SpatialKMeans(Generic[AnyLocation]):
|
|
41
|
+
"""K-means clustering for geographical locations."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, locations: Iterable[AnyLocation]) -> None:
|
|
44
|
+
"""Initialize the K-means clustering with the given locations.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
locations (Iterable[Location]): The locations to cluster.
|
|
48
|
+
"""
|
|
49
|
+
self.locations = list(locations)
|
|
50
|
+
self.locations_array = np.array(
|
|
51
|
+
[[location.latitude, location.longitude] for location in self.locations],
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def _get_distances_from_centroids(self, centroids: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
|
|
55
|
+
return np.array(
|
|
56
|
+
[
|
|
57
|
+
[haversine.haversine((point[0], point[1]), (centroid[0], centroid[1])) for centroid in centroids]
|
|
58
|
+
for point in self.locations_array
|
|
59
|
+
]
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def fit(self, k: int, max_iters: int = 100) -> dict["Location", list[AnyLocation]]:
|
|
63
|
+
"""Fit the K-means model with k clusters and return the clusters.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
k (int): The desired number of clusters (must be less than the number of locations).
|
|
67
|
+
max_iters (int, optional): The maximum number of iterations. Defaults to 100.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
dict[Location, list[Location]]: The clusters with the centroids as keys.
|
|
71
|
+
"""
|
|
72
|
+
centroids = self.locations_array[np.random.choice(len(self.locations_array), k, replace=False), :2]
|
|
73
|
+
|
|
74
|
+
for _ in range(max_iters):
|
|
75
|
+
distances = self._get_distances_from_centroids(centroids)
|
|
76
|
+
|
|
77
|
+
cluster_assignments = np.argmin(distances, axis=1)
|
|
78
|
+
|
|
79
|
+
new_centroids_list = []
|
|
80
|
+
for cluster_idx in range(k):
|
|
81
|
+
cluster_points = self.locations_array[cluster_assignments == cluster_idx, :2]
|
|
82
|
+
if len(cluster_points) > 0:
|
|
83
|
+
new_centroids_list.append(cluster_points.mean(axis=0))
|
|
84
|
+
else:
|
|
85
|
+
new_centroids_list.append(centroids[cluster_idx]) # Keep old centroid for empty clusters
|
|
86
|
+
new_centroids = np.array(new_centroids_list)
|
|
87
|
+
|
|
88
|
+
if np.allclose(centroids, new_centroids):
|
|
89
|
+
break
|
|
90
|
+
centroids = new_centroids
|
|
91
|
+
|
|
92
|
+
clusters: dict[Location, list[AnyLocation]] = defaultdict(list)
|
|
93
|
+
centroid_locations = [Location(float(lat), float(lon)) for lat, lon in centroids]
|
|
94
|
+
|
|
95
|
+
for location_id, centroid_id in enumerate(cluster_assignments):
|
|
96
|
+
clusters[centroid_locations[centroid_id]].append(self.locations[location_id])
|
|
97
|
+
return dict(clusters)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass(frozen=True, eq=True)
|
|
101
|
+
class Location:
|
|
102
|
+
"""A geographical location."""
|
|
103
|
+
|
|
104
|
+
_cache: ClassVar[Cache | None] = None
|
|
105
|
+
|
|
106
|
+
latitude: float
|
|
107
|
+
longitude: float
|
|
108
|
+
title: str | None = None
|
|
109
|
+
address: str | None = None
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def _get_cache(cls) -> Cache:
|
|
113
|
+
if cls._cache is None:
|
|
114
|
+
cls._cache = Cache.from_config()
|
|
115
|
+
return cls._cache
|
|
116
|
+
|
|
117
|
+
def _get_distance_haversine(self, other: "Location") -> float:
|
|
118
|
+
return cast(
|
|
119
|
+
float,
|
|
120
|
+
haversine.haversine(
|
|
121
|
+
(self.latitude, self.longitude), (other.latitude, other.longitude), haversine.Unit.METERS
|
|
122
|
+
),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def _get_distance_euclidean(self, other: "Location") -> float:
|
|
126
|
+
lat_diff = math.radians(self.latitude - other.latitude)
|
|
127
|
+
lon_diff = math.radians(self.longitude - other.longitude)
|
|
128
|
+
x = lon_diff * math.cos(math.radians((self.latitude + other.latitude) / 2))
|
|
129
|
+
y = lat_diff
|
|
130
|
+
return cast(float, EARTH_RADIUS_METERS) * math.sqrt(x**2 + y**2)
|
|
131
|
+
|
|
132
|
+
def get_distance(self, other: "Location", approximate: bool = False) -> float:
|
|
133
|
+
"""Get the distance to another location in meters.
|
|
134
|
+
If approximate is False (default), will use haversine. Otherwise uses Euclidean to approximate the distance.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
other (Location): The other location.
|
|
138
|
+
approximate (bool, optional): Whether to approximate the distance. Defaults to False.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
float: The distance in meters.
|
|
142
|
+
"""
|
|
143
|
+
cache_key = str((self.latitude, self.longitude, other.latitude, other.longitude, approximate))
|
|
144
|
+
cache = self._get_cache()
|
|
145
|
+
if cached := cache.get(cache_key):
|
|
146
|
+
return float(cached)
|
|
147
|
+
if approximate:
|
|
148
|
+
return self._get_distance_euclidean(other)
|
|
149
|
+
return self._get_distance_haversine(other)
|
|
150
|
+
|
|
151
|
+
def get_nearest(
|
|
152
|
+
self, others: Iterable[AnyLocation], threshold: float | None = None, approximate: bool = False
|
|
153
|
+
) -> AnyLocation:
|
|
154
|
+
"""Get the closest location from the iterable of locations.
|
|
155
|
+
|
|
156
|
+
If threshold is given and all locations are further than the threshold in meters, an ValueError is raised.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
others (Iterable[Location]): The other locations.
|
|
160
|
+
threshold (float, optional): The maximum distance in meters. Defaults to None.
|
|
161
|
+
approximate (bool, optional): Whether to approximate the distance. Defaults to False.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Location: The nearest location.
|
|
165
|
+
"""
|
|
166
|
+
nearest: tuple[AnyLocation, float] | None = None
|
|
167
|
+
for other in others:
|
|
168
|
+
distance = self.get_distance(other, approximate=approximate)
|
|
169
|
+
if nearest is None or nearest[1] > distance:
|
|
170
|
+
nearest = (other, distance)
|
|
171
|
+
if nearest is None:
|
|
172
|
+
raise ValueError(f"None of the given locations is close enough to {self!r}.")
|
|
173
|
+
logger.debug("Found nearest location.", this=self, other=nearest[0], distance=nearest[1])
|
|
174
|
+
return nearest[0]
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@dataclass(frozen=True, eq=True)
|
|
178
|
+
class NamedLocation:
|
|
179
|
+
location: Location
|
|
180
|
+
route: str | None = None
|
|
181
|
+
neighborhood: str | None = None
|
|
182
|
+
sublocality: str | None = None
|
|
183
|
+
administrative_area: str | None = None
|
|
184
|
+
administrative_area_level_2: str | None = None
|
|
185
|
+
country: str | None = None
|
|
186
|
+
address: str | None = None
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from types import TracebackType
|
|
3
|
+
from urllib.parse import urlparse
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from hyperion.config import http_config
|
|
8
|
+
from hyperion.logging import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger("hyperion-http")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def redact_url(url: str, replace: str = "***") -> str:
|
|
14
|
+
"""Replace password from {url} (if any) with {replace}.
|
|
15
|
+
If the url contains no password, url is returned unchanged.
|
|
16
|
+
"""
|
|
17
|
+
parsed = urlparse(url)
|
|
18
|
+
if parsed.password:
|
|
19
|
+
return parsed._replace(netloc=f"{parsed.username}:{replace}@{parsed.hostname}").geturl()
|
|
20
|
+
return url
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AsyncHTTPClientWrapper:
|
|
24
|
+
def __init__(self) -> None:
|
|
25
|
+
self._client: httpx.AsyncClient | None = None
|
|
26
|
+
self._stacklevel = 0
|
|
27
|
+
self._lock = asyncio.Lock()
|
|
28
|
+
|
|
29
|
+
async def __aenter__(self) -> httpx.AsyncClient:
|
|
30
|
+
async with self._lock:
|
|
31
|
+
self._stacklevel += 1
|
|
32
|
+
if self._client is None or self._client.is_closed:
|
|
33
|
+
logger.debug("Creating and entering httpx async client.")
|
|
34
|
+
self._client = httpx.AsyncClient(mounts=self._get_proxy_mounts())
|
|
35
|
+
await self._client.__aenter__()
|
|
36
|
+
return self._client
|
|
37
|
+
|
|
38
|
+
async def __aexit__(
|
|
39
|
+
self,
|
|
40
|
+
exc_type: type[BaseException] | None = None,
|
|
41
|
+
exc_value: BaseException | None = None,
|
|
42
|
+
traceback: TracebackType | None = None,
|
|
43
|
+
) -> None:
|
|
44
|
+
async with self._lock:
|
|
45
|
+
self._stacklevel -= 1
|
|
46
|
+
if self._stacklevel == 0 and self._client is not None:
|
|
47
|
+
logger.info("Closing httpx async client.")
|
|
48
|
+
await self._client.__aexit__(exc_type, exc_value, traceback)
|
|
49
|
+
self._client = None
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def _get_proxy_mounts() -> dict[str, httpx.AsyncHTTPTransport]:
|
|
53
|
+
proxy_mounts: dict[str, httpx.AsyncHTTPTransport] = {}
|
|
54
|
+
if http_config.proxy_http:
|
|
55
|
+
redacted_url = redact_url(http_config.proxy_http)
|
|
56
|
+
logger.info("Configuring HTTP proxy for http://", proxy_url=redacted_url)
|
|
57
|
+
proxy_mounts["http://"] = httpx.AsyncHTTPTransport(proxy=http_config.proxy_http)
|
|
58
|
+
if http_config.proxy_https:
|
|
59
|
+
redacted_url = redact_url(http_config.proxy_https)
|
|
60
|
+
logger.info("Configuring HTTP proxy for https://", proxy_url=redacted_url)
|
|
61
|
+
proxy_mounts["https://"] = httpx.AsyncHTTPTransport(proxy=http_config.proxy_https)
|
|
62
|
+
return proxy_mounts
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Key-Value stores.
|
|
2
|
+
|
|
3
|
+
If you need a store with a TTL, look at Cache instead.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import gzip
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections.abc import Iterable, Iterator
|
|
9
|
+
from fnmatch import fnmatch
|
|
10
|
+
from typing import Literal, TypeGuard, cast
|
|
11
|
+
|
|
12
|
+
import boto3
|
|
13
|
+
import snappy
|
|
14
|
+
|
|
15
|
+
CompressionType = Literal["snappy", "gzip"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_valid_compression_type(compression_type: str) -> TypeGuard[CompressionType]:
|
|
19
|
+
return compression_type in ("snappy", "gzip")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class KeyValueStore(ABC, Iterable[str]):
|
|
23
|
+
"""A key-value store."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, prefix: str | None = None, compression: CompressionType | None = None) -> None:
|
|
26
|
+
"""Initialize the store.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
prefix: A prefix to add to all keys.
|
|
30
|
+
compression: The compression algorithm to use.
|
|
31
|
+
"""
|
|
32
|
+
self.prefix = prefix
|
|
33
|
+
self.compression = compression
|
|
34
|
+
|
|
35
|
+
def _key(self, key: str) -> str:
|
|
36
|
+
"""Return the key with the prefix added."""
|
|
37
|
+
prefix = f"{self.prefix}:" if self.prefix else ""
|
|
38
|
+
return f"{prefix}{key}"
|
|
39
|
+
|
|
40
|
+
def _compress(self, value: str) -> bytes:
|
|
41
|
+
"""Compress a value using configured compression.
|
|
42
|
+
|
|
43
|
+
If no compression is selected, the string is encoded as utf-8 into bytes.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
value: The value to compress.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
bytes: The compressed value.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: If an unsupported compression is selected.
|
|
53
|
+
"""
|
|
54
|
+
value_bytes = value.encode("utf-8")
|
|
55
|
+
match self.compression:
|
|
56
|
+
case None:
|
|
57
|
+
return value_bytes
|
|
58
|
+
case "snappy":
|
|
59
|
+
return cast(bytes, snappy.compress(value_bytes))
|
|
60
|
+
case "gzip":
|
|
61
|
+
return gzip.compress(value_bytes)
|
|
62
|
+
case _:
|
|
63
|
+
raise ValueError(f"Unsupported compression {self.compression!r}.")
|
|
64
|
+
|
|
65
|
+
def _decompress(self, value: bytes) -> str:
|
|
66
|
+
"""Decompress a value using configured compression.
|
|
67
|
+
|
|
68
|
+
If no compression is selected, the bytes are decoded as utf-8 into a string.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
value: The value to decompress.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
str: The decompressed value.
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
ValueError: If an unsupported compression is selected.
|
|
78
|
+
TypeError: If the decompression returns an unexpected type.
|
|
79
|
+
"""
|
|
80
|
+
match self.compression:
|
|
81
|
+
case None:
|
|
82
|
+
return value.decode("utf-8")
|
|
83
|
+
case "snappy":
|
|
84
|
+
value_decompressed = snappy.decompress(value)
|
|
85
|
+
if isinstance(value_decompressed, str):
|
|
86
|
+
return value_decompressed
|
|
87
|
+
elif isinstance(value_decompressed, bytes):
|
|
88
|
+
return value_decompressed.decode("utf-8")
|
|
89
|
+
raise TypeError(
|
|
90
|
+
"Unexpected value returned from snappy decompression, "
|
|
91
|
+
f"expected str | bytes, got {type(value_decompressed)!r}."
|
|
92
|
+
)
|
|
93
|
+
case "gzip":
|
|
94
|
+
return gzip.decompress(value).decode("utf-8")
|
|
95
|
+
case _:
|
|
96
|
+
raise ValueError(f"Unsupported compression {self.compression!r}.")
|
|
97
|
+
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def _get_raw(self, hashed_key: str) -> bytes | None:
|
|
100
|
+
"""Get a raw uncompressed value from the store.
|
|
101
|
+
|
|
102
|
+
The key passed here must already be hashed and prefixed.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
hashed_key: The key to get.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
bytes | None: The raw value, or None if not found.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def get(self, key: str) -> str | None:
|
|
112
|
+
"""Gets a value from the store."""
|
|
113
|
+
if value := self._get_raw(self._key(key)):
|
|
114
|
+
return self._decompress(value)
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def _set_raw(self, hashed_key: str, compresed_value: bytes) -> None:
|
|
119
|
+
"""Set a raw, compressed bytes value in the store.
|
|
120
|
+
|
|
121
|
+
The key passed here must already by hashed and prefixed.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
hashed_key: The key to set.
|
|
125
|
+
compresed_value: The compressed value to set.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def set(self, key: str, value: str) -> None:
|
|
129
|
+
"""Sets a value in the store."""
|
|
130
|
+
return self._set_raw(self._key(key), self._compress(value))
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def _delete_raw(self, hashed_key: str) -> None:
|
|
134
|
+
"""Delete a value from the store.
|
|
135
|
+
|
|
136
|
+
The key passed here must already by hashed and prefixed.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
hashed_key: The key to delete.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def __iter__(self) -> Iterator[str]:
|
|
143
|
+
"""Iterate all keys available in the store."""
|
|
144
|
+
return iter(self._iter_all_keys())
|
|
145
|
+
|
|
146
|
+
@abstractmethod
|
|
147
|
+
def _iter_all_keys(self) -> Iterable[str]:
|
|
148
|
+
"""Iterate all keys available in the store."""
|
|
149
|
+
|
|
150
|
+
def keys(self, match: str = "*") -> Iterable[str]:
|
|
151
|
+
"""Iterate all keys available in the store that match the given expression.
|
|
152
|
+
|
|
153
|
+
Nothing fancy is supported, UNIX-style fnmatch is performed.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
match: The expression to match.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Iterable[str]: The keys that match the expression
|
|
160
|
+
"""
|
|
161
|
+
for key in self._iter_all_keys():
|
|
162
|
+
if fnmatch(key, match):
|
|
163
|
+
yield key
|
|
164
|
+
|
|
165
|
+
def delete(self, key: str) -> None:
|
|
166
|
+
"""Delete a value from the store.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
key: The key to delete.
|
|
170
|
+
"""
|
|
171
|
+
return self._delete_raw(self._key(key))
|
|
172
|
+
|
|
173
|
+
def exists(self, key: str) -> bool:
|
|
174
|
+
"""Returns whether a key exists in the store.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
key: The key to check.
|
|
178
|
+
"""
|
|
179
|
+
return key in self.keys()
|
|
180
|
+
|
|
181
|
+
def __contains__(self, key: str) -> bool:
|
|
182
|
+
return self.exists(key)
|
|
183
|
+
|
|
184
|
+
def __getitem__(self, key: str) -> str | None:
|
|
185
|
+
return self.get(key)
|
|
186
|
+
|
|
187
|
+
def __setitem__(self, key: str, value: str) -> None:
|
|
188
|
+
return self.set(key, value)
|
|
189
|
+
|
|
190
|
+
def __delitem__(self, key: str) -> None:
|
|
191
|
+
return self.delete(key)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class DynamoDBStore(KeyValueStore):
|
|
195
|
+
"""A key-value store using DynamoDB.
|
|
196
|
+
|
|
197
|
+
The table must have a key attribute and a value attribute.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
prefix: str | None = None,
|
|
203
|
+
compression: CompressionType | None = None,
|
|
204
|
+
table_name: str | None = None,
|
|
205
|
+
key_attribute: str = "key",
|
|
206
|
+
value_attribute: str = "value",
|
|
207
|
+
):
|
|
208
|
+
super().__init__(prefix, compression)
|
|
209
|
+
self.client = boto3.resource("dynamodb")
|
|
210
|
+
if table_name is None:
|
|
211
|
+
raise ValueError("No table name provided for DynamoDBStore.")
|
|
212
|
+
self.table_name = table_name
|
|
213
|
+
self.table = self.client.Table(self.table_name)
|
|
214
|
+
self.key_attribute = key_attribute
|
|
215
|
+
self.value_attribute = value_attribute
|
|
216
|
+
|
|
217
|
+
def _get_raw(self, hashed_key: str) -> bytes | None:
|
|
218
|
+
response = self.table.get_item(Key={self.key_attribute: hashed_key})
|
|
219
|
+
item = response.get("Item")
|
|
220
|
+
if not item:
|
|
221
|
+
return None
|
|
222
|
+
return bytes(item[self.value_attribute])
|
|
223
|
+
|
|
224
|
+
def _set_raw(self, hashed_key: str, compresed_value: bytes) -> None:
|
|
225
|
+
self.table.put_item(Item={self.key_attribute: hashed_key, self.value_attribute: compresed_value})
|
|
226
|
+
|
|
227
|
+
def _delete_raw(self, hashed_key: str) -> None:
|
|
228
|
+
self.table.delete_item(Key={self.key_attribute: hashed_key})
|
|
229
|
+
|
|
230
|
+
def _iter_all_keys(self) -> Iterable[str]:
|
|
231
|
+
scan_kwargs = {"ProjectionExpression": "#k", "ExpressionAttributeNames": {"#k": self.key_attribute}}
|
|
232
|
+
while True:
|
|
233
|
+
response = self.table.scan(**scan_kwargs)
|
|
234
|
+
for item in response.get("Items", []):
|
|
235
|
+
key = str(item[self.key_attribute])
|
|
236
|
+
if self.prefix:
|
|
237
|
+
key = key.replace(f"{self.prefix}:", "", 1)
|
|
238
|
+
yield key
|
|
239
|
+
if "LastEvaluatedKey" not in response:
|
|
240
|
+
break
|
|
241
|
+
scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class InMemoryStore(KeyValueStore):
|
|
245
|
+
"""An in-memory key-value store.
|
|
246
|
+
|
|
247
|
+
This store is not persistent and is not shared between instances.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self, prefix: str | None = None, compression: CompressionType | None = None):
|
|
251
|
+
super().__init__(prefix, compression)
|
|
252
|
+
self.store: dict[str, bytes] = {}
|
|
253
|
+
|
|
254
|
+
def _get_raw(self, hashed_key: str) -> bytes | None:
|
|
255
|
+
return self.store.get(hashed_key)
|
|
256
|
+
|
|
257
|
+
def _set_raw(self, hashed_key: str, compresed_value: bytes) -> None:
|
|
258
|
+
self.store[hashed_key] = compresed_value
|
|
259
|
+
|
|
260
|
+
def _delete_raw(self, hashed_key: str) -> None:
|
|
261
|
+
self.store.pop(hashed_key, None)
|
|
262
|
+
|
|
263
|
+
def _iter_all_keys(self) -> Iterable[str]:
|
|
264
|
+
yield from self.store
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import datetime
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import ClassVar
|
|
6
|
+
|
|
7
|
+
from aws_lambda_typing.events import SQSEvent
|
|
8
|
+
from boto3 import client
|
|
9
|
+
|
|
10
|
+
# After a long consideration, I decided to use pydantic to serialize / deserialize the messages.
|
|
11
|
+
# The convenience outweighs the performance hit.
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
|
+
|
|
14
|
+
from hyperion.config import config, queue_config
|
|
15
|
+
from hyperion.dateutils import utcnow
|
|
16
|
+
from hyperion.entities.catalog import DataLakeAsset
|
|
17
|
+
from hyperion.logging import get_logger
|
|
18
|
+
|
|
19
|
+
logger = get_logger("hyperion-queue")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ArrivalEvent(str, Enum):
|
|
23
|
+
ARRIVED = "arrived"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Message(BaseModel):
|
|
27
|
+
_subclasses: ClassVar[dict[str, type["Message"]]] = {}
|
|
28
|
+
|
|
29
|
+
def __init_subclass__(cls: type["Message"]) -> None:
|
|
30
|
+
cls._subclasses[cls.__name__] = cls
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def deserialize(json_str: str, message_type: str, receipt_handle: str | None = None) -> "Message":
|
|
34
|
+
msg = Message._subclasses[message_type].model_validate_json(json_str)
|
|
35
|
+
if receipt_handle and not msg.receipt_handle:
|
|
36
|
+
msg.receipt_handle = receipt_handle
|
|
37
|
+
return msg
|
|
38
|
+
|
|
39
|
+
created: datetime.datetime = Field(default_factory=utcnow)
|
|
40
|
+
sender: str = config.service_name
|
|
41
|
+
receipt_handle: str | None = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DataLakeArrivalMessage(Message):
|
|
45
|
+
asset: DataLakeAsset
|
|
46
|
+
event: ArrivalEvent
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SourceBackfillMessage(Message):
|
|
50
|
+
source: str
|
|
51
|
+
start_date: datetime.datetime | datetime.date
|
|
52
|
+
end_date: datetime.datetime | datetime.date | None = None
|
|
53
|
+
notify: bool = True
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def iter_messages_from_sqs_event(event: SQSEvent) -> Iterator[Message]:
|
|
57
|
+
if not isinstance(event, dict) or "Records" not in event or not isinstance(event["Records"], list):
|
|
58
|
+
raise ValueError("Provided event is not a valid SQS Event.")
|
|
59
|
+
for record in event["Records"]:
|
|
60
|
+
logger.info("Deserializing message.", message_id=record["messageId"])
|
|
61
|
+
if "MessageType" not in record["messageAttributes"]:
|
|
62
|
+
logger.warning("Message has no type and probably does not come from Hyperion.", message=record)
|
|
63
|
+
continue
|
|
64
|
+
message_type = record["messageAttributes"]["MessageType"]["stringValue"]
|
|
65
|
+
logger.debug("Attempting to deserialize message.", message=record, type=message_type)
|
|
66
|
+
yield Message.deserialize(record["body"], message_type, record["receiptHandle"])
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def create_backfill_event(
|
|
70
|
+
message: SourceBackfillMessage, message_id: str = "test", receipt_handle: str = ""
|
|
71
|
+
) -> SQSEvent:
|
|
72
|
+
return {
|
|
73
|
+
"Records": [
|
|
74
|
+
{
|
|
75
|
+
"body": message.model_dump_json(),
|
|
76
|
+
"messageId": message_id,
|
|
77
|
+
"receiptHandle": receipt_handle,
|
|
78
|
+
"messageAttributes": {
|
|
79
|
+
"MessageType": {
|
|
80
|
+
"binaryListValues": [],
|
|
81
|
+
"binaryValue": None,
|
|
82
|
+
"dataType": "String",
|
|
83
|
+
"stringListValues": [],
|
|
84
|
+
"stringValue": "SourceBackfillMessage",
|
|
85
|
+
}
|
|
86
|
+
},
|
|
87
|
+
}
|
|
88
|
+
]
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class Queue(abc.ABC):
|
|
93
|
+
@staticmethod
|
|
94
|
+
def from_config() -> "Queue":
|
|
95
|
+
if queue_config.url is None:
|
|
96
|
+
logger.info("Using in-memory queue.")
|
|
97
|
+
return InMemoryQueue()
|
|
98
|
+
logger.info("Using SQS queue.")
|
|
99
|
+
return SQSQueue(queue_config.url)
|
|
100
|
+
|
|
101
|
+
@abc.abstractmethod
|
|
102
|
+
def send(self, message: Message) -> None:
|
|
103
|
+
"""Send a message to the queue."""
|
|
104
|
+
|
|
105
|
+
@abc.abstractmethod
|
|
106
|
+
def delete(self, receipt_handle: str) -> None:
|
|
107
|
+
"""Delete a message from the queue."""
|
|
108
|
+
|
|
109
|
+
def __repr__(self) -> str:
|
|
110
|
+
return f"<{self.__class__.__name__}>"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class InMemoryQueue(Queue):
|
|
114
|
+
def __init__(self) -> None:
|
|
115
|
+
super().__init__()
|
|
116
|
+
self._messages: list[Message] = []
|
|
117
|
+
|
|
118
|
+
def send(self, message: Message) -> None:
|
|
119
|
+
self._messages.append(message)
|
|
120
|
+
|
|
121
|
+
def delete(self, receipt_handle: str) -> None:
|
|
122
|
+
"""Delete the message using the created as isoformat."""
|
|
123
|
+
for message in self._messages:
|
|
124
|
+
if message.created.isoformat() == receipt_handle:
|
|
125
|
+
self._messages.remove(message)
|
|
126
|
+
break
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class SQSQueue(Queue):
|
|
130
|
+
def __init__(self, queue_url: str) -> None:
|
|
131
|
+
super().__init__()
|
|
132
|
+
self._queue_url = queue_url
|
|
133
|
+
self._client = client("sqs")
|
|
134
|
+
|
|
135
|
+
def send(self, message: Message) -> None:
|
|
136
|
+
self._client.send_message(
|
|
137
|
+
QueueUrl=self._queue_url,
|
|
138
|
+
MessageBody=message.model_dump_json(),
|
|
139
|
+
MessageAttributes={
|
|
140
|
+
"MessageType": {
|
|
141
|
+
"DataType": "String",
|
|
142
|
+
"StringValue": message.__class__.__name__,
|
|
143
|
+
}
|
|
144
|
+
},
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def delete(self, receipt_handle: str) -> None:
|
|
148
|
+
self._client.delete_message(QueueUrl=self._queue_url, ReceiptHandle=receipt_handle)
|
|
149
|
+
|
|
150
|
+
def __repr__(self) -> str:
|
|
151
|
+
return f"<{self.__class__.__name__} {self._queue_url}>"
|