hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,153 @@
1
+ """Schema store."""
2
+
3
+ import abc
4
+ import json
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+ from typing import Any, ClassVar, cast
8
+ from urllib.parse import urlparse
9
+
10
+ from hyperion.config import storage_config
11
+ from hyperion.entities.catalog import AssetProtocol, AssetType
12
+ from hyperion.infrastructure.aws import S3Client
13
+ from hyperion.logging import get_logger
14
+
15
+ logger = get_logger("schema-store")
16
+
17
+ AVRO_SCHEMAS_PATH = Path(__file__).parent / "avro_schemas"
18
+
19
+
20
+ class SchemaStore(abc.ABC):
21
+ """Abstract base class for schema stores."""
22
+
23
+ _instances: ClassVar[dict[str, "SchemaStore"]] = {}
24
+
25
+ def __init__(self, path: str) -> None:
26
+ """Initialize the schema store with the given path.
27
+
28
+ Args:
29
+ path (str): The path to the schema store.
30
+ """
31
+ self.path = path
32
+
33
+ def get_asset_schema(self, asset: AssetProtocol) -> dict[str, Any]:
34
+ """Get the schema for the given asset.
35
+
36
+ Args:
37
+ asset (AssetProtocol): The asset to get the schema for.
38
+
39
+ Returns:
40
+ dict[str, Any]: The schema for the asset.
41
+ """
42
+ return self.get_schema(asset.name, asset.schema_version, asset_type=asset.asset_type)
43
+
44
+ @abc.abstractmethod
45
+ def get_schema(self, asset_name: str, schema_version: int, asset_type: AssetType) -> dict[str, Any]:
46
+ """Get the schema for the asset with the given name and version."""
47
+
48
+ @staticmethod
49
+ def _create_new(path: str) -> "SchemaStore":
50
+ parsed = urlparse(path)
51
+ if parsed.scheme == "file" or not parsed.scheme:
52
+ resolved = (Path(parsed.netloc or "/") / parsed.path.lstrip("/")).resolve()
53
+ logger.info("Using file schema store.", path=resolved.as_posix())
54
+ return LocalSchemaStore(resolved)
55
+ if parsed.scheme == "s3":
56
+ bucket = parsed.netloc
57
+ prefix = parsed.path.lstrip("/")
58
+ logger.info("Using S3 schema store.", bucket=bucket, prefix=prefix)
59
+ return S3SchemaStore(bucket, prefix)
60
+ logger.critical("Unsupported schema store scheme.", scheme=parsed.scheme, path=storage_config.schema_path)
61
+ raise ValueError(f"Unsupported schema store scheme {parsed.scheme!r}.")
62
+
63
+ @staticmethod
64
+ def from_path(path: str) -> "SchemaStore":
65
+ """Get a schema store from the given path.
66
+
67
+ Args:
68
+ path (str): The path to the schema store.
69
+
70
+ Returns:
71
+ SchemaStore: The schema store.
72
+ """
73
+ if path not in SchemaStore._instances:
74
+ SchemaStore._instances[path] = SchemaStore._create_new(path)
75
+ return SchemaStore._instances[path]
76
+
77
+ @staticmethod
78
+ def from_config() -> "SchemaStore":
79
+ """Get a schema store from the configuration.
80
+
81
+ Returns:
82
+ SchemaStore: The schema store.
83
+ """
84
+ return SchemaStore.from_path(storage_config.schema_path)
85
+
86
+
87
+ class LocalSchemaStore(SchemaStore):
88
+ """Schema store for local files."""
89
+
90
+ def __init__(self, schemas_path: Path = AVRO_SCHEMAS_PATH) -> None:
91
+ """Initialize the local schema store with the given path.
92
+
93
+ Args:
94
+ schemas_path (Path, optional): The path to the schemas. Defaults to AVRO_SCHEMAS_PATH.
95
+ """
96
+ super().__init__(schemas_path.as_posix())
97
+ self.schemas_path = schemas_path
98
+ if not schemas_path.exists():
99
+ logger.critical("Provided schemas path does not exist.", schemas_path=schemas_path.as_posix())
100
+ raise FileNotFoundError("Provided schemas path does not exist.")
101
+
102
+ def get_schema(self, asset_name: str, schema_version: int, asset_type: AssetType) -> dict[str, Any]:
103
+ path = self.schemas_path / asset_type / f"{asset_name}.v{schema_version}.avro.json"
104
+ logger.info(
105
+ "Reading avro schema from stored json file.",
106
+ path=path.as_posix(),
107
+ asset_name=asset_name,
108
+ asset_type=asset_type,
109
+ )
110
+ try:
111
+ with path.open("r", encoding="utf-8") as file:
112
+ schema = json.load(file)
113
+ if not isinstance(schema, dict):
114
+ raise TypeError(f"Schema has unexpected type {type(schema)}, expected 'dict'.")
115
+ return schema
116
+ except Exception:
117
+ logger.critical(
118
+ "Failed to get avro schema from stored json file.",
119
+ path=path.as_posix(),
120
+ asset_name=asset_name,
121
+ asset_type=asset_type,
122
+ )
123
+ raise
124
+
125
+
126
+ @lru_cache(maxsize=256)
127
+ def _get_schema_from_s3(bucket: str, key: str, client: S3Client) -> dict[str, Any]:
128
+ return cast(dict[str, Any], json.loads(client.download_as_string(bucket, key)))
129
+
130
+
131
+ class S3SchemaStore(SchemaStore):
132
+ """Schema store for S3."""
133
+
134
+ def __init__(self, bucket: str, prefix: str) -> None:
135
+ """Initialize the S3 schema store with the given bucket and prefix.
136
+
137
+ Args:
138
+ bucket (str): The S3 bucket.
139
+ prefix (str): The prefix in the bucket.
140
+ """
141
+ super().__init__(f"s3://{bucket}/{prefix}")
142
+ self.bucket = bucket
143
+ self.prefix = prefix
144
+ self.s3_client = S3Client()
145
+
146
+ def get_schema(self, asset_name: str, schema_version: int, asset_type: AssetType) -> dict[str, Any]:
147
+ key = f"{asset_type}/{asset_name}.v{schema_version}.avro.json"
148
+ logger.info("Getting avro schema from S3.", bucket=self.bucket, key=key)
149
+ try:
150
+ return _get_schema_from_s3(self.bucket, key, self.s3_client)
151
+ except Exception:
152
+ logger.critical("Failed to get avro schema from S3.", bucket=self.bucket, key=key)
153
+ raise
File without changes
@@ -0,0 +1,285 @@
1
+ """Asset collection is a class that allows you to fetch
2
+ and store data from the catalog in a type-safe manner.
3
+ """
4
+
5
+ import asyncio
6
+ import datetime
7
+ from collections.abc import Coroutine
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, ClassVar, Generic, TypeVar, cast
10
+
11
+ from hyperion.asyncutils import iter_async
12
+ from hyperion.catalog import Catalog
13
+ from hyperion.dateutils import utcnow
14
+ from hyperion.entities.catalog import FeatureAsset, FeatureModel
15
+ from hyperion.logging import get_logger
16
+ from hyperion.typeutils import DateOrDelta
17
+
18
+ logger = get_logger("hyperion-model-specification")
19
+
20
+ CClass = TypeVar("CClass", bound=FeatureModel)
21
+ CollectionType = TypeVar("CollectionType", bound="AssetCollection")
22
+
23
+
24
+ @dataclass(frozen=True, eq=True)
25
+ class FeatureAssetSpecification(Generic[CClass]):
26
+ """Specification for fetching feature assets from the catalog.
27
+
28
+ Args:
29
+ feature: The feature model class to fetch.
30
+ start_date: The start date or delta from now to fetch the data.
31
+ end_date: The end date or delta from now to fetch the data.
32
+ """
33
+
34
+ feature: type[CClass]
35
+ start_date: DateOrDelta | None = None
36
+ end_date: DateOrDelta | None = None
37
+
38
+ @staticmethod
39
+ def _resolve_date(
40
+ date_spec: DateOrDelta | None, default: datetime.datetime, the_now: datetime.datetime | None
41
+ ) -> datetime.datetime:
42
+ if date_spec is None:
43
+ return default
44
+ if isinstance(date_spec, datetime.datetime):
45
+ return date_spec
46
+ return (the_now or utcnow()) + date_spec
47
+
48
+ def resolve_start_date(self, the_now: datetime.datetime | None = None) -> datetime.datetime:
49
+ """Resolve the start date for fetching the feature asset data."""
50
+ return self._resolve_date(self.start_date, datetime.datetime.min, the_now)
51
+
52
+ def resolve_end_date(self, the_now: datetime.datetime | None = None) -> datetime.datetime:
53
+ """Resolve the end date for fetching the feature asset data."""
54
+ return self._resolve_date(self.end_date, utcnow(), the_now)
55
+
56
+
57
+ @dataclass
58
+ class _CollectionState:
59
+ """Internal state of the asset collection.
60
+
61
+ Args:
62
+ fetched: Whether the data has been fetched.
63
+ data: A mapping of field names to the fetched data.
64
+ fetch_specifications: A mapping of field names to the fetch specifications.
65
+ anchor_timestamps: The anchor timestamps for fetching the data.
66
+ semaphore: The asyncio semaphore for controlling concurrency.
67
+ max_concurrency: The maximum concurrency for fetching data.
68
+ """
69
+
70
+ fetched: bool = False
71
+ data: dict[str, list[Any]] = field(default_factory=dict)
72
+ fetch_specifications: dict[str, FeatureAssetSpecification[Any]] = field(default_factory=dict)
73
+ anchor_timestamps: dict[str, datetime.datetime | None] = field(default_factory=dict)
74
+ semaphore: asyncio.Semaphore | None = None
75
+ max_concurrency: int | None = None
76
+
77
+
78
+ class AssetCollection:
79
+ """A collection of feature assets that can be fetched from the catalog.
80
+
81
+ Attributes:
82
+ catalog: The catalog to fetch the data from. If not set, it will be created from the config.
83
+ max_concurrency: The maximum concurrency for fetching data. Default is 8.
84
+ reserved_fields: The reserved field names for the collection.
85
+ _state: The internal state of the collection. It should under no circumstances be modified directly.
86
+ """
87
+
88
+ catalog: ClassVar[Catalog | None] = None
89
+ max_concurrency: ClassVar[int] = 8
90
+ reserved_fields: ClassVar = ("catalog", "max_concurrency", "reserved_fields")
91
+ _state: ClassVar[_CollectionState]
92
+
93
+ @classmethod
94
+ def _get_state(cls) -> _CollectionState:
95
+ if not hasattr(cls, "_state"):
96
+ logger.debug("Creating new empty state for the collection.", collection=cls.__name__)
97
+ cls._state = _CollectionState()
98
+ return cls._state
99
+
100
+ @classmethod
101
+ def _get_semaphore(cls) -> asyncio.Semaphore:
102
+ state = cls._get_state()
103
+ if state.semaphore is None:
104
+ logger.debug("Creating new semaphore.", collection=cls.__name__, max_concurrency=cls.max_concurrency)
105
+ state.semaphore = asyncio.Semaphore(cls.max_concurrency)
106
+ state.max_concurrency = cls.max_concurrency
107
+ if state.max_concurrency != cls.max_concurrency:
108
+ logger.warning(
109
+ "Config max_concurrency cannot be changed after first use of the collection.", collection=cls.__name__
110
+ )
111
+ return state.semaphore
112
+
113
+ @classmethod
114
+ def is_fetched(cls) -> bool:
115
+ """Check if the collection has fetched all data."""
116
+ return cls._get_state().fetched
117
+
118
+ @classmethod
119
+ def get_data(cls, field: str) -> list[Any]:
120
+ """Get the fetched data for the given field."""
121
+ state = cls._get_state()
122
+ if field not in state.data:
123
+ raise ValueError(f"Data for {field!r} has not been fetched yet. Did you call 'fetch_all()'?")
124
+ return state.data[field]
125
+
126
+ @classmethod
127
+ def clear(cls) -> None:
128
+ """Clear all fetched data from the collection."""
129
+ logger.info("Clearing all fetched data from the collection.", collection=cls.__name__)
130
+ cls._get_state().data = {}
131
+ cls._get_state().fetched = False
132
+
133
+ @classmethod
134
+ def _get_catalog(cls) -> Catalog:
135
+ if cls.catalog is None:
136
+ cls.catalog = Catalog.from_config()
137
+ return cls.catalog
138
+
139
+ @classmethod
140
+ async def _gather_asset_range(
141
+ cls, asset_spec: FeatureAssetSpecification[CClass], the_now: datetime.datetime
142
+ ) -> list[CClass]:
143
+ start_date = asset_spec.resolve_start_date(the_now)
144
+ end_date = asset_spec.resolve_end_date(the_now)
145
+ partitions = list(
146
+ cls._get_catalog().iter_feature_store_partitions(
147
+ asset_spec.feature.asset_name,
148
+ asset_spec.feature.resolution,
149
+ start_date,
150
+ end_date,
151
+ asset_spec.feature.schema_version,
152
+ )
153
+ )
154
+ all_data: list[CClass] = []
155
+
156
+ async def _retrieve_async(partition: FeatureAsset) -> list[dict[str, Any]]:
157
+ async with cls._get_semaphore():
158
+ logger.debug("Retrieving partition.", partition=partition)
159
+ return await asyncio.to_thread(list, cls._get_catalog().retrieve_asset(partition))
160
+
161
+ tasks: list[Coroutine[None, None, list[dict[str, Any]]]] = []
162
+
163
+ logger.info(
164
+ f"Retrieving {len(partitions)} partitions.",
165
+ partitions=len(partitions),
166
+ asset_name=asset_spec.feature.asset_name,
167
+ )
168
+ for partition in partitions:
169
+ tasks.append(_retrieve_async(partition))
170
+ results = await asyncio.gather(*tasks)
171
+ for data in results:
172
+ all_data.extend(asset_spec.feature(**row) for row in data)
173
+ logger.info(
174
+ f"Downloaded {len(all_data)} rows from {len(partitions)} partitions.",
175
+ asset_name=asset_spec.feature.asset_name,
176
+ )
177
+ return all_data
178
+
179
+ @classmethod
180
+ def register_specification(
181
+ cls,
182
+ field_name: str,
183
+ specification: FeatureAssetSpecification[CClass],
184
+ anchor_timestamp: datetime.datetime | None = None,
185
+ ) -> None:
186
+ """Register a fetch specification for a field in the collection.
187
+
188
+ This is normally only called by the `FeatureFetchSpecifier` descriptor and should
189
+ not be called directly.
190
+
191
+ Args:
192
+ field_name: The name of the field to register the specification for.
193
+ specification: The fetch specification for the field.
194
+ anchor_timestamp: The anchor timestamp for fetching the data.
195
+ """
196
+ if field_name in cls._get_state().fetch_specifications:
197
+ logger.warning(
198
+ "Registering duplicate fetch specification, existing will be discarded.",
199
+ field=field_name,
200
+ asset_name=specification.feature.asset_name,
201
+ )
202
+ logger.debug("Registering field into an asset collection.", collection=cls.__name__, field=field_name)
203
+ cls._get_state().fetch_specifications[field_name] = specification
204
+ cls._get_state().anchor_timestamps[field_name] = anchor_timestamp
205
+
206
+ @classmethod
207
+ async def fetch_all(cls) -> None:
208
+ """Fetch all data for the collection."""
209
+ if cls.is_fetched():
210
+ logger.info(
211
+ "Collection already fetched all data, if you want to start over, call .clear()", collection=cls.__name__
212
+ )
213
+ return
214
+ logger.info("Gather all data within the collection.", collection=cls.__name__)
215
+ tasks: list[Coroutine[None, None, tuple[str, list[Any]]]] = []
216
+
217
+ async def _gather(name: str, specs: FeatureAssetSpecification[CClass]) -> tuple[str, list[CClass]]:
218
+ anchor_timestamp = cls._get_state().anchor_timestamps.get(name) or utcnow()
219
+ return (name, await cls._gather_asset_range(specs, anchor_timestamp))
220
+
221
+ async for prop, specs in iter_async(cls._get_state().fetch_specifications.items()):
222
+ tasks.append(_gather(prop, specs))
223
+
224
+ results = await asyncio.gather(*tasks)
225
+ for name, data in results:
226
+ logger.info("Finished receiving feature data.", field=name)
227
+ cls._get_state().data[name] = data
228
+
229
+ cls._get_state().fetched = True
230
+
231
+
232
+ class _FeatureFetchSpecifier(Generic[CClass]):
233
+ def __init__(
234
+ self,
235
+ feature: type[CClass],
236
+ start_date: DateOrDelta | None = None,
237
+ end_date: DateOrDelta | None = None,
238
+ ) -> None:
239
+ self._specification = FeatureAssetSpecification(feature, start_date, end_date)
240
+ self._owner: type[AssetCollection] | None = None
241
+ self._field_name: str | None = None
242
+
243
+ @property
244
+ def owner(self) -> type[AssetCollection]:
245
+ if self._owner is None:
246
+ raise RuntimeError("Field was not properly initialized and has no owner.")
247
+ return self._owner
248
+
249
+ @property
250
+ def field_name(self) -> str:
251
+ if self._field_name is None:
252
+ raise RuntimeError("Field was not properly initialized and has no name.")
253
+ return self._field_name
254
+
255
+ def __set_name__(self, owner: type[AssetCollection], field_name: str) -> None:
256
+ if not issubclass(owner, AssetCollection) and owner is not AssetCollection:
257
+ raise TypeError(
258
+ f"{self.__class__.__name__!r} can only be a field of AssetCollection or its subclass, "
259
+ f"{owner!r} is not a valid owner."
260
+ )
261
+ if field_name.startswith("_") or field_name in owner.reserved_fields:
262
+ raise ValueError(f"Field name {field_name!r} is reserved for internal use.")
263
+ self._owner = owner
264
+ self._field_name = field_name
265
+ owner.register_specification(self.field_name, self._specification)
266
+
267
+ def __get__(self, _instance: AssetCollection, _instance_type: type[AssetCollection]) -> list[CClass]:
268
+ if not self.owner.is_fetched():
269
+ raise RuntimeError(
270
+ f"Owner collection {self.owner.__name__!r} was not fetched yet. Did you call fetch_all()?"
271
+ )
272
+ return self.owner.get_data(self.field_name)
273
+
274
+
275
+ def FeatureFetchSpecifier( # noqa: N802, a fake class factory
276
+ feature: type[CClass], start_date: DateOrDelta | None = None, end_date: DateOrDelta | None = None
277
+ ) -> list[CClass]:
278
+ """Create a feature fetch specifier for the given feature model class.
279
+
280
+ Args:
281
+ feature: The feature model class to fetch.
282
+ start_date: The start date or delta from now to fetch the data.
283
+ end_date: The end date or delta from now to fetch the data.
284
+ """
285
+ return cast(list[CClass], _FeatureFetchSpecifier(feature, start_date, end_date))
hyperion/config.py ADDED
@@ -0,0 +1,77 @@
1
+ from dotenv import find_dotenv, load_dotenv
2
+ from env_proxy import EnvConfig, EnvProxy, Field
3
+
4
+ load_dotenv(find_dotenv(usecwd=True))
5
+
6
+
7
+ class CommonConfig(EnvConfig):
8
+ env_proxy = EnvProxy(prefix="HYPERION_COMMON")
9
+
10
+ log_pretty: bool = Field(
11
+ default=False, description="If truthy, logs will be pretty-printed. Defaults to JSON logs."
12
+ )
13
+ log_level: str = Field(default="DEBUG", description="The minimum log level.")
14
+ service_name: str = Field(description="The name of the service.", default="hyperion")
15
+
16
+
17
+ class StorageConfig(EnvConfig):
18
+ env_proxy = EnvProxy(prefix="HYPERION_STORAGE")
19
+
20
+ max_concurrency: int = Field(default=5, description="The maximum number of concurrent storage tasks.")
21
+
22
+ data_lake_bucket: str = Field(description="Data lake store bucket name.")
23
+ feature_store_bucket: str = Field(description="Feature store bucket name.")
24
+ persistent_store_bucket: str = Field(description="Persistent store bucket name.")
25
+
26
+ data_lake_prefix: str = Field(default="", description="Optional data lake paths prefix.")
27
+ feature_store_prefix: str = Field(default="", description="Optional feature store paths prefix.")
28
+ persistent_store_prefix: str = Field(default="", description="Optional persistent store path prefix.")
29
+
30
+ schema_path: str = Field(description="The path to the schema store. Supported schemes: 'file', 's3', 'python'.")
31
+
32
+ cache_dynamodb_table: str | None = Field(
33
+ description="The name of DynamoDB table that will be used as key-value cache. If not set, memory cache is used."
34
+ )
35
+ cache_dynamodb_default_ttl: int = Field(
36
+ default=60, description="Default TTL for key-value cache in DynamoDB (in seconds)."
37
+ )
38
+
39
+ cache_local_path: str | None = Field(
40
+ description="Path to a directory where cache is going to be stored to.", default=None
41
+ )
42
+ cache_key_prefix: str = Field(default="", description="Optional prefix for key-value cache items.")
43
+
44
+
45
+ class QueueConfig(EnvConfig):
46
+ env_proxy = EnvProxy(prefix="HYPERION_QUEUE")
47
+
48
+ url: str | None = Field(description="The URL of the SQS queue.", default=None)
49
+
50
+
51
+ class SecretsConfig(EnvConfig):
52
+ env_proxy = EnvProxy(prefix="HYPERION_SECRETS")
53
+
54
+ backend: str | None = Field(
55
+ description="The backend to use for secrets management. Can only be `AWSSecretsManager` or None.", default=None
56
+ )
57
+
58
+
59
+ class GeoConfig(EnvConfig):
60
+ env_proxy = EnvProxy(prefix="HYPERION_GEO")
61
+
62
+ gmaps_api_key: str | None = Field(description="Google Maps API key.")
63
+
64
+
65
+ class HttpConfig(EnvConfig):
66
+ env_proxy = EnvProxy(prefix="HYPERION_HTTP")
67
+
68
+ proxy_http: str | None = Field(default=None)
69
+ proxy_https: str | None = Field(default=None)
70
+
71
+
72
+ config = CommonConfig()
73
+ storage_config = StorageConfig()
74
+ geo_config = GeoConfig()
75
+ queue_config = QueueConfig()
76
+ secrets_config = SecretsConfig()
77
+ http_config = HttpConfig()