hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperion/__init__.py +0 -0
- hyperion/asyncutils.py +79 -0
- hyperion/catalog/__init__.py +8 -0
- hyperion/catalog/catalog.py +623 -0
- hyperion/catalog/schema.py +153 -0
- hyperion/collections/__init__.py +0 -0
- hyperion/collections/asset_collection.py +285 -0
- hyperion/config.py +77 -0
- hyperion/dateutils.py +238 -0
- hyperion/entities/__init__.py +0 -0
- hyperion/entities/catalog.py +190 -0
- hyperion/infrastructure/__init__.py +0 -0
- hyperion/infrastructure/aws.py +220 -0
- hyperion/infrastructure/cache.py +396 -0
- hyperion/infrastructure/geo/__init__.py +7 -0
- hyperion/infrastructure/geo/gmaps.py +124 -0
- hyperion/infrastructure/geo/location.py +186 -0
- hyperion/infrastructure/http.py +62 -0
- hyperion/infrastructure/keyval.py +264 -0
- hyperion/infrastructure/queue.py +151 -0
- hyperion/infrastructure/secrets.py +63 -0
- hyperion/logging.py +122 -0
- hyperion/py.typed +0 -0
- hyperion/sources/__init__.py +0 -0
- hyperion/sources/base.py +105 -0
- hyperion/typeutils.py +52 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/METADATA +476 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/RECORD +29 -0
- hyperion_sdk-0.2.0.dev1741815359.dist-info/WHEEL +4 -0
hyperion/dateutils.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import re
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Literal, TypeAlias, cast
|
|
6
|
+
|
|
7
|
+
from dateutil.relativedelta import relativedelta
|
|
8
|
+
|
|
9
|
+
from hyperion.logging import get_logger
|
|
10
|
+
|
|
11
|
+
TIME_UNITS = ["s", "m", "h", "d", "w", "M", "y"]
|
|
12
|
+
PATT_TIME_RESOLUTION = re.compile(rf"(?P<value>\d+)(?P<unit>[{''.join(TIME_UNITS)}])")
|
|
13
|
+
TimeResolutionUnit: TypeAlias = Literal["s", "m", "h", "d", "w", "M", "y"]
|
|
14
|
+
|
|
15
|
+
logger = get_logger("hyperion-dateutils")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True, eq=True)
|
|
19
|
+
class TimeResolution:
|
|
20
|
+
value: int
|
|
21
|
+
unit: TimeResolutionUnit
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
if not isinstance(self.value, int):
|
|
25
|
+
super().__setattr__("value", int(self.value))
|
|
26
|
+
if self.unit not in TIME_UNITS:
|
|
27
|
+
raise ValueError(f"Unknown time unit {self.unit!r}. Pick one of {', '.join(TIME_UNITS)}")
|
|
28
|
+
|
|
29
|
+
def __repr__(self) -> str:
|
|
30
|
+
return f"{self.value}{self.unit}"
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def from_str(string: str) -> "TimeResolution":
|
|
34
|
+
if (rematch := PATT_TIME_RESOLUTION.match(string)) is None:
|
|
35
|
+
raise ValueError(f"Invalid time resolution specification {string!r}. Use expressions such as 1d, 5s or 3M.")
|
|
36
|
+
value = int(rematch.group("value"))
|
|
37
|
+
unit = cast(TimeResolutionUnit, rematch.group("unit"))
|
|
38
|
+
return TimeResolution(value=value, unit=unit)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def delta(self) -> datetime.timedelta | relativedelta:
|
|
42
|
+
match self.unit:
|
|
43
|
+
case "s":
|
|
44
|
+
return datetime.timedelta(seconds=self.value)
|
|
45
|
+
case "m":
|
|
46
|
+
return datetime.timedelta(minutes=self.value)
|
|
47
|
+
case "h":
|
|
48
|
+
return datetime.timedelta(hours=self.value)
|
|
49
|
+
case "d":
|
|
50
|
+
return datetime.timedelta(days=self.value)
|
|
51
|
+
case "w":
|
|
52
|
+
return datetime.timedelta(days=self.value * 7)
|
|
53
|
+
case "M":
|
|
54
|
+
return relativedelta(months=self.value)
|
|
55
|
+
case "y":
|
|
56
|
+
return relativedelta(years=self.value)
|
|
57
|
+
case _:
|
|
58
|
+
raise ValueError(f"Unsupported time unit {self.unit!r}.")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def truncate_datetime(base: datetime.datetime | datetime.date, unit: TimeResolutionUnit) -> datetime.datetime:
|
|
62
|
+
"""Truncate datetime to the specified unit (set all smaller units to zero)."""
|
|
63
|
+
if not isinstance(base, datetime.datetime):
|
|
64
|
+
base = datetime.datetime(base.year, base.month, base.day, tzinfo=datetime.timezone.utc)
|
|
65
|
+
match unit:
|
|
66
|
+
case "s":
|
|
67
|
+
return base.replace(microsecond=0)
|
|
68
|
+
case "m":
|
|
69
|
+
return base.replace(second=0, microsecond=0)
|
|
70
|
+
case "h":
|
|
71
|
+
return base.replace(minute=0, second=0, microsecond=0)
|
|
72
|
+
case "d":
|
|
73
|
+
return base.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
74
|
+
case "w":
|
|
75
|
+
return base.replace(hour=0, minute=0, second=0, microsecond=0) - datetime.timedelta(days=base.weekday())
|
|
76
|
+
case "M":
|
|
77
|
+
return base.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
78
|
+
case "y":
|
|
79
|
+
return base.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
80
|
+
raise ValueError(f"Unknown time unit {unit!r}. Pick one of {', '.join(TIME_UNITS)}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def iter_dates_between(
|
|
84
|
+
start_date: datetime.datetime | datetime.date,
|
|
85
|
+
end_date: datetime.datetime | datetime.date,
|
|
86
|
+
granularity: TimeResolutionUnit,
|
|
87
|
+
) -> Iterator[datetime.datetime]:
|
|
88
|
+
"""
|
|
89
|
+
Iterate over datetimes between start_date and end_date with steps based on the given granularity.
|
|
90
|
+
Includes the start_date and may include end_date (if end date is reachable from start date with given granularity).
|
|
91
|
+
|
|
92
|
+
:param start_date: The starting datetime.
|
|
93
|
+
:param end_date: The ending datetime.
|
|
94
|
+
:param granularity: The granularity for steps (e.g., "d" for days, "M" for months).
|
|
95
|
+
:return: An iterator of datetime objects.
|
|
96
|
+
"""
|
|
97
|
+
start_date = assure_timezone(start_date)
|
|
98
|
+
end_date = assure_timezone(end_date)
|
|
99
|
+
logger.debug(
|
|
100
|
+
"Generating dates between two points.", start_date=start_date.isoformat(), end_date=end_date.isoformat()
|
|
101
|
+
)
|
|
102
|
+
if start_date > end_date:
|
|
103
|
+
raise ValueError("Start date cannot be later than end date.")
|
|
104
|
+
|
|
105
|
+
current = start_date
|
|
106
|
+
|
|
107
|
+
delta: datetime.timedelta | relativedelta
|
|
108
|
+
|
|
109
|
+
match granularity:
|
|
110
|
+
case "s":
|
|
111
|
+
delta = datetime.timedelta(seconds=1)
|
|
112
|
+
case "m":
|
|
113
|
+
delta = datetime.timedelta(minutes=1)
|
|
114
|
+
case "h":
|
|
115
|
+
delta = datetime.timedelta(hours=1)
|
|
116
|
+
case "d":
|
|
117
|
+
delta = datetime.timedelta(days=1)
|
|
118
|
+
case "w":
|
|
119
|
+
delta = datetime.timedelta(weeks=1)
|
|
120
|
+
case "M":
|
|
121
|
+
delta = relativedelta(months=1)
|
|
122
|
+
case "y":
|
|
123
|
+
delta = relativedelta(years=1)
|
|
124
|
+
case default:
|
|
125
|
+
raise ValueError(f"Unsupported granularity {default!r}.")
|
|
126
|
+
|
|
127
|
+
while current <= end_date:
|
|
128
|
+
yield current
|
|
129
|
+
current += delta
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def quantize_datetime(base: datetime.datetime, resolution: TimeResolution | str) -> datetime.datetime:
|
|
133
|
+
"""
|
|
134
|
+
Quantize a datetime to the next interval based on the specified resolution.
|
|
135
|
+
|
|
136
|
+
This function aligns a given datetime to the next moment in time defined
|
|
137
|
+
by the resolution. The resolution is expressed as a unit (seconds, minutes,
|
|
138
|
+
hours, or days) and a value (e.g., 5 seconds, 15 minutes).
|
|
139
|
+
|
|
140
|
+
**Important Notes:**
|
|
141
|
+
- The resolution is always calculated relative to the higher unit, which may lead
|
|
142
|
+
to overlapping intervals for non-standard values. For example:
|
|
143
|
+
- A 7-second resolution could result in intervals ending at 12:50:56 and 12:51:03,
|
|
144
|
+
with another interval starting at 12:51:00, causing overlaps.
|
|
145
|
+
- To avoid such overlaps, it is recommended to use resolutions that are divisors of
|
|
146
|
+
the higher unit (e.g., 60 for seconds and minutes).
|
|
147
|
+
|
|
148
|
+
Parameters:
|
|
149
|
+
- base (datetime.datetime): The datetime to quantize.
|
|
150
|
+
- resolution (TimeResolution | str): The resolution for quantization.
|
|
151
|
+
If a string is provided, it should follow the format "{value}{unit}"
|
|
152
|
+
(e.g., "5s", "15m", "2h").
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
- datetime.datetime: The quantized datetime.
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
- ValueError: If the resolution unit is unsupported.
|
|
159
|
+
"""
|
|
160
|
+
resolution = resolution if isinstance(resolution, TimeResolution) else TimeResolution.from_str(resolution)
|
|
161
|
+
base_truncated = truncate_datetime(base, resolution.unit)
|
|
162
|
+
|
|
163
|
+
def _get_shift(value: int) -> int:
|
|
164
|
+
return resolution.value - (value % resolution.value)
|
|
165
|
+
|
|
166
|
+
match resolution.unit:
|
|
167
|
+
case "s":
|
|
168
|
+
seconds_shift = _get_shift(base.second)
|
|
169
|
+
return base_truncated + datetime.timedelta(seconds=seconds_shift)
|
|
170
|
+
case "m":
|
|
171
|
+
minutes_shift = _get_shift(base.minute)
|
|
172
|
+
return base_truncated + datetime.timedelta(minutes=minutes_shift)
|
|
173
|
+
case "h":
|
|
174
|
+
hours_shift = _get_shift(base.hour)
|
|
175
|
+
return base_truncated + datetime.timedelta(hours=hours_shift)
|
|
176
|
+
case "d":
|
|
177
|
+
days_shift = _get_shift(base.day)
|
|
178
|
+
return base_truncated + datetime.timedelta(days=days_shift)
|
|
179
|
+
case "w":
|
|
180
|
+
days_shift = resolution.value * 7 - (base.weekday() % resolution.value)
|
|
181
|
+
return base_truncated + datetime.timedelta(days=days_shift)
|
|
182
|
+
case "M":
|
|
183
|
+
months_shift = _get_shift(base.month)
|
|
184
|
+
return base_truncated + relativedelta(months=months_shift)
|
|
185
|
+
case "y":
|
|
186
|
+
years_shift = _get_shift(base.year)
|
|
187
|
+
return base_truncated + relativedelta(years=years_shift)
|
|
188
|
+
raise ValueError(f"Unsupported resolution unit {resolution.unit!r} for quantization.") # pragma: no cover
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def assure_timezone(
|
|
192
|
+
base: datetime.datetime | datetime.date, tz: datetime.timezone = datetime.timezone.utc
|
|
193
|
+
) -> datetime.datetime:
|
|
194
|
+
"""Assure datetime has a datetime and return it timezone-aware if not."""
|
|
195
|
+
if not isinstance(base, datetime.datetime):
|
|
196
|
+
return datetime.datetime(base.year, base.month, base.day, tzinfo=tz)
|
|
197
|
+
if base.tzinfo is not None:
|
|
198
|
+
if base.tzinfo == tz:
|
|
199
|
+
return base
|
|
200
|
+
return base.astimezone(tz)
|
|
201
|
+
logger.warning(f"A timezone-unaware timestamp was given, assuming {tz!r}.")
|
|
202
|
+
return base.replace(tzinfo=tz)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def get_date_pattern(date: datetime.datetime, unit: TimeResolutionUnit) -> str:
|
|
206
|
+
"""Get a date pattern string up to the specified unit level.
|
|
207
|
+
|
|
208
|
+
Examples:
|
|
209
|
+
>>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "h")
|
|
210
|
+
"2025-01-12T12"
|
|
211
|
+
>>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "d")
|
|
212
|
+
"2025-01-12"
|
|
213
|
+
>>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "M")
|
|
214
|
+
"2025-01"
|
|
215
|
+
>>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "y")
|
|
216
|
+
"2025"
|
|
217
|
+
"""
|
|
218
|
+
truncated = truncate_datetime(date, unit)
|
|
219
|
+
match unit:
|
|
220
|
+
case "s":
|
|
221
|
+
return truncated.strftime("%Y-%m-%dT%H:%M:%S")
|
|
222
|
+
case "m":
|
|
223
|
+
return truncated.strftime("%Y-%m-%dT%H:%M")
|
|
224
|
+
case "h":
|
|
225
|
+
return truncated.strftime("%Y-%m-%dT%H")
|
|
226
|
+
case "d":
|
|
227
|
+
return truncated.strftime("%Y-%m-%d")
|
|
228
|
+
case "w":
|
|
229
|
+
return truncated.strftime("%Y-%m-%d") # Week is special, keep the day
|
|
230
|
+
case "M":
|
|
231
|
+
return truncated.strftime("%Y-%m")
|
|
232
|
+
case "y":
|
|
233
|
+
return truncated.strftime("%Y")
|
|
234
|
+
raise ValueError(f"Unsupported time unit {unit!r}.")
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def utcnow() -> datetime.datetime:
|
|
238
|
+
return datetime.datetime.now(tz=datetime.timezone.utc)
|
|
File without changes
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import json
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import ClassVar, Literal, Protocol
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from hyperion.dateutils import TimeResolution, assure_timezone
|
|
9
|
+
|
|
10
|
+
AssetType = Literal["data_lake", "feature", "persistent_store"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_prefixed_path(path: str, prefix: str = "") -> str:
|
|
14
|
+
"""Get the path with the given prefix.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
path (str): The path.
|
|
18
|
+
prefix (str): The prefix.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
str: The path with the prefix.
|
|
22
|
+
"""
|
|
23
|
+
prefix = prefix.strip("/")
|
|
24
|
+
if prefix:
|
|
25
|
+
prefix = f"{prefix}/"
|
|
26
|
+
return prefix + path
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AssetProtocol(Protocol):
|
|
30
|
+
"""Protocol for assets in the catalog.
|
|
31
|
+
|
|
32
|
+
This protocol defines the interface for assets in the catalog.
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
asset_type (ClassVar[AssetType]): The type of the asset.
|
|
36
|
+
name (str): The name of the asset.
|
|
37
|
+
schema_version (int): The schema version of the asset.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
asset_type: ClassVar[AssetType]
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def name(self) -> str: ...
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def schema_version(self) -> int:
|
|
47
|
+
"""The schema version of the asset."""
|
|
48
|
+
|
|
49
|
+
def get_path(self, prefix: str = "") -> str:
|
|
50
|
+
"""Get the path for the asset with the given prefix.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
prefix (str): The prefix for the path.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
str: The path for the asset.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def to_metadata(self) -> dict[str, str]:
|
|
60
|
+
"""Get the metadata for the asset.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
dict[str, str]: The metadata for the asset.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass(frozen=True, eq=True)
|
|
68
|
+
class DataLakeAsset:
|
|
69
|
+
"""Data lake asset.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
asset_type (ClassVar[AssetType]): The type of the asset.
|
|
73
|
+
name (str): The name of the asset.
|
|
74
|
+
date (datetime.datetime): The date of the asset.
|
|
75
|
+
schema_version (int): The schema version of the asset.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
asset_type: ClassVar[AssetType] = "data_lake"
|
|
79
|
+
name: str
|
|
80
|
+
date: datetime.datetime
|
|
81
|
+
schema_version: int = 1
|
|
82
|
+
|
|
83
|
+
def get_path(self, prefix: str = "") -> str:
|
|
84
|
+
"""Get the path for the asset with the given prefix."""
|
|
85
|
+
date = assure_timezone(self.date).isoformat()
|
|
86
|
+
return get_prefixed_path(f"{self.name}/date={date}/v{self.schema_version}.avro", prefix)
|
|
87
|
+
|
|
88
|
+
def to_metadata(self) -> dict[str, str]:
|
|
89
|
+
"""Get the metadata for the asset."""
|
|
90
|
+
return {"name": self.name, "date": self.date.isoformat(), "schema_version": str(self.schema_version)}
|
|
91
|
+
|
|
92
|
+
def __repr__(self) -> str:
|
|
93
|
+
return f"{self.__class__.__name__}({self.get_path()!r})"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(frozen=True, eq=True)
|
|
97
|
+
class PersistentStoreAsset:
|
|
98
|
+
"""Persistent store asset.
|
|
99
|
+
|
|
100
|
+
Attributes:
|
|
101
|
+
asset_type (ClassVar[AssetType]): The type of the asset.
|
|
102
|
+
name (str): The name of the asset.
|
|
103
|
+
schema_version (int): The schema version of the asset.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
asset_type: ClassVar[AssetType] = "persistent_store"
|
|
107
|
+
name: str
|
|
108
|
+
schema_version: int = 1
|
|
109
|
+
|
|
110
|
+
def get_path(self, prefix: str = "") -> str:
|
|
111
|
+
"""Get the path for the asset with the given prefix."""
|
|
112
|
+
return get_prefixed_path(f"{self.name}/v{self.schema_version}.avro", prefix)
|
|
113
|
+
|
|
114
|
+
def to_metadata(self) -> dict[str, str]:
|
|
115
|
+
"""Get the metadata for the asset."""
|
|
116
|
+
return {"name": self.name, "schema_version": str(self.schema_version)}
|
|
117
|
+
|
|
118
|
+
def __repr__(self) -> str:
|
|
119
|
+
return f"{self.__class__.__name__}({self.get_path()!r})"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass(frozen=True, eq=True)
|
|
123
|
+
class FeatureAsset:
|
|
124
|
+
"""Feature asset.
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
asset_type (ClassVar[AssetType]): The type of the asset.
|
|
128
|
+
name (str): The name of the asset.
|
|
129
|
+
partition_date (datetime.datetime): The partition timestamp of the asset.
|
|
130
|
+
resolution (TimeResolution | str): The resolution of the asset.
|
|
131
|
+
schema_version (int): The schema version of the asset.
|
|
132
|
+
partition_keys (dict[str, str]): The partition keys of the asset.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
asset_type: ClassVar[AssetType] = "feature"
|
|
136
|
+
name: str
|
|
137
|
+
partition_date: datetime.datetime
|
|
138
|
+
resolution: TimeResolution | str
|
|
139
|
+
schema_version: int = 1
|
|
140
|
+
partition_keys: dict[str, str] = field(default_factory=dict)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def time_resolution(self) -> TimeResolution:
|
|
144
|
+
"""The time resolution of the feature."""
|
|
145
|
+
if isinstance(self.resolution, TimeResolution):
|
|
146
|
+
return self.resolution
|
|
147
|
+
return TimeResolution.from_str(self.resolution)
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def feature_name(self) -> str:
|
|
151
|
+
"""The name of the feature including the time resolution."""
|
|
152
|
+
return f"{self.name}.{self.time_resolution!r}"
|
|
153
|
+
|
|
154
|
+
def _get_partition_keys_prefix(self) -> str:
|
|
155
|
+
key_names = sorted(self.partition_keys.keys())
|
|
156
|
+
return ("/".join(f"{key}={self.partition_keys[key]}" for key in key_names)).strip("/")
|
|
157
|
+
|
|
158
|
+
def get_path(self, prefix: str = "") -> str:
|
|
159
|
+
"""Get the path for the asset with the given prefix."""
|
|
160
|
+
partition_date = assure_timezone(self.partition_date).isoformat()
|
|
161
|
+
keys_prefix = self._get_partition_keys_prefix()
|
|
162
|
+
if keys_prefix:
|
|
163
|
+
keys_prefix = keys_prefix + "/"
|
|
164
|
+
return get_prefixed_path(
|
|
165
|
+
f"{self.feature_name}/{keys_prefix}partition_date={partition_date}/v{self.schema_version}.avro", prefix
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def to_metadata(self) -> dict[str, str]:
|
|
169
|
+
"""Get the metadata for the asset."""
|
|
170
|
+
return {
|
|
171
|
+
"name": self.name,
|
|
172
|
+
"partition_date": self.partition_date.isoformat(),
|
|
173
|
+
"schema_version": str(self.schema_version),
|
|
174
|
+
"partition_keys": json.dumps(self.partition_keys),
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
def __repr__(self) -> str:
|
|
178
|
+
return f"{self.__class__.__name__}({self.get_path()!r})"
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class FeatureModel(BaseModel):
|
|
182
|
+
"""A base class for feature models.
|
|
183
|
+
|
|
184
|
+
You may use this base class (along with pydantic's BaseModel) to define type-safe feature models.
|
|
185
|
+
Use with "AssetCollection" to make powerful typed feature collections.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
asset_name: ClassVar[str] = NotImplemented
|
|
189
|
+
resolution: ClassVar[TimeResolution] = NotImplemented
|
|
190
|
+
schema_version: ClassVar[int] = 1
|
|
File without changes
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""AWS helpers and methods."""
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import logging
|
|
5
|
+
from asyncio import Semaphore
|
|
6
|
+
from collections.abc import AsyncIterator, Iterator
|
|
7
|
+
from contextlib import ExitStack, asynccontextmanager
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import IO, BinaryIO, TypeVar, cast
|
|
12
|
+
|
|
13
|
+
import aioboto3
|
|
14
|
+
import boto3
|
|
15
|
+
import botocore.exceptions
|
|
16
|
+
|
|
17
|
+
from hyperion.config import storage_config
|
|
18
|
+
from hyperion.logging import get_logger
|
|
19
|
+
|
|
20
|
+
PathOrIOBinary = str | Path | BinaryIO | IO[bytes]
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
logger = get_logger("aws")
|
|
25
|
+
|
|
26
|
+
logging.getLogger("botocore.endpoint").setLevel("WARNING")
|
|
27
|
+
logging.getLogger("botocore").setLevel("WARNING")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class S3StorageClass(str, Enum):
|
|
31
|
+
"""S3 storage classes."""
|
|
32
|
+
|
|
33
|
+
STANDARD = "STANDARD"
|
|
34
|
+
REDUCED_REDUNDANCY = "REDUCED_REDUNDANCY"
|
|
35
|
+
STANDARD_IA = "STANDARD_IA"
|
|
36
|
+
ONEZONE_IA = "ONEZONE_IA"
|
|
37
|
+
INTELLIGENT_TIERING = "INTELLIGENT_TIERING"
|
|
38
|
+
GLACIER = "GLACIER"
|
|
39
|
+
DEEP_ARCHIVE = "DEEP_ARCHIVE"
|
|
40
|
+
OUTPOSTS = "OUTPOSTS"
|
|
41
|
+
GLACIER_IR = "GLACIER_IR"
|
|
42
|
+
SNOW = "SNOW"
|
|
43
|
+
EXPRESS_ONEZONE = "EXPRESS_ONEZONE"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class S3ObjectAttributes:
|
|
48
|
+
"""Attributes of an S3 object."""
|
|
49
|
+
|
|
50
|
+
last_modified: datetime.datetime
|
|
51
|
+
etag: str
|
|
52
|
+
storage_class: S3StorageClass
|
|
53
|
+
object_size: int
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class S3Client:
|
|
57
|
+
"""S3 client."""
|
|
58
|
+
|
|
59
|
+
_storage_semaphore: Semaphore | None = None
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
@asynccontextmanager
|
|
63
|
+
async def semaphore(cls) -> AsyncIterator[None]:
|
|
64
|
+
"""Semaphore for asynchronous storage operations."""
|
|
65
|
+
# Lazily initialize the semaphore instance
|
|
66
|
+
if cls._storage_semaphore is None:
|
|
67
|
+
logger.debug(
|
|
68
|
+
"Initializing new storage operations semaphore.", max_concurrency=storage_config.max_concurrency
|
|
69
|
+
)
|
|
70
|
+
cls._storage_semaphore = Semaphore(storage_config.max_concurrency)
|
|
71
|
+
logger.debug("Attempting to acquire the storage operations semaphore.")
|
|
72
|
+
async with cls._storage_semaphore:
|
|
73
|
+
logger.debug("Green light, semaphore open.")
|
|
74
|
+
yield
|
|
75
|
+
|
|
76
|
+
def __init__(self) -> None:
|
|
77
|
+
self._client = boto3.client("s3")
|
|
78
|
+
self._aio_session = aioboto3.Session()
|
|
79
|
+
|
|
80
|
+
async def upload_async(self, file: PathOrIOBinary, bucket: str, name: str) -> None:
|
|
81
|
+
"""Upload a file to S3 asynchronously.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
file (PathOrIOBinary): The file to upload.
|
|
85
|
+
bucket (str): The bucket to upload the file to.
|
|
86
|
+
name (str): The name of the file in the bucket.
|
|
87
|
+
"""
|
|
88
|
+
with ExitStack() as file_context:
|
|
89
|
+
if isinstance(file, str | Path):
|
|
90
|
+
path = Path(file)
|
|
91
|
+
logger.debug("Uploading from path.", path=path.as_posix(), bucket=bucket, name=name)
|
|
92
|
+
file = file_context.enter_context(path.open("rb"))
|
|
93
|
+
async with self.semaphore(), self._aio_session.client("s3") as s3:
|
|
94
|
+
try:
|
|
95
|
+
await s3.upload_fileobj(file, bucket, name)
|
|
96
|
+
except botocore.exceptions.ClientError:
|
|
97
|
+
logger.error("Error when uploading file to S3.", bucket=bucket, name=name)
|
|
98
|
+
raise
|
|
99
|
+
|
|
100
|
+
def iter_objects(self, bucket: str, prefix: str) -> Iterator[str]:
|
|
101
|
+
"""Iterate over objects in an S3 bucket.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
bucket (str): The bucket to list objects in.
|
|
105
|
+
prefix (str): The prefix to filter objects by.
|
|
106
|
+
|
|
107
|
+
Yields:
|
|
108
|
+
Iterator[str]: The keys of the objects in the bucket.
|
|
109
|
+
"""
|
|
110
|
+
paginator = self._client.get_paginator("list_objects_v2")
|
|
111
|
+
pagination_config = {"StartingToken": None}
|
|
112
|
+
logger.debug("Listing contents of a bucket.", bucket=bucket, prefix=prefix)
|
|
113
|
+
response_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, PaginationConfig=pagination_config)
|
|
114
|
+
for response in response_iterator:
|
|
115
|
+
logger.debug("Received a response from S3.", response=response)
|
|
116
|
+
if "Contents" not in response:
|
|
117
|
+
continue
|
|
118
|
+
yield from (s3_object["Key"] for s3_object in response["Contents"])
|
|
119
|
+
|
|
120
|
+
def upload(self, file: PathOrIOBinary, bucket: str, name: str) -> None:
|
|
121
|
+
"""Upload a file to S3.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
file (PathOrIOBinary): The file to upload.
|
|
125
|
+
bucket (str): The bucket to upload the file to.
|
|
126
|
+
name (str): The name of the file in the bucket.
|
|
127
|
+
"""
|
|
128
|
+
if isinstance(file, str | Path):
|
|
129
|
+
file = Path(file)
|
|
130
|
+
logger.debug("Uploading from path.", path=file.as_posix(), bucket=bucket, name=name)
|
|
131
|
+
try:
|
|
132
|
+
self._client.upload_file(file, bucket, name)
|
|
133
|
+
except botocore.exceptions.ClientError:
|
|
134
|
+
logger.error("Error when uploading file to S3.", file=file.as_posix(), bucket=bucket, name=name)
|
|
135
|
+
raise
|
|
136
|
+
return
|
|
137
|
+
logger.debug("Uploading from an open file stream.", bucket=bucket, name=name)
|
|
138
|
+
try:
|
|
139
|
+
self._client.upload_fileobj(file, bucket, name)
|
|
140
|
+
except botocore.exceptions.ClientError:
|
|
141
|
+
logger.error("Error when uploading file to S3.", bucket=bucket, name=name)
|
|
142
|
+
raise
|
|
143
|
+
|
|
144
|
+
def get_object_attributes(self, bucket: str, name: str) -> S3ObjectAttributes:
|
|
145
|
+
"""Get the attributes of an object in S3.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
bucket (str): The bucket containing the object.
|
|
149
|
+
name (str): The name of the object.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
S3ObjectAttributes: The attributes of the object.
|
|
153
|
+
"""
|
|
154
|
+
response = self._client.get_object_attributes(
|
|
155
|
+
Bucket=bucket,
|
|
156
|
+
Key=name,
|
|
157
|
+
ObjectAttributes=[
|
|
158
|
+
"ETag",
|
|
159
|
+
"Checksum",
|
|
160
|
+
"ObjectParts",
|
|
161
|
+
"StorageClass",
|
|
162
|
+
"ObjectSize",
|
|
163
|
+
],
|
|
164
|
+
)
|
|
165
|
+
try:
|
|
166
|
+
return S3ObjectAttributes(
|
|
167
|
+
last_modified=response["LastModified"],
|
|
168
|
+
etag=response["ETag"],
|
|
169
|
+
storage_class=S3StorageClass(response["StorageClass"]),
|
|
170
|
+
object_size=int(response["ObjectSize"]),
|
|
171
|
+
)
|
|
172
|
+
except Exception:
|
|
173
|
+
logger.error(
|
|
174
|
+
"Failed to get object attributes, the response is probably invalid.",
|
|
175
|
+
bucket=bucket,
|
|
176
|
+
object_name=name,
|
|
177
|
+
response=response,
|
|
178
|
+
)
|
|
179
|
+
raise
|
|
180
|
+
|
|
181
|
+
def download(self, bucket: str, name: str, file: PathOrIOBinary) -> None:
|
|
182
|
+
"""Download a file from S3.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
bucket (str): The bucket containing the file.
|
|
186
|
+
name (str): The name of the file.
|
|
187
|
+
file (PathOrIOBinary): The file to download to.
|
|
188
|
+
"""
|
|
189
|
+
if isinstance(file, str | Path):
|
|
190
|
+
file = Path(file)
|
|
191
|
+
logger.debug("Downloading into a path.", path=file.as_posix(), bucket=bucket, name=name)
|
|
192
|
+
try:
|
|
193
|
+
self._client.download_file(bucket, name, file)
|
|
194
|
+
except botocore.exceptions.ClientError:
|
|
195
|
+
logger.error("Error when downloading file from S3.", bucket=bucket, name=name, path=file.as_posix())
|
|
196
|
+
raise
|
|
197
|
+
return
|
|
198
|
+
logger.debug("Downloading into a file stream.", bucket=bucket, name=name)
|
|
199
|
+
try:
|
|
200
|
+
self._client.download_fileobj(bucket, name, file)
|
|
201
|
+
except botocore.exceptions.ClientError:
|
|
202
|
+
logger.error("Error when downloading file from S3.", bucket=bucket, name=name)
|
|
203
|
+
raise
|
|
204
|
+
|
|
205
|
+
def download_as_string(self, bucket: str, name: str) -> str:
|
|
206
|
+
"""Download an object from S3 as a string.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
bucket (str): The bucket containing the object.
|
|
210
|
+
name (str): The name of the object.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
str: The object as a string.
|
|
214
|
+
"""
|
|
215
|
+
logger.debug("Downloading object as a string.", bucket=bucket, name=name)
|
|
216
|
+
try:
|
|
217
|
+
return cast(str, self._client.get_object(Bucket=bucket, Key=name)["Body"].read().decode("utf-8"))
|
|
218
|
+
except botocore.exceptions.ClientError:
|
|
219
|
+
logger.error("Error when downloading file from S3.", bucket=bucket, name=name)
|
|
220
|
+
raise
|