tab-cli 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tab_cli/__init__.py +3 -0
- tab_cli/cli.py +171 -0
- tab_cli/config.py +14 -0
- tab_cli/formats/__init__.py +15 -0
- tab_cli/formats/avro.py +47 -0
- tab_cli/formats/base.py +63 -0
- tab_cli/formats/csv.py +45 -0
- tab_cli/formats/jsonl.py +41 -0
- tab_cli/formats/parquet.py +57 -0
- tab_cli/handlers/__init__.py +87 -0
- tab_cli/handlers/base.py +259 -0
- tab_cli/handlers/cli_table.py +55 -0
- tab_cli/storage/__init__.py +83 -0
- tab_cli/storage/aws.py +223 -0
- tab_cli/storage/az.py +249 -0
- tab_cli/storage/base.py +36 -0
- tab_cli/storage/fsspec.py +60 -0
- tab_cli/storage/gcloud.py +215 -0
- tab_cli/storage/local.py +25 -0
- tab_cli/style.py +4 -0
- tab_cli/url_parser.py +97 -0
- tab_cli-0.1.1.dist-info/METADATA +27 -0
- tab_cli-0.1.1.dist-info/RECORD +26 -0
- tab_cli-0.1.1.dist-info/WHEEL +4 -0
- tab_cli-0.1.1.dist-info/entry_points.txt +2 -0
- tab_cli-0.1.1.dist-info/licenses/LICENSE +21 -0
tab_cli/handlers/base.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Base classes for table reading and writing."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
import polars as pl
|
|
9
|
+
from rich import box
|
|
10
|
+
from rich.progress import Progress
|
|
11
|
+
from rich.table import Table
|
|
12
|
+
|
|
13
|
+
from tab_cli.formats.base import FormatHandler
|
|
14
|
+
from tab_cli.storage.base import StorageBackend
|
|
15
|
+
from tab_cli.style import _ALT_ROW_STYLE_0, _ALT_ROW_STYLE_1, _KEY_STYLE, _VAL_STYLE
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class TableSchema:
|
|
20
|
+
"""Schema information for a table."""
|
|
21
|
+
|
|
22
|
+
columns: list[tuple[str, pl.DataType]]
|
|
23
|
+
|
|
24
|
+
def __rich__(self) -> Table:
|
|
25
|
+
table = Table(
|
|
26
|
+
show_header=False,
|
|
27
|
+
box=box.SIMPLE_HEAD,
|
|
28
|
+
row_styles=[_ALT_ROW_STYLE_0, _ALT_ROW_STYLE_1],
|
|
29
|
+
)
|
|
30
|
+
table.add_column(style=_KEY_STYLE)
|
|
31
|
+
table.add_column(style=_VAL_STYLE)
|
|
32
|
+
for name, dtype in self.columns:
|
|
33
|
+
table.add_row(name, str(dtype))
|
|
34
|
+
return table
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class TableSummary:
|
|
39
|
+
"""Summary information for a table."""
|
|
40
|
+
|
|
41
|
+
file_size: int
|
|
42
|
+
num_rows: int
|
|
43
|
+
num_columns: int
|
|
44
|
+
extra: dict[str, str | int | float] | None = None
|
|
45
|
+
|
|
46
|
+
def __rich__(self) -> Table:
|
|
47
|
+
def format_size(size: int) -> str:
|
|
48
|
+
s: float = size
|
|
49
|
+
for unit in ["B", "KiB", "MiB", "GiB", "TiB"]:
|
|
50
|
+
if s < 1024:
|
|
51
|
+
return f"{s:.1f} {unit}" if unit != "B" else f"{int(s)} {unit}"
|
|
52
|
+
s /= 1024
|
|
53
|
+
return f"{s:.1f} PiB"
|
|
54
|
+
|
|
55
|
+
table = Table(
|
|
56
|
+
show_header=False,
|
|
57
|
+
box=box.SIMPLE_HEAD,
|
|
58
|
+
row_styles=["", _ALT_ROW_STYLE],
|
|
59
|
+
)
|
|
60
|
+
table.add_column(style=_KEY_STYLE)
|
|
61
|
+
table.add_column(style=_VAL_STYLE)
|
|
62
|
+
|
|
63
|
+
table.add_row("File size", format_size(self.file_size))
|
|
64
|
+
table.add_row("Rows", f"{self.num_rows:,}")
|
|
65
|
+
table.add_row("Columns", str(self.num_columns))
|
|
66
|
+
|
|
67
|
+
if self.extra:
|
|
68
|
+
for key, value in self.extra.items():
|
|
69
|
+
table.add_row(key, str(value))
|
|
70
|
+
|
|
71
|
+
return table
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TableReader:
|
|
75
|
+
"""Reads tabular data by composing a StorageBackend and FormatHandler."""
|
|
76
|
+
|
|
77
|
+
def __init__(self, backend: StorageBackend, format: FormatHandler):
|
|
78
|
+
self.backend = backend
|
|
79
|
+
self.format = format
|
|
80
|
+
|
|
81
|
+
def read(self, url: str, limit: int | None = None, offset: int = 0) -> pl.LazyFrame:
|
|
82
|
+
if self.backend.is_directory(url):
|
|
83
|
+
lf = self._read_directory(url)
|
|
84
|
+
else:
|
|
85
|
+
polars_uri = self.backend.normalize_for_polars(url)
|
|
86
|
+
storage_options = self.backend.storage_options(url)
|
|
87
|
+
lf = self.format.scan(polars_uri, storage_options=storage_options)
|
|
88
|
+
|
|
89
|
+
if offset > 0:
|
|
90
|
+
lf = lf.slice(offset, length=limit)
|
|
91
|
+
elif limit is not None:
|
|
92
|
+
lf = lf.head(limit)
|
|
93
|
+
return lf
|
|
94
|
+
|
|
95
|
+
def _read_directory(self, url: str) -> pl.LazyFrame:
|
|
96
|
+
"""Read all files in a directory."""
|
|
97
|
+
extension = self.format.extension()
|
|
98
|
+
storage_options = self.backend.storage_options(url)
|
|
99
|
+
|
|
100
|
+
if self.format.supports_glob():
|
|
101
|
+
# Use native glob support
|
|
102
|
+
polars_uri = self.backend.normalize_for_polars(url)
|
|
103
|
+
glob_pattern = os.path.join(polars_uri, "**", f"*{extension}")
|
|
104
|
+
return self.format.scan(glob_pattern, storage_options=storage_options)
|
|
105
|
+
else:
|
|
106
|
+
# Manual concatenation for formats without glob support
|
|
107
|
+
files = list(self.backend.list_files(url, extension))
|
|
108
|
+
if not files:
|
|
109
|
+
raise ValueError(f"No {extension} files found in {url}")
|
|
110
|
+
frames = [
|
|
111
|
+
self.format.scan(
|
|
112
|
+
self.backend.normalize_for_polars(f.url),
|
|
113
|
+
storage_options=self.backend.storage_options(f.url),
|
|
114
|
+
)
|
|
115
|
+
for f in files
|
|
116
|
+
]
|
|
117
|
+
return pl.concat(frames, how="vertical")
|
|
118
|
+
|
|
119
|
+
def schema(self, url: str) -> TableSchema:
|
|
120
|
+
if self.backend.is_directory(url):
|
|
121
|
+
# Get schema from first file
|
|
122
|
+
files = list(self.backend.list_files(url, self.format.extension()))
|
|
123
|
+
if not files:
|
|
124
|
+
raise ValueError(f"No {self.format.extension()} files found in {url}")
|
|
125
|
+
url = files[0].url
|
|
126
|
+
polars_uri = self.backend.normalize_for_polars(url)
|
|
127
|
+
storage_options = self.backend.storage_options(url)
|
|
128
|
+
columns = self.format.collect_schema(polars_uri, storage_options=storage_options)
|
|
129
|
+
return TableSchema(columns=columns)
|
|
130
|
+
|
|
131
|
+
def summary(self, url: str) -> TableSummary:
|
|
132
|
+
if self.backend.is_directory(url):
|
|
133
|
+
return self._summary_directory(url)
|
|
134
|
+
else:
|
|
135
|
+
return self._summary_single(url)
|
|
136
|
+
|
|
137
|
+
def _summary_single(self, url: str) -> TableSummary:
|
|
138
|
+
file_size = self.backend.size(url)
|
|
139
|
+
polars_uri = self.backend.normalize_for_polars(url)
|
|
140
|
+
storage_options = self.backend.storage_options(url)
|
|
141
|
+
num_rows = self.format.count_rows(polars_uri, storage_options=storage_options)
|
|
142
|
+
schema = self.format.collect_schema(polars_uri, storage_options=storage_options)
|
|
143
|
+
num_columns = len(schema)
|
|
144
|
+
extra = self.format.extra_summary(url)
|
|
145
|
+
return TableSummary(
|
|
146
|
+
file_size=file_size,
|
|
147
|
+
num_rows=num_rows,
|
|
148
|
+
num_columns=num_columns,
|
|
149
|
+
extra=extra,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def _summary_directory(self, url: str) -> TableSummary:
|
|
153
|
+
"""Aggregate summary from all files in directory."""
|
|
154
|
+
files = list(self.backend.list_files(url, self.format.extension()))
|
|
155
|
+
if not files:
|
|
156
|
+
raise ValueError(f"No {self.format.extension()} files found in {url}")
|
|
157
|
+
|
|
158
|
+
file_size = 0
|
|
159
|
+
num_rows = 0
|
|
160
|
+
num_columns: int | None = None
|
|
161
|
+
|
|
162
|
+
extra_numeric: dict[str, float] = {}
|
|
163
|
+
extra_strings: dict[str, set[str]] = {}
|
|
164
|
+
|
|
165
|
+
for file_info in files:
|
|
166
|
+
file_size += file_info.size
|
|
167
|
+
polars_uri = self.backend.normalize_for_polars(file_info.url)
|
|
168
|
+
storage_options = self.backend.storage_options(file_info.url)
|
|
169
|
+
num_rows += self.format.count_rows(polars_uri, storage_options=storage_options)
|
|
170
|
+
|
|
171
|
+
schema = self.format.collect_schema(polars_uri, storage_options=storage_options)
|
|
172
|
+
if num_columns is None:
|
|
173
|
+
num_columns = len(schema)
|
|
174
|
+
elif len(schema) != num_columns:
|
|
175
|
+
raise ValueError(f"Inconsistent column counts in {url}")
|
|
176
|
+
|
|
177
|
+
extra = self.format.extra_summary(file_info.url)
|
|
178
|
+
if extra:
|
|
179
|
+
for key, value in extra.items():
|
|
180
|
+
if isinstance(value, (int, float)):
|
|
181
|
+
extra_numeric[key] = extra_numeric.get(key, 0) + value
|
|
182
|
+
else:
|
|
183
|
+
extra_strings.setdefault(key, set()).add(str(value))
|
|
184
|
+
|
|
185
|
+
result_extra: dict[str, str | int | float] = {"Partitions": len(files)}
|
|
186
|
+
for key, value in extra_numeric.items():
|
|
187
|
+
if float(value).is_integer():
|
|
188
|
+
result_extra[key] = int(value)
|
|
189
|
+
else:
|
|
190
|
+
result_extra[key] = value
|
|
191
|
+
|
|
192
|
+
for key, values in extra_strings.items():
|
|
193
|
+
if len(values) == 1:
|
|
194
|
+
result_extra[key] = next(iter(values))
|
|
195
|
+
else:
|
|
196
|
+
result_extra[key] = ", ".join(sorted(values))
|
|
197
|
+
|
|
198
|
+
return TableSummary(
|
|
199
|
+
file_size=file_size,
|
|
200
|
+
num_rows=num_rows,
|
|
201
|
+
num_columns=num_columns or 0,
|
|
202
|
+
extra=result_extra,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class TableWriter(ABC):
|
|
207
|
+
"""Base class for writing tabular data."""
|
|
208
|
+
|
|
209
|
+
@abstractmethod
|
|
210
|
+
def extension(self) -> str:
|
|
211
|
+
"""Return the file extension for this format."""
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
@abstractmethod
|
|
215
|
+
def write(self, lf: pl.LazyFrame) -> Iterable[bytes]:
|
|
216
|
+
"""Write LazyFrame to bytes (for streaming output)."""
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
@abstractmethod
|
|
220
|
+
def write_to_single_file(self, lf: pl.LazyFrame, path: str) -> None:
|
|
221
|
+
"""Write LazyFrame to a single file."""
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
def write_to_path(self, lf: pl.LazyFrame, path: str, partitions: int | None = None) -> None:
|
|
225
|
+
"""Write LazyFrame to a file or partitioned directory."""
|
|
226
|
+
if partitions is None:
|
|
227
|
+
with Progress() as progress:
|
|
228
|
+
task = progress.add_task("Writing...", total=1)
|
|
229
|
+
self.write_to_single_file(lf, path)
|
|
230
|
+
progress.update(task, completed=1)
|
|
231
|
+
else:
|
|
232
|
+
os.makedirs(path, exist_ok=True)
|
|
233
|
+
row_count = lf.select(pl.len()).collect().item()
|
|
234
|
+
rows_per_part = (row_count + partitions - 1) // partitions
|
|
235
|
+
with Progress() as progress:
|
|
236
|
+
task = progress.add_task("Writing partitions...", total=partitions)
|
|
237
|
+
for i in range(partitions):
|
|
238
|
+
offset = i * rows_per_part
|
|
239
|
+
if offset < row_count:
|
|
240
|
+
part_lf = lf.slice(offset, rows_per_part)
|
|
241
|
+
part_path = os.path.join(path, f"part-{i:05d}{self.extension()}")
|
|
242
|
+
self.write_to_single_file(part_lf, part_path)
|
|
243
|
+
progress.update(task, advance=1)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class FormatWriter(TableWriter):
|
|
247
|
+
"""TableWriter adapter for FormatHandler."""
|
|
248
|
+
|
|
249
|
+
def __init__(self, format: FormatHandler):
|
|
250
|
+
self._format = format
|
|
251
|
+
|
|
252
|
+
def extension(self) -> str:
|
|
253
|
+
return self._format.extension()
|
|
254
|
+
|
|
255
|
+
def write(self, lf: pl.LazyFrame) -> Iterable[bytes]:
|
|
256
|
+
return self._format.write(lf)
|
|
257
|
+
|
|
258
|
+
def write_to_single_file(self, lf: pl.LazyFrame, path: str) -> None:
|
|
259
|
+
self._format.write_to_single_file(lf, path)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
|
|
4
|
+
from rich.table import Table
|
|
5
|
+
from rich import box
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
import polars as pl
|
|
8
|
+
|
|
9
|
+
from tab_cli.handlers.base import TableWriter
|
|
10
|
+
from tab_cli.style import _ALT_ROW_STYLE_0, _ALT_ROW_STYLE_1, _KEY_STYLE
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CliTableFormatter(TableWriter):
|
|
14
|
+
def __init__(self, truncated: bool = False, svg_capture: bool = False):
|
|
15
|
+
self.truncated = truncated
|
|
16
|
+
self.svg_capture = svg_capture
|
|
17
|
+
|
|
18
|
+
def extension(self) -> str:
|
|
19
|
+
return ".txt"
|
|
20
|
+
|
|
21
|
+
def write(self, lf: pl.LazyFrame) -> Iterable[bytes]:
|
|
22
|
+
|
|
23
|
+
table = Table(
|
|
24
|
+
show_header=True,
|
|
25
|
+
header_style=_KEY_STYLE,
|
|
26
|
+
box=box.SIMPLE_HEAD,
|
|
27
|
+
row_styles=[_ALT_ROW_STYLE_0, _ALT_ROW_STYLE_1],
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
for col in lf.collect_schema().names():
|
|
31
|
+
table.add_column(col)
|
|
32
|
+
|
|
33
|
+
for batch in lf.collect_batches():
|
|
34
|
+
for row in batch.iter_rows():
|
|
35
|
+
table.add_row(*[str(v) if v is not None else "" for v in row])
|
|
36
|
+
|
|
37
|
+
if self.truncated:
|
|
38
|
+
table.add_row(*["..." for _ in lf.collect_schema().names()])
|
|
39
|
+
|
|
40
|
+
if self.svg_capture:
|
|
41
|
+
console = Console(record=True, width=80)
|
|
42
|
+
console.print(table)
|
|
43
|
+
svg = console.export_svg()
|
|
44
|
+
print(svg, file=sys.stderr)
|
|
45
|
+
else:
|
|
46
|
+
console = Console()
|
|
47
|
+
with console.capture() as capture:
|
|
48
|
+
console.print(table)
|
|
49
|
+
yield capture.get().encode("utf-8")
|
|
50
|
+
|
|
51
|
+
def write_to_single_file(self, lf: pl.LazyFrame, path: str) -> None:
|
|
52
|
+
"""Write a LazyFrame to a single text file."""
|
|
53
|
+
with open(path, "wb") as f:
|
|
54
|
+
for chunk in self.write(lf):
|
|
55
|
+
f.write(chunk)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Storage backends for filesystem abstraction."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from urllib.parse import urlparse
|
|
7
|
+
|
|
8
|
+
from tab_cli.storage.base import FileInfo, StorageBackend
|
|
9
|
+
from tab_cli.storage.local import LocalBackend
|
|
10
|
+
from tab_cli.url_parser import parse_url, ParsedUrl
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"FileInfo",
|
|
14
|
+
"StorageBackend",
|
|
15
|
+
"LocalBackend",
|
|
16
|
+
"get_backend",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_backend(url: str) -> StorageBackend:
|
|
21
|
+
"""Get the appropriate storage backend for a URL.
|
|
22
|
+
|
|
23
|
+
Supports:
|
|
24
|
+
- Local paths (no scheme or file://)
|
|
25
|
+
- s3:// - AWS S3 (requires s3fs)
|
|
26
|
+
- gs:// - Google Cloud Storage (requires gcsfs)
|
|
27
|
+
- az:// - Azure Blob Storage (requires adlfs)
|
|
28
|
+
- abfs://, abfss:// - Azure Data Lake Storage Gen2 (requires adlfs)
|
|
29
|
+
- Any other fsspec-supported protocol
|
|
30
|
+
|
|
31
|
+
For az:// URLs, the interpretation of the URL authority depends on the
|
|
32
|
+
--az-url-authority-is-account global flag:
|
|
33
|
+
- If set: authority is the storage account name
|
|
34
|
+
- az://account/container/path
|
|
35
|
+
- az:///container/path (account from AZURE_STORAGE_ACCOUNT)
|
|
36
|
+
- If not set (default): authority is the container name
|
|
37
|
+
- az://container/path (standard adlfs behavior)
|
|
38
|
+
"""
|
|
39
|
+
parsed = parse_url(url)
|
|
40
|
+
logger.debug(f"Accessing data from\n"
|
|
41
|
+
f" - Protocol: [bold]{parsed.scheme}[/]\n"
|
|
42
|
+
f" - Account: {parsed.account}\n"
|
|
43
|
+
f" - Bucket: {parsed.bucket}\n"
|
|
44
|
+
f" - Path: {parsed.path}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Local filesystem
|
|
48
|
+
if parsed.scheme == "file" or not parsed.scheme:
|
|
49
|
+
return LocalBackend()
|
|
50
|
+
|
|
51
|
+
elif parsed.scheme == "az":
|
|
52
|
+
from tab_cli import config
|
|
53
|
+
from tab_cli.storage.az import AzBackend
|
|
54
|
+
|
|
55
|
+
return AzBackend(
|
|
56
|
+
account=parsed.account,
|
|
57
|
+
container=parsed.bucket,
|
|
58
|
+
az_url_authority_is_account=config.config.az_url_authority_is_account,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
elif parsed.scheme in {"abfs", "abfss"}:
|
|
62
|
+
from tab_cli.storage.az import AzBackend
|
|
63
|
+
|
|
64
|
+
# abfs/abfss always uses account in URL or env
|
|
65
|
+
return AzBackend(
|
|
66
|
+
account=parsed.account,
|
|
67
|
+
container=parsed.bucket,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Google Cloud Storage
|
|
71
|
+
elif parsed.scheme == "gs":
|
|
72
|
+
from tab_cli.storage.gcloud import GcloudBackend
|
|
73
|
+
return GcloudBackend()
|
|
74
|
+
|
|
75
|
+
# AWS S3
|
|
76
|
+
elif parsed.scheme == "s3":
|
|
77
|
+
from tab_cli.storage.aws import AwsBackend
|
|
78
|
+
return AwsBackend()
|
|
79
|
+
|
|
80
|
+
# All other protocols via fsspec
|
|
81
|
+
from tab_cli.storage.fsspec import FsspecBackend
|
|
82
|
+
|
|
83
|
+
return FsspecBackend(parsed.scheme)
|
tab_cli/storage/aws.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from typing import BinaryIO, Iterator
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
from tab_cli.storage.base import StorageBackend, FileInfo
|
|
8
|
+
from tab_cli.url_parser import parse_url
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AwsAuthMethod(Enum):
|
|
12
|
+
EXPLICIT_KEYS = 1 # AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY
|
|
13
|
+
PROFILE = 2 # AWS_PROFILE or default profile (handles SSO, assume role, etc.)
|
|
14
|
+
ANONYMOUS = 3 # Public buckets
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AwsBackend(StorageBackend):
|
|
18
|
+
"""Storage backend for AWS S3.
|
|
19
|
+
|
|
20
|
+
URL format: s3://bucket/path
|
|
21
|
+
|
|
22
|
+
Authentication is handled by boto3's credential chain via s3fs:
|
|
23
|
+
1. Explicit keys from environment (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN)
|
|
24
|
+
2. Profile-based auth (AWS_PROFILE or default) - handles:
|
|
25
|
+
- Shared credentials file (~/.aws/credentials)
|
|
26
|
+
- AWS config file (~/.aws/config)
|
|
27
|
+
- SSO credentials (from `aws sso login`)
|
|
28
|
+
- Assume role
|
|
29
|
+
- Container credentials (ECS/EKS)
|
|
30
|
+
- EC2 instance metadata
|
|
31
|
+
3. Anonymous access (for public buckets, if requested)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, anon: bool = False) -> None:
|
|
35
|
+
"""Initialize the AWS S3 storage backend.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
anon: If True, use anonymous access (for public buckets only).
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
import s3fs
|
|
42
|
+
except ImportError as e:
|
|
43
|
+
raise ImportError("Package 's3fs' is required for s3:// URLs. Install with: pip install s3fs") from e
|
|
44
|
+
|
|
45
|
+
self.s3fs = s3fs
|
|
46
|
+
self.fs = None
|
|
47
|
+
self.anon = anon
|
|
48
|
+
|
|
49
|
+
# Get profile from environment
|
|
50
|
+
self.profile = os.environ.get("AWS_PROFILE")
|
|
51
|
+
self.region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
|
|
52
|
+
|
|
53
|
+
# Check for explicit credentials in environment
|
|
54
|
+
self.access_key = os.environ.get("AWS_ACCESS_KEY_ID")
|
|
55
|
+
self.secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
|
|
56
|
+
self.session_token = os.environ.get("AWS_SESSION_TOKEN")
|
|
57
|
+
|
|
58
|
+
if anon:
|
|
59
|
+
# Anonymous access for public buckets
|
|
60
|
+
logger.debug("Using anonymous access for S3")
|
|
61
|
+
self.fs = self.s3fs.S3FileSystem(anon=True)
|
|
62
|
+
self.method = AwsAuthMethod.ANONYMOUS
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
# 1. Try explicit credentials from environment
|
|
66
|
+
if self.access_key and self.secret_key:
|
|
67
|
+
logger.debug("Authenticating to S3 using explicit credentials from environment")
|
|
68
|
+
try:
|
|
69
|
+
self.fs = self.s3fs.S3FileSystem(
|
|
70
|
+
key=self.access_key,
|
|
71
|
+
secret=self.secret_key,
|
|
72
|
+
token=self.session_token,
|
|
73
|
+
)
|
|
74
|
+
self.method = AwsAuthMethod.EXPLICIT_KEYS
|
|
75
|
+
return
|
|
76
|
+
except Exception as e:
|
|
77
|
+
logger.debug("Explicit credentials authentication failed: {}", e)
|
|
78
|
+
|
|
79
|
+
# 2. Fall back to profile-based auth (boto3 credential chain)
|
|
80
|
+
# This handles: ~/.aws/credentials, ~/.aws/config, SSO, assume role, instance metadata
|
|
81
|
+
profile_desc = f"profile '{self.profile}'" if self.profile else "default credential chain"
|
|
82
|
+
logger.debug("Authenticating to S3 using {}", profile_desc)
|
|
83
|
+
try:
|
|
84
|
+
self.fs = self.s3fs.S3FileSystem(profile=self.profile)
|
|
85
|
+
self.method = AwsAuthMethod.PROFILE
|
|
86
|
+
return
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.debug("Profile-based authentication failed: {}", e)
|
|
89
|
+
|
|
90
|
+
if self.fs is None:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"Could not authenticate to AWS S3. "
|
|
93
|
+
"Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY, "
|
|
94
|
+
"configure ~/.aws/credentials, run 'aws configure', "
|
|
95
|
+
"or run 'aws sso login'."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def normalize_for_polars(self, url: str) -> str:
|
|
99
|
+
"""Normalize URL to a format Polars understands.
|
|
100
|
+
|
|
101
|
+
Polars expects s3://bucket/path format.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Normalized URL in s3://bucket/path format.
|
|
105
|
+
"""
|
|
106
|
+
parsed = parse_url(url)
|
|
107
|
+
return f"s3://{parsed.bucket}/{parsed.path}"
|
|
108
|
+
|
|
109
|
+
def storage_options(self, url: str) -> dict[str, str] | None:
|
|
110
|
+
"""Return storage options for Polars S3 access.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Dict with appropriate authentication options for S3.
|
|
114
|
+
Includes both s3fs-style keys and Rust object_store keys for compatibility.
|
|
115
|
+
"""
|
|
116
|
+
if self.method == AwsAuthMethod.ANONYMOUS:
|
|
117
|
+
return {
|
|
118
|
+
"anon": True,
|
|
119
|
+
"aws_skip_signature": "true",
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
if self.method == AwsAuthMethod.EXPLICIT_KEYS:
|
|
123
|
+
opts = {
|
|
124
|
+
# s3fs keys
|
|
125
|
+
"key": self.access_key,
|
|
126
|
+
"secret": self.secret_key,
|
|
127
|
+
# object_store keys
|
|
128
|
+
"aws_access_key_id": self.access_key,
|
|
129
|
+
"aws_secret_access_key": self.secret_key,
|
|
130
|
+
}
|
|
131
|
+
if self.session_token:
|
|
132
|
+
opts["token"] = self.session_token
|
|
133
|
+
opts["aws_session_token"] = self.session_token
|
|
134
|
+
if self.region:
|
|
135
|
+
opts["client_kwargs"] = {"region_name": self.region}
|
|
136
|
+
opts["aws_region"] = self.region
|
|
137
|
+
return opts
|
|
138
|
+
|
|
139
|
+
if self.method == AwsAuthMethod.PROFILE:
|
|
140
|
+
# For profile-based auth, we need to fetch the resolved credentials
|
|
141
|
+
# so Polars can use them (Polars doesn't understand profiles directly)
|
|
142
|
+
try:
|
|
143
|
+
credentials = self._get_credentials_from_session()
|
|
144
|
+
if credentials:
|
|
145
|
+
opts = {
|
|
146
|
+
# s3fs keys
|
|
147
|
+
"key": credentials["access_key"],
|
|
148
|
+
"secret": credentials["secret_key"],
|
|
149
|
+
# object_store keys
|
|
150
|
+
"aws_access_key_id": credentials["access_key"],
|
|
151
|
+
"aws_secret_access_key": credentials["secret_key"],
|
|
152
|
+
}
|
|
153
|
+
if credentials.get("token"):
|
|
154
|
+
opts["token"] = credentials["token"]
|
|
155
|
+
opts["aws_session_token"] = credentials["token"]
|
|
156
|
+
if self.region:
|
|
157
|
+
opts["client_kwargs"] = {"region_name": self.region}
|
|
158
|
+
opts["aws_region"] = self.region
|
|
159
|
+
return opts
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logger.debug("Failed to resolve credentials from session: {}", e)
|
|
162
|
+
|
|
163
|
+
# Fallback: just pass the profile and hope Polars/fsspec can handle it
|
|
164
|
+
opts = {}
|
|
165
|
+
if self.profile:
|
|
166
|
+
opts["profile"] = self.profile
|
|
167
|
+
if self.region:
|
|
168
|
+
opts["client_kwargs"] = {"region_name": self.region}
|
|
169
|
+
opts["aws_region"] = self.region
|
|
170
|
+
return opts if opts else None
|
|
171
|
+
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
def _get_credentials_from_session(self) -> dict | None:
|
|
175
|
+
"""Get resolved credentials from boto3 session."""
|
|
176
|
+
try:
|
|
177
|
+
import boto3
|
|
178
|
+
session = boto3.Session(profile_name=self.profile)
|
|
179
|
+
credentials = session.get_credentials()
|
|
180
|
+
if credentials:
|
|
181
|
+
frozen = credentials.get_frozen_credentials()
|
|
182
|
+
return {
|
|
183
|
+
"access_key": frozen.access_key,
|
|
184
|
+
"secret_key": frozen.secret_key,
|
|
185
|
+
"token": frozen.token,
|
|
186
|
+
}
|
|
187
|
+
except Exception:
|
|
188
|
+
pass
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
def _to_internal(self, url: str) -> str:
|
|
192
|
+
"""Convert URL to internal path for s3fs operations."""
|
|
193
|
+
parsed = parse_url(url)
|
|
194
|
+
return f"{parsed.bucket}/{parsed.path}"
|
|
195
|
+
|
|
196
|
+
def _to_uri(self, internal_path: str) -> str:
|
|
197
|
+
"""Convert internal path back to s3:// URL."""
|
|
198
|
+
return f"s3://{internal_path}"
|
|
199
|
+
|
|
200
|
+
def open(self, url: str) -> BinaryIO:
|
|
201
|
+
return self.fs.open(self._to_internal(url), "rb")
|
|
202
|
+
|
|
203
|
+
def list_files(self, url: str, extension: str) -> Iterator[FileInfo]:
|
|
204
|
+
internal_path = self._to_internal(url)
|
|
205
|
+
pattern = f"{internal_path}/**/*{extension}"
|
|
206
|
+
for path in sorted(self.fs.glob(pattern)):
|
|
207
|
+
info = self.fs.info(path)
|
|
208
|
+
yield FileInfo(url=self._to_uri(path), size=info["size"])
|
|
209
|
+
|
|
210
|
+
def size(self, url: str) -> int:
|
|
211
|
+
return self.fs.size(self._to_internal(url))
|
|
212
|
+
|
|
213
|
+
def is_directory(self, url: str) -> bool:
|
|
214
|
+
path = self._to_internal(url)
|
|
215
|
+
try:
|
|
216
|
+
info = self.fs.info(path)
|
|
217
|
+
return info.get("type") == "directory"
|
|
218
|
+
except FileNotFoundError:
|
|
219
|
+
try:
|
|
220
|
+
contents = self.fs.ls(path, detail=False)
|
|
221
|
+
return len(contents) > 0
|
|
222
|
+
except Exception:
|
|
223
|
+
return False
|