syvain-training-data 0.0.118__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,121 @@
1
+ Metadata-Version: 2.3
2
+ Name: syvain-training-data
3
+ Version: 0.0.118
4
+ Summary: Syvain training data manifest, loading, and saving utilities
5
+ Requires-Dist: obstore>=0.10.1
6
+ Requires-Dist: pydantic>=2.13.3
7
+ Requires-Dist: torch>=2.0.0,<2.11
8
+ Requires-Dist: typing-extensions>=4.15.0
9
+ Requires-Dist: pytest>=8.0.0 ; extra == 'dev'
10
+ Requires-Dist: ruff>=0.15.12 ; extra == 'dev'
11
+ Requires-Dist: ty>=0.0.34 ; extra == 'dev'
12
+ Requires-Python: >=3.10
13
+ Provides-Extra: dev
14
+ Description-Content-Type: text/markdown
15
+
16
+ # syvain-training-data
17
+
18
+ Internal [Syvain](https://syvain.com/) data utility. No secret sauce here, just
19
+ a shared helper.
20
+
21
+ > This is my dataloader. There are many like it, but this one is mine. My
22
+ > dataloader is my best friend. It is my life. I must master it as I must master
23
+ > my life. My dataloader, without me, is useless. Without my dataloader, I am
24
+ > useless.
25
+
26
+ ## Install
27
+
28
+ ```bash
29
+ uv add syvain-training-data
30
+ ```
31
+
32
+ ## Load data
33
+
34
+ ```python
35
+ from syvain_training_data import SyvainTrainingData
36
+
37
+ training_data = SyvainTrainingData(
38
+ s3_base_url="https://t3.storage.dev",
39
+ region="auto",
40
+ access_key_id="...",
41
+ secret_access_key="...",
42
+ )
43
+
44
+
45
+ def collate(records):
46
+ ...
47
+
48
+
49
+ loader = training_data.split_data_loader(
50
+ "s3://my-training-bucket/path/to/data-manifest-v1.json",
51
+ collate_fn=collate,
52
+ dataloader_args={"batch_size": 32, "num_workers": 4, ...},
53
+ )
54
+
55
+ train_batches = loader.load("train")
56
+ valid_batches = loader.load("valid")
57
+ easy_batches = loader.load("train", curriculum_stage="easy")
58
+ infinite_train_batches = loader.load("train", infinite_iter=True)
59
+ ```
60
+
61
+ ## Save data
62
+
63
+ ```python
64
+ from concurrent.futures import ProcessPoolExecutor
65
+
66
+ from syvain_training_data import SyvainTrainingData
67
+
68
+ def generate_data(split, curriculum_stage, shard_id):
69
+ ...
70
+
71
+ def save_shard(job):
72
+ saver, split, curriculum_stage, shard_id = job
73
+ records = generate_data(split, curriculum_stage, shard_id)
74
+ saver.save(split, curriculum_stage, records)
75
+
76
+
77
+ training_data = SyvainTrainingData(
78
+ s3_base_url="https://t3.storage.dev",
79
+ region="auto",
80
+ access_key_id="...",
81
+ secret_access_key="...",
82
+ )
83
+
84
+ saver = training_data.dataset_saver(
85
+ "s3://my-training-bucket/path/to/dataset/data-manifest-v1.json",
86
+ )
87
+
88
+ jobs = [
89
+ (saver, "train", stage, shard_id)
90
+ for stage in ["easy", "medium", "hard"]
91
+ for shard_id in range(32)
92
+ ] + [
93
+ (saver, "valid", None, shard_id) for shard_id in range(4)
94
+ ] + [
95
+ (saver, "test", None, shard_id) for shard_id in range(4)
96
+ ]
97
+
98
+ with ProcessPoolExecutor(max_workers=8) as pool:
99
+ list(pool.map(save_shard, jobs))
100
+
101
+ manifest = saver.commit_manifest()
102
+ ```
103
+
104
+ ## Copy a manifest
105
+
106
+ ```python
107
+ from syvain_training_data import SyvainTrainingData
108
+
109
+ training_data = SyvainTrainingData(
110
+ s3_base_url="https://t3.storage.dev",
111
+ region="auto",
112
+ access_key_id="...",
113
+ secret_access_key="...",
114
+ )
115
+
116
+ manifest = training_data.load_manifest("s3://my-training-bucket/shared/data-manifest-v1.json")
117
+
118
+ # Do modifications if needed
119
+
120
+ training_data.save_manifest("s3://my-training-bucket/new-run/data-manifest-v1.json", manifest)
121
+ ```
@@ -0,0 +1,106 @@
1
+ # syvain-training-data
2
+
3
+ Internal [Syvain](https://syvain.com/) data utility. No secret sauce here, just
4
+ a shared helper.
5
+
6
+ > This is my dataloader. There are many like it, but this one is mine. My
7
+ > dataloader is my best friend. It is my life. I must master it as I must master
8
+ > my life. My dataloader, without me, is useless. Without my dataloader, I am
9
+ > useless.
10
+
11
+ ## Install
12
+
13
+ ```bash
14
+ uv add syvain-training-data
15
+ ```
16
+
17
+ ## Load data
18
+
19
+ ```python
20
+ from syvain_training_data import SyvainTrainingData
21
+
22
+ training_data = SyvainTrainingData(
23
+ s3_base_url="https://t3.storage.dev",
24
+ region="auto",
25
+ access_key_id="...",
26
+ secret_access_key="...",
27
+ )
28
+
29
+
30
+ def collate(records):
31
+ ...
32
+
33
+
34
+ loader = training_data.split_data_loader(
35
+ "s3://my-training-bucket/path/to/data-manifest-v1.json",
36
+ collate_fn=collate,
37
+ dataloader_args={"batch_size": 32, "num_workers": 4, ...},
38
+ )
39
+
40
+ train_batches = loader.load("train")
41
+ valid_batches = loader.load("valid")
42
+ easy_batches = loader.load("train", curriculum_stage="easy")
43
+ infinite_train_batches = loader.load("train", infinite_iter=True)
44
+ ```
45
+
46
+ ## Save data
47
+
48
+ ```python
49
+ from concurrent.futures import ProcessPoolExecutor
50
+
51
+ from syvain_training_data import SyvainTrainingData
52
+
53
+ def generate_data(split, curriculum_stage, shard_id):
54
+ ...
55
+
56
+ def save_shard(job):
57
+ saver, split, curriculum_stage, shard_id = job
58
+ records = generate_data(split, curriculum_stage, shard_id)
59
+ saver.save(split, curriculum_stage, records)
60
+
61
+
62
+ training_data = SyvainTrainingData(
63
+ s3_base_url="https://t3.storage.dev",
64
+ region="auto",
65
+ access_key_id="...",
66
+ secret_access_key="...",
67
+ )
68
+
69
+ saver = training_data.dataset_saver(
70
+ "s3://my-training-bucket/path/to/dataset/data-manifest-v1.json",
71
+ )
72
+
73
+ jobs = [
74
+ (saver, "train", stage, shard_id)
75
+ for stage in ["easy", "medium", "hard"]
76
+ for shard_id in range(32)
77
+ ] + [
78
+ (saver, "valid", None, shard_id) for shard_id in range(4)
79
+ ] + [
80
+ (saver, "test", None, shard_id) for shard_id in range(4)
81
+ ]
82
+
83
+ with ProcessPoolExecutor(max_workers=8) as pool:
84
+ list(pool.map(save_shard, jobs))
85
+
86
+ manifest = saver.commit_manifest()
87
+ ```
88
+
89
+ ## Copy a manifest
90
+
91
+ ```python
92
+ from syvain_training_data import SyvainTrainingData
93
+
94
+ training_data = SyvainTrainingData(
95
+ s3_base_url="https://t3.storage.dev",
96
+ region="auto",
97
+ access_key_id="...",
98
+ secret_access_key="...",
99
+ )
100
+
101
+ manifest = training_data.load_manifest("s3://my-training-bucket/shared/data-manifest-v1.json")
102
+
103
+ # Do modifications if needed
104
+
105
+ training_data.save_manifest("s3://my-training-bucket/new-run/data-manifest-v1.json", manifest)
106
+ ```
@@ -0,0 +1,22 @@
1
+ [project]
2
+ name = "syvain-training-data"
3
+ version = "0.0.118"
4
+ description = "Syvain training data manifest, loading, and saving utilities"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "obstore>=0.10.1",
9
+ "pydantic>=2.13.3",
10
+ "torch>=2.0.0,<2.11",
11
+ "typing-extensions>=4.15.0",
12
+ ]
13
+
14
+ [project.optional-dependencies]
15
+ dev = ["pytest>=8.0.0", "ruff>=0.15.12", "ty>=0.0.34"]
16
+
17
+ [tool.pytest.ini_options]
18
+ testpaths = ["tests"]
19
+
20
+ [build-system]
21
+ requires = ["uv_build>=0.11.9,<0.12"]
22
+ build-backend = "uv_build"
@@ -0,0 +1,23 @@
1
+ from syvain_training_data.client import SyvainTrainingData
2
+ from syvain_training_data.loader import SplitDataLoader
3
+ from syvain_training_data.manifest import (
4
+ DataManifest,
5
+ JsonObject,
6
+ JsonValue,
7
+ ManifestCurriculumStage,
8
+ ManifestDatasetEntry,
9
+ ManifestShard,
10
+ )
11
+ from syvain_training_data.saver import DatasetSaver
12
+
13
+ __all__ = [
14
+ "DataManifest",
15
+ "DatasetSaver",
16
+ "JsonObject",
17
+ "JsonValue",
18
+ "ManifestCurriculumStage",
19
+ "ManifestDatasetEntry",
20
+ "ManifestShard",
21
+ "SplitDataLoader",
22
+ "SyvainTrainingData",
23
+ ]
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping, Sequence
4
+ from typing import Any, TypeVar
5
+
6
+ from syvain_training_data.loader import SplitDataLoader
7
+ from syvain_training_data.manifest import (
8
+ DataManifest,
9
+ manifest_json_bytes,
10
+ parse_manifest_json,
11
+ )
12
+ from syvain_training_data.saver import DatasetSaver
13
+ from syvain_training_data.storage import (
14
+ StorageConfig,
15
+ put_object_bytes,
16
+ read_object_bytes,
17
+ )
18
+
19
+ BatchT = TypeVar("BatchT")
20
+
21
+
22
+ class SyvainTrainingData:
23
+ def __init__(
24
+ self,
25
+ *,
26
+ s3_base_url: str,
27
+ region: str,
28
+ access_key_id: str,
29
+ secret_access_key: str,
30
+ virtual_hosted_style_request: bool = False,
31
+ ) -> None:
32
+ self._storage_config = StorageConfig(
33
+ s3_base_url=s3_base_url,
34
+ region=region,
35
+ access_key_id=access_key_id,
36
+ secret_access_key=secret_access_key,
37
+ virtual_hosted_style_request=virtual_hosted_style_request,
38
+ )
39
+
40
+ def load_manifest(self, manifest_uri: str) -> DataManifest:
41
+ body = read_object_bytes(
42
+ manifest_uri,
43
+ storage_config=self._storage_config,
44
+ )
45
+ return parse_manifest_json(body, source=manifest_uri)
46
+
47
+ def save_manifest(
48
+ self,
49
+ manifest_uri: str,
50
+ manifest: DataManifest | Mapping[str, object],
51
+ ) -> None:
52
+ put_object_bytes(
53
+ manifest_uri,
54
+ manifest_json_bytes(manifest),
55
+ storage_config=self._storage_config,
56
+ attributes={"Content-Type": "application/json"},
57
+ )
58
+
59
+ def split_data_loader(
60
+ self,
61
+ manifest_uri: str,
62
+ collate_fn: Callable[[Sequence[Mapping[str, Any]]], BatchT],
63
+ dataloader_args: Mapping[str, Any] | None = None,
64
+ ) -> SplitDataLoader[BatchT]:
65
+ return SplitDataLoader(
66
+ manifest=self.load_manifest(manifest_uri),
67
+ storage_config=self._storage_config,
68
+ collate_fn=collate_fn,
69
+ dataloader_args={} if dataloader_args is None else dict(dataloader_args),
70
+ )
71
+
72
+ def dataset_saver(
73
+ self,
74
+ manifest_uri: str,
75
+ ) -> DatasetSaver:
76
+ return DatasetSaver(
77
+ storage_config=self._storage_config,
78
+ manifest_uri=manifest_uri,
79
+ )
80
+
81
+
82
+ __all__ = ["SyvainTrainingData"]
@@ -0,0 +1,141 @@
1
+ from __future__ import annotations
2
+
3
+ import codecs
4
+ import gzip
5
+ import json
6
+ import logging
7
+ import zlib
8
+ from collections.abc import Iterable, Iterator, Mapping
9
+ from io import BytesIO
10
+ from typing import Any
11
+
12
+ from torch.utils.data import get_worker_info
13
+
14
+ from syvain_training_data.manifest import ManifestShard
15
+ from syvain_training_data.storage import (
16
+ StorageConfig,
17
+ parse_s3_uri,
18
+ store_for_uri,
19
+ )
20
+
21
+ LOGGER = logging.getLogger(__name__)
22
+
23
+ def records_to_body(records: Iterable[Mapping[str, Any]]) -> tuple[bytes, int]:
24
+ output = BytesIO()
25
+ record_count = 0
26
+ with gzip.GzipFile(
27
+ fileobj=output, mode="wb", compresslevel=6, mtime=0
28
+ ) as gzip_file:
29
+ for record in records:
30
+ line = json.dumps(record, separators=(",", ":")).encode("utf-8")
31
+ gzip_file.write(line)
32
+ gzip_file.write(b"\n")
33
+ record_count += 1
34
+ return output.getvalue(), record_count
35
+
36
+
37
+ def iter_shard(
38
+ shard: ManifestShard,
39
+ *,
40
+ storage_config: StorageConfig,
41
+ ) -> Iterator[Mapping[str, Any]]:
42
+ source = shard.uri
43
+ parsed_uri = parse_s3_uri(source)
44
+ worker = get_worker_info()
45
+ worker_id = "main" if worker is None else str(worker.id)
46
+ yielded_records = 0
47
+ LOGGER.info(
48
+ "Worker %s reading shard %s with %s expected records",
49
+ worker_id,
50
+ source,
51
+ shard.records,
52
+ )
53
+
54
+ try:
55
+ store, _ = store_for_uri(source, storage_config=storage_config)
56
+ response = store.get(parsed_uri.key)
57
+ except Exception as exc:
58
+ raise RuntimeError(f"Unreadable shard {source}") from exc
59
+
60
+ decompressor = zlib.decompressobj(wbits=16 + zlib.MAX_WBITS)
61
+ decoder = codecs.getincrementaldecoder("utf-8")()
62
+ pending_line = ""
63
+ line_number = 0
64
+
65
+ try:
66
+ for chunk in response.stream(min_chunk_size=1024 * 1024):
67
+ decoded = decoder.decode(decompressor.decompress(chunk), final=False)
68
+ if not decoded:
69
+ continue
70
+
71
+ pending_line += decoded
72
+ lines = pending_line.split("\n")
73
+ pending_line = lines.pop()
74
+ for line in lines:
75
+ line_number += 1
76
+ if not line.strip():
77
+ continue
78
+ record = _parse_record_line(
79
+ line,
80
+ source=source,
81
+ line_number=line_number,
82
+ )
83
+ if record is not None:
84
+ yielded_records += 1
85
+ yield record
86
+
87
+ pending_line += decoder.decode(decompressor.flush(), final=True)
88
+ except (zlib.error, UnicodeDecodeError) as exc:
89
+ raise RuntimeError(
90
+ f"Unreadable gzip-compressed UTF-8 JSONL shard {source}"
91
+ ) from exc
92
+
93
+ if not decompressor.eof or decompressor.unused_data:
94
+ raise RuntimeError(
95
+ "Invalid gzip-compressed UTF-8 JSONL shard "
96
+ f"{source}: eof={decompressor.eof} "
97
+ f"unused_data={bool(decompressor.unused_data)}"
98
+ )
99
+
100
+ if pending_line.strip():
101
+ line_number += 1
102
+ record = _parse_record_line(
103
+ pending_line,
104
+ source=source,
105
+ line_number=line_number,
106
+ )
107
+ if record is not None:
108
+ yielded_records += 1
109
+ yield record
110
+
111
+ LOGGER.info(
112
+ "Worker %s finished shard %s with %s/%s yielded records",
113
+ worker_id,
114
+ source,
115
+ yielded_records,
116
+ shard.records,
117
+ )
118
+
119
+
120
+ def _parse_record_line(
121
+ line: str,
122
+ *,
123
+ source: str,
124
+ line_number: int,
125
+ ) -> Mapping[str, Any] | None:
126
+ try:
127
+ raw_record = json.loads(line)
128
+ except json.JSONDecodeError as exc:
129
+ LOGGER.error(
130
+ "Skipping invalid JSON record in %s:%s: %s", source, line_number, exc
131
+ )
132
+ return None
133
+
134
+ if not isinstance(raw_record, Mapping):
135
+ LOGGER.error("Skipping non-object JSON record in %s:%s", source, line_number)
136
+ return None
137
+
138
+ return raw_record
139
+
140
+
141
+ __all__ = ["iter_shard", "records_to_body"]
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from collections.abc import Callable, Iterator, Mapping, Sequence
5
+ from dataclasses import dataclass
6
+ from typing import Any, Generic, TypeVar, cast, overload
7
+
8
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
9
+ from typing_extensions import Literal
10
+
11
+ from syvain_training_data import gzjsonl
12
+ from syvain_training_data.manifest import (
13
+ DataManifest,
14
+ DataManifestDataFormat,
15
+ ManifestShard,
16
+ curriculum_stage_entry,
17
+ manifest_split_entry,
18
+ )
19
+ from syvain_training_data.storage import StorageConfig
20
+
21
+ LOGGER = logging.getLogger(__name__)
22
+
23
+ BatchT = TypeVar("BatchT")
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class SplitDataLoader(Generic[BatchT]):
28
+ manifest: DataManifest
29
+ storage_config: StorageConfig
30
+ collate_fn: Callable[[Sequence[Mapping[str, Any]]], BatchT]
31
+ dataloader_args: Mapping[str, Any]
32
+
33
+ @overload
34
+ def load(
35
+ self,
36
+ split: str,
37
+ curriculum_stage: str | None = None,
38
+ infinite_iter: Literal[False] = False,
39
+ ) -> DataLoader[BatchT]: ...
40
+
41
+ @overload
42
+ def load(
43
+ self,
44
+ split: str,
45
+ curriculum_stage: str | None = None,
46
+ infinite_iter: Literal[True] = True,
47
+ ) -> Iterator[BatchT]: ...
48
+
49
+ def load(
50
+ self,
51
+ split: str,
52
+ curriculum_stage: str | None = None,
53
+ infinite_iter: bool = False,
54
+ ) -> DataLoader[BatchT] | Iterator[BatchT]:
55
+ split_entry = manifest_split_entry(self.manifest, split)
56
+ if curriculum_stage is not None:
57
+ split_entry = curriculum_stage_entry(split_entry, curriculum_stage)
58
+
59
+ dataset = _ManifestShardDataset(
60
+ shards=split_entry.shards,
61
+ data_format=self.manifest.data_format,
62
+ storage_config=self.storage_config,
63
+ )
64
+ loader_kwargs = dict(self.dataloader_args)
65
+ if "collate_fn" in loader_kwargs:
66
+ raise ValueError(
67
+ "Pass collate_fn to split_data_loader, not dataloader_args"
68
+ )
69
+ loader_kwargs["collate_fn"] = self.collate_fn
70
+ loader_kwargs = _cap_worker_args(
71
+ loader_kwargs,
72
+ shard_count=dataset.shard_count,
73
+ split=split,
74
+ )
75
+
76
+ loader = cast(DataLoader[BatchT], DataLoader(dataset, **loader_kwargs))
77
+ if not infinite_iter:
78
+ return loader
79
+ return _iter_dataloader_forever(loader, split=split)
80
+
81
+
82
+ class _ManifestShardDataset(IterableDataset[Mapping[str, Any]]):
83
+ def __init__(
84
+ self,
85
+ *,
86
+ shards: Sequence[ManifestShard],
87
+ data_format: DataManifestDataFormat,
88
+ storage_config: StorageConfig,
89
+ ) -> None:
90
+ self._shards = tuple(shards)
91
+ self._data_format = data_format
92
+ self._storage_config = storage_config
93
+
94
+ @property
95
+ def shard_count(self) -> int:
96
+ return len(self._shards)
97
+
98
+ def __iter__(self) -> Iterator[Mapping[str, Any]]:
99
+ for shard in _worker_shards(self._shards):
100
+ if self._data_format == "jsonl.gz":
101
+ yield from gzjsonl.iter_shard(
102
+ shard,
103
+ storage_config=self._storage_config,
104
+ )
105
+ else:
106
+ raise ValueError(f"Unknown data format: {self._data_format}")
107
+
108
+
109
+ def _cap_worker_args(
110
+ loader_kwargs: dict[str, Any],
111
+ *,
112
+ shard_count: int,
113
+ split: str,
114
+ ) -> dict[str, Any]:
115
+ requested_workers = int(loader_kwargs.get("num_workers", 0))
116
+ if requested_workers < 0:
117
+ raise ValueError("dataloader_args['num_workers'] must be non-negative")
118
+
119
+ num_workers = min(requested_workers, shard_count)
120
+ loader_kwargs["num_workers"] = num_workers
121
+ if num_workers < requested_workers:
122
+ LOGGER.debug(
123
+ "Capping %s dataloader workers from %s to %s because the split has %s shards",
124
+ split,
125
+ requested_workers,
126
+ num_workers,
127
+ shard_count,
128
+ )
129
+
130
+ if num_workers == 0:
131
+ loader_kwargs.pop("persistent_workers", None)
132
+ loader_kwargs.pop("prefetch_factor", None)
133
+ loader_kwargs.pop("multiprocessing_context", None)
134
+ if int(loader_kwargs.get("timeout", 0)) > 0:
135
+ loader_kwargs.pop("timeout", None)
136
+ return loader_kwargs
137
+
138
+
139
+ def _worker_shards(shards: Sequence[ManifestShard]) -> Sequence[ManifestShard]:
140
+ worker = get_worker_info()
141
+ if worker is None:
142
+ return shards
143
+ return shards[worker.id :: worker.num_workers]
144
+
145
+
146
+ def _iter_dataloader_forever(
147
+ loader: DataLoader[BatchT],
148
+ *,
149
+ split: str,
150
+ ) -> Iterator[BatchT]:
151
+ while True:
152
+ yielded_batch = False
153
+ for batch in loader:
154
+ yielded_batch = True
155
+ yield batch
156
+ if not yielded_batch:
157
+ raise RuntimeError(f"{split} split produced no batches")
158
+
159
+
160
+ __all__ = ["SplitDataLoader"]
@@ -0,0 +1,198 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Literal, Mapping
5
+ from urllib.parse import urlparse
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
8
+ from typing_extensions import TypeAliasType
9
+
10
+ JsonValue = TypeAliasType(
11
+ "JsonValue",
12
+ str | int | float | bool | None | list["JsonValue"] | dict[str, "JsonValue"],
13
+ )
14
+ JsonObject = TypeAliasType("JsonObject", dict[str, JsonValue])
15
+ DataManifestDataFormat = TypeAliasType("DataManifestDataFormat", Literal["jsonl.gz"])
16
+
17
+ MANIFEST_VERSION = 1
18
+
19
+ MODEL_CONFIG = ConfigDict(extra="forbid", frozen=True, strict=True)
20
+
21
+
22
+ class ManifestShard(BaseModel):
23
+ model_config = MODEL_CONFIG
24
+
25
+ uri: str = Field(min_length=1)
26
+ records: int = Field(gt=0)
27
+ size_bytes: int = Field(gt=0)
28
+ etag: str = ""
29
+
30
+ @field_validator("uri")
31
+ @classmethod
32
+ def validate_uri(cls, value: str) -> str:
33
+ parsed = urlparse(value)
34
+ if parsed.scheme != "s3" or not parsed.netloc or not parsed.path.lstrip("/"):
35
+ raise ValueError(f"Expected non-empty s3://bucket/key URI, got: {value}")
36
+ return value
37
+
38
+
39
+ class ManifestCurriculumStage(BaseModel):
40
+ model_config = MODEL_CONFIG
41
+
42
+ records: int = Field(gt=0)
43
+ shards: list[ManifestShard] = Field(min_length=1)
44
+
45
+ @field_validator("shards")
46
+ @classmethod
47
+ def validate_shard_records(cls, value: list[ManifestShard]) -> list[ManifestShard]:
48
+ if sum(shard.records for shard in value) <= 0:
49
+ raise ValueError("stage shards must contain records")
50
+ return value
51
+
52
+ def model_post_init(self, __context: object) -> None:
53
+ shard_records = sum(shard.records for shard in self.shards)
54
+ if self.records != shard_records:
55
+ raise ValueError(
56
+ f"stage records {self.records} do not match shard records "
57
+ f"{shard_records}"
58
+ )
59
+
60
+
61
+ class ManifestDatasetEntry(BaseModel):
62
+ model_config = MODEL_CONFIG
63
+
64
+ records: int = Field(gt=0)
65
+ shards: list[ManifestShard] = Field(min_length=1)
66
+ curriculum: dict[str, ManifestCurriculumStage] = Field(default_factory=dict)
67
+
68
+ @field_validator("curriculum")
69
+ @classmethod
70
+ def validate_curriculum(
71
+ cls, value: dict[str, ManifestCurriculumStage]
72
+ ) -> dict[str, ManifestCurriculumStage]:
73
+ for stage_name in value:
74
+ if not stage_name:
75
+ raise ValueError("curriculum stage names must not be empty")
76
+ return value
77
+
78
+ def model_post_init(self, __context: object) -> None:
79
+ shard_records = sum(shard.records for shard in self.shards)
80
+ if self.records != shard_records:
81
+ raise ValueError(
82
+ f"split records {self.records} do not match shard records "
83
+ f"{shard_records}"
84
+ )
85
+
86
+
87
+ class DataManifest(BaseModel):
88
+ model_config = MODEL_CONFIG
89
+
90
+ version: Literal[1]
91
+ data_format: DataManifestDataFormat
92
+ splits: dict[str, ManifestDatasetEntry] = Field(min_length=1)
93
+
94
+ @field_validator("splits")
95
+ @classmethod
96
+ def validate_splits(
97
+ cls, value: dict[str, ManifestDatasetEntry]
98
+ ) -> dict[str, ManifestDatasetEntry]:
99
+ for split in value:
100
+ if not split:
101
+ raise ValueError("split names must not be empty")
102
+ return value
103
+
104
+
105
+ def manifest_split_entry(manifest: DataManifest, split: str) -> ManifestDatasetEntry:
106
+ try:
107
+ return manifest.splits[split]
108
+ except KeyError as exc:
109
+ available = ", ".join(sorted(manifest.splits))
110
+ raise ValueError(
111
+ f"Unknown split {split!r}; available splits: {available}"
112
+ ) from exc
113
+
114
+
115
+ def curriculum_stage_entry(
116
+ split_entry: ManifestDatasetEntry,
117
+ stage_name: str,
118
+ ) -> ManifestCurriculumStage:
119
+ try:
120
+ return split_entry.curriculum[stage_name]
121
+ except KeyError as exc:
122
+ available = ", ".join(sorted(split_entry.curriculum))
123
+ raise ValueError(
124
+ f"Unknown curriculum stage {stage_name!r}; available stages: {available}"
125
+ ) from exc
126
+
127
+
128
+ def parse_manifest_json(body: bytes, *, source: str) -> DataManifest:
129
+ try:
130
+ raw_manifest = json.loads(body.decode("utf-8"))
131
+ return DataManifest.model_validate(raw_manifest)
132
+ except UnicodeDecodeError as exc:
133
+ raise ValueError(f"Data manifest is not valid UTF-8: {source}") from exc
134
+ except json.JSONDecodeError as exc:
135
+ raise ValueError(f"Data manifest is not valid JSON: {source}") from exc
136
+ except ValidationError as exc:
137
+ raise ValueError(f"Data manifest has invalid shape: {source}") from exc
138
+
139
+
140
+ def manifest_json_bytes(manifest: DataManifest | Mapping[str, object]) -> bytes:
141
+ typed_manifest = (
142
+ manifest
143
+ if isinstance(manifest, DataManifest)
144
+ else DataManifest.model_validate(manifest)
145
+ )
146
+ return (typed_manifest.model_dump_json(indent=2, exclude_none=True) + "\n").encode(
147
+ "utf-8"
148
+ )
149
+
150
+
151
+ def dataset_entry_from_shards(
152
+ shards: list[ManifestShard],
153
+ *,
154
+ curriculum: Mapping[str, ManifestCurriculumStage] | None = None,
155
+ ) -> ManifestDatasetEntry:
156
+ return ManifestDatasetEntry(
157
+ records=sum(shard.records for shard in shards),
158
+ shards=shards,
159
+ curriculum={} if curriculum is None else dict(curriculum),
160
+ )
161
+
162
+
163
+ def curriculum_stage_from_shards(
164
+ shards: list[ManifestShard],
165
+ ) -> ManifestCurriculumStage:
166
+ return ManifestCurriculumStage(
167
+ records=sum(shard.records for shard in shards),
168
+ shards=shards,
169
+ )
170
+
171
+
172
+ def manifest_from_splits(
173
+ splits: Mapping[str, ManifestDatasetEntry],
174
+ *,
175
+ data_format: DataManifestDataFormat = "jsonl.gz",
176
+ ) -> DataManifest:
177
+ return DataManifest(
178
+ version=MANIFEST_VERSION, data_format=data_format, splits=dict(splits)
179
+ )
180
+
181
+
182
+ __all__ = [
183
+ "DataManifest",
184
+ "DataManifestDataFormat",
185
+ "JsonObject",
186
+ "JsonValue",
187
+ "MANIFEST_VERSION",
188
+ "ManifestCurriculumStage",
189
+ "ManifestDatasetEntry",
190
+ "ManifestShard",
191
+ "curriculum_stage_entry",
192
+ "curriculum_stage_from_shards",
193
+ "dataset_entry_from_shards",
194
+ "manifest_from_splits",
195
+ "manifest_json_bytes",
196
+ "manifest_split_entry",
197
+ "parse_manifest_json",
198
+ ]
@@ -0,0 +1,191 @@
1
+ from __future__ import annotations
2
+
3
+ import multiprocessing
4
+ import uuid
5
+ from collections.abc import Callable, Iterable, Mapping
6
+ from typing import Any, TypedDict, cast
7
+
8
+ from syvain_training_data import gzjsonl
9
+ from syvain_training_data.manifest import (
10
+ DataManifest,
11
+ DataManifestDataFormat,
12
+ ManifestCurriculumStage,
13
+ ManifestDatasetEntry,
14
+ ManifestShard,
15
+ curriculum_stage_from_shards,
16
+ manifest_from_splits,
17
+ )
18
+ from syvain_training_data.storage import (
19
+ StorageConfig,
20
+ join_s3_uri,
21
+ parse_s3_uri,
22
+ put_object_bytes,
23
+ safe_path_segment,
24
+ )
25
+
26
+
27
+ class _SavedShard(TypedDict):
28
+ split: str
29
+ curriculum_stage: str | None
30
+ shard: dict[str, object]
31
+
32
+
33
+ _RecordsToBody = Callable[[Iterable[Mapping[str, Any]]], tuple[bytes, int]]
34
+
35
+
36
+ class DatasetSaver:
37
+ def __init__(
38
+ self,
39
+ *,
40
+ storage_config: StorageConfig,
41
+ manifest_uri: str,
42
+ data_format: DataManifestDataFormat = "jsonl.gz",
43
+ ) -> None:
44
+ manifest_uri = manifest_uri.strip()
45
+ if not manifest_uri:
46
+ raise ValueError("manifest_uri must be provided")
47
+
48
+ self._storage_config = storage_config
49
+ self._manifest_uri = manifest_uri
50
+ self._shards_base_uri = _manifest_shards_base_uri(manifest_uri)
51
+ self._data_format = data_format
52
+ if data_format == "jsonl.gz":
53
+ self._records_to_body: _RecordsToBody = gzjsonl.records_to_body
54
+ else:
55
+ raise ValueError(f"Unknown data format: {data_format}")
56
+ self._manager = multiprocessing.Manager()
57
+ self._saved_shards = self._manager.list()
58
+ self._lock = self._manager.RLock()
59
+
60
+ def __getstate__(self) -> dict[str, Any]:
61
+ state = self.__dict__.copy()
62
+ state["_manager"] = None
63
+ return state
64
+
65
+ @property
66
+ def manifest_uri(self) -> str:
67
+ return self._manifest_uri
68
+
69
+ def save(
70
+ self,
71
+ split: str,
72
+ curriculum_stage: str | None,
73
+ records: Iterable[Mapping[str, Any]],
74
+ ) -> ManifestShard:
75
+ if not split:
76
+ raise ValueError("split must not be empty")
77
+
78
+ body, record_count = self._records_to_body(records)
79
+ if record_count == 0:
80
+ raise ValueError("Cannot save an empty record shard")
81
+
82
+ shard_uri = self._shard_uri(split=split, curriculum_stage=curriculum_stage)
83
+ response = put_object_bytes(
84
+ shard_uri,
85
+ body,
86
+ storage_config=self._storage_config,
87
+ attributes={
88
+ "Content-Type": "application/gzip",
89
+ "content-format": "jsonl",
90
+ "compression": "gzip",
91
+ },
92
+ )
93
+ etag = response.get("e_tag") or response.get("etag") or ""
94
+ shard = ManifestShard(
95
+ uri=shard_uri,
96
+ records=record_count,
97
+ size_bytes=len(body),
98
+ etag=str(etag),
99
+ )
100
+ saved: _SavedShard = {
101
+ "split": split,
102
+ "curriculum_stage": curriculum_stage,
103
+ "shard": shard.model_dump(mode="json"),
104
+ }
105
+ with self._lock:
106
+ self._saved_shards.append(saved)
107
+ return shard
108
+
109
+ def commit_manifest(self) -> DataManifest:
110
+ manifest = self.build_manifest()
111
+ from syvain_training_data.manifest import manifest_json_bytes
112
+
113
+ put_object_bytes(
114
+ self.manifest_uri,
115
+ manifest_json_bytes(manifest),
116
+ storage_config=self._storage_config,
117
+ attributes={"Content-Type": "application/json"},
118
+ )
119
+ return manifest
120
+
121
+ def build_manifest(self) -> DataManifest:
122
+ with self._lock:
123
+ saved_shards = [
124
+ cast(_SavedShard, item) for item in list(self._saved_shards)
125
+ ]
126
+ if not saved_shards:
127
+ raise ValueError("Cannot commit a manifest with no saved shards")
128
+
129
+ split_names = sorted({item["split"] for item in saved_shards})
130
+ split_entries: dict[str, ManifestDatasetEntry] = {}
131
+ for split in split_names:
132
+ split_items = [item for item in saved_shards if item["split"] == split]
133
+ curriculum_entries = _curriculum_entries(split_items)
134
+ root_shards = _sorted_shards(split_items)
135
+ split_entries[split] = ManifestDatasetEntry(
136
+ records=sum(shard.records for shard in root_shards),
137
+ shards=root_shards,
138
+ curriculum=curriculum_entries,
139
+ )
140
+
141
+ return manifest_from_splits(split_entries, data_format=self._data_format)
142
+
143
+ def _shard_uri(self, *, split: str, curriculum_stage: str | None) -> str:
144
+ split_segment = safe_path_segment(split)
145
+ shard_name = f"part-{uuid.uuid4().hex}.jsonl.gz"
146
+ if curriculum_stage is None:
147
+ return join_s3_uri(self._shards_base_uri, split_segment, shard_name)
148
+ stage_segment = safe_path_segment(curriculum_stage)
149
+ return join_s3_uri(
150
+ self._shards_base_uri,
151
+ split_segment,
152
+ "curriculum",
153
+ stage_segment,
154
+ shard_name,
155
+ )
156
+
157
+
158
+ def _manifest_shards_base_uri(manifest_uri: str) -> str:
159
+ parsed = parse_s3_uri(manifest_uri)
160
+ key = parsed.key.rstrip("/")
161
+ if key != parsed.key:
162
+ raise ValueError(f"Expected manifest file URI, got: {manifest_uri}")
163
+ parent_key = key.rpartition("/")[0]
164
+ shards_key = "shards" if not parent_key else f"{parent_key}/shards"
165
+ return f"s3://{parsed.bucket}/{shards_key}"
166
+
167
+
168
+ def _curriculum_entries(items: list[_SavedShard]) -> dict[str, ManifestCurriculumStage]:
169
+ stage_names = sorted(
170
+ {
171
+ item["curriculum_stage"]
172
+ for item in items
173
+ if item["curriculum_stage"] is not None
174
+ }
175
+ )
176
+ return {
177
+ stage_name: curriculum_stage_from_shards(
178
+ _sorted_shards(
179
+ item for item in items if item["curriculum_stage"] == stage_name
180
+ )
181
+ )
182
+ for stage_name in stage_names
183
+ }
184
+
185
+
186
+ def _sorted_shards(items: Iterable[_SavedShard]) -> list[ManifestShard]:
187
+ shards = [ManifestShard.model_validate(item["shard"]) for item in items]
188
+ return sorted(shards, key=lambda shard: shard.uri)
189
+
190
+
191
+ __all__ = ["DatasetSaver"]
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ import builtins
4
+ from collections.abc import Mapping
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+ from urllib.parse import quote, urlparse
8
+
9
+ from obstore.store import S3Store
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class S3Uri:
14
+ bucket: str
15
+ key: str
16
+
17
+ @property
18
+ def uri(self) -> str:
19
+ return f"s3://{self.bucket}/{self.key}"
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class StorageConfig:
24
+ s3_base_url: str
25
+ region: str
26
+ access_key_id: str
27
+ secret_access_key: str
28
+ virtual_hosted_style_request: bool = False
29
+
30
+ def __post_init__(self) -> None:
31
+ _require_non_empty("s3_base_url", self.s3_base_url)
32
+ _require_non_empty("region", self.region)
33
+ _require_non_empty("access_key_id", self.access_key_id)
34
+ _require_non_empty("secret_access_key", self.secret_access_key)
35
+
36
+ def open_store(self, bucket: str) -> S3Store:
37
+ kwargs: dict[str, Any] = {
38
+ "endpoint": self.s3_base_url,
39
+ "region": self.region,
40
+ "access_key_id": self.access_key_id,
41
+ "secret_access_key": self.secret_access_key,
42
+ "virtual_hosted_style_request": self.virtual_hosted_style_request,
43
+ }
44
+ return S3Store(bucket, **kwargs)
45
+
46
+
47
+ def parse_s3_uri(uri: str) -> S3Uri:
48
+ parsed = urlparse(uri)
49
+ key = parsed.path.lstrip("/")
50
+ if parsed.scheme != "s3" or not parsed.netloc or not key:
51
+ raise ValueError(f"Expected non-empty s3://bucket/key URI, got: {uri}")
52
+ return S3Uri(bucket=parsed.netloc, key=key)
53
+
54
+
55
+ def join_s3_uri(base_uri: str, *parts: str) -> str:
56
+ parsed = parse_s3_uri(_ensure_base_key(base_uri))
57
+ normalized_key = parsed.key.rstrip("/")
58
+ suffix = "/".join(_clean_key_part(part) for part in parts if part)
59
+ if suffix:
60
+ normalized_key = f"{normalized_key}/{suffix}"
61
+ return f"s3://{parsed.bucket}/{normalized_key}"
62
+
63
+
64
+ def safe_path_segment(value: str) -> str:
65
+ stripped = value.strip()
66
+ if not stripped:
67
+ raise ValueError("S3 path segment must not be empty")
68
+ return quote(stripped, safe="-_.=+")
69
+
70
+
71
+ def store_for_uri(
72
+ uri: str,
73
+ *,
74
+ storage_config: StorageConfig,
75
+ ) -> tuple[S3Store, S3Uri]:
76
+ parsed = parse_s3_uri(uri)
77
+ store = storage_config.open_store(parsed.bucket)
78
+ return store, parsed
79
+
80
+
81
+ def read_object_bytes(
82
+ uri: str,
83
+ *,
84
+ storage_config: StorageConfig,
85
+ ) -> builtins.bytes:
86
+ store, parsed = store_for_uri(uri, storage_config=storage_config)
87
+ return bytes(store.get(parsed.key).bytes())
88
+
89
+
90
+ def put_object_bytes(
91
+ uri: str,
92
+ body: builtins.bytes,
93
+ *,
94
+ storage_config: StorageConfig,
95
+ attributes: dict[str, str] | None = None,
96
+ ) -> Mapping[str, Any]:
97
+ store, parsed = store_for_uri(uri, storage_config=storage_config)
98
+ return store.put(parsed.key, body, attributes=attributes)
99
+
100
+
101
+ def _ensure_base_key(uri: str) -> str:
102
+ parsed = urlparse(uri)
103
+ if parsed.scheme == "s3" and parsed.netloc and parsed.path.lstrip("/"):
104
+ return uri
105
+ raise ValueError(f"Expected non-empty s3://bucket/key URI, got: {uri}")
106
+
107
+
108
+ def _clean_key_part(value: str) -> str:
109
+ return value.strip().strip("/")
110
+
111
+
112
+ def _require_non_empty(field_name: str, value: str) -> None:
113
+ if not isinstance(value, str) or not value.strip():
114
+ raise ValueError(f"{field_name} must be provided")
115
+
116
+
117
+ __all__ = [
118
+ "S3Uri",
119
+ "StorageConfig",
120
+ "join_s3_uri",
121
+ "parse_s3_uri",
122
+ "put_object_bytes",
123
+ "read_object_bytes",
124
+ "safe_path_segment",
125
+ "store_for_uri",
126
+ ]