schematico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- schematico/__init__.py +27 -0
- schematico/cli/__init__.py +0 -0
- schematico/cli/main.py +375 -0
- schematico/cli/progress.py +31 -0
- schematico/cli/projects.py +173 -0
- schematico/cli/runner.py +125 -0
- schematico/cli/wizard.py +177 -0
- schematico/discovery.py +102 -0
- schematico/generator.py +97 -0
- schematico/helpers.py +38 -0
- schematico/logging.py +34 -0
- schematico/models.py +140 -0
- schematico/providers.py +95 -0
- schematico/tools/tavily_tools.py +79 -0
- schematico-0.1.0.dist-info/METADATA +289 -0
- schematico-0.1.0.dist-info/RECORD +19 -0
- schematico-0.1.0.dist-info/WHEEL +4 -0
- schematico-0.1.0.dist-info/entry_points.txt +2 -0
- schematico-0.1.0.dist-info/licenses/LICENSE +21 -0
schematico/cli/runner.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from schematico.cli.progress import ProgressReporter
|
|
10
|
+
from schematico.cli.projects import ProjectConfig
|
|
11
|
+
from schematico.discovery import run_discovery
|
|
12
|
+
from schematico.generator import run_generation
|
|
13
|
+
from schematico.logging import get_logger
|
|
14
|
+
from schematico.models import model_from_dict, model_from_json
|
|
15
|
+
from schematico.providers import DEFAULT_MODEL, SchematicoModel, get_llm_model
|
|
16
|
+
|
|
17
|
+
logger = get_logger("cli.runner")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def run(
|
|
21
|
+
config: ProjectConfig,
|
|
22
|
+
*,
|
|
23
|
+
output_override: str | None = None,
|
|
24
|
+
count_override: int | None = None,
|
|
25
|
+
model_override: str | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
if not config.record_schema and not config.schema_path:
|
|
28
|
+
print(
|
|
29
|
+
f"schematico: error: config '{config.name}.{config.mode}.toml' has no "
|
|
30
|
+
"schema. Either fill the [schema] table or set schema_path "
|
|
31
|
+
f"(e.g. `schematico {config.mode} schema import ./schema.json`).",
|
|
32
|
+
file=sys.stderr,
|
|
33
|
+
)
|
|
34
|
+
sys.exit(1)
|
|
35
|
+
|
|
36
|
+
if config.mode == "discover" and not os.environ.get("TAVILY_API_KEY"):
|
|
37
|
+
print(
|
|
38
|
+
"schematico: error: `discover` mode searches the live web and needs a "
|
|
39
|
+
"Tavily API key. Set TAVILY_API_KEY (get one free at "
|
|
40
|
+
"https://tavily.com), or use `generate` mode to synthesize records "
|
|
41
|
+
"without web search.",
|
|
42
|
+
file=sys.stderr,
|
|
43
|
+
)
|
|
44
|
+
sys.exit(1)
|
|
45
|
+
|
|
46
|
+
model_str = model_override or config.model or DEFAULT_MODEL
|
|
47
|
+
|
|
48
|
+
api_key: str | None = None
|
|
49
|
+
if config.env_key:
|
|
50
|
+
api_key = os.environ.get(config.env_key)
|
|
51
|
+
if not api_key:
|
|
52
|
+
print(
|
|
53
|
+
f"schematico: error: env var '{config.env_key}' is not set "
|
|
54
|
+
f"(required by config '{config.name}').",
|
|
55
|
+
file=sys.stderr,
|
|
56
|
+
)
|
|
57
|
+
sys.exit(1)
|
|
58
|
+
|
|
59
|
+
llm_model = get_llm_model(
|
|
60
|
+
SchematicoModel(
|
|
61
|
+
model=model_str,
|
|
62
|
+
api_key=api_key,
|
|
63
|
+
base_url=config.base_url or None,
|
|
64
|
+
)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
output_path = output_override or config.output_path
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
if config.schema_path:
|
|
71
|
+
record_model, rows, instructions = model_from_json(config.schema_path)
|
|
72
|
+
else:
|
|
73
|
+
record_model, rows, instructions = model_from_dict(config.record_schema)
|
|
74
|
+
except (FileNotFoundError, ValueError) as e:
|
|
75
|
+
print(f"schematico: error: {e}", file=sys.stderr)
|
|
76
|
+
sys.exit(1)
|
|
77
|
+
|
|
78
|
+
samples = count_override if count_override is not None else rows
|
|
79
|
+
table = record_model.__name__.removesuffix("Record")
|
|
80
|
+
|
|
81
|
+
logger.info(
|
|
82
|
+
"Running %s for '%s': %d fields, %d records, model=%s",
|
|
83
|
+
config.mode,
|
|
84
|
+
table,
|
|
85
|
+
len(record_model.model_fields),
|
|
86
|
+
samples,
|
|
87
|
+
model_str,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
from pydantic_ai.exceptions import UserError
|
|
91
|
+
|
|
92
|
+
run_fn = run_discovery if config.mode == "discover" else run_generation
|
|
93
|
+
|
|
94
|
+
reporter = ProgressReporter(table)
|
|
95
|
+
try:
|
|
96
|
+
records = run_fn(
|
|
97
|
+
record_model,
|
|
98
|
+
samples,
|
|
99
|
+
instructions,
|
|
100
|
+
model=llm_model,
|
|
101
|
+
logfire_token=config.logfire_token or None,
|
|
102
|
+
progress_cb=reporter.update,
|
|
103
|
+
)
|
|
104
|
+
except UserError as e:
|
|
105
|
+
print(f"schematico: error: {e}", file=sys.stderr)
|
|
106
|
+
sys.exit(1)
|
|
107
|
+
reporter.done(len(records))
|
|
108
|
+
|
|
109
|
+
out_path = Path(output_path)
|
|
110
|
+
if out_path.suffix:
|
|
111
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
112
|
+
out = out_path
|
|
113
|
+
else:
|
|
114
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
116
|
+
out = out_path / f"{config.name}_{timestamp}.json"
|
|
117
|
+
try:
|
|
118
|
+
with out.open("w", encoding="utf-8") as f:
|
|
119
|
+
json.dump(records, f, indent=2, ensure_ascii=False)
|
|
120
|
+
except OSError as e:
|
|
121
|
+
print(f"schematico: error writing output: {e}", file=sys.stderr)
|
|
122
|
+
sys.exit(1)
|
|
123
|
+
|
|
124
|
+
logger.info("Wrote %d records to %s", len(records), out)
|
|
125
|
+
print(f"Generated {len(records)} records from schema '{table}' -> {out}")
|
schematico/cli/wizard.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import typer
|
|
9
|
+
|
|
10
|
+
from schematico.cli.projects import (
|
|
11
|
+
DEFAULT_COUNT,
|
|
12
|
+
DEFAULT_ENV_KEY,
|
|
13
|
+
Mode,
|
|
14
|
+
ProjectConfig,
|
|
15
|
+
ensure_config_dir,
|
|
16
|
+
next_available_name,
|
|
17
|
+
save_project,
|
|
18
|
+
set_default,
|
|
19
|
+
)
|
|
20
|
+
from schematico.models import model_from_dict
|
|
21
|
+
|
|
22
|
+
_TEMPLATE_SCHEMA: dict[str, Any] = {
|
|
23
|
+
"table": "REPLACE_WITH_TABLE_NAME",
|
|
24
|
+
"rows": 25,
|
|
25
|
+
"instructions": "",
|
|
26
|
+
"fields": [
|
|
27
|
+
{"name": "id", "type": "string", "description": "UUID v4"},
|
|
28
|
+
{"name": "full_name", "type": "string", "description": "realistic full name"},
|
|
29
|
+
{"name": "email", "type": "string", "description": "unique work email"},
|
|
30
|
+
{
|
|
31
|
+
"name": "role",
|
|
32
|
+
"type": "enum",
|
|
33
|
+
"values": ["admin", "editor", "viewer"],
|
|
34
|
+
},
|
|
35
|
+
{
|
|
36
|
+
"name": "country",
|
|
37
|
+
"type": "string",
|
|
38
|
+
"description": "ISO 3166-1 alpha-2 country code",
|
|
39
|
+
},
|
|
40
|
+
{
|
|
41
|
+
"name": "created_at",
|
|
42
|
+
"type": "string",
|
|
43
|
+
"description": "ISO 8601 UTC timestamp",
|
|
44
|
+
},
|
|
45
|
+
],
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _setup_schema(name: str, mode: Mode) -> tuple[dict[str, Any], str]:
|
|
50
|
+
"""Return (record_schema, schema_path) — exactly one is populated."""
|
|
51
|
+
if typer.confirm("Do you have an existing JSON schema file to use?", default=False):
|
|
52
|
+
path_str = typer.prompt("Path to JSON schema file")
|
|
53
|
+
p = Path(path_str).expanduser()
|
|
54
|
+
if not p.exists():
|
|
55
|
+
typer.echo(f"schematico: error: file not found: {p}", err=True)
|
|
56
|
+
raise typer.Exit(1)
|
|
57
|
+
try:
|
|
58
|
+
raw = json.loads(p.read_text(encoding="utf-8"))
|
|
59
|
+
except json.JSONDecodeError as e:
|
|
60
|
+
typer.echo(f"schematico: error: invalid JSON in '{p}': {e}", err=True)
|
|
61
|
+
raise typer.Exit(1)
|
|
62
|
+
try:
|
|
63
|
+
model_from_dict(raw)
|
|
64
|
+
except ValueError as e:
|
|
65
|
+
typer.echo(f"schematico: error: {e}", err=True)
|
|
66
|
+
raise typer.Exit(1)
|
|
67
|
+
|
|
68
|
+
embed = typer.confirm(
|
|
69
|
+
"Embed the schema into the config (vs. reference it by path)?",
|
|
70
|
+
default=True,
|
|
71
|
+
)
|
|
72
|
+
if embed:
|
|
73
|
+
typer.echo(
|
|
74
|
+
f"Imported schema from '{p}' ({len(raw.get('fields', []))} fields)."
|
|
75
|
+
)
|
|
76
|
+
return raw, ""
|
|
77
|
+
typer.echo(f"Referencing schema at '{p}'.")
|
|
78
|
+
return {}, str(p)
|
|
79
|
+
|
|
80
|
+
tpl = ensure_config_dir() / f"{name}.{mode}.schema.json"
|
|
81
|
+
if tpl.exists() and not typer.confirm(
|
|
82
|
+
f"{tpl} already exists. Overwrite with a fresh template?", default=False
|
|
83
|
+
):
|
|
84
|
+
typer.echo(f"Keeping existing template at '{tpl}'. Edit it before running.")
|
|
85
|
+
return {}, str(tpl)
|
|
86
|
+
tpl.write_text(json.dumps(_TEMPLATE_SCHEMA, indent=2) + "\n", encoding="utf-8")
|
|
87
|
+
typer.echo(f"Created template schema at '{tpl}'.")
|
|
88
|
+
typer.echo(
|
|
89
|
+
"Edit it (replace REPLACE_WITH_TABLE_NAME, adjust fields, "
|
|
90
|
+
"optionally add 'instructions') before running."
|
|
91
|
+
)
|
|
92
|
+
return {}, str(tpl)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def run_wizard() -> Path:
|
|
96
|
+
name_input = typer.prompt("Project name", default="project_1")
|
|
97
|
+
mode_input = (
|
|
98
|
+
typer.prompt("Mode ([d]iscover/[g]enerate)", default="discover").strip().lower()
|
|
99
|
+
)
|
|
100
|
+
if mode_input in ("d", "discover"):
|
|
101
|
+
mode: Mode = "discover"
|
|
102
|
+
elif mode_input in ("g", "generate"):
|
|
103
|
+
mode = "generate"
|
|
104
|
+
else:
|
|
105
|
+
typer.echo(
|
|
106
|
+
f"Invalid mode '{mode_input}'. Use 'd'/'discover' or 'g'/'generate'.",
|
|
107
|
+
err=True,
|
|
108
|
+
)
|
|
109
|
+
raise typer.Exit(1)
|
|
110
|
+
|
|
111
|
+
name = next_available_name(name_input, mode)
|
|
112
|
+
if name != name_input:
|
|
113
|
+
typer.echo(f"Name '{name_input}' is taken; using '{name}' instead.")
|
|
114
|
+
|
|
115
|
+
output_path = typer.prompt(
|
|
116
|
+
"Default output directory (filename will be <project>_<timestamp>.json)",
|
|
117
|
+
default="./.schematico/output",
|
|
118
|
+
)
|
|
119
|
+
count = typer.prompt("Count (records to generate)", default=DEFAULT_COUNT, type=int)
|
|
120
|
+
model = typer.prompt(
|
|
121
|
+
"Model (e.g. gateway/anthropic:claude-sonnet-4-6; "
|
|
122
|
+
"leave blank to inherit PAI_MODEL)",
|
|
123
|
+
default=os.environ.get("PAI_MODEL", ""),
|
|
124
|
+
show_default=True,
|
|
125
|
+
)
|
|
126
|
+
env_key = typer.prompt(
|
|
127
|
+
"Name of env var that holds your API key (not the key itself; "
|
|
128
|
+
"leave blank for local/keyless models like ollama)",
|
|
129
|
+
default=DEFAULT_ENV_KEY,
|
|
130
|
+
)
|
|
131
|
+
base_url = typer.prompt(
|
|
132
|
+
"Custom base URL (leave blank to use the provider's default; "
|
|
133
|
+
"e.g. http://localhost:11434/v1 for ollama)",
|
|
134
|
+
default="",
|
|
135
|
+
show_default=False,
|
|
136
|
+
)
|
|
137
|
+
logfire_token = ""
|
|
138
|
+
if typer.confirm(
|
|
139
|
+
"Enable Logfire observability? (You'll need a Logfire write token. "
|
|
140
|
+
"Decline to log to stdout only.)",
|
|
141
|
+
default=False,
|
|
142
|
+
):
|
|
143
|
+
logfire_token = typer.prompt(
|
|
144
|
+
"Paste your Logfire write token (e.g. pylf_v1_...)",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
record_schema, schema_path = _setup_schema(name, mode)
|
|
148
|
+
|
|
149
|
+
cfg = ProjectConfig(
|
|
150
|
+
name=name,
|
|
151
|
+
mode=mode,
|
|
152
|
+
model=model,
|
|
153
|
+
env_key=env_key,
|
|
154
|
+
base_url=base_url,
|
|
155
|
+
output_path=output_path,
|
|
156
|
+
count=count,
|
|
157
|
+
logfire_token=logfire_token,
|
|
158
|
+
schema_path=schema_path,
|
|
159
|
+
record_schema=record_schema,
|
|
160
|
+
)
|
|
161
|
+
path = save_project(cfg)
|
|
162
|
+
set_default(mode, name)
|
|
163
|
+
|
|
164
|
+
typer.echo("")
|
|
165
|
+
typer.echo(f"Created {path}")
|
|
166
|
+
typer.echo(f"Set as default for `schematico {mode}`.")
|
|
167
|
+
if record_schema:
|
|
168
|
+
typer.echo(f"Run `schematico {mode}` from this directory when ready.")
|
|
169
|
+
else:
|
|
170
|
+
typer.echo(
|
|
171
|
+
f"Edit '{schema_path}', then run `schematico {mode}` from this directory."
|
|
172
|
+
)
|
|
173
|
+
return path
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
if __name__ == "__main__":
|
|
177
|
+
typer.run(run_wizard)
|
schematico/discovery.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import logfire
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
from pydantic_ai.models import Model
|
|
4
|
+
from pydantic_ai.agent import Agent
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from schematico.helpers import _table_name, _hash_record, _describe_fields
|
|
8
|
+
from schematico.logging import get_logger
|
|
9
|
+
from schematico.models import build_batch_model
|
|
10
|
+
from schematico.providers import DEFAULT_MODEL
|
|
11
|
+
from schematico.tools.tavily_tools import (
|
|
12
|
+
search_web,
|
|
13
|
+
extract_web_content,
|
|
14
|
+
crawl_paths,
|
|
15
|
+
map_website,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = get_logger("core.discovery")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _build_prompt(schema: type[BaseModel], samples: int, instructions: str) -> str:
|
|
22
|
+
table = _table_name(schema)
|
|
23
|
+
field_lines = _describe_fields(schema)
|
|
24
|
+
prompt = (
|
|
25
|
+
f"You are a data discovery agent for the '{table}' table.\n"
|
|
26
|
+
f"Find exactly {samples} realistic, unique records with "
|
|
27
|
+
"these fields:\n" + "\n".join(field_lines) + "\n\nRules:\n"
|
|
28
|
+
"- Every record must be unique across all fields.\n"
|
|
29
|
+
"- Enum fields must use only the declared values.\n"
|
|
30
|
+
"- Numeric fields must respect any declared min/max range.\n"
|
|
31
|
+
"- Return exactly the requested number of records.\n"
|
|
32
|
+
"- Use the tavily tools to find the records."
|
|
33
|
+
)
|
|
34
|
+
if instructions:
|
|
35
|
+
prompt += f"\n\nAdditional instructions:\n{instructions}"
|
|
36
|
+
return prompt
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def build_agent(
|
|
40
|
+
schema: type[BaseModel],
|
|
41
|
+
samples: int,
|
|
42
|
+
instructions: str = "",
|
|
43
|
+
model: str | Model | None = None,
|
|
44
|
+
) -> Agent:
|
|
45
|
+
resolved: str | Model = model if model is not None else DEFAULT_MODEL
|
|
46
|
+
table = _table_name(schema)
|
|
47
|
+
logger.debug("Building agent for '%s' with model %r", table, resolved)
|
|
48
|
+
batch_model = build_batch_model(schema)
|
|
49
|
+
|
|
50
|
+
agent = Agent(
|
|
51
|
+
model=resolved,
|
|
52
|
+
output_type=batch_model,
|
|
53
|
+
system_prompt=_build_prompt(schema, samples, instructions),
|
|
54
|
+
tools=[search_web, extract_web_content, crawl_paths, map_website],
|
|
55
|
+
)
|
|
56
|
+
return agent
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def run_discovery(
|
|
60
|
+
schema: type[BaseModel],
|
|
61
|
+
samples: int,
|
|
62
|
+
instructions: str = "",
|
|
63
|
+
model: str | Model | None = None,
|
|
64
|
+
logfire_token: str | None = None,
|
|
65
|
+
progress_cb: Callable[[int, int, str], None] | None = None,
|
|
66
|
+
) -> list[dict]:
|
|
67
|
+
if logfire_token:
|
|
68
|
+
logfire.configure(token=logfire_token, send_to_logfire=True, scrubbing=False)
|
|
69
|
+
else:
|
|
70
|
+
logfire.configure(send_to_logfire=False, scrubbing=False)
|
|
71
|
+
logfire.instrument_pydantic_ai()
|
|
72
|
+
|
|
73
|
+
table = _table_name(schema)
|
|
74
|
+
logger.info(
|
|
75
|
+
"Starting discovery run for '%s' (%d records requested)", table, samples
|
|
76
|
+
)
|
|
77
|
+
agent = build_agent(schema, samples, instructions, model=model)
|
|
78
|
+
result = agent.run_sync(
|
|
79
|
+
f"Find exactly {samples} unique records for the '{table}' table."
|
|
80
|
+
)
|
|
81
|
+
logger.debug("Agent returned %d raw records", len(result.output.records))
|
|
82
|
+
|
|
83
|
+
seen: dict[str, dict] = {}
|
|
84
|
+
duplicates = 0
|
|
85
|
+
for record in result.output.records:
|
|
86
|
+
record_dict = record.model_dump()
|
|
87
|
+
h = _hash_record(record_dict)
|
|
88
|
+
if h in seen:
|
|
89
|
+
duplicates += 1
|
|
90
|
+
if progress_cb:
|
|
91
|
+
progress_cb(len(seen), samples, "duplicate")
|
|
92
|
+
continue
|
|
93
|
+
seen[h] = record_dict
|
|
94
|
+
if progress_cb:
|
|
95
|
+
progress_cb(len(seen), samples, "found")
|
|
96
|
+
|
|
97
|
+
logger.info(
|
|
98
|
+
"Generation run complete: %d unique records (%d duplicates discarded)",
|
|
99
|
+
len(seen),
|
|
100
|
+
duplicates,
|
|
101
|
+
)
|
|
102
|
+
return list(seen.values())
|
schematico/generator.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import logfire
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from pydantic_ai import Agent
|
|
8
|
+
from pydantic_ai.models import Model
|
|
9
|
+
|
|
10
|
+
from schematico.helpers import _table_name, _hash_record, _describe_fields
|
|
11
|
+
from schematico.logging import get_logger
|
|
12
|
+
from schematico.models import build_batch_model
|
|
13
|
+
from schematico.providers import DEFAULT_MODEL
|
|
14
|
+
|
|
15
|
+
logger = get_logger("core.generator")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _build_prompt(schema: type[BaseModel], samples: int, instructions: str) -> str:
|
|
19
|
+
table = _table_name(schema)
|
|
20
|
+
field_lines = _describe_fields(schema)
|
|
21
|
+
prompt = (
|
|
22
|
+
f"You are a data generation agent for the '{table}' table.\n"
|
|
23
|
+
f"Generate exactly {samples} realistic, unique records with "
|
|
24
|
+
"these fields:\n" + "\n".join(field_lines) + "\n\nRules:\n"
|
|
25
|
+
"- Every record must be unique across all fields.\n"
|
|
26
|
+
"- Enum fields must use only the declared values.\n"
|
|
27
|
+
"- Numeric fields must respect any declared min/max range.\n"
|
|
28
|
+
"- Return exactly the requested number of records."
|
|
29
|
+
)
|
|
30
|
+
if instructions:
|
|
31
|
+
prompt += f"\n\nAdditional instructions:\n{instructions}"
|
|
32
|
+
return prompt
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def build_agent(
|
|
36
|
+
schema: type[BaseModel],
|
|
37
|
+
samples: int,
|
|
38
|
+
instructions: str = "",
|
|
39
|
+
model: str | Model | None = None,
|
|
40
|
+
) -> Agent:
|
|
41
|
+
resolved: str | Model = model if model is not None else DEFAULT_MODEL
|
|
42
|
+
table = _table_name(schema)
|
|
43
|
+
logger.debug("Building agent for '%s' with model %r", table, resolved)
|
|
44
|
+
batch_model = build_batch_model(schema)
|
|
45
|
+
|
|
46
|
+
agent = Agent(
|
|
47
|
+
model=resolved,
|
|
48
|
+
output_type=batch_model,
|
|
49
|
+
system_prompt=_build_prompt(schema, samples, instructions),
|
|
50
|
+
)
|
|
51
|
+
return agent
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def run_generation(
|
|
55
|
+
schema: type[BaseModel],
|
|
56
|
+
samples: int,
|
|
57
|
+
instructions: str = "",
|
|
58
|
+
model: str | Model | None = None,
|
|
59
|
+
logfire_token: str | None = None,
|
|
60
|
+
progress_cb: Callable[[int, int, str], None] | None = None,
|
|
61
|
+
) -> list[dict]:
|
|
62
|
+
if logfire_token:
|
|
63
|
+
logfire.configure(token=logfire_token, send_to_logfire=True)
|
|
64
|
+
else:
|
|
65
|
+
logfire.configure(send_to_logfire=False)
|
|
66
|
+
logfire.instrument_pydantic_ai()
|
|
67
|
+
|
|
68
|
+
table = _table_name(schema)
|
|
69
|
+
logger.info(
|
|
70
|
+
"Starting generation run for '%s' (%d records requested)", table, samples
|
|
71
|
+
)
|
|
72
|
+
agent = build_agent(schema, samples, instructions, model=model)
|
|
73
|
+
result = agent.run_sync(
|
|
74
|
+
f"Generate exactly {samples} unique records for the '{table}' table."
|
|
75
|
+
)
|
|
76
|
+
logger.debug("Agent returned %d raw records", len(result.output.records))
|
|
77
|
+
|
|
78
|
+
seen: dict[str, dict] = {}
|
|
79
|
+
duplicates = 0
|
|
80
|
+
for record in result.output.records:
|
|
81
|
+
record_dict = record.model_dump()
|
|
82
|
+
h = _hash_record(record_dict)
|
|
83
|
+
if h in seen:
|
|
84
|
+
duplicates += 1
|
|
85
|
+
if progress_cb:
|
|
86
|
+
progress_cb(len(seen), samples, "duplicate")
|
|
87
|
+
continue
|
|
88
|
+
seen[h] = record_dict
|
|
89
|
+
if progress_cb:
|
|
90
|
+
progress_cb(len(seen), samples, "found")
|
|
91
|
+
|
|
92
|
+
logger.info(
|
|
93
|
+
"Generation run complete: %d unique records (%d duplicates discarded)",
|
|
94
|
+
len(seen),
|
|
95
|
+
duplicates,
|
|
96
|
+
)
|
|
97
|
+
return list(seen.values())
|
schematico/helpers.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _table_name(schema: type[BaseModel]) -> str:
|
|
10
|
+
name = schema.__name__
|
|
11
|
+
if name.endswith("Record"):
|
|
12
|
+
name = name[: -len("Record")]
|
|
13
|
+
s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
14
|
+
return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _hash_record(record: dict) -> str:
|
|
18
|
+
serialized = json.dumps(record, sort_keys=True, default=str)
|
|
19
|
+
return hashlib.sha256(serialized.encode()).hexdigest()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _describe_fields(schema: type[BaseModel]) -> list[str]:
|
|
23
|
+
json_schema = schema.model_json_schema()
|
|
24
|
+
properties: dict[str, dict[str, Any]] = json_schema.get("properties", {})
|
|
25
|
+
lines: list[str] = []
|
|
26
|
+
for name, prop in properties.items():
|
|
27
|
+
ptype = prop.get("type") or prop.get("anyOf") or "any"
|
|
28
|
+
line = f"- {name}: {ptype}"
|
|
29
|
+
desc = prop.get("description")
|
|
30
|
+
if desc:
|
|
31
|
+
line += f" — {desc}"
|
|
32
|
+
if "enum" in prop:
|
|
33
|
+
line += f" (must be exactly one of: {prop['enum']})"
|
|
34
|
+
lo, hi = prop.get("minimum"), prop.get("maximum")
|
|
35
|
+
if lo is not None or hi is not None:
|
|
36
|
+
line += f" (range: {lo} to {hi})"
|
|
37
|
+
lines.append(line)
|
|
38
|
+
return lines
|
schematico/logging.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
_ROOT_NAME = "schematico"
|
|
8
|
+
_configured = False
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def configure_logging() -> None:
|
|
12
|
+
global _configured
|
|
13
|
+
if _configured:
|
|
14
|
+
return
|
|
15
|
+
|
|
16
|
+
level_name = os.environ.get("LOG_LEVEL", "WARNING").upper()
|
|
17
|
+
level = getattr(logging, level_name, logging.WARNING)
|
|
18
|
+
|
|
19
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
20
|
+
handler.setFormatter(
|
|
21
|
+
logging.Formatter("%(asctime)s %(levelname)-8s %(name)s: %(message)s")
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
root = logging.getLogger(_ROOT_NAME)
|
|
25
|
+
root.setLevel(level)
|
|
26
|
+
root.addHandler(handler)
|
|
27
|
+
root.propagate = False
|
|
28
|
+
|
|
29
|
+
_configured = True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_logger(name: str) -> logging.Logger:
|
|
33
|
+
configure_logging()
|
|
34
|
+
return logging.getLogger(f"{_ROOT_NAME}.{name}")
|