sqlseed-ai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlseed_ai/__init__.py +88 -0
- sqlseed_ai/_client.py +31 -0
- sqlseed_ai/_json_utils.py +24 -0
- sqlseed_ai/analyzer.py +304 -0
- sqlseed_ai/config.py +20 -0
- sqlseed_ai/errors.py +119 -0
- sqlseed_ai/examples.py +172 -0
- sqlseed_ai/nl_config.py +80 -0
- sqlseed_ai/provider.py +88 -0
- sqlseed_ai/refiner.py +277 -0
- sqlseed_ai/suggest.py +62 -0
- sqlseed_ai-0.1.0.dist-info/METADATA +19 -0
- sqlseed_ai-0.1.0.dist-info/RECORD +15 -0
- sqlseed_ai-0.1.0.dist-info/WHEEL +4 -0
- sqlseed_ai-0.1.0.dist-info/entry_points.txt +2 -0
sqlseed_ai/__init__.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlseed.plugins.hookspecs import hookimpl
|
|
7
|
+
from sqlseed_ai.analyzer import SchemaAnalyzer
|
|
8
|
+
|
|
9
|
+
_SIMPLE_COL_RE = re.compile(
|
|
10
|
+
r'(^|[_\s])('
|
|
11
|
+
r'name|email|phone|address|url|uuid|'
|
|
12
|
+
r'date|time|datetime|timestamp|boolean|'
|
|
13
|
+
r'int|float|double|real|text|string|'
|
|
14
|
+
r'char|varchar|blob|byte|id|code|title|'
|
|
15
|
+
r'description|status|type|category|count|'
|
|
16
|
+
r'amount|price|value|number|index|order|level'
|
|
17
|
+
r')($|[_\s])',
|
|
18
|
+
re.IGNORECASE,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AISqlseedPlugin:
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
self._analyzer: SchemaAnalyzer | None = None
|
|
25
|
+
|
|
26
|
+
def _get_analyzer(self) -> SchemaAnalyzer:
|
|
27
|
+
if self._analyzer is None:
|
|
28
|
+
self._analyzer = SchemaAnalyzer()
|
|
29
|
+
return self._analyzer
|
|
30
|
+
|
|
31
|
+
def _is_simple_column(self, column_name: str, column_type: str) -> bool:
|
|
32
|
+
return bool(
|
|
33
|
+
_SIMPLE_COL_RE.search(column_name) or _SIMPLE_COL_RE.search(column_type)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@hookimpl
|
|
37
|
+
def sqlseed_ai_analyze_table(
|
|
38
|
+
self,
|
|
39
|
+
table_name: str,
|
|
40
|
+
columns: list[Any],
|
|
41
|
+
indexes: list[dict[str, Any]],
|
|
42
|
+
sample_data: list[dict[str, Any]],
|
|
43
|
+
foreign_keys: list[Any],
|
|
44
|
+
all_table_names: list[str],
|
|
45
|
+
) -> dict[str, Any] | None:
|
|
46
|
+
analyzer = self._get_analyzer()
|
|
47
|
+
return analyzer.analyze_table(
|
|
48
|
+
table_name=table_name,
|
|
49
|
+
columns=columns,
|
|
50
|
+
indexes=indexes,
|
|
51
|
+
sample_data=sample_data,
|
|
52
|
+
foreign_keys=foreign_keys,
|
|
53
|
+
all_table_names=all_table_names,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@hookimpl
|
|
57
|
+
def sqlseed_pre_generate_templates(
|
|
58
|
+
self,
|
|
59
|
+
table_name: str,
|
|
60
|
+
column_name: str,
|
|
61
|
+
column_type: str,
|
|
62
|
+
count: int,
|
|
63
|
+
sample_data: list[Any],
|
|
64
|
+
) -> list[Any] | None:
|
|
65
|
+
if self._is_simple_column(column_name, column_type):
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
analyzer = self._get_analyzer()
|
|
69
|
+
try:
|
|
70
|
+
return analyzer.generate_template_values(
|
|
71
|
+
column_name=column_name,
|
|
72
|
+
column_type=column_type,
|
|
73
|
+
count=min(count, 50),
|
|
74
|
+
sample_data=sample_data,
|
|
75
|
+
)
|
|
76
|
+
except Exception:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
@hookimpl
|
|
80
|
+
def sqlseed_register_providers(self, registry: Any) -> None:
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@hookimpl
|
|
84
|
+
def sqlseed_register_column_mappers(self, mapper: Any) -> None:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
plugin = AISqlseedPlugin()
|
sqlseed_ai/_client.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlseed._utils.logger import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_openai_client(config: Any | None = None) -> Any:
|
|
11
|
+
try:
|
|
12
|
+
from openai import OpenAI
|
|
13
|
+
|
|
14
|
+
from sqlseed_ai.config import AIConfig
|
|
15
|
+
|
|
16
|
+
if config is None:
|
|
17
|
+
config = AIConfig.from_env()
|
|
18
|
+
|
|
19
|
+
api_key = config.api_key if hasattr(config, "api_key") else None
|
|
20
|
+
base_url = config.base_url if hasattr(config, "base_url") else None
|
|
21
|
+
|
|
22
|
+
if not api_key:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
"AI API key not configured. Set SQLSEED_AI_API_KEY or OPENAI_API_KEY environment variable."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
return OpenAI(api_key=api_key, base_url=base_url)
|
|
28
|
+
except ImportError:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
"openai is not installed. Install it with: pip install sqlseed-ai"
|
|
31
|
+
) from None
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def parse_json_response(content: str) -> dict[str, Any]:
|
|
8
|
+
cleaned = content.strip()
|
|
9
|
+
if cleaned.startswith("```json"):
|
|
10
|
+
cleaned = cleaned[7:]
|
|
11
|
+
if cleaned.startswith("```"):
|
|
12
|
+
cleaned = cleaned[3:]
|
|
13
|
+
if cleaned.endswith("```"):
|
|
14
|
+
cleaned = cleaned[:-3]
|
|
15
|
+
cleaned = cleaned.strip()
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
result = json.loads(cleaned)
|
|
19
|
+
if isinstance(result, dict):
|
|
20
|
+
return result
|
|
21
|
+
except json.JSONDecodeError:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
return {}
|
sqlseed_ai/analyzer.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlseed._utils.logger import get_logger
|
|
6
|
+
from sqlseed_ai._client import get_openai_client
|
|
7
|
+
from sqlseed_ai.config import AIConfig
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
SYSTEM_PROMPT = """You are an expert database test data engineer.
|
|
12
|
+
You analyze SQLite table schemas and recommend data generation configurations for the sqlseed toolkit.
|
|
13
|
+
|
|
14
|
+
## Available Generators
|
|
15
|
+
- string (params: min_length, max_length, charset)
|
|
16
|
+
- integer (params: min_value, max_value)
|
|
17
|
+
- float (params: min_value, max_value, precision)
|
|
18
|
+
- boolean
|
|
19
|
+
- bytes (params: length)
|
|
20
|
+
- name, first_name, last_name
|
|
21
|
+
- email, phone, address, company
|
|
22
|
+
- url, ipv4, uuid
|
|
23
|
+
- date (params: start_year, end_year)
|
|
24
|
+
- datetime (params: start_year, end_year)
|
|
25
|
+
- timestamp
|
|
26
|
+
- text (params: min_length, max_length)
|
|
27
|
+
- sentence, password
|
|
28
|
+
- choice (params: choices)
|
|
29
|
+
- json (params: schema)
|
|
30
|
+
- pattern (params: regex) — generates strings matching a regex pattern
|
|
31
|
+
|
|
32
|
+
## Key Rules
|
|
33
|
+
1. INTEGER PRIMARY KEY AUTOINCREMENT columns → do NOT include (auto-skip)
|
|
34
|
+
2. Columns with DEFAULT values → do NOT include (auto-skip)
|
|
35
|
+
3. Nullable columns → do NOT include unless they have semantic meaning
|
|
36
|
+
4. Use `pattern` generator with regex for card numbers, codes, IDs with specific formats
|
|
37
|
+
5. Use `derive_from` + `expression` when one column is computed from another
|
|
38
|
+
6. Use `constraints.unique: true` for columns that must be unique
|
|
39
|
+
7. Detect cross-column dependencies: if CutCard4byte = last 8 chars of sCardNo, use derive_from
|
|
40
|
+
8. Detect implicit business associations: if sUserNo appears in multiple tables, note it
|
|
41
|
+
|
|
42
|
+
## Output Format
|
|
43
|
+
You MUST respond with a valid JSON object (NOT YAML, NOT markdown fences).
|
|
44
|
+
The JSON object must have this exact structure:
|
|
45
|
+
{
|
|
46
|
+
"name": "table_name",
|
|
47
|
+
"count": 1000,
|
|
48
|
+
"columns": [
|
|
49
|
+
{
|
|
50
|
+
"name": "column_name",
|
|
51
|
+
"generator": "generator_name",
|
|
52
|
+
"params": {"key": "value"}
|
|
53
|
+
},
|
|
54
|
+
{
|
|
55
|
+
"name": "derived_column",
|
|
56
|
+
"derive_from": "source_column",
|
|
57
|
+
"expression": "value[-8:]",
|
|
58
|
+
"constraints": {"unique": true}
|
|
59
|
+
}
|
|
60
|
+
]
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
IMPORTANT: Do NOT include columns that are PRIMARY KEY AUTOINCREMENT or have DEFAULT values."""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SchemaAnalyzer:
|
|
67
|
+
def __init__(self, config: AIConfig | None = None) -> None:
|
|
68
|
+
self._config = config
|
|
69
|
+
|
|
70
|
+
def analyze_table(
|
|
71
|
+
self,
|
|
72
|
+
table_name: str,
|
|
73
|
+
columns: list[Any],
|
|
74
|
+
indexes: list[dict[str, Any]],
|
|
75
|
+
sample_data: list[dict[str, Any]],
|
|
76
|
+
foreign_keys: list[Any],
|
|
77
|
+
all_table_names: list[str],
|
|
78
|
+
) -> dict[str, Any] | None:
|
|
79
|
+
if self._config is None:
|
|
80
|
+
self._config = AIConfig.from_env()
|
|
81
|
+
|
|
82
|
+
if not self._config.api_key:
|
|
83
|
+
logger.warning("AI API key not configured, skipping analysis")
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
messages = self.build_initial_messages(
|
|
87
|
+
table_name=table_name,
|
|
88
|
+
columns=columns,
|
|
89
|
+
indexes=indexes,
|
|
90
|
+
sample_data=sample_data,
|
|
91
|
+
foreign_keys=foreign_keys,
|
|
92
|
+
all_table_names=all_table_names,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
return self.call_llm(messages)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.warning("AI analysis failed", table_name=table_name, error=str(e))
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
def build_initial_messages(
|
|
102
|
+
self,
|
|
103
|
+
table_name: str,
|
|
104
|
+
columns: list[Any],
|
|
105
|
+
indexes: list[dict[str, Any]],
|
|
106
|
+
sample_data: list[dict[str, Any]],
|
|
107
|
+
foreign_keys: list[Any],
|
|
108
|
+
all_table_names: list[str],
|
|
109
|
+
distribution_profiles: list[dict[str, Any]] | None = None,
|
|
110
|
+
) -> list[dict[str, str]]:
|
|
111
|
+
context = self._build_context(
|
|
112
|
+
table_name=table_name,
|
|
113
|
+
columns=columns,
|
|
114
|
+
indexes=indexes,
|
|
115
|
+
sample_data=sample_data,
|
|
116
|
+
foreign_keys=foreign_keys,
|
|
117
|
+
all_table_names=all_table_names,
|
|
118
|
+
distribution_profiles=distribution_profiles,
|
|
119
|
+
)
|
|
120
|
+
messages: list[dict[str, str]] = [
|
|
121
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
from sqlseed_ai.examples import FEW_SHOT_EXAMPLES
|
|
125
|
+
for example in FEW_SHOT_EXAMPLES:
|
|
126
|
+
messages.append({"role": "user", "content": example["input"]})
|
|
127
|
+
messages.append({"role": "assistant", "content": example["output"]})
|
|
128
|
+
|
|
129
|
+
messages.append({"role": "user", "content": context})
|
|
130
|
+
|
|
131
|
+
return messages
|
|
132
|
+
|
|
133
|
+
def call_llm(self, messages: list[dict[str, str]]) -> dict[str, Any]:
|
|
134
|
+
if self._config is None:
|
|
135
|
+
self._config = AIConfig.from_env()
|
|
136
|
+
if not self._config.api_key:
|
|
137
|
+
raise ValueError("AI API key not configured")
|
|
138
|
+
|
|
139
|
+
client = get_openai_client(self._config)
|
|
140
|
+
try:
|
|
141
|
+
response = client.chat.completions.create(
|
|
142
|
+
model=self._config.model,
|
|
143
|
+
messages=messages,
|
|
144
|
+
max_tokens=self._config.max_tokens,
|
|
145
|
+
temperature=self._config.temperature,
|
|
146
|
+
response_format={"type": "json_object"},
|
|
147
|
+
)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
raise RuntimeError(
|
|
150
|
+
f"LLM API call failed (model={self._config.model}): {e}"
|
|
151
|
+
) from e
|
|
152
|
+
|
|
153
|
+
if not response.choices:
|
|
154
|
+
raise RuntimeError(
|
|
155
|
+
f"LLM returned no choices (model={self._config.model}). "
|
|
156
|
+
"The API key or model may be invalid."
|
|
157
|
+
)
|
|
158
|
+
content = response.choices[0].message.content
|
|
159
|
+
if content is None:
|
|
160
|
+
return {}
|
|
161
|
+
return self._parse_json_response(content)
|
|
162
|
+
|
|
163
|
+
TEMPLATE_SYSTEM_PROMPT = (
|
|
164
|
+
"You are a data generation assistant. Generate realistic sample values "
|
|
165
|
+
"for the given database column. Return a JSON object with a 'values' "
|
|
166
|
+
"array containing the requested number of unique, realistic values. "
|
|
167
|
+
"Each value must be valid for the column type. Do NOT include explanations."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def generate_template_values(
|
|
171
|
+
self,
|
|
172
|
+
column_name: str,
|
|
173
|
+
column_type: str,
|
|
174
|
+
count: int,
|
|
175
|
+
sample_data: list[Any],
|
|
176
|
+
) -> list[Any]:
|
|
177
|
+
prompt = (
|
|
178
|
+
f"Generate {count} realistic sample values for a database column "
|
|
179
|
+
f"named '{column_name}' with type '{column_type}'."
|
|
180
|
+
)
|
|
181
|
+
if sample_data:
|
|
182
|
+
prompt += f"\nExisting sample values: {sample_data[:5]}"
|
|
183
|
+
prompt += (
|
|
184
|
+
f"\nRespond with a JSON object: {{\"values\": [...]}}."
|
|
185
|
+
f"\nEach value should be a valid {column_type} value."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
messages = [
|
|
189
|
+
{"role": "system", "content": self.TEMPLATE_SYSTEM_PROMPT},
|
|
190
|
+
{"role": "user", "content": prompt},
|
|
191
|
+
]
|
|
192
|
+
result = self.call_llm(messages)
|
|
193
|
+
return result.get("values", [])
|
|
194
|
+
|
|
195
|
+
def _build_context(
|
|
196
|
+
self,
|
|
197
|
+
table_name: str,
|
|
198
|
+
columns: list[Any],
|
|
199
|
+
indexes: list[dict[str, Any]],
|
|
200
|
+
sample_data: list[dict[str, Any]],
|
|
201
|
+
foreign_keys: list[Any],
|
|
202
|
+
all_table_names: list[str],
|
|
203
|
+
distribution_profiles: list[dict[str, Any]] | None = None,
|
|
204
|
+
) -> str:
|
|
205
|
+
lines: list[str] = []
|
|
206
|
+
lines.append(f"# Table: {table_name}")
|
|
207
|
+
lines.append("")
|
|
208
|
+
|
|
209
|
+
self._append_columns_info(lines, columns)
|
|
210
|
+
|
|
211
|
+
if indexes:
|
|
212
|
+
self._append_indexes_info(lines, indexes)
|
|
213
|
+
|
|
214
|
+
if foreign_keys:
|
|
215
|
+
lines.append("")
|
|
216
|
+
lines.append("## Foreign Keys")
|
|
217
|
+
for fk in foreign_keys:
|
|
218
|
+
lines.append(f"- {fk.column} → {fk.ref_table}.{fk.ref_column}")
|
|
219
|
+
|
|
220
|
+
if all_table_names:
|
|
221
|
+
lines.append("")
|
|
222
|
+
lines.append("## All Tables in Database")
|
|
223
|
+
lines.append(", ".join(all_table_names))
|
|
224
|
+
|
|
225
|
+
if sample_data:
|
|
226
|
+
lines.append("")
|
|
227
|
+
lines.append("## Sample Data (existing rows)")
|
|
228
|
+
for i, row in enumerate(sample_data[:5]):
|
|
229
|
+
row_str = ", ".join(f"{k}={v}" for k, v in row.items())
|
|
230
|
+
lines.append(f" Row {i + 1}: {row_str}")
|
|
231
|
+
|
|
232
|
+
if distribution_profiles:
|
|
233
|
+
self._append_distribution_info(lines, distribution_profiles)
|
|
234
|
+
|
|
235
|
+
lines.append("")
|
|
236
|
+
lines.append(
|
|
237
|
+
"Please analyze this table schema and recommend "
|
|
238
|
+
"a complete sqlseed JSON configuration for generating test data."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return "\n".join(lines)
|
|
242
|
+
|
|
243
|
+
def _append_columns_info(
|
|
244
|
+
self,
|
|
245
|
+
lines: list[str],
|
|
246
|
+
columns: list[Any],
|
|
247
|
+
) -> None:
|
|
248
|
+
lines.append("## Columns")
|
|
249
|
+
for col in columns:
|
|
250
|
+
parts = [f"- {col.name}: {col.type}"]
|
|
251
|
+
if col.is_primary_key:
|
|
252
|
+
parts.append("PRIMARY KEY")
|
|
253
|
+
if col.is_autoincrement:
|
|
254
|
+
parts.append("AUTOINCREMENT")
|
|
255
|
+
if col.nullable:
|
|
256
|
+
parts.append("NULLABLE")
|
|
257
|
+
if col.default is not None:
|
|
258
|
+
parts.append(f"DEFAULT={col.default}")
|
|
259
|
+
if not col.nullable and col.default is None and not col.is_primary_key:
|
|
260
|
+
parts.append("NOT NULL")
|
|
261
|
+
lines.append(" ".join(parts))
|
|
262
|
+
|
|
263
|
+
def _append_indexes_info(
|
|
264
|
+
self,
|
|
265
|
+
lines: list[str],
|
|
266
|
+
indexes: list[dict[str, Any]],
|
|
267
|
+
) -> None:
|
|
268
|
+
lines.append("")
|
|
269
|
+
lines.append("## Indexes")
|
|
270
|
+
for idx in indexes:
|
|
271
|
+
unique_str = "UNIQUE " if idx.get("unique") else ""
|
|
272
|
+
cols_str = ", ".join(idx.get("columns", []))
|
|
273
|
+
lines.append(f"- {unique_str}INDEX ({cols_str})")
|
|
274
|
+
|
|
275
|
+
def _append_distribution_info(
|
|
276
|
+
self,
|
|
277
|
+
lines: list[str],
|
|
278
|
+
distribution_profiles: list[dict[str, Any]],
|
|
279
|
+
) -> None:
|
|
280
|
+
lines.append("")
|
|
281
|
+
lines.append("## Column Distribution (from existing data)")
|
|
282
|
+
for profile in distribution_profiles:
|
|
283
|
+
col = profile["column"]
|
|
284
|
+
distinct = profile.get("distinct_count", "?")
|
|
285
|
+
null_ratio = profile.get("null_ratio", 0)
|
|
286
|
+
lines.append(
|
|
287
|
+
f"- {col}: {distinct} distinct values, {null_ratio:.1%} null"
|
|
288
|
+
)
|
|
289
|
+
top_values = profile.get("top_values", [])
|
|
290
|
+
if top_values:
|
|
291
|
+
top_str = ", ".join(
|
|
292
|
+
f"{tv['value']}({tv['frequency']:.0%})"
|
|
293
|
+
for tv in top_values[:3]
|
|
294
|
+
)
|
|
295
|
+
lines.append(f" Top values: {top_str}")
|
|
296
|
+
vr = profile.get("value_range")
|
|
297
|
+
if vr:
|
|
298
|
+
lines.append(f" Range: [{vr['min']}, {vr['max']}]")
|
|
299
|
+
|
|
300
|
+
def _parse_json_response(self, content: str) -> dict[str, Any]:
|
|
301
|
+
from sqlseed_ai._json_utils import parse_json_response
|
|
302
|
+
|
|
303
|
+
return parse_json_response(content)
|
|
304
|
+
|
sqlseed_ai/config.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AIConfig(BaseModel):
|
|
9
|
+
api_key: str | None = None
|
|
10
|
+
model: str = "qwen3-coder-plus"
|
|
11
|
+
base_url: str | None = None
|
|
12
|
+
temperature: float = Field(default=0.3, ge=0.0, le=2.0)
|
|
13
|
+
max_tokens: int = Field(default=4096, gt=0)
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def from_env(cls) -> AIConfig:
|
|
17
|
+
api_key = os.environ.get("SQLSEED_AI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
|
18
|
+
base_url = os.environ.get("SQLSEED_AI_BASE_URL") or os.environ.get("OPENAI_BASE_URL")
|
|
19
|
+
model = os.environ.get("SQLSEED_AI_MODEL", cls.model_fields["model"].default)
|
|
20
|
+
return cls(api_key=api_key, base_url=base_url, model=model)
|
sqlseed_ai/errors.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ErrorSummary:
|
|
10
|
+
error_type: str
|
|
11
|
+
message: str
|
|
12
|
+
column: str | None
|
|
13
|
+
retryable: bool
|
|
14
|
+
|
|
15
|
+
def to_prompt_str(self) -> str:
|
|
16
|
+
parts = [f"Error Type: {self.error_type}", f"Message: {self.message}"]
|
|
17
|
+
if self.column:
|
|
18
|
+
parts.append(f"Affected Column: {self.column}")
|
|
19
|
+
return "\n".join(parts)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def summarize_error(exc: Exception) -> ErrorSummary:
|
|
23
|
+
import json as _json
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from pydantic import ValidationError
|
|
27
|
+
|
|
28
|
+
if isinstance(exc, ValidationError):
|
|
29
|
+
first = exc.errors()[0]
|
|
30
|
+
loc = " → ".join(str(part) for part in first["loc"])
|
|
31
|
+
col_name = _extract_column_from_pydantic_loc(first["loc"])
|
|
32
|
+
return ErrorSummary(
|
|
33
|
+
error_type="pydantic_validation",
|
|
34
|
+
message=f"Field '{loc}': {first['msg']} (type={first['type']})",
|
|
35
|
+
column=col_name,
|
|
36
|
+
retryable=True,
|
|
37
|
+
)
|
|
38
|
+
except ImportError:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
if isinstance(exc, _json.JSONDecodeError):
|
|
42
|
+
return ErrorSummary(
|
|
43
|
+
error_type="json_syntax",
|
|
44
|
+
message=f"JSON parsing failed at position {exc.pos}: {exc.msg}",
|
|
45
|
+
column=None,
|
|
46
|
+
retryable=True,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if isinstance(exc, AttributeError) and "generate_" in str(exc):
|
|
50
|
+
gen_name = _extract_generator_name(str(exc))
|
|
51
|
+
return ErrorSummary(
|
|
52
|
+
error_type="unknown_generator",
|
|
53
|
+
message=(
|
|
54
|
+
f"Generator '{gen_name}' does not exist. "
|
|
55
|
+
"Use one of the available generators listed in the system prompt."
|
|
56
|
+
),
|
|
57
|
+
column=None,
|
|
58
|
+
retryable=True,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
from sqlseed.generators.stream import UnknownGeneratorError
|
|
63
|
+
if isinstance(exc, UnknownGeneratorError):
|
|
64
|
+
return ErrorSummary(
|
|
65
|
+
error_type="unknown_generator",
|
|
66
|
+
message=(
|
|
67
|
+
f"Generator '{exc.generator_name}' does not exist. "
|
|
68
|
+
"Use one of the available generators listed in the system prompt."
|
|
69
|
+
),
|
|
70
|
+
column=exc.column_name,
|
|
71
|
+
retryable=True,
|
|
72
|
+
)
|
|
73
|
+
except ImportError:
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
exc_type_name = type(exc).__name__
|
|
77
|
+
exc_module = str(getattr(type(exc), "__module__", ""))
|
|
78
|
+
if "ExpressionTimeout" in exc_type_name or "simpleeval" in exc_module:
|
|
79
|
+
return ErrorSummary(
|
|
80
|
+
error_type="expression_error",
|
|
81
|
+
message=f"Expression evaluation failed: {str(exc)[:150]}",
|
|
82
|
+
column=_extract_column_from_message(str(exc)),
|
|
83
|
+
retryable=True,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if isinstance(exc, (FileNotFoundError, PermissionError)):
|
|
87
|
+
return ErrorSummary(
|
|
88
|
+
error_type="fatal",
|
|
89
|
+
message=str(exc)[:200],
|
|
90
|
+
column=None,
|
|
91
|
+
retryable=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return ErrorSummary(
|
|
95
|
+
error_type="runtime_error",
|
|
96
|
+
message=f"{exc_type_name}: {str(exc)[:200]}",
|
|
97
|
+
column=_extract_column_from_message(str(exc)),
|
|
98
|
+
retryable=True,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _extract_column_from_pydantic_loc(loc: tuple[Any, ...]) -> str | None:
|
|
103
|
+
if len(loc) >= 3 and loc[0] == "columns":
|
|
104
|
+
if hasattr(loc[2], "value"):
|
|
105
|
+
return loc[2].value.get("name") if isinstance(loc[2].value, dict) else str(loc[2].value)
|
|
106
|
+
return str(loc[2])
|
|
107
|
+
if len(loc) >= 2 and loc[0] == "columns":
|
|
108
|
+
return str(loc[1])
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _extract_column_from_message(msg: str) -> str | None:
|
|
113
|
+
match = re.search(r"column[:\s]+'?(\w+)'?", msg, re.IGNORECASE)
|
|
114
|
+
return match.group(1) if match else None
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _extract_generator_name(msg: str) -> str:
|
|
118
|
+
match = re.search(r"generate_(\w+)", msg)
|
|
119
|
+
return match.group(1) if match else "unknown"
|
sqlseed_ai/examples.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
FEW_SHOT_EXAMPLES: list[dict[str, str]] = [
|
|
6
|
+
{
|
|
7
|
+
"input": """# Table: users
|
|
8
|
+
## Columns
|
|
9
|
+
- id: INTEGER PRIMARY KEY AUTOINCREMENT
|
|
10
|
+
- name: VARCHAR(50) NOT NULL
|
|
11
|
+
- email: VARCHAR(100) NOT NULL
|
|
12
|
+
- status: INTEGER DEFAULT 1
|
|
13
|
+
- created_at: DATETIME
|
|
14
|
+
## Indexes
|
|
15
|
+
- UNIQUE INDEX (email)
|
|
16
|
+
## All Tables in Database
|
|
17
|
+
users, orders""",
|
|
18
|
+
|
|
19
|
+
"output": json.dumps({
|
|
20
|
+
"name": "users",
|
|
21
|
+
"count": 1000,
|
|
22
|
+
"columns": [
|
|
23
|
+
{"name": "name", "generator": "name"},
|
|
24
|
+
{"name": "email", "generator": "email", "constraints": {"unique": True}},
|
|
25
|
+
{"name": "created_at", "generator": "datetime", "params": {"start_year": 2020, "end_year": 2025}},
|
|
26
|
+
]
|
|
27
|
+
}, indent=2),
|
|
28
|
+
},
|
|
29
|
+
{
|
|
30
|
+
"input": """# Table: card_info
|
|
31
|
+
## Columns
|
|
32
|
+
- cardId: INTEGER PRIMARY KEY AUTOINCREMENT
|
|
33
|
+
- sCardNo: VARCHAR(20) NOT NULL
|
|
34
|
+
- sUserNo: VARCHAR(32) NOT NULL
|
|
35
|
+
- CutCard4byte: VARCHAR(8)
|
|
36
|
+
- nStatus: INTEGER DEFAULT 0
|
|
37
|
+
- dCreateTime: DATETIME
|
|
38
|
+
## Indexes
|
|
39
|
+
- UNIQUE INDEX (sCardNo)
|
|
40
|
+
- UNIQUE INDEX (sUserNo)
|
|
41
|
+
## All Tables in Database
|
|
42
|
+
card_info, user_info""",
|
|
43
|
+
|
|
44
|
+
"output": json.dumps({
|
|
45
|
+
"name": "card_info",
|
|
46
|
+
"count": 1000,
|
|
47
|
+
"columns": [
|
|
48
|
+
{
|
|
49
|
+
"name": "sCardNo",
|
|
50
|
+
"generator": "pattern",
|
|
51
|
+
"params": {"regex": "[0-9]{16}"},
|
|
52
|
+
"constraints": {"unique": True},
|
|
53
|
+
},
|
|
54
|
+
{
|
|
55
|
+
"name": "sUserNo",
|
|
56
|
+
"generator": "pattern",
|
|
57
|
+
"params": {"regex": "U[0-9]{10}"},
|
|
58
|
+
"constraints": {"unique": True},
|
|
59
|
+
},
|
|
60
|
+
{"name": "CutCard4byte", "derive_from": "sCardNo", "expression": "value[-8:]"},
|
|
61
|
+
{
|
|
62
|
+
"name": "dCreateTime",
|
|
63
|
+
"generator": "datetime",
|
|
64
|
+
"params": {"start_year": 2023, "end_year": 2025},
|
|
65
|
+
},
|
|
66
|
+
]
|
|
67
|
+
}, indent=2),
|
|
68
|
+
},
|
|
69
|
+
{
|
|
70
|
+
"input": """# Table: orders
|
|
71
|
+
## Columns
|
|
72
|
+
- id: INTEGER PRIMARY KEY AUTOINCREMENT
|
|
73
|
+
- user_id: INTEGER NOT NULL
|
|
74
|
+
- product_name: VARCHAR(100) NOT NULL
|
|
75
|
+
- quantity: INTEGER NOT NULL
|
|
76
|
+
- unit_price: FLOAT NOT NULL
|
|
77
|
+
- order_status: VARCHAR(20) NOT NULL
|
|
78
|
+
- order_date: DATE
|
|
79
|
+
- notes: TEXT
|
|
80
|
+
## Foreign Keys
|
|
81
|
+
- user_id → users.id
|
|
82
|
+
## Indexes
|
|
83
|
+
- INDEX (user_id)
|
|
84
|
+
## All Tables in Database
|
|
85
|
+
users, orders""",
|
|
86
|
+
|
|
87
|
+
"output": json.dumps({
|
|
88
|
+
"name": "orders",
|
|
89
|
+
"count": 5000,
|
|
90
|
+
"columns": [
|
|
91
|
+
{
|
|
92
|
+
"name": "user_id",
|
|
93
|
+
"generator": "foreign_key",
|
|
94
|
+
"params": {"ref_table": "users", "ref_column": "id"},
|
|
95
|
+
},
|
|
96
|
+
{
|
|
97
|
+
"name": "product_name",
|
|
98
|
+
"generator": "string",
|
|
99
|
+
"params": {"min_length": 5, "max_length": 50},
|
|
100
|
+
},
|
|
101
|
+
{
|
|
102
|
+
"name": "quantity",
|
|
103
|
+
"generator": "integer",
|
|
104
|
+
"params": {"min_value": 1, "max_value": 100},
|
|
105
|
+
},
|
|
106
|
+
{
|
|
107
|
+
"name": "unit_price",
|
|
108
|
+
"generator": "float",
|
|
109
|
+
"params": {"min_value": 0.99, "max_value": 999.99, "precision": 2},
|
|
110
|
+
},
|
|
111
|
+
{
|
|
112
|
+
"name": "order_status",
|
|
113
|
+
"generator": "choice",
|
|
114
|
+
"params": {
|
|
115
|
+
"choices": [
|
|
116
|
+
"pending", "confirmed",
|
|
117
|
+
"shipped", "delivered", "cancelled",
|
|
118
|
+
],
|
|
119
|
+
},
|
|
120
|
+
},
|
|
121
|
+
{
|
|
122
|
+
"name": "order_date",
|
|
123
|
+
"generator": "date",
|
|
124
|
+
"params": {"start_year": 2023, "end_year": 2025},
|
|
125
|
+
},
|
|
126
|
+
]
|
|
127
|
+
}, indent=2),
|
|
128
|
+
},
|
|
129
|
+
{
|
|
130
|
+
"input": """# Table: employees
|
|
131
|
+
## Columns
|
|
132
|
+
- emp_id: INTEGER PRIMARY KEY AUTOINCREMENT
|
|
133
|
+
- dept_id: INTEGER NOT NULL
|
|
134
|
+
- first_name: VARCHAR(50) NOT NULL
|
|
135
|
+
- last_name: VARCHAR(50) NOT NULL
|
|
136
|
+
- hire_date: DATE NOT NULL
|
|
137
|
+
- salary: INTEGER NOT NULL
|
|
138
|
+
- is_active: BOOLEAN
|
|
139
|
+
- metadata: TEXT
|
|
140
|
+
## Foreign Keys
|
|
141
|
+
- dept_id → departments.id
|
|
142
|
+
## Indexes
|
|
143
|
+
- UNIQUE INDEX (first_name, last_name)
|
|
144
|
+
## All Tables in Database
|
|
145
|
+
departments, employees""",
|
|
146
|
+
|
|
147
|
+
"output": json.dumps({
|
|
148
|
+
"name": "employees",
|
|
149
|
+
"count": 2000,
|
|
150
|
+
"columns": [
|
|
151
|
+
{
|
|
152
|
+
"name": "dept_id",
|
|
153
|
+
"generator": "foreign_key",
|
|
154
|
+
"params": {"ref_table": "departments", "ref_column": "id"},
|
|
155
|
+
},
|
|
156
|
+
{"name": "first_name", "generator": "first_name"},
|
|
157
|
+
{"name": "last_name", "generator": "last_name"},
|
|
158
|
+
{
|
|
159
|
+
"name": "hire_date",
|
|
160
|
+
"generator": "date",
|
|
161
|
+
"params": {"start_year": 2015, "end_year": 2025},
|
|
162
|
+
},
|
|
163
|
+
{
|
|
164
|
+
"name": "salary",
|
|
165
|
+
"generator": "integer",
|
|
166
|
+
"params": {"min_value": 30000, "max_value": 200000},
|
|
167
|
+
},
|
|
168
|
+
{"name": "is_active", "generator": "boolean"},
|
|
169
|
+
]
|
|
170
|
+
}, indent=2),
|
|
171
|
+
},
|
|
172
|
+
]
|
sqlseed_ai/nl_config.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlseed._utils.logger import get_logger
|
|
6
|
+
from sqlseed_ai._client import get_openai_client
|
|
7
|
+
from sqlseed_ai.config import AIConfig
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NLConfigGenerator:
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: Any | None = None) -> None:
|
|
15
|
+
self._config = config
|
|
16
|
+
|
|
17
|
+
def generate(self, description: str, db_path: str | None = None) -> dict[str, Any]:
|
|
18
|
+
try:
|
|
19
|
+
client = get_openai_client(self._config)
|
|
20
|
+
model = AIConfig.model_fields["model"].default
|
|
21
|
+
if self._config is not None and hasattr(self._config, "model"):
|
|
22
|
+
model = self._config.model
|
|
23
|
+
|
|
24
|
+
schema_info = ""
|
|
25
|
+
if db_path:
|
|
26
|
+
schema_info = self._read_schema(db_path)
|
|
27
|
+
|
|
28
|
+
prompt = (
|
|
29
|
+
f"Generate a sqlseed configuration based on this description:\n"
|
|
30
|
+
f'"{description}"\n\n'
|
|
31
|
+
)
|
|
32
|
+
if schema_info:
|
|
33
|
+
prompt += f"Database schema:\n{schema_info}\n\n"
|
|
34
|
+
|
|
35
|
+
prompt += (
|
|
36
|
+
"The JSON should follow this structure:\n"
|
|
37
|
+
'{"db_path": "path", "provider": "mimesis", "locale": "en_US", '
|
|
38
|
+
'"tables": [{"name": "table_name", "count": 1000, '
|
|
39
|
+
'"columns": [{"name": "column_name", "generator": "generator_name", "params": {}}]}]}\n\n'
|
|
40
|
+
"Respond with valid JSON only."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
response = client.chat.completions.create(
|
|
44
|
+
model=model,
|
|
45
|
+
messages=[{"role": "user", "content": prompt}],
|
|
46
|
+
max_tokens=2048,
|
|
47
|
+
temperature=0.5,
|
|
48
|
+
response_format={"type": "json_object"},
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
content = response.choices[0].message.content if response.choices else None
|
|
52
|
+
if content is None:
|
|
53
|
+
return {}
|
|
54
|
+
|
|
55
|
+
from sqlseed_ai._json_utils import parse_json_response
|
|
56
|
+
|
|
57
|
+
return parse_json_response(content)
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.warning("NL config generation failed", error=e)
|
|
61
|
+
return {}
|
|
62
|
+
|
|
63
|
+
def _read_schema(self, db_path: str) -> str:
|
|
64
|
+
try:
|
|
65
|
+
from sqlseed.database.raw_sqlite_adapter import RawSQLiteAdapter
|
|
66
|
+
|
|
67
|
+
adapter = RawSQLiteAdapter()
|
|
68
|
+
adapter.connect(db_path)
|
|
69
|
+
tables = adapter.get_table_names()
|
|
70
|
+
lines: list[str] = []
|
|
71
|
+
for table in tables:
|
|
72
|
+
columns = adapter.get_column_info(table)
|
|
73
|
+
col_desc = ", ".join(
|
|
74
|
+
f"{c.name}({c.type})" for c in columns
|
|
75
|
+
)
|
|
76
|
+
lines.append(f" {table}: {col_desc}")
|
|
77
|
+
adapter.close()
|
|
78
|
+
return "\n".join(lines)
|
|
79
|
+
except Exception:
|
|
80
|
+
return ""
|
sqlseed_ai/provider.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AIProvider:
|
|
7
|
+
|
|
8
|
+
@property
|
|
9
|
+
def name(self) -> str:
|
|
10
|
+
return "ai"
|
|
11
|
+
|
|
12
|
+
def set_locale(self, locale: str) -> None:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
def set_seed(self, seed: int) -> None:
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
def generate_string(self, **kwargs: Any) -> str:
|
|
19
|
+
return ""
|
|
20
|
+
|
|
21
|
+
def generate_integer(self, **kwargs: Any) -> int:
|
|
22
|
+
return 0
|
|
23
|
+
|
|
24
|
+
def generate_float(self, **kwargs: Any) -> float:
|
|
25
|
+
return 0.0
|
|
26
|
+
|
|
27
|
+
def generate_boolean(self) -> bool:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
def generate_bytes(self, **kwargs: Any) -> bytes:
|
|
31
|
+
return b""
|
|
32
|
+
|
|
33
|
+
def generate_name(self) -> str:
|
|
34
|
+
return ""
|
|
35
|
+
|
|
36
|
+
def generate_first_name(self) -> str:
|
|
37
|
+
return ""
|
|
38
|
+
|
|
39
|
+
def generate_last_name(self) -> str:
|
|
40
|
+
return ""
|
|
41
|
+
|
|
42
|
+
def generate_email(self) -> str:
|
|
43
|
+
return ""
|
|
44
|
+
|
|
45
|
+
def generate_phone(self) -> str:
|
|
46
|
+
return ""
|
|
47
|
+
|
|
48
|
+
def generate_address(self) -> str:
|
|
49
|
+
return ""
|
|
50
|
+
|
|
51
|
+
def generate_company(self) -> str:
|
|
52
|
+
return ""
|
|
53
|
+
|
|
54
|
+
def generate_url(self) -> str:
|
|
55
|
+
return ""
|
|
56
|
+
|
|
57
|
+
def generate_ipv4(self) -> str:
|
|
58
|
+
return ""
|
|
59
|
+
|
|
60
|
+
def generate_uuid(self) -> str:
|
|
61
|
+
return ""
|
|
62
|
+
|
|
63
|
+
def generate_date(self, **kwargs: Any) -> str:
|
|
64
|
+
return ""
|
|
65
|
+
|
|
66
|
+
def generate_datetime(self, **kwargs: Any) -> str:
|
|
67
|
+
return ""
|
|
68
|
+
|
|
69
|
+
def generate_timestamp(self) -> int:
|
|
70
|
+
return 0
|
|
71
|
+
|
|
72
|
+
def generate_text(self, **kwargs: Any) -> str:
|
|
73
|
+
return ""
|
|
74
|
+
|
|
75
|
+
def generate_sentence(self) -> str:
|
|
76
|
+
return ""
|
|
77
|
+
|
|
78
|
+
def generate_password(self, **kwargs: Any) -> str:
|
|
79
|
+
return ""
|
|
80
|
+
|
|
81
|
+
def generate_choice(self, choices: list[Any]) -> Any:
|
|
82
|
+
return choices[0] if choices else None
|
|
83
|
+
|
|
84
|
+
def generate_json(self, **kwargs: Any) -> str:
|
|
85
|
+
return "{}"
|
|
86
|
+
|
|
87
|
+
def generate_pattern(self, *, regex: str, **kwargs: Any) -> str:
|
|
88
|
+
return ""
|
sqlseed_ai/refiner.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
from sqlseed._utils.logger import get_logger
|
|
10
|
+
from sqlseed_ai.errors import ErrorSummary, summarize_error
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from sqlseed_ai.analyzer import SchemaAnalyzer
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AISuggestionFailedError(Exception):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AiConfigRefiner:
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
analyzer: SchemaAnalyzer,
|
|
27
|
+
db_path: str,
|
|
28
|
+
*,
|
|
29
|
+
cache_dir: str | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self._analyzer = analyzer
|
|
32
|
+
self._db_path = db_path
|
|
33
|
+
self._cache_dir = Path(cache_dir) if cache_dir else Path(".sqlseed_cache/ai_configs")
|
|
34
|
+
|
|
35
|
+
def generate_and_refine(
|
|
36
|
+
self,
|
|
37
|
+
table_name: str,
|
|
38
|
+
*,
|
|
39
|
+
max_retries: int = 3,
|
|
40
|
+
no_cache: bool = False,
|
|
41
|
+
) -> dict[str, Any]:
|
|
42
|
+
from sqlseed.core.orchestrator import DataOrchestrator
|
|
43
|
+
|
|
44
|
+
with DataOrchestrator(self._db_path) as orch:
|
|
45
|
+
schema_hash = self._compute_schema_hash(orch, table_name)
|
|
46
|
+
|
|
47
|
+
if not no_cache:
|
|
48
|
+
cached = self.get_cached_config(table_name, schema_hash)
|
|
49
|
+
if cached is not None:
|
|
50
|
+
logger.info("Using cached AI config", table_name=table_name)
|
|
51
|
+
return cached
|
|
52
|
+
|
|
53
|
+
schema_ctx = orch.get_schema_context(table_name)
|
|
54
|
+
|
|
55
|
+
initial_messages = self._analyzer.build_initial_messages(
|
|
56
|
+
table_name=schema_ctx["table_name"],
|
|
57
|
+
columns=schema_ctx["columns"],
|
|
58
|
+
indexes=schema_ctx["indexes"],
|
|
59
|
+
sample_data=schema_ctx["sample_data"],
|
|
60
|
+
foreign_keys=schema_ctx["foreign_keys"],
|
|
61
|
+
all_table_names=schema_ctx["all_table_names"],
|
|
62
|
+
distribution_profiles=schema_ctx.get("distribution"),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
messages_history = list(initial_messages)
|
|
66
|
+
|
|
67
|
+
for attempt in range(max_retries + 1):
|
|
68
|
+
messages = list(messages_history)
|
|
69
|
+
try:
|
|
70
|
+
config_dict = self._analyzer.call_llm(messages)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
error = summarize_error(e)
|
|
73
|
+
if not error.retryable:
|
|
74
|
+
raise AISuggestionFailedError(
|
|
75
|
+
f"Non-retryable error: {error.message}"
|
|
76
|
+
) from e
|
|
77
|
+
if attempt == max_retries:
|
|
78
|
+
raise AISuggestionFailedError(
|
|
79
|
+
f"Failed after {max_retries} retries. Last error: {error.message}"
|
|
80
|
+
) from e
|
|
81
|
+
logger.info(
|
|
82
|
+
"LLM API call failed, retrying",
|
|
83
|
+
attempt=attempt + 1,
|
|
84
|
+
max_retries=max_retries,
|
|
85
|
+
error=error.message,
|
|
86
|
+
)
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
error = self._validate_config(orch, table_name, config_dict)
|
|
90
|
+
|
|
91
|
+
if error is None:
|
|
92
|
+
logger.info(
|
|
93
|
+
"AI config validated successfully",
|
|
94
|
+
table_name=table_name,
|
|
95
|
+
attempts=attempt + 1,
|
|
96
|
+
)
|
|
97
|
+
self._cache_successful_config(table_name, config_dict, schema_hash)
|
|
98
|
+
return config_dict
|
|
99
|
+
|
|
100
|
+
if not error.retryable:
|
|
101
|
+
raise AISuggestionFailedError(
|
|
102
|
+
f"Non-retryable error: {error.message}"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if attempt == max_retries:
|
|
106
|
+
logger.warning(
|
|
107
|
+
"AI config refinement exhausted all retries",
|
|
108
|
+
table_name=table_name,
|
|
109
|
+
last_error=error.error_type,
|
|
110
|
+
)
|
|
111
|
+
raise AISuggestionFailedError(
|
|
112
|
+
f"Failed after {max_retries} retries. Last error: {error.message}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
logger.info(
|
|
116
|
+
"AI config refinement attempt",
|
|
117
|
+
attempt=attempt + 1,
|
|
118
|
+
max_retries=max_retries,
|
|
119
|
+
error_type=error.error_type,
|
|
120
|
+
column=error.column,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
messages_history.append({
|
|
124
|
+
"role": "assistant",
|
|
125
|
+
"content": json.dumps(config_dict, ensure_ascii=False),
|
|
126
|
+
})
|
|
127
|
+
messages_history.append({
|
|
128
|
+
"role": "user",
|
|
129
|
+
"content": self._build_refinement_prompt(error, attempt, max_retries),
|
|
130
|
+
})
|
|
131
|
+
|
|
132
|
+
raise AISuggestionFailedError("Unexpected state")
|
|
133
|
+
|
|
134
|
+
def _compute_schema_hash(self, orch: Any, table_name: str) -> str:
|
|
135
|
+
column_names = orch.get_column_names(table_name)
|
|
136
|
+
raw = "|".join(sorted(column_names))
|
|
137
|
+
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
|
138
|
+
|
|
139
|
+
def _validate_config(
|
|
140
|
+
self,
|
|
141
|
+
orch: Any,
|
|
142
|
+
table_name: str,
|
|
143
|
+
config_dict: dict[str, Any],
|
|
144
|
+
) -> ErrorSummary | None:
|
|
145
|
+
from sqlseed.config.models import TableConfig
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
table_config = TableConfig(**config_dict)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
return summarize_error(e)
|
|
151
|
+
|
|
152
|
+
actual_columns = orch.get_column_names(table_name)
|
|
153
|
+
skippable_cols = orch.get_skippable_columns(table_name)
|
|
154
|
+
suggestable_cols = actual_columns - skippable_cols
|
|
155
|
+
|
|
156
|
+
for col_cfg in table_config.columns:
|
|
157
|
+
if col_cfg.name not in actual_columns:
|
|
158
|
+
return ErrorSummary(
|
|
159
|
+
error_type="column_mismatch",
|
|
160
|
+
message=(
|
|
161
|
+
f"Column '{col_cfg.name}' does not exist in table "
|
|
162
|
+
f"'{table_name}'. Available columns: "
|
|
163
|
+
f"{sorted(actual_columns)}"
|
|
164
|
+
),
|
|
165
|
+
column=col_cfg.name,
|
|
166
|
+
retryable=True,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if suggestable_cols and len(table_config.columns) == 0:
|
|
170
|
+
return ErrorSummary(
|
|
171
|
+
error_type="empty_config",
|
|
172
|
+
message=(
|
|
173
|
+
f"No column suggestions provided for table '{table_name}'. "
|
|
174
|
+
f"There are {len(suggestable_cols)} suggestable columns: "
|
|
175
|
+
f"{sorted(suggestable_cols)}. "
|
|
176
|
+
"Please provide generator suggestions for at least the "
|
|
177
|
+
"non-default, non-autoincrement columns."
|
|
178
|
+
),
|
|
179
|
+
column=None,
|
|
180
|
+
retryable=True,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
orch.preview_table(
|
|
185
|
+
table_name=table_name,
|
|
186
|
+
count=5,
|
|
187
|
+
column_configs=table_config.columns,
|
|
188
|
+
)
|
|
189
|
+
except Exception as e:
|
|
190
|
+
return summarize_error(e)
|
|
191
|
+
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def _build_refinement_prompt(
|
|
195
|
+
self,
|
|
196
|
+
error: ErrorSummary,
|
|
197
|
+
attempt: int,
|
|
198
|
+
max_retries: int,
|
|
199
|
+
) -> str:
|
|
200
|
+
parts = [
|
|
201
|
+
"Your previous configuration contained an error. Please fix it.",
|
|
202
|
+
"",
|
|
203
|
+
"## Error Details",
|
|
204
|
+
error.to_prompt_str(),
|
|
205
|
+
"",
|
|
206
|
+
"## Instructions",
|
|
207
|
+
"- Only fix the column(s) mentioned in the error.",
|
|
208
|
+
"- Do NOT modify other column configurations that were working correctly.",
|
|
209
|
+
"- Return the COMPLETE configuration JSON "
|
|
210
|
+
"with only the problematic parts corrected.",
|
|
211
|
+
"- If you are unsure how to fix the error, "
|
|
212
|
+
"use 'string' generator as a safe fallback.",
|
|
213
|
+
"",
|
|
214
|
+
f"This is refinement attempt {attempt + 1} of {max_retries}.",
|
|
215
|
+
]
|
|
216
|
+
|
|
217
|
+
if attempt >= max_retries - 1:
|
|
218
|
+
parts.append(
|
|
219
|
+
"WARNING: This is the LAST attempt. "
|
|
220
|
+
"Use the simplest possible generators to ensure validity."
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return "\n".join(parts)
|
|
224
|
+
|
|
225
|
+
def _cache_successful_config(
|
|
226
|
+
self,
|
|
227
|
+
table_name: str,
|
|
228
|
+
config_dict: dict[str, Any],
|
|
229
|
+
schema_hash: str,
|
|
230
|
+
) -> None:
|
|
231
|
+
try:
|
|
232
|
+
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
|
233
|
+
cache_file = self._cache_dir / f"{table_name}.json"
|
|
234
|
+
entry = {
|
|
235
|
+
"_meta": {
|
|
236
|
+
"schema_hash": schema_hash,
|
|
237
|
+
"created_at": time.time(),
|
|
238
|
+
},
|
|
239
|
+
"config": config_dict,
|
|
240
|
+
}
|
|
241
|
+
cache_file.write_text(
|
|
242
|
+
json.dumps(entry, ensure_ascii=False, indent=2),
|
|
243
|
+
encoding="utf-8",
|
|
244
|
+
)
|
|
245
|
+
logger.debug(
|
|
246
|
+
"Cached AI config",
|
|
247
|
+
table_name=table_name,
|
|
248
|
+
path=str(cache_file),
|
|
249
|
+
schema_hash=schema_hash,
|
|
250
|
+
)
|
|
251
|
+
except Exception as e:
|
|
252
|
+
logger.debug("Failed to cache AI config", error=str(e))
|
|
253
|
+
|
|
254
|
+
def get_cached_config(
|
|
255
|
+
self,
|
|
256
|
+
table_name: str,
|
|
257
|
+
schema_hash: str | None = None,
|
|
258
|
+
) -> dict[str, Any] | None:
|
|
259
|
+
cache_file = self._cache_dir / f"{table_name}.json"
|
|
260
|
+
if cache_file.exists():
|
|
261
|
+
try:
|
|
262
|
+
entry = json.loads(cache_file.read_text(encoding="utf-8"))
|
|
263
|
+
if isinstance(entry, dict) and "_meta" in entry:
|
|
264
|
+
cached_hash = entry["_meta"].get("schema_hash", "")
|
|
265
|
+
if schema_hash and cached_hash != schema_hash:
|
|
266
|
+
logger.debug(
|
|
267
|
+
"Cache schema hash mismatch, invalidating",
|
|
268
|
+
table_name=table_name,
|
|
269
|
+
cached_hash=cached_hash,
|
|
270
|
+
current_hash=schema_hash,
|
|
271
|
+
)
|
|
272
|
+
return None
|
|
273
|
+
return entry.get("config")
|
|
274
|
+
return entry
|
|
275
|
+
except Exception:
|
|
276
|
+
pass
|
|
277
|
+
return None
|
sqlseed_ai/suggest.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from sqlseed._utils.logger import get_logger
|
|
6
|
+
from sqlseed_ai._client import get_openai_client
|
|
7
|
+
from sqlseed_ai.config import AIConfig
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ColumnSuggester:
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: Any | None = None) -> None:
|
|
15
|
+
self._config = config
|
|
16
|
+
|
|
17
|
+
def suggest(
|
|
18
|
+
self,
|
|
19
|
+
column_name: str,
|
|
20
|
+
column_type: str,
|
|
21
|
+
table_name: str,
|
|
22
|
+
all_column_names: list[str],
|
|
23
|
+
) -> dict[str, Any] | None:
|
|
24
|
+
try:
|
|
25
|
+
client = get_openai_client(self._config)
|
|
26
|
+
model = AIConfig.model_fields["model"].default
|
|
27
|
+
if self._config is not None and hasattr(self._config, "model"):
|
|
28
|
+
model = self._config.model
|
|
29
|
+
|
|
30
|
+
prompt = (
|
|
31
|
+
f"Given a SQLite table '{table_name}' with columns {all_column_names}, "
|
|
32
|
+
f"the column '{column_name}' has type '{column_type}'. "
|
|
33
|
+
f"Suggest the best data generator name and params for this column. "
|
|
34
|
+
f"Available generators: string, integer, float, boolean, bytes, "
|
|
35
|
+
f"name, first_name, last_name, email, phone, address, company, "
|
|
36
|
+
f"url, ipv4, uuid, date, datetime, timestamp, text, sentence, "
|
|
37
|
+
f"password, choice, json, foreign_key, pattern. "
|
|
38
|
+
f'Respond in JSON format: {{"generator": "...", "params": {{}}}}'
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
response = client.chat.completions.create(
|
|
42
|
+
model=model,
|
|
43
|
+
messages=[{"role": "user", "content": prompt}],
|
|
44
|
+
max_tokens=256,
|
|
45
|
+
temperature=0.3,
|
|
46
|
+
response_format={"type": "json_object"},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
content = response.choices[0].message.content if response.choices else None
|
|
50
|
+
if content is None:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
from sqlseed_ai._json_utils import parse_json_response
|
|
54
|
+
|
|
55
|
+
result = parse_json_response(content)
|
|
56
|
+
if "generator" in result:
|
|
57
|
+
return result
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.warning("AI suggestion failed", column_name=column_name, error=e)
|
|
61
|
+
|
|
62
|
+
return None
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sqlseed-ai
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: AI-powered data generation plugin for sqlseed
|
|
5
|
+
Project-URL: Homepage, https://github.com/sunbos/sqlseed
|
|
6
|
+
Project-URL: Repository, https://github.com/sunbos/sqlseed/tree/main/plugins/sqlseed-ai
|
|
7
|
+
Author-email: SunBo <1443584939@qq.com>
|
|
8
|
+
License-Expression: AGPL-3.0-or-later
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Requires-Python: >=3.10
|
|
17
|
+
Requires-Dist: openai>=1.0
|
|
18
|
+
Requires-Dist: pydantic>=2.0
|
|
19
|
+
Requires-Dist: sqlseed
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
sqlseed_ai/__init__.py,sha256=KD8UutkBS4td2IaeEcG1YlDPafiMEFYnX1A-2_X0qaw,2450
|
|
2
|
+
sqlseed_ai/_client.py,sha256=0lLJD03sCvJPeCZXYMUYrdi_RVGq9Xhufa2oZt_TvL8,891
|
|
3
|
+
sqlseed_ai/_json_utils.py,sha256=TByhP4wvxj17aiQe9-Lgx88k4tqIa2RoOBWR7i-xgB4,556
|
|
4
|
+
sqlseed_ai/analyzer.py,sha256=aetDgwymMBDop14Qhu2kLIlTP6AWJ5m1xKPQNVX-qgU,10554
|
|
5
|
+
sqlseed_ai/config.py,sha256=Zc3N1EOtbuAcfxo48nftKDJK1paR4Av8x0lYcYojTmg,713
|
|
6
|
+
sqlseed_ai/errors.py,sha256=W43rjOWaP1zFUSltXRtpB9CXalEP_qgHzMfzdmD418c,3846
|
|
7
|
+
sqlseed_ai/examples.py,sha256=Sfl9JVjGiMjwXm9NuuHZPmnKKGTvg_yl599w5z8O9Yk,5322
|
|
8
|
+
sqlseed_ai/nl_config.py,sha256=Wcf8ujcoA98XAKjUA6HkReuQplodCC83afDL7UIFDP4,2781
|
|
9
|
+
sqlseed_ai/provider.py,sha256=JNr4Gsr0-cXH3hRxVBogrXehIM4YHudWJhVJgwzEdiM,1853
|
|
10
|
+
sqlseed_ai/refiner.py,sha256=br0wVQYGm-HuN3FW5x8wAyOgFqDGonFGsEx3INCQVnc,9896
|
|
11
|
+
sqlseed_ai/suggest.py,sha256=EckLx5q5CgwK9RJRsslOJTjJG6DQNufdRqCOxmhr8VU,2196
|
|
12
|
+
sqlseed_ai-0.1.0.dist-info/METADATA,sha256=RcpIcb4GFHQfGemI1UdxjgRcuSS7Ds5uDfNpnHs3U2c,826
|
|
13
|
+
sqlseed_ai-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
14
|
+
sqlseed_ai-0.1.0.dist-info/entry_points.txt,sha256=zEHu6heoDb7qXrooVmJN10y-4LSWIhitMNdmcovm1UE,33
|
|
15
|
+
sqlseed_ai-0.1.0.dist-info/RECORD,,
|