pydantic-fixturegen 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-fixturegen might be problematic. Click here for more details.
- pydantic_fixturegen/api/__init__.py +137 -0
- pydantic_fixturegen/api/_runtime.py +726 -0
- pydantic_fixturegen/api/models.py +73 -0
- pydantic_fixturegen/cli/__init__.py +32 -1
- pydantic_fixturegen/cli/check.py +230 -0
- pydantic_fixturegen/cli/diff.py +992 -0
- pydantic_fixturegen/cli/doctor.py +188 -35
- pydantic_fixturegen/cli/gen/_common.py +134 -7
- pydantic_fixturegen/cli/gen/explain.py +597 -40
- pydantic_fixturegen/cli/gen/fixtures.py +244 -112
- pydantic_fixturegen/cli/gen/json.py +229 -138
- pydantic_fixturegen/cli/gen/schema.py +170 -85
- pydantic_fixturegen/cli/init.py +333 -0
- pydantic_fixturegen/cli/schema.py +45 -0
- pydantic_fixturegen/cli/watch.py +126 -0
- pydantic_fixturegen/core/config.py +137 -3
- pydantic_fixturegen/core/config_schema.py +178 -0
- pydantic_fixturegen/core/constraint_report.py +305 -0
- pydantic_fixturegen/core/errors.py +42 -0
- pydantic_fixturegen/core/field_policies.py +100 -0
- pydantic_fixturegen/core/generate.py +241 -37
- pydantic_fixturegen/core/io_utils.py +10 -2
- pydantic_fixturegen/core/path_template.py +197 -0
- pydantic_fixturegen/core/presets.py +73 -0
- pydantic_fixturegen/core/providers/temporal.py +10 -0
- pydantic_fixturegen/core/safe_import.py +146 -12
- pydantic_fixturegen/core/seed_freeze.py +176 -0
- pydantic_fixturegen/emitters/json_out.py +65 -16
- pydantic_fixturegen/emitters/pytest_codegen.py +68 -13
- pydantic_fixturegen/emitters/schema_out.py +27 -3
- pydantic_fixturegen/logging.py +114 -0
- pydantic_fixturegen/schemas/config.schema.json +244 -0
- pydantic_fixturegen-1.1.0.dist-info/METADATA +173 -0
- pydantic_fixturegen-1.1.0.dist-info/RECORD +57 -0
- pydantic_fixturegen-1.0.0.dist-info/METADATA +0 -280
- pydantic_fixturegen-1.0.0.dist-info/RECORD +0 -41
- {pydantic_fixturegen-1.0.0.dist-info → pydantic_fixturegen-1.1.0.dist-info}/WHEEL +0 -0
- {pydantic_fixturegen-1.0.0.dist-info → pydantic_fixturegen-1.1.0.dist-info}/entry_points.txt +0 -0
- {pydantic_fixturegen-1.0.0.dist-info → pydantic_fixturegen-1.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,19 +3,27 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import dataclasses
|
|
6
|
+
import datetime
|
|
6
7
|
import enum
|
|
7
8
|
import inspect
|
|
8
9
|
import random
|
|
9
|
-
from collections.abc import Iterable, Sized
|
|
10
|
+
from collections.abc import Iterable, Mapping, Sized
|
|
10
11
|
from dataclasses import dataclass, is_dataclass
|
|
11
12
|
from dataclasses import fields as dataclass_fields
|
|
12
13
|
from typing import Any, get_type_hints
|
|
13
14
|
|
|
14
15
|
from faker import Faker
|
|
15
|
-
from pydantic import BaseModel
|
|
16
|
+
from pydantic import BaseModel, ValidationError
|
|
16
17
|
from pydantic.fields import FieldInfo
|
|
17
18
|
|
|
18
19
|
from pydantic_fixturegen.core import schema as schema_module
|
|
20
|
+
from pydantic_fixturegen.core.config import ConfigError
|
|
21
|
+
from pydantic_fixturegen.core.constraint_report import ConstraintReporter
|
|
22
|
+
from pydantic_fixturegen.core.field_policies import (
|
|
23
|
+
FieldPolicy,
|
|
24
|
+
FieldPolicyConflictError,
|
|
25
|
+
FieldPolicySet,
|
|
26
|
+
)
|
|
19
27
|
from pydantic_fixturegen.core.providers import ProviderRegistry, create_default_registry
|
|
20
28
|
from pydantic_fixturegen.core.schema import FieldConstraints, FieldSummary, extract_constraints
|
|
21
29
|
from pydantic_fixturegen.core.strategies import (
|
|
@@ -36,6 +44,19 @@ class GenerationConfig:
|
|
|
36
44
|
default_p_none: float = 0.0
|
|
37
45
|
optional_p_none: float = 0.0
|
|
38
46
|
seed: int | None = None
|
|
47
|
+
time_anchor: datetime.datetime | None = None
|
|
48
|
+
field_policies: tuple[FieldPolicy, ...] = ()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(slots=True)
|
|
52
|
+
class _PathEntry:
|
|
53
|
+
module: str
|
|
54
|
+
qualname: str
|
|
55
|
+
via_field: str | None = None
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def full(self) -> str:
|
|
59
|
+
return f"{self.module}.{self.qualname}"
|
|
39
60
|
|
|
40
61
|
|
|
41
62
|
class InstanceGenerator:
|
|
@@ -70,6 +91,15 @@ class InstanceGenerator:
|
|
|
70
91
|
plugin_manager=self._plugin_manager,
|
|
71
92
|
)
|
|
72
93
|
self._strategy_cache: dict[type[Any], dict[str, StrategyResult]] = {}
|
|
94
|
+
self._constraint_reporter = ConstraintReporter()
|
|
95
|
+
self._field_policy_set = (
|
|
96
|
+
FieldPolicySet(self.config.field_policies) if self.config.field_policies else None
|
|
97
|
+
)
|
|
98
|
+
self._path_stack: list[_PathEntry] = []
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def constraint_report(self) -> ConstraintReporter:
|
|
102
|
+
return self._constraint_reporter
|
|
73
103
|
|
|
74
104
|
# ------------------------------------------------------------------ public API
|
|
75
105
|
def generate_one(self, model: type[BaseModel]) -> BaseModel | None:
|
|
@@ -86,64 +116,223 @@ class InstanceGenerator:
|
|
|
86
116
|
return results
|
|
87
117
|
|
|
88
118
|
# ------------------------------------------------------------------ internals
|
|
89
|
-
def _build_model_instance(
|
|
119
|
+
def _build_model_instance(
|
|
120
|
+
self,
|
|
121
|
+
model_type: type[Any],
|
|
122
|
+
*,
|
|
123
|
+
depth: int,
|
|
124
|
+
via_field: str | None = None,
|
|
125
|
+
) -> Any | None:
|
|
90
126
|
if depth >= self.config.max_depth:
|
|
91
127
|
return None
|
|
92
128
|
if not self._consume_object():
|
|
93
129
|
return None
|
|
94
130
|
|
|
131
|
+
entry = self._make_path_entry(model_type, via_field)
|
|
132
|
+
self._path_stack.append(entry)
|
|
95
133
|
try:
|
|
96
134
|
strategies = self._get_model_strategies(model_type)
|
|
97
135
|
except TypeError:
|
|
98
136
|
return None
|
|
99
137
|
|
|
138
|
+
self._constraint_reporter.begin_model(model_type)
|
|
139
|
+
|
|
100
140
|
values: dict[str, Any] = {}
|
|
101
|
-
|
|
102
|
-
|
|
141
|
+
try:
|
|
142
|
+
for field_name, strategy in strategies.items():
|
|
143
|
+
self._apply_field_policies(field_name, strategy)
|
|
144
|
+
values[field_name] = self._evaluate_strategy(
|
|
145
|
+
strategy,
|
|
146
|
+
depth,
|
|
147
|
+
model_type,
|
|
148
|
+
field_name,
|
|
149
|
+
)
|
|
150
|
+
finally:
|
|
151
|
+
self._path_stack.pop()
|
|
103
152
|
|
|
104
153
|
try:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
154
|
+
instance: Any | None = None
|
|
155
|
+
if (isinstance(model_type, type) and issubclass(model_type, BaseModel)) or is_dataclass(
|
|
156
|
+
model_type
|
|
157
|
+
):
|
|
158
|
+
instance = model_type(**values)
|
|
159
|
+
except ValidationError as exc:
|
|
160
|
+
self._constraint_reporter.finish_model(
|
|
161
|
+
model_type,
|
|
162
|
+
success=False,
|
|
163
|
+
errors=exc.errors(),
|
|
164
|
+
)
|
|
165
|
+
return None
|
|
109
166
|
except Exception:
|
|
167
|
+
self._constraint_reporter.finish_model(model_type, success=False)
|
|
110
168
|
return None
|
|
111
|
-
return None
|
|
112
|
-
|
|
113
|
-
def _evaluate_strategy(self, strategy: StrategyResult, depth: int) -> Any:
|
|
114
|
-
if isinstance(strategy, UnionStrategy):
|
|
115
|
-
return self._evaluate_union(strategy, depth)
|
|
116
|
-
return self._evaluate_single(strategy, depth)
|
|
117
169
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
if not choices:
|
|
170
|
+
if instance is None:
|
|
171
|
+
self._constraint_reporter.finish_model(model_type, success=False)
|
|
121
172
|
return None
|
|
122
173
|
|
|
123
|
-
|
|
124
|
-
return
|
|
174
|
+
self._constraint_reporter.finish_model(model_type, success=True)
|
|
175
|
+
return instance
|
|
125
176
|
|
|
126
|
-
def
|
|
127
|
-
|
|
128
|
-
|
|
177
|
+
def _evaluate_strategy(
|
|
178
|
+
self,
|
|
179
|
+
strategy: StrategyResult,
|
|
180
|
+
depth: int,
|
|
181
|
+
model_type: type[Any],
|
|
182
|
+
field_name: str,
|
|
183
|
+
) -> Any:
|
|
184
|
+
if isinstance(strategy, UnionStrategy):
|
|
185
|
+
return self._evaluate_union(strategy, depth, model_type, field_name)
|
|
186
|
+
return self._evaluate_single(strategy, depth, model_type, field_name)
|
|
187
|
+
|
|
188
|
+
def _make_path_entry(self, model_type: type[Any], via_field: str | None) -> _PathEntry:
|
|
189
|
+
module = getattr(model_type, "__module__", "<unknown>")
|
|
190
|
+
qualname = getattr(
|
|
191
|
+
model_type,
|
|
192
|
+
"__qualname__",
|
|
193
|
+
getattr(model_type, "__name__", str(model_type)),
|
|
194
|
+
)
|
|
195
|
+
return _PathEntry(module=module, qualname=qualname, via_field=via_field)
|
|
129
196
|
|
|
130
|
-
|
|
131
|
-
|
|
197
|
+
def _apply_field_policies(self, field_name: str, strategy: StrategyResult) -> None:
|
|
198
|
+
if self._field_policy_set is None:
|
|
199
|
+
return
|
|
200
|
+
full_path, aliases = self._current_field_paths(field_name)
|
|
201
|
+
try:
|
|
202
|
+
policy_values = self._field_policy_set.resolve(full_path, aliases=aliases)
|
|
203
|
+
except FieldPolicyConflictError as exc:
|
|
204
|
+
raise ConfigError(str(exc)) from exc
|
|
132
205
|
|
|
133
|
-
if
|
|
134
|
-
return
|
|
206
|
+
if not policy_values:
|
|
207
|
+
return
|
|
135
208
|
|
|
136
|
-
|
|
209
|
+
if isinstance(strategy, UnionStrategy):
|
|
210
|
+
union_override = policy_values.get("union_policy")
|
|
211
|
+
if union_override is not None:
|
|
212
|
+
strategy.policy = union_override
|
|
213
|
+
element_policy = {
|
|
214
|
+
key: policy_values[key]
|
|
215
|
+
for key in ("p_none", "enum_policy")
|
|
216
|
+
if key in policy_values and policy_values[key] is not None
|
|
217
|
+
}
|
|
218
|
+
if element_policy:
|
|
219
|
+
for choice in strategy.choices:
|
|
220
|
+
self._apply_field_policy_to_strategy(choice, element_policy)
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
self._apply_field_policy_to_strategy(strategy, policy_values)
|
|
224
|
+
|
|
225
|
+
def _current_field_paths(self, field_name: str) -> tuple[str, tuple[str, ...]]:
|
|
226
|
+
if not self._path_stack:
|
|
227
|
+
return field_name, ()
|
|
228
|
+
|
|
229
|
+
full_segments: list[str] = []
|
|
230
|
+
name_segments: list[str] = []
|
|
231
|
+
model_segments: list[str] = []
|
|
232
|
+
field_segments: list[str] = []
|
|
233
|
+
|
|
234
|
+
for index, entry in enumerate(self._path_stack):
|
|
235
|
+
if index == 0:
|
|
236
|
+
full_segments.append(entry.full)
|
|
237
|
+
name_segments.append(entry.qualname)
|
|
238
|
+
model_segments.append(entry.qualname)
|
|
239
|
+
else:
|
|
240
|
+
if entry.via_field:
|
|
241
|
+
full_segments.append(entry.via_field)
|
|
242
|
+
name_segments.append(entry.via_field)
|
|
243
|
+
field_segments.append(entry.via_field)
|
|
244
|
+
full_segments.append(entry.full)
|
|
245
|
+
name_segments.append(entry.qualname)
|
|
246
|
+
model_segments.append(entry.qualname)
|
|
247
|
+
|
|
248
|
+
full_path = ".".join((*full_segments, field_name))
|
|
249
|
+
|
|
250
|
+
alias_candidates: list[str] = []
|
|
251
|
+
|
|
252
|
+
alias_candidates.append(".".join((*name_segments, field_name)))
|
|
253
|
+
alias_candidates.append(".".join((*model_segments, field_name)))
|
|
254
|
+
|
|
255
|
+
if field_segments:
|
|
256
|
+
alias_candidates.append(".".join((*field_segments, field_name)))
|
|
257
|
+
root_fields = ".".join((name_segments[0], *field_segments, field_name))
|
|
258
|
+
alias_candidates.append(root_fields)
|
|
259
|
+
|
|
260
|
+
last_entry = self._path_stack[-1]
|
|
261
|
+
alias_candidates.append(".".join((last_entry.qualname, field_name)))
|
|
262
|
+
alias_candidates.append(".".join((last_entry.full, field_name)))
|
|
263
|
+
alias_candidates.append(field_name)
|
|
264
|
+
|
|
265
|
+
aliases = self._dedupe_paths(alias_candidates, exclude=full_path)
|
|
266
|
+
return full_path, aliases
|
|
137
267
|
|
|
138
|
-
|
|
139
|
-
|
|
268
|
+
@staticmethod
|
|
269
|
+
def _dedupe_paths(paths: Iterable[str], *, exclude: str) -> tuple[str, ...]:
|
|
270
|
+
seen: dict[str, None] = {}
|
|
271
|
+
for path in paths:
|
|
272
|
+
if not path or path == exclude or path in seen:
|
|
273
|
+
continue
|
|
274
|
+
seen[path] = None
|
|
275
|
+
return tuple(seen.keys())
|
|
140
276
|
|
|
141
|
-
|
|
142
|
-
|
|
277
|
+
def _apply_field_policy_to_strategy(
|
|
278
|
+
self,
|
|
279
|
+
strategy: Strategy,
|
|
280
|
+
policy_values: Mapping[str, Any],
|
|
281
|
+
) -> None:
|
|
282
|
+
if "p_none" in policy_values and policy_values["p_none"] is not None:
|
|
283
|
+
strategy.p_none = policy_values["p_none"]
|
|
284
|
+
if "enum_policy" in policy_values and policy_values["enum_policy"] is not None:
|
|
285
|
+
strategy.enum_policy = policy_values["enum_policy"]
|
|
143
286
|
|
|
144
|
-
|
|
287
|
+
def _evaluate_union(
|
|
288
|
+
self,
|
|
289
|
+
strategy: UnionStrategy,
|
|
290
|
+
depth: int,
|
|
291
|
+
model_type: type[Any],
|
|
292
|
+
field_name: str,
|
|
293
|
+
) -> Any:
|
|
294
|
+
choices = strategy.choices
|
|
295
|
+
if not choices:
|
|
145
296
|
return None
|
|
146
|
-
|
|
297
|
+
|
|
298
|
+
selected = self.random.choice(choices) if strategy.policy == "random" else choices[0]
|
|
299
|
+
return self._evaluate_single(selected, depth, model_type, field_name)
|
|
300
|
+
|
|
301
|
+
def _evaluate_single(
|
|
302
|
+
self,
|
|
303
|
+
strategy: Strategy,
|
|
304
|
+
depth: int,
|
|
305
|
+
model_type: type[Any],
|
|
306
|
+
field_name: str,
|
|
307
|
+
) -> Any:
|
|
308
|
+
summary = strategy.summary
|
|
309
|
+
self._constraint_reporter.record_field_attempt(model_type, field_name, summary)
|
|
310
|
+
|
|
311
|
+
if self._should_return_none(strategy):
|
|
312
|
+
value: Any = None
|
|
313
|
+
else:
|
|
314
|
+
enum_values = strategy.enum_values or summary.enum_values
|
|
315
|
+
if enum_values:
|
|
316
|
+
value = self._select_enum_value(strategy, enum_values)
|
|
317
|
+
else:
|
|
318
|
+
annotation = strategy.annotation
|
|
319
|
+
|
|
320
|
+
if self._is_model_like(annotation):
|
|
321
|
+
value = self._build_model_instance(
|
|
322
|
+
annotation,
|
|
323
|
+
depth=depth + 1,
|
|
324
|
+
via_field=field_name,
|
|
325
|
+
)
|
|
326
|
+
elif summary.type in {"list", "set", "tuple", "mapping"}:
|
|
327
|
+
value = self._evaluate_collection(strategy, depth)
|
|
328
|
+
else:
|
|
329
|
+
if strategy.provider_ref is None:
|
|
330
|
+
value = None
|
|
331
|
+
else:
|
|
332
|
+
value = self._call_strategy_provider(strategy)
|
|
333
|
+
|
|
334
|
+
self._constraint_reporter.record_field_value(field_name, value)
|
|
335
|
+
return value
|
|
147
336
|
|
|
148
337
|
def _evaluate_collection(self, strategy: Strategy, depth: int) -> Any:
|
|
149
338
|
summary = strategy.summary
|
|
@@ -154,13 +343,22 @@ class InstanceGenerator:
|
|
|
154
343
|
return base_value
|
|
155
344
|
|
|
156
345
|
if summary.type == "mapping":
|
|
157
|
-
return self._build_mapping_collection(
|
|
346
|
+
return self._build_mapping_collection(
|
|
347
|
+
base_value,
|
|
348
|
+
item_annotation,
|
|
349
|
+
depth,
|
|
350
|
+
strategy.field_name,
|
|
351
|
+
)
|
|
158
352
|
|
|
159
353
|
length = self._collection_length_from_value(base_value)
|
|
160
354
|
count = max(1, length)
|
|
161
355
|
items: list[Any] = []
|
|
162
356
|
for _ in range(count):
|
|
163
|
-
nested = self._build_model_instance(
|
|
357
|
+
nested = self._build_model_instance(
|
|
358
|
+
item_annotation,
|
|
359
|
+
depth=depth + 1,
|
|
360
|
+
via_field=strategy.field_name,
|
|
361
|
+
)
|
|
164
362
|
if nested is not None:
|
|
165
363
|
items.append(nested)
|
|
166
364
|
|
|
@@ -180,6 +378,7 @@ class InstanceGenerator:
|
|
|
180
378
|
base_value: Any,
|
|
181
379
|
annotation: Any,
|
|
182
380
|
depth: int,
|
|
381
|
+
field_name: str,
|
|
183
382
|
) -> dict[str, Any]:
|
|
184
383
|
if isinstance(base_value, dict) and base_value:
|
|
185
384
|
keys: Iterable[str] = base_value.keys()
|
|
@@ -190,7 +389,11 @@ class InstanceGenerator:
|
|
|
190
389
|
|
|
191
390
|
result: dict[str, Any] = {}
|
|
192
391
|
for key in keys:
|
|
193
|
-
nested = self._build_model_instance(
|
|
392
|
+
nested = self._build_model_instance(
|
|
393
|
+
annotation,
|
|
394
|
+
depth=depth + 1,
|
|
395
|
+
via_field=field_name,
|
|
396
|
+
)
|
|
194
397
|
if nested is not None:
|
|
195
398
|
result[str(key)] = nested
|
|
196
399
|
return result
|
|
@@ -297,6 +500,7 @@ class InstanceGenerator:
|
|
|
297
500
|
"summary": strategy.summary,
|
|
298
501
|
"faker": self.faker,
|
|
299
502
|
"random_generator": self.random,
|
|
503
|
+
"time_anchor": self.config.time_anchor,
|
|
300
504
|
}
|
|
301
505
|
kwargs.update(strategy.provider_kwargs)
|
|
302
506
|
|
|
@@ -7,6 +7,7 @@ import os
|
|
|
7
7
|
import tempfile
|
|
8
8
|
from dataclasses import dataclass
|
|
9
9
|
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
10
11
|
|
|
11
12
|
__all__ = ["WriteResult", "write_atomic_text", "write_atomic_bytes"]
|
|
12
13
|
|
|
@@ -19,6 +20,7 @@ class WriteResult:
|
|
|
19
20
|
wrote: bool
|
|
20
21
|
skipped: bool
|
|
21
22
|
reason: str | None = None
|
|
23
|
+
metadata: dict[str, Any] | None = None
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
def write_atomic_text(
|
|
@@ -53,7 +55,13 @@ def _write_atomic(path: Path, data: bytes, *, hash_compare: bool) -> WriteResult
|
|
|
53
55
|
if hash_compare and path.exists():
|
|
54
56
|
existing = path.read_bytes()
|
|
55
57
|
if _hash_bytes(existing) == digest:
|
|
56
|
-
return WriteResult(
|
|
58
|
+
return WriteResult(
|
|
59
|
+
path=path,
|
|
60
|
+
wrote=False,
|
|
61
|
+
skipped=True,
|
|
62
|
+
reason="unchanged",
|
|
63
|
+
metadata=None,
|
|
64
|
+
)
|
|
57
65
|
|
|
58
66
|
with tempfile.NamedTemporaryFile(
|
|
59
67
|
delete=False,
|
|
@@ -70,7 +78,7 @@ def _write_atomic(path: Path, data: bytes, *, hash_compare: bool) -> WriteResult
|
|
|
70
78
|
temp_path.unlink(missing_ok=True)
|
|
71
79
|
raise exc
|
|
72
80
|
|
|
73
|
-
return WriteResult(path=path, wrote=True, skipped=False, reason=None)
|
|
81
|
+
return WriteResult(path=path, wrote=True, skipped=False, reason=None, metadata=None)
|
|
74
82
|
|
|
75
83
|
|
|
76
84
|
def _hash_bytes(data: bytes) -> str:
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""Utilities for rendering templated output paths."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import datetime as _dt
|
|
6
|
+
import re
|
|
7
|
+
from collections.abc import Mapping, Sequence
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from string import Formatter
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .errors import EmitError
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"OutputTemplate",
|
|
17
|
+
"OutputTemplateContext",
|
|
18
|
+
"OutputTemplateError",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_ALLOWED_FIELDS = {"model", "case_index", "timestamp"}
|
|
23
|
+
_PLACEHOLDER_PATTERN = re.compile(r"{([^{}]+)}")
|
|
24
|
+
_SANITIZE_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OutputTemplateError(EmitError):
|
|
28
|
+
"""Raised when a templated path cannot be rendered safely."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, message: str, *, details: dict[str, object] | None = None) -> None:
|
|
31
|
+
super().__init__(message, details=details)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(slots=True)
|
|
35
|
+
class OutputTemplateContext:
|
|
36
|
+
"""Context data supplied when rendering an :class:`OutputTemplate`."""
|
|
37
|
+
|
|
38
|
+
model: str | None = None
|
|
39
|
+
timestamp: _dt.datetime | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class _StrictFormatter(Formatter):
|
|
43
|
+
def get_value(self, key: object, args: Sequence[Any], kwargs: Mapping[str, Any]) -> object:
|
|
44
|
+
if isinstance(key, str):
|
|
45
|
+
try:
|
|
46
|
+
return kwargs[key]
|
|
47
|
+
except KeyError as exc:
|
|
48
|
+
raise OutputTemplateError(
|
|
49
|
+
f"Missing template variable '{{{key}}}'.",
|
|
50
|
+
details={"field": key},
|
|
51
|
+
) from exc
|
|
52
|
+
raise OutputTemplateError("Positional fields are not supported in templates.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _TemplateTimestamp:
|
|
56
|
+
def __init__(self, value: _dt.datetime) -> None:
|
|
57
|
+
self._value = value
|
|
58
|
+
|
|
59
|
+
def __format__(self, format_spec: str) -> str:
|
|
60
|
+
if format_spec:
|
|
61
|
+
rendered = self._value.strftime(format_spec)
|
|
62
|
+
else:
|
|
63
|
+
rendered = self._value.strftime("%Y%m%dT%H%M%S")
|
|
64
|
+
sanitized = _sanitize_segment(rendered)
|
|
65
|
+
return sanitized or "timestamp"
|
|
66
|
+
|
|
67
|
+
def __str__(self) -> str:
|
|
68
|
+
return self.__format__("")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class _TemplateString(str):
|
|
72
|
+
def __new__(cls, value: str) -> _TemplateString:
|
|
73
|
+
sanitized = _sanitize_segment(value)
|
|
74
|
+
return str.__new__(cls, sanitized or "artifact")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class _TemplateCaseIndex(int):
|
|
78
|
+
def __new__(cls, value: int) -> _TemplateCaseIndex:
|
|
79
|
+
if value < 1:
|
|
80
|
+
raise OutputTemplateError(
|
|
81
|
+
"case_index must be >= 1",
|
|
82
|
+
details={"case_index": value},
|
|
83
|
+
)
|
|
84
|
+
return int.__new__(cls, value)
|
|
85
|
+
|
|
86
|
+
def __format__(self, format_spec: str) -> str:
|
|
87
|
+
formatted = format(int(self), format_spec) if format_spec else str(int(self))
|
|
88
|
+
return _sanitize_segment(formatted) or "1"
|
|
89
|
+
|
|
90
|
+
def __str__(self) -> str:
|
|
91
|
+
return self.__format__("")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _sanitize_segment(value: str) -> str:
|
|
95
|
+
sanitized = _SANITIZE_PATTERN.sub("_", value.strip())
|
|
96
|
+
return sanitized.strip("._-")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _collect_fields(template: str) -> set[str]:
|
|
100
|
+
return {match.group(1) for match in _PLACEHOLDER_PATTERN.finditer(template)}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _contains_placeholder(segment: str) -> bool:
|
|
104
|
+
return bool(_PLACEHOLDER_PATTERN.search(segment))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class OutputTemplate:
|
|
108
|
+
"""Render filesystem paths from user-supplied templates."""
|
|
109
|
+
|
|
110
|
+
def __init__(self, template: str | Path) -> None:
|
|
111
|
+
self._raw = str(template)
|
|
112
|
+
self._formatter = _StrictFormatter()
|
|
113
|
+
self._fields = _collect_fields(self._raw)
|
|
114
|
+
invalid = self._fields - _ALLOWED_FIELDS
|
|
115
|
+
if invalid:
|
|
116
|
+
raise OutputTemplateError(
|
|
117
|
+
"Unsupported template variable",
|
|
118
|
+
details={"fields": sorted(invalid)},
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def raw(self) -> str:
|
|
123
|
+
return self._raw
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def fields(self) -> set[str]:
|
|
127
|
+
return set(self._fields)
|
|
128
|
+
|
|
129
|
+
def uses_case_index(self) -> bool:
|
|
130
|
+
return "case_index" in self._fields
|
|
131
|
+
|
|
132
|
+
def has_dynamic_directories(self) -> bool:
|
|
133
|
+
candidate = Path(self._raw)
|
|
134
|
+
return any(_contains_placeholder(segment) for segment in candidate.parts[:-1])
|
|
135
|
+
|
|
136
|
+
def watch_parent(self) -> Path:
|
|
137
|
+
candidate = Path(self._raw)
|
|
138
|
+
parts = list(candidate.parts)
|
|
139
|
+
if len(parts) <= 1:
|
|
140
|
+
return Path(".")
|
|
141
|
+
|
|
142
|
+
stable: list[str] = []
|
|
143
|
+
for part in parts[:-1]:
|
|
144
|
+
if _contains_placeholder(part):
|
|
145
|
+
break
|
|
146
|
+
stable.append(part)
|
|
147
|
+
|
|
148
|
+
if not stable:
|
|
149
|
+
return Path(".")
|
|
150
|
+
return Path(*stable)
|
|
151
|
+
|
|
152
|
+
def preview_path(self) -> Path:
|
|
153
|
+
context = OutputTemplateContext(model="preview")
|
|
154
|
+
case_index = 1 if self.uses_case_index() else None
|
|
155
|
+
return self.render(context=context, case_index=case_index)
|
|
156
|
+
|
|
157
|
+
def render(
|
|
158
|
+
self,
|
|
159
|
+
*,
|
|
160
|
+
context: OutputTemplateContext | None = None,
|
|
161
|
+
case_index: int | None = None,
|
|
162
|
+
) -> Path:
|
|
163
|
+
if not self._fields:
|
|
164
|
+
return Path(self._raw).expanduser()
|
|
165
|
+
|
|
166
|
+
ctx = context or OutputTemplateContext()
|
|
167
|
+
values: dict[str, object] = {}
|
|
168
|
+
if "model" in self._fields:
|
|
169
|
+
if ctx.model is None:
|
|
170
|
+
raise OutputTemplateError(
|
|
171
|
+
"Template variable '{model}' requires a model context.",
|
|
172
|
+
)
|
|
173
|
+
values["model"] = _TemplateString(ctx.model)
|
|
174
|
+
if "timestamp" in self._fields:
|
|
175
|
+
timestamp = ctx.timestamp or _dt.datetime.now(_dt.timezone.utc)
|
|
176
|
+
values["timestamp"] = _TemplateTimestamp(timestamp)
|
|
177
|
+
if "case_index" in self._fields:
|
|
178
|
+
if case_index is None:
|
|
179
|
+
raise OutputTemplateError(
|
|
180
|
+
"Template variable '{case_index}' requires an index.",
|
|
181
|
+
)
|
|
182
|
+
values["case_index"] = _TemplateCaseIndex(case_index)
|
|
183
|
+
|
|
184
|
+
rendered = self._formatter.format(self._raw, **values).strip()
|
|
185
|
+
if not rendered:
|
|
186
|
+
raise OutputTemplateError(
|
|
187
|
+
"Rendered template produced an empty path.",
|
|
188
|
+
details={"template": self._raw},
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
path = Path(rendered).expanduser()
|
|
192
|
+
if any(part == ".." for part in path.parts):
|
|
193
|
+
raise OutputTemplateError(
|
|
194
|
+
"Templates cannot traverse above the working directory.",
|
|
195
|
+
details={"path": rendered},
|
|
196
|
+
)
|
|
197
|
+
return path
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Preset definitions for curated generation policy bundles."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class PresetSpec:
|
|
12
|
+
"""Descriptor for a preset and the configuration values it applies."""
|
|
13
|
+
|
|
14
|
+
name: str
|
|
15
|
+
description: str
|
|
16
|
+
settings: Mapping[str, Any]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
_PRESET_DEFINITIONS: dict[str, PresetSpec] = {
|
|
20
|
+
"boundary": PresetSpec(
|
|
21
|
+
name="boundary",
|
|
22
|
+
description=(
|
|
23
|
+
"Favor random union/enum selection and increase optional None probability for"
|
|
24
|
+
" boundary-focused datasets."
|
|
25
|
+
),
|
|
26
|
+
settings={
|
|
27
|
+
"union_policy": "random",
|
|
28
|
+
"enum_policy": "random",
|
|
29
|
+
"p_none": 0.35,
|
|
30
|
+
},
|
|
31
|
+
),
|
|
32
|
+
"boundary-max": PresetSpec(
|
|
33
|
+
name="boundary-max",
|
|
34
|
+
description=(
|
|
35
|
+
"Aggressive boundary exploration with high optional None probability and"
|
|
36
|
+
" randomized union/enum selection."
|
|
37
|
+
),
|
|
38
|
+
settings={
|
|
39
|
+
"union_policy": "random",
|
|
40
|
+
"enum_policy": "random",
|
|
41
|
+
"p_none": 0.6,
|
|
42
|
+
"json": {"indent": 0},
|
|
43
|
+
},
|
|
44
|
+
),
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
_ALIASES: dict[str, str] = {
|
|
48
|
+
"edge": "boundary",
|
|
49
|
+
"boundary-heavy": "boundary-max",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def normalize_preset_name(name: str) -> str:
|
|
54
|
+
"""Normalize a preset name applying aliases and case folding."""
|
|
55
|
+
|
|
56
|
+
key = name.strip().lower()
|
|
57
|
+
return _ALIASES.get(key, key)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_preset_spec(name: str) -> PresetSpec:
|
|
61
|
+
"""Return the preset specification for ``name`` or raise ``KeyError``."""
|
|
62
|
+
|
|
63
|
+
normalized = normalize_preset_name(name)
|
|
64
|
+
return _PRESET_DEFINITIONS[normalized]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def available_presets() -> list[PresetSpec]:
|
|
68
|
+
"""Return the list of available preset specifications."""
|
|
69
|
+
|
|
70
|
+
return list(_PRESET_DEFINITIONS.values())
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
__all__ = ["PresetSpec", "available_presets", "get_preset_spec", "normalize_preset_name"]
|