hyperion-sdk 0.2.0.dev1741815359__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hyperion/dateutils.py ADDED
@@ -0,0 +1,238 @@
1
+ import datetime
2
+ import re
3
+ from collections.abc import Iterator
4
+ from dataclasses import dataclass
5
+ from typing import Literal, TypeAlias, cast
6
+
7
+ from dateutil.relativedelta import relativedelta
8
+
9
+ from hyperion.logging import get_logger
10
+
11
+ TIME_UNITS = ["s", "m", "h", "d", "w", "M", "y"]
12
+ PATT_TIME_RESOLUTION = re.compile(rf"(?P<value>\d+)(?P<unit>[{''.join(TIME_UNITS)}])")
13
+ TimeResolutionUnit: TypeAlias = Literal["s", "m", "h", "d", "w", "M", "y"]
14
+
15
+ logger = get_logger("hyperion-dateutils")
16
+
17
+
18
+ @dataclass(frozen=True, eq=True)
19
+ class TimeResolution:
20
+ value: int
21
+ unit: TimeResolutionUnit
22
+
23
+ def __post_init__(self) -> None:
24
+ if not isinstance(self.value, int):
25
+ super().__setattr__("value", int(self.value))
26
+ if self.unit not in TIME_UNITS:
27
+ raise ValueError(f"Unknown time unit {self.unit!r}. Pick one of {', '.join(TIME_UNITS)}")
28
+
29
+ def __repr__(self) -> str:
30
+ return f"{self.value}{self.unit}"
31
+
32
+ @staticmethod
33
+ def from_str(string: str) -> "TimeResolution":
34
+ if (rematch := PATT_TIME_RESOLUTION.match(string)) is None:
35
+ raise ValueError(f"Invalid time resolution specification {string!r}. Use expressions such as 1d, 5s or 3M.")
36
+ value = int(rematch.group("value"))
37
+ unit = cast(TimeResolutionUnit, rematch.group("unit"))
38
+ return TimeResolution(value=value, unit=unit)
39
+
40
+ @property
41
+ def delta(self) -> datetime.timedelta | relativedelta:
42
+ match self.unit:
43
+ case "s":
44
+ return datetime.timedelta(seconds=self.value)
45
+ case "m":
46
+ return datetime.timedelta(minutes=self.value)
47
+ case "h":
48
+ return datetime.timedelta(hours=self.value)
49
+ case "d":
50
+ return datetime.timedelta(days=self.value)
51
+ case "w":
52
+ return datetime.timedelta(days=self.value * 7)
53
+ case "M":
54
+ return relativedelta(months=self.value)
55
+ case "y":
56
+ return relativedelta(years=self.value)
57
+ case _:
58
+ raise ValueError(f"Unsupported time unit {self.unit!r}.")
59
+
60
+
61
+ def truncate_datetime(base: datetime.datetime | datetime.date, unit: TimeResolutionUnit) -> datetime.datetime:
62
+ """Truncate datetime to the specified unit (set all smaller units to zero)."""
63
+ if not isinstance(base, datetime.datetime):
64
+ base = datetime.datetime(base.year, base.month, base.day, tzinfo=datetime.timezone.utc)
65
+ match unit:
66
+ case "s":
67
+ return base.replace(microsecond=0)
68
+ case "m":
69
+ return base.replace(second=0, microsecond=0)
70
+ case "h":
71
+ return base.replace(minute=0, second=0, microsecond=0)
72
+ case "d":
73
+ return base.replace(hour=0, minute=0, second=0, microsecond=0)
74
+ case "w":
75
+ return base.replace(hour=0, minute=0, second=0, microsecond=0) - datetime.timedelta(days=base.weekday())
76
+ case "M":
77
+ return base.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
78
+ case "y":
79
+ return base.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
80
+ raise ValueError(f"Unknown time unit {unit!r}. Pick one of {', '.join(TIME_UNITS)}")
81
+
82
+
83
+ def iter_dates_between(
84
+ start_date: datetime.datetime | datetime.date,
85
+ end_date: datetime.datetime | datetime.date,
86
+ granularity: TimeResolutionUnit,
87
+ ) -> Iterator[datetime.datetime]:
88
+ """
89
+ Iterate over datetimes between start_date and end_date with steps based on the given granularity.
90
+ Includes the start_date and may include end_date (if end date is reachable from start date with given granularity).
91
+
92
+ :param start_date: The starting datetime.
93
+ :param end_date: The ending datetime.
94
+ :param granularity: The granularity for steps (e.g., "d" for days, "M" for months).
95
+ :return: An iterator of datetime objects.
96
+ """
97
+ start_date = assure_timezone(start_date)
98
+ end_date = assure_timezone(end_date)
99
+ logger.debug(
100
+ "Generating dates between two points.", start_date=start_date.isoformat(), end_date=end_date.isoformat()
101
+ )
102
+ if start_date > end_date:
103
+ raise ValueError("Start date cannot be later than end date.")
104
+
105
+ current = start_date
106
+
107
+ delta: datetime.timedelta | relativedelta
108
+
109
+ match granularity:
110
+ case "s":
111
+ delta = datetime.timedelta(seconds=1)
112
+ case "m":
113
+ delta = datetime.timedelta(minutes=1)
114
+ case "h":
115
+ delta = datetime.timedelta(hours=1)
116
+ case "d":
117
+ delta = datetime.timedelta(days=1)
118
+ case "w":
119
+ delta = datetime.timedelta(weeks=1)
120
+ case "M":
121
+ delta = relativedelta(months=1)
122
+ case "y":
123
+ delta = relativedelta(years=1)
124
+ case default:
125
+ raise ValueError(f"Unsupported granularity {default!r}.")
126
+
127
+ while current <= end_date:
128
+ yield current
129
+ current += delta
130
+
131
+
132
+ def quantize_datetime(base: datetime.datetime, resolution: TimeResolution | str) -> datetime.datetime:
133
+ """
134
+ Quantize a datetime to the next interval based on the specified resolution.
135
+
136
+ This function aligns a given datetime to the next moment in time defined
137
+ by the resolution. The resolution is expressed as a unit (seconds, minutes,
138
+ hours, or days) and a value (e.g., 5 seconds, 15 minutes).
139
+
140
+ **Important Notes:**
141
+ - The resolution is always calculated relative to the higher unit, which may lead
142
+ to overlapping intervals for non-standard values. For example:
143
+ - A 7-second resolution could result in intervals ending at 12:50:56 and 12:51:03,
144
+ with another interval starting at 12:51:00, causing overlaps.
145
+ - To avoid such overlaps, it is recommended to use resolutions that are divisors of
146
+ the higher unit (e.g., 60 for seconds and minutes).
147
+
148
+ Parameters:
149
+ - base (datetime.datetime): The datetime to quantize.
150
+ - resolution (TimeResolution | str): The resolution for quantization.
151
+ If a string is provided, it should follow the format "{value}{unit}"
152
+ (e.g., "5s", "15m", "2h").
153
+
154
+ Returns:
155
+ - datetime.datetime: The quantized datetime.
156
+
157
+ Raises:
158
+ - ValueError: If the resolution unit is unsupported.
159
+ """
160
+ resolution = resolution if isinstance(resolution, TimeResolution) else TimeResolution.from_str(resolution)
161
+ base_truncated = truncate_datetime(base, resolution.unit)
162
+
163
+ def _get_shift(value: int) -> int:
164
+ return resolution.value - (value % resolution.value)
165
+
166
+ match resolution.unit:
167
+ case "s":
168
+ seconds_shift = _get_shift(base.second)
169
+ return base_truncated + datetime.timedelta(seconds=seconds_shift)
170
+ case "m":
171
+ minutes_shift = _get_shift(base.minute)
172
+ return base_truncated + datetime.timedelta(minutes=minutes_shift)
173
+ case "h":
174
+ hours_shift = _get_shift(base.hour)
175
+ return base_truncated + datetime.timedelta(hours=hours_shift)
176
+ case "d":
177
+ days_shift = _get_shift(base.day)
178
+ return base_truncated + datetime.timedelta(days=days_shift)
179
+ case "w":
180
+ days_shift = resolution.value * 7 - (base.weekday() % resolution.value)
181
+ return base_truncated + datetime.timedelta(days=days_shift)
182
+ case "M":
183
+ months_shift = _get_shift(base.month)
184
+ return base_truncated + relativedelta(months=months_shift)
185
+ case "y":
186
+ years_shift = _get_shift(base.year)
187
+ return base_truncated + relativedelta(years=years_shift)
188
+ raise ValueError(f"Unsupported resolution unit {resolution.unit!r} for quantization.") # pragma: no cover
189
+
190
+
191
+ def assure_timezone(
192
+ base: datetime.datetime | datetime.date, tz: datetime.timezone = datetime.timezone.utc
193
+ ) -> datetime.datetime:
194
+ """Assure datetime has a datetime and return it timezone-aware if not."""
195
+ if not isinstance(base, datetime.datetime):
196
+ return datetime.datetime(base.year, base.month, base.day, tzinfo=tz)
197
+ if base.tzinfo is not None:
198
+ if base.tzinfo == tz:
199
+ return base
200
+ return base.astimezone(tz)
201
+ logger.warning(f"A timezone-unaware timestamp was given, assuming {tz!r}.")
202
+ return base.replace(tzinfo=tz)
203
+
204
+
205
+ def get_date_pattern(date: datetime.datetime, unit: TimeResolutionUnit) -> str:
206
+ """Get a date pattern string up to the specified unit level.
207
+
208
+ Examples:
209
+ >>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "h")
210
+ "2025-01-12T12"
211
+ >>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "d")
212
+ "2025-01-12"
213
+ >>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "M")
214
+ "2025-01"
215
+ >>> get_date_pattern(datetime.datetime(2025, 1, 12, 12), "y")
216
+ "2025"
217
+ """
218
+ truncated = truncate_datetime(date, unit)
219
+ match unit:
220
+ case "s":
221
+ return truncated.strftime("%Y-%m-%dT%H:%M:%S")
222
+ case "m":
223
+ return truncated.strftime("%Y-%m-%dT%H:%M")
224
+ case "h":
225
+ return truncated.strftime("%Y-%m-%dT%H")
226
+ case "d":
227
+ return truncated.strftime("%Y-%m-%d")
228
+ case "w":
229
+ return truncated.strftime("%Y-%m-%d") # Week is special, keep the day
230
+ case "M":
231
+ return truncated.strftime("%Y-%m")
232
+ case "y":
233
+ return truncated.strftime("%Y")
234
+ raise ValueError(f"Unsupported time unit {unit!r}.")
235
+
236
+
237
+ def utcnow() -> datetime.datetime:
238
+ return datetime.datetime.now(tz=datetime.timezone.utc)
File without changes
@@ -0,0 +1,190 @@
1
+ import datetime
2
+ import json
3
+ from dataclasses import dataclass, field
4
+ from typing import ClassVar, Literal, Protocol
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from hyperion.dateutils import TimeResolution, assure_timezone
9
+
10
+ AssetType = Literal["data_lake", "feature", "persistent_store"]
11
+
12
+
13
+ def get_prefixed_path(path: str, prefix: str = "") -> str:
14
+ """Get the path with the given prefix.
15
+
16
+ Args:
17
+ path (str): The path.
18
+ prefix (str): The prefix.
19
+
20
+ Returns:
21
+ str: The path with the prefix.
22
+ """
23
+ prefix = prefix.strip("/")
24
+ if prefix:
25
+ prefix = f"{prefix}/"
26
+ return prefix + path
27
+
28
+
29
+ class AssetProtocol(Protocol):
30
+ """Protocol for assets in the catalog.
31
+
32
+ This protocol defines the interface for assets in the catalog.
33
+
34
+ Attributes:
35
+ asset_type (ClassVar[AssetType]): The type of the asset.
36
+ name (str): The name of the asset.
37
+ schema_version (int): The schema version of the asset.
38
+ """
39
+
40
+ asset_type: ClassVar[AssetType]
41
+
42
+ @property
43
+ def name(self) -> str: ...
44
+
45
+ @property
46
+ def schema_version(self) -> int:
47
+ """The schema version of the asset."""
48
+
49
+ def get_path(self, prefix: str = "") -> str:
50
+ """Get the path for the asset with the given prefix.
51
+
52
+ Args:
53
+ prefix (str): The prefix for the path.
54
+
55
+ Returns:
56
+ str: The path for the asset.
57
+ """
58
+
59
+ def to_metadata(self) -> dict[str, str]:
60
+ """Get the metadata for the asset.
61
+
62
+ Returns:
63
+ dict[str, str]: The metadata for the asset.
64
+ """
65
+
66
+
67
+ @dataclass(frozen=True, eq=True)
68
+ class DataLakeAsset:
69
+ """Data lake asset.
70
+
71
+ Attributes:
72
+ asset_type (ClassVar[AssetType]): The type of the asset.
73
+ name (str): The name of the asset.
74
+ date (datetime.datetime): The date of the asset.
75
+ schema_version (int): The schema version of the asset.
76
+ """
77
+
78
+ asset_type: ClassVar[AssetType] = "data_lake"
79
+ name: str
80
+ date: datetime.datetime
81
+ schema_version: int = 1
82
+
83
+ def get_path(self, prefix: str = "") -> str:
84
+ """Get the path for the asset with the given prefix."""
85
+ date = assure_timezone(self.date).isoformat()
86
+ return get_prefixed_path(f"{self.name}/date={date}/v{self.schema_version}.avro", prefix)
87
+
88
+ def to_metadata(self) -> dict[str, str]:
89
+ """Get the metadata for the asset."""
90
+ return {"name": self.name, "date": self.date.isoformat(), "schema_version": str(self.schema_version)}
91
+
92
+ def __repr__(self) -> str:
93
+ return f"{self.__class__.__name__}({self.get_path()!r})"
94
+
95
+
96
+ @dataclass(frozen=True, eq=True)
97
+ class PersistentStoreAsset:
98
+ """Persistent store asset.
99
+
100
+ Attributes:
101
+ asset_type (ClassVar[AssetType]): The type of the asset.
102
+ name (str): The name of the asset.
103
+ schema_version (int): The schema version of the asset.
104
+ """
105
+
106
+ asset_type: ClassVar[AssetType] = "persistent_store"
107
+ name: str
108
+ schema_version: int = 1
109
+
110
+ def get_path(self, prefix: str = "") -> str:
111
+ """Get the path for the asset with the given prefix."""
112
+ return get_prefixed_path(f"{self.name}/v{self.schema_version}.avro", prefix)
113
+
114
+ def to_metadata(self) -> dict[str, str]:
115
+ """Get the metadata for the asset."""
116
+ return {"name": self.name, "schema_version": str(self.schema_version)}
117
+
118
+ def __repr__(self) -> str:
119
+ return f"{self.__class__.__name__}({self.get_path()!r})"
120
+
121
+
122
+ @dataclass(frozen=True, eq=True)
123
+ class FeatureAsset:
124
+ """Feature asset.
125
+
126
+ Attributes:
127
+ asset_type (ClassVar[AssetType]): The type of the asset.
128
+ name (str): The name of the asset.
129
+ partition_date (datetime.datetime): The partition timestamp of the asset.
130
+ resolution (TimeResolution | str): The resolution of the asset.
131
+ schema_version (int): The schema version of the asset.
132
+ partition_keys (dict[str, str]): The partition keys of the asset.
133
+ """
134
+
135
+ asset_type: ClassVar[AssetType] = "feature"
136
+ name: str
137
+ partition_date: datetime.datetime
138
+ resolution: TimeResolution | str
139
+ schema_version: int = 1
140
+ partition_keys: dict[str, str] = field(default_factory=dict)
141
+
142
+ @property
143
+ def time_resolution(self) -> TimeResolution:
144
+ """The time resolution of the feature."""
145
+ if isinstance(self.resolution, TimeResolution):
146
+ return self.resolution
147
+ return TimeResolution.from_str(self.resolution)
148
+
149
+ @property
150
+ def feature_name(self) -> str:
151
+ """The name of the feature including the time resolution."""
152
+ return f"{self.name}.{self.time_resolution!r}"
153
+
154
+ def _get_partition_keys_prefix(self) -> str:
155
+ key_names = sorted(self.partition_keys.keys())
156
+ return ("/".join(f"{key}={self.partition_keys[key]}" for key in key_names)).strip("/")
157
+
158
+ def get_path(self, prefix: str = "") -> str:
159
+ """Get the path for the asset with the given prefix."""
160
+ partition_date = assure_timezone(self.partition_date).isoformat()
161
+ keys_prefix = self._get_partition_keys_prefix()
162
+ if keys_prefix:
163
+ keys_prefix = keys_prefix + "/"
164
+ return get_prefixed_path(
165
+ f"{self.feature_name}/{keys_prefix}partition_date={partition_date}/v{self.schema_version}.avro", prefix
166
+ )
167
+
168
+ def to_metadata(self) -> dict[str, str]:
169
+ """Get the metadata for the asset."""
170
+ return {
171
+ "name": self.name,
172
+ "partition_date": self.partition_date.isoformat(),
173
+ "schema_version": str(self.schema_version),
174
+ "partition_keys": json.dumps(self.partition_keys),
175
+ }
176
+
177
+ def __repr__(self) -> str:
178
+ return f"{self.__class__.__name__}({self.get_path()!r})"
179
+
180
+
181
+ class FeatureModel(BaseModel):
182
+ """A base class for feature models.
183
+
184
+ You may use this base class (along with pydantic's BaseModel) to define type-safe feature models.
185
+ Use with "AssetCollection" to make powerful typed feature collections.
186
+ """
187
+
188
+ asset_name: ClassVar[str] = NotImplemented
189
+ resolution: ClassVar[TimeResolution] = NotImplemented
190
+ schema_version: ClassVar[int] = 1
File without changes
@@ -0,0 +1,220 @@
1
+ """AWS helpers and methods."""
2
+
3
+ import datetime
4
+ import logging
5
+ from asyncio import Semaphore
6
+ from collections.abc import AsyncIterator, Iterator
7
+ from contextlib import ExitStack, asynccontextmanager
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ from pathlib import Path
11
+ from typing import IO, BinaryIO, TypeVar, cast
12
+
13
+ import aioboto3
14
+ import boto3
15
+ import botocore.exceptions
16
+
17
+ from hyperion.config import storage_config
18
+ from hyperion.logging import get_logger
19
+
20
+ PathOrIOBinary = str | Path | BinaryIO | IO[bytes]
21
+
22
+ T = TypeVar("T")
23
+
24
+ logger = get_logger("aws")
25
+
26
+ logging.getLogger("botocore.endpoint").setLevel("WARNING")
27
+ logging.getLogger("botocore").setLevel("WARNING")
28
+
29
+
30
+ class S3StorageClass(str, Enum):
31
+ """S3 storage classes."""
32
+
33
+ STANDARD = "STANDARD"
34
+ REDUCED_REDUNDANCY = "REDUCED_REDUNDANCY"
35
+ STANDARD_IA = "STANDARD_IA"
36
+ ONEZONE_IA = "ONEZONE_IA"
37
+ INTELLIGENT_TIERING = "INTELLIGENT_TIERING"
38
+ GLACIER = "GLACIER"
39
+ DEEP_ARCHIVE = "DEEP_ARCHIVE"
40
+ OUTPOSTS = "OUTPOSTS"
41
+ GLACIER_IR = "GLACIER_IR"
42
+ SNOW = "SNOW"
43
+ EXPRESS_ONEZONE = "EXPRESS_ONEZONE"
44
+
45
+
46
+ @dataclass
47
+ class S3ObjectAttributes:
48
+ """Attributes of an S3 object."""
49
+
50
+ last_modified: datetime.datetime
51
+ etag: str
52
+ storage_class: S3StorageClass
53
+ object_size: int
54
+
55
+
56
+ class S3Client:
57
+ """S3 client."""
58
+
59
+ _storage_semaphore: Semaphore | None = None
60
+
61
+ @classmethod
62
+ @asynccontextmanager
63
+ async def semaphore(cls) -> AsyncIterator[None]:
64
+ """Semaphore for asynchronous storage operations."""
65
+ # Lazily initialize the semaphore instance
66
+ if cls._storage_semaphore is None:
67
+ logger.debug(
68
+ "Initializing new storage operations semaphore.", max_concurrency=storage_config.max_concurrency
69
+ )
70
+ cls._storage_semaphore = Semaphore(storage_config.max_concurrency)
71
+ logger.debug("Attempting to acquire the storage operations semaphore.")
72
+ async with cls._storage_semaphore:
73
+ logger.debug("Green light, semaphore open.")
74
+ yield
75
+
76
+ def __init__(self) -> None:
77
+ self._client = boto3.client("s3")
78
+ self._aio_session = aioboto3.Session()
79
+
80
+ async def upload_async(self, file: PathOrIOBinary, bucket: str, name: str) -> None:
81
+ """Upload a file to S3 asynchronously.
82
+
83
+ Args:
84
+ file (PathOrIOBinary): The file to upload.
85
+ bucket (str): The bucket to upload the file to.
86
+ name (str): The name of the file in the bucket.
87
+ """
88
+ with ExitStack() as file_context:
89
+ if isinstance(file, str | Path):
90
+ path = Path(file)
91
+ logger.debug("Uploading from path.", path=path.as_posix(), bucket=bucket, name=name)
92
+ file = file_context.enter_context(path.open("rb"))
93
+ async with self.semaphore(), self._aio_session.client("s3") as s3:
94
+ try:
95
+ await s3.upload_fileobj(file, bucket, name)
96
+ except botocore.exceptions.ClientError:
97
+ logger.error("Error when uploading file to S3.", bucket=bucket, name=name)
98
+ raise
99
+
100
+ def iter_objects(self, bucket: str, prefix: str) -> Iterator[str]:
101
+ """Iterate over objects in an S3 bucket.
102
+
103
+ Args:
104
+ bucket (str): The bucket to list objects in.
105
+ prefix (str): The prefix to filter objects by.
106
+
107
+ Yields:
108
+ Iterator[str]: The keys of the objects in the bucket.
109
+ """
110
+ paginator = self._client.get_paginator("list_objects_v2")
111
+ pagination_config = {"StartingToken": None}
112
+ logger.debug("Listing contents of a bucket.", bucket=bucket, prefix=prefix)
113
+ response_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix, PaginationConfig=pagination_config)
114
+ for response in response_iterator:
115
+ logger.debug("Received a response from S3.", response=response)
116
+ if "Contents" not in response:
117
+ continue
118
+ yield from (s3_object["Key"] for s3_object in response["Contents"])
119
+
120
+ def upload(self, file: PathOrIOBinary, bucket: str, name: str) -> None:
121
+ """Upload a file to S3.
122
+
123
+ Args:
124
+ file (PathOrIOBinary): The file to upload.
125
+ bucket (str): The bucket to upload the file to.
126
+ name (str): The name of the file in the bucket.
127
+ """
128
+ if isinstance(file, str | Path):
129
+ file = Path(file)
130
+ logger.debug("Uploading from path.", path=file.as_posix(), bucket=bucket, name=name)
131
+ try:
132
+ self._client.upload_file(file, bucket, name)
133
+ except botocore.exceptions.ClientError:
134
+ logger.error("Error when uploading file to S3.", file=file.as_posix(), bucket=bucket, name=name)
135
+ raise
136
+ return
137
+ logger.debug("Uploading from an open file stream.", bucket=bucket, name=name)
138
+ try:
139
+ self._client.upload_fileobj(file, bucket, name)
140
+ except botocore.exceptions.ClientError:
141
+ logger.error("Error when uploading file to S3.", bucket=bucket, name=name)
142
+ raise
143
+
144
+ def get_object_attributes(self, bucket: str, name: str) -> S3ObjectAttributes:
145
+ """Get the attributes of an object in S3.
146
+
147
+ Args:
148
+ bucket (str): The bucket containing the object.
149
+ name (str): The name of the object.
150
+
151
+ Returns:
152
+ S3ObjectAttributes: The attributes of the object.
153
+ """
154
+ response = self._client.get_object_attributes(
155
+ Bucket=bucket,
156
+ Key=name,
157
+ ObjectAttributes=[
158
+ "ETag",
159
+ "Checksum",
160
+ "ObjectParts",
161
+ "StorageClass",
162
+ "ObjectSize",
163
+ ],
164
+ )
165
+ try:
166
+ return S3ObjectAttributes(
167
+ last_modified=response["LastModified"],
168
+ etag=response["ETag"],
169
+ storage_class=S3StorageClass(response["StorageClass"]),
170
+ object_size=int(response["ObjectSize"]),
171
+ )
172
+ except Exception:
173
+ logger.error(
174
+ "Failed to get object attributes, the response is probably invalid.",
175
+ bucket=bucket,
176
+ object_name=name,
177
+ response=response,
178
+ )
179
+ raise
180
+
181
+ def download(self, bucket: str, name: str, file: PathOrIOBinary) -> None:
182
+ """Download a file from S3.
183
+
184
+ Args:
185
+ bucket (str): The bucket containing the file.
186
+ name (str): The name of the file.
187
+ file (PathOrIOBinary): The file to download to.
188
+ """
189
+ if isinstance(file, str | Path):
190
+ file = Path(file)
191
+ logger.debug("Downloading into a path.", path=file.as_posix(), bucket=bucket, name=name)
192
+ try:
193
+ self._client.download_file(bucket, name, file)
194
+ except botocore.exceptions.ClientError:
195
+ logger.error("Error when downloading file from S3.", bucket=bucket, name=name, path=file.as_posix())
196
+ raise
197
+ return
198
+ logger.debug("Downloading into a file stream.", bucket=bucket, name=name)
199
+ try:
200
+ self._client.download_fileobj(bucket, name, file)
201
+ except botocore.exceptions.ClientError:
202
+ logger.error("Error when downloading file from S3.", bucket=bucket, name=name)
203
+ raise
204
+
205
+ def download_as_string(self, bucket: str, name: str) -> str:
206
+ """Download an object from S3 as a string.
207
+
208
+ Args:
209
+ bucket (str): The bucket containing the object.
210
+ name (str): The name of the object.
211
+
212
+ Returns:
213
+ str: The object as a string.
214
+ """
215
+ logger.debug("Downloading object as a string.", bucket=bucket, name=name)
216
+ try:
217
+ return cast(str, self._client.get_object(Bucket=bucket, Key=name)["Body"].read().decode("utf-8"))
218
+ except botocore.exceptions.ClientError:
219
+ logger.error("Error when downloading file from S3.", bucket=bucket, name=name)
220
+ raise