syvain-training-data 0.0.118__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syvain_training_data-0.0.118/PKG-INFO +121 -0
- syvain_training_data-0.0.118/README.md +106 -0
- syvain_training_data-0.0.118/pyproject.toml +22 -0
- syvain_training_data-0.0.118/src/syvain_training_data/__init__.py +23 -0
- syvain_training_data-0.0.118/src/syvain_training_data/client.py +82 -0
- syvain_training_data-0.0.118/src/syvain_training_data/gzjsonl.py +141 -0
- syvain_training_data-0.0.118/src/syvain_training_data/loader.py +160 -0
- syvain_training_data-0.0.118/src/syvain_training_data/manifest.py +198 -0
- syvain_training_data-0.0.118/src/syvain_training_data/py.typed +1 -0
- syvain_training_data-0.0.118/src/syvain_training_data/saver.py +191 -0
- syvain_training_data-0.0.118/src/syvain_training_data/storage.py +126 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: syvain-training-data
|
|
3
|
+
Version: 0.0.118
|
|
4
|
+
Summary: Syvain training data manifest, loading, and saving utilities
|
|
5
|
+
Requires-Dist: obstore>=0.10.1
|
|
6
|
+
Requires-Dist: pydantic>=2.13.3
|
|
7
|
+
Requires-Dist: torch>=2.0.0,<2.11
|
|
8
|
+
Requires-Dist: typing-extensions>=4.15.0
|
|
9
|
+
Requires-Dist: pytest>=8.0.0 ; extra == 'dev'
|
|
10
|
+
Requires-Dist: ruff>=0.15.12 ; extra == 'dev'
|
|
11
|
+
Requires-Dist: ty>=0.0.34 ; extra == 'dev'
|
|
12
|
+
Requires-Python: >=3.10
|
|
13
|
+
Provides-Extra: dev
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# syvain-training-data
|
|
17
|
+
|
|
18
|
+
Internal [Syvain](https://syvain.com/) data utility. No secret sauce here, just
|
|
19
|
+
a shared helper.
|
|
20
|
+
|
|
21
|
+
> This is my dataloader. There are many like it, but this one is mine. My
|
|
22
|
+
> dataloader is my best friend. It is my life. I must master it as I must master
|
|
23
|
+
> my life. My dataloader, without me, is useless. Without my dataloader, I am
|
|
24
|
+
> useless.
|
|
25
|
+
|
|
26
|
+
## Install
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
uv add syvain-training-data
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Load data
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
from syvain_training_data import SyvainTrainingData
|
|
36
|
+
|
|
37
|
+
training_data = SyvainTrainingData(
|
|
38
|
+
s3_base_url="https://t3.storage.dev",
|
|
39
|
+
region="auto",
|
|
40
|
+
access_key_id="...",
|
|
41
|
+
secret_access_key="...",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def collate(records):
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
loader = training_data.split_data_loader(
|
|
50
|
+
"s3://my-training-bucket/path/to/data-manifest-v1.json",
|
|
51
|
+
collate_fn=collate,
|
|
52
|
+
dataloader_args={"batch_size": 32, "num_workers": 4, ...},
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
train_batches = loader.load("train")
|
|
56
|
+
valid_batches = loader.load("valid")
|
|
57
|
+
easy_batches = loader.load("train", curriculum_stage="easy")
|
|
58
|
+
infinite_train_batches = loader.load("train", infinite_iter=True)
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Save data
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
65
|
+
|
|
66
|
+
from syvain_training_data import SyvainTrainingData
|
|
67
|
+
|
|
68
|
+
def generate_data(split, curriculum_stage, shard_id):
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
def save_shard(job):
|
|
72
|
+
saver, split, curriculum_stage, shard_id = job
|
|
73
|
+
records = generate_data(split, curriculum_stage, shard_id)
|
|
74
|
+
saver.save(split, curriculum_stage, records)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
training_data = SyvainTrainingData(
|
|
78
|
+
s3_base_url="https://t3.storage.dev",
|
|
79
|
+
region="auto",
|
|
80
|
+
access_key_id="...",
|
|
81
|
+
secret_access_key="...",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
saver = training_data.dataset_saver(
|
|
85
|
+
"s3://my-training-bucket/path/to/dataset/data-manifest-v1.json",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
jobs = [
|
|
89
|
+
(saver, "train", stage, shard_id)
|
|
90
|
+
for stage in ["easy", "medium", "hard"]
|
|
91
|
+
for shard_id in range(32)
|
|
92
|
+
] + [
|
|
93
|
+
(saver, "valid", None, shard_id) for shard_id in range(4)
|
|
94
|
+
] + [
|
|
95
|
+
(saver, "test", None, shard_id) for shard_id in range(4)
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
with ProcessPoolExecutor(max_workers=8) as pool:
|
|
99
|
+
list(pool.map(save_shard, jobs))
|
|
100
|
+
|
|
101
|
+
manifest = saver.commit_manifest()
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
## Copy a manifest
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
from syvain_training_data import SyvainTrainingData
|
|
108
|
+
|
|
109
|
+
training_data = SyvainTrainingData(
|
|
110
|
+
s3_base_url="https://t3.storage.dev",
|
|
111
|
+
region="auto",
|
|
112
|
+
access_key_id="...",
|
|
113
|
+
secret_access_key="...",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
manifest = training_data.load_manifest("s3://my-training-bucket/shared/data-manifest-v1.json")
|
|
117
|
+
|
|
118
|
+
# Do modifications if needed
|
|
119
|
+
|
|
120
|
+
training_data.save_manifest("s3://my-training-bucket/new-run/data-manifest-v1.json", manifest)
|
|
121
|
+
```
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# syvain-training-data
|
|
2
|
+
|
|
3
|
+
Internal [Syvain](https://syvain.com/) data utility. No secret sauce here, just
|
|
4
|
+
a shared helper.
|
|
5
|
+
|
|
6
|
+
> This is my dataloader. There are many like it, but this one is mine. My
|
|
7
|
+
> dataloader is my best friend. It is my life. I must master it as I must master
|
|
8
|
+
> my life. My dataloader, without me, is useless. Without my dataloader, I am
|
|
9
|
+
> useless.
|
|
10
|
+
|
|
11
|
+
## Install
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
uv add syvain-training-data
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Load data
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
from syvain_training_data import SyvainTrainingData
|
|
21
|
+
|
|
22
|
+
training_data = SyvainTrainingData(
|
|
23
|
+
s3_base_url="https://t3.storage.dev",
|
|
24
|
+
region="auto",
|
|
25
|
+
access_key_id="...",
|
|
26
|
+
secret_access_key="...",
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def collate(records):
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
loader = training_data.split_data_loader(
|
|
35
|
+
"s3://my-training-bucket/path/to/data-manifest-v1.json",
|
|
36
|
+
collate_fn=collate,
|
|
37
|
+
dataloader_args={"batch_size": 32, "num_workers": 4, ...},
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
train_batches = loader.load("train")
|
|
41
|
+
valid_batches = loader.load("valid")
|
|
42
|
+
easy_batches = loader.load("train", curriculum_stage="easy")
|
|
43
|
+
infinite_train_batches = loader.load("train", infinite_iter=True)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
## Save data
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
50
|
+
|
|
51
|
+
from syvain_training_data import SyvainTrainingData
|
|
52
|
+
|
|
53
|
+
def generate_data(split, curriculum_stage, shard_id):
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
def save_shard(job):
|
|
57
|
+
saver, split, curriculum_stage, shard_id = job
|
|
58
|
+
records = generate_data(split, curriculum_stage, shard_id)
|
|
59
|
+
saver.save(split, curriculum_stage, records)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
training_data = SyvainTrainingData(
|
|
63
|
+
s3_base_url="https://t3.storage.dev",
|
|
64
|
+
region="auto",
|
|
65
|
+
access_key_id="...",
|
|
66
|
+
secret_access_key="...",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
saver = training_data.dataset_saver(
|
|
70
|
+
"s3://my-training-bucket/path/to/dataset/data-manifest-v1.json",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
jobs = [
|
|
74
|
+
(saver, "train", stage, shard_id)
|
|
75
|
+
for stage in ["easy", "medium", "hard"]
|
|
76
|
+
for shard_id in range(32)
|
|
77
|
+
] + [
|
|
78
|
+
(saver, "valid", None, shard_id) for shard_id in range(4)
|
|
79
|
+
] + [
|
|
80
|
+
(saver, "test", None, shard_id) for shard_id in range(4)
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
with ProcessPoolExecutor(max_workers=8) as pool:
|
|
84
|
+
list(pool.map(save_shard, jobs))
|
|
85
|
+
|
|
86
|
+
manifest = saver.commit_manifest()
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
## Copy a manifest
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
from syvain_training_data import SyvainTrainingData
|
|
93
|
+
|
|
94
|
+
training_data = SyvainTrainingData(
|
|
95
|
+
s3_base_url="https://t3.storage.dev",
|
|
96
|
+
region="auto",
|
|
97
|
+
access_key_id="...",
|
|
98
|
+
secret_access_key="...",
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
manifest = training_data.load_manifest("s3://my-training-bucket/shared/data-manifest-v1.json")
|
|
102
|
+
|
|
103
|
+
# Do modifications if needed
|
|
104
|
+
|
|
105
|
+
training_data.save_manifest("s3://my-training-bucket/new-run/data-manifest-v1.json", manifest)
|
|
106
|
+
```
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "syvain-training-data"
|
|
3
|
+
version = "0.0.118"
|
|
4
|
+
description = "Syvain training data manifest, loading, and saving utilities"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.10"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"obstore>=0.10.1",
|
|
9
|
+
"pydantic>=2.13.3",
|
|
10
|
+
"torch>=2.0.0,<2.11",
|
|
11
|
+
"typing-extensions>=4.15.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[project.optional-dependencies]
|
|
15
|
+
dev = ["pytest>=8.0.0", "ruff>=0.15.12", "ty>=0.0.34"]
|
|
16
|
+
|
|
17
|
+
[tool.pytest.ini_options]
|
|
18
|
+
testpaths = ["tests"]
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["uv_build>=0.11.9,<0.12"]
|
|
22
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from syvain_training_data.client import SyvainTrainingData
|
|
2
|
+
from syvain_training_data.loader import SplitDataLoader
|
|
3
|
+
from syvain_training_data.manifest import (
|
|
4
|
+
DataManifest,
|
|
5
|
+
JsonObject,
|
|
6
|
+
JsonValue,
|
|
7
|
+
ManifestCurriculumStage,
|
|
8
|
+
ManifestDatasetEntry,
|
|
9
|
+
ManifestShard,
|
|
10
|
+
)
|
|
11
|
+
from syvain_training_data.saver import DatasetSaver
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"DataManifest",
|
|
15
|
+
"DatasetSaver",
|
|
16
|
+
"JsonObject",
|
|
17
|
+
"JsonValue",
|
|
18
|
+
"ManifestCurriculumStage",
|
|
19
|
+
"ManifestDatasetEntry",
|
|
20
|
+
"ManifestShard",
|
|
21
|
+
"SplitDataLoader",
|
|
22
|
+
"SyvainTrainingData",
|
|
23
|
+
]
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
|
+
|
|
6
|
+
from syvain_training_data.loader import SplitDataLoader
|
|
7
|
+
from syvain_training_data.manifest import (
|
|
8
|
+
DataManifest,
|
|
9
|
+
manifest_json_bytes,
|
|
10
|
+
parse_manifest_json,
|
|
11
|
+
)
|
|
12
|
+
from syvain_training_data.saver import DatasetSaver
|
|
13
|
+
from syvain_training_data.storage import (
|
|
14
|
+
StorageConfig,
|
|
15
|
+
put_object_bytes,
|
|
16
|
+
read_object_bytes,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
BatchT = TypeVar("BatchT")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SyvainTrainingData:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
*,
|
|
26
|
+
s3_base_url: str,
|
|
27
|
+
region: str,
|
|
28
|
+
access_key_id: str,
|
|
29
|
+
secret_access_key: str,
|
|
30
|
+
virtual_hosted_style_request: bool = False,
|
|
31
|
+
) -> None:
|
|
32
|
+
self._storage_config = StorageConfig(
|
|
33
|
+
s3_base_url=s3_base_url,
|
|
34
|
+
region=region,
|
|
35
|
+
access_key_id=access_key_id,
|
|
36
|
+
secret_access_key=secret_access_key,
|
|
37
|
+
virtual_hosted_style_request=virtual_hosted_style_request,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def load_manifest(self, manifest_uri: str) -> DataManifest:
|
|
41
|
+
body = read_object_bytes(
|
|
42
|
+
manifest_uri,
|
|
43
|
+
storage_config=self._storage_config,
|
|
44
|
+
)
|
|
45
|
+
return parse_manifest_json(body, source=manifest_uri)
|
|
46
|
+
|
|
47
|
+
def save_manifest(
|
|
48
|
+
self,
|
|
49
|
+
manifest_uri: str,
|
|
50
|
+
manifest: DataManifest | Mapping[str, object],
|
|
51
|
+
) -> None:
|
|
52
|
+
put_object_bytes(
|
|
53
|
+
manifest_uri,
|
|
54
|
+
manifest_json_bytes(manifest),
|
|
55
|
+
storage_config=self._storage_config,
|
|
56
|
+
attributes={"Content-Type": "application/json"},
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def split_data_loader(
|
|
60
|
+
self,
|
|
61
|
+
manifest_uri: str,
|
|
62
|
+
collate_fn: Callable[[Sequence[Mapping[str, Any]]], BatchT],
|
|
63
|
+
dataloader_args: Mapping[str, Any] | None = None,
|
|
64
|
+
) -> SplitDataLoader[BatchT]:
|
|
65
|
+
return SplitDataLoader(
|
|
66
|
+
manifest=self.load_manifest(manifest_uri),
|
|
67
|
+
storage_config=self._storage_config,
|
|
68
|
+
collate_fn=collate_fn,
|
|
69
|
+
dataloader_args={} if dataloader_args is None else dict(dataloader_args),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def dataset_saver(
|
|
73
|
+
self,
|
|
74
|
+
manifest_uri: str,
|
|
75
|
+
) -> DatasetSaver:
|
|
76
|
+
return DatasetSaver(
|
|
77
|
+
storage_config=self._storage_config,
|
|
78
|
+
manifest_uri=manifest_uri,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
__all__ = ["SyvainTrainingData"]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import codecs
|
|
4
|
+
import gzip
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import zlib
|
|
8
|
+
from collections.abc import Iterable, Iterator, Mapping
|
|
9
|
+
from io import BytesIO
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from torch.utils.data import get_worker_info
|
|
13
|
+
|
|
14
|
+
from syvain_training_data.manifest import ManifestShard
|
|
15
|
+
from syvain_training_data.storage import (
|
|
16
|
+
StorageConfig,
|
|
17
|
+
parse_s3_uri,
|
|
18
|
+
store_for_uri,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
LOGGER = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
def records_to_body(records: Iterable[Mapping[str, Any]]) -> tuple[bytes, int]:
|
|
24
|
+
output = BytesIO()
|
|
25
|
+
record_count = 0
|
|
26
|
+
with gzip.GzipFile(
|
|
27
|
+
fileobj=output, mode="wb", compresslevel=6, mtime=0
|
|
28
|
+
) as gzip_file:
|
|
29
|
+
for record in records:
|
|
30
|
+
line = json.dumps(record, separators=(",", ":")).encode("utf-8")
|
|
31
|
+
gzip_file.write(line)
|
|
32
|
+
gzip_file.write(b"\n")
|
|
33
|
+
record_count += 1
|
|
34
|
+
return output.getvalue(), record_count
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def iter_shard(
|
|
38
|
+
shard: ManifestShard,
|
|
39
|
+
*,
|
|
40
|
+
storage_config: StorageConfig,
|
|
41
|
+
) -> Iterator[Mapping[str, Any]]:
|
|
42
|
+
source = shard.uri
|
|
43
|
+
parsed_uri = parse_s3_uri(source)
|
|
44
|
+
worker = get_worker_info()
|
|
45
|
+
worker_id = "main" if worker is None else str(worker.id)
|
|
46
|
+
yielded_records = 0
|
|
47
|
+
LOGGER.info(
|
|
48
|
+
"Worker %s reading shard %s with %s expected records",
|
|
49
|
+
worker_id,
|
|
50
|
+
source,
|
|
51
|
+
shard.records,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
store, _ = store_for_uri(source, storage_config=storage_config)
|
|
56
|
+
response = store.get(parsed_uri.key)
|
|
57
|
+
except Exception as exc:
|
|
58
|
+
raise RuntimeError(f"Unreadable shard {source}") from exc
|
|
59
|
+
|
|
60
|
+
decompressor = zlib.decompressobj(wbits=16 + zlib.MAX_WBITS)
|
|
61
|
+
decoder = codecs.getincrementaldecoder("utf-8")()
|
|
62
|
+
pending_line = ""
|
|
63
|
+
line_number = 0
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
for chunk in response.stream(min_chunk_size=1024 * 1024):
|
|
67
|
+
decoded = decoder.decode(decompressor.decompress(chunk), final=False)
|
|
68
|
+
if not decoded:
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
pending_line += decoded
|
|
72
|
+
lines = pending_line.split("\n")
|
|
73
|
+
pending_line = lines.pop()
|
|
74
|
+
for line in lines:
|
|
75
|
+
line_number += 1
|
|
76
|
+
if not line.strip():
|
|
77
|
+
continue
|
|
78
|
+
record = _parse_record_line(
|
|
79
|
+
line,
|
|
80
|
+
source=source,
|
|
81
|
+
line_number=line_number,
|
|
82
|
+
)
|
|
83
|
+
if record is not None:
|
|
84
|
+
yielded_records += 1
|
|
85
|
+
yield record
|
|
86
|
+
|
|
87
|
+
pending_line += decoder.decode(decompressor.flush(), final=True)
|
|
88
|
+
except (zlib.error, UnicodeDecodeError) as exc:
|
|
89
|
+
raise RuntimeError(
|
|
90
|
+
f"Unreadable gzip-compressed UTF-8 JSONL shard {source}"
|
|
91
|
+
) from exc
|
|
92
|
+
|
|
93
|
+
if not decompressor.eof or decompressor.unused_data:
|
|
94
|
+
raise RuntimeError(
|
|
95
|
+
"Invalid gzip-compressed UTF-8 JSONL shard "
|
|
96
|
+
f"{source}: eof={decompressor.eof} "
|
|
97
|
+
f"unused_data={bool(decompressor.unused_data)}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if pending_line.strip():
|
|
101
|
+
line_number += 1
|
|
102
|
+
record = _parse_record_line(
|
|
103
|
+
pending_line,
|
|
104
|
+
source=source,
|
|
105
|
+
line_number=line_number,
|
|
106
|
+
)
|
|
107
|
+
if record is not None:
|
|
108
|
+
yielded_records += 1
|
|
109
|
+
yield record
|
|
110
|
+
|
|
111
|
+
LOGGER.info(
|
|
112
|
+
"Worker %s finished shard %s with %s/%s yielded records",
|
|
113
|
+
worker_id,
|
|
114
|
+
source,
|
|
115
|
+
yielded_records,
|
|
116
|
+
shard.records,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _parse_record_line(
|
|
121
|
+
line: str,
|
|
122
|
+
*,
|
|
123
|
+
source: str,
|
|
124
|
+
line_number: int,
|
|
125
|
+
) -> Mapping[str, Any] | None:
|
|
126
|
+
try:
|
|
127
|
+
raw_record = json.loads(line)
|
|
128
|
+
except json.JSONDecodeError as exc:
|
|
129
|
+
LOGGER.error(
|
|
130
|
+
"Skipping invalid JSON record in %s:%s: %s", source, line_number, exc
|
|
131
|
+
)
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
if not isinstance(raw_record, Mapping):
|
|
135
|
+
LOGGER.error("Skipping non-object JSON record in %s:%s", source, line_number)
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
return raw_record
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
__all__ = ["iter_shard", "records_to_body"]
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Callable, Iterator, Mapping, Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Generic, TypeVar, cast, overload
|
|
7
|
+
|
|
8
|
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
9
|
+
from typing_extensions import Literal
|
|
10
|
+
|
|
11
|
+
from syvain_training_data import gzjsonl
|
|
12
|
+
from syvain_training_data.manifest import (
|
|
13
|
+
DataManifest,
|
|
14
|
+
DataManifestDataFormat,
|
|
15
|
+
ManifestShard,
|
|
16
|
+
curriculum_stage_entry,
|
|
17
|
+
manifest_split_entry,
|
|
18
|
+
)
|
|
19
|
+
from syvain_training_data.storage import StorageConfig
|
|
20
|
+
|
|
21
|
+
LOGGER = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
BatchT = TypeVar("BatchT")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class SplitDataLoader(Generic[BatchT]):
|
|
28
|
+
manifest: DataManifest
|
|
29
|
+
storage_config: StorageConfig
|
|
30
|
+
collate_fn: Callable[[Sequence[Mapping[str, Any]]], BatchT]
|
|
31
|
+
dataloader_args: Mapping[str, Any]
|
|
32
|
+
|
|
33
|
+
@overload
|
|
34
|
+
def load(
|
|
35
|
+
self,
|
|
36
|
+
split: str,
|
|
37
|
+
curriculum_stage: str | None = None,
|
|
38
|
+
infinite_iter: Literal[False] = False,
|
|
39
|
+
) -> DataLoader[BatchT]: ...
|
|
40
|
+
|
|
41
|
+
@overload
|
|
42
|
+
def load(
|
|
43
|
+
self,
|
|
44
|
+
split: str,
|
|
45
|
+
curriculum_stage: str | None = None,
|
|
46
|
+
infinite_iter: Literal[True] = True,
|
|
47
|
+
) -> Iterator[BatchT]: ...
|
|
48
|
+
|
|
49
|
+
def load(
|
|
50
|
+
self,
|
|
51
|
+
split: str,
|
|
52
|
+
curriculum_stage: str | None = None,
|
|
53
|
+
infinite_iter: bool = False,
|
|
54
|
+
) -> DataLoader[BatchT] | Iterator[BatchT]:
|
|
55
|
+
split_entry = manifest_split_entry(self.manifest, split)
|
|
56
|
+
if curriculum_stage is not None:
|
|
57
|
+
split_entry = curriculum_stage_entry(split_entry, curriculum_stage)
|
|
58
|
+
|
|
59
|
+
dataset = _ManifestShardDataset(
|
|
60
|
+
shards=split_entry.shards,
|
|
61
|
+
data_format=self.manifest.data_format,
|
|
62
|
+
storage_config=self.storage_config,
|
|
63
|
+
)
|
|
64
|
+
loader_kwargs = dict(self.dataloader_args)
|
|
65
|
+
if "collate_fn" in loader_kwargs:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Pass collate_fn to split_data_loader, not dataloader_args"
|
|
68
|
+
)
|
|
69
|
+
loader_kwargs["collate_fn"] = self.collate_fn
|
|
70
|
+
loader_kwargs = _cap_worker_args(
|
|
71
|
+
loader_kwargs,
|
|
72
|
+
shard_count=dataset.shard_count,
|
|
73
|
+
split=split,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
loader = cast(DataLoader[BatchT], DataLoader(dataset, **loader_kwargs))
|
|
77
|
+
if not infinite_iter:
|
|
78
|
+
return loader
|
|
79
|
+
return _iter_dataloader_forever(loader, split=split)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class _ManifestShardDataset(IterableDataset[Mapping[str, Any]]):
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
*,
|
|
86
|
+
shards: Sequence[ManifestShard],
|
|
87
|
+
data_format: DataManifestDataFormat,
|
|
88
|
+
storage_config: StorageConfig,
|
|
89
|
+
) -> None:
|
|
90
|
+
self._shards = tuple(shards)
|
|
91
|
+
self._data_format = data_format
|
|
92
|
+
self._storage_config = storage_config
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def shard_count(self) -> int:
|
|
96
|
+
return len(self._shards)
|
|
97
|
+
|
|
98
|
+
def __iter__(self) -> Iterator[Mapping[str, Any]]:
|
|
99
|
+
for shard in _worker_shards(self._shards):
|
|
100
|
+
if self._data_format == "jsonl.gz":
|
|
101
|
+
yield from gzjsonl.iter_shard(
|
|
102
|
+
shard,
|
|
103
|
+
storage_config=self._storage_config,
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError(f"Unknown data format: {self._data_format}")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _cap_worker_args(
|
|
110
|
+
loader_kwargs: dict[str, Any],
|
|
111
|
+
*,
|
|
112
|
+
shard_count: int,
|
|
113
|
+
split: str,
|
|
114
|
+
) -> dict[str, Any]:
|
|
115
|
+
requested_workers = int(loader_kwargs.get("num_workers", 0))
|
|
116
|
+
if requested_workers < 0:
|
|
117
|
+
raise ValueError("dataloader_args['num_workers'] must be non-negative")
|
|
118
|
+
|
|
119
|
+
num_workers = min(requested_workers, shard_count)
|
|
120
|
+
loader_kwargs["num_workers"] = num_workers
|
|
121
|
+
if num_workers < requested_workers:
|
|
122
|
+
LOGGER.debug(
|
|
123
|
+
"Capping %s dataloader workers from %s to %s because the split has %s shards",
|
|
124
|
+
split,
|
|
125
|
+
requested_workers,
|
|
126
|
+
num_workers,
|
|
127
|
+
shard_count,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if num_workers == 0:
|
|
131
|
+
loader_kwargs.pop("persistent_workers", None)
|
|
132
|
+
loader_kwargs.pop("prefetch_factor", None)
|
|
133
|
+
loader_kwargs.pop("multiprocessing_context", None)
|
|
134
|
+
if int(loader_kwargs.get("timeout", 0)) > 0:
|
|
135
|
+
loader_kwargs.pop("timeout", None)
|
|
136
|
+
return loader_kwargs
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _worker_shards(shards: Sequence[ManifestShard]) -> Sequence[ManifestShard]:
|
|
140
|
+
worker = get_worker_info()
|
|
141
|
+
if worker is None:
|
|
142
|
+
return shards
|
|
143
|
+
return shards[worker.id :: worker.num_workers]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _iter_dataloader_forever(
|
|
147
|
+
loader: DataLoader[BatchT],
|
|
148
|
+
*,
|
|
149
|
+
split: str,
|
|
150
|
+
) -> Iterator[BatchT]:
|
|
151
|
+
while True:
|
|
152
|
+
yielded_batch = False
|
|
153
|
+
for batch in loader:
|
|
154
|
+
yielded_batch = True
|
|
155
|
+
yield batch
|
|
156
|
+
if not yielded_batch:
|
|
157
|
+
raise RuntimeError(f"{split} split produced no batches")
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
__all__ = ["SplitDataLoader"]
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Literal, Mapping
|
|
5
|
+
from urllib.parse import urlparse
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
|
8
|
+
from typing_extensions import TypeAliasType
|
|
9
|
+
|
|
10
|
+
JsonValue = TypeAliasType(
|
|
11
|
+
"JsonValue",
|
|
12
|
+
str | int | float | bool | None | list["JsonValue"] | dict[str, "JsonValue"],
|
|
13
|
+
)
|
|
14
|
+
JsonObject = TypeAliasType("JsonObject", dict[str, JsonValue])
|
|
15
|
+
DataManifestDataFormat = TypeAliasType("DataManifestDataFormat", Literal["jsonl.gz"])
|
|
16
|
+
|
|
17
|
+
MANIFEST_VERSION = 1
|
|
18
|
+
|
|
19
|
+
MODEL_CONFIG = ConfigDict(extra="forbid", frozen=True, strict=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ManifestShard(BaseModel):
|
|
23
|
+
model_config = MODEL_CONFIG
|
|
24
|
+
|
|
25
|
+
uri: str = Field(min_length=1)
|
|
26
|
+
records: int = Field(gt=0)
|
|
27
|
+
size_bytes: int = Field(gt=0)
|
|
28
|
+
etag: str = ""
|
|
29
|
+
|
|
30
|
+
@field_validator("uri")
|
|
31
|
+
@classmethod
|
|
32
|
+
def validate_uri(cls, value: str) -> str:
|
|
33
|
+
parsed = urlparse(value)
|
|
34
|
+
if parsed.scheme != "s3" or not parsed.netloc or not parsed.path.lstrip("/"):
|
|
35
|
+
raise ValueError(f"Expected non-empty s3://bucket/key URI, got: {value}")
|
|
36
|
+
return value
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ManifestCurriculumStage(BaseModel):
|
|
40
|
+
model_config = MODEL_CONFIG
|
|
41
|
+
|
|
42
|
+
records: int = Field(gt=0)
|
|
43
|
+
shards: list[ManifestShard] = Field(min_length=1)
|
|
44
|
+
|
|
45
|
+
@field_validator("shards")
|
|
46
|
+
@classmethod
|
|
47
|
+
def validate_shard_records(cls, value: list[ManifestShard]) -> list[ManifestShard]:
|
|
48
|
+
if sum(shard.records for shard in value) <= 0:
|
|
49
|
+
raise ValueError("stage shards must contain records")
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
def model_post_init(self, __context: object) -> None:
|
|
53
|
+
shard_records = sum(shard.records for shard in self.shards)
|
|
54
|
+
if self.records != shard_records:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"stage records {self.records} do not match shard records "
|
|
57
|
+
f"{shard_records}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ManifestDatasetEntry(BaseModel):
|
|
62
|
+
model_config = MODEL_CONFIG
|
|
63
|
+
|
|
64
|
+
records: int = Field(gt=0)
|
|
65
|
+
shards: list[ManifestShard] = Field(min_length=1)
|
|
66
|
+
curriculum: dict[str, ManifestCurriculumStage] = Field(default_factory=dict)
|
|
67
|
+
|
|
68
|
+
@field_validator("curriculum")
|
|
69
|
+
@classmethod
|
|
70
|
+
def validate_curriculum(
|
|
71
|
+
cls, value: dict[str, ManifestCurriculumStage]
|
|
72
|
+
) -> dict[str, ManifestCurriculumStage]:
|
|
73
|
+
for stage_name in value:
|
|
74
|
+
if not stage_name:
|
|
75
|
+
raise ValueError("curriculum stage names must not be empty")
|
|
76
|
+
return value
|
|
77
|
+
|
|
78
|
+
def model_post_init(self, __context: object) -> None:
|
|
79
|
+
shard_records = sum(shard.records for shard in self.shards)
|
|
80
|
+
if self.records != shard_records:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"split records {self.records} do not match shard records "
|
|
83
|
+
f"{shard_records}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class DataManifest(BaseModel):
|
|
88
|
+
model_config = MODEL_CONFIG
|
|
89
|
+
|
|
90
|
+
version: Literal[1]
|
|
91
|
+
data_format: DataManifestDataFormat
|
|
92
|
+
splits: dict[str, ManifestDatasetEntry] = Field(min_length=1)
|
|
93
|
+
|
|
94
|
+
@field_validator("splits")
|
|
95
|
+
@classmethod
|
|
96
|
+
def validate_splits(
|
|
97
|
+
cls, value: dict[str, ManifestDatasetEntry]
|
|
98
|
+
) -> dict[str, ManifestDatasetEntry]:
|
|
99
|
+
for split in value:
|
|
100
|
+
if not split:
|
|
101
|
+
raise ValueError("split names must not be empty")
|
|
102
|
+
return value
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def manifest_split_entry(manifest: DataManifest, split: str) -> ManifestDatasetEntry:
|
|
106
|
+
try:
|
|
107
|
+
return manifest.splits[split]
|
|
108
|
+
except KeyError as exc:
|
|
109
|
+
available = ", ".join(sorted(manifest.splits))
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Unknown split {split!r}; available splits: {available}"
|
|
112
|
+
) from exc
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def curriculum_stage_entry(
|
|
116
|
+
split_entry: ManifestDatasetEntry,
|
|
117
|
+
stage_name: str,
|
|
118
|
+
) -> ManifestCurriculumStage:
|
|
119
|
+
try:
|
|
120
|
+
return split_entry.curriculum[stage_name]
|
|
121
|
+
except KeyError as exc:
|
|
122
|
+
available = ", ".join(sorted(split_entry.curriculum))
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Unknown curriculum stage {stage_name!r}; available stages: {available}"
|
|
125
|
+
) from exc
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def parse_manifest_json(body: bytes, *, source: str) -> DataManifest:
|
|
129
|
+
try:
|
|
130
|
+
raw_manifest = json.loads(body.decode("utf-8"))
|
|
131
|
+
return DataManifest.model_validate(raw_manifest)
|
|
132
|
+
except UnicodeDecodeError as exc:
|
|
133
|
+
raise ValueError(f"Data manifest is not valid UTF-8: {source}") from exc
|
|
134
|
+
except json.JSONDecodeError as exc:
|
|
135
|
+
raise ValueError(f"Data manifest is not valid JSON: {source}") from exc
|
|
136
|
+
except ValidationError as exc:
|
|
137
|
+
raise ValueError(f"Data manifest has invalid shape: {source}") from exc
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def manifest_json_bytes(manifest: DataManifest | Mapping[str, object]) -> bytes:
|
|
141
|
+
typed_manifest = (
|
|
142
|
+
manifest
|
|
143
|
+
if isinstance(manifest, DataManifest)
|
|
144
|
+
else DataManifest.model_validate(manifest)
|
|
145
|
+
)
|
|
146
|
+
return (typed_manifest.model_dump_json(indent=2, exclude_none=True) + "\n").encode(
|
|
147
|
+
"utf-8"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def dataset_entry_from_shards(
|
|
152
|
+
shards: list[ManifestShard],
|
|
153
|
+
*,
|
|
154
|
+
curriculum: Mapping[str, ManifestCurriculumStage] | None = None,
|
|
155
|
+
) -> ManifestDatasetEntry:
|
|
156
|
+
return ManifestDatasetEntry(
|
|
157
|
+
records=sum(shard.records for shard in shards),
|
|
158
|
+
shards=shards,
|
|
159
|
+
curriculum={} if curriculum is None else dict(curriculum),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def curriculum_stage_from_shards(
|
|
164
|
+
shards: list[ManifestShard],
|
|
165
|
+
) -> ManifestCurriculumStage:
|
|
166
|
+
return ManifestCurriculumStage(
|
|
167
|
+
records=sum(shard.records for shard in shards),
|
|
168
|
+
shards=shards,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def manifest_from_splits(
|
|
173
|
+
splits: Mapping[str, ManifestDatasetEntry],
|
|
174
|
+
*,
|
|
175
|
+
data_format: DataManifestDataFormat = "jsonl.gz",
|
|
176
|
+
) -> DataManifest:
|
|
177
|
+
return DataManifest(
|
|
178
|
+
version=MANIFEST_VERSION, data_format=data_format, splits=dict(splits)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
__all__ = [
|
|
183
|
+
"DataManifest",
|
|
184
|
+
"DataManifestDataFormat",
|
|
185
|
+
"JsonObject",
|
|
186
|
+
"JsonValue",
|
|
187
|
+
"MANIFEST_VERSION",
|
|
188
|
+
"ManifestCurriculumStage",
|
|
189
|
+
"ManifestDatasetEntry",
|
|
190
|
+
"ManifestShard",
|
|
191
|
+
"curriculum_stage_entry",
|
|
192
|
+
"curriculum_stage_from_shards",
|
|
193
|
+
"dataset_entry_from_shards",
|
|
194
|
+
"manifest_from_splits",
|
|
195
|
+
"manifest_json_bytes",
|
|
196
|
+
"manifest_split_entry",
|
|
197
|
+
"parse_manifest_json",
|
|
198
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import multiprocessing
|
|
4
|
+
import uuid
|
|
5
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
6
|
+
from typing import Any, TypedDict, cast
|
|
7
|
+
|
|
8
|
+
from syvain_training_data import gzjsonl
|
|
9
|
+
from syvain_training_data.manifest import (
|
|
10
|
+
DataManifest,
|
|
11
|
+
DataManifestDataFormat,
|
|
12
|
+
ManifestCurriculumStage,
|
|
13
|
+
ManifestDatasetEntry,
|
|
14
|
+
ManifestShard,
|
|
15
|
+
curriculum_stage_from_shards,
|
|
16
|
+
manifest_from_splits,
|
|
17
|
+
)
|
|
18
|
+
from syvain_training_data.storage import (
|
|
19
|
+
StorageConfig,
|
|
20
|
+
join_s3_uri,
|
|
21
|
+
parse_s3_uri,
|
|
22
|
+
put_object_bytes,
|
|
23
|
+
safe_path_segment,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _SavedShard(TypedDict):
|
|
28
|
+
split: str
|
|
29
|
+
curriculum_stage: str | None
|
|
30
|
+
shard: dict[str, object]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_RecordsToBody = Callable[[Iterable[Mapping[str, Any]]], tuple[bytes, int]]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DatasetSaver:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
storage_config: StorageConfig,
|
|
41
|
+
manifest_uri: str,
|
|
42
|
+
data_format: DataManifestDataFormat = "jsonl.gz",
|
|
43
|
+
) -> None:
|
|
44
|
+
manifest_uri = manifest_uri.strip()
|
|
45
|
+
if not manifest_uri:
|
|
46
|
+
raise ValueError("manifest_uri must be provided")
|
|
47
|
+
|
|
48
|
+
self._storage_config = storage_config
|
|
49
|
+
self._manifest_uri = manifest_uri
|
|
50
|
+
self._shards_base_uri = _manifest_shards_base_uri(manifest_uri)
|
|
51
|
+
self._data_format = data_format
|
|
52
|
+
if data_format == "jsonl.gz":
|
|
53
|
+
self._records_to_body: _RecordsToBody = gzjsonl.records_to_body
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unknown data format: {data_format}")
|
|
56
|
+
self._manager = multiprocessing.Manager()
|
|
57
|
+
self._saved_shards = self._manager.list()
|
|
58
|
+
self._lock = self._manager.RLock()
|
|
59
|
+
|
|
60
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
61
|
+
state = self.__dict__.copy()
|
|
62
|
+
state["_manager"] = None
|
|
63
|
+
return state
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def manifest_uri(self) -> str:
|
|
67
|
+
return self._manifest_uri
|
|
68
|
+
|
|
69
|
+
def save(
|
|
70
|
+
self,
|
|
71
|
+
split: str,
|
|
72
|
+
curriculum_stage: str | None,
|
|
73
|
+
records: Iterable[Mapping[str, Any]],
|
|
74
|
+
) -> ManifestShard:
|
|
75
|
+
if not split:
|
|
76
|
+
raise ValueError("split must not be empty")
|
|
77
|
+
|
|
78
|
+
body, record_count = self._records_to_body(records)
|
|
79
|
+
if record_count == 0:
|
|
80
|
+
raise ValueError("Cannot save an empty record shard")
|
|
81
|
+
|
|
82
|
+
shard_uri = self._shard_uri(split=split, curriculum_stage=curriculum_stage)
|
|
83
|
+
response = put_object_bytes(
|
|
84
|
+
shard_uri,
|
|
85
|
+
body,
|
|
86
|
+
storage_config=self._storage_config,
|
|
87
|
+
attributes={
|
|
88
|
+
"Content-Type": "application/gzip",
|
|
89
|
+
"content-format": "jsonl",
|
|
90
|
+
"compression": "gzip",
|
|
91
|
+
},
|
|
92
|
+
)
|
|
93
|
+
etag = response.get("e_tag") or response.get("etag") or ""
|
|
94
|
+
shard = ManifestShard(
|
|
95
|
+
uri=shard_uri,
|
|
96
|
+
records=record_count,
|
|
97
|
+
size_bytes=len(body),
|
|
98
|
+
etag=str(etag),
|
|
99
|
+
)
|
|
100
|
+
saved: _SavedShard = {
|
|
101
|
+
"split": split,
|
|
102
|
+
"curriculum_stage": curriculum_stage,
|
|
103
|
+
"shard": shard.model_dump(mode="json"),
|
|
104
|
+
}
|
|
105
|
+
with self._lock:
|
|
106
|
+
self._saved_shards.append(saved)
|
|
107
|
+
return shard
|
|
108
|
+
|
|
109
|
+
def commit_manifest(self) -> DataManifest:
|
|
110
|
+
manifest = self.build_manifest()
|
|
111
|
+
from syvain_training_data.manifest import manifest_json_bytes
|
|
112
|
+
|
|
113
|
+
put_object_bytes(
|
|
114
|
+
self.manifest_uri,
|
|
115
|
+
manifest_json_bytes(manifest),
|
|
116
|
+
storage_config=self._storage_config,
|
|
117
|
+
attributes={"Content-Type": "application/json"},
|
|
118
|
+
)
|
|
119
|
+
return manifest
|
|
120
|
+
|
|
121
|
+
def build_manifest(self) -> DataManifest:
|
|
122
|
+
with self._lock:
|
|
123
|
+
saved_shards = [
|
|
124
|
+
cast(_SavedShard, item) for item in list(self._saved_shards)
|
|
125
|
+
]
|
|
126
|
+
if not saved_shards:
|
|
127
|
+
raise ValueError("Cannot commit a manifest with no saved shards")
|
|
128
|
+
|
|
129
|
+
split_names = sorted({item["split"] for item in saved_shards})
|
|
130
|
+
split_entries: dict[str, ManifestDatasetEntry] = {}
|
|
131
|
+
for split in split_names:
|
|
132
|
+
split_items = [item for item in saved_shards if item["split"] == split]
|
|
133
|
+
curriculum_entries = _curriculum_entries(split_items)
|
|
134
|
+
root_shards = _sorted_shards(split_items)
|
|
135
|
+
split_entries[split] = ManifestDatasetEntry(
|
|
136
|
+
records=sum(shard.records for shard in root_shards),
|
|
137
|
+
shards=root_shards,
|
|
138
|
+
curriculum=curriculum_entries,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return manifest_from_splits(split_entries, data_format=self._data_format)
|
|
142
|
+
|
|
143
|
+
def _shard_uri(self, *, split: str, curriculum_stage: str | None) -> str:
|
|
144
|
+
split_segment = safe_path_segment(split)
|
|
145
|
+
shard_name = f"part-{uuid.uuid4().hex}.jsonl.gz"
|
|
146
|
+
if curriculum_stage is None:
|
|
147
|
+
return join_s3_uri(self._shards_base_uri, split_segment, shard_name)
|
|
148
|
+
stage_segment = safe_path_segment(curriculum_stage)
|
|
149
|
+
return join_s3_uri(
|
|
150
|
+
self._shards_base_uri,
|
|
151
|
+
split_segment,
|
|
152
|
+
"curriculum",
|
|
153
|
+
stage_segment,
|
|
154
|
+
shard_name,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _manifest_shards_base_uri(manifest_uri: str) -> str:
|
|
159
|
+
parsed = parse_s3_uri(manifest_uri)
|
|
160
|
+
key = parsed.key.rstrip("/")
|
|
161
|
+
if key != parsed.key:
|
|
162
|
+
raise ValueError(f"Expected manifest file URI, got: {manifest_uri}")
|
|
163
|
+
parent_key = key.rpartition("/")[0]
|
|
164
|
+
shards_key = "shards" if not parent_key else f"{parent_key}/shards"
|
|
165
|
+
return f"s3://{parsed.bucket}/{shards_key}"
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _curriculum_entries(items: list[_SavedShard]) -> dict[str, ManifestCurriculumStage]:
|
|
169
|
+
stage_names = sorted(
|
|
170
|
+
{
|
|
171
|
+
item["curriculum_stage"]
|
|
172
|
+
for item in items
|
|
173
|
+
if item["curriculum_stage"] is not None
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
return {
|
|
177
|
+
stage_name: curriculum_stage_from_shards(
|
|
178
|
+
_sorted_shards(
|
|
179
|
+
item for item in items if item["curriculum_stage"] == stage_name
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
for stage_name in stage_names
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _sorted_shards(items: Iterable[_SavedShard]) -> list[ManifestShard]:
|
|
187
|
+
shards = [ManifestShard.model_validate(item["shard"]) for item in items]
|
|
188
|
+
return sorted(shards, key=lambda shard: shard.uri)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
__all__ = ["DatasetSaver"]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
from collections.abc import Mapping
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
from urllib.parse import quote, urlparse
|
|
8
|
+
|
|
9
|
+
from obstore.store import S3Store
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class S3Uri:
|
|
14
|
+
bucket: str
|
|
15
|
+
key: str
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def uri(self) -> str:
|
|
19
|
+
return f"s3://{self.bucket}/{self.key}"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class StorageConfig:
|
|
24
|
+
s3_base_url: str
|
|
25
|
+
region: str
|
|
26
|
+
access_key_id: str
|
|
27
|
+
secret_access_key: str
|
|
28
|
+
virtual_hosted_style_request: bool = False
|
|
29
|
+
|
|
30
|
+
def __post_init__(self) -> None:
|
|
31
|
+
_require_non_empty("s3_base_url", self.s3_base_url)
|
|
32
|
+
_require_non_empty("region", self.region)
|
|
33
|
+
_require_non_empty("access_key_id", self.access_key_id)
|
|
34
|
+
_require_non_empty("secret_access_key", self.secret_access_key)
|
|
35
|
+
|
|
36
|
+
def open_store(self, bucket: str) -> S3Store:
|
|
37
|
+
kwargs: dict[str, Any] = {
|
|
38
|
+
"endpoint": self.s3_base_url,
|
|
39
|
+
"region": self.region,
|
|
40
|
+
"access_key_id": self.access_key_id,
|
|
41
|
+
"secret_access_key": self.secret_access_key,
|
|
42
|
+
"virtual_hosted_style_request": self.virtual_hosted_style_request,
|
|
43
|
+
}
|
|
44
|
+
return S3Store(bucket, **kwargs)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def parse_s3_uri(uri: str) -> S3Uri:
|
|
48
|
+
parsed = urlparse(uri)
|
|
49
|
+
key = parsed.path.lstrip("/")
|
|
50
|
+
if parsed.scheme != "s3" or not parsed.netloc or not key:
|
|
51
|
+
raise ValueError(f"Expected non-empty s3://bucket/key URI, got: {uri}")
|
|
52
|
+
return S3Uri(bucket=parsed.netloc, key=key)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def join_s3_uri(base_uri: str, *parts: str) -> str:
|
|
56
|
+
parsed = parse_s3_uri(_ensure_base_key(base_uri))
|
|
57
|
+
normalized_key = parsed.key.rstrip("/")
|
|
58
|
+
suffix = "/".join(_clean_key_part(part) for part in parts if part)
|
|
59
|
+
if suffix:
|
|
60
|
+
normalized_key = f"{normalized_key}/{suffix}"
|
|
61
|
+
return f"s3://{parsed.bucket}/{normalized_key}"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def safe_path_segment(value: str) -> str:
|
|
65
|
+
stripped = value.strip()
|
|
66
|
+
if not stripped:
|
|
67
|
+
raise ValueError("S3 path segment must not be empty")
|
|
68
|
+
return quote(stripped, safe="-_.=+")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def store_for_uri(
|
|
72
|
+
uri: str,
|
|
73
|
+
*,
|
|
74
|
+
storage_config: StorageConfig,
|
|
75
|
+
) -> tuple[S3Store, S3Uri]:
|
|
76
|
+
parsed = parse_s3_uri(uri)
|
|
77
|
+
store = storage_config.open_store(parsed.bucket)
|
|
78
|
+
return store, parsed
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def read_object_bytes(
|
|
82
|
+
uri: str,
|
|
83
|
+
*,
|
|
84
|
+
storage_config: StorageConfig,
|
|
85
|
+
) -> builtins.bytes:
|
|
86
|
+
store, parsed = store_for_uri(uri, storage_config=storage_config)
|
|
87
|
+
return bytes(store.get(parsed.key).bytes())
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def put_object_bytes(
|
|
91
|
+
uri: str,
|
|
92
|
+
body: builtins.bytes,
|
|
93
|
+
*,
|
|
94
|
+
storage_config: StorageConfig,
|
|
95
|
+
attributes: dict[str, str] | None = None,
|
|
96
|
+
) -> Mapping[str, Any]:
|
|
97
|
+
store, parsed = store_for_uri(uri, storage_config=storage_config)
|
|
98
|
+
return store.put(parsed.key, body, attributes=attributes)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _ensure_base_key(uri: str) -> str:
|
|
102
|
+
parsed = urlparse(uri)
|
|
103
|
+
if parsed.scheme == "s3" and parsed.netloc and parsed.path.lstrip("/"):
|
|
104
|
+
return uri
|
|
105
|
+
raise ValueError(f"Expected non-empty s3://bucket/key URI, got: {uri}")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _clean_key_part(value: str) -> str:
|
|
109
|
+
return value.strip().strip("/")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _require_non_empty(field_name: str, value: str) -> None:
|
|
113
|
+
if not isinstance(value, str) or not value.strip():
|
|
114
|
+
raise ValueError(f"{field_name} must be provided")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
__all__ = [
|
|
118
|
+
"S3Uri",
|
|
119
|
+
"StorageConfig",
|
|
120
|
+
"join_s3_uri",
|
|
121
|
+
"parse_s3_uri",
|
|
122
|
+
"put_object_bytes",
|
|
123
|
+
"read_object_bytes",
|
|
124
|
+
"safe_path_segment",
|
|
125
|
+
"store_for_uri",
|
|
126
|
+
]
|