atdata 0.2.3b1__py3-none-any.whl → 0.3.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/.gitignore +1 -0
- atdata/__init__.py +39 -0
- atdata/_cid.py +0 -21
- atdata/_exceptions.py +168 -0
- atdata/_helpers.py +41 -15
- atdata/_hf_api.py +95 -11
- atdata/_logging.py +70 -0
- atdata/_protocols.py +77 -238
- atdata/_schema_codec.py +7 -6
- atdata/_stub_manager.py +5 -25
- atdata/_type_utils.py +28 -2
- atdata/atmosphere/__init__.py +31 -20
- atdata/atmosphere/_types.py +4 -4
- atdata/atmosphere/client.py +64 -12
- atdata/atmosphere/lens.py +11 -12
- atdata/atmosphere/records.py +12 -12
- atdata/atmosphere/schema.py +16 -18
- atdata/atmosphere/store.py +6 -7
- atdata/cli/__init__.py +161 -175
- atdata/cli/diagnose.py +2 -2
- atdata/cli/{local.py → infra.py} +11 -11
- atdata/cli/inspect.py +69 -0
- atdata/cli/preview.py +63 -0
- atdata/cli/schema.py +109 -0
- atdata/dataset.py +583 -328
- atdata/index/__init__.py +54 -0
- atdata/index/_entry.py +157 -0
- atdata/index/_index.py +1198 -0
- atdata/index/_schema.py +380 -0
- atdata/lens.py +9 -2
- atdata/lexicons/__init__.py +121 -0
- atdata/lexicons/ac.foundation.dataset.arrayFormat.json +16 -0
- atdata/lexicons/ac.foundation.dataset.getLatestSchema.json +78 -0
- atdata/lexicons/ac.foundation.dataset.lens.json +99 -0
- atdata/lexicons/ac.foundation.dataset.record.json +96 -0
- atdata/lexicons/ac.foundation.dataset.schema.json +107 -0
- atdata/lexicons/ac.foundation.dataset.schemaType.json +16 -0
- atdata/lexicons/ac.foundation.dataset.storageBlobs.json +24 -0
- atdata/lexicons/ac.foundation.dataset.storageExternal.json +25 -0
- atdata/lexicons/ndarray_shim.json +16 -0
- atdata/local/__init__.py +70 -0
- atdata/local/_repo_legacy.py +218 -0
- atdata/manifest/__init__.py +28 -0
- atdata/manifest/_aggregates.py +156 -0
- atdata/manifest/_builder.py +163 -0
- atdata/manifest/_fields.py +154 -0
- atdata/manifest/_manifest.py +146 -0
- atdata/manifest/_query.py +150 -0
- atdata/manifest/_writer.py +74 -0
- atdata/promote.py +18 -14
- atdata/providers/__init__.py +25 -0
- atdata/providers/_base.py +140 -0
- atdata/providers/_factory.py +69 -0
- atdata/providers/_postgres.py +214 -0
- atdata/providers/_redis.py +171 -0
- atdata/providers/_sqlite.py +191 -0
- atdata/repository.py +323 -0
- atdata/stores/__init__.py +23 -0
- atdata/stores/_disk.py +123 -0
- atdata/stores/_s3.py +349 -0
- atdata/testing.py +341 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/METADATA +5 -2
- atdata-0.3.1b1.dist-info/RECORD +67 -0
- atdata/local.py +0 -1720
- atdata-0.2.3b1.dist-info/RECORD +0 -28
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/WHEEL +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.2.3b1.dist-info → atdata-0.3.1b1.dist-info}/licenses/LICENSE +0 -0
atdata/repository.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""Repository and atmosphere backend for the unified Index.
|
|
2
|
+
|
|
3
|
+
A ``Repository`` pairs an ``IndexProvider`` (persistence backend) with an
|
|
4
|
+
optional ``AbstractDataStore`` (shard storage), forming a named storage unit
|
|
5
|
+
that can be mounted into an ``Index``.
|
|
6
|
+
|
|
7
|
+
The ``_AtmosphereBackend`` is an internal adapter that wraps an
|
|
8
|
+
``Atmosphere`` to present the same operational surface as a repository,
|
|
9
|
+
but routes through the ATProto network instead of a local provider.
|
|
10
|
+
|
|
11
|
+
Examples:
|
|
12
|
+
>>> from atdata.repository import Repository, create_repository
|
|
13
|
+
>>> repo = Repository(provider=SqliteProvider("/data/lab.db"))
|
|
14
|
+
>>> repo = create_repository("sqlite", path="/data/lab.db")
|
|
15
|
+
>>>
|
|
16
|
+
>>> # With a data store for shard storage
|
|
17
|
+
>>> repo = Repository(
|
|
18
|
+
... provider=SqliteProvider(),
|
|
19
|
+
... data_store=S3DataStore(credentials, bucket="lab-data"),
|
|
20
|
+
... )
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any, Iterator, Optional, TYPE_CHECKING
|
|
28
|
+
|
|
29
|
+
from ._protocols import AbstractDataStore
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from .providers._base import IndexProvider
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class Repository:
|
|
37
|
+
"""A named storage backend pairing index persistence with optional data storage.
|
|
38
|
+
|
|
39
|
+
Repositories are mounted into an ``Index`` by name. The built-in ``"local"``
|
|
40
|
+
repository uses SQLite by default; additional repositories can be added for
|
|
41
|
+
multi-source dataset management.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
provider: IndexProvider handling dataset/schema persistence.
|
|
45
|
+
data_store: Optional data store for reading/writing dataset shards.
|
|
46
|
+
If present, ``insert_dataset`` will write shards to this store.
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
>>> from atdata.providers import create_provider
|
|
50
|
+
>>> from atdata.repository import Repository
|
|
51
|
+
>>>
|
|
52
|
+
>>> provider = create_provider("sqlite", path="/data/lab.db")
|
|
53
|
+
>>> repo = Repository(provider=provider)
|
|
54
|
+
>>>
|
|
55
|
+
>>> # With S3 shard storage
|
|
56
|
+
>>> repo = Repository(
|
|
57
|
+
... provider=provider,
|
|
58
|
+
... data_store=S3DataStore(credentials, bucket="lab-data"),
|
|
59
|
+
... )
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
provider: IndexProvider
|
|
63
|
+
data_store: AbstractDataStore | None = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def create_repository(
|
|
67
|
+
provider: str = "sqlite",
|
|
68
|
+
*,
|
|
69
|
+
path: str | Path | None = None,
|
|
70
|
+
dsn: str | None = None,
|
|
71
|
+
redis: Any = None,
|
|
72
|
+
data_store: AbstractDataStore | None = None,
|
|
73
|
+
**kwargs: Any,
|
|
74
|
+
) -> Repository:
|
|
75
|
+
"""Create a Repository with a provider by name.
|
|
76
|
+
|
|
77
|
+
This is a convenience factory that combines ``create_provider`` with
|
|
78
|
+
``Repository`` construction.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
provider: Backend name: ``"sqlite"``, ``"redis"``, or ``"postgres"``.
|
|
82
|
+
path: Database file path (SQLite only).
|
|
83
|
+
dsn: Connection string (PostgreSQL only).
|
|
84
|
+
redis: Existing Redis connection (Redis only).
|
|
85
|
+
data_store: Optional data store for shard storage.
|
|
86
|
+
**kwargs: Extra arguments forwarded to the provider constructor.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
A ready-to-use Repository.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ValueError: If provider name is not recognised.
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
>>> repo = create_repository("sqlite", path="/data/lab.db")
|
|
96
|
+
>>> repo = create_repository(
|
|
97
|
+
... "sqlite",
|
|
98
|
+
... data_store=S3DataStore(creds, bucket="lab"),
|
|
99
|
+
... )
|
|
100
|
+
"""
|
|
101
|
+
from .providers._factory import create_provider as _create_provider
|
|
102
|
+
|
|
103
|
+
backend = _create_provider(provider, path=path, dsn=dsn, redis=redis, **kwargs)
|
|
104
|
+
return Repository(provider=backend, data_store=data_store)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class _AtmosphereBackend:
|
|
108
|
+
"""Internal adapter wrapping Atmosphere for Index routing.
|
|
109
|
+
|
|
110
|
+
This class extracts the operational logic from ``AtmosphereIndex`` into an
|
|
111
|
+
internal component that the unified ``Index`` uses for ATProto resolution.
|
|
112
|
+
It is not part of the public API.
|
|
113
|
+
|
|
114
|
+
The backend is lazily initialised -- the publishers/loaders are only
|
|
115
|
+
created when the client is authenticated or when operations require them.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
client: Any, # Atmosphere, typed as Any to avoid hard import
|
|
121
|
+
*,
|
|
122
|
+
data_store: Optional[AbstractDataStore] = None,
|
|
123
|
+
) -> None:
|
|
124
|
+
from .atmosphere.client import Atmosphere
|
|
125
|
+
|
|
126
|
+
if not isinstance(client, Atmosphere):
|
|
127
|
+
raise TypeError(f"Expected Atmosphere, got {type(client).__name__}")
|
|
128
|
+
self.client: Atmosphere = client
|
|
129
|
+
self._data_store = data_store
|
|
130
|
+
self._schema_publisher: Any = None
|
|
131
|
+
self._schema_loader: Any = None
|
|
132
|
+
self._dataset_publisher: Any = None
|
|
133
|
+
self._dataset_loader: Any = None
|
|
134
|
+
|
|
135
|
+
def _ensure_loaders(self) -> None:
|
|
136
|
+
"""Lazily create publishers/loaders on first use."""
|
|
137
|
+
if self._schema_loader is not None:
|
|
138
|
+
return
|
|
139
|
+
from .atmosphere.schema import SchemaPublisher, SchemaLoader
|
|
140
|
+
from .atmosphere.records import DatasetPublisher, DatasetLoader
|
|
141
|
+
|
|
142
|
+
self._schema_publisher = SchemaPublisher(self.client)
|
|
143
|
+
self._schema_loader = SchemaLoader(self.client)
|
|
144
|
+
self._dataset_publisher = DatasetPublisher(self.client)
|
|
145
|
+
self._dataset_loader = DatasetLoader(self.client)
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def data_store(self) -> Optional[AbstractDataStore]:
|
|
149
|
+
"""The data store for this atmosphere backend, or None."""
|
|
150
|
+
return self._data_store
|
|
151
|
+
|
|
152
|
+
# -- Dataset operations --
|
|
153
|
+
|
|
154
|
+
def get_dataset(self, ref: str) -> Any:
|
|
155
|
+
"""Get a dataset entry by name or AT URI.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
ref: Dataset name or AT URI.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
AtmosphereIndexEntry for the dataset.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
ValueError: If record is not a dataset.
|
|
165
|
+
"""
|
|
166
|
+
self._ensure_loaders()
|
|
167
|
+
from .atmosphere import AtmosphereIndexEntry
|
|
168
|
+
|
|
169
|
+
record = self._dataset_loader.get(ref)
|
|
170
|
+
return AtmosphereIndexEntry(ref, record)
|
|
171
|
+
|
|
172
|
+
def list_datasets(self, repo: str | None = None) -> list[Any]:
|
|
173
|
+
"""List all dataset entries.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
repo: DID of repository. Defaults to authenticated user.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
List of AtmosphereIndexEntry for each dataset.
|
|
180
|
+
"""
|
|
181
|
+
self._ensure_loaders()
|
|
182
|
+
from .atmosphere import AtmosphereIndexEntry
|
|
183
|
+
|
|
184
|
+
records = self._dataset_loader.list_all(repo=repo)
|
|
185
|
+
return [
|
|
186
|
+
AtmosphereIndexEntry(rec.get("uri", ""), rec.get("value", rec))
|
|
187
|
+
for rec in records
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
def iter_datasets(self, repo: str | None = None) -> Iterator[Any]:
|
|
191
|
+
"""Lazily iterate over all dataset entries.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
repo: DID of repository. Defaults to authenticated user.
|
|
195
|
+
|
|
196
|
+
Yields:
|
|
197
|
+
AtmosphereIndexEntry for each dataset.
|
|
198
|
+
"""
|
|
199
|
+
self._ensure_loaders()
|
|
200
|
+
from .atmosphere import AtmosphereIndexEntry
|
|
201
|
+
|
|
202
|
+
records = self._dataset_loader.list_all(repo=repo)
|
|
203
|
+
for rec in records:
|
|
204
|
+
uri = rec.get("uri", "")
|
|
205
|
+
yield AtmosphereIndexEntry(uri, rec.get("value", rec))
|
|
206
|
+
|
|
207
|
+
def insert_dataset(
|
|
208
|
+
self,
|
|
209
|
+
ds: Any,
|
|
210
|
+
*,
|
|
211
|
+
name: str,
|
|
212
|
+
schema_ref: str | None = None,
|
|
213
|
+
**kwargs: Any,
|
|
214
|
+
) -> Any:
|
|
215
|
+
"""Insert a dataset into ATProto.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
ds: The Dataset to publish.
|
|
219
|
+
name: Human-readable name.
|
|
220
|
+
schema_ref: Optional schema AT URI. If None, auto-publishes schema.
|
|
221
|
+
**kwargs: Additional options (description, tags, license).
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
AtmosphereIndexEntry for the inserted dataset.
|
|
225
|
+
"""
|
|
226
|
+
self._ensure_loaders()
|
|
227
|
+
from .atmosphere import AtmosphereIndexEntry
|
|
228
|
+
|
|
229
|
+
uri = self._dataset_publisher.publish(
|
|
230
|
+
ds,
|
|
231
|
+
name=name,
|
|
232
|
+
schema_uri=schema_ref,
|
|
233
|
+
description=kwargs.get("description"),
|
|
234
|
+
tags=kwargs.get("tags"),
|
|
235
|
+
license=kwargs.get("license"),
|
|
236
|
+
auto_publish_schema=(schema_ref is None),
|
|
237
|
+
)
|
|
238
|
+
record = self._dataset_loader.get(uri)
|
|
239
|
+
return AtmosphereIndexEntry(str(uri), record)
|
|
240
|
+
|
|
241
|
+
# -- Schema operations --
|
|
242
|
+
|
|
243
|
+
def publish_schema(
|
|
244
|
+
self,
|
|
245
|
+
sample_type: type,
|
|
246
|
+
*,
|
|
247
|
+
version: str = "1.0.0",
|
|
248
|
+
**kwargs: Any,
|
|
249
|
+
) -> str:
|
|
250
|
+
"""Publish a schema to ATProto.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
sample_type: A Packable type.
|
|
254
|
+
version: Semantic version string.
|
|
255
|
+
**kwargs: Additional options.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
AT URI of the schema record.
|
|
259
|
+
"""
|
|
260
|
+
self._ensure_loaders()
|
|
261
|
+
uri = self._schema_publisher.publish(
|
|
262
|
+
sample_type,
|
|
263
|
+
version=version,
|
|
264
|
+
description=kwargs.get("description"),
|
|
265
|
+
metadata=kwargs.get("metadata"),
|
|
266
|
+
)
|
|
267
|
+
return str(uri)
|
|
268
|
+
|
|
269
|
+
def get_schema(self, ref: str) -> dict:
|
|
270
|
+
"""Get a schema record by AT URI.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
ref: AT URI of the schema record.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
Schema record dictionary.
|
|
277
|
+
"""
|
|
278
|
+
self._ensure_loaders()
|
|
279
|
+
return self._schema_loader.get(ref)
|
|
280
|
+
|
|
281
|
+
def list_schemas(self, repo: str | None = None) -> list[dict]:
|
|
282
|
+
"""List all schema records.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
repo: DID of repository. Defaults to authenticated user.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
List of schema records as dictionaries.
|
|
289
|
+
"""
|
|
290
|
+
self._ensure_loaders()
|
|
291
|
+
records = self._schema_loader.list_all(repo=repo)
|
|
292
|
+
return [rec.get("value", rec) for rec in records]
|
|
293
|
+
|
|
294
|
+
def iter_schemas(self) -> Iterator[dict]:
|
|
295
|
+
"""Lazily iterate over all schema records.
|
|
296
|
+
|
|
297
|
+
Yields:
|
|
298
|
+
Schema records as dictionaries.
|
|
299
|
+
"""
|
|
300
|
+
self._ensure_loaders()
|
|
301
|
+
records = self._schema_loader.list_all()
|
|
302
|
+
for rec in records:
|
|
303
|
+
yield rec.get("value", rec)
|
|
304
|
+
|
|
305
|
+
def decode_schema(self, ref: str) -> type:
|
|
306
|
+
"""Reconstruct a Python type from a schema record.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
ref: AT URI of the schema record.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Dynamically generated Packable type.
|
|
313
|
+
"""
|
|
314
|
+
from ._schema_codec import schema_to_type
|
|
315
|
+
|
|
316
|
+
schema = self.get_schema(ref)
|
|
317
|
+
return schema_to_type(schema)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
__all__ = [
|
|
321
|
+
"Repository",
|
|
322
|
+
"create_repository",
|
|
323
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Data stores for atdata datasets.
|
|
2
|
+
|
|
3
|
+
Key classes:
|
|
4
|
+
|
|
5
|
+
- ``LocalDiskStore``: Local filesystem data store.
|
|
6
|
+
- ``S3DataStore``: S3-compatible object storage.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from atdata.stores._disk import LocalDiskStore
|
|
10
|
+
from atdata.stores._s3 import (
|
|
11
|
+
S3DataStore,
|
|
12
|
+
_s3_env,
|
|
13
|
+
_s3_from_credentials,
|
|
14
|
+
_create_s3_write_callbacks,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"LocalDiskStore",
|
|
19
|
+
"S3DataStore",
|
|
20
|
+
"_s3_env",
|
|
21
|
+
"_s3_from_credentials",
|
|
22
|
+
"_create_s3_write_callbacks",
|
|
23
|
+
]
|
atdata/stores/_disk.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Local filesystem data store for WebDataset shards.
|
|
2
|
+
|
|
3
|
+
Writes and reads WebDataset tar archives on the local filesystem,
|
|
4
|
+
implementing the ``AbstractDataStore`` protocol.
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
>>> store = LocalDiskStore(root="~/.atdata/data")
|
|
8
|
+
>>> urls = store.write_shards(dataset, prefix="mnist/v1")
|
|
9
|
+
>>> print(urls[0])
|
|
10
|
+
/home/user/.atdata/data/mnist/v1/data--a1b2c3--000000.tar
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, Any
|
|
17
|
+
from uuid import uuid4
|
|
18
|
+
|
|
19
|
+
import webdataset as wds
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from atdata.dataset import Dataset
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LocalDiskStore:
|
|
26
|
+
"""Local filesystem data store.
|
|
27
|
+
|
|
28
|
+
Writes WebDataset shards to a directory on disk. Implements the
|
|
29
|
+
``AbstractDataStore`` protocol for use with ``Index``.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
root: Root directory for shard storage. Defaults to
|
|
33
|
+
``~/.atdata/data/``. Created automatically if it does
|
|
34
|
+
not exist.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
>>> store = LocalDiskStore()
|
|
38
|
+
>>> urls = store.write_shards(dataset, prefix="my-dataset")
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, root: str | Path | None = None) -> None:
|
|
42
|
+
if root is None:
|
|
43
|
+
root = Path.home() / ".atdata" / "data"
|
|
44
|
+
self._root = Path(root).expanduser().resolve()
|
|
45
|
+
self._root.mkdir(parents=True, exist_ok=True)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def root(self) -> Path:
|
|
49
|
+
"""Root directory for shard storage."""
|
|
50
|
+
return self._root
|
|
51
|
+
|
|
52
|
+
def write_shards(
|
|
53
|
+
self,
|
|
54
|
+
ds: "Dataset",
|
|
55
|
+
*,
|
|
56
|
+
prefix: str,
|
|
57
|
+
**kwargs: Any,
|
|
58
|
+
) -> list[str]:
|
|
59
|
+
"""Write dataset shards to the local filesystem.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
ds: The Dataset to write.
|
|
63
|
+
prefix: Path prefix within root (e.g., ``'datasets/mnist/v1'``).
|
|
64
|
+
**kwargs: Additional args passed to ``wds.writer.ShardWriter``
|
|
65
|
+
(e.g., ``maxcount``, ``maxsize``).
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List of absolute file paths for the written shards.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
RuntimeError: If no shards were written.
|
|
72
|
+
"""
|
|
73
|
+
shard_dir = self._root / prefix
|
|
74
|
+
shard_dir.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
|
|
76
|
+
new_uuid = str(uuid4())[:8]
|
|
77
|
+
shard_pattern = str(shard_dir / f"data--{new_uuid}--%06d.tar")
|
|
78
|
+
|
|
79
|
+
written_shards: list[str] = []
|
|
80
|
+
|
|
81
|
+
def _track_shard(path: str) -> None:
|
|
82
|
+
written_shards.append(str(Path(path).resolve()))
|
|
83
|
+
|
|
84
|
+
# Filter out kwargs that are specific to other stores (e.g. S3)
|
|
85
|
+
# and not understood by wds.writer.ShardWriter / TarWriter.
|
|
86
|
+
writer_kwargs = {k: v for k, v in kwargs.items() if k not in ("cache_local",)}
|
|
87
|
+
|
|
88
|
+
with wds.writer.ShardWriter(
|
|
89
|
+
shard_pattern,
|
|
90
|
+
post=_track_shard,
|
|
91
|
+
**writer_kwargs,
|
|
92
|
+
) as sink:
|
|
93
|
+
for sample in ds.ordered(batch_size=None):
|
|
94
|
+
sink.write(sample.as_wds)
|
|
95
|
+
|
|
96
|
+
if not written_shards:
|
|
97
|
+
raise RuntimeError(
|
|
98
|
+
f"No shards written for prefix {prefix!r} in {self._root}"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return written_shards
|
|
102
|
+
|
|
103
|
+
def read_url(self, url: str) -> str:
|
|
104
|
+
"""Resolve a storage URL for reading.
|
|
105
|
+
|
|
106
|
+
Local filesystem paths are returned as-is since WebDataset
|
|
107
|
+
can read them directly.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
url: Absolute file path to a shard.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The same path, unchanged.
|
|
114
|
+
"""
|
|
115
|
+
return url
|
|
116
|
+
|
|
117
|
+
def supports_streaming(self) -> bool:
|
|
118
|
+
"""Whether this store supports streaming reads.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
``True`` — local filesystem supports streaming.
|
|
122
|
+
"""
|
|
123
|
+
return True
|