hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperion/__init__.py +0 -0
- hyperion/asyncutils.py +79 -0
- hyperion/catalog/__init__.py +8 -0
- hyperion/catalog/catalog.py +623 -0
- hyperion/catalog/schema.py +153 -0
- hyperion/collections/__init__.py +0 -0
- hyperion/collections/asset_collection.py +285 -0
- hyperion/config.py +77 -0
- hyperion/dateutils.py +238 -0
- hyperion/entities/__init__.py +0 -0
- hyperion/entities/catalog.py +190 -0
- hyperion/infrastructure/__init__.py +0 -0
- hyperion/infrastructure/aws.py +220 -0
- hyperion/infrastructure/cache.py +396 -0
- hyperion/infrastructure/geo/__init__.py +7 -0
- hyperion/infrastructure/geo/gmaps.py +124 -0
- hyperion/infrastructure/geo/location.py +186 -0
- hyperion/infrastructure/http.py +62 -0
- hyperion/infrastructure/keyval.py +264 -0
- hyperion/infrastructure/queue.py +151 -0
- hyperion/infrastructure/secrets.py +63 -0
- hyperion/logging.py +122 -0
- hyperion/py.typed +0 -0
- hyperion/sources/__init__.py +0 -0
- hyperion/sources/base.py +105 -0
- hyperion/typeutils.py +52 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/METADATA +476 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/RECORD +29 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Schema store."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import json
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, ClassVar, cast
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
|
|
10
|
+
from hyperion.config import storage_config
|
|
11
|
+
from hyperion.entities.catalog import AssetProtocol, AssetType
|
|
12
|
+
from hyperion.infrastructure.aws import S3Client
|
|
13
|
+
from hyperion.logging import get_logger
|
|
14
|
+
|
|
15
|
+
logger = get_logger("schema-store")
|
|
16
|
+
|
|
17
|
+
AVRO_SCHEMAS_PATH = Path(__file__).parent / "avro_schemas"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SchemaStore(abc.ABC):
|
|
21
|
+
"""Abstract base class for schema stores."""
|
|
22
|
+
|
|
23
|
+
_instances: ClassVar[dict[str, "SchemaStore"]] = {}
|
|
24
|
+
|
|
25
|
+
def __init__(self, path: str) -> None:
|
|
26
|
+
"""Initialize the schema store with the given path.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
path (str): The path to the schema store.
|
|
30
|
+
"""
|
|
31
|
+
self.path = path
|
|
32
|
+
|
|
33
|
+
def get_asset_schema(self, asset: AssetProtocol) -> dict[str, Any]:
|
|
34
|
+
"""Get the schema for the given asset.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
asset (AssetProtocol): The asset to get the schema for.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
dict[str, Any]: The schema for the asset.
|
|
41
|
+
"""
|
|
42
|
+
return self.get_schema(asset.name, asset.schema_version, asset_type=asset.asset_type)
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def get_schema(self, asset_name: str, schema_version: int, asset_type: AssetType) -> dict[str, Any]:
|
|
46
|
+
"""Get the schema for the asset with the given name and version."""
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def _create_new(path: str) -> "SchemaStore":
|
|
50
|
+
parsed = urlparse(path)
|
|
51
|
+
if parsed.scheme == "file" or not parsed.scheme:
|
|
52
|
+
resolved = (Path(parsed.netloc or "/") / parsed.path.lstrip("/")).resolve()
|
|
53
|
+
logger.info("Using file schema store.", path=resolved.as_posix())
|
|
54
|
+
return LocalSchemaStore(resolved)
|
|
55
|
+
if parsed.scheme == "s3":
|
|
56
|
+
bucket = parsed.netloc
|
|
57
|
+
prefix = parsed.path.lstrip("/")
|
|
58
|
+
logger.info("Using S3 schema store.", bucket=bucket, prefix=prefix)
|
|
59
|
+
return S3SchemaStore(bucket, prefix)
|
|
60
|
+
logger.critical("Unsupported schema store scheme.", scheme=parsed.scheme, path=storage_config.schema_path)
|
|
61
|
+
raise ValueError(f"Unsupported schema store scheme {parsed.scheme!r}.")
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def from_path(path: str) -> "SchemaStore":
|
|
65
|
+
"""Get a schema store from the given path.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
path (str): The path to the schema store.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
SchemaStore: The schema store.
|
|
72
|
+
"""
|
|
73
|
+
if path not in SchemaStore._instances:
|
|
74
|
+
SchemaStore._instances[path] = SchemaStore._create_new(path)
|
|
75
|
+
return SchemaStore._instances[path]
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def from_config() -> "SchemaStore":
|
|
79
|
+
"""Get a schema store from the configuration.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
SchemaStore: The schema store.
|
|
83
|
+
"""
|
|
84
|
+
return SchemaStore.from_path(storage_config.schema_path)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class LocalSchemaStore(SchemaStore):
|
|
88
|
+
"""Schema store for local files."""
|
|
89
|
+
|
|
90
|
+
def __init__(self, schemas_path: Path = AVRO_SCHEMAS_PATH) -> None:
|
|
91
|
+
"""Initialize the local schema store with the given path.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
schemas_path (Path, optional): The path to the schemas. Defaults to AVRO_SCHEMAS_PATH.
|
|
95
|
+
"""
|
|
96
|
+
super().__init__(schemas_path.as_posix())
|
|
97
|
+
self.schemas_path = schemas_path
|
|
98
|
+
if not schemas_path.exists():
|
|
99
|
+
logger.critical("Provided schemas path does not exist.", schemas_path=schemas_path.as_posix())
|
|
100
|
+
raise FileNotFoundError("Provided schemas path does not exist.")
|
|
101
|
+
|
|
102
|
+
def get_schema(self, asset_name: str, schema_version: int, asset_type: AssetType) -> dict[str, Any]:
|
|
103
|
+
path = self.schemas_path / asset_type / f"{asset_name}.v{schema_version}.avro.json"
|
|
104
|
+
logger.info(
|
|
105
|
+
"Reading avro schema from stored json file.",
|
|
106
|
+
path=path.as_posix(),
|
|
107
|
+
asset_name=asset_name,
|
|
108
|
+
asset_type=asset_type,
|
|
109
|
+
)
|
|
110
|
+
try:
|
|
111
|
+
with path.open("r", encoding="utf-8") as file:
|
|
112
|
+
schema = json.load(file)
|
|
113
|
+
if not isinstance(schema, dict):
|
|
114
|
+
raise TypeError(f"Schema has unexpected type {type(schema)}, expected 'dict'.")
|
|
115
|
+
return schema
|
|
116
|
+
except Exception:
|
|
117
|
+
logger.critical(
|
|
118
|
+
"Failed to get avro schema from stored json file.",
|
|
119
|
+
path=path.as_posix(),
|
|
120
|
+
asset_name=asset_name,
|
|
121
|
+
asset_type=asset_type,
|
|
122
|
+
)
|
|
123
|
+
raise
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@lru_cache(maxsize=256)
|
|
127
|
+
def _get_schema_from_s3(bucket: str, key: str, client: S3Client) -> dict[str, Any]:
|
|
128
|
+
return cast(dict[str, Any], json.loads(client.download_as_string(bucket, key)))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class S3SchemaStore(SchemaStore):
|
|
132
|
+
"""Schema store for S3."""
|
|
133
|
+
|
|
134
|
+
def __init__(self, bucket: str, prefix: str) -> None:
|
|
135
|
+
"""Initialize the S3 schema store with the given bucket and prefix.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
bucket (str): The S3 bucket.
|
|
139
|
+
prefix (str): The prefix in the bucket.
|
|
140
|
+
"""
|
|
141
|
+
super().__init__(f"s3://{bucket}/{prefix}")
|
|
142
|
+
self.bucket = bucket
|
|
143
|
+
self.prefix = prefix
|
|
144
|
+
self.s3_client = S3Client()
|
|
145
|
+
|
|
146
|
+
def get_schema(self, asset_name: str, schema_version: int, asset_type: AssetType) -> dict[str, Any]:
|
|
147
|
+
key = f"{asset_type}/{asset_name}.v{schema_version}.avro.json"
|
|
148
|
+
logger.info("Getting avro schema from S3.", bucket=self.bucket, key=key)
|
|
149
|
+
try:
|
|
150
|
+
return _get_schema_from_s3(self.bucket, key, self.s3_client)
|
|
151
|
+
except Exception:
|
|
152
|
+
logger.critical("Failed to get avro schema from S3.", bucket=self.bucket, key=key)
|
|
153
|
+
raise
|
|
File without changes
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""Asset collection is a class that allows you to fetch
|
|
2
|
+
and store data from the catalog in a type-safe manner.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import datetime
|
|
7
|
+
from collections.abc import Coroutine
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any, ClassVar, Generic, TypeVar, cast
|
|
10
|
+
|
|
11
|
+
from hyperion.asyncutils import iter_async
|
|
12
|
+
from hyperion.catalog import Catalog
|
|
13
|
+
from hyperion.dateutils import utcnow
|
|
14
|
+
from hyperion.entities.catalog import FeatureAsset, FeatureModel
|
|
15
|
+
from hyperion.logging import get_logger
|
|
16
|
+
from hyperion.typeutils import DateOrDelta
|
|
17
|
+
|
|
18
|
+
logger = get_logger("hyperion-model-specification")
|
|
19
|
+
|
|
20
|
+
CClass = TypeVar("CClass", bound=FeatureModel)
|
|
21
|
+
CollectionType = TypeVar("CollectionType", bound="AssetCollection")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True, eq=True)
|
|
25
|
+
class FeatureAssetSpecification(Generic[CClass]):
|
|
26
|
+
"""Specification for fetching feature assets from the catalog.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
feature: The feature model class to fetch.
|
|
30
|
+
start_date: The start date or delta from now to fetch the data.
|
|
31
|
+
end_date: The end date or delta from now to fetch the data.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
feature: type[CClass]
|
|
35
|
+
start_date: DateOrDelta | None = None
|
|
36
|
+
end_date: DateOrDelta | None = None
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def _resolve_date(
|
|
40
|
+
date_spec: DateOrDelta | None, default: datetime.datetime, the_now: datetime.datetime | None
|
|
41
|
+
) -> datetime.datetime:
|
|
42
|
+
if date_spec is None:
|
|
43
|
+
return default
|
|
44
|
+
if isinstance(date_spec, datetime.datetime):
|
|
45
|
+
return date_spec
|
|
46
|
+
return (the_now or utcnow()) + date_spec
|
|
47
|
+
|
|
48
|
+
def resolve_start_date(self, the_now: datetime.datetime | None = None) -> datetime.datetime:
|
|
49
|
+
"""Resolve the start date for fetching the feature asset data."""
|
|
50
|
+
return self._resolve_date(self.start_date, datetime.datetime.min, the_now)
|
|
51
|
+
|
|
52
|
+
def resolve_end_date(self, the_now: datetime.datetime | None = None) -> datetime.datetime:
|
|
53
|
+
"""Resolve the end date for fetching the feature asset data."""
|
|
54
|
+
return self._resolve_date(self.end_date, utcnow(), the_now)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class _CollectionState:
|
|
59
|
+
"""Internal state of the asset collection.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
fetched: Whether the data has been fetched.
|
|
63
|
+
data: A mapping of field names to the fetched data.
|
|
64
|
+
fetch_specifications: A mapping of field names to the fetch specifications.
|
|
65
|
+
anchor_timestamps: The anchor timestamps for fetching the data.
|
|
66
|
+
semaphore: The asyncio semaphore for controlling concurrency.
|
|
67
|
+
max_concurrency: The maximum concurrency for fetching data.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
fetched: bool = False
|
|
71
|
+
data: dict[str, list[Any]] = field(default_factory=dict)
|
|
72
|
+
fetch_specifications: dict[str, FeatureAssetSpecification[Any]] = field(default_factory=dict)
|
|
73
|
+
anchor_timestamps: dict[str, datetime.datetime | None] = field(default_factory=dict)
|
|
74
|
+
semaphore: asyncio.Semaphore | None = None
|
|
75
|
+
max_concurrency: int | None = None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class AssetCollection:
|
|
79
|
+
"""A collection of feature assets that can be fetched from the catalog.
|
|
80
|
+
|
|
81
|
+
Attributes:
|
|
82
|
+
catalog: The catalog to fetch the data from. If not set, it will be created from the config.
|
|
83
|
+
max_concurrency: The maximum concurrency for fetching data. Default is 8.
|
|
84
|
+
reserved_fields: The reserved field names for the collection.
|
|
85
|
+
_state: The internal state of the collection. It should under no circumstances be modified directly.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
catalog: ClassVar[Catalog | None] = None
|
|
89
|
+
max_concurrency: ClassVar[int] = 8
|
|
90
|
+
reserved_fields: ClassVar = ("catalog", "max_concurrency", "reserved_fields")
|
|
91
|
+
_state: ClassVar[_CollectionState]
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def _get_state(cls) -> _CollectionState:
|
|
95
|
+
if not hasattr(cls, "_state"):
|
|
96
|
+
logger.debug("Creating new empty state for the collection.", collection=cls.__name__)
|
|
97
|
+
cls._state = _CollectionState()
|
|
98
|
+
return cls._state
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def _get_semaphore(cls) -> asyncio.Semaphore:
|
|
102
|
+
state = cls._get_state()
|
|
103
|
+
if state.semaphore is None:
|
|
104
|
+
logger.debug("Creating new semaphore.", collection=cls.__name__, max_concurrency=cls.max_concurrency)
|
|
105
|
+
state.semaphore = asyncio.Semaphore(cls.max_concurrency)
|
|
106
|
+
state.max_concurrency = cls.max_concurrency
|
|
107
|
+
if state.max_concurrency != cls.max_concurrency:
|
|
108
|
+
logger.warning(
|
|
109
|
+
"Config max_concurrency cannot be changed after first use of the collection.", collection=cls.__name__
|
|
110
|
+
)
|
|
111
|
+
return state.semaphore
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def is_fetched(cls) -> bool:
|
|
115
|
+
"""Check if the collection has fetched all data."""
|
|
116
|
+
return cls._get_state().fetched
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def get_data(cls, field: str) -> list[Any]:
|
|
120
|
+
"""Get the fetched data for the given field."""
|
|
121
|
+
state = cls._get_state()
|
|
122
|
+
if field not in state.data:
|
|
123
|
+
raise ValueError(f"Data for {field!r} has not been fetched yet. Did you call 'fetch_all()'?")
|
|
124
|
+
return state.data[field]
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def clear(cls) -> None:
|
|
128
|
+
"""Clear all fetched data from the collection."""
|
|
129
|
+
logger.info("Clearing all fetched data from the collection.", collection=cls.__name__)
|
|
130
|
+
cls._get_state().data = {}
|
|
131
|
+
cls._get_state().fetched = False
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def _get_catalog(cls) -> Catalog:
|
|
135
|
+
if cls.catalog is None:
|
|
136
|
+
cls.catalog = Catalog.from_config()
|
|
137
|
+
return cls.catalog
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
async def _gather_asset_range(
|
|
141
|
+
cls, asset_spec: FeatureAssetSpecification[CClass], the_now: datetime.datetime
|
|
142
|
+
) -> list[CClass]:
|
|
143
|
+
start_date = asset_spec.resolve_start_date(the_now)
|
|
144
|
+
end_date = asset_spec.resolve_end_date(the_now)
|
|
145
|
+
partitions = list(
|
|
146
|
+
cls._get_catalog().iter_feature_store_partitions(
|
|
147
|
+
asset_spec.feature.asset_name,
|
|
148
|
+
asset_spec.feature.resolution,
|
|
149
|
+
start_date,
|
|
150
|
+
end_date,
|
|
151
|
+
asset_spec.feature.schema_version,
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
all_data: list[CClass] = []
|
|
155
|
+
|
|
156
|
+
async def _retrieve_async(partition: FeatureAsset) -> list[dict[str, Any]]:
|
|
157
|
+
async with cls._get_semaphore():
|
|
158
|
+
logger.debug("Retrieving partition.", partition=partition)
|
|
159
|
+
return await asyncio.to_thread(list, cls._get_catalog().retrieve_asset(partition))
|
|
160
|
+
|
|
161
|
+
tasks: list[Coroutine[None, None, list[dict[str, Any]]]] = []
|
|
162
|
+
|
|
163
|
+
logger.info(
|
|
164
|
+
f"Retrieving {len(partitions)} partitions.",
|
|
165
|
+
partitions=len(partitions),
|
|
166
|
+
asset_name=asset_spec.feature.asset_name,
|
|
167
|
+
)
|
|
168
|
+
for partition in partitions:
|
|
169
|
+
tasks.append(_retrieve_async(partition))
|
|
170
|
+
results = await asyncio.gather(*tasks)
|
|
171
|
+
for data in results:
|
|
172
|
+
all_data.extend(asset_spec.feature(**row) for row in data)
|
|
173
|
+
logger.info(
|
|
174
|
+
f"Downloaded {len(all_data)} rows from {len(partitions)} partitions.",
|
|
175
|
+
asset_name=asset_spec.feature.asset_name,
|
|
176
|
+
)
|
|
177
|
+
return all_data
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def register_specification(
|
|
181
|
+
cls,
|
|
182
|
+
field_name: str,
|
|
183
|
+
specification: FeatureAssetSpecification[CClass],
|
|
184
|
+
anchor_timestamp: datetime.datetime | None = None,
|
|
185
|
+
) -> None:
|
|
186
|
+
"""Register a fetch specification for a field in the collection.
|
|
187
|
+
|
|
188
|
+
This is normally only called by the `FeatureFetchSpecifier` descriptor and should
|
|
189
|
+
not be called directly.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
field_name: The name of the field to register the specification for.
|
|
193
|
+
specification: The fetch specification for the field.
|
|
194
|
+
anchor_timestamp: The anchor timestamp for fetching the data.
|
|
195
|
+
"""
|
|
196
|
+
if field_name in cls._get_state().fetch_specifications:
|
|
197
|
+
logger.warning(
|
|
198
|
+
"Registering duplicate fetch specification, existing will be discarded.",
|
|
199
|
+
field=field_name,
|
|
200
|
+
asset_name=specification.feature.asset_name,
|
|
201
|
+
)
|
|
202
|
+
logger.debug("Registering field into an asset collection.", collection=cls.__name__, field=field_name)
|
|
203
|
+
cls._get_state().fetch_specifications[field_name] = specification
|
|
204
|
+
cls._get_state().anchor_timestamps[field_name] = anchor_timestamp
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
async def fetch_all(cls) -> None:
|
|
208
|
+
"""Fetch all data for the collection."""
|
|
209
|
+
if cls.is_fetched():
|
|
210
|
+
logger.info(
|
|
211
|
+
"Collection already fetched all data, if you want to start over, call .clear()", collection=cls.__name__
|
|
212
|
+
)
|
|
213
|
+
return
|
|
214
|
+
logger.info("Gather all data within the collection.", collection=cls.__name__)
|
|
215
|
+
tasks: list[Coroutine[None, None, tuple[str, list[Any]]]] = []
|
|
216
|
+
|
|
217
|
+
async def _gather(name: str, specs: FeatureAssetSpecification[CClass]) -> tuple[str, list[CClass]]:
|
|
218
|
+
anchor_timestamp = cls._get_state().anchor_timestamps.get(name) or utcnow()
|
|
219
|
+
return (name, await cls._gather_asset_range(specs, anchor_timestamp))
|
|
220
|
+
|
|
221
|
+
async for prop, specs in iter_async(cls._get_state().fetch_specifications.items()):
|
|
222
|
+
tasks.append(_gather(prop, specs))
|
|
223
|
+
|
|
224
|
+
results = await asyncio.gather(*tasks)
|
|
225
|
+
for name, data in results:
|
|
226
|
+
logger.info("Finished receiving feature data.", field=name)
|
|
227
|
+
cls._get_state().data[name] = data
|
|
228
|
+
|
|
229
|
+
cls._get_state().fetched = True
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class _FeatureFetchSpecifier(Generic[CClass]):
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
feature: type[CClass],
|
|
236
|
+
start_date: DateOrDelta | None = None,
|
|
237
|
+
end_date: DateOrDelta | None = None,
|
|
238
|
+
) -> None:
|
|
239
|
+
self._specification = FeatureAssetSpecification(feature, start_date, end_date)
|
|
240
|
+
self._owner: type[AssetCollection] | None = None
|
|
241
|
+
self._field_name: str | None = None
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def owner(self) -> type[AssetCollection]:
|
|
245
|
+
if self._owner is None:
|
|
246
|
+
raise RuntimeError("Field was not properly initialized and has no owner.")
|
|
247
|
+
return self._owner
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def field_name(self) -> str:
|
|
251
|
+
if self._field_name is None:
|
|
252
|
+
raise RuntimeError("Field was not properly initialized and has no name.")
|
|
253
|
+
return self._field_name
|
|
254
|
+
|
|
255
|
+
def __set_name__(self, owner: type[AssetCollection], field_name: str) -> None:
|
|
256
|
+
if not issubclass(owner, AssetCollection) and owner is not AssetCollection:
|
|
257
|
+
raise TypeError(
|
|
258
|
+
f"{self.__class__.__name__!r} can only be a field of AssetCollection or its subclass, "
|
|
259
|
+
f"{owner!r} is not a valid owner."
|
|
260
|
+
)
|
|
261
|
+
if field_name.startswith("_") or field_name in owner.reserved_fields:
|
|
262
|
+
raise ValueError(f"Field name {field_name!r} is reserved for internal use.")
|
|
263
|
+
self._owner = owner
|
|
264
|
+
self._field_name = field_name
|
|
265
|
+
owner.register_specification(self.field_name, self._specification)
|
|
266
|
+
|
|
267
|
+
def __get__(self, _instance: AssetCollection, _instance_type: type[AssetCollection]) -> list[CClass]:
|
|
268
|
+
if not self.owner.is_fetched():
|
|
269
|
+
raise RuntimeError(
|
|
270
|
+
f"Owner collection {self.owner.__name__!r} was not fetched yet. Did you call fetch_all()?"
|
|
271
|
+
)
|
|
272
|
+
return self.owner.get_data(self.field_name)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def FeatureFetchSpecifier( # noqa: N802, a fake class factory
|
|
276
|
+
feature: type[CClass], start_date: DateOrDelta | None = None, end_date: DateOrDelta | None = None
|
|
277
|
+
) -> list[CClass]:
|
|
278
|
+
"""Create a feature fetch specifier for the given feature model class.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
feature: The feature model class to fetch.
|
|
282
|
+
start_date: The start date or delta from now to fetch the data.
|
|
283
|
+
end_date: The end date or delta from now to fetch the data.
|
|
284
|
+
"""
|
|
285
|
+
return cast(list[CClass], _FeatureFetchSpecifier(feature, start_date, end_date))
|
hyperion/config.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from dotenv import find_dotenv, load_dotenv
|
|
2
|
+
from env_proxy import EnvConfig, EnvProxy, Field
|
|
3
|
+
|
|
4
|
+
load_dotenv(find_dotenv(usecwd=True))
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CommonConfig(EnvConfig):
|
|
8
|
+
env_proxy = EnvProxy(prefix="HYPERION_COMMON")
|
|
9
|
+
|
|
10
|
+
log_pretty: bool = Field(
|
|
11
|
+
default=False, description="If truthy, logs will be pretty-printed. Defaults to JSON logs."
|
|
12
|
+
)
|
|
13
|
+
log_level: str = Field(default="DEBUG", description="The minimum log level.")
|
|
14
|
+
service_name: str = Field(description="The name of the service.", default="hyperion")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class StorageConfig(EnvConfig):
|
|
18
|
+
env_proxy = EnvProxy(prefix="HYPERION_STORAGE")
|
|
19
|
+
|
|
20
|
+
max_concurrency: int = Field(default=5, description="The maximum number of concurrent storage tasks.")
|
|
21
|
+
|
|
22
|
+
data_lake_bucket: str = Field(description="Data lake store bucket name.")
|
|
23
|
+
feature_store_bucket: str = Field(description="Feature store bucket name.")
|
|
24
|
+
persistent_store_bucket: str = Field(description="Persistent store bucket name.")
|
|
25
|
+
|
|
26
|
+
data_lake_prefix: str = Field(default="", description="Optional data lake paths prefix.")
|
|
27
|
+
feature_store_prefix: str = Field(default="", description="Optional feature store paths prefix.")
|
|
28
|
+
persistent_store_prefix: str = Field(default="", description="Optional persistent store path prefix.")
|
|
29
|
+
|
|
30
|
+
schema_path: str = Field(description="The path to the schema store. Supported schemes: 'file', 's3', 'python'.")
|
|
31
|
+
|
|
32
|
+
cache_dynamodb_table: str | None = Field(
|
|
33
|
+
description="The name of DynamoDB table that will be used as key-value cache. If not set, memory cache is used."
|
|
34
|
+
)
|
|
35
|
+
cache_dynamodb_default_ttl: int = Field(
|
|
36
|
+
default=60, description="Default TTL for key-value cache in DynamoDB (in seconds)."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
cache_local_path: str | None = Field(
|
|
40
|
+
description="Path to a directory where cache is going to be stored to.", default=None
|
|
41
|
+
)
|
|
42
|
+
cache_key_prefix: str = Field(default="", description="Optional prefix for key-value cache items.")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class QueueConfig(EnvConfig):
|
|
46
|
+
env_proxy = EnvProxy(prefix="HYPERION_QUEUE")
|
|
47
|
+
|
|
48
|
+
url: str | None = Field(description="The URL of the SQS queue.", default=None)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SecretsConfig(EnvConfig):
|
|
52
|
+
env_proxy = EnvProxy(prefix="HYPERION_SECRETS")
|
|
53
|
+
|
|
54
|
+
backend: str | None = Field(
|
|
55
|
+
description="The backend to use for secrets management. Can only be `AWSSecretsManager` or None.", default=None
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GeoConfig(EnvConfig):
|
|
60
|
+
env_proxy = EnvProxy(prefix="HYPERION_GEO")
|
|
61
|
+
|
|
62
|
+
gmaps_api_key: str | None = Field(description="Google Maps API key.")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class HttpConfig(EnvConfig):
|
|
66
|
+
env_proxy = EnvProxy(prefix="HYPERION_HTTP")
|
|
67
|
+
|
|
68
|
+
proxy_http: str | None = Field(default=None)
|
|
69
|
+
proxy_https: str | None = Field(default=None)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
config = CommonConfig()
|
|
73
|
+
storage_config = StorageConfig()
|
|
74
|
+
geo_config = GeoConfig()
|
|
75
|
+
queue_config = QueueConfig()
|
|
76
|
+
secrets_config = SecretsConfig()
|
|
77
|
+
http_config = HttpConfig()
|