sqlseed-ai 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlseed_ai-0.1.0/.gitignore +69 -0
- sqlseed_ai-0.1.0/PKG-INFO +19 -0
- sqlseed_ai-0.1.0/pyproject.toml +37 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/__init__.py +88 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/_client.py +31 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/_json_utils.py +24 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/analyzer.py +304 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/config.py +20 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/errors.py +119 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/examples.py +172 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/nl_config.py +80 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/provider.py +88 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/refiner.py +277 -0
- sqlseed_ai-0.1.0/src/sqlseed_ai/suggest.py +62 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.so
|
|
6
|
+
.Python
|
|
7
|
+
build/
|
|
8
|
+
develop-eggs/
|
|
9
|
+
dist/
|
|
10
|
+
downloads/
|
|
11
|
+
eggs/
|
|
12
|
+
.eggs/
|
|
13
|
+
lib/
|
|
14
|
+
lib64/
|
|
15
|
+
parts/
|
|
16
|
+
sdist/
|
|
17
|
+
var/
|
|
18
|
+
wheels/
|
|
19
|
+
*.egg-info/
|
|
20
|
+
.installed.cfg
|
|
21
|
+
*.egg
|
|
22
|
+
|
|
23
|
+
# Virtual Environment
|
|
24
|
+
.venv/
|
|
25
|
+
env/
|
|
26
|
+
ENV/
|
|
27
|
+
|
|
28
|
+
# IDE
|
|
29
|
+
.vscode/
|
|
30
|
+
.idea/
|
|
31
|
+
*.swp
|
|
32
|
+
*.swo
|
|
33
|
+
*~
|
|
34
|
+
|
|
35
|
+
# Testing
|
|
36
|
+
.pytest_cache/
|
|
37
|
+
.coverage
|
|
38
|
+
htmlcov/
|
|
39
|
+
|
|
40
|
+
# Type checking
|
|
41
|
+
.mypy_cache/
|
|
42
|
+
.dmypy.json
|
|
43
|
+
dmypy.json
|
|
44
|
+
|
|
45
|
+
# Linting
|
|
46
|
+
.ruff_cache/
|
|
47
|
+
|
|
48
|
+
# Project specific
|
|
49
|
+
*.db
|
|
50
|
+
*.sqlite
|
|
51
|
+
*.sqlite3
|
|
52
|
+
snapshots/
|
|
53
|
+
|
|
54
|
+
# AI cache
|
|
55
|
+
.sqlseed_cache/
|
|
56
|
+
|
|
57
|
+
# Archived temp files
|
|
58
|
+
_archived_temp/
|
|
59
|
+
|
|
60
|
+
# macOS
|
|
61
|
+
.DS_Store
|
|
62
|
+
|
|
63
|
+
# Trae IDE
|
|
64
|
+
.trae/
|
|
65
|
+
|
|
66
|
+
# Build artifacts
|
|
67
|
+
dist/
|
|
68
|
+
*.whl
|
|
69
|
+
*.tar.gz
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sqlseed-ai
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: AI-powered data generation plugin for sqlseed
|
|
5
|
+
Project-URL: Homepage, https://github.com/sunbos/sqlseed
|
|
6
|
+
Project-URL: Repository, https://github.com/sunbos/sqlseed/tree/main/plugins/sqlseed-ai
|
|
7
|
+
Author-email: SunBo <1443584939@qq.com>
|
|
8
|
+
License-Expression: AGPL-3.0-or-later
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Requires-Python: >=3.10
|
|
17
|
+
Requires-Dist: openai>=1.0
|
|
18
|
+
Requires-Dist: pydantic>=2.0
|
|
19
|
+
Requires-Dist: sqlseed
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "sqlseed-ai"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
requires-python = ">=3.10"
|
|
9
|
+
description = "AI-powered data generation plugin for sqlseed"
|
|
10
|
+
license = "AGPL-3.0-or-later"
|
|
11
|
+
authors = [
|
|
12
|
+
{name = "SunBo", email = "1443584939@qq.com"},
|
|
13
|
+
]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 3 - Alpha",
|
|
16
|
+
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"Programming Language :: Python :: 3.10",
|
|
19
|
+
"Programming Language :: Python :: 3.11",
|
|
20
|
+
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Programming Language :: Python :: 3.13",
|
|
22
|
+
]
|
|
23
|
+
dependencies = [
|
|
24
|
+
"sqlseed",
|
|
25
|
+
"openai>=1.0",
|
|
26
|
+
"pydantic>=2.0",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[project.urls]
|
|
30
|
+
Homepage = "https://github.com/sunbos/sqlseed"
|
|
31
|
+
Repository = "https://github.com/sunbos/sqlseed/tree/main/plugins/sqlseed-ai"
|
|
32
|
+
|
|
33
|
+
[project.entry-points."sqlseed"]
|
|
34
|
+
ai = "sqlseed_ai:plugin"
|
|
35
|
+
|
|
36
|
+
[tool.hatch.build.targets.wheel]
|
|
37
|
+
packages = ["src/sqlseed_ai"]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlseed.plugins.hookspecs import hookimpl
|
|
7
|
+
from sqlseed_ai.analyzer import SchemaAnalyzer
|
|
8
|
+
|
|
9
|
+
_SIMPLE_COL_RE = re.compile(
|
|
10
|
+
r'(^|[_\s])('
|
|
11
|
+
r'name|email|phone|address|url|uuid|'
|
|
12
|
+
r'date|time|datetime|timestamp|boolean|'
|
|
13
|
+
r'int|float|double|real|text|string|'
|
|
14
|
+
r'char|varchar|blob|byte|id|code|title|'
|
|
15
|
+
r'description|status|type|category|count|'
|
|
16
|
+
r'amount|price|value|number|index|order|level'
|
|
17
|
+
r')($|[_\s])',
|
|
18
|
+
re.IGNORECASE,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AISqlseedPlugin:
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
self._analyzer: SchemaAnalyzer | None = None
|
|
25
|
+
|
|
26
|
+
def _get_analyzer(self) -> SchemaAnalyzer:
|
|
27
|
+
if self._analyzer is None:
|
|
28
|
+
self._analyzer = SchemaAnalyzer()
|
|
29
|
+
return self._analyzer
|
|
30
|
+
|
|
31
|
+
def _is_simple_column(self, column_name: str, column_type: str) -> bool:
|
|
32
|
+
return bool(
|
|
33
|
+
_SIMPLE_COL_RE.search(column_name) or _SIMPLE_COL_RE.search(column_type)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@hookimpl
|
|
37
|
+
def sqlseed_ai_analyze_table(
|
|
38
|
+
self,
|
|
39
|
+
table_name: str,
|
|
40
|
+
columns: list[Any],
|
|
41
|
+
indexes: list[dict[str, Any]],
|
|
42
|
+
sample_data: list[dict[str, Any]],
|
|
43
|
+
foreign_keys: list[Any],
|
|
44
|
+
all_table_names: list[str],
|
|
45
|
+
) -> dict[str, Any] | None:
|
|
46
|
+
analyzer = self._get_analyzer()
|
|
47
|
+
return analyzer.analyze_table(
|
|
48
|
+
table_name=table_name,
|
|
49
|
+
columns=columns,
|
|
50
|
+
indexes=indexes,
|
|
51
|
+
sample_data=sample_data,
|
|
52
|
+
foreign_keys=foreign_keys,
|
|
53
|
+
all_table_names=all_table_names,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@hookimpl
|
|
57
|
+
def sqlseed_pre_generate_templates(
|
|
58
|
+
self,
|
|
59
|
+
table_name: str,
|
|
60
|
+
column_name: str,
|
|
61
|
+
column_type: str,
|
|
62
|
+
count: int,
|
|
63
|
+
sample_data: list[Any],
|
|
64
|
+
) -> list[Any] | None:
|
|
65
|
+
if self._is_simple_column(column_name, column_type):
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
analyzer = self._get_analyzer()
|
|
69
|
+
try:
|
|
70
|
+
return analyzer.generate_template_values(
|
|
71
|
+
column_name=column_name,
|
|
72
|
+
column_type=column_type,
|
|
73
|
+
count=min(count, 50),
|
|
74
|
+
sample_data=sample_data,
|
|
75
|
+
)
|
|
76
|
+
except Exception:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
@hookimpl
|
|
80
|
+
def sqlseed_register_providers(self, registry: Any) -> None:
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@hookimpl
|
|
84
|
+
def sqlseed_register_column_mappers(self, mapper: Any) -> None:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
plugin = AISqlseedPlugin()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlseed._utils.logger import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_openai_client(config: Any | None = None) -> Any:
|
|
11
|
+
try:
|
|
12
|
+
from openai import OpenAI
|
|
13
|
+
|
|
14
|
+
from sqlseed_ai.config import AIConfig
|
|
15
|
+
|
|
16
|
+
if config is None:
|
|
17
|
+
config = AIConfig.from_env()
|
|
18
|
+
|
|
19
|
+
api_key = config.api_key if hasattr(config, "api_key") else None
|
|
20
|
+
base_url = config.base_url if hasattr(config, "base_url") else None
|
|
21
|
+
|
|
22
|
+
if not api_key:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
"AI API key not configured. Set SQLSEED_AI_API_KEY or OPENAI_API_KEY environment variable."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
return OpenAI(api_key=api_key, base_url=base_url)
|
|
28
|
+
except ImportError:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
"openai is not installed. Install it with: pip install sqlseed-ai"
|
|
31
|
+
) from None
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def parse_json_response(content: str) -> dict[str, Any]:
|
|
8
|
+
cleaned = content.strip()
|
|
9
|
+
if cleaned.startswith("```json"):
|
|
10
|
+
cleaned = cleaned[7:]
|
|
11
|
+
if cleaned.startswith("```"):
|
|
12
|
+
cleaned = cleaned[3:]
|
|
13
|
+
if cleaned.endswith("```"):
|
|
14
|
+
cleaned = cleaned[:-3]
|
|
15
|
+
cleaned = cleaned.strip()
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
result = json.loads(cleaned)
|
|
19
|
+
if isinstance(result, dict):
|
|
20
|
+
return result
|
|
21
|
+
except json.JSONDecodeError:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
return {}
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlseed._utils.logger import get_logger
|
|
6
|
+
from sqlseed_ai._client import get_openai_client
|
|
7
|
+
from sqlseed_ai.config import AIConfig
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
SYSTEM_PROMPT = """You are an expert database test data engineer.
|
|
12
|
+
You analyze SQLite table schemas and recommend data generation configurations for the sqlseed toolkit.
|
|
13
|
+
|
|
14
|
+
## Available Generators
|
|
15
|
+
- string (params: min_length, max_length, charset)
|
|
16
|
+
- integer (params: min_value, max_value)
|
|
17
|
+
- float (params: min_value, max_value, precision)
|
|
18
|
+
- boolean
|
|
19
|
+
- bytes (params: length)
|
|
20
|
+
- name, first_name, last_name
|
|
21
|
+
- email, phone, address, company
|
|
22
|
+
- url, ipv4, uuid
|
|
23
|
+
- date (params: start_year, end_year)
|
|
24
|
+
- datetime (params: start_year, end_year)
|
|
25
|
+
- timestamp
|
|
26
|
+
- text (params: min_length, max_length)
|
|
27
|
+
- sentence, password
|
|
28
|
+
- choice (params: choices)
|
|
29
|
+
- json (params: schema)
|
|
30
|
+
- pattern (params: regex) — generates strings matching a regex pattern
|
|
31
|
+
|
|
32
|
+
## Key Rules
|
|
33
|
+
1. INTEGER PRIMARY KEY AUTOINCREMENT columns → do NOT include (auto-skip)
|
|
34
|
+
2. Columns with DEFAULT values → do NOT include (auto-skip)
|
|
35
|
+
3. Nullable columns → do NOT include unless they have semantic meaning
|
|
36
|
+
4. Use `pattern` generator with regex for card numbers, codes, IDs with specific formats
|
|
37
|
+
5. Use `derive_from` + `expression` when one column is computed from another
|
|
38
|
+
6. Use `constraints.unique: true` for columns that must be unique
|
|
39
|
+
7. Detect cross-column dependencies: if CutCard4byte = last 8 chars of sCardNo, use derive_from
|
|
40
|
+
8. Detect implicit business associations: if sUserNo appears in multiple tables, note it
|
|
41
|
+
|
|
42
|
+
## Output Format
|
|
43
|
+
You MUST respond with a valid JSON object (NOT YAML, NOT markdown fences).
|
|
44
|
+
The JSON object must have this exact structure:
|
|
45
|
+
{
|
|
46
|
+
"name": "table_name",
|
|
47
|
+
"count": 1000,
|
|
48
|
+
"columns": [
|
|
49
|
+
{
|
|
50
|
+
"name": "column_name",
|
|
51
|
+
"generator": "generator_name",
|
|
52
|
+
"params": {"key": "value"}
|
|
53
|
+
},
|
|
54
|
+
{
|
|
55
|
+
"name": "derived_column",
|
|
56
|
+
"derive_from": "source_column",
|
|
57
|
+
"expression": "value[-8:]",
|
|
58
|
+
"constraints": {"unique": true}
|
|
59
|
+
}
|
|
60
|
+
]
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
IMPORTANT: Do NOT include columns that are PRIMARY KEY AUTOINCREMENT or have DEFAULT values."""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SchemaAnalyzer:
|
|
67
|
+
def __init__(self, config: AIConfig | None = None) -> None:
|
|
68
|
+
self._config = config
|
|
69
|
+
|
|
70
|
+
def analyze_table(
|
|
71
|
+
self,
|
|
72
|
+
table_name: str,
|
|
73
|
+
columns: list[Any],
|
|
74
|
+
indexes: list[dict[str, Any]],
|
|
75
|
+
sample_data: list[dict[str, Any]],
|
|
76
|
+
foreign_keys: list[Any],
|
|
77
|
+
all_table_names: list[str],
|
|
78
|
+
) -> dict[str, Any] | None:
|
|
79
|
+
if self._config is None:
|
|
80
|
+
self._config = AIConfig.from_env()
|
|
81
|
+
|
|
82
|
+
if not self._config.api_key:
|
|
83
|
+
logger.warning("AI API key not configured, skipping analysis")
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
messages = self.build_initial_messages(
|
|
87
|
+
table_name=table_name,
|
|
88
|
+
columns=columns,
|
|
89
|
+
indexes=indexes,
|
|
90
|
+
sample_data=sample_data,
|
|
91
|
+
foreign_keys=foreign_keys,
|
|
92
|
+
all_table_names=all_table_names,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
return self.call_llm(messages)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.warning("AI analysis failed", table_name=table_name, error=str(e))
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
def build_initial_messages(
|
|
102
|
+
self,
|
|
103
|
+
table_name: str,
|
|
104
|
+
columns: list[Any],
|
|
105
|
+
indexes: list[dict[str, Any]],
|
|
106
|
+
sample_data: list[dict[str, Any]],
|
|
107
|
+
foreign_keys: list[Any],
|
|
108
|
+
all_table_names: list[str],
|
|
109
|
+
distribution_profiles: list[dict[str, Any]] | None = None,
|
|
110
|
+
) -> list[dict[str, str]]:
|
|
111
|
+
context = self._build_context(
|
|
112
|
+
table_name=table_name,
|
|
113
|
+
columns=columns,
|
|
114
|
+
indexes=indexes,
|
|
115
|
+
sample_data=sample_data,
|
|
116
|
+
foreign_keys=foreign_keys,
|
|
117
|
+
all_table_names=all_table_names,
|
|
118
|
+
distribution_profiles=distribution_profiles,
|
|
119
|
+
)
|
|
120
|
+
messages: list[dict[str, str]] = [
|
|
121
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
from sqlseed_ai.examples import FEW_SHOT_EXAMPLES
|
|
125
|
+
for example in FEW_SHOT_EXAMPLES:
|
|
126
|
+
messages.append({"role": "user", "content": example["input"]})
|
|
127
|
+
messages.append({"role": "assistant", "content": example["output"]})
|
|
128
|
+
|
|
129
|
+
messages.append({"role": "user", "content": context})
|
|
130
|
+
|
|
131
|
+
return messages
|
|
132
|
+
|
|
133
|
+
def call_llm(self, messages: list[dict[str, str]]) -> dict[str, Any]:
|
|
134
|
+
if self._config is None:
|
|
135
|
+
self._config = AIConfig.from_env()
|
|
136
|
+
if not self._config.api_key:
|
|
137
|
+
raise ValueError("AI API key not configured")
|
|
138
|
+
|
|
139
|
+
client = get_openai_client(self._config)
|
|
140
|
+
try:
|
|
141
|
+
response = client.chat.completions.create(
|
|
142
|
+
model=self._config.model,
|
|
143
|
+
messages=messages,
|
|
144
|
+
max_tokens=self._config.max_tokens,
|
|
145
|
+
temperature=self._config.temperature,
|
|
146
|
+
response_format={"type": "json_object"},
|
|
147
|
+
)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
raise RuntimeError(
|
|
150
|
+
f"LLM API call failed (model={self._config.model}): {e}"
|
|
151
|
+
) from e
|
|
152
|
+
|
|
153
|
+
if not response.choices:
|
|
154
|
+
raise RuntimeError(
|
|
155
|
+
f"LLM returned no choices (model={self._config.model}). "
|
|
156
|
+
"The API key or model may be invalid."
|
|
157
|
+
)
|
|
158
|
+
content = response.choices[0].message.content
|
|
159
|
+
if content is None:
|
|
160
|
+
return {}
|
|
161
|
+
return self._parse_json_response(content)
|
|
162
|
+
|
|
163
|
+
TEMPLATE_SYSTEM_PROMPT = (
|
|
164
|
+
"You are a data generation assistant. Generate realistic sample values "
|
|
165
|
+
"for the given database column. Return a JSON object with a 'values' "
|
|
166
|
+
"array containing the requested number of unique, realistic values. "
|
|
167
|
+
"Each value must be valid for the column type. Do NOT include explanations."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def generate_template_values(
|
|
171
|
+
self,
|
|
172
|
+
column_name: str,
|
|
173
|
+
column_type: str,
|
|
174
|
+
count: int,
|
|
175
|
+
sample_data: list[Any],
|
|
176
|
+
) -> list[Any]:
|
|
177
|
+
prompt = (
|
|
178
|
+
f"Generate {count} realistic sample values for a database column "
|
|
179
|
+
f"named '{column_name}' with type '{column_type}'."
|
|
180
|
+
)
|
|
181
|
+
if sample_data:
|
|
182
|
+
prompt += f"\nExisting sample values: {sample_data[:5]}"
|
|
183
|
+
prompt += (
|
|
184
|
+
f"\nRespond with a JSON object: {{\"values\": [...]}}."
|
|
185
|
+
f"\nEach value should be a valid {column_type} value."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
messages = [
|
|
189
|
+
{"role": "system", "content": self.TEMPLATE_SYSTEM_PROMPT},
|
|
190
|
+
{"role": "user", "content": prompt},
|
|
191
|
+
]
|
|
192
|
+
result = self.call_llm(messages)
|
|
193
|
+
return result.get("values", [])
|
|
194
|
+
|
|
195
|
+
def _build_context(
|
|
196
|
+
self,
|
|
197
|
+
table_name: str,
|
|
198
|
+
columns: list[Any],
|
|
199
|
+
indexes: list[dict[str, Any]],
|
|
200
|
+
sample_data: list[dict[str, Any]],
|
|
201
|
+
foreign_keys: list[Any],
|
|
202
|
+
all_table_names: list[str],
|
|
203
|
+
distribution_profiles: list[dict[str, Any]] | None = None,
|
|
204
|
+
) -> str:
|
|
205
|
+
lines: list[str] = []
|
|
206
|
+
lines.append(f"# Table: {table_name}")
|
|
207
|
+
lines.append("")
|
|
208
|
+
|
|
209
|
+
self._append_columns_info(lines, columns)
|
|
210
|
+
|
|
211
|
+
if indexes:
|
|
212
|
+
self._append_indexes_info(lines, indexes)
|
|
213
|
+
|
|
214
|
+
if foreign_keys:
|
|
215
|
+
lines.append("")
|
|
216
|
+
lines.append("## Foreign Keys")
|
|
217
|
+
for fk in foreign_keys:
|
|
218
|
+
lines.append(f"- {fk.column} → {fk.ref_table}.{fk.ref_column}")
|
|
219
|
+
|
|
220
|
+
if all_table_names:
|
|
221
|
+
lines.append("")
|
|
222
|
+
lines.append("## All Tables in Database")
|
|
223
|
+
lines.append(", ".join(all_table_names))
|
|
224
|
+
|
|
225
|
+
if sample_data:
|
|
226
|
+
lines.append("")
|
|
227
|
+
lines.append("## Sample Data (existing rows)")
|
|
228
|
+
for i, row in enumerate(sample_data[:5]):
|
|
229
|
+
row_str = ", ".join(f"{k}={v}" for k, v in row.items())
|
|
230
|
+
lines.append(f" Row {i + 1}: {row_str}")
|
|
231
|
+
|
|
232
|
+
if distribution_profiles:
|
|
233
|
+
self._append_distribution_info(lines, distribution_profiles)
|
|
234
|
+
|
|
235
|
+
lines.append("")
|
|
236
|
+
lines.append(
|
|
237
|
+
"Please analyze this table schema and recommend "
|
|
238
|
+
"a complete sqlseed JSON configuration for generating test data."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return "\n".join(lines)
|
|
242
|
+
|
|
243
|
+
def _append_columns_info(
|
|
244
|
+
self,
|
|
245
|
+
lines: list[str],
|
|
246
|
+
columns: list[Any],
|
|
247
|
+
) -> None:
|
|
248
|
+
lines.append("## Columns")
|
|
249
|
+
for col in columns:
|
|
250
|
+
parts = [f"- {col.name}: {col.type}"]
|
|
251
|
+
if col.is_primary_key:
|
|
252
|
+
parts.append("PRIMARY KEY")
|
|
253
|
+
if col.is_autoincrement:
|
|
254
|
+
parts.append("AUTOINCREMENT")
|
|
255
|
+
if col.nullable:
|
|
256
|
+
parts.append("NULLABLE")
|
|
257
|
+
if col.default is not None:
|
|
258
|
+
parts.append(f"DEFAULT={col.default}")
|
|
259
|
+
if not col.nullable and col.default is None and not col.is_primary_key:
|
|
260
|
+
parts.append("NOT NULL")
|
|
261
|
+
lines.append(" ".join(parts))
|
|
262
|
+
|
|
263
|
+
def _append_indexes_info(
|
|
264
|
+
self,
|
|
265
|
+
lines: list[str],
|
|
266
|
+
indexes: list[dict[str, Any]],
|
|
267
|
+
) -> None:
|
|
268
|
+
lines.append("")
|
|
269
|
+
lines.append("## Indexes")
|
|
270
|
+
for idx in indexes:
|
|
271
|
+
unique_str = "UNIQUE " if idx.get("unique") else ""
|
|
272
|
+
cols_str = ", ".join(idx.get("columns", []))
|
|
273
|
+
lines.append(f"- {unique_str}INDEX ({cols_str})")
|
|
274
|
+
|
|
275
|
+
def _append_distribution_info(
|
|
276
|
+
self,
|
|
277
|
+
lines: list[str],
|
|
278
|
+
distribution_profiles: list[dict[str, Any]],
|
|
279
|
+
) -> None:
|
|
280
|
+
lines.append("")
|
|
281
|
+
lines.append("## Column Distribution (from existing data)")
|
|
282
|
+
for profile in distribution_profiles:
|
|
283
|
+
col = profile["column"]
|
|
284
|
+
distinct = profile.get("distinct_count", "?")
|
|
285
|
+
null_ratio = profile.get("null_ratio", 0)
|
|
286
|
+
lines.append(
|
|
287
|
+
f"- {col}: {distinct} distinct values, {null_ratio:.1%} null"
|
|
288
|
+
)
|
|
289
|
+
top_values = profile.get("top_values", [])
|
|
290
|
+
if top_values:
|
|
291
|
+
top_str = ", ".join(
|
|
292
|
+
f"{tv['value']}({tv['frequency']:.0%})"
|
|
293
|
+
for tv in top_values[:3]
|
|
294
|
+
)
|
|
295
|
+
lines.append(f" Top values: {top_str}")
|
|
296
|
+
vr = profile.get("value_range")
|
|
297
|
+
if vr:
|
|
298
|
+
lines.append(f" Range: [{vr['min']}, {vr['max']}]")
|
|
299
|
+
|
|
300
|
+
def _parse_json_response(self, content: str) -> dict[str, Any]:
|
|
301
|
+
from sqlseed_ai._json_utils import parse_json_response
|
|
302
|
+
|
|
303
|
+
return parse_json_response(content)
|
|
304
|
+
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AIConfig(BaseModel):
|
|
9
|
+
api_key: str | None = None
|
|
10
|
+
model: str = "qwen3-coder-plus"
|
|
11
|
+
base_url: str | None = None
|
|
12
|
+
temperature: float = Field(default=0.3, ge=0.0, le=2.0)
|
|
13
|
+
max_tokens: int = Field(default=4096, gt=0)
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def from_env(cls) -> AIConfig:
|
|
17
|
+
api_key = os.environ.get("SQLSEED_AI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
|
18
|
+
base_url = os.environ.get("SQLSEED_AI_BASE_URL") or os.environ.get("OPENAI_BASE_URL")
|
|
19
|
+
model = os.environ.get("SQLSEED_AI_MODEL", cls.model_fields["model"].default)
|
|
20
|
+
return cls(api_key=api_key, base_url=base_url, model=model)
|