atdata 0.2.3b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +30 -0
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +29 -15
- atdata/_hf_api.py +63 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +19 -62
- atdata/_schema_codec.py +5 -4
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +19 -9
- atdata/atmosphere/records.py +3 -2
- atdata/atmosphere/schema.py +2 -2
- atdata/cli/__init__.py +157 -171
- atdata/cli/inspect.py +69 -0
- atdata/cli/local.py +1 -1
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +428 -326
- atdata/lens.py +9 -2
- atdata/local/__init__.py +71 -0
- atdata/local/_entry.py +157 -0
- atdata/local/_index.py +940 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/local/_s3.py +349 -0
- atdata/local/_schema.py +380 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +4 -4
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/testing.py +337 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +4 -1
- atdata-0.3.0b1.dist-info/RECORD +54 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""Deprecated Repo class for legacy S3 repository operations."""
|
|
2
|
+
|
|
3
|
+
from atdata import Dataset
|
|
4
|
+
|
|
5
|
+
from atdata.local._entry import LocalDatasetEntry
|
|
6
|
+
from atdata.local._s3 import _s3_env, _s3_from_credentials, _create_s3_write_callbacks
|
|
7
|
+
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from uuid import uuid4
|
|
10
|
+
from tempfile import TemporaryDirectory
|
|
11
|
+
from typing import Any, BinaryIO, TypeVar, cast
|
|
12
|
+
|
|
13
|
+
from redis import Redis
|
|
14
|
+
import msgpack
|
|
15
|
+
import webdataset as wds
|
|
16
|
+
import warnings
|
|
17
|
+
|
|
18
|
+
from atdata._protocols import Packable
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T", bound=Packable)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Repo:
|
|
24
|
+
"""Repository for storing and managing atdata datasets.
|
|
25
|
+
|
|
26
|
+
.. deprecated::
|
|
27
|
+
Use :class:`Index` with :class:`S3DataStore` instead::
|
|
28
|
+
|
|
29
|
+
store = S3DataStore(credentials, bucket="my-bucket")
|
|
30
|
+
index = Index(redis=redis, data_store=store)
|
|
31
|
+
entry = index.insert_dataset(ds, name="my-dataset")
|
|
32
|
+
|
|
33
|
+
Provides storage of datasets in S3-compatible object storage with Redis-based
|
|
34
|
+
indexing. Datasets are stored as WebDataset tar files with optional metadata.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
s3_credentials: S3 credentials dictionary or None.
|
|
38
|
+
bucket_fs: S3FileSystem instance or None.
|
|
39
|
+
hive_path: Path within S3 bucket for storing datasets.
|
|
40
|
+
hive_bucket: Name of the S3 bucket.
|
|
41
|
+
index: Index instance for tracking datasets.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
##
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
s3_credentials: str | Path | dict[str, Any] | None = None,
|
|
49
|
+
hive_path: str | Path | None = None,
|
|
50
|
+
redis: Redis | None = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Initialize a repository.
|
|
53
|
+
|
|
54
|
+
.. deprecated::
|
|
55
|
+
Use Index with S3DataStore instead.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
s3_credentials: Path to .env file with S3 credentials, or dict with
|
|
59
|
+
AWS_ENDPOINT, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
|
|
60
|
+
If None, S3 functionality will be disabled.
|
|
61
|
+
hive_path: Path within the S3 bucket to store datasets.
|
|
62
|
+
Required if s3_credentials is provided.
|
|
63
|
+
redis: Redis connection for indexing. If None, creates a new connection.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If hive_path is not provided when s3_credentials is set.
|
|
67
|
+
"""
|
|
68
|
+
warnings.warn(
|
|
69
|
+
"Repo is deprecated. Use Index with S3DataStore instead:\n"
|
|
70
|
+
" store = S3DataStore(credentials, bucket='my-bucket')\n"
|
|
71
|
+
" index = Index(redis=redis, data_store=store)\n"
|
|
72
|
+
" entry = index.insert_dataset(ds, name='my-dataset')",
|
|
73
|
+
DeprecationWarning,
|
|
74
|
+
stacklevel=2,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if s3_credentials is None:
|
|
78
|
+
self.s3_credentials = None
|
|
79
|
+
elif isinstance(s3_credentials, dict):
|
|
80
|
+
self.s3_credentials = s3_credentials
|
|
81
|
+
else:
|
|
82
|
+
self.s3_credentials = _s3_env(s3_credentials)
|
|
83
|
+
|
|
84
|
+
if self.s3_credentials is None:
|
|
85
|
+
self.bucket_fs = None
|
|
86
|
+
else:
|
|
87
|
+
self.bucket_fs = _s3_from_credentials(self.s3_credentials)
|
|
88
|
+
|
|
89
|
+
if self.bucket_fs is not None:
|
|
90
|
+
if hive_path is None:
|
|
91
|
+
raise ValueError("Must specify hive path within bucket")
|
|
92
|
+
self.hive_path = Path(hive_path)
|
|
93
|
+
self.hive_bucket = self.hive_path.parts[0]
|
|
94
|
+
else:
|
|
95
|
+
self.hive_path = None
|
|
96
|
+
self.hive_bucket = None
|
|
97
|
+
|
|
98
|
+
#
|
|
99
|
+
|
|
100
|
+
from atdata.local._index import Index
|
|
101
|
+
|
|
102
|
+
self.index = Index(redis=redis)
|
|
103
|
+
|
|
104
|
+
##
|
|
105
|
+
|
|
106
|
+
def insert(
|
|
107
|
+
self,
|
|
108
|
+
ds: Dataset[T],
|
|
109
|
+
*,
|
|
110
|
+
name: str,
|
|
111
|
+
cache_local: bool = False,
|
|
112
|
+
schema_ref: str | None = None,
|
|
113
|
+
**kwargs,
|
|
114
|
+
) -> tuple[LocalDatasetEntry, Dataset[T]]:
|
|
115
|
+
"""Insert a dataset into the repository.
|
|
116
|
+
|
|
117
|
+
Writes the dataset to S3 as WebDataset tar files, stores metadata,
|
|
118
|
+
and creates an index entry in Redis.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
ds: The dataset to insert.
|
|
122
|
+
name: Human-readable name for the dataset.
|
|
123
|
+
cache_local: If True, write to local temporary storage first, then
|
|
124
|
+
copy to S3. This can be faster for some workloads.
|
|
125
|
+
schema_ref: Optional schema reference. If None, generates from sample type.
|
|
126
|
+
**kwargs: Additional arguments passed to wds.ShardWriter.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
A tuple of (index_entry, new_dataset) where:
|
|
130
|
+
- index_entry: LocalDatasetEntry for the stored dataset
|
|
131
|
+
- new_dataset: Dataset object pointing to the stored copy
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
ValueError: If S3 credentials or hive_path are not configured.
|
|
135
|
+
RuntimeError: If no shards were written.
|
|
136
|
+
"""
|
|
137
|
+
if self.s3_credentials is None:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"S3 credentials required for insert(). Initialize Repo with s3_credentials."
|
|
140
|
+
)
|
|
141
|
+
if self.hive_bucket is None or self.hive_path is None:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"hive_path required for insert(). Initialize Repo with hive_path."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
new_uuid = str(uuid4())
|
|
147
|
+
|
|
148
|
+
hive_fs = _s3_from_credentials(self.s3_credentials)
|
|
149
|
+
|
|
150
|
+
# Write metadata
|
|
151
|
+
metadata_path = (
|
|
152
|
+
self.hive_path / "metadata" / f"atdata-metadata--{new_uuid}.msgpack"
|
|
153
|
+
)
|
|
154
|
+
# Note: S3 doesn't need directories created beforehand - s3fs handles this
|
|
155
|
+
|
|
156
|
+
if ds.metadata is not None:
|
|
157
|
+
# Use s3:// prefix to ensure s3fs treats this as an S3 path
|
|
158
|
+
with cast(
|
|
159
|
+
BinaryIO, hive_fs.open(f"s3://{metadata_path.as_posix()}", "wb")
|
|
160
|
+
) as f:
|
|
161
|
+
meta_packed = msgpack.packb(ds.metadata)
|
|
162
|
+
f.write(cast(bytes, meta_packed))
|
|
163
|
+
|
|
164
|
+
# Write data
|
|
165
|
+
shard_pattern = (self.hive_path / f"atdata--{new_uuid}--%06d.tar").as_posix()
|
|
166
|
+
|
|
167
|
+
written_shards: list[str] = []
|
|
168
|
+
with TemporaryDirectory() as temp_dir:
|
|
169
|
+
writer_opener, writer_post = _create_s3_write_callbacks(
|
|
170
|
+
credentials=self.s3_credentials,
|
|
171
|
+
temp_dir=temp_dir,
|
|
172
|
+
written_shards=written_shards,
|
|
173
|
+
fs=hive_fs,
|
|
174
|
+
cache_local=cache_local,
|
|
175
|
+
add_s3_prefix=False,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
with wds.writer.ShardWriter(
|
|
179
|
+
shard_pattern,
|
|
180
|
+
opener=writer_opener,
|
|
181
|
+
post=writer_post,
|
|
182
|
+
**kwargs,
|
|
183
|
+
) as sink:
|
|
184
|
+
for sample in ds.ordered(batch_size=None):
|
|
185
|
+
sink.write(sample.as_wds)
|
|
186
|
+
|
|
187
|
+
# Make a new Dataset object for the written dataset copy
|
|
188
|
+
if len(written_shards) == 0:
|
|
189
|
+
raise RuntimeError(
|
|
190
|
+
"Cannot form new dataset entry -- did not write any shards"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
elif len(written_shards) < 2:
|
|
194
|
+
new_dataset_url = (
|
|
195
|
+
self.hive_path / (Path(written_shards[0]).name)
|
|
196
|
+
).as_posix()
|
|
197
|
+
|
|
198
|
+
else:
|
|
199
|
+
shard_s3_format = (
|
|
200
|
+
(self.hive_path / f"atdata--{new_uuid}").as_posix()
|
|
201
|
+
) + "--{shard_id}.tar"
|
|
202
|
+
shard_id_braced = "{" + f"{0:06d}..{len(written_shards) - 1:06d}" + "}"
|
|
203
|
+
new_dataset_url = shard_s3_format.format(shard_id=shard_id_braced)
|
|
204
|
+
|
|
205
|
+
new_dataset = Dataset[ds.sample_type](
|
|
206
|
+
url=new_dataset_url,
|
|
207
|
+
metadata_url=metadata_path.as_posix(),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Add to index (use ds._metadata to avoid network requests)
|
|
211
|
+
new_entry = self.index.add_entry(
|
|
212
|
+
new_dataset,
|
|
213
|
+
name=name,
|
|
214
|
+
schema_ref=schema_ref,
|
|
215
|
+
metadata=ds._metadata,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return new_entry, new_dataset
|
atdata/local/_s3.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""S3-compatible data store and helper functions."""
|
|
2
|
+
|
|
3
|
+
from atdata import Dataset
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
from tempfile import TemporaryDirectory
|
|
8
|
+
from dotenv import dotenv_values
|
|
9
|
+
from typing import Any, BinaryIO, cast
|
|
10
|
+
|
|
11
|
+
from s3fs import S3FileSystem
|
|
12
|
+
import webdataset as wds
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _s3_env(credentials_path: str | Path) -> dict[str, Any]:
|
|
16
|
+
"""Load S3 credentials from .env file.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
credentials_path: Path to .env file containing AWS_ENDPOINT,
|
|
20
|
+
AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Dict with the three required credential keys.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If any required key is missing from the .env file.
|
|
27
|
+
"""
|
|
28
|
+
credentials_path = Path(credentials_path)
|
|
29
|
+
env_values = dotenv_values(credentials_path)
|
|
30
|
+
|
|
31
|
+
required_keys = ("AWS_ENDPOINT", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")
|
|
32
|
+
missing = [k for k in required_keys if k not in env_values]
|
|
33
|
+
if missing:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Missing required keys in {credentials_path}: {', '.join(missing)}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
return {k: env_values[k] for k in required_keys}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _s3_from_credentials(creds: str | Path | dict) -> S3FileSystem:
|
|
42
|
+
"""Create S3FileSystem from credentials dict or .env file path."""
|
|
43
|
+
if not isinstance(creds, dict):
|
|
44
|
+
creds = _s3_env(creds)
|
|
45
|
+
|
|
46
|
+
# Build kwargs, making endpoint_url optional
|
|
47
|
+
kwargs = {
|
|
48
|
+
"key": creds["AWS_ACCESS_KEY_ID"],
|
|
49
|
+
"secret": creds["AWS_SECRET_ACCESS_KEY"],
|
|
50
|
+
}
|
|
51
|
+
if "AWS_ENDPOINT" in creds:
|
|
52
|
+
kwargs["endpoint_url"] = creds["AWS_ENDPOINT"]
|
|
53
|
+
|
|
54
|
+
return S3FileSystem(**kwargs)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _create_s3_write_callbacks(
|
|
58
|
+
credentials: dict[str, Any],
|
|
59
|
+
temp_dir: str,
|
|
60
|
+
written_shards: list[str],
|
|
61
|
+
fs: S3FileSystem | None,
|
|
62
|
+
cache_local: bool,
|
|
63
|
+
add_s3_prefix: bool = False,
|
|
64
|
+
) -> tuple:
|
|
65
|
+
"""Create opener and post callbacks for ShardWriter with S3 upload.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
credentials: S3 credentials dict.
|
|
69
|
+
temp_dir: Temporary directory for local caching.
|
|
70
|
+
written_shards: List to append written shard paths to.
|
|
71
|
+
fs: S3FileSystem for direct writes (used when cache_local=False).
|
|
72
|
+
cache_local: If True, write locally then copy to S3.
|
|
73
|
+
add_s3_prefix: If True, prepend 's3://' to shard paths.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (writer_opener, writer_post) callbacks.
|
|
77
|
+
"""
|
|
78
|
+
if cache_local:
|
|
79
|
+
import boto3
|
|
80
|
+
|
|
81
|
+
s3_client_kwargs = {
|
|
82
|
+
"aws_access_key_id": credentials["AWS_ACCESS_KEY_ID"],
|
|
83
|
+
"aws_secret_access_key": credentials["AWS_SECRET_ACCESS_KEY"],
|
|
84
|
+
}
|
|
85
|
+
if "AWS_ENDPOINT" in credentials:
|
|
86
|
+
s3_client_kwargs["endpoint_url"] = credentials["AWS_ENDPOINT"]
|
|
87
|
+
s3_client = boto3.client("s3", **s3_client_kwargs)
|
|
88
|
+
|
|
89
|
+
def _writer_opener(p: str):
|
|
90
|
+
local_path = Path(temp_dir) / p
|
|
91
|
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
92
|
+
return open(local_path, "wb")
|
|
93
|
+
|
|
94
|
+
def _writer_post(p: str):
|
|
95
|
+
local_path = Path(temp_dir) / p
|
|
96
|
+
path_parts = Path(p).parts
|
|
97
|
+
bucket = path_parts[0]
|
|
98
|
+
key = str(Path(*path_parts[1:]))
|
|
99
|
+
|
|
100
|
+
with open(local_path, "rb") as f_in:
|
|
101
|
+
s3_client.put_object(Bucket=bucket, Key=key, Body=f_in.read())
|
|
102
|
+
|
|
103
|
+
local_path.unlink()
|
|
104
|
+
if add_s3_prefix:
|
|
105
|
+
written_shards.append(f"s3://{p}")
|
|
106
|
+
else:
|
|
107
|
+
written_shards.append(p)
|
|
108
|
+
|
|
109
|
+
return _writer_opener, _writer_post
|
|
110
|
+
else:
|
|
111
|
+
if fs is None:
|
|
112
|
+
raise ValueError("S3FileSystem required when cache_local=False")
|
|
113
|
+
|
|
114
|
+
def _direct_opener(s: str):
|
|
115
|
+
return cast(BinaryIO, fs.open(f"s3://{s}", "wb"))
|
|
116
|
+
|
|
117
|
+
def _direct_post(s: str):
|
|
118
|
+
if add_s3_prefix:
|
|
119
|
+
written_shards.append(f"s3://{s}")
|
|
120
|
+
else:
|
|
121
|
+
written_shards.append(s)
|
|
122
|
+
|
|
123
|
+
return _direct_opener, _direct_post
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class S3DataStore:
|
|
127
|
+
"""S3-compatible data store implementing AbstractDataStore protocol.
|
|
128
|
+
|
|
129
|
+
Handles writing dataset shards to S3-compatible object storage and
|
|
130
|
+
resolving URLs for reading.
|
|
131
|
+
|
|
132
|
+
Attributes:
|
|
133
|
+
credentials: S3 credentials dictionary.
|
|
134
|
+
bucket: Target bucket name.
|
|
135
|
+
_fs: S3FileSystem instance.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
credentials: str | Path | dict[str, Any],
|
|
141
|
+
*,
|
|
142
|
+
bucket: str,
|
|
143
|
+
) -> None:
|
|
144
|
+
"""Initialize an S3 data store.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
credentials: Path to .env file or dict with AWS_ACCESS_KEY_ID,
|
|
148
|
+
AWS_SECRET_ACCESS_KEY, and optionally AWS_ENDPOINT.
|
|
149
|
+
bucket: Name of the S3 bucket for storage.
|
|
150
|
+
"""
|
|
151
|
+
if isinstance(credentials, dict):
|
|
152
|
+
self.credentials = credentials
|
|
153
|
+
else:
|
|
154
|
+
self.credentials = _s3_env(credentials)
|
|
155
|
+
|
|
156
|
+
self.bucket = bucket
|
|
157
|
+
self._fs = _s3_from_credentials(self.credentials)
|
|
158
|
+
|
|
159
|
+
def write_shards(
|
|
160
|
+
self,
|
|
161
|
+
ds: Dataset,
|
|
162
|
+
*,
|
|
163
|
+
prefix: str,
|
|
164
|
+
cache_local: bool = False,
|
|
165
|
+
manifest: bool = False,
|
|
166
|
+
schema_version: str = "1.0.0",
|
|
167
|
+
source_job_id: str | None = None,
|
|
168
|
+
parent_shards: list[str] | None = None,
|
|
169
|
+
pipeline_version: str | None = None,
|
|
170
|
+
**kwargs,
|
|
171
|
+
) -> list[str]:
|
|
172
|
+
"""Write dataset shards to S3.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
ds: The Dataset to write.
|
|
176
|
+
prefix: Path prefix within bucket (e.g., 'datasets/mnist/v1').
|
|
177
|
+
cache_local: If True, write locally first then copy to S3.
|
|
178
|
+
manifest: If True, generate per-shard manifest files alongside
|
|
179
|
+
each tar shard (``.manifest.json`` + ``.manifest.parquet``).
|
|
180
|
+
schema_version: Schema version for manifest headers.
|
|
181
|
+
source_job_id: Optional provenance job identifier for manifests.
|
|
182
|
+
parent_shards: Optional list of input shard identifiers for provenance.
|
|
183
|
+
pipeline_version: Optional pipeline version string for provenance.
|
|
184
|
+
**kwargs: Additional args passed to wds.ShardWriter (e.g., maxcount).
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
List of S3 URLs for the written shards.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
RuntimeError: If no shards were written.
|
|
191
|
+
"""
|
|
192
|
+
new_uuid = str(uuid4())
|
|
193
|
+
shard_pattern = f"{self.bucket}/{prefix}/data--{new_uuid}--%06d.tar"
|
|
194
|
+
|
|
195
|
+
written_shards: list[str] = []
|
|
196
|
+
|
|
197
|
+
# Manifest tracking state shared with the post callback
|
|
198
|
+
manifest_builders: list = []
|
|
199
|
+
current_builder: list = [None] # mutable ref for closure
|
|
200
|
+
shard_counter: list[int] = [0]
|
|
201
|
+
|
|
202
|
+
if manifest:
|
|
203
|
+
from atdata.manifest import ManifestBuilder, ManifestWriter
|
|
204
|
+
|
|
205
|
+
def _make_builder(shard_idx: int) -> ManifestBuilder:
|
|
206
|
+
shard_id = f"{self.bucket}/{prefix}/data--{new_uuid}--{shard_idx:06d}"
|
|
207
|
+
return ManifestBuilder(
|
|
208
|
+
sample_type=ds.sample_type,
|
|
209
|
+
shard_id=shard_id,
|
|
210
|
+
schema_version=schema_version,
|
|
211
|
+
source_job_id=source_job_id,
|
|
212
|
+
parent_shards=parent_shards,
|
|
213
|
+
pipeline_version=pipeline_version,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
current_builder[0] = _make_builder(0)
|
|
217
|
+
|
|
218
|
+
with TemporaryDirectory() as temp_dir:
|
|
219
|
+
writer_opener, writer_post_orig = _create_s3_write_callbacks(
|
|
220
|
+
credentials=self.credentials,
|
|
221
|
+
temp_dir=temp_dir,
|
|
222
|
+
written_shards=written_shards,
|
|
223
|
+
fs=self._fs,
|
|
224
|
+
cache_local=cache_local,
|
|
225
|
+
add_s3_prefix=True,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if manifest:
|
|
229
|
+
|
|
230
|
+
def writer_post(p: str):
|
|
231
|
+
# Finalize the current manifest builder when a shard completes
|
|
232
|
+
builder = current_builder[0]
|
|
233
|
+
if builder is not None:
|
|
234
|
+
manifest_builders.append(builder)
|
|
235
|
+
# Advance to the next shard's builder
|
|
236
|
+
shard_counter[0] += 1
|
|
237
|
+
current_builder[0] = _make_builder(shard_counter[0])
|
|
238
|
+
# Call original post callback
|
|
239
|
+
writer_post_orig(p)
|
|
240
|
+
else:
|
|
241
|
+
writer_post = writer_post_orig
|
|
242
|
+
|
|
243
|
+
offset = 0
|
|
244
|
+
with wds.writer.ShardWriter(
|
|
245
|
+
shard_pattern,
|
|
246
|
+
opener=writer_opener,
|
|
247
|
+
post=writer_post,
|
|
248
|
+
**kwargs,
|
|
249
|
+
) as sink:
|
|
250
|
+
for sample in ds.ordered(batch_size=None):
|
|
251
|
+
wds_dict = sample.as_wds
|
|
252
|
+
sink.write(wds_dict)
|
|
253
|
+
|
|
254
|
+
if manifest and current_builder[0] is not None:
|
|
255
|
+
packed_size = len(wds_dict.get("msgpack", b""))
|
|
256
|
+
current_builder[0].add_sample(
|
|
257
|
+
key=wds_dict["__key__"],
|
|
258
|
+
offset=offset,
|
|
259
|
+
size=packed_size,
|
|
260
|
+
sample=sample,
|
|
261
|
+
)
|
|
262
|
+
# Approximate tar entry: 512-byte header + data rounded to 512
|
|
263
|
+
offset += 512 + packed_size + (512 - packed_size % 512) % 512
|
|
264
|
+
|
|
265
|
+
# Finalize the last shard's builder (post isn't called for the last shard
|
|
266
|
+
# until ShardWriter closes, but we handle it here for safety)
|
|
267
|
+
if manifest and current_builder[0] is not None:
|
|
268
|
+
builder = current_builder[0]
|
|
269
|
+
if builder._rows: # Only if samples were added
|
|
270
|
+
manifest_builders.append(builder)
|
|
271
|
+
|
|
272
|
+
# Write all manifest files
|
|
273
|
+
if manifest:
|
|
274
|
+
for builder in manifest_builders:
|
|
275
|
+
built = builder.build()
|
|
276
|
+
writer = ManifestWriter(Path(temp_dir) / Path(built.shard_id))
|
|
277
|
+
json_path, parquet_path = writer.write(built)
|
|
278
|
+
|
|
279
|
+
# Upload manifest files to S3 alongside shards
|
|
280
|
+
shard_id = built.shard_id
|
|
281
|
+
json_key = f"{shard_id}.manifest.json"
|
|
282
|
+
parquet_key = f"{shard_id}.manifest.parquet"
|
|
283
|
+
|
|
284
|
+
if cache_local:
|
|
285
|
+
import boto3
|
|
286
|
+
|
|
287
|
+
s3_kwargs = {
|
|
288
|
+
"aws_access_key_id": self.credentials["AWS_ACCESS_KEY_ID"],
|
|
289
|
+
"aws_secret_access_key": self.credentials[
|
|
290
|
+
"AWS_SECRET_ACCESS_KEY"
|
|
291
|
+
],
|
|
292
|
+
}
|
|
293
|
+
if "AWS_ENDPOINT" in self.credentials:
|
|
294
|
+
s3_kwargs["endpoint_url"] = self.credentials["AWS_ENDPOINT"]
|
|
295
|
+
s3_client = boto3.client("s3", **s3_kwargs)
|
|
296
|
+
|
|
297
|
+
bucket_name = Path(shard_id).parts[0]
|
|
298
|
+
json_s3_key = str(Path(*Path(json_key).parts[1:]))
|
|
299
|
+
parquet_s3_key = str(Path(*Path(parquet_key).parts[1:]))
|
|
300
|
+
|
|
301
|
+
with open(json_path, "rb") as f:
|
|
302
|
+
s3_client.put_object(
|
|
303
|
+
Bucket=bucket_name, Key=json_s3_key, Body=f.read()
|
|
304
|
+
)
|
|
305
|
+
with open(parquet_path, "rb") as f:
|
|
306
|
+
s3_client.put_object(
|
|
307
|
+
Bucket=bucket_name, Key=parquet_s3_key, Body=f.read()
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
self._fs.put(str(json_path), f"s3://{json_key}")
|
|
311
|
+
self._fs.put(str(parquet_path), f"s3://{parquet_key}")
|
|
312
|
+
|
|
313
|
+
if len(written_shards) == 0:
|
|
314
|
+
raise RuntimeError("No shards written")
|
|
315
|
+
|
|
316
|
+
return written_shards
|
|
317
|
+
|
|
318
|
+
def read_url(self, url: str) -> str:
|
|
319
|
+
"""Resolve an S3 URL for reading/streaming.
|
|
320
|
+
|
|
321
|
+
For S3-compatible stores with custom endpoints (like Cloudflare R2,
|
|
322
|
+
MinIO, etc.), converts s3:// URLs to HTTPS URLs that WebDataset can
|
|
323
|
+
stream directly.
|
|
324
|
+
|
|
325
|
+
For standard AWS S3 (no custom endpoint), URLs are returned unchanged
|
|
326
|
+
since WebDataset's built-in s3fs integration handles them.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
url: S3 URL to resolve (e.g., 's3://bucket/path/file.tar').
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
HTTPS URL if custom endpoint is configured, otherwise unchanged.
|
|
333
|
+
Example: 's3://bucket/path' -> 'https://endpoint.com/bucket/path'
|
|
334
|
+
"""
|
|
335
|
+
endpoint = self.credentials.get("AWS_ENDPOINT")
|
|
336
|
+
if endpoint and url.startswith("s3://"):
|
|
337
|
+
# s3://bucket/path -> https://endpoint/bucket/path
|
|
338
|
+
path = url[5:] # Remove 's3://' prefix
|
|
339
|
+
endpoint = endpoint.rstrip("/")
|
|
340
|
+
return f"{endpoint}/{path}"
|
|
341
|
+
return url
|
|
342
|
+
|
|
343
|
+
def supports_streaming(self) -> bool:
|
|
344
|
+
"""S3 supports streaming reads.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
True.
|
|
348
|
+
"""
|
|
349
|
+
return True
|