pydantic-fixturegen 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-fixturegen might be problematic. Click here for more details.

Files changed (41) hide show
  1. pydantic_fixturegen/__init__.py +7 -0
  2. pydantic_fixturegen/cli/__init__.py +85 -0
  3. pydantic_fixturegen/cli/doctor.py +235 -0
  4. pydantic_fixturegen/cli/gen/__init__.py +23 -0
  5. pydantic_fixturegen/cli/gen/_common.py +139 -0
  6. pydantic_fixturegen/cli/gen/explain.py +145 -0
  7. pydantic_fixturegen/cli/gen/fixtures.py +283 -0
  8. pydantic_fixturegen/cli/gen/json.py +262 -0
  9. pydantic_fixturegen/cli/gen/schema.py +164 -0
  10. pydantic_fixturegen/cli/list.py +164 -0
  11. pydantic_fixturegen/core/__init__.py +103 -0
  12. pydantic_fixturegen/core/ast_discover.py +169 -0
  13. pydantic_fixturegen/core/config.py +440 -0
  14. pydantic_fixturegen/core/errors.py +136 -0
  15. pydantic_fixturegen/core/generate.py +311 -0
  16. pydantic_fixturegen/core/introspect.py +141 -0
  17. pydantic_fixturegen/core/io_utils.py +77 -0
  18. pydantic_fixturegen/core/providers/__init__.py +32 -0
  19. pydantic_fixturegen/core/providers/collections.py +74 -0
  20. pydantic_fixturegen/core/providers/identifiers.py +68 -0
  21. pydantic_fixturegen/core/providers/numbers.py +133 -0
  22. pydantic_fixturegen/core/providers/registry.py +98 -0
  23. pydantic_fixturegen/core/providers/strings.py +109 -0
  24. pydantic_fixturegen/core/providers/temporal.py +42 -0
  25. pydantic_fixturegen/core/safe_import.py +403 -0
  26. pydantic_fixturegen/core/schema.py +320 -0
  27. pydantic_fixturegen/core/seed.py +154 -0
  28. pydantic_fixturegen/core/strategies.py +193 -0
  29. pydantic_fixturegen/core/version.py +52 -0
  30. pydantic_fixturegen/emitters/__init__.py +15 -0
  31. pydantic_fixturegen/emitters/json_out.py +373 -0
  32. pydantic_fixturegen/emitters/pytest_codegen.py +365 -0
  33. pydantic_fixturegen/emitters/schema_out.py +84 -0
  34. pydantic_fixturegen/plugins/builtin.py +45 -0
  35. pydantic_fixturegen/plugins/hookspecs.py +59 -0
  36. pydantic_fixturegen/plugins/loader.py +72 -0
  37. pydantic_fixturegen-1.0.0.dist-info/METADATA +280 -0
  38. pydantic_fixturegen-1.0.0.dist-info/RECORD +41 -0
  39. pydantic_fixturegen-1.0.0.dist-info/WHEEL +4 -0
  40. pydantic_fixturegen-1.0.0.dist-info/entry_points.txt +5 -0
  41. pydantic_fixturegen-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,365 @@
