hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,63 @@
1
+ import abc
2
+ import os
3
+ import re
4
+ from typing import cast
5
+
6
+ import boto3
7
+
8
+ from hyperion.config import secrets_config
9
+ from hyperion.logging import get_logger
10
+
11
+ logger = get_logger("hyperion-secrets")
12
+
13
+ SECRET_PATTERN = re.compile(r"!#secret:#(?P<secret_name>.+)")
14
+
15
+
16
+ class SecretsManager(abc.ABC):
17
+ _instance: "SecretsManager | None" = None
18
+
19
+ @staticmethod
20
+ def _create_new() -> "SecretsManager":
21
+ if secrets_config.backend is None:
22
+ logger.warning("No secrets backend is configured. Using dummy secrets manager.")
23
+ return DummySecretsManager()
24
+ if secrets_config.backend == "AWSSecretsManager":
25
+ logger.info("Using AWS Secrets Manager.")
26
+ return AWSSecretsManager()
27
+ raise ValueError(f"Unsupported secrets backend: {secrets_config.backend!r}.")
28
+
29
+ @staticmethod
30
+ def from_config() -> "SecretsManager":
31
+ if SecretsManager._instance is None:
32
+ SecretsManager._instance = SecretsManager._create_new()
33
+ return SecretsManager._instance
34
+
35
+ @staticmethod
36
+ def translate_env_vars() -> None:
37
+ """Loop through all env variables and replace secrets with their values.
38
+
39
+ Only variables with value pattern of `!#secret:#secret_name` will be replaced.
40
+ """
41
+ for key, value in os.environ.items():
42
+ if (match := SECRET_PATTERN.match(value)) is not None:
43
+ secret_name = match.group("secret_name")
44
+ logger.info("Replacing secret in environment variable.", key=key, secret_name=secret_name)
45
+ os.environ[key] = SecretsManager.from_config().get_secret(secret_name)
46
+
47
+ @abc.abstractmethod
48
+ def get_secret(self, secret_name: str) -> str:
49
+ """Get the secret with the given name."""
50
+
51
+
52
+ class DummySecretsManager(SecretsManager):
53
+ def get_secret(self, secret_name: str) -> str:
54
+ logger.warning("Using dummy secrets manager, no values will be returned.", secret_name=secret_name)
55
+ return ""
56
+
57
+
58
+ class AWSSecretsManager(SecretsManager):
59
+ def __init__(self) -> None:
60
+ self.client = boto3.client("secretsmanager")
61
+
62
+ def get_secret(self, secret_name: str) -> str:
63
+ return cast(str, self.client.get_secret_value(SecretId=secret_name)["SecretString"])
hyperion/logging.py ADDED
@@ -0,0 +1,122 @@
1
+ import logging as _logging
2
+ import sys
3
+ import time
4
+ import traceback
5
+ from collections.abc import Callable
6
+ from types import TracebackType
7
+ from typing import Any, ParamSpec, TypeVar
8
+
9
+ import loguru
10
+
11
+ from hyperion.config import config
12
+
13
+ P = ParamSpec("P")
14
+ T = TypeVar("T")
15
+
16
+ LOGURU_PRETTY_FORMAT = (
17
+ "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
18
+ "<level>{level: <8}</level> | "
19
+ "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
20
+ "<level>{message}</level>\n<dim>{extra}</dim>"
21
+ )
22
+
23
+
24
+ def setup_logger_from_env() -> None:
25
+ loguru.logger.remove()
26
+ if config.log_pretty:
27
+ loguru.logger.add(sys.stderr, colorize=True, format=LOGURU_PRETTY_FORMAT, level=config.log_level)
28
+ else:
29
+ loguru.logger.add(sys.stderr, format="{message}", serialize=True, level=config.log_level)
30
+
31
+
32
+ class InterceptHandler(_logging.Handler):
33
+ def emit(self, record: _logging.LogRecord) -> None:
34
+ log_level: str | int
35
+ try:
36
+ log_level = loguru.logger.level(record.levelname).name
37
+ except ValueError:
38
+ log_level = record.levelno
39
+ frame = _logging.currentframe()
40
+ depth = 2
41
+ while frame.f_code.co_filename == _logging.__file__:
42
+ if frame.f_back is None:
43
+ break
44
+ frame = frame.f_back
45
+ depth += 1
46
+ loguru.logger.opt(depth=depth, exception=record.exc_info).log(log_level, record.getMessage())
47
+
48
+
49
+ def intercept_python_loggers() -> None:
50
+ _logging.root.handlers = []
51
+ _logging.root.setLevel(_logging.DEBUG)
52
+ _logging.root.addHandler(InterceptHandler())
53
+
54
+
55
+ class LogTiming:
56
+ def __init__(self, action: str | None = None, logger: "loguru.Logger | None" = None) -> None:
57
+ self.action = action
58
+ self.start_time: float = 0
59
+ self.logger = logger or get_logger("timing")
60
+
61
+ def __enter__(self) -> None:
62
+ self.start_time = time.monotonic()
63
+
64
+ def __exit__(
65
+ self,
66
+ exc_type: type[BaseException] | None = None,
67
+ exc: BaseException | None = None,
68
+ tb: TracebackType | None = None,
69
+ ) -> None:
70
+ duration = time.monotonic() - self.start_time
71
+ action = self.action or "unnamed action"
72
+ if exc:
73
+ self.logger.warning(
74
+ f"Action {action!r} failed after {duration:0.3f} seconds.",
75
+ action=self.action,
76
+ duration=duration,
77
+ exception_type=exc_type.__name__ if exc_type else "unknown",
78
+ exception=str(exc),
79
+ traceback="\n".join(traceback.format_tb(tb)),
80
+ )
81
+ return
82
+ self.logger.info(
83
+ f"Action {action!r} finished after {duration:0.3f} seconds.", action=self.action, duration=duration
84
+ )
85
+
86
+ def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
87
+ if self.action is None:
88
+ self.action = func.__name__
89
+
90
+ def _wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
91
+ with self:
92
+ return func(*args, **kwargs)
93
+
94
+ return _wrapped
95
+
96
+
97
+ setup_logger_from_env()
98
+ intercept_python_loggers()
99
+
100
+
101
+ def get_logger(service: str, **context: Any) -> "loguru.Logger":
102
+ """Get logger bound to a service with optional contextual variables."""
103
+ return loguru.logger.bind(service=service, **context)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ demo_logger = get_logger("demo")
108
+ demo_logger.trace("Trace message")
109
+ demo_logger.debug("Debug message")
110
+ demo_logger.info("Informational message")
111
+ demo_logger.warning("A warning!")
112
+ demo_logger.error("An error...")
113
+ demo_logger.critical("This is terrible.")
114
+ with LogTiming("This will take around 1 second.", demo_logger):
115
+ time.sleep(1)
116
+
117
+ try:
118
+ with LogTiming("This will fail after around 1 second.", demo_logger):
119
+ time.sleep(1)
120
+ raise ValueError("This failed.")
121
+ except Exception as error:
122
+ demo_logger.exception(error)
hyperion/py.typed ADDED
File without changes
File without changes
@@ -0,0 +1,105 @@
1
+ """Base abstract class for sources."""
2
+
3
+ import abc
4
+ import asyncio
5
+ import datetime
6
+ from collections.abc import AsyncIterator, Awaitable, Iterable
7
+ from dataclasses import dataclass
8
+ from typing import Any, ClassVar, cast
9
+
10
+ from aws_lambda_typing.context import Context
11
+ from aws_lambda_typing.events import EventBridgeEvent, SQSEvent
12
+
13
+ from hyperion.asyncutils import AsyncTaskQueue, get_loop
14
+ from hyperion.catalog import Catalog
15
+ from hyperion.config import storage_config
16
+ from hyperion.entities.catalog import DataLakeAsset
17
+ from hyperion.infrastructure.queue import SourceBackfillMessage, SQSQueue, iter_messages_from_sqs_event
18
+ from hyperion.logging import get_logger
19
+
20
+ SourceEventType = EventBridgeEvent | SQSEvent
21
+
22
+ logger = get_logger("hyperion-source")
23
+
24
+
25
+ @dataclass(eq=True, frozen=True)
26
+ class SourceAsset:
27
+ asset: DataLakeAsset
28
+ data: Iterable[dict[str, Any]]
29
+
30
+
31
+ class Source(abc.ABC):
32
+ source: ClassVar[str] = NotImplemented
33
+
34
+ def __init__(self, catalog: Catalog) -> None:
35
+ if self.source is NotImplemented:
36
+ raise NotImplementedError("Cannot instantiate a source without a source name.")
37
+ self.catalog = catalog
38
+
39
+ @abc.abstractmethod
40
+ def run(
41
+ self,
42
+ start_date: datetime.datetime | datetime.date | None = None,
43
+ end_date: datetime.datetime | datetime.date | None = None,
44
+ ) -> Awaitable[Iterable[SourceAsset]] | AsyncIterator[SourceAsset]:
45
+ """The main coroutine that runs the source extraction."""
46
+
47
+ @classmethod
48
+ async def _run(
49
+ cls,
50
+ catalog: Catalog,
51
+ notify: bool = True,
52
+ start_date: datetime.datetime | datetime.date | None = None,
53
+ end_date: datetime.datetime | datetime.date | None = None,
54
+ ) -> None:
55
+ source = cls(catalog)
56
+ result = source.run(start_date=start_date, end_date=end_date)
57
+ async with AsyncTaskQueue[None](maxsize=storage_config.max_concurrency) as queue:
58
+ if isinstance(result, AsyncIterator):
59
+ async for asset in result:
60
+ logger.info("Processing asset retrieved from source.", asset=asset.asset)
61
+ await queue.add_task(source.catalog.store_asset_async(asset.asset, asset.data, notify=notify))
62
+ else:
63
+ for asset in await result:
64
+ logger.info("Processing asset retrieved from source.", asset=asset.asset)
65
+ await queue.add_task(source.catalog.store_asset_async(asset.asset, asset.data, notify=notify))
66
+
67
+ @classmethod
68
+ def handle_aws_lambda_event(
69
+ cls,
70
+ event: SourceEventType | None = None,
71
+ context: Context | None = None,
72
+ *,
73
+ loop: asyncio.AbstractEventLoop | None = None,
74
+ ) -> None:
75
+ logger.info("Starting Hyperion source.", source=cls.__name__, event=str(event), context=str(context))
76
+ catalog = Catalog.from_config()
77
+ loop = loop or get_loop()
78
+ queue = SQSQueue.from_config()
79
+ if isinstance(event, dict) and "Records" in event:
80
+ # We may presume this is an SQS Event
81
+ event = cast(SQSEvent, event)
82
+ for message in iter_messages_from_sqs_event(event):
83
+ if not isinstance(message, SourceBackfillMessage):
84
+ logger.warning(
85
+ "Only SourceBackfillMessage is supported, this message is probably not for us.", message=message
86
+ )
87
+ continue
88
+ if message.source != cls.source:
89
+ logger.info("Message is not intended for us.", source=cls.source, message_source=message.source)
90
+ continue
91
+ logger.info("Source triggered by an SQS Message.", source=cls.source, message=message)
92
+ loop.run_until_complete(
93
+ cls._run(catalog, start_date=message.start_date, end_date=message.end_date, notify=message.notify)
94
+ )
95
+ if message.receipt_handle:
96
+ queue.delete(message.receipt_handle)
97
+ return
98
+ if isinstance(event, dict) and "detail" in event:
99
+ # We may presume this is an EventBridgeEvent
100
+ event = cast(EventBridgeEvent, event)
101
+ logger.warning("EventBridge events can carry no config for now.")
102
+ loop.run_until_complete(cls._run(catalog, start_date=None, end_date=None))
103
+ return
104
+ logger.warning("No event was provided, assuming a no-config run.")
105
+ loop.run_until_complete(cls._run(catalog, start_date=None, end_date=None))
hyperion/typeutils.py ADDED
@@ -0,0 +1,52 @@
1
+ import datetime
2
+ from collections.abc import Sequence
3
+ from dataclasses import asdict
4
+ from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
5
+
6
+ if TYPE_CHECKING:
7
+ from _typeshed import DataclassInstance
8
+
9
+ T = TypeVar("T")
10
+
11
+ DateOrDelta: TypeAlias = datetime.datetime | datetime.timedelta
12
+
13
+
14
+ def assert_type(variable: Any, assertion: type[T]) -> T:
15
+ if not isinstance(variable, assertion):
16
+ raise TypeError(f"Expected {assertion}, got {type(variable)}.")
17
+ return variable
18
+
19
+
20
+ def dataclass_asdict(
21
+ dataclass: "DataclassInstance",
22
+ *,
23
+ exclude: Sequence[str] | None = None,
24
+ include: Sequence[str] | None = None,
25
+ ) -> dict[str, Any]:
26
+ """Convert a dataclass instance to a dictionary.
27
+
28
+ Args:
29
+ dataclass (DataclassInstance): Dataclass instance to convert.
30
+ exclude (Sequence[str], optional): Fields to exclude. Defaults to None.
31
+ include (Sequence[str], optional): Fields to include. Defaults to None.
32
+
33
+ Returns:
34
+ dict[str, Any]: Converted dictionary.
35
+
36
+ Raises:
37
+ ValueError: If include and exclude overlap.
38
+ ValueError: If include field is not found in dataclass.
39
+ """
40
+ exclude = exclude or []
41
+ include = include or []
42
+ dct = asdict(dataclass)
43
+ fields = list(dct.keys())
44
+ if set(include) & set(exclude):
45
+ raise ValueError("Include and exclude cannot overlap.")
46
+ if set(include) - set(fields):
47
+ raise ValueError("Include field not found in dataclass.")
48
+ if exclude or include:
49
+ for field in fields:
50
+ if (include and field not in include) or field in exclude:
51
+ del dct[field]
52
+ return dct