pydantic-fixturegen 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-fixturegen might be problematic. Click here for more details.
- pydantic_fixturegen/api/__init__.py +137 -0
- pydantic_fixturegen/api/_runtime.py +726 -0
- pydantic_fixturegen/api/models.py +73 -0
- pydantic_fixturegen/cli/__init__.py +32 -1
- pydantic_fixturegen/cli/check.py +230 -0
- pydantic_fixturegen/cli/diff.py +992 -0
- pydantic_fixturegen/cli/doctor.py +188 -35
- pydantic_fixturegen/cli/gen/_common.py +134 -7
- pydantic_fixturegen/cli/gen/explain.py +597 -40
- pydantic_fixturegen/cli/gen/fixtures.py +244 -112
- pydantic_fixturegen/cli/gen/json.py +229 -138
- pydantic_fixturegen/cli/gen/schema.py +170 -85
- pydantic_fixturegen/cli/init.py +333 -0
- pydantic_fixturegen/cli/schema.py +45 -0
- pydantic_fixturegen/cli/watch.py +126 -0
- pydantic_fixturegen/core/config.py +137 -3
- pydantic_fixturegen/core/config_schema.py +178 -0
- pydantic_fixturegen/core/constraint_report.py +305 -0
- pydantic_fixturegen/core/errors.py +42 -0
- pydantic_fixturegen/core/field_policies.py +100 -0
- pydantic_fixturegen/core/generate.py +241 -37
- pydantic_fixturegen/core/io_utils.py +10 -2
- pydantic_fixturegen/core/path_template.py +197 -0
- pydantic_fixturegen/core/presets.py +73 -0
- pydantic_fixturegen/core/providers/temporal.py +10 -0
- pydantic_fixturegen/core/safe_import.py +146 -12
- pydantic_fixturegen/core/seed_freeze.py +176 -0
- pydantic_fixturegen/emitters/json_out.py +65 -16
- pydantic_fixturegen/emitters/pytest_codegen.py +68 -13
- pydantic_fixturegen/emitters/schema_out.py +27 -3
- pydantic_fixturegen/logging.py +114 -0
- pydantic_fixturegen/schemas/config.schema.json +244 -0
- pydantic_fixturegen-1.1.0.dist-info/METADATA +173 -0
- pydantic_fixturegen-1.1.0.dist-info/RECORD +57 -0
- pydantic_fixturegen-1.0.0.dist-info/METADATA +0 -280
- pydantic_fixturegen-1.0.0.dist-info/RECORD +0 -41
- {pydantic_fixturegen-1.0.0.dist-info → pydantic_fixturegen-1.1.0.dist-info}/WHEEL +0 -0
- {pydantic_fixturegen-1.0.0.dist-info → pydantic_fixturegen-1.1.0.dist-info}/entry_points.txt +0 -0
- {pydantic_fixturegen-1.0.0.dist-info → pydantic_fixturegen-1.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,10 +15,20 @@ def generate_temporal(
|
|
|
15
15
|
summary: FieldSummary,
|
|
16
16
|
*,
|
|
17
17
|
faker: Faker | None = None,
|
|
18
|
+
time_anchor: datetime.datetime | None = None,
|
|
18
19
|
) -> Any:
|
|
19
20
|
faker = faker or Faker()
|
|
20
21
|
type_name = summary.type
|
|
21
22
|
|
|
23
|
+
if time_anchor is not None:
|
|
24
|
+
anchor = time_anchor
|
|
25
|
+
if type_name == "datetime":
|
|
26
|
+
return anchor
|
|
27
|
+
if type_name == "date":
|
|
28
|
+
return anchor.date()
|
|
29
|
+
if type_name == "time":
|
|
30
|
+
return anchor.timetz() if anchor.tzinfo else anchor.time()
|
|
31
|
+
|
|
22
32
|
if type_name == "datetime":
|
|
23
33
|
return faker.date_time(tzinfo=datetime.timezone.utc)
|
|
24
34
|
if type_name == "date":
|
|
@@ -45,6 +45,109 @@ class SafeImportResult:
|
|
|
45
45
|
exit_code: int
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
def _module_basename(path: Path) -> str:
|
|
49
|
+
"""Return the module name portion for a Python file."""
|
|
50
|
+
|
|
51
|
+
if path.name == "__init__.py":
|
|
52
|
+
return path.parent.name or "module"
|
|
53
|
+
stem = path.stem
|
|
54
|
+
return stem if stem else "module"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _package_hierarchy(module_path: Path) -> list[Path]:
|
|
58
|
+
"""Collect package directories (with __init__.py) from top to bottom."""
|
|
59
|
+
|
|
60
|
+
hierarchy: list[Path] = []
|
|
61
|
+
current = module_path.parent.resolve()
|
|
62
|
+
while True:
|
|
63
|
+
init_file = current / "__init__.py"
|
|
64
|
+
if not init_file.exists():
|
|
65
|
+
break
|
|
66
|
+
hierarchy.append(current)
|
|
67
|
+
parent = current.parent.resolve()
|
|
68
|
+
if parent == current:
|
|
69
|
+
break
|
|
70
|
+
current = parent
|
|
71
|
+
hierarchy.reverse()
|
|
72
|
+
return hierarchy
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _resolve_module_name(module_path: Path, workdir: Path, index: int) -> str:
|
|
76
|
+
"""Determine an importable module name for the module path."""
|
|
77
|
+
|
|
78
|
+
packages = _package_hierarchy(module_path)
|
|
79
|
+
if packages:
|
|
80
|
+
module_part = _module_basename(module_path)
|
|
81
|
+
if module_path.name == "__init__.py":
|
|
82
|
+
return ".".join(pkg.name for pkg in packages)
|
|
83
|
+
package_parts = [pkg.name for pkg in packages]
|
|
84
|
+
return ".".join(package_parts + [module_part])
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
relative = module_path.relative_to(workdir)
|
|
88
|
+
except ValueError:
|
|
89
|
+
relative = None
|
|
90
|
+
|
|
91
|
+
if relative is not None:
|
|
92
|
+
parts = list(relative.parts)
|
|
93
|
+
if parts:
|
|
94
|
+
parts[-1] = _module_basename(module_path)
|
|
95
|
+
module_name = ".".join(part for part in parts if part not in ("", "."))
|
|
96
|
+
if module_name:
|
|
97
|
+
return module_name
|
|
98
|
+
|
|
99
|
+
fallback = _module_basename(module_path)
|
|
100
|
+
return fallback if index == 0 else f"{fallback}_{index}"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _candidate_python_paths(module_path: Path, workdir: Path) -> list[Path]:
|
|
104
|
+
"""Return directories that should be added to PYTHONPATH for imports."""
|
|
105
|
+
|
|
106
|
+
candidates: list[Path] = []
|
|
107
|
+
packages = _package_hierarchy(module_path)
|
|
108
|
+
if packages:
|
|
109
|
+
highest_package = packages[0]
|
|
110
|
+
parent = highest_package.parent
|
|
111
|
+
if parent != highest_package:
|
|
112
|
+
candidates.append(parent)
|
|
113
|
+
candidates.append(module_path.parent)
|
|
114
|
+
|
|
115
|
+
if not candidates:
|
|
116
|
+
candidates.append(workdir)
|
|
117
|
+
|
|
118
|
+
return candidates
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _build_module_entries(paths: Sequence[Path], workdir: Path) -> list[dict[str, str]]:
|
|
122
|
+
entries: list[dict[str, str]] = []
|
|
123
|
+
for index, module_path in enumerate(paths):
|
|
124
|
+
module_name = _resolve_module_name(module_path, workdir, index)
|
|
125
|
+
entries.append({"path": str(module_path), "name": module_name})
|
|
126
|
+
return entries
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _build_pythonpath_entries(workdir: Path, paths: Sequence[Path]) -> list[Path]:
|
|
130
|
+
entries: list[Path] = []
|
|
131
|
+
seen: set[Path] = set()
|
|
132
|
+
|
|
133
|
+
def _add(path: Path) -> None:
|
|
134
|
+
resolved = path.resolve()
|
|
135
|
+
if not resolved.exists() or not resolved.is_dir():
|
|
136
|
+
return
|
|
137
|
+
if resolved in seen:
|
|
138
|
+
return
|
|
139
|
+
entries.append(resolved)
|
|
140
|
+
seen.add(resolved)
|
|
141
|
+
|
|
142
|
+
_add(workdir)
|
|
143
|
+
|
|
144
|
+
for module_path in paths:
|
|
145
|
+
for candidate in _candidate_python_paths(module_path, workdir):
|
|
146
|
+
_add(candidate)
|
|
147
|
+
|
|
148
|
+
return entries
|
|
149
|
+
|
|
150
|
+
|
|
48
151
|
def safe_import_models(
|
|
49
152
|
paths: Sequence[Path | str],
|
|
50
153
|
*,
|
|
@@ -67,16 +170,22 @@ def safe_import_models(
|
|
|
67
170
|
if not paths:
|
|
68
171
|
return SafeImportResult(True, [], None, None, "", 0)
|
|
69
172
|
|
|
70
|
-
workdir = Path(cwd) if cwd else Path.cwd()
|
|
173
|
+
workdir = (Path(cwd) if cwd else Path.cwd()).resolve()
|
|
71
174
|
python = python_executable or sys.executable
|
|
72
175
|
|
|
176
|
+
resolved_paths = [Path(path).resolve() for path in paths]
|
|
177
|
+
module_entries = _build_module_entries(resolved_paths, workdir)
|
|
178
|
+
pythonpath_entries = _build_pythonpath_entries(workdir, resolved_paths)
|
|
179
|
+
|
|
73
180
|
request = {
|
|
74
|
-
"paths": [str(
|
|
181
|
+
"paths": [str(path) for path in resolved_paths],
|
|
182
|
+
"module_entries": module_entries,
|
|
183
|
+
"python_path_entries": [str(path) for path in pythonpath_entries],
|
|
75
184
|
"memory_limit_mb": memory_limit_mb,
|
|
76
|
-
"workdir": str(workdir
|
|
185
|
+
"workdir": str(workdir),
|
|
77
186
|
}
|
|
78
187
|
|
|
79
|
-
env = _build_env(workdir, extra_env)
|
|
188
|
+
env = _build_env(workdir, extra_env, pythonpath_entries)
|
|
80
189
|
|
|
81
190
|
try:
|
|
82
191
|
completed = subprocess.run(
|
|
@@ -146,10 +255,13 @@ def _safe_text(value: object) -> str:
|
|
|
146
255
|
return value.decode("utf-8", "replace") if isinstance(value, bytes) else str(value or "")
|
|
147
256
|
|
|
148
257
|
|
|
149
|
-
def _build_env(
|
|
258
|
+
def _build_env(
|
|
259
|
+
workdir: Path,
|
|
260
|
+
extra_env: Mapping[str, str] | None,
|
|
261
|
+
pythonpath_entries: Sequence[Path],
|
|
262
|
+
) -> dict[str, str]:
|
|
150
263
|
base_env: dict[str, str] = {
|
|
151
264
|
"PYTHONSAFEPATH": "1",
|
|
152
|
-
"PYTHONPATH": str(workdir),
|
|
153
265
|
"NO_PROXY": "*",
|
|
154
266
|
"no_proxy": "*",
|
|
155
267
|
"http_proxy": "",
|
|
@@ -165,6 +277,9 @@ def _build_env(workdir: Path, extra_env: Mapping[str, str] | None) -> dict[str,
|
|
|
165
277
|
"HOME": str(workdir),
|
|
166
278
|
}
|
|
167
279
|
|
|
280
|
+
pythonpath_value = os.pathsep.join(str(entry) for entry in pythonpath_entries)
|
|
281
|
+
base_env["PYTHONPATH"] = pythonpath_value or str(workdir)
|
|
282
|
+
|
|
168
283
|
allowed_passthrough = ["PATH", "SYSTEMROOT", "COMSPEC"]
|
|
169
284
|
for key in allowed_passthrough:
|
|
170
285
|
if key in os.environ:
|
|
@@ -335,8 +450,8 @@ _RUNNER_SNIPPET = textwrap.dedent(
|
|
|
335
450
|
stem = module_path.stem or "module"
|
|
336
451
|
return stem if index == 0 else f"{stem}_{index}"
|
|
337
452
|
|
|
338
|
-
def _load_module(module_path: Path, index: int):
|
|
339
|
-
module_name = _derive_module_name(module_path, index)
|
|
453
|
+
def _load_module(module_path: Path, index: int, explicit_name: str | None = None):
|
|
454
|
+
module_name = explicit_name or _derive_module_name(module_path, index)
|
|
340
455
|
spec = importlib_util.spec_from_file_location(module_name, module_path)
|
|
341
456
|
if spec is None or spec.loader is None:
|
|
342
457
|
raise ImportError(f"Could not load module from {module_path}")
|
|
@@ -378,12 +493,31 @@ _RUNNER_SNIPPET = textwrap.dedent(
|
|
|
378
493
|
_block_network()
|
|
379
494
|
_restrict_filesystem(workdir)
|
|
380
495
|
|
|
381
|
-
|
|
496
|
+
python_path_entries = request.get("python_path_entries") or []
|
|
497
|
+
for extra in reversed(python_path_entries):
|
|
498
|
+
if not extra:
|
|
499
|
+
continue
|
|
500
|
+
extra_path = str(Path(extra))
|
|
501
|
+
if extra_path not in sys.path:
|
|
502
|
+
sys.path.insert(0, extra_path)
|
|
503
|
+
|
|
504
|
+
module_entries = request.get("module_entries") or []
|
|
505
|
+
normalized_entries = []
|
|
506
|
+
if module_entries:
|
|
507
|
+
for entry in module_entries:
|
|
508
|
+
raw_path = entry.get("path")
|
|
509
|
+
if not raw_path:
|
|
510
|
+
continue
|
|
511
|
+
module_path = Path(raw_path)
|
|
512
|
+
module_name = entry.get("name")
|
|
513
|
+
normalized_entries.append((module_path, module_name))
|
|
514
|
+
else:
|
|
515
|
+
fallback_paths = [Path(path) for path in request.get("paths", [])]
|
|
516
|
+
normalized_entries = [(path, None) for path in fallback_paths]
|
|
382
517
|
|
|
383
518
|
collected = []
|
|
384
|
-
for idx,
|
|
385
|
-
|
|
386
|
-
module = _load_module(module_path, idx)
|
|
519
|
+
for idx, (module_path, module_name) in enumerate(normalized_entries):
|
|
520
|
+
module = _load_module(module_path, idx, module_name)
|
|
387
521
|
collected.extend(_collect_models(module, module_path))
|
|
388
522
|
|
|
389
523
|
payload = {"success": True, "models": collected}
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""Helpers for managing seed freeze files used for deterministic generation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
from .seed import SeedManager
|
|
15
|
+
|
|
16
|
+
FREEZE_FILE_BASENAME = ".pfg-seeds.json"
|
|
17
|
+
FREEZE_FILE_VERSION = 1
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FreezeStatus(str, Enum):
|
|
21
|
+
"""Status classification for freeze entries when resolving seeds."""
|
|
22
|
+
|
|
23
|
+
MISSING = "missing"
|
|
24
|
+
STALE = "stale"
|
|
25
|
+
VALID = "valid"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True)
|
|
29
|
+
class SeedRecord:
|
|
30
|
+
"""Stored seed metadata for a single model."""
|
|
31
|
+
|
|
32
|
+
seed: int
|
|
33
|
+
model_digest: str | None = None
|
|
34
|
+
|
|
35
|
+
def to_payload(self) -> dict[str, Any]:
|
|
36
|
+
payload: dict[str, Any] = {"seed": self.seed}
|
|
37
|
+
if self.model_digest is not None:
|
|
38
|
+
payload["model_digest"] = self.model_digest
|
|
39
|
+
return payload
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SeedFreezeFile:
|
|
43
|
+
"""Abstraction over the freeze file storing per-model deterministic seeds."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, path: Path) -> None:
|
|
46
|
+
self.path = path
|
|
47
|
+
self.exists = False
|
|
48
|
+
self._records: dict[str, SeedRecord] = {}
|
|
49
|
+
self._dirty = False
|
|
50
|
+
self.messages: list[str] = []
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def records(self) -> dict[str, SeedRecord]:
|
|
54
|
+
return self._records
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def load(cls, path: Path) -> SeedFreezeFile:
|
|
58
|
+
manager = cls(path)
|
|
59
|
+
if not path.exists():
|
|
60
|
+
return manager
|
|
61
|
+
|
|
62
|
+
manager.exists = True
|
|
63
|
+
try:
|
|
64
|
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
|
65
|
+
except json.JSONDecodeError as exc:
|
|
66
|
+
manager.messages.append(f"Failed to parse seed freeze file: {exc}")
|
|
67
|
+
return manager
|
|
68
|
+
|
|
69
|
+
version = raw.get("version")
|
|
70
|
+
if version != FREEZE_FILE_VERSION:
|
|
71
|
+
manager.messages.append("Seed freeze file version mismatch; ignoring entries")
|
|
72
|
+
return manager
|
|
73
|
+
|
|
74
|
+
models = raw.get("models", {})
|
|
75
|
+
if not isinstance(models, dict):
|
|
76
|
+
manager.messages.append("Seed freeze file missing 'models' mapping; ignoring entries")
|
|
77
|
+
return manager
|
|
78
|
+
|
|
79
|
+
for identifier, payload in models.items():
|
|
80
|
+
if not isinstance(payload, dict):
|
|
81
|
+
continue
|
|
82
|
+
seed = payload.get("seed")
|
|
83
|
+
if not isinstance(seed, int):
|
|
84
|
+
continue
|
|
85
|
+
record = SeedRecord(
|
|
86
|
+
seed=seed,
|
|
87
|
+
model_digest=payload.get("model_digest"),
|
|
88
|
+
)
|
|
89
|
+
manager._records[identifier] = record
|
|
90
|
+
return manager
|
|
91
|
+
|
|
92
|
+
def resolve_seed(
|
|
93
|
+
self, identifier: str, *, model_digest: str | None
|
|
94
|
+
) -> tuple[int | None, FreezeStatus]:
|
|
95
|
+
record = self._records.get(identifier)
|
|
96
|
+
if record is None:
|
|
97
|
+
return None, FreezeStatus.MISSING
|
|
98
|
+
|
|
99
|
+
if model_digest and record.model_digest and record.model_digest != model_digest:
|
|
100
|
+
return record.seed, FreezeStatus.STALE
|
|
101
|
+
|
|
102
|
+
if model_digest and record.model_digest is None:
|
|
103
|
+
return record.seed, FreezeStatus.STALE
|
|
104
|
+
|
|
105
|
+
return record.seed, FreezeStatus.VALID
|
|
106
|
+
|
|
107
|
+
def record_seed(self, identifier: str, seed: int, *, model_digest: str | None) -> None:
|
|
108
|
+
current = self._records.get(identifier)
|
|
109
|
+
new_record = SeedRecord(seed=seed, model_digest=model_digest)
|
|
110
|
+
if (
|
|
111
|
+
current
|
|
112
|
+
and current.seed == new_record.seed
|
|
113
|
+
and current.model_digest == new_record.model_digest
|
|
114
|
+
):
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
self._records[identifier] = new_record
|
|
118
|
+
self._dirty = True
|
|
119
|
+
|
|
120
|
+
def save(self) -> None:
|
|
121
|
+
if not self._dirty:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
output = {
|
|
125
|
+
"version": FREEZE_FILE_VERSION,
|
|
126
|
+
"models": {
|
|
127
|
+
identifier: record.to_payload()
|
|
128
|
+
for identifier, record in sorted(self._records.items())
|
|
129
|
+
},
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
self.path.write_text(json.dumps(output, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
|
134
|
+
self._dirty = False
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def resolve_freeze_path(path_option: Path | None, *, root: Path | None = None) -> Path:
|
|
138
|
+
base = root or Path.cwd()
|
|
139
|
+
if path_option is None:
|
|
140
|
+
return base / FREEZE_FILE_BASENAME
|
|
141
|
+
|
|
142
|
+
candidate = Path(path_option)
|
|
143
|
+
if candidate.is_absolute():
|
|
144
|
+
return candidate
|
|
145
|
+
return base / candidate
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def model_identifier(model: type[BaseModel]) -> str:
|
|
149
|
+
return f"{model.__module__}.{model.__qualname__}"
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def compute_model_digest(model: type[BaseModel]) -> str | None:
|
|
153
|
+
try:
|
|
154
|
+
schema = model.model_json_schema()
|
|
155
|
+
except Exception: # pragma: no cover - defensive
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
serialized = json.dumps(schema, sort_keys=True, separators=(",", ":"))
|
|
159
|
+
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def derive_default_model_seed(base_seed: int | str | None, identifier: str) -> int:
|
|
163
|
+
manager = SeedManager(seed=base_seed)
|
|
164
|
+
return manager.derive_child_seed(identifier)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
__all__ = [
|
|
168
|
+
"FREEZE_FILE_BASENAME",
|
|
169
|
+
"FREEZE_FILE_VERSION",
|
|
170
|
+
"FreezeStatus",
|
|
171
|
+
"SeedFreezeFile",
|
|
172
|
+
"compute_model_digest",
|
|
173
|
+
"derive_default_model_seed",
|
|
174
|
+
"model_identifier",
|
|
175
|
+
"resolve_freeze_path",
|
|
176
|
+
]
|
|
@@ -14,6 +14,8 @@ from typing import Any, cast
|
|
|
14
14
|
|
|
15
15
|
from pydantic import BaseModel
|
|
16
16
|
|
|
17
|
+
from pydantic_fixturegen.core.path_template import OutputTemplate, OutputTemplateContext
|
|
18
|
+
|
|
17
19
|
orjson: ModuleType | None
|
|
18
20
|
try: # Optional dependency
|
|
19
21
|
import orjson as _orjson
|
|
@@ -32,6 +34,8 @@ class JsonEmitConfig:
|
|
|
32
34
|
|
|
33
35
|
output_path: Path
|
|
34
36
|
count: int
|
|
37
|
+
template: OutputTemplate | None = None
|
|
38
|
+
template_context: OutputTemplateContext | None = None
|
|
35
39
|
jsonl: bool = False
|
|
36
40
|
indent: int | None = DEFAULT_INDENT
|
|
37
41
|
shard_size: int | None = None
|
|
@@ -51,6 +55,8 @@ def emit_json_samples(
|
|
|
51
55
|
use_orjson: bool = False,
|
|
52
56
|
ensure_ascii: bool = False,
|
|
53
57
|
max_workers: int | None = None,
|
|
58
|
+
template: OutputTemplate | None = None,
|
|
59
|
+
template_context: OutputTemplateContext | None = None,
|
|
54
60
|
) -> list[Path]:
|
|
55
61
|
"""Emit generated samples to JSON or JSONL files.
|
|
56
62
|
|
|
@@ -73,9 +79,20 @@ def emit_json_samples(
|
|
|
73
79
|
List of ``Path`` objects for the created file(s), ordered by shard index.
|
|
74
80
|
"""
|
|
75
81
|
|
|
82
|
+
template_obj = template or OutputTemplate(output_path)
|
|
83
|
+
context = template_context or OutputTemplateContext()
|
|
84
|
+
|
|
85
|
+
if template_obj.fields:
|
|
86
|
+
initial_index = 1 if template_obj.uses_case_index() else None
|
|
87
|
+
resolved_path = template_obj.render(context=context, case_index=initial_index)
|
|
88
|
+
else:
|
|
89
|
+
resolved_path = template_obj.render(context=context)
|
|
90
|
+
|
|
76
91
|
config = JsonEmitConfig(
|
|
77
|
-
output_path=
|
|
92
|
+
output_path=resolved_path,
|
|
78
93
|
count=count,
|
|
94
|
+
template=template_obj if template_obj.fields else None,
|
|
95
|
+
template_context=context,
|
|
79
96
|
jsonl=jsonl,
|
|
80
97
|
indent=_normalise_indent(indent, jsonl=jsonl),
|
|
81
98
|
shard_size=_normalise_shard_size(shard_size, count),
|
|
@@ -95,13 +112,13 @@ def emit_json_samples(
|
|
|
95
112
|
if config.jsonl:
|
|
96
113
|
path = _stream_jsonl(
|
|
97
114
|
samples_iter,
|
|
98
|
-
config
|
|
115
|
+
_resolve_base_path(config, index=1),
|
|
99
116
|
encoder,
|
|
100
117
|
)
|
|
101
118
|
else:
|
|
102
119
|
path = _stream_json_array(
|
|
103
120
|
samples_iter,
|
|
104
|
-
config
|
|
121
|
+
_resolve_base_path(config, index=1),
|
|
105
122
|
encoder,
|
|
106
123
|
indent=config.indent,
|
|
107
124
|
)
|
|
@@ -167,7 +184,7 @@ def _write_empty_shard(
|
|
|
167
184
|
path = _shard_path(base_path, 1, 1, jsonl)
|
|
168
185
|
empty_payload = "" if jsonl else encoder.encode([])
|
|
169
186
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
170
|
-
path.write_text(empty_payload, encoding="utf-8")
|
|
187
|
+
path.write_text(empty_payload, encoding="utf-8", newline="\n")
|
|
171
188
|
return path
|
|
172
189
|
|
|
173
190
|
|
|
@@ -179,14 +196,19 @@ def _prepare_payload(
|
|
|
179
196
|
workers: int,
|
|
180
197
|
) -> str:
|
|
181
198
|
if not jsonl:
|
|
182
|
-
|
|
199
|
+
payload = encoder.encode(list(chunk))
|
|
200
|
+
if payload and not payload.endswith("\n"):
|
|
201
|
+
payload += "\n"
|
|
202
|
+
return payload
|
|
183
203
|
|
|
184
204
|
if workers <= 1:
|
|
185
205
|
lines = [encoder.encode(item) for item in chunk]
|
|
186
206
|
else:
|
|
187
207
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
188
208
|
lines = list(executor.map(encoder.encode, chunk))
|
|
189
|
-
|
|
209
|
+
if not lines:
|
|
210
|
+
return ""
|
|
211
|
+
return "\n".join(lines) + "\n"
|
|
190
212
|
|
|
191
213
|
|
|
192
214
|
def _stream_jsonl(
|
|
@@ -196,7 +218,7 @@ def _stream_jsonl(
|
|
|
196
218
|
) -> Path:
|
|
197
219
|
path = _ensure_suffix(base_path, ".jsonl")
|
|
198
220
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
199
|
-
with path.open("w", encoding="utf-8") as stream:
|
|
221
|
+
with path.open("w", encoding="utf-8", newline="\n") as stream:
|
|
200
222
|
for record in iterator:
|
|
201
223
|
stream.write(encoder.encode(record))
|
|
202
224
|
stream.write("\n")
|
|
@@ -214,7 +236,7 @@ def _stream_json_array(
|
|
|
214
236
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
215
237
|
|
|
216
238
|
if indent is None:
|
|
217
|
-
with path.open("w", encoding="utf-8") as stream:
|
|
239
|
+
with path.open("w", encoding="utf-8", newline="\n") as stream:
|
|
218
240
|
first = True
|
|
219
241
|
stream.write("[")
|
|
220
242
|
for record in iterator:
|
|
@@ -222,11 +244,11 @@ def _stream_json_array(
|
|
|
222
244
|
stream.write(",")
|
|
223
245
|
stream.write(encoder.encode(record))
|
|
224
246
|
first = False
|
|
225
|
-
stream.write("]")
|
|
247
|
+
stream.write("]\n")
|
|
226
248
|
return path
|
|
227
249
|
|
|
228
250
|
spacing = " " * indent
|
|
229
|
-
with path.open("w", encoding="utf-8") as stream:
|
|
251
|
+
with path.open("w", encoding="utf-8", newline="\n") as stream:
|
|
230
252
|
written = False
|
|
231
253
|
for record in iterator:
|
|
232
254
|
encoded = encoder.encode(record)
|
|
@@ -237,9 +259,9 @@ def _stream_json_array(
|
|
|
237
259
|
stream.write(f"{spacing}{encoded}")
|
|
238
260
|
written = True
|
|
239
261
|
if not written:
|
|
240
|
-
stream.write("[]")
|
|
262
|
+
stream.write("[]\n")
|
|
241
263
|
else:
|
|
242
|
-
stream.write("\n]")
|
|
264
|
+
stream.write("\n]\n")
|
|
243
265
|
return path
|
|
244
266
|
|
|
245
267
|
|
|
@@ -253,7 +275,13 @@ def _write_chunked_samples(
|
|
|
253
275
|
|
|
254
276
|
chunk = list(islice(iterator, chunk_size))
|
|
255
277
|
if not chunk:
|
|
256
|
-
results.append(
|
|
278
|
+
results.append(
|
|
279
|
+
_write_empty_shard(
|
|
280
|
+
_resolve_base_path(config, index=1),
|
|
281
|
+
config.jsonl,
|
|
282
|
+
encoder,
|
|
283
|
+
)
|
|
284
|
+
)
|
|
257
285
|
return results
|
|
258
286
|
|
|
259
287
|
index = 1
|
|
@@ -261,7 +289,7 @@ def _write_chunked_samples(
|
|
|
261
289
|
next_chunk = list(islice(iterator, chunk_size))
|
|
262
290
|
is_last = not next_chunk
|
|
263
291
|
path = _chunk_path(
|
|
264
|
-
config
|
|
292
|
+
config,
|
|
265
293
|
index=index,
|
|
266
294
|
is_last=is_last,
|
|
267
295
|
jsonl=config.jsonl,
|
|
@@ -273,7 +301,7 @@ def _write_chunked_samples(
|
|
|
273
301
|
workers=_worker_count(config.max_workers, len(chunk)),
|
|
274
302
|
)
|
|
275
303
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
276
|
-
path.write_text(payload, encoding="utf-8")
|
|
304
|
+
path.write_text(payload, encoding="utf-8", newline="\n")
|
|
277
305
|
results.append(path)
|
|
278
306
|
|
|
279
307
|
chunk = next_chunk
|
|
@@ -283,13 +311,24 @@ def _write_chunked_samples(
|
|
|
283
311
|
|
|
284
312
|
|
|
285
313
|
def _chunk_path(
|
|
286
|
-
|
|
314
|
+
config: JsonEmitConfig,
|
|
287
315
|
*,
|
|
288
316
|
index: int,
|
|
289
317
|
is_last: bool,
|
|
290
318
|
jsonl: bool,
|
|
291
319
|
) -> Path:
|
|
320
|
+
template = config.template
|
|
321
|
+
if template is not None:
|
|
322
|
+
base_path = template.render(
|
|
323
|
+
context=config.template_context,
|
|
324
|
+
case_index=index if template.uses_case_index() else None,
|
|
325
|
+
)
|
|
326
|
+
else:
|
|
327
|
+
base_path = config.output_path
|
|
328
|
+
|
|
292
329
|
suffix = ".jsonl" if jsonl else ".json"
|
|
330
|
+
if template is not None and template.uses_case_index():
|
|
331
|
+
return _ensure_suffix(base_path, suffix)
|
|
293
332
|
if is_last and index == 1:
|
|
294
333
|
return _ensure_suffix(base_path, suffix)
|
|
295
334
|
|
|
@@ -297,6 +336,16 @@ def _chunk_path(
|
|
|
297
336
|
return _shard_path(base_path, index, shard_total, jsonl)
|
|
298
337
|
|
|
299
338
|
|
|
339
|
+
def _resolve_base_path(config: JsonEmitConfig, *, index: int) -> Path:
|
|
340
|
+
template = config.template
|
|
341
|
+
if template is None:
|
|
342
|
+
return config.output_path
|
|
343
|
+
return template.render(
|
|
344
|
+
context=config.template_context,
|
|
345
|
+
case_index=index if template.uses_case_index() else None,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
|
|
300
349
|
def _shard_path(base_path: Path, shard_index: int, shard_count: int, jsonl: bool) -> Path:
|
|
301
350
|
suffix = ".jsonl" if jsonl else ".json"
|
|
302
351
|
if shard_count <= 1:
|