s3-overlay 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- s3_overlay/__init__.py +6 -0
- s3_overlay/app.py +61 -0
- s3_overlay/proxy.py +816 -0
- s3_overlay-0.0.0.dist-info/METADATA +14 -0
- s3_overlay-0.0.0.dist-info/RECORD +6 -0
- s3_overlay-0.0.0.dist-info/WHEEL +4 -0
s3_overlay/__init__.py
ADDED
s3_overlay/app.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from litestar import Litestar, Request, get
|
|
6
|
+
from litestar.config.cors import CORSConfig
|
|
7
|
+
from litestar.handlers import asgi
|
|
8
|
+
from litestar.plugins.prometheus import PrometheusConfig, PrometheusController
|
|
9
|
+
|
|
10
|
+
from .proxy import S3OverlayProxy
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from litestar.types import Receive, Scope, Send
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
prometheus_config = PrometheusConfig(app_name="s3_overlay", prefix="s3_overlay")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_app() -> Litestar:
|
|
20
|
+
"""Create the S3 overlay proxy ASGI application."""
|
|
21
|
+
proxy = S3OverlayProxy.from_env()
|
|
22
|
+
|
|
23
|
+
@get("/health", include_in_schema=False)
|
|
24
|
+
async def health() -> dict[str, str]:
|
|
25
|
+
return {"status": "ok"}
|
|
26
|
+
|
|
27
|
+
@asgi(path="/", is_mount=True, copy_scope=True)
|
|
28
|
+
async def proxy_handler(scope: Scope, receive: Receive, send: Send) -> None:
|
|
29
|
+
request = Request(scope=scope, receive=receive)
|
|
30
|
+
path = scope.get("path", "/")
|
|
31
|
+
if not path.startswith("/"):
|
|
32
|
+
path = f"/{path}"
|
|
33
|
+
if path != "/" and path.endswith("/"):
|
|
34
|
+
path = path.rstrip("/") or "/"
|
|
35
|
+
response = await proxy.handle(request, path)
|
|
36
|
+
asgi_response = response.to_asgi_response(None, request)
|
|
37
|
+
await asgi_response(scope, receive, send)
|
|
38
|
+
|
|
39
|
+
async def startup(app: Litestar) -> None:
|
|
40
|
+
await proxy.startup()
|
|
41
|
+
|
|
42
|
+
async def shutdown(app: Litestar) -> None:
|
|
43
|
+
await proxy.shutdown()
|
|
44
|
+
|
|
45
|
+
cors_config = CORSConfig(
|
|
46
|
+
allow_origins=["*"],
|
|
47
|
+
allow_methods=["*"],
|
|
48
|
+
allow_headers=["*"],
|
|
49
|
+
expose_headers=["ETag", "x-amz-*"],
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
return Litestar(
|
|
53
|
+
route_handlers=[health, proxy_handler, PrometheusController],
|
|
54
|
+
on_startup=[startup],
|
|
55
|
+
on_shutdown=[shutdown],
|
|
56
|
+
cors_config=cors_config,
|
|
57
|
+
middleware=[prometheus_config.middleware],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
app = create_app()
|
s3_overlay/proxy.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import UTC, datetime
|
|
5
|
+
from email.utils import format_datetime
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
from anyio import to_thread
|
|
10
|
+
import httpx
|
|
11
|
+
from boto3.session import Session
|
|
12
|
+
from botocore.config import Config as BotoConfig
|
|
13
|
+
from botocore.exceptions import ClientError
|
|
14
|
+
from litestar.response import Response, Stream
|
|
15
|
+
from pydantic import AliasChoices, Field, field_validator
|
|
16
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import AsyncIterator, Callable, Iterable, Mapping
|
|
20
|
+
|
|
21
|
+
from litestar import Request
|
|
22
|
+
else: # pragma: no cover
|
|
23
|
+
AsyncIterator = Iterable = Mapping = Any
|
|
24
|
+
|
|
25
|
+
LOG = logging.getLogger("s3_overlay.proxy")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def _run_sync(func: "Callable[..., Any]", /, *args: Any, **kwargs: Any) -> Any:
|
|
29
|
+
return await to_thread.run_sync(func, *args, **kwargs)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LocalSettings(BaseSettings):
|
|
33
|
+
"""Configuration for local S3 storage."""
|
|
34
|
+
|
|
35
|
+
model_config = SettingsConfigDict(
|
|
36
|
+
env_prefix="", case_sensitive=False, extra="ignore"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
endpoint: str = Field(
|
|
40
|
+
default="http://127.0.0.1:9000",
|
|
41
|
+
validation_alias="S3_OVERLAY_LOCAL_ENDPOINT",
|
|
42
|
+
)
|
|
43
|
+
access_key: str = Field(
|
|
44
|
+
default="minioadmin",
|
|
45
|
+
validation_alias=AliasChoices(
|
|
46
|
+
"S3_OVERLAY_LOCAL_ACCESS_KEY",
|
|
47
|
+
"AWS_ACCESS_KEY_ID",
|
|
48
|
+
),
|
|
49
|
+
)
|
|
50
|
+
secret_key: str = Field(
|
|
51
|
+
default="minioadmin",
|
|
52
|
+
validation_alias=AliasChoices(
|
|
53
|
+
"S3_OVERLAY_LOCAL_SECRET_KEY",
|
|
54
|
+
"AWS_SECRET_ACCESS_KEY",
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
session_token: str | None = Field(
|
|
58
|
+
default=None,
|
|
59
|
+
validation_alias="S3_OVERLAY_LOCAL_SESSION_TOKEN",
|
|
60
|
+
)
|
|
61
|
+
region: str = Field(
|
|
62
|
+
default="us-east-1",
|
|
63
|
+
validation_alias="S3_OVERLAY_LOCAL_REGION",
|
|
64
|
+
)
|
|
65
|
+
bucket_location: str = Field(
|
|
66
|
+
default="us-east-1",
|
|
67
|
+
validation_alias="S3_OVERLAY_DEFAULT_BUCKET_LOCATION",
|
|
68
|
+
)
|
|
69
|
+
chunk_threshold: int = Field(
|
|
70
|
+
default=50 * 1024 * 1024,
|
|
71
|
+
validation_alias="S3_OVERLAY_CHUNK_THRESHOLD",
|
|
72
|
+
)
|
|
73
|
+
chunk_size: int = Field(
|
|
74
|
+
default=16 * 1024 * 1024,
|
|
75
|
+
validation_alias="S3_OVERLAY_CHUNK_SIZE",
|
|
76
|
+
)
|
|
77
|
+
cache_bucket_name: str = Field(
|
|
78
|
+
default="s3-overlay-cache",
|
|
79
|
+
validation_alias="S3_OVERLAY_CACHE_BUCKET",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class RemoteSettings(BaseSettings):
|
|
84
|
+
"""Configuration for remote S3 storage."""
|
|
85
|
+
|
|
86
|
+
model_config = SettingsConfigDict(
|
|
87
|
+
env_prefix="", case_sensitive=False, extra="ignore"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
endpoint: str | None = Field(
|
|
91
|
+
default=None,
|
|
92
|
+
validation_alias="S3_OVERLAY_REMOTE_ENDPOINT",
|
|
93
|
+
)
|
|
94
|
+
access_key: str | None = Field(
|
|
95
|
+
default=None,
|
|
96
|
+
validation_alias=AliasChoices(
|
|
97
|
+
"S3_OVERLAY_REMOTE_ACCESS_KEY_ID",
|
|
98
|
+
"AWS_REMOTE_ACCESS_KEY_ID",
|
|
99
|
+
),
|
|
100
|
+
)
|
|
101
|
+
secret_key: str | None = Field(
|
|
102
|
+
default=None,
|
|
103
|
+
validation_alias=AliasChoices(
|
|
104
|
+
"S3_OVERLAY_REMOTE_SECRET_ACCESS_KEY",
|
|
105
|
+
"AWS_REMOTE_SECRET_ACCESS_KEY",
|
|
106
|
+
),
|
|
107
|
+
)
|
|
108
|
+
session_token: str | None = Field(
|
|
109
|
+
default=None,
|
|
110
|
+
validation_alias=AliasChoices(
|
|
111
|
+
"S3_OVERLAY_REMOTE_SESSION_TOKEN",
|
|
112
|
+
"AWS_REMOTE_SESSION_TOKEN",
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
region: str | None = Field(
|
|
116
|
+
default=None,
|
|
117
|
+
validation_alias=AliasChoices(
|
|
118
|
+
"S3_OVERLAY_REMOTE_REGION",
|
|
119
|
+
"AWS_REGION",
|
|
120
|
+
),
|
|
121
|
+
)
|
|
122
|
+
bucket_mapping: dict[str, str] | None = Field(
|
|
123
|
+
default=None,
|
|
124
|
+
validation_alias="S3_OVERLAY_BUCKET_MAPPING",
|
|
125
|
+
)
|
|
126
|
+
addressing_style: str = Field(
|
|
127
|
+
default="virtual",
|
|
128
|
+
validation_alias="S3_OVERLAY_REMOTE_ADDRESSING_STYLE",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@field_validator("bucket_mapping", mode="before")
|
|
132
|
+
@classmethod
|
|
133
|
+
def _parse_bucket_mapping(cls, value: object) -> dict[str, str] | None:
|
|
134
|
+
if value is None:
|
|
135
|
+
return None
|
|
136
|
+
if isinstance(value, dict):
|
|
137
|
+
return {str(k).strip(): str(v).strip() for k, v in value.items()}
|
|
138
|
+
if isinstance(value, str):
|
|
139
|
+
mapping: dict[str, str] = {}
|
|
140
|
+
for pair in value.split(","):
|
|
141
|
+
if ":" in pair:
|
|
142
|
+
local, remote = pair.split(":", 1)
|
|
143
|
+
mapping[local.strip()] = remote.strip()
|
|
144
|
+
return mapping or None
|
|
145
|
+
msg = "Invalid bucket mapping format"
|
|
146
|
+
raise ValueError(msg)
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def enabled(self) -> bool:
|
|
150
|
+
"""Check if remote storage is enabled based on configuration."""
|
|
151
|
+
return bool(self.endpoint or self.access_key or self.secret_key or self.region)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def load_local_settings_from_env() -> LocalSettings:
|
|
155
|
+
"""Load local S3 settings from environment variables.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
LocalSettings instance populated from environment variables.
|
|
159
|
+
"""
|
|
160
|
+
return LocalSettings()
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def load_remote_settings_from_env() -> RemoteSettings:
|
|
164
|
+
"""Load remote S3 settings from environment variables.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
RemoteSettings instance populated from environment variables.
|
|
168
|
+
"""
|
|
169
|
+
return RemoteSettings()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class S3OverlayProxy:
|
|
173
|
+
def __init__(self, local: LocalSettings, remote: RemoteSettings):
|
|
174
|
+
self._local_settings = local
|
|
175
|
+
self._remote_settings = remote
|
|
176
|
+
self._http_client: httpx.AsyncClient | None = None
|
|
177
|
+
self._local_client = self._build_local_client()
|
|
178
|
+
self._remote_client = self._build_remote_client() if remote.enabled else None
|
|
179
|
+
|
|
180
|
+
async def startup(self) -> None:
|
|
181
|
+
self._http_client = httpx.AsyncClient(
|
|
182
|
+
base_url=self._local_settings.endpoint,
|
|
183
|
+
timeout=httpx.Timeout(60.0, read=300.0),
|
|
184
|
+
trust_env=False,
|
|
185
|
+
)
|
|
186
|
+
LOG.info(
|
|
187
|
+
"S3 overlay ready (local=%s, remote=%s)",
|
|
188
|
+
self._local_settings.endpoint,
|
|
189
|
+
self._describe_remote(),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
async def shutdown(self) -> None:
|
|
193
|
+
if self._http_client is not None:
|
|
194
|
+
await self._http_client.aclose()
|
|
195
|
+
self._http_client = None
|
|
196
|
+
|
|
197
|
+
async def handle(self, request: Request, path: str) -> Response:
|
|
198
|
+
LOG.debug("handle method=%s path=%s", request.method, path)
|
|
199
|
+
if request.method in {"GET", "HEAD"}:
|
|
200
|
+
response = await self._handle_read(request, path)
|
|
201
|
+
if response is not None:
|
|
202
|
+
LOG.debug("handled via read path=%s", path)
|
|
203
|
+
return response
|
|
204
|
+
LOG.debug("falling back to direct proxy path=%s", path)
|
|
205
|
+
|
|
206
|
+
if self._http_client is None:
|
|
207
|
+
message = "proxy not initialised"
|
|
208
|
+
raise RuntimeError(message)
|
|
209
|
+
|
|
210
|
+
local_request = await self._build_httpx_request(request, path)
|
|
211
|
+
response = await self._http_client.send(local_request, stream=True)
|
|
212
|
+
|
|
213
|
+
if response.status_code == 404 and self._should_backfill(request, path):
|
|
214
|
+
await response.aclose()
|
|
215
|
+
if await self._backfill_from_remote(request, path):
|
|
216
|
+
local_request = await self._build_httpx_request(request, path)
|
|
217
|
+
response = await self._http_client.send(local_request, stream=True)
|
|
218
|
+
|
|
219
|
+
return self._to_streaming_response(response)
|
|
220
|
+
|
|
221
|
+
def _build_local_client(self):
|
|
222
|
+
session = Session(
|
|
223
|
+
aws_access_key_id=self._local_settings.access_key,
|
|
224
|
+
aws_secret_access_key=self._local_settings.secret_key,
|
|
225
|
+
aws_session_token=self._local_settings.session_token,
|
|
226
|
+
region_name=self._local_settings.region,
|
|
227
|
+
)
|
|
228
|
+
return session.client(
|
|
229
|
+
"s3",
|
|
230
|
+
endpoint_url=self._local_settings.endpoint,
|
|
231
|
+
config=BotoConfig(signature_version="s3v4", retries={"max_attempts": 3}),
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def _build_remote_client(self):
|
|
235
|
+
session = Session(
|
|
236
|
+
aws_access_key_id=self._remote_settings.access_key,
|
|
237
|
+
aws_secret_access_key=self._remote_settings.secret_key,
|
|
238
|
+
aws_session_token=self._remote_settings.session_token,
|
|
239
|
+
region_name=self._remote_settings.region,
|
|
240
|
+
)
|
|
241
|
+
return session.client(
|
|
242
|
+
"s3",
|
|
243
|
+
endpoint_url=self._remote_settings.endpoint,
|
|
244
|
+
config=BotoConfig(
|
|
245
|
+
signature_version="s3v4",
|
|
246
|
+
retries={"max_attempts": 3},
|
|
247
|
+
s3={"addressing_style": self._remote_settings.addressing_style},
|
|
248
|
+
),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
async def _build_httpx_request(self, request: Request, path: str) -> httpx.Request:
|
|
252
|
+
assert self._http_client is not None
|
|
253
|
+
url = path or "/"
|
|
254
|
+
if request.scope.get("query_string"):
|
|
255
|
+
query = request.scope["query_string"].decode("latin-1")
|
|
256
|
+
url = f"{url}?{query}"
|
|
257
|
+
|
|
258
|
+
headers = self._prepare_outgoing_headers(request.headers)
|
|
259
|
+
content = None
|
|
260
|
+
if request.method in {"POST", "PUT", "PATCH"}:
|
|
261
|
+
content = await self._build_content_bytes(request)
|
|
262
|
+
|
|
263
|
+
return self._http_client.build_request(
|
|
264
|
+
method=request.method,
|
|
265
|
+
url=url,
|
|
266
|
+
headers=headers,
|
|
267
|
+
content=content,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def _prepare_outgoing_headers(self, headers: Mapping[str, str]) -> dict[str, str]:
|
|
271
|
+
hop_by_hop = {
|
|
272
|
+
"connection",
|
|
273
|
+
"keep-alive",
|
|
274
|
+
"proxy-authenticate",
|
|
275
|
+
"proxy-authorization",
|
|
276
|
+
"te",
|
|
277
|
+
"trailers",
|
|
278
|
+
"transfer-encoding",
|
|
279
|
+
"upgrade",
|
|
280
|
+
}
|
|
281
|
+
prepared: dict[str, str] = {}
|
|
282
|
+
for key, value in headers.items():
|
|
283
|
+
lowered = key.lower()
|
|
284
|
+
if lowered in hop_by_hop:
|
|
285
|
+
continue
|
|
286
|
+
prepared[lowered if lowered == "host" else key] = value
|
|
287
|
+
return prepared
|
|
288
|
+
|
|
289
|
+
async def _build_content_bytes(self, request: Request) -> bytes:
|
|
290
|
+
"""Read the entire request body as bytes directly from ASGI scope."""
|
|
291
|
+
body_parts = []
|
|
292
|
+
receive = request.receive
|
|
293
|
+
|
|
294
|
+
while True:
|
|
295
|
+
message = await receive()
|
|
296
|
+
if message["type"] == "http.request":
|
|
297
|
+
body = message.get("body", b"")
|
|
298
|
+
if body:
|
|
299
|
+
body_parts.append(body)
|
|
300
|
+
if not message.get("more_body", False):
|
|
301
|
+
break
|
|
302
|
+
elif message["type"] == "http.disconnect":
|
|
303
|
+
break
|
|
304
|
+
|
|
305
|
+
return b"".join(body_parts)
|
|
306
|
+
|
|
307
|
+
def _should_backfill(self, request: Request, path: str) -> bool:
|
|
308
|
+
if request.method not in {"GET", "HEAD"}:
|
|
309
|
+
return False
|
|
310
|
+
if not self._remote_client:
|
|
311
|
+
return False
|
|
312
|
+
bucket, key = self._extract_bucket_and_key(path)
|
|
313
|
+
return bool(bucket and key is not None and key != "")
|
|
314
|
+
|
|
315
|
+
async def _handle_read(self, request: Request, path: str) -> Response | None:
|
|
316
|
+
LOG.debug("handle_read method=%s path=%s", request.method, path)
|
|
317
|
+
bucket, key = self._extract_bucket_and_key(path)
|
|
318
|
+
if not bucket or key is None:
|
|
319
|
+
LOG.debug("missing bucket/key for path %s", path)
|
|
320
|
+
return None
|
|
321
|
+
if key == "":
|
|
322
|
+
# Empty key means bucket listing, not object operation
|
|
323
|
+
LOG.debug("empty key (bucket listing) for path %s", path)
|
|
324
|
+
return None
|
|
325
|
+
|
|
326
|
+
range_header = request.headers.get("range")
|
|
327
|
+
|
|
328
|
+
try:
|
|
329
|
+
if request.method == "HEAD":
|
|
330
|
+
result = await _run_sync(
|
|
331
|
+
partial(self._local_client.head_object, Bucket=bucket, Key=key)
|
|
332
|
+
)
|
|
333
|
+
headers = self._object_headers(result)
|
|
334
|
+
LOG.debug("HEAD hit for s3://%s/%s", bucket, key)
|
|
335
|
+
return Response(content=b"", headers=headers, status_code=200)
|
|
336
|
+
|
|
337
|
+
get_kwargs = {"Bucket": bucket, "Key": key}
|
|
338
|
+
if range_header:
|
|
339
|
+
get_kwargs["Range"] = range_header
|
|
340
|
+
result = await _run_sync(
|
|
341
|
+
partial(self._local_client.get_object, **get_kwargs)
|
|
342
|
+
)
|
|
343
|
+
except ClientError as error:
|
|
344
|
+
miss_result = await self._handle_read_miss(
|
|
345
|
+
error, request, path, bucket, key
|
|
346
|
+
)
|
|
347
|
+
if isinstance(miss_result, Response):
|
|
348
|
+
return miss_result
|
|
349
|
+
if not miss_result:
|
|
350
|
+
return self._from_client_error(error)
|
|
351
|
+
|
|
352
|
+
if request.method == "HEAD":
|
|
353
|
+
result = await _run_sync(
|
|
354
|
+
partial(self._local_client.head_object, Bucket=bucket, Key=key)
|
|
355
|
+
)
|
|
356
|
+
headers = self._object_headers(result)
|
|
357
|
+
return Response(content=b"", headers=headers, status_code=200)
|
|
358
|
+
|
|
359
|
+
get_kwargs = {"Bucket": bucket, "Key": key}
|
|
360
|
+
if range_header:
|
|
361
|
+
get_kwargs["Range"] = range_header
|
|
362
|
+
result = await _run_sync(
|
|
363
|
+
partial(self._local_client.get_object, **get_kwargs)
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
headers = self._object_headers(result)
|
|
367
|
+
streaming_body = result["Body"]
|
|
368
|
+
|
|
369
|
+
async def iterator() -> AsyncIterator[bytes]:
|
|
370
|
+
try:
|
|
371
|
+
while True:
|
|
372
|
+
chunk = await _run_sync(streaming_body.read, 1024 * 64)
|
|
373
|
+
if not chunk:
|
|
374
|
+
break
|
|
375
|
+
yield chunk
|
|
376
|
+
finally:
|
|
377
|
+
await _run_sync(streaming_body.close)
|
|
378
|
+
|
|
379
|
+
status_code = 206 if range_header else 200
|
|
380
|
+
LOG.debug("GET hit for s3://%s/%s status=%s", bucket, key, status_code)
|
|
381
|
+
return Stream(content=iterator, status_code=status_code, headers=headers)
|
|
382
|
+
|
|
383
|
+
async def _handle_read_miss(
|
|
384
|
+
self,
|
|
385
|
+
error: ClientError,
|
|
386
|
+
request: Request,
|
|
387
|
+
path: str,
|
|
388
|
+
bucket: str,
|
|
389
|
+
key: str,
|
|
390
|
+
) -> bool | Response:
|
|
391
|
+
code = error.response.get("Error", {}).get("Code")
|
|
392
|
+
if code not in {"404", "NoSuchKey", "NotFound", "NoSuchBucket"}:
|
|
393
|
+
raise error
|
|
394
|
+
if not self._remote_client:
|
|
395
|
+
return False
|
|
396
|
+
|
|
397
|
+
remote_bucket = self._map_to_remote_bucket(bucket)
|
|
398
|
+
|
|
399
|
+
# Check remote object metadata first
|
|
400
|
+
try:
|
|
401
|
+
remote_head = await _run_sync(
|
|
402
|
+
partial(self._remote_client.head_object, Bucket=remote_bucket, Key=key)
|
|
403
|
+
)
|
|
404
|
+
except ClientError as e:
|
|
405
|
+
# Check if it is a 404
|
|
406
|
+
err_code = e.response.get("Error", {}).get("Code")
|
|
407
|
+
if err_code in {"404", "NoSuchKey", "NotFound"}:
|
|
408
|
+
LOG.debug(
|
|
409
|
+
"remote miss for s3://%s/%s (remote: %s)",
|
|
410
|
+
bucket,
|
|
411
|
+
key,
|
|
412
|
+
remote_bucket,
|
|
413
|
+
)
|
|
414
|
+
return False
|
|
415
|
+
raise
|
|
416
|
+
|
|
417
|
+
size = remote_head.get("ContentLength", 0)
|
|
418
|
+
|
|
419
|
+
# If file is large, use chunked caching
|
|
420
|
+
if size > self._local_settings.chunk_threshold:
|
|
421
|
+
LOG.info(
|
|
422
|
+
"using chunked caching for s3://%s/%s (size=%d, threshold=%d)",
|
|
423
|
+
bucket,
|
|
424
|
+
key,
|
|
425
|
+
size,
|
|
426
|
+
self._local_settings.chunk_threshold,
|
|
427
|
+
)
|
|
428
|
+
return await self._serve_chunked(
|
|
429
|
+
bucket,
|
|
430
|
+
key,
|
|
431
|
+
size,
|
|
432
|
+
request.headers.get("range"),
|
|
433
|
+
remote_bucket,
|
|
434
|
+
remote_head,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return await self._backfill_from_remote(request, path)
|
|
438
|
+
|
|
439
|
+
def _object_headers(self, result: Mapping[str, Any]) -> dict[str, str]:
|
|
440
|
+
headers: dict[str, str] = {}
|
|
441
|
+
mapping = {
|
|
442
|
+
"Accept-Ranges": "AcceptRanges",
|
|
443
|
+
"Cache-Control": "CacheControl",
|
|
444
|
+
"Content-Disposition": "ContentDisposition",
|
|
445
|
+
"Content-Encoding": "ContentEncoding",
|
|
446
|
+
"Content-Language": "ContentLanguage",
|
|
447
|
+
"Content-Length": "ContentLength",
|
|
448
|
+
"Content-Range": "ContentRange",
|
|
449
|
+
"Content-Type": "ContentType",
|
|
450
|
+
"ETag": "ETag",
|
|
451
|
+
"Expires": "Expires",
|
|
452
|
+
"Last-Modified": "LastModified",
|
|
453
|
+
"x-amz-delete-marker": "DeleteMarker",
|
|
454
|
+
"x-amz-version-id": "VersionId",
|
|
455
|
+
"x-amz-storage-class": "StorageClass",
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
for header, key in mapping.items():
|
|
459
|
+
value = result.get(key)
|
|
460
|
+
if value is None:
|
|
461
|
+
continue
|
|
462
|
+
headers[header] = self._format_header_value(value)
|
|
463
|
+
|
|
464
|
+
metadata = result.get("Metadata") or {}
|
|
465
|
+
for meta_key, meta_value in metadata.items():
|
|
466
|
+
headers[f"x-amz-meta-{meta_key}"] = meta_value
|
|
467
|
+
|
|
468
|
+
return headers
|
|
469
|
+
|
|
470
|
+
@staticmethod
|
|
471
|
+
def _format_header_value(value: Any) -> str:
|
|
472
|
+
if isinstance(value, datetime):
|
|
473
|
+
aware = value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
|
474
|
+
aware = aware.astimezone(UTC)
|
|
475
|
+
return format_datetime(aware, usegmt=True)
|
|
476
|
+
return str(value)
|
|
477
|
+
|
|
478
|
+
def _from_client_error(self, error: ClientError) -> Response:
|
|
479
|
+
status_code = int(
|
|
480
|
+
error.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 500)
|
|
481
|
+
)
|
|
482
|
+
message = error.response.get("Error", {}).get(
|
|
483
|
+
"Message", "Internal Server Error"
|
|
484
|
+
)
|
|
485
|
+
return Response(content=message, status_code=status_code)
|
|
486
|
+
|
|
487
|
+
def _map_to_remote_bucket(self, local_bucket: str) -> str:
|
|
488
|
+
"""Map a local bucket name to its remote equivalent."""
|
|
489
|
+
if not self._remote_settings.bucket_mapping:
|
|
490
|
+
return local_bucket
|
|
491
|
+
return self._remote_settings.bucket_mapping.get(local_bucket, local_bucket)
|
|
492
|
+
|
|
493
|
+
async def _backfill_from_remote(self, request: Request, path: str) -> bool:
|
|
494
|
+
if not self._remote_client:
|
|
495
|
+
return False
|
|
496
|
+
|
|
497
|
+
bucket, key = self._extract_bucket_and_key(path)
|
|
498
|
+
if not bucket or key is None:
|
|
499
|
+
return False
|
|
500
|
+
|
|
501
|
+
remote_bucket = self._map_to_remote_bucket(bucket)
|
|
502
|
+
range_header = request.headers.get("range")
|
|
503
|
+
try:
|
|
504
|
+
remote_obj = await _run_sync(
|
|
505
|
+
partial(self._remote_client.get_object, Bucket=remote_bucket, Key=key)
|
|
506
|
+
)
|
|
507
|
+
except ClientError as error:
|
|
508
|
+
code = error.response.get("Error", {}).get("Code")
|
|
509
|
+
if code in {"404", "NoSuchKey", "NotFound"}:
|
|
510
|
+
LOG.debug(
|
|
511
|
+
"remote miss for s3://%s/%s (remote: %s)",
|
|
512
|
+
bucket,
|
|
513
|
+
key,
|
|
514
|
+
remote_bucket,
|
|
515
|
+
)
|
|
516
|
+
return False
|
|
517
|
+
raise
|
|
518
|
+
|
|
519
|
+
body = await _run_sync(remote_obj["Body"].read)
|
|
520
|
+
await _run_sync(remote_obj["Body"].close)
|
|
521
|
+
|
|
522
|
+
await self._ensure_bucket(bucket)
|
|
523
|
+
|
|
524
|
+
put_kwargs = {
|
|
525
|
+
"Bucket": bucket,
|
|
526
|
+
"Key": key,
|
|
527
|
+
"Body": body,
|
|
528
|
+
"Metadata": remote_obj.get("Metadata", {}),
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
for field in (
|
|
532
|
+
"ContentType",
|
|
533
|
+
"ContentEncoding",
|
|
534
|
+
"ContentDisposition",
|
|
535
|
+
"ContentLanguage",
|
|
536
|
+
"CacheControl",
|
|
537
|
+
):
|
|
538
|
+
if remote_obj.get(field):
|
|
539
|
+
put_kwargs[field] = remote_obj[field]
|
|
540
|
+
|
|
541
|
+
if remote_obj.get("Expires"):
|
|
542
|
+
put_kwargs["Expires"] = remote_obj["Expires"]
|
|
543
|
+
|
|
544
|
+
await _run_sync(partial(self._local_client.put_object, **put_kwargs))
|
|
545
|
+
LOG.info(
|
|
546
|
+
"cached remote object s3://%s/%s (%s bytes)%s (from remote: %s)",
|
|
547
|
+
bucket,
|
|
548
|
+
key,
|
|
549
|
+
len(body),
|
|
550
|
+
" with Range" if range_header else "",
|
|
551
|
+
remote_bucket,
|
|
552
|
+
)
|
|
553
|
+
return True
|
|
554
|
+
|
|
555
|
+
async def _ensure_bucket(self, bucket: str) -> None:
|
|
556
|
+
try:
|
|
557
|
+
await _run_sync(partial(self._local_client.head_bucket, Bucket=bucket))
|
|
558
|
+
except ClientError as error:
|
|
559
|
+
code = error.response.get("Error", {}).get("Code")
|
|
560
|
+
if code not in {"404", "NoSuchBucket", "NotFound"}:
|
|
561
|
+
raise
|
|
562
|
+
create_kwargs = {"Bucket": bucket}
|
|
563
|
+
location = self._local_settings.bucket_location
|
|
564
|
+
if location and location != "us-east-1":
|
|
565
|
+
create_kwargs["CreateBucketConfiguration"] = {
|
|
566
|
+
"LocationConstraint": location
|
|
567
|
+
}
|
|
568
|
+
await _run_sync(partial(self._local_client.create_bucket, **create_kwargs))
|
|
569
|
+
LOG.info("created local bucket %s", bucket)
|
|
570
|
+
|
|
571
|
+
def _extract_bucket_and_key(self, path: str) -> tuple[str | None, str | None]:
|
|
572
|
+
trimmed = path.lstrip("/")
|
|
573
|
+
if not trimmed:
|
|
574
|
+
return None, None
|
|
575
|
+
if "/" not in trimmed:
|
|
576
|
+
return trimmed, ""
|
|
577
|
+
bucket, key = trimmed.split("/", 1)
|
|
578
|
+
return bucket, key
|
|
579
|
+
|
|
580
|
+
def _to_streaming_response(self, response: httpx.Response) -> Response:
|
|
581
|
+
headers = self._prepare_response_headers(response.headers.raw)
|
|
582
|
+
|
|
583
|
+
async def iterator() -> AsyncIterator[bytes]:
|
|
584
|
+
try:
|
|
585
|
+
async for chunk in response.aiter_raw():
|
|
586
|
+
yield chunk
|
|
587
|
+
finally:
|
|
588
|
+
await response.aclose()
|
|
589
|
+
|
|
590
|
+
return Stream(
|
|
591
|
+
content=iterator(), status_code=response.status_code, headers=headers
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
def _prepare_response_headers(
|
|
595
|
+
self, headers: list[tuple[bytes, bytes]]
|
|
596
|
+
) -> dict[str, str]:
|
|
597
|
+
hop_by_hop = {
|
|
598
|
+
"connection",
|
|
599
|
+
"keep-alive",
|
|
600
|
+
"proxy-authenticate",
|
|
601
|
+
"proxy-authorization",
|
|
602
|
+
"te",
|
|
603
|
+
"trailers",
|
|
604
|
+
"transfer-encoding",
|
|
605
|
+
"upgrade",
|
|
606
|
+
}
|
|
607
|
+
prepared: dict[str, str] = {}
|
|
608
|
+
for key_bytes, value_bytes in headers:
|
|
609
|
+
key = key_bytes.decode("latin-1")
|
|
610
|
+
value = value_bytes.decode("latin-1")
|
|
611
|
+
if key.lower() in hop_by_hop:
|
|
612
|
+
continue
|
|
613
|
+
prepared[key] = value
|
|
614
|
+
return prepared
|
|
615
|
+
|
|
616
|
+
def _describe_remote(self) -> str:
|
|
617
|
+
if not self._remote_client:
|
|
618
|
+
return "disabled"
|
|
619
|
+
endpoint = self._remote_settings.endpoint or "aws"
|
|
620
|
+
region = self._remote_settings.region or "default"
|
|
621
|
+
return f"{endpoint} ({region})"
|
|
622
|
+
|
|
623
|
+
@classmethod
|
|
624
|
+
def from_env(cls) -> S3OverlayProxy:
|
|
625
|
+
"""Create an S3OverlayProxy instance from environment variables.
|
|
626
|
+
|
|
627
|
+
Returns:
|
|
628
|
+
S3OverlayProxy configured from environment variables.
|
|
629
|
+
"""
|
|
630
|
+
return cls(
|
|
631
|
+
local=load_local_settings_from_env(),
|
|
632
|
+
remote=load_remote_settings_from_env(),
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
async def _serve_chunked(
|
|
636
|
+
self,
|
|
637
|
+
bucket: str,
|
|
638
|
+
key: str,
|
|
639
|
+
total_size: int,
|
|
640
|
+
range_header: str | None,
|
|
641
|
+
remote_bucket: str,
|
|
642
|
+
remote_head: dict[str, Any],
|
|
643
|
+
) -> Response:
|
|
644
|
+
start, end = self._parse_range(range_header, total_size)
|
|
645
|
+
content_length = end - start + 1
|
|
646
|
+
chunk_size = self._local_settings.chunk_size
|
|
647
|
+
cache_bucket = self._local_settings.cache_bucket_name
|
|
648
|
+
|
|
649
|
+
await self._ensure_bucket(cache_bucket)
|
|
650
|
+
|
|
651
|
+
async def iterator() -> AsyncIterator[bytes]:
|
|
652
|
+
current_pos = start
|
|
653
|
+
while current_pos <= end:
|
|
654
|
+
chunk_index = current_pos // chunk_size
|
|
655
|
+
chunk_start_in_file = chunk_index * chunk_size
|
|
656
|
+
|
|
657
|
+
# Ensure chunk is available
|
|
658
|
+
await self._ensure_chunk(
|
|
659
|
+
bucket,
|
|
660
|
+
key,
|
|
661
|
+
chunk_index,
|
|
662
|
+
chunk_size,
|
|
663
|
+
total_size,
|
|
664
|
+
remote_bucket,
|
|
665
|
+
cache_bucket,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Read from chunk
|
|
669
|
+
chunk_key = f"v1/{bucket}/{key}/{chunk_size}/{chunk_index}"
|
|
670
|
+
offset_in_chunk = current_pos - chunk_start_in_file
|
|
671
|
+
|
|
672
|
+
# Calculate how much to read from this chunk
|
|
673
|
+
# We need up to the end of the request or end of the chunk
|
|
674
|
+
bytes_left_in_request = end - current_pos + 1
|
|
675
|
+
bytes_left_in_chunk = chunk_size - offset_in_chunk
|
|
676
|
+
|
|
677
|
+
# Handling the last chunk which might be smaller than chunk_size
|
|
678
|
+
if chunk_index * chunk_size + chunk_size > total_size:
|
|
679
|
+
bytes_left_in_chunk = (
|
|
680
|
+
total_size - chunk_start_in_file - offset_in_chunk
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
bytes_to_read = min(bytes_left_in_request, bytes_left_in_chunk)
|
|
684
|
+
|
|
685
|
+
chunk_range = (
|
|
686
|
+
f"bytes={offset_in_chunk}-{offset_in_chunk + bytes_to_read - 1}"
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
obj = await _run_sync(
|
|
690
|
+
partial(
|
|
691
|
+
self._local_client.get_object,
|
|
692
|
+
Bucket=cache_bucket,
|
|
693
|
+
Key=chunk_key,
|
|
694
|
+
Range=chunk_range,
|
|
695
|
+
)
|
|
696
|
+
)
|
|
697
|
+
stream = obj["Body"]
|
|
698
|
+
try:
|
|
699
|
+
chunk_bytes = await _run_sync(stream.read)
|
|
700
|
+
yield chunk_bytes
|
|
701
|
+
finally:
|
|
702
|
+
await _run_sync(stream.close)
|
|
703
|
+
|
|
704
|
+
current_pos += bytes_to_read
|
|
705
|
+
|
|
706
|
+
headers = self._object_headers(remote_head)
|
|
707
|
+
headers.update(
|
|
708
|
+
{
|
|
709
|
+
"Content-Length": str(content_length),
|
|
710
|
+
"Content-Range": f"bytes {start}-{end}/{total_size}",
|
|
711
|
+
}
|
|
712
|
+
)
|
|
713
|
+
# Remove ETag as it might correspond to the full file and we are sending partial?
|
|
714
|
+
# S3 usually keeps ETag for ranges.
|
|
715
|
+
|
|
716
|
+
status_code = 206
|
|
717
|
+
return Stream(content=iterator, status_code=status_code, headers=headers)
|
|
718
|
+
|
|
719
|
+
async def _ensure_chunk(
|
|
720
|
+
self,
|
|
721
|
+
bucket: str,
|
|
722
|
+
key: str,
|
|
723
|
+
index: int,
|
|
724
|
+
chunk_size: int,
|
|
725
|
+
total_size: int,
|
|
726
|
+
remote_bucket: str,
|
|
727
|
+
cache_bucket: str,
|
|
728
|
+
) -> None:
|
|
729
|
+
assert self._remote_client is not None
|
|
730
|
+
chunk_key = f"v1/{bucket}/{key}/{chunk_size}/{index}"
|
|
731
|
+
|
|
732
|
+
# Check if exists
|
|
733
|
+
try:
|
|
734
|
+
await _run_sync(
|
|
735
|
+
partial(
|
|
736
|
+
self._local_client.head_object, Bucket=cache_bucket, Key=chunk_key
|
|
737
|
+
)
|
|
738
|
+
)
|
|
739
|
+
except ClientError:
|
|
740
|
+
pass
|
|
741
|
+
else:
|
|
742
|
+
return
|
|
743
|
+
|
|
744
|
+
# Download from remote
|
|
745
|
+
start = index * chunk_size
|
|
746
|
+
end = min(start + chunk_size, total_size) - 1
|
|
747
|
+
range_header = f"bytes={start}-{end}"
|
|
748
|
+
|
|
749
|
+
try:
|
|
750
|
+
remote_obj = await _run_sync(
|
|
751
|
+
partial(
|
|
752
|
+
self._remote_client.get_object,
|
|
753
|
+
Bucket=remote_bucket,
|
|
754
|
+
Key=key,
|
|
755
|
+
Range=range_header,
|
|
756
|
+
)
|
|
757
|
+
)
|
|
758
|
+
except ClientError:
|
|
759
|
+
LOG.exception("Failed to download chunk %s from remote", chunk_key)
|
|
760
|
+
raise
|
|
761
|
+
|
|
762
|
+
body = await _run_sync(remote_obj["Body"].read)
|
|
763
|
+
await _run_sync(remote_obj["Body"].close)
|
|
764
|
+
|
|
765
|
+
# Upload to cache
|
|
766
|
+
await _run_sync(
|
|
767
|
+
partial(
|
|
768
|
+
self._local_client.put_object,
|
|
769
|
+
Bucket=cache_bucket,
|
|
770
|
+
Key=chunk_key,
|
|
771
|
+
Body=body,
|
|
772
|
+
)
|
|
773
|
+
)
|
|
774
|
+
LOG.debug("cached chunk %s (%d bytes)", chunk_key, len(body))
|
|
775
|
+
|
|
776
|
+
def _parse_range(
|
|
777
|
+
self, range_header: str | None, total_size: int
|
|
778
|
+
) -> tuple[int, int]:
|
|
779
|
+
if not range_header:
|
|
780
|
+
return 0, total_size - 1
|
|
781
|
+
|
|
782
|
+
try:
|
|
783
|
+
unit, ranges = range_header.split("=", 1)
|
|
784
|
+
if unit.strip().lower() != "bytes":
|
|
785
|
+
return 0, total_size - 1
|
|
786
|
+
|
|
787
|
+
r = ranges.split(",")[0].strip()
|
|
788
|
+
if "-" not in r:
|
|
789
|
+
return 0, total_size - 1
|
|
790
|
+
|
|
791
|
+
start_str, end_str = r.split("-", 1)
|
|
792
|
+
|
|
793
|
+
if start_str and end_str:
|
|
794
|
+
start = int(start_str)
|
|
795
|
+
end = int(end_str)
|
|
796
|
+
elif start_str:
|
|
797
|
+
start = int(start_str)
|
|
798
|
+
end = total_size - 1
|
|
799
|
+
elif end_str:
|
|
800
|
+
length = int(end_str)
|
|
801
|
+
start = total_size - length
|
|
802
|
+
end = total_size - 1
|
|
803
|
+
else:
|
|
804
|
+
return 0, total_size - 1
|
|
805
|
+
|
|
806
|
+
if end >= total_size:
|
|
807
|
+
end = total_size - 1
|
|
808
|
+
if start < 0:
|
|
809
|
+
start = 0
|
|
810
|
+
if start > end:
|
|
811
|
+
# Invalid range, fallback to full
|
|
812
|
+
return 0, total_size - 1
|
|
813
|
+
except ValueError:
|
|
814
|
+
return 0, total_size - 1
|
|
815
|
+
else:
|
|
816
|
+
return start, end
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: s3-overlay
|
|
3
|
+
Version: 0.0.0
|
|
4
|
+
Summary: S3 overlay proxy for transparent remote object caching
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Requires-Dist: anyio>=4.0.0
|
|
7
|
+
Requires-Dist: boto3>=1.37.0
|
|
8
|
+
Requires-Dist: httpx>=0.27.2
|
|
9
|
+
Requires-Dist: litestar[standard,prometheus]>=2.10.0
|
|
10
|
+
Requires-Dist: pydantic>=2.6.0
|
|
11
|
+
Requires-Dist: pydantic-settings>=2.2.0
|
|
12
|
+
Requires-Dist: uv~=0.10.0 ; extra == 'build'
|
|
13
|
+
Requires-Python: >=3.13
|
|
14
|
+
Provides-Extra: build
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
s3_overlay/__init__.py,sha256=l6BnnaQCmhQto2hXfns4XjkX0XAkwxWwmqf8GYaWICM,235
|
|
2
|
+
s3_overlay/app.py,sha256=Xg7l4Y4O6awqIRP4fb_W4zcfV2IjG00s0Rre5A1u3Ho,1849
|
|
3
|
+
s3_overlay/proxy.py,sha256=ocyKJGbyKsqfrhNvZ2mjKSmulgC2_59CgKoP9Df7pqQ,28069
|
|
4
|
+
s3_overlay-0.0.0.dist-info/WHEEL,sha256=iHtWm8nRfs0VRdCYVXocAWFW8ppjHL-uTJkAdZJKOBM,80
|
|
5
|
+
s3_overlay-0.0.0.dist-info/METADATA,sha256=I3nbzHZC3Gs_aU72EQUDjvs9rrS96X5afYAvw1aZ6iY,443
|
|
6
|
+
s3_overlay-0.0.0.dist-info/RECORD,,
|