hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperion/__init__.py +0 -0
- hyperion/asyncutils.py +79 -0
- hyperion/catalog/__init__.py +8 -0
- hyperion/catalog/catalog.py +623 -0
- hyperion/catalog/schema.py +153 -0
- hyperion/collections/__init__.py +0 -0
- hyperion/collections/asset_collection.py +285 -0
- hyperion/config.py +77 -0
- hyperion/dateutils.py +238 -0
- hyperion/entities/__init__.py +0 -0
- hyperion/entities/catalog.py +190 -0
- hyperion/infrastructure/__init__.py +0 -0
- hyperion/infrastructure/aws.py +220 -0
- hyperion/infrastructure/cache.py +396 -0
- hyperion/infrastructure/geo/__init__.py +7 -0
- hyperion/infrastructure/geo/gmaps.py +124 -0
- hyperion/infrastructure/geo/location.py +186 -0
- hyperion/infrastructure/http.py +62 -0
- hyperion/infrastructure/keyval.py +264 -0
- hyperion/infrastructure/queue.py +151 -0
- hyperion/infrastructure/secrets.py +63 -0
- hyperion/logging.py +122 -0
- hyperion/py.typed +0 -0
- hyperion/sources/__init__.py +0 -0
- hyperion/sources/base.py +105 -0
- hyperion/typeutils.py +52 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/METADATA +476 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/RECORD +29 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
"""A serverless cache for our shenanigans."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import tempfile
|
|
5
|
+
import time
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, ClassVar, cast
|
|
9
|
+
|
|
10
|
+
import boto3
|
|
11
|
+
import cachetools
|
|
12
|
+
import snappy
|
|
13
|
+
|
|
14
|
+
from hyperion.catalog import AssetNotFoundError, Catalog
|
|
15
|
+
from hyperion.config import storage_config
|
|
16
|
+
from hyperion.dateutils import utcnow
|
|
17
|
+
from hyperion.entities.catalog import PersistentStoreAsset
|
|
18
|
+
from hyperion.logging import get_logger
|
|
19
|
+
|
|
20
|
+
DEFAULT_TTL_SECONDS = 60
|
|
21
|
+
DYNAMODB_MAX_LENGTH = 65535
|
|
22
|
+
DEFAULT_LOCAL_FILE_CACHE_MAX_SIZE = 256 * (1024**2)
|
|
23
|
+
|
|
24
|
+
logger = get_logger("cache")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CachingError(Exception):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Cache(ABC):
|
|
32
|
+
"""A serverless cache for our shenanigans."""
|
|
33
|
+
|
|
34
|
+
_instances: ClassVar[dict[tuple[str, bool], "Cache"]] = {}
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def from_config(cls) -> "Cache":
|
|
38
|
+
"""
|
|
39
|
+
Creates a cache from the configuration.
|
|
40
|
+
|
|
41
|
+
This function emulates a singleton pattern, so it will return the same instance for the same configuration.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Cache: A cache instance.
|
|
45
|
+
"""
|
|
46
|
+
instance_key = (storage_config.cache_key_prefix, True)
|
|
47
|
+
if instance_key not in Cache._instances:
|
|
48
|
+
if storage_config.cache_dynamodb_table:
|
|
49
|
+
logger.info("Using DynamoDB Cache.")
|
|
50
|
+
cls._instances[instance_key] = DynamoDBCache(
|
|
51
|
+
prefix=storage_config.cache_key_prefix,
|
|
52
|
+
default_ttl=storage_config.cache_dynamodb_default_ttl,
|
|
53
|
+
table_name=storage_config.cache_dynamodb_table,
|
|
54
|
+
)
|
|
55
|
+
elif storage_config.cache_local_path:
|
|
56
|
+
logger.info("Using LocalFileCache.", path=storage_config.cache_local_path)
|
|
57
|
+
cls._instances[instance_key] = LocalFileCache(
|
|
58
|
+
prefix=storage_config.cache_key_prefix, root_path=Path(storage_config.cache_local_path)
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
logger.info("Using InMemory Cache.")
|
|
62
|
+
cls._instances[instance_key] = InMemoryCache(
|
|
63
|
+
prefix=storage_config.cache_key_prefix, default_ttl=storage_config.cache_dynamodb_default_ttl
|
|
64
|
+
)
|
|
65
|
+
return cls._instances[instance_key]
|
|
66
|
+
|
|
67
|
+
def __init__(self, prefix: str, hash_keys: bool = True, default_ttl: int = DEFAULT_TTL_SECONDS):
|
|
68
|
+
"""Initializes the cache with the given prefix and default TTL.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
prefix (str): The prefix for the cache keys.
|
|
72
|
+
hash_keys (bool): Whether to hash the keys.
|
|
73
|
+
default_ttl (int): The default TTL for the cache.
|
|
74
|
+
"""
|
|
75
|
+
self.prefix = prefix
|
|
76
|
+
self.hash_keys = hash_keys
|
|
77
|
+
self.default_ttl = default_ttl
|
|
78
|
+
|
|
79
|
+
def _key(self, key: str) -> str:
|
|
80
|
+
"""Generates a cache key from a given key."""
|
|
81
|
+
if self.hash_keys:
|
|
82
|
+
key = hashlib.sha256(key.encode(encoding="utf-8")).hexdigest()
|
|
83
|
+
prefix = f"{self.prefix}:" if self.prefix else ""
|
|
84
|
+
return f"{prefix}{key}"
|
|
85
|
+
|
|
86
|
+
def _compress(self, value: str) -> bytes:
|
|
87
|
+
"""Compresses a value using snappy compression."""
|
|
88
|
+
compressed_value = snappy.compress(value.encode(encoding="utf-8"))
|
|
89
|
+
logger.debug(
|
|
90
|
+
"Compressed value using snappy compression.",
|
|
91
|
+
original_length=len(value),
|
|
92
|
+
compressed_length=len(compressed_value),
|
|
93
|
+
ratio=len(compressed_value) / len(value),
|
|
94
|
+
)
|
|
95
|
+
return cast(bytes, compressed_value)
|
|
96
|
+
|
|
97
|
+
def _decompress(self, value: bytes) -> str:
|
|
98
|
+
"""Decompresses a value using snappy decompression."""
|
|
99
|
+
return cast(str, snappy.decompress(value).decode(encoding="utf-8"))
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def get(self, key: str) -> str | None:
|
|
103
|
+
"""Gets a value from the cache."""
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def set(self, key: str, value: str) -> None:
|
|
108
|
+
"""Sets a value in the cache."""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def delete(self, key: str) -> None:
|
|
113
|
+
"""Deletes a value from the cache."""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def clear(self) -> None:
|
|
118
|
+
"""Clears the cache."""
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
def hit(self, key: str) -> bool:
|
|
123
|
+
"""Checks if a key exists in the cache."""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class InMemoryCache(Cache):
|
|
128
|
+
"""An in-memory cache for our shenanigans."""
|
|
129
|
+
|
|
130
|
+
MAX_KEYS = 1000
|
|
131
|
+
|
|
132
|
+
def __init__(
|
|
133
|
+
self, prefix: str, hash_keys: bool = True, default_ttl: int = DEFAULT_TTL_SECONDS, max_size: int = MAX_KEYS
|
|
134
|
+
):
|
|
135
|
+
super().__init__(prefix, hash_keys, default_ttl)
|
|
136
|
+
self.max_size = max_size
|
|
137
|
+
self.cache = cachetools.TTLCache[str, str](maxsize=self.max_size, ttl=self.default_ttl)
|
|
138
|
+
|
|
139
|
+
def get(self, key: str) -> str | None:
|
|
140
|
+
return self.cache.get(self._key(key))
|
|
141
|
+
|
|
142
|
+
def set(self, key: str, value: str) -> None:
|
|
143
|
+
self.cache[self._key(key)] = value
|
|
144
|
+
|
|
145
|
+
def delete(self, key: str) -> None:
|
|
146
|
+
self.cache.pop(self._key(key), None)
|
|
147
|
+
|
|
148
|
+
def clear(self) -> None:
|
|
149
|
+
self.cache.clear()
|
|
150
|
+
|
|
151
|
+
def hit(self, key: str) -> bool:
|
|
152
|
+
return self._key(key) in self.cache
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class LocalFileCache(Cache):
|
|
156
|
+
"""A local file cache for our shenanigans."""
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
prefix: str,
|
|
161
|
+
hash_keys: bool = True,
|
|
162
|
+
default_ttl: int = DEFAULT_TTL_SECONDS,
|
|
163
|
+
root_path: Path | None = None,
|
|
164
|
+
max_size: int | None = DEFAULT_LOCAL_FILE_CACHE_MAX_SIZE,
|
|
165
|
+
) -> None:
|
|
166
|
+
super().__init__(prefix, hash_keys, default_ttl)
|
|
167
|
+
self.root_path = root_path or Path(tempfile.mkdtemp())
|
|
168
|
+
if self.root_path.exists() and not self.root_path.is_dir():
|
|
169
|
+
raise ValueError(f"Given local cache path ({self.root_path.as_posix()}) is not a directory.")
|
|
170
|
+
self.max_size = max_size
|
|
171
|
+
self._assert_root_path()
|
|
172
|
+
logger.info(f"Initialized LocalFileCache in {self.root_path.as_posix()}.", root_path=self.root_path.as_posix())
|
|
173
|
+
if not hash_keys:
|
|
174
|
+
logger.warning("When using filesystem cache, it is recommended to hash keys.")
|
|
175
|
+
self.shrink_to_fit_max_size()
|
|
176
|
+
|
|
177
|
+
def _assert_root_path(self) -> None:
|
|
178
|
+
self.root_path.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
|
|
180
|
+
def cleanup(self) -> None:
|
|
181
|
+
"""Clean up all expired files from the cache."""
|
|
182
|
+
self._assert_root_path()
|
|
183
|
+
for key_path in self.root_path.iterdir():
|
|
184
|
+
if key_path.is_file() and self._is_expired(key_path):
|
|
185
|
+
logger.debug("Cleaning up expired file.", key_path=key_path)
|
|
186
|
+
key_path.unlink()
|
|
187
|
+
|
|
188
|
+
def get_total_size(self) -> int:
|
|
189
|
+
self._assert_root_path()
|
|
190
|
+
return sum(key_path.stat().st_size for key_path in self.root_path.iterdir() if key_path.is_file())
|
|
191
|
+
|
|
192
|
+
def shrink_to_fit_max_size(self) -> None:
|
|
193
|
+
self.cleanup()
|
|
194
|
+
if not self.max_size or self.max_size < 0:
|
|
195
|
+
return
|
|
196
|
+
total_size = self.get_total_size()
|
|
197
|
+
keys_ordered = sorted(self.root_path.iterdir(), key=lambda key: key.stat().st_mtime)
|
|
198
|
+
while total_size > self.max_size:
|
|
199
|
+
key_path = keys_ordered.pop(0)
|
|
200
|
+
if not key_path.is_file():
|
|
201
|
+
continue
|
|
202
|
+
size = key_path.stat().st_size
|
|
203
|
+
logger.debug("Cleaning up old file to make some space.", key_path=key_path, size=size)
|
|
204
|
+
key_path.unlink()
|
|
205
|
+
total_size -= size
|
|
206
|
+
|
|
207
|
+
def _key_path(self, key: str) -> Path:
|
|
208
|
+
return self.root_path / self._key(key)
|
|
209
|
+
|
|
210
|
+
def _is_expired(self, key: str | Path) -> bool:
|
|
211
|
+
self._assert_root_path()
|
|
212
|
+
if isinstance(key, str):
|
|
213
|
+
key = self._key_path(key)
|
|
214
|
+
current_time = time.time()
|
|
215
|
+
return (current_time - key.stat().st_mtime) > self.default_ttl
|
|
216
|
+
|
|
217
|
+
def get(self, key: str) -> str | None:
|
|
218
|
+
self._assert_root_path()
|
|
219
|
+
key_path = self._key_path(key)
|
|
220
|
+
if not key_path.exists():
|
|
221
|
+
return None
|
|
222
|
+
if self._is_expired(key_path):
|
|
223
|
+
logger.debug("Key is expired, deleting file.", key=key, key_path=key_path)
|
|
224
|
+
key_path.unlink()
|
|
225
|
+
return None
|
|
226
|
+
logger.debug("Reading key from file.", key=key, path=key_path.as_posix())
|
|
227
|
+
return self._decompress(key_path.read_bytes())
|
|
228
|
+
|
|
229
|
+
def set(self, key: str, value: str) -> None:
|
|
230
|
+
self._assert_root_path()
|
|
231
|
+
key_path = self._key_path(key)
|
|
232
|
+
logger.debug("Storing key into a file.", key=key, path=key_path.as_posix())
|
|
233
|
+
key_path.write_bytes(self._compress(value))
|
|
234
|
+
|
|
235
|
+
def delete(self, key: str) -> None:
|
|
236
|
+
self._assert_root_path()
|
|
237
|
+
key_path = self._key_path(key)
|
|
238
|
+
if not key_path.exists():
|
|
239
|
+
return None
|
|
240
|
+
logger.debug("Removing cached key.", key=key, path=key_path.as_posix())
|
|
241
|
+
key_path.unlink()
|
|
242
|
+
|
|
243
|
+
def hit(self, key: str) -> bool:
|
|
244
|
+
key_path = self._key_path(key)
|
|
245
|
+
return key_path.exists()
|
|
246
|
+
|
|
247
|
+
def clear(self) -> None:
|
|
248
|
+
self._assert_root_path()
|
|
249
|
+
for file in self.root_path.iterdir():
|
|
250
|
+
if not file.is_file():
|
|
251
|
+
continue
|
|
252
|
+
logger.debug("Removing cached key.", path=file.as_posix())
|
|
253
|
+
file.unlink()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class PersistentCache(Cache):
|
|
257
|
+
"""Uses a persistent store asset to store the cached data.
|
|
258
|
+
|
|
259
|
+
Persistent store asset is a key-value store that is stored in the catalog.
|
|
260
|
+
Please note that for now, there is no locking mechanism in place and two services using the same cache
|
|
261
|
+
may overwrite each other's data.
|
|
262
|
+
|
|
263
|
+
Note that `default_ttl` and all ttl-related arguments are ignored.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
# TODO: Implement a locking or ownership mechanism
|
|
267
|
+
# https://github.com/Zephyr-Trade/FVE-map/issues/11
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
prefix: str,
|
|
271
|
+
hash_keys: bool = True,
|
|
272
|
+
default_ttl: int = DEFAULT_TTL_SECONDS,
|
|
273
|
+
asset: "PersistentStoreAsset | None" = None,
|
|
274
|
+
catalog: Catalog | None = None,
|
|
275
|
+
):
|
|
276
|
+
"""Initializes the persistent cache with the given prefix, asset, and catalog.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
prefix (str): The prefix for the cache keys.
|
|
280
|
+
hash_keys (bool): Whether to hash the keys.
|
|
281
|
+
default_ttl (int): The default TTL for the cache.
|
|
282
|
+
asset (PersistentStoreAsset, optional): The asset to use for the cache. Must be provided.
|
|
283
|
+
catalog (Catalog, optional): The catalog to use for the cache. Defaults to None
|
|
284
|
+
and creates a new one from the config.
|
|
285
|
+
"""
|
|
286
|
+
super().__init__(prefix, hash_keys, default_ttl)
|
|
287
|
+
if asset is None:
|
|
288
|
+
raise ValueError("No asset provided for persistent cache.")
|
|
289
|
+
self.asset = asset
|
|
290
|
+
self.catalog = catalog or Catalog.from_config()
|
|
291
|
+
self._data: dict[str, str] | None = None
|
|
292
|
+
|
|
293
|
+
def __enter__(self) -> None:
|
|
294
|
+
"""Retrieve the store to work with the cache."""
|
|
295
|
+
try:
|
|
296
|
+
data = self.catalog.retrieve_asset(self.asset)
|
|
297
|
+
self._data = {row["key"]: row["value"] for row in data}
|
|
298
|
+
logger.info(
|
|
299
|
+
f"Retrieved {len(self._data)} items from persistent cache.", asset=self.asset, count=len(self._data)
|
|
300
|
+
)
|
|
301
|
+
except AssetNotFoundError:
|
|
302
|
+
logger.info("Persistent cache not found, creating new cache.", asset=self.asset)
|
|
303
|
+
self._data = {}
|
|
304
|
+
|
|
305
|
+
def __exit__(self, *args: Any) -> None:
|
|
306
|
+
"""Upload the contents of the cache as the new version of the store."""
|
|
307
|
+
if self._data is None:
|
|
308
|
+
logger.warning("Persistent cache has not yet been retrieved, no data written.", asset=self.asset)
|
|
309
|
+
return None
|
|
310
|
+
timestamp = utcnow()
|
|
311
|
+
data = ({"key": key, "value": value, "timestamp": timestamp} for key, value in self._data.items())
|
|
312
|
+
self.catalog.store_asset(self.asset, data)
|
|
313
|
+
|
|
314
|
+
def get(self, key: str) -> str | None:
|
|
315
|
+
cache_key = self._key(key)
|
|
316
|
+
return self.data.get(cache_key)
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def data(self) -> dict[str, str]:
|
|
320
|
+
if self._data is None:
|
|
321
|
+
raise RuntimeError("Persistent cache must be used as a context manager.")
|
|
322
|
+
return self._data
|
|
323
|
+
|
|
324
|
+
def set(self, key: str, value: str) -> None:
|
|
325
|
+
cache_key = self._key(key)
|
|
326
|
+
self.data[cache_key] = value
|
|
327
|
+
|
|
328
|
+
def delete(self, key: str) -> None:
|
|
329
|
+
cache_key = self._key(key)
|
|
330
|
+
if cache_key in self.data:
|
|
331
|
+
del self.data[cache_key]
|
|
332
|
+
|
|
333
|
+
def clear(self) -> None:
|
|
334
|
+
self.data.clear()
|
|
335
|
+
|
|
336
|
+
def hit(self, key: str) -> bool:
|
|
337
|
+
return self._key(key) in self.data
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class DynamoDBCache(Cache):
|
|
341
|
+
"""A DynamoDB cache for our shenanigans."""
|
|
342
|
+
|
|
343
|
+
TTL_ATTRIBUTE_NAME = "time_to_live"
|
|
344
|
+
|
|
345
|
+
def __init__(
|
|
346
|
+
self, prefix: str, hash_keys: bool = True, default_ttl: int = DEFAULT_TTL_SECONDS, table_name: str | None = None
|
|
347
|
+
):
|
|
348
|
+
super().__init__(prefix, hash_keys, default_ttl)
|
|
349
|
+
self.client = boto3.resource("dynamodb")
|
|
350
|
+
self.table_name = table_name
|
|
351
|
+
self.table = self.client.Table(table_name)
|
|
352
|
+
|
|
353
|
+
def get(self, key: str) -> str | None:
|
|
354
|
+
cache_key = self._key(key)
|
|
355
|
+
response = self.table.get_item(Key={"key": cache_key})
|
|
356
|
+
item = response.get("Item")
|
|
357
|
+
if item:
|
|
358
|
+
return self._decompress(bytes(item["value"]))
|
|
359
|
+
return None
|
|
360
|
+
|
|
361
|
+
def set(self, key: str, value: str) -> None:
|
|
362
|
+
expiration_time = int(time.time()) + self.default_ttl
|
|
363
|
+
cache_key = self._key(key)
|
|
364
|
+
compressed_value = self._compress(value)
|
|
365
|
+
if len(compressed_value) > DYNAMODB_MAX_LENGTH:
|
|
366
|
+
logger.warning(
|
|
367
|
+
"Value is too long to store in DynamoDB.",
|
|
368
|
+
original_key=cache_key,
|
|
369
|
+
cache_key=cache_key,
|
|
370
|
+
length=len(compressed_value),
|
|
371
|
+
)
|
|
372
|
+
raise CachingError(f"Value is too long to store in DynamoDB: {len(compressed_value)}")
|
|
373
|
+
self.table.put_item(
|
|
374
|
+
Item={"key": cache_key, "value": compressed_value, self.TTL_ATTRIBUTE_NAME: expiration_time}
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def delete(self, key: str) -> None:
|
|
378
|
+
key = self._key(key)
|
|
379
|
+
self.table.delete_item(Key={"key": key})
|
|
380
|
+
|
|
381
|
+
def clear(self) -> None:
|
|
382
|
+
"""
|
|
383
|
+
Deletes all items in the table.
|
|
384
|
+
|
|
385
|
+
Warning: DynamoDB doesn't have a built-in clear mechanism, so we scan and delete all items manually.
|
|
386
|
+
"""
|
|
387
|
+
logger.info("Clearing cache.", cache_table=self.table_name)
|
|
388
|
+
scan = self.table.scan()
|
|
389
|
+
with self.table.batch_writer() as batch:
|
|
390
|
+
for item in scan["Items"]:
|
|
391
|
+
batch.delete_item(Key={"key": item["key"]})
|
|
392
|
+
|
|
393
|
+
def hit(self, key: str) -> bool:
|
|
394
|
+
cache_key = self._key(key)
|
|
395
|
+
item = self.table.get_item(Key={"key": cache_key}, ProjectionExpression=self.TTL_ATTRIBUTE_NAME).get("Item")
|
|
396
|
+
return bool(item and int(item.get(self.TTL_ATTRIBUTE_NAME, 0)) > int(time.time()))
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Google Maps API client."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from contextlib import ExitStack
|
|
5
|
+
from dataclasses import asdict, replace
|
|
6
|
+
from typing import Any, ClassVar
|
|
7
|
+
|
|
8
|
+
import googlemaps
|
|
9
|
+
|
|
10
|
+
from hyperion.config import geo_config
|
|
11
|
+
from hyperion.entities.catalog import PersistentStoreAsset
|
|
12
|
+
from hyperion.infrastructure.cache import PersistentCache
|
|
13
|
+
from hyperion.infrastructure.geo.location import Location, NamedLocation
|
|
14
|
+
from hyperion.logging import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger("gmaps")
|
|
17
|
+
|
|
18
|
+
cache_asset = PersistentStoreAsset("GEOCodeCache", schema_version=1)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _find_info_by_type(components: list[dict[str, Any]], info_type: str) -> str | None:
|
|
22
|
+
for component in components:
|
|
23
|
+
if not isinstance((component_types := component.get("types")), list):
|
|
24
|
+
raise TypeError(f"Unexpected component type info, expected 'list', got {type(component_types)!r}.")
|
|
25
|
+
for component_type in component_types:
|
|
26
|
+
if component_type == info_type:
|
|
27
|
+
value = component.get("long_name", component.get("short_name"))
|
|
28
|
+
if value is None or isinstance(value, str):
|
|
29
|
+
return value
|
|
30
|
+
raise TypeError(f"Unexpected value type, expected 'str' or None, got {type(value)!r}.")
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GoogleMaps:
|
|
35
|
+
"""Google Maps API client."""
|
|
36
|
+
|
|
37
|
+
_instance: ClassVar["GoogleMaps | None"] = None
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_config(cls) -> "GoogleMaps":
|
|
41
|
+
"""Get the Google Maps API client instance from the configuration."""
|
|
42
|
+
if cls._instance is None:
|
|
43
|
+
if geo_config.gmaps_api_key is None:
|
|
44
|
+
raise ValueError("Google Maps API key is not set.")
|
|
45
|
+
cls._instance = GoogleMaps(api_key=geo_config.gmaps_api_key)
|
|
46
|
+
return cls._instance
|
|
47
|
+
|
|
48
|
+
def __init__(self, api_key: str) -> None:
|
|
49
|
+
"""Initialize the Google Maps API client.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
api_key (str): The Google Maps API key.
|
|
53
|
+
"""
|
|
54
|
+
self.geocode_cache = PersistentCache("gmaps", hash_keys=False, asset=cache_asset)
|
|
55
|
+
self.client = googlemaps.Client(key=api_key)
|
|
56
|
+
self._cache_context: ExitStack | None = None
|
|
57
|
+
|
|
58
|
+
def __enter__(self) -> None:
|
|
59
|
+
self._cache_context = ExitStack()
|
|
60
|
+
self._cache_context.enter_context(self.geocode_cache)
|
|
61
|
+
|
|
62
|
+
def __exit__(self, *args: Any) -> None:
|
|
63
|
+
if self._cache_context is None:
|
|
64
|
+
return
|
|
65
|
+
self._cache_context.close()
|
|
66
|
+
|
|
67
|
+
def geocode(self, address: str) -> Location:
|
|
68
|
+
"""Geocode an address.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
address (str): The address to geocode.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Location: The geocoded location.
|
|
75
|
+
"""
|
|
76
|
+
if (cached_location := self.geocode_cache.get(address)) is not None:
|
|
77
|
+
logger.debug("Using geocoded information from cache.", address=address, location=cached_location)
|
|
78
|
+
return Location(**json.loads(cached_location))
|
|
79
|
+
result = self.client.geocode(address)
|
|
80
|
+
if not result:
|
|
81
|
+
raise ValueError(f"Could not geocode address: {address!r}.")
|
|
82
|
+
location = Location(
|
|
83
|
+
latitude=result[0]["geometry"]["location"]["lat"],
|
|
84
|
+
longitude=result[0]["geometry"]["location"]["lng"],
|
|
85
|
+
title=address,
|
|
86
|
+
address=result[0]["formatted_address"],
|
|
87
|
+
)
|
|
88
|
+
logger.debug("Found geocoded information on address.", address=address, location=location)
|
|
89
|
+
self.geocode_cache.set(address, json.dumps(asdict(location)))
|
|
90
|
+
return location
|
|
91
|
+
|
|
92
|
+
def reverse_geocode(self, location: Location, language: str | None = None) -> NamedLocation:
|
|
93
|
+
"""Reverse geocode a location into an address.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
location (Location): The location coordinates.
|
|
97
|
+
language (str, optional): The language in which to return results. Defaults
|
|
98
|
+
to None.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
str: The address name.
|
|
102
|
+
"""
|
|
103
|
+
results = self.client.reverse_geocode({"lat": location.latitude, "lng": location.longitude}, language=language)
|
|
104
|
+
if not results or not isinstance(results, list):
|
|
105
|
+
raise ValueError(f"Could not reverse-geocode provided location - no results. {location!r}")
|
|
106
|
+
result = results[0]
|
|
107
|
+
if not isinstance(result, dict):
|
|
108
|
+
raise TypeError(f"Unexpected result type, expected 'dict', got {type(result)!r}.")
|
|
109
|
+
if not isinstance(address_components := result.get("address_components"), list):
|
|
110
|
+
raise TypeError(f"Unexpected address components type, expected 'dict', got {type(address_components)!r}.")
|
|
111
|
+
title = _find_info_by_type(address_components, "route")
|
|
112
|
+
return NamedLocation(
|
|
113
|
+
location=replace(
|
|
114
|
+
location, address=result.get("formatted_address") or location.address, title=title or location.title
|
|
115
|
+
),
|
|
116
|
+
route=title,
|
|
117
|
+
neighborhood=_find_info_by_type(address_components, "neighborhood"),
|
|
118
|
+
sublocality=_find_info_by_type(address_components, "sublocality")
|
|
119
|
+
or _find_info_by_type(address_components, "sublocality_level_1"),
|
|
120
|
+
administrative_area=_find_info_by_type(address_components, "administrative_area_level_1"),
|
|
121
|
+
administrative_area_level_2=_find_info_by_type(address_components, "administrative_area_level_2"),
|
|
122
|
+
country=_find_info_by_type(address_components, "country"),
|
|
123
|
+
address=result.get("formatted_address"),
|
|
124
|
+
)
|