1
+ """Emit pytest fixture modules from Pydantic models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import re
7
+ import shutil
8
+ import subprocess
9
+ from collections.abc import Iterable, Sequence
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from pprint import pformat
13
+ from typing import Any, Literal, cast
14
+
15
+ from pydantic import BaseModel
16
+
17
+ from pydantic_fixturegen.core.generate import GenerationConfig, InstanceGenerator
18
+ from pydantic_fixturegen.core.io_utils import WriteResult, write_atomic_text
19
+ from pydantic_fixturegen.core.version import build_artifact_header
20
+
21
+ DEFAULT_SCOPE = "function"
22
+ ALLOWED_SCOPES: set[str] = {"function", "module", "session"}
23
+ DEFAULT_STYLE: Literal["functions", "factory", "class"] = "functions"
24
+ DEFAULT_RETURN_TYPE: Literal["model", "dict"] = "model"
25
+
26
+
27
+ @dataclass(slots=True)
28
+ class PytestEmitConfig:
29
+ """Configuration for pytest fixture emission."""
30
+
31
+ scope: str = DEFAULT_SCOPE
32
+ style: Literal["functions", "factory", "class"] = DEFAULT_STYLE
33
+ return_type: Literal["model", "dict"] = DEFAULT_RETURN_TYPE
34
+ cases: int = 1
35
+ seed: int | None = None
36
+ optional_p_none: float | None = None
37
+ model_digest: str | None = None
38
+ hash_compare: bool = True
39
+
40
+
41
+ def emit_pytest_fixtures(
42
+ models: Sequence[type[BaseModel]],
43
+ *,
44
+ output_path: str | Path,
45
+ config: PytestEmitConfig | None = None,
46
+ ) -> WriteResult:
47
+ """Generate pytest fixture code for ``models`` and write it atomically."""
48
+
49
+ if not models:
50
+ raise ValueError("At least one model must be provided.")
51
+
52
+ cfg = config or PytestEmitConfig()
53
+ if cfg.scope not in ALLOWED_SCOPES:
54
+ raise ValueError(f"Unsupported fixture scope: {cfg.scope!r}")
55
+ if cfg.cases < 1:
56
+ raise ValueError("cases must be >= 1.")
57
+ if cfg.style not in {"functions", "factory", "class"}:
58
+ raise ValueError(f"Unsupported pytest fixture style: {cfg.style!r}")
59
+ if cfg.return_type not in {"model", "dict"}:
60
+ raise ValueError(f"Unsupported return_type: {cfg.return_type!r}")
61
+
62
+ generation_config = GenerationConfig(seed=cfg.seed)
63
+ if cfg.optional_p_none is not None:
64
+ generation_config.optional_p_none = cfg.optional_p_none
65
+ generator = InstanceGenerator(config=generation_config)
66
+
67
+ model_entries: list[_ModelEntry] = []
68
+ fixture_names: dict[str, int] = {}
69
+ helper_names: dict[str, int] = {}
70
+
71
+ for model in models:
72
+ instances = generator.generate(model, count=cfg.cases)
73
+ if len(instances) < cfg.cases:
74
+ raise RuntimeError(
75
+ f"Failed to generate {cfg.cases} instance(s) for {model.__qualname__}."
76
+ )
77
+ data = [_model_to_literal(instance) for instance in instances]
78
+ base_name = model.__name__
79
+ if cfg.style in {"factory", "class"}:
80
+ base_name = f"{base_name}_factory"
81
+ fixture_name = _unique_fixture_name(base_name, fixture_names)
82
+ helper_name = None
83
+ if cfg.style == "class":
84
+ helper_base = f"{model.__name__}Factory"
85
+ helper_name = _unique_helper_name(helper_base, helper_names)
86
+ model_entries.append(
87
+ _ModelEntry(
88
+ model=model,
89
+ data=data,
90
+ fixture_name=fixture_name,
91
+ helper_name=helper_name,
92
+ )
93
+ )
94
+
95
+ rendered = _render_module(
96
+ entries=model_entries,
97
+ config=cfg,
98
+ )
99
+ result = write_atomic_text(
100
+ output_path,
101
+ rendered,
102
+ hash_compare=cfg.hash_compare,
103
+ )
104
+ return result
105
+
106
+
107
+ # --------------------------------------------------------------------------- rendering helpers
108
+ @dataclass(slots=True)
109
+ class _ModelEntry:
110
+ model: type[BaseModel]
111
+ data: list[dict[str, Any]]
112
+ fixture_name: str
113
+ helper_name: str | None = None
114
+
115
+
116
+ def _render_module(*, entries: Iterable[_ModelEntry], config: PytestEmitConfig) -> str:
117
+ entries_list = list(entries)
118
+ models_metadata = ", ".join(
119
+ f"{entry.model.__module__}.{entry.model.__name__}" for entry in entries_list
120
+ )
121
+ header = build_artifact_header(
122
+ seed=config.seed,
123
+ model_digest=config.model_digest,
124
+ extras={
125
+ "style": config.style,
126
+ "scope": config.scope,
127
+ "return": config.return_type,
128
+ "cases": config.cases,
129
+ "models": models_metadata,
130
+ },
131
+ )
132
+
133
+ needs_any = config.return_type == "dict" or config.style in {"factory", "class"}
134
+ needs_callable = config.style == "factory"
135
+ module_imports = _collect_model_imports(entries_list)
136
+
137
+ lines: list[str] = []
138
+ lines.append("from __future__ import annotations")
139
+ lines.append("")
140
+ lines.append(f"# {header}")
141
+ lines.append("")
142
+ lines.append("import pytest")
143
+ typing_imports: list[str] = []
144
+ if needs_any:
145
+ typing_imports.append("Any")
146
+ if needs_callable:
147
+ typing_imports.append("Callable")
148
+ if typing_imports:
149
+ items = ", ".join(sorted(set(typing_imports)))
150
+ lines.append(f"from typing import {items}")
151
+ for module, names in module_imports.items():
152
+ joined = ", ".join(sorted(names))
153
+ lines.append(f"from {module} import {joined}")
154
+
155
+ for entry in entries_list:
156
+ if config.style == "class":
157
+ lines.append("")
158
+ lines.extend(_render_factory_class(entry, config=config))
159
+ lines.append("")
160
+ lines.extend(
161
+ _render_fixture(entry, config=config),
162
+ )
163
+
164
+ lines.append("")
165
+ return _format_code("\n".join(lines))
166
+
167
+
168
+ def _collect_model_imports(entries: Iterable[_ModelEntry]) -> dict[str, set[str]]:
169
+ imports: dict[str, set[str]] = {}
170
+ for entry in entries:
171
+ imports.setdefault(entry.model.__module__, set()).add(entry.model.__name__)
172
+ return imports
173
+
174
+
175
+ def _render_fixture(entry: _ModelEntry, *, config: PytestEmitConfig) -> list[str]:
176
+ if config.style == "functions":
177
+ return _render_functions_fixture(entry, config=config)
178
+ if config.style == "factory":
179
+ return _render_factory_fixture(entry, config=config)
180
+ return _render_class_fixture(entry, config=config)
181
+
182
+
183
+ def _render_functions_fixture(entry: _ModelEntry, *, config: PytestEmitConfig) -> list[str]:
184
+ annotation = entry.model.__name__ if config.return_type == "model" else "dict[str, Any]"
185
+ has_params = len(entry.data) > 1
186
+ params_literal = _format_literal(entry.data) if has_params else None
187
+
188
+ lines: list[str] = []
189
+ if has_params:
190
+ lines.append(f'@pytest.fixture(scope="{config.scope}", params={params_literal})')
191
+ else:
192
+ lines.append(f'@pytest.fixture(scope="{config.scope}")')
193
+
194
+ arglist = "request" if has_params else ""
195
+ signature = f"def {entry.fixture_name}({arglist}) -> {annotation}:"
196
+ lines.append(signature)
197
+
198
+ if has_params:
199
+ lines.append(" data = request.param")
200
+ else:
201
+ data_literal = _format_literal(entry.data[0])
202
+ lines.extend(_format_assignment_lines("data", data_literal))
203
+
204
+ if config.return_type == "model":
205
+ lines.append(f" return {entry.model.__name__}.model_validate(data)")
206
+ else:
207
+ lines.append(" return dict(data)")
208
+
209
+ return lines
210
+
211
+
212
+ def _render_factory_fixture(entry: _ModelEntry, *, config: PytestEmitConfig) -> list[str]:
213
+ return_annotation = entry.model.__name__ if config.return_type == "model" else "dict[str, Any]"
214
+ fixture_annotation = f"Callable[[dict[str, Any] | None], {return_annotation}]"
215
+ has_params = len(entry.data) > 1
216
+ params_literal = _format_literal(entry.data) if has_params else None
217
+
218
+ lines: list[str] = []
219
+ if has_params:
220
+ lines.append(f'@pytest.fixture(scope="{config.scope}", params={params_literal})')
221
+ else:
222
+ lines.append(f'@pytest.fixture(scope="{config.scope}")')
223
+
224
+ arglist = "request" if has_params else ""
225
+ signature = f"def {entry.fixture_name}({arglist}) -> {fixture_annotation}:"
226
+ lines.append(signature)
227
+
228
+ if has_params:
229
+ lines.append(" base_data = request.param")
230
+ else:
231
+ base_literal = _format_literal(entry.data[0])
232
+ lines.extend(_format_assignment_lines("base_data", base_literal))
233
+
234
+ lines.append(
235
+ " def builder(overrides: dict[str, Any] | None = None) -> " + return_annotation + ":"
236
+ )
237
+ lines.append(" data = dict(base_data)")
238
+ lines.append(" if overrides:")
239
+ lines.append(" data.update(overrides)")
240
+ if config.return_type == "model":
241
+ lines.append(f" return {entry.model.__name__}.model_validate(data)")
242
+ else:
243
+ lines.append(" return dict(data)")
244
+ lines.append(" return builder")
245
+
246
+ return lines
247
+
248
+
249
+ def _render_factory_class(entry: _ModelEntry, *, config: PytestEmitConfig) -> list[str]:
250
+ class_name = entry.helper_name or f"{entry.model.__name__}Factory"
251
+ return_annotation = entry.model.__name__ if config.return_type == "model" else "dict[str, Any]"
252
+
253
+ lines = [f"class {class_name}:"]
254
+ lines.append(" def __init__(self, base_data: dict[str, Any]) -> None:")
255
+ lines.append(" self._base_data = dict(base_data)")
256
+ lines.append("")
257
+ lines.append(f" def build(self, **overrides: Any) -> {return_annotation}:")
258
+ lines.append(" data = dict(self._base_data)")
259
+ lines.append(" if overrides:")
260
+ lines.append(" data.update(overrides)")
261
+ if config.return_type == "model":
262
+ lines.append(f" return {entry.model.__name__}.model_validate(data)")
263
+ else:
264
+ lines.append(" return dict(data)")
265
+
266
+ return lines
267
+
268
+
269
+ def _render_class_fixture(entry: _ModelEntry, *, config: PytestEmitConfig) -> list[str]:
270
+ class_name = entry.helper_name or f"{entry.model.__name__}Factory"
271
+ annotation = class_name
272
+ has_params = len(entry.data) > 1
273
+ params_literal = _format_literal(entry.data) if has_params else None
274
+
275
+ lines: list[str] = []
276
+ if has_params:
277
+ lines.append(f'@pytest.fixture(scope="{config.scope}", params={params_literal})')
278
+ else:
279
+ lines.append(f'@pytest.fixture(scope="{config.scope}")')
280
+
281
+ arglist = "request" if has_params else ""
282
+ signature = f"def {entry.fixture_name}({arglist}) -> {annotation}:"
283
+ lines.append(signature)
284
+
285
+ if has_params:
286
+ lines.append(" base_data = request.param")
287
+ else:
288
+ base_literal = _format_literal(entry.data[0])
289
+ lines.extend(_format_assignment_lines("base_data", base_literal))
290
+
291
+ lines.append(f" return {class_name}(base_data)")
292
+
293
+ return lines
294
+
295
+
296
+ def _format_literal(value: Any) -> str:
297
+ return pformat(value, width=88, sort_dicts=True)
298
+
299
+
300
+ def _format_assignment_lines(var_name: str, literal: str) -> list[str]:
301
+ if "\n" not in literal:
302
+ return [f" {var_name} = {literal}"]
303
+
304
+ pieces = literal.splitlines()
305
+ result = [f" {var_name} = {pieces[0]}"]
306
+ for piece in pieces[1:]:
307
+ result.append(f" {piece}")
308
+ return result
309
+
310
+
311
+ def _unique_fixture_name(base: str, seen: dict[str, int]) -> str:
312
+ candidate = _to_snake_case(base)
313
+ count = seen.get(candidate, 0)
314
+ seen[candidate] = count + 1
315
+ if count == 0:
316
+ return candidate
317
+ return f"{candidate}_{count + 1}"
318
+
319
+
320
+ def _unique_helper_name(base: str, seen: dict[str, int]) -> str:
321
+ count = seen.get(base, 0)
322
+ seen[base] = count + 1
323
+ if count == 0:
324
+ return base
325
+ return f"{base}{count + 1}"
326
+
327
+
328
+ _CAMEL_CASE_PATTERN_1 = re.compile("(.)([A-Z][a-z]+)")
329
+ _CAMEL_CASE_PATTERN_2 = re.compile("([a-z0-9])([A-Z])")
330
+
331
+
332
+ def _to_snake_case(name: str) -> str:
333
+ name = _CAMEL_CASE_PATTERN_1.sub(r"\1_\2", name)
334
+ name = _CAMEL_CASE_PATTERN_2.sub(r"\1_\2", name)
335
+ return name.lower()
336
+
337
+
338
+ def _model_to_literal(instance: BaseModel) -> dict[str, Any]:
339
+ raw = instance.model_dump(mode="json")
340
+ serialized = json.dumps(raw, sort_keys=True, ensure_ascii=False)
341
+ return cast(dict[str, Any], json.loads(serialized))
342
+
343
+
344
+ def _format_code(source: str) -> str:
345
+ formatter = shutil.which("ruff")
346
+ if not formatter:
347
+ return source
348
+
349
+ try:
350
+ proc = subprocess.run(
351
+ [formatter, "format", "--stdin-filename", "fixtures.py", "-"],
352
+ input=source.encode("utf-8"),
353
+ capture_output=True,
354
+ check=False,
355
+ )
356
+ except OSError:
357
+ return source
358
+
359
+ if proc.returncode != 0 or not proc.stdout:
360
+ return source
361
+
362
+ try:
363
+ return proc.stdout.decode("utf-8")
364
+ except UnicodeDecodeError:
365
+ return source
@@ -0,0 +1,84 @@
1
+ """Schema emitter utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from collections.abc import Iterable
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class SchemaEmitConfig:
16
+ output_path: Path
17
+ indent: int | None = 2
18
+ ensure_ascii: bool = False
19
+
20
+
21
+ def emit_model_schema(
22
+ model: type[BaseModel],
23
+ *,
24
+ output_path: str | Path,
25
+ indent: int | None = 2,
26
+ ensure_ascii: bool = False,
27
+ ) -> Path:
28
+ """Write the model JSON schema to ``output_path``."""
29
+
30
+ config = SchemaEmitConfig(
31
+ output_path=Path(output_path),
32
+ indent=_normalise_indent(indent),
33
+ ensure_ascii=ensure_ascii,
34
+ )
35
+ schema = model.model_json_schema()
36
+ payload = json.dumps(
37
+ schema,
38
+ indent=config.indent,
39
+ ensure_ascii=config.ensure_ascii,
40
+ sort_keys=True,
41
+ )
42
+ config.output_path.parent.mkdir(parents=True, exist_ok=True)
43
+ config.output_path.write_text(payload, encoding="utf-8")
44
+ return config.output_path
45
+
46
+
47
+ def emit_models_schema(
48
+ models: Iterable[type[BaseModel]],
49
+ *,
50
+ output_path: str | Path,
51
+ indent: int | None = 2,
52
+ ensure_ascii: bool = False,
53
+ ) -> Path:
54
+ """Emit a combined schema referencing each model by its qualified name."""
55
+
56
+ config = SchemaEmitConfig(
57
+ output_path=Path(output_path),
58
+ indent=_normalise_indent(indent),
59
+ ensure_ascii=ensure_ascii,
60
+ )
61
+ combined: dict[str, Any] = {}
62
+ for model in models:
63
+ combined[model.__name__] = model.model_json_schema()
64
+
65
+ payload = json.dumps(
66
+ combined,
67
+ indent=config.indent,
68
+ ensure_ascii=config.ensure_ascii,
69
+ sort_keys=True,
70
+ )
71
+ config.output_path.parent.mkdir(parents=True, exist_ok=True)
72
+ config.output_path.write_text(payload, encoding="utf-8")
73
+ return config.output_path
74
+
75
+
76
+ def _normalise_indent(indent: int | None) -> int | None:
77
+ if indent is None or indent == 0:
78
+ return None
79
+ if indent < 0:
80
+ raise ValueError("indent must be >= 0")
81
+ return indent
82
+
83
+
84
+ __all__ = ["SchemaEmitConfig", "emit_model_schema", "emit_models_schema"]
@@ -0,0 +1,45 @@
1
+ """Built-in plugin registrations exposed via entry points."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+
7
+ from pydantic_fixturegen.core.providers import (
8
+ register_collection_providers,
9
+ register_identifier_providers,
10
+ register_numeric_providers,
11
+ register_string_providers,
12
+ register_temporal_providers,
13
+ )
14
+ from pydantic_fixturegen.core.providers.registry import ProviderRegistry
15
+ from pydantic_fixturegen.plugins.hookspecs import hookimpl
16
+
17
+ _SENTINEL_KEYS = {
18
+ "numbers": "int",
19
+ "strings": "string",
20
+ "collections": "list",
21
+ "temporal": "datetime",
22
+ "identifiers": "email",
23
+ }
24
+
25
+
26
+ RegisterFunc = Callable[[ProviderRegistry], None]
27
+
28
+
29
+ def _ensure_registered(registry: ProviderRegistry, key: str, register_func: RegisterFunc) -> None:
30
+ if registry.get(key) is None:
31
+ register_func(registry)
32
+
33
+
34
+ @hookimpl
35
+ def pfg_register_providers(registry: ProviderRegistry) -> None:
36
+ """Register built-in providers when entry point loading is used."""
37
+
38
+ _ensure_registered(registry, _SENTINEL_KEYS["numbers"], register_numeric_providers)
39
+ _ensure_registered(registry, _SENTINEL_KEYS["strings"], register_string_providers)
40
+ _ensure_registered(registry, _SENTINEL_KEYS["collections"], register_collection_providers)
41
+ _ensure_registered(registry, _SENTINEL_KEYS["temporal"], register_temporal_providers)
42
+ _ensure_registered(registry, _SENTINEL_KEYS["identifiers"], register_identifier_providers)
43
+
44
+
45
+ __all__ = ["pfg_register_providers"]
@@ -0,0 +1,59 @@
1
+ """Hookspec definitions for pydantic-fixturegen plugins."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Mapping, Sequence
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ import pluggy
11
+
12
+ if TYPE_CHECKING: # pragma: no cover
13
+ from pydantic import BaseModel
14
+ from pydantic_fixturegen.core.providers import ProviderRegistry
15
+ from pydantic_fixturegen.core.strategies import Strategy
16
+
17
+ hookspec = pluggy.HookspecMarker("pfg")
18
+ hookimpl = pluggy.HookimplMarker("pfg")
19
+
20
+
21
+ @dataclass(slots=True)
22
+ class EmitterContext:
23
+ """Context passed to emitter plugins."""
24
+
25
+ models: Sequence[type[BaseModel]]
26
+ output: Path
27
+ parameters: Mapping[str, object]
28
+
29
+
30
+ @hookspec
31
+ def pfg_register_providers(registry: ProviderRegistry) -> None: # pragma: no cover
32
+ """Register additional providers with the given registry."""
33
+ raise NotImplementedError
34
+
35
+
36
+ @hookspec
37
+ def pfg_modify_strategy(
38
+ model: type[BaseModel],
39
+ field_name: str,
40
+ strategy: Strategy,
41
+ ) -> Strategy | None: # pragma: no cover
42
+ """Modify or replace the strategy chosen for a model field."""
43
+ raise NotImplementedError
44
+
45
+
46
+ @hookspec
47
+ def pfg_emit_artifact(kind: str, context: EmitterContext) -> bool: # pragma: no cover
48
+ """Handle artifact emission for ``kind``. Return True to skip default behaviour."""
49
+ raise NotImplementedError
50
+
51
+
52
+ __all__ = [
53
+ "EmitterContext",
54
+ "hookimpl",
55
+ "hookspec",
56
+ "pfg_emit_artifact",
57
+ "pfg_modify_strategy",
58
+ "pfg_register_providers",
59
+ ]
@@ -0,0 +1,72 @@
1
+ """Utilities for loading and interacting with fixturegen plugins."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable
6
+ from contextlib import suppress
7
+ from importlib import metadata
8
+ from typing import Any
9
+
10
+ import pluggy
11
+ from pydantic_fixturegen.plugins import hookspecs
12
+
13
+ _plugin_manager = pluggy.PluginManager("pfg")
14
+ _plugin_manager.add_hookspecs(hookspecs)
15
+ _loaded_groups: set[str] = set()
16
+
17
+
18
+ def get_plugin_manager() -> pluggy.PluginManager:
19
+ """Return the global plugin manager."""
20
+
21
+ return _plugin_manager
22
+
23
+
24
+ def register_plugin(plugin: Any) -> None:
25
+ """Register a plugin object with the global manager."""
26
+
27
+ with suppress(ValueError): # already registered
28
+ _plugin_manager.register(plugin)
29
+
30
+
31
+ def load_entrypoint_plugins(
32
+ group: str = "pydantic_fixturegen",
33
+ *,
34
+ force: bool = False,
35
+ ) -> list[Any]:
36
+ """Load plugins declared via entry points."""
37
+
38
+ if group in _loaded_groups and not force:
39
+ return []
40
+
41
+ if force:
42
+ _loaded_groups.discard(group)
43
+
44
+ entry_points = metadata.entry_points()
45
+ selector = getattr(entry_points, "select", None)
46
+ if selector is not None:
47
+ entries: Iterable[Any] = selector(group=group)
48
+ else: # pragma: no cover - Python <3.10 fallback
49
+ entries = entry_points.get(group, [])
50
+
51
+ plugins: list[Any] = []
52
+ for entry in entries:
53
+ plugin = entry.load()
54
+ register_plugin(plugin)
55
+ plugins.append(plugin)
56
+
57
+ _loaded_groups.add(group)
58
+ return plugins
59
+
60
+
61
+ def emit_artifact(kind: str, context: hookspecs.EmitterContext) -> bool:
62
+ """Invoke emitter plugins for the given artifact.
63
+
64
+ Returns ``True`` when a plugin handled the emission. When ``True`` is
65
+ returned, the caller should skip the default emission behaviour.
66
+ """
67
+
68
+ results = _plugin_manager.hook.pfg_emit_artifact(kind=kind, context=context)
69
+ return any(bool(result) for result in results)
70
+
71
+
72
+ __all__ = ["emit_artifact", "get_plugin_manager", "load_entrypoint_plugins", "register_plugin"]