sqlseed-ai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlseed_ai/__init__.py ADDED
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import Any
5
+
6
+ from sqlseed.plugins.hookspecs import hookimpl
7
+ from sqlseed_ai.analyzer import SchemaAnalyzer
8
+
9
+ _SIMPLE_COL_RE = re.compile(
10
+ r'(^|[_\s])('
11
+ r'name|email|phone|address|url|uuid|'
12
+ r'date|time|datetime|timestamp|boolean|'
13
+ r'int|float|double|real|text|string|'
14
+ r'char|varchar|blob|byte|id|code|title|'
15
+ r'description|status|type|category|count|'
16
+ r'amount|price|value|number|index|order|level'
17
+ r')($|[_\s])',
18
+ re.IGNORECASE,
19
+ )
20
+
21
+
22
+ class AISqlseedPlugin:
23
+ def __init__(self) -> None:
24
+ self._analyzer: SchemaAnalyzer | None = None
25
+
26
+ def _get_analyzer(self) -> SchemaAnalyzer:
27
+ if self._analyzer is None:
28
+ self._analyzer = SchemaAnalyzer()
29
+ return self._analyzer
30
+
31
+ def _is_simple_column(self, column_name: str, column_type: str) -> bool:
32
+ return bool(
33
+ _SIMPLE_COL_RE.search(column_name) or _SIMPLE_COL_RE.search(column_type)
34
+ )
35
+
36
+ @hookimpl
37
+ def sqlseed_ai_analyze_table(
38
+ self,
39
+ table_name: str,
40
+ columns: list[Any],
41
+ indexes: list[dict[str, Any]],
42
+ sample_data: list[dict[str, Any]],
43
+ foreign_keys: list[Any],
44
+ all_table_names: list[str],
45
+ ) -> dict[str, Any] | None:
46
+ analyzer = self._get_analyzer()
47
+ return analyzer.analyze_table(
48
+ table_name=table_name,
49
+ columns=columns,
50
+ indexes=indexes,
51
+ sample_data=sample_data,
52
+ foreign_keys=foreign_keys,
53
+ all_table_names=all_table_names,
54
+ )
55
+
56
+ @hookimpl
57
+ def sqlseed_pre_generate_templates(
58
+ self,
59
+ table_name: str,
60
+ column_name: str,
61
+ column_type: str,
62
+ count: int,
63
+ sample_data: list[Any],
64
+ ) -> list[Any] | None:
65
+ if self._is_simple_column(column_name, column_type):
66
+ return None
67
+
68
+ analyzer = self._get_analyzer()
69
+ try:
70
+ return analyzer.generate_template_values(
71
+ column_name=column_name,
72
+ column_type=column_type,
73
+ count=min(count, 50),
74
+ sample_data=sample_data,
75
+ )
76
+ except Exception:
77
+ return None
78
+
79
+ @hookimpl
80
+ def sqlseed_register_providers(self, registry: Any) -> None:
81
+ pass
82
+
83
+ @hookimpl
84
+ def sqlseed_register_column_mappers(self, mapper: Any) -> None:
85
+ pass
86
+
87
+
88
+ plugin = AISqlseedPlugin()
sqlseed_ai/_client.py ADDED
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from sqlseed._utils.logger import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ def get_openai_client(config: Any | None = None) -> Any:
11
+ try:
12
+ from openai import OpenAI
13
+
14
+ from sqlseed_ai.config import AIConfig
15
+
16
+ if config is None:
17
+ config = AIConfig.from_env()
18
+
19
+ api_key = config.api_key if hasattr(config, "api_key") else None
20
+ base_url = config.base_url if hasattr(config, "base_url") else None
21
+
22
+ if not api_key:
23
+ raise ValueError(
24
+ "AI API key not configured. Set SQLSEED_AI_API_KEY or OPENAI_API_KEY environment variable."
25
+ )
26
+
27
+ return OpenAI(api_key=api_key, base_url=base_url)
28
+ except ImportError:
29
+ raise ImportError(
30
+ "openai is not installed. Install it with: pip install sqlseed-ai"
31
+ ) from None
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+
7
+ def parse_json_response(content: str) -> dict[str, Any]:
8
+ cleaned = content.strip()
9
+ if cleaned.startswith("```json"):
10
+ cleaned = cleaned[7:]
11
+ if cleaned.startswith("```"):
12
+ cleaned = cleaned[3:]
13
+ if cleaned.endswith("```"):
14
+ cleaned = cleaned[:-3]
15
+ cleaned = cleaned.strip()
16
+
17
+ try:
18
+ result = json.loads(cleaned)
19
+ if isinstance(result, dict):
20
+ return result
21
+ except json.JSONDecodeError:
22
+ pass
23
+
24
+ return {}
sqlseed_ai/analyzer.py ADDED
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from sqlseed._utils.logger import get_logger
6
+ from sqlseed_ai._client import get_openai_client
7
+ from sqlseed_ai.config import AIConfig
8
+
9
+ logger = get_logger(__name__)
10
+
11
+ SYSTEM_PROMPT = """You are an expert database test data engineer.
12
+ You analyze SQLite table schemas and recommend data generation configurations for the sqlseed toolkit.
13
+
14
+ ## Available Generators
15
+ - string (params: min_length, max_length, charset)
16
+ - integer (params: min_value, max_value)
17
+ - float (params: min_value, max_value, precision)
18
+ - boolean
19
+ - bytes (params: length)
20
+ - name, first_name, last_name
21
+ - email, phone, address, company
22
+ - url, ipv4, uuid
23
+ - date (params: start_year, end_year)
24
+ - datetime (params: start_year, end_year)
25
+ - timestamp
26
+ - text (params: min_length, max_length)
27
+ - sentence, password
28
+ - choice (params: choices)
29
+ - json (params: schema)
30
+ - pattern (params: regex) — generates strings matching a regex pattern
31
+
32
+ ## Key Rules
33
+ 1. INTEGER PRIMARY KEY AUTOINCREMENT columns → do NOT include (auto-skip)
34
+ 2. Columns with DEFAULT values → do NOT include (auto-skip)
35
+ 3. Nullable columns → do NOT include unless they have semantic meaning
36
+ 4. Use `pattern` generator with regex for card numbers, codes, IDs with specific formats
37
+ 5. Use `derive_from` + `expression` when one column is computed from another
38
+ 6. Use `constraints.unique: true` for columns that must be unique
39
+ 7. Detect cross-column dependencies: if CutCard4byte = last 8 chars of sCardNo, use derive_from
40
+ 8. Detect implicit business associations: if sUserNo appears in multiple tables, note it
41
+
42
+ ## Output Format
43
+ You MUST respond with a valid JSON object (NOT YAML, NOT markdown fences).
44
+ The JSON object must have this exact structure:
45
+ {
46
+ "name": "table_name",
47
+ "count": 1000,
48
+ "columns": [
49
+ {
50
+ "name": "column_name",
51
+ "generator": "generator_name",
52
+ "params": {"key": "value"}
53
+ },
54
+ {
55
+ "name": "derived_column",
56
+ "derive_from": "source_column",
57
+ "expression": "value[-8:]",
58
+ "constraints": {"unique": true}
59
+ }
60
+ ]
61
+ }
62
+
63
+ IMPORTANT: Do NOT include columns that are PRIMARY KEY AUTOINCREMENT or have DEFAULT values."""
64
+
65
+
66
+ class SchemaAnalyzer:
67
+ def __init__(self, config: AIConfig | None = None) -> None:
68
+ self._config = config
69
+
70
+ def analyze_table(
71
+ self,
72
+ table_name: str,
73
+ columns: list[Any],
74
+ indexes: list[dict[str, Any]],
75
+ sample_data: list[dict[str, Any]],
76
+ foreign_keys: list[Any],
77
+ all_table_names: list[str],
78
+ ) -> dict[str, Any] | None:
79
+ if self._config is None:
80
+ self._config = AIConfig.from_env()
81
+
82
+ if not self._config.api_key:
83
+ logger.warning("AI API key not configured, skipping analysis")
84
+ return None
85
+
86
+ messages = self.build_initial_messages(
87
+ table_name=table_name,
88
+ columns=columns,
89
+ indexes=indexes,
90
+ sample_data=sample_data,
91
+ foreign_keys=foreign_keys,
92
+ all_table_names=all_table_names,
93
+ )
94
+
95
+ try:
96
+ return self.call_llm(messages)
97
+ except Exception as e:
98
+ logger.warning("AI analysis failed", table_name=table_name, error=str(e))
99
+ return None
100
+
101
+ def build_initial_messages(
102
+ self,
103
+ table_name: str,
104
+ columns: list[Any],
105
+ indexes: list[dict[str, Any]],
106
+ sample_data: list[dict[str, Any]],
107
+ foreign_keys: list[Any],
108
+ all_table_names: list[str],
109
+ distribution_profiles: list[dict[str, Any]] | None = None,
110
+ ) -> list[dict[str, str]]:
111
+ context = self._build_context(
112
+ table_name=table_name,
113
+ columns=columns,
114
+ indexes=indexes,
115
+ sample_data=sample_data,
116
+ foreign_keys=foreign_keys,
117
+ all_table_names=all_table_names,
118
+ distribution_profiles=distribution_profiles,
119
+ )
120
+ messages: list[dict[str, str]] = [
121
+ {"role": "system", "content": SYSTEM_PROMPT},
122
+ ]
123
+
124
+ from sqlseed_ai.examples import FEW_SHOT_EXAMPLES
125
+ for example in FEW_SHOT_EXAMPLES:
126
+ messages.append({"role": "user", "content": example["input"]})
127
+ messages.append({"role": "assistant", "content": example["output"]})
128
+
129
+ messages.append({"role": "user", "content": context})
130
+
131
+ return messages
132
+
133
+ def call_llm(self, messages: list[dict[str, str]]) -> dict[str, Any]:
134
+ if self._config is None:
135
+ self._config = AIConfig.from_env()
136
+ if not self._config.api_key:
137
+ raise ValueError("AI API key not configured")
138
+
139
+ client = get_openai_client(self._config)
140
+ try:
141
+ response = client.chat.completions.create(
142
+ model=self._config.model,
143
+ messages=messages,
144
+ max_tokens=self._config.max_tokens,
145
+ temperature=self._config.temperature,
146
+ response_format={"type": "json_object"},
147
+ )
148
+ except Exception as e:
149
+ raise RuntimeError(
150
+ f"LLM API call failed (model={self._config.model}): {e}"
151
+ ) from e
152
+
153
+ if not response.choices:
154
+ raise RuntimeError(
155
+ f"LLM returned no choices (model={self._config.model}). "
156
+ "The API key or model may be invalid."
157
+ )
158
+ content = response.choices[0].message.content
159
+ if content is None:
160
+ return {}
161
+ return self._parse_json_response(content)
162
+
163
+ TEMPLATE_SYSTEM_PROMPT = (
164
+ "You are a data generation assistant. Generate realistic sample values "
165
+ "for the given database column. Return a JSON object with a 'values' "
166
+ "array containing the requested number of unique, realistic values. "
167
+ "Each value must be valid for the column type. Do NOT include explanations."
168
+ )
169
+
170
+ def generate_template_values(
171
+ self,
172
+ column_name: str,
173
+ column_type: str,
174
+ count: int,
175
+ sample_data: list[Any],
176
+ ) -> list[Any]:
177
+ prompt = (
178
+ f"Generate {count} realistic sample values for a database column "
179
+ f"named '{column_name}' with type '{column_type}'."
180
+ )
181
+ if sample_data:
182
+ prompt += f"\nExisting sample values: {sample_data[:5]}"
183
+ prompt += (
184
+ f"\nRespond with a JSON object: {{\"values\": [...]}}."
185
+ f"\nEach value should be a valid {column_type} value."
186
+ )
187
+
188
+ messages = [
189
+ {"role": "system", "content": self.TEMPLATE_SYSTEM_PROMPT},
190
+ {"role": "user", "content": prompt},
191
+ ]
192
+ result = self.call_llm(messages)
193
+ return result.get("values", [])
194
+
195
+ def _build_context(
196
+ self,
197
+ table_name: str,
198
+ columns: list[Any],
199
+ indexes: list[dict[str, Any]],
200
+ sample_data: list[dict[str, Any]],
201
+ foreign_keys: list[Any],
202
+ all_table_names: list[str],
203
+ distribution_profiles: list[dict[str, Any]] | None = None,
204
+ ) -> str:
205
+ lines: list[str] = []
206
+ lines.append(f"# Table: {table_name}")
207
+ lines.append("")
208
+
209
+ self._append_columns_info(lines, columns)
210
+
211
+ if indexes:
212
+ self._append_indexes_info(lines, indexes)
213
+
214
+ if foreign_keys:
215
+ lines.append("")
216
+ lines.append("## Foreign Keys")
217
+ for fk in foreign_keys:
218
+ lines.append(f"- {fk.column} → {fk.ref_table}.{fk.ref_column}")
219
+
220
+ if all_table_names:
221
+ lines.append("")
222
+ lines.append("## All Tables in Database")
223
+ lines.append(", ".join(all_table_names))
224
+
225
+ if sample_data:
226
+ lines.append("")
227
+ lines.append("## Sample Data (existing rows)")
228
+ for i, row in enumerate(sample_data[:5]):
229
+ row_str = ", ".join(f"{k}={v}" for k, v in row.items())
230
+ lines.append(f" Row {i + 1}: {row_str}")
231
+
232
+ if distribution_profiles:
233
+ self._append_distribution_info(lines, distribution_profiles)
234
+
235
+ lines.append("")
236
+ lines.append(
237
+ "Please analyze this table schema and recommend "
238
+ "a complete sqlseed JSON configuration for generating test data."
239
+ )
240
+
241
+ return "\n".join(lines)
242
+
243
+ def _append_columns_info(
244
+ self,
245
+ lines: list[str],
246
+ columns: list[Any],
247
+ ) -> None:
248
+ lines.append("## Columns")
249
+ for col in columns:
250
+ parts = [f"- {col.name}: {col.type}"]
251
+ if col.is_primary_key:
252
+ parts.append("PRIMARY KEY")
253
+ if col.is_autoincrement:
254
+ parts.append("AUTOINCREMENT")
255
+ if col.nullable:
256
+ parts.append("NULLABLE")
257
+ if col.default is not None:
258
+ parts.append(f"DEFAULT={col.default}")
259
+ if not col.nullable and col.default is None and not col.is_primary_key:
260
+ parts.append("NOT NULL")
261
+ lines.append(" ".join(parts))
262
+
263
+ def _append_indexes_info(
264
+ self,
265
+ lines: list[str],
266
+ indexes: list[dict[str, Any]],
267
+ ) -> None:
268
+ lines.append("")
269
+ lines.append("## Indexes")
270
+ for idx in indexes:
271
+ unique_str = "UNIQUE " if idx.get("unique") else ""
272
+ cols_str = ", ".join(idx.get("columns", []))
273
+ lines.append(f"- {unique_str}INDEX ({cols_str})")
274
+
275
+ def _append_distribution_info(
276
+ self,
277
+ lines: list[str],
278
+ distribution_profiles: list[dict[str, Any]],
279
+ ) -> None:
280
+ lines.append("")
281
+ lines.append("## Column Distribution (from existing data)")
282
+ for profile in distribution_profiles:
283
+ col = profile["column"]
284
+ distinct = profile.get("distinct_count", "?")
285
+ null_ratio = profile.get("null_ratio", 0)
286
+ lines.append(
287
+ f"- {col}: {distinct} distinct values, {null_ratio:.1%} null"
288
+ )
289
+ top_values = profile.get("top_values", [])
290
+ if top_values:
291
+ top_str = ", ".join(
292
+ f"{tv['value']}({tv['frequency']:.0%})"
293
+ for tv in top_values[:3]
294
+ )
295
+ lines.append(f" Top values: {top_str}")
296
+ vr = profile.get("value_range")
297
+ if vr:
298
+ lines.append(f" Range: [{vr['min']}, {vr['max']}]")
299
+
300
+ def _parse_json_response(self, content: str) -> dict[str, Any]:
301
+ from sqlseed_ai._json_utils import parse_json_response
302
+
303
+ return parse_json_response(content)
304
+
sqlseed_ai/config.py ADDED
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class AIConfig(BaseModel):
9
+ api_key: str | None = None
10
+ model: str = "qwen3-coder-plus"
11
+ base_url: str | None = None
12
+ temperature: float = Field(default=0.3, ge=0.0, le=2.0)
13
+ max_tokens: int = Field(default=4096, gt=0)
14
+
15
+ @classmethod
16
+ def from_env(cls) -> AIConfig:
17
+ api_key = os.environ.get("SQLSEED_AI_API_KEY") or os.environ.get("OPENAI_API_KEY")
18
+ base_url = os.environ.get("SQLSEED_AI_BASE_URL") or os.environ.get("OPENAI_BASE_URL")
19
+ model = os.environ.get("SQLSEED_AI_MODEL", cls.model_fields["model"].default)
20
+ return cls(api_key=api_key, base_url=base_url, model=model)
sqlseed_ai/errors.py ADDED
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+
8
+ @dataclass
9
+ class ErrorSummary:
10
+ error_type: str
11
+ message: str
12
+ column: str | None
13
+ retryable: bool
14
+
15
+ def to_prompt_str(self) -> str:
16
+ parts = [f"Error Type: {self.error_type}", f"Message: {self.message}"]
17
+ if self.column:
18
+ parts.append(f"Affected Column: {self.column}")
19
+ return "\n".join(parts)
20
+
21
+
22
+ def summarize_error(exc: Exception) -> ErrorSummary:
23
+ import json as _json
24
+
25
+ try:
26
+ from pydantic import ValidationError
27
+
28
+ if isinstance(exc, ValidationError):
29
+ first = exc.errors()[0]
30
+ loc = " → ".join(str(part) for part in first["loc"])
31
+ col_name = _extract_column_from_pydantic_loc(first["loc"])
32
+ return ErrorSummary(
33
+ error_type="pydantic_validation",
34
+ message=f"Field '{loc}': {first['msg']} (type={first['type']})",
35
+ column=col_name,
36
+ retryable=True,
37
+ )
38
+ except ImportError:
39
+ pass
40
+
41
+ if isinstance(exc, _json.JSONDecodeError):
42
+ return ErrorSummary(
43
+ error_type="json_syntax",
44
+ message=f"JSON parsing failed at position {exc.pos}: {exc.msg}",
45
+ column=None,
46
+ retryable=True,
47
+ )
48
+
49
+ if isinstance(exc, AttributeError) and "generate_" in str(exc):
50
+ gen_name = _extract_generator_name(str(exc))
51
+ return ErrorSummary(
52
+ error_type="unknown_generator",
53
+ message=(
54
+ f"Generator '{gen_name}' does not exist. "
55
+ "Use one of the available generators listed in the system prompt."
56
+ ),
57
+ column=None,
58
+ retryable=True,
59
+ )
60
+
61
+ try:
62
+ from sqlseed.generators.stream import UnknownGeneratorError
63
+ if isinstance(exc, UnknownGeneratorError):
64
+ return ErrorSummary(
65
+ error_type="unknown_generator",
66
+ message=(
67
+ f"Generator '{exc.generator_name}' does not exist. "
68
+ "Use one of the available generators listed in the system prompt."
69
+ ),
70
+ column=exc.column_name,
71
+ retryable=True,
72
+ )
73
+ except ImportError:
74
+ pass
75
+
76
+ exc_type_name = type(exc).__name__
77
+ exc_module = str(getattr(type(exc), "__module__", ""))
78
+ if "ExpressionTimeout" in exc_type_name or "simpleeval" in exc_module:
79
+ return ErrorSummary(
80
+ error_type="expression_error",
81
+ message=f"Expression evaluation failed: {str(exc)[:150]}",
82
+ column=_extract_column_from_message(str(exc)),
83
+ retryable=True,
84
+ )
85
+
86
+ if isinstance(exc, (FileNotFoundError, PermissionError)):
87
+ return ErrorSummary(
88
+ error_type="fatal",
89
+ message=str(exc)[:200],
90
+ column=None,
91
+ retryable=False,
92
+ )
93
+
94
+ return ErrorSummary(
95
+ error_type="runtime_error",
96
+ message=f"{exc_type_name}: {str(exc)[:200]}",
97
+ column=_extract_column_from_message(str(exc)),
98
+ retryable=True,
99
+ )
100
+
101
+
102
+ def _extract_column_from_pydantic_loc(loc: tuple[Any, ...]) -> str | None:
103
+ if len(loc) >= 3 and loc[0] == "columns":
104
+ if hasattr(loc[2], "value"):
105
+ return loc[2].value.get("name") if isinstance(loc[2].value, dict) else str(loc[2].value)
106
+ return str(loc[2])
107
+ if len(loc) >= 2 and loc[0] == "columns":
108
+ return str(loc[1])
109
+ return None
110
+
111
+
112
+ def _extract_column_from_message(msg: str) -> str | None:
113
+ match = re.search(r"column[:\s]+'?(\w+)'?", msg, re.IGNORECASE)
114
+ return match.group(1) if match else None
115
+
116
+
117
+ def _extract_generator_name(msg: str) -> str:
118
+ match = re.search(r"generate_(\w+)", msg)
119
+ return match.group(1) if match else "unknown"
sqlseed_ai/examples.py ADDED
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+
5
+ FEW_SHOT_EXAMPLES: list[dict[str, str]] = [
6
+ {
7
+ "input": """# Table: users
8
+ ## Columns
9
+ - id: INTEGER PRIMARY KEY AUTOINCREMENT
10
+ - name: VARCHAR(50) NOT NULL
11
+ - email: VARCHAR(100) NOT NULL
12
+ - status: INTEGER DEFAULT 1
13
+ - created_at: DATETIME
14
+ ## Indexes
15
+ - UNIQUE INDEX (email)
16
+ ## All Tables in Database
17
+ users, orders""",
18
+
19
+ "output": json.dumps({
20
+ "name": "users",
21
+ "count": 1000,
22
+ "columns": [
23
+ {"name": "name", "generator": "name"},
24
+ {"name": "email", "generator": "email", "constraints": {"unique": True}},
25
+ {"name": "created_at", "generator": "datetime", "params": {"start_year": 2020, "end_year": 2025}},
26
+ ]
27
+ }, indent=2),
28
+ },
29
+ {
30
+ "input": """# Table: card_info
31
+ ## Columns
32
+ - cardId: INTEGER PRIMARY KEY AUTOINCREMENT
33
+ - sCardNo: VARCHAR(20) NOT NULL
34
+ - sUserNo: VARCHAR(32) NOT NULL
35
+ - CutCard4byte: VARCHAR(8)
36
+ - nStatus: INTEGER DEFAULT 0
37
+ - dCreateTime: DATETIME
38
+ ## Indexes
39
+ - UNIQUE INDEX (sCardNo)
40
+ - UNIQUE INDEX (sUserNo)
41
+ ## All Tables in Database
42
+ card_info, user_info""",
43
+
44
+ "output": json.dumps({
45
+ "name": "card_info",
46
+ "count": 1000,
47
+ "columns": [
48
+ {
49
+ "name": "sCardNo",
50
+ "generator": "pattern",
51
+ "params": {"regex": "[0-9]{16}"},
52
+ "constraints": {"unique": True},
53
+ },
54
+ {
55
+ "name": "sUserNo",
56
+ "generator": "pattern",
57
+ "params": {"regex": "U[0-9]{10}"},
58
+ "constraints": {"unique": True},
59
+ },
60
+ {"name": "CutCard4byte", "derive_from": "sCardNo", "expression": "value[-8:]"},
61
+ {
62
+ "name": "dCreateTime",
63
+ "generator": "datetime",
64
+ "params": {"start_year": 2023, "end_year": 2025},
65
+ },
66
+ ]
67
+ }, indent=2),
68
+ },
69
+ {
70
+ "input": """# Table: orders
71
+ ## Columns
72
+ - id: INTEGER PRIMARY KEY AUTOINCREMENT
73
+ - user_id: INTEGER NOT NULL
74
+ - product_name: VARCHAR(100) NOT NULL
75
+ - quantity: INTEGER NOT NULL
76
+ - unit_price: FLOAT NOT NULL
77
+ - order_status: VARCHAR(20) NOT NULL
78
+ - order_date: DATE
79
+ - notes: TEXT
80
+ ## Foreign Keys
81
+ - user_id → users.id
82
+ ## Indexes
83
+ - INDEX (user_id)
84
+ ## All Tables in Database
85
+ users, orders""",
86
+
87
+ "output": json.dumps({
88
+ "name": "orders",
89
+ "count": 5000,
90
+ "columns": [
91
+ {
92
+ "name": "user_id",
93
+ "generator": "foreign_key",
94
+ "params": {"ref_table": "users", "ref_column": "id"},
95
+ },
96
+ {
97
+ "name": "product_name",
98
+ "generator": "string",
99
+ "params": {"min_length": 5, "max_length": 50},
100
+ },
101
+ {
102
+ "name": "quantity",
103
+ "generator": "integer",
104
+ "params": {"min_value": 1, "max_value": 100},
105
+ },
106
+ {
107
+ "name": "unit_price",
108
+ "generator": "float",
109
+ "params": {"min_value": 0.99, "max_value": 999.99, "precision": 2},
110
+ },
111
+ {
112
+ "name": "order_status",
113
+ "generator": "choice",
114
+ "params": {
115
+ "choices": [
116
+ "pending", "confirmed",
117
+ "shipped", "delivered", "cancelled",
118
+ ],
119
+ },
120
+ },
121
+ {
122
+ "name": "order_date",
123
+ "generator": "date",
124
+ "params": {"start_year": 2023, "end_year": 2025},
125
+ },
126
+ ]
127
+ }, indent=2),
128
+ },
129
+ {
130
+ "input": """# Table: employees
131
+ ## Columns
132
+ - emp_id: INTEGER PRIMARY KEY AUTOINCREMENT
133
+ - dept_id: INTEGER NOT NULL
134
+ - first_name: VARCHAR(50) NOT NULL
135
+ - last_name: VARCHAR(50) NOT NULL
136
+ - hire_date: DATE NOT NULL
137
+ - salary: INTEGER NOT NULL
138
+ - is_active: BOOLEAN
139
+ - metadata: TEXT
140
+ ## Foreign Keys
141
+ - dept_id → departments.id
142
+ ## Indexes
143
+ - UNIQUE INDEX (first_name, last_name)
144
+ ## All Tables in Database
145
+ departments, employees""",
146
+
147
+ "output": json.dumps({
148
+ "name": "employees",
149
+ "count": 2000,
150
+ "columns": [
151
+ {
152
+ "name": "dept_id",
153
+ "generator": "foreign_key",
154
+ "params": {"ref_table": "departments", "ref_column": "id"},
155
+ },
156
+ {"name": "first_name", "generator": "first_name"},
157
+ {"name": "last_name", "generator": "last_name"},
158
+ {
159
+ "name": "hire_date",
160
+ "generator": "date",
161
+ "params": {"start_year": 2015, "end_year": 2025},
162
+ },
163
+ {
164
+ "name": "salary",
165
+ "generator": "integer",
166
+ "params": {"min_value": 30000, "max_value": 200000},
167
+ },
168
+ {"name": "is_active", "generator": "boolean"},
169
+ ]
170
+ }, indent=2),
171
+ },
172
+ ]
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from sqlseed._utils.logger import get_logger
6
+ from sqlseed_ai._client import get_openai_client
7
+ from sqlseed_ai.config import AIConfig
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class NLConfigGenerator:
13
+
14
+ def __init__(self, config: Any | None = None) -> None:
15
+ self._config = config
16
+
17
+ def generate(self, description: str, db_path: str | None = None) -> dict[str, Any]:
18
+ try:
19
+ client = get_openai_client(self._config)
20
+ model = AIConfig.model_fields["model"].default
21
+ if self._config is not None and hasattr(self._config, "model"):
22
+ model = self._config.model
23
+
24
+ schema_info = ""
25
+ if db_path:
26
+ schema_info = self._read_schema(db_path)
27
+
28
+ prompt = (
29
+ f"Generate a sqlseed configuration based on this description:\n"
30
+ f'"{description}"\n\n'
31
+ )
32
+ if schema_info:
33
+ prompt += f"Database schema:\n{schema_info}\n\n"
34
+
35
+ prompt += (
36
+ "The JSON should follow this structure:\n"
37
+ '{"db_path": "path", "provider": "mimesis", "locale": "en_US", '
38
+ '"tables": [{"name": "table_name", "count": 1000, '
39
+ '"columns": [{"name": "column_name", "generator": "generator_name", "params": {}}]}]}\n\n'
40
+ "Respond with valid JSON only."
41
+ )
42
+
43
+ response = client.chat.completions.create(
44
+ model=model,
45
+ messages=[{"role": "user", "content": prompt}],
46
+ max_tokens=2048,
47
+ temperature=0.5,
48
+ response_format={"type": "json_object"},
49
+ )
50
+
51
+ content = response.choices[0].message.content if response.choices else None
52
+ if content is None:
53
+ return {}
54
+
55
+ from sqlseed_ai._json_utils import parse_json_response
56
+
57
+ return parse_json_response(content)
58
+
59
+ except Exception as e:
60
+ logger.warning("NL config generation failed", error=e)
61
+ return {}
62
+
63
+ def _read_schema(self, db_path: str) -> str:
64
+ try:
65
+ from sqlseed.database.raw_sqlite_adapter import RawSQLiteAdapter
66
+
67
+ adapter = RawSQLiteAdapter()
68
+ adapter.connect(db_path)
69
+ tables = adapter.get_table_names()
70
+ lines: list[str] = []
71
+ for table in tables:
72
+ columns = adapter.get_column_info(table)
73
+ col_desc = ", ".join(
74
+ f"{c.name}({c.type})" for c in columns
75
+ )
76
+ lines.append(f" {table}: {col_desc}")
77
+ adapter.close()
78
+ return "\n".join(lines)
79
+ except Exception:
80
+ return ""
sqlseed_ai/provider.py ADDED
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ class AIProvider:
7
+
8
+ @property
9
+ def name(self) -> str:
10
+ return "ai"
11
+
12
+ def set_locale(self, locale: str) -> None:
13
+ pass
14
+
15
+ def set_seed(self, seed: int) -> None:
16
+ pass
17
+
18
+ def generate_string(self, **kwargs: Any) -> str:
19
+ return ""
20
+
21
+ def generate_integer(self, **kwargs: Any) -> int:
22
+ return 0
23
+
24
+ def generate_float(self, **kwargs: Any) -> float:
25
+ return 0.0
26
+
27
+ def generate_boolean(self) -> bool:
28
+ return False
29
+
30
+ def generate_bytes(self, **kwargs: Any) -> bytes:
31
+ return b""
32
+
33
+ def generate_name(self) -> str:
34
+ return ""
35
+
36
+ def generate_first_name(self) -> str:
37
+ return ""
38
+
39
+ def generate_last_name(self) -> str:
40
+ return ""
41
+
42
+ def generate_email(self) -> str:
43
+ return ""
44
+
45
+ def generate_phone(self) -> str:
46
+ return ""
47
+
48
+ def generate_address(self) -> str:
49
+ return ""
50
+
51
+ def generate_company(self) -> str:
52
+ return ""
53
+
54
+ def generate_url(self) -> str:
55
+ return ""
56
+
57
+ def generate_ipv4(self) -> str:
58
+ return ""
59
+
60
+ def generate_uuid(self) -> str:
61
+ return ""
62
+
63
+ def generate_date(self, **kwargs: Any) -> str:
64
+ return ""
65
+
66
+ def generate_datetime(self, **kwargs: Any) -> str:
67
+ return ""
68
+
69
+ def generate_timestamp(self) -> int:
70
+ return 0
71
+
72
+ def generate_text(self, **kwargs: Any) -> str:
73
+ return ""
74
+
75
+ def generate_sentence(self) -> str:
76
+ return ""
77
+
78
+ def generate_password(self, **kwargs: Any) -> str:
79
+ return ""
80
+
81
+ def generate_choice(self, choices: list[Any]) -> Any:
82
+ return choices[0] if choices else None
83
+
84
+ def generate_json(self, **kwargs: Any) -> str:
85
+ return "{}"
86
+
87
+ def generate_pattern(self, *, regex: str, **kwargs: Any) -> str:
88
+ return ""
sqlseed_ai/refiner.py ADDED
@@ -0,0 +1,277 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import time
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from sqlseed._utils.logger import get_logger
10
+ from sqlseed_ai.errors import ErrorSummary, summarize_error
11
+
12
+ if TYPE_CHECKING:
13
+ from sqlseed_ai.analyzer import SchemaAnalyzer
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class AISuggestionFailedError(Exception):
19
+ pass
20
+
21
+
22
+ class AiConfigRefiner:
23
+
24
+ def __init__(
25
+ self,
26
+ analyzer: SchemaAnalyzer,
27
+ db_path: str,
28
+ *,
29
+ cache_dir: str | None = None,
30
+ ) -> None:
31
+ self._analyzer = analyzer
32
+ self._db_path = db_path
33
+ self._cache_dir = Path(cache_dir) if cache_dir else Path(".sqlseed_cache/ai_configs")
34
+
35
+ def generate_and_refine(
36
+ self,
37
+ table_name: str,
38
+ *,
39
+ max_retries: int = 3,
40
+ no_cache: bool = False,
41
+ ) -> dict[str, Any]:
42
+ from sqlseed.core.orchestrator import DataOrchestrator
43
+
44
+ with DataOrchestrator(self._db_path) as orch:
45
+ schema_hash = self._compute_schema_hash(orch, table_name)
46
+
47
+ if not no_cache:
48
+ cached = self.get_cached_config(table_name, schema_hash)
49
+ if cached is not None:
50
+ logger.info("Using cached AI config", table_name=table_name)
51
+ return cached
52
+
53
+ schema_ctx = orch.get_schema_context(table_name)
54
+
55
+ initial_messages = self._analyzer.build_initial_messages(
56
+ table_name=schema_ctx["table_name"],
57
+ columns=schema_ctx["columns"],
58
+ indexes=schema_ctx["indexes"],
59
+ sample_data=schema_ctx["sample_data"],
60
+ foreign_keys=schema_ctx["foreign_keys"],
61
+ all_table_names=schema_ctx["all_table_names"],
62
+ distribution_profiles=schema_ctx.get("distribution"),
63
+ )
64
+
65
+ messages_history = list(initial_messages)
66
+
67
+ for attempt in range(max_retries + 1):
68
+ messages = list(messages_history)
69
+ try:
70
+ config_dict = self._analyzer.call_llm(messages)
71
+ except Exception as e:
72
+ error = summarize_error(e)
73
+ if not error.retryable:
74
+ raise AISuggestionFailedError(
75
+ f"Non-retryable error: {error.message}"
76
+ ) from e
77
+ if attempt == max_retries:
78
+ raise AISuggestionFailedError(
79
+ f"Failed after {max_retries} retries. Last error: {error.message}"
80
+ ) from e
81
+ logger.info(
82
+ "LLM API call failed, retrying",
83
+ attempt=attempt + 1,
84
+ max_retries=max_retries,
85
+ error=error.message,
86
+ )
87
+ continue
88
+
89
+ error = self._validate_config(orch, table_name, config_dict)
90
+
91
+ if error is None:
92
+ logger.info(
93
+ "AI config validated successfully",
94
+ table_name=table_name,
95
+ attempts=attempt + 1,
96
+ )
97
+ self._cache_successful_config(table_name, config_dict, schema_hash)
98
+ return config_dict
99
+
100
+ if not error.retryable:
101
+ raise AISuggestionFailedError(
102
+ f"Non-retryable error: {error.message}"
103
+ )
104
+
105
+ if attempt == max_retries:
106
+ logger.warning(
107
+ "AI config refinement exhausted all retries",
108
+ table_name=table_name,
109
+ last_error=error.error_type,
110
+ )
111
+ raise AISuggestionFailedError(
112
+ f"Failed after {max_retries} retries. Last error: {error.message}"
113
+ )
114
+
115
+ logger.info(
116
+ "AI config refinement attempt",
117
+ attempt=attempt + 1,
118
+ max_retries=max_retries,
119
+ error_type=error.error_type,
120
+ column=error.column,
121
+ )
122
+
123
+ messages_history.append({
124
+ "role": "assistant",
125
+ "content": json.dumps(config_dict, ensure_ascii=False),
126
+ })
127
+ messages_history.append({
128
+ "role": "user",
129
+ "content": self._build_refinement_prompt(error, attempt, max_retries),
130
+ })
131
+
132
+ raise AISuggestionFailedError("Unexpected state")
133
+
134
+ def _compute_schema_hash(self, orch: Any, table_name: str) -> str:
135
+ column_names = orch.get_column_names(table_name)
136
+ raw = "|".join(sorted(column_names))
137
+ return hashlib.md5(raw.encode()).hexdigest()[:12]
138
+
139
+ def _validate_config(
140
+ self,
141
+ orch: Any,
142
+ table_name: str,
143
+ config_dict: dict[str, Any],
144
+ ) -> ErrorSummary | None:
145
+ from sqlseed.config.models import TableConfig
146
+
147
+ try:
148
+ table_config = TableConfig(**config_dict)
149
+ except Exception as e:
150
+ return summarize_error(e)
151
+
152
+ actual_columns = orch.get_column_names(table_name)
153
+ skippable_cols = orch.get_skippable_columns(table_name)
154
+ suggestable_cols = actual_columns - skippable_cols
155
+
156
+ for col_cfg in table_config.columns:
157
+ if col_cfg.name not in actual_columns:
158
+ return ErrorSummary(
159
+ error_type="column_mismatch",
160
+ message=(
161
+ f"Column '{col_cfg.name}' does not exist in table "
162
+ f"'{table_name}'. Available columns: "
163
+ f"{sorted(actual_columns)}"
164
+ ),
165
+ column=col_cfg.name,
166
+ retryable=True,
167
+ )
168
+
169
+ if suggestable_cols and len(table_config.columns) == 0:
170
+ return ErrorSummary(
171
+ error_type="empty_config",
172
+ message=(
173
+ f"No column suggestions provided for table '{table_name}'. "
174
+ f"There are {len(suggestable_cols)} suggestable columns: "
175
+ f"{sorted(suggestable_cols)}. "
176
+ "Please provide generator suggestions for at least the "
177
+ "non-default, non-autoincrement columns."
178
+ ),
179
+ column=None,
180
+ retryable=True,
181
+ )
182
+
183
+ try:
184
+ orch.preview_table(
185
+ table_name=table_name,
186
+ count=5,
187
+ column_configs=table_config.columns,
188
+ )
189
+ except Exception as e:
190
+ return summarize_error(e)
191
+
192
+ return None
193
+
194
+ def _build_refinement_prompt(
195
+ self,
196
+ error: ErrorSummary,
197
+ attempt: int,
198
+ max_retries: int,
199
+ ) -> str:
200
+ parts = [
201
+ "Your previous configuration contained an error. Please fix it.",
202
+ "",
203
+ "## Error Details",
204
+ error.to_prompt_str(),
205
+ "",
206
+ "## Instructions",
207
+ "- Only fix the column(s) mentioned in the error.",
208
+ "- Do NOT modify other column configurations that were working correctly.",
209
+ "- Return the COMPLETE configuration JSON "
210
+ "with only the problematic parts corrected.",
211
+ "- If you are unsure how to fix the error, "
212
+ "use 'string' generator as a safe fallback.",
213
+ "",
214
+ f"This is refinement attempt {attempt + 1} of {max_retries}.",
215
+ ]
216
+
217
+ if attempt >= max_retries - 1:
218
+ parts.append(
219
+ "WARNING: This is the LAST attempt. "
220
+ "Use the simplest possible generators to ensure validity."
221
+ )
222
+
223
+ return "\n".join(parts)
224
+
225
+ def _cache_successful_config(
226
+ self,
227
+ table_name: str,
228
+ config_dict: dict[str, Any],
229
+ schema_hash: str,
230
+ ) -> None:
231
+ try:
232
+ self._cache_dir.mkdir(parents=True, exist_ok=True)
233
+ cache_file = self._cache_dir / f"{table_name}.json"
234
+ entry = {
235
+ "_meta": {
236
+ "schema_hash": schema_hash,
237
+ "created_at": time.time(),
238
+ },
239
+ "config": config_dict,
240
+ }
241
+ cache_file.write_text(
242
+ json.dumps(entry, ensure_ascii=False, indent=2),
243
+ encoding="utf-8",
244
+ )
245
+ logger.debug(
246
+ "Cached AI config",
247
+ table_name=table_name,
248
+ path=str(cache_file),
249
+ schema_hash=schema_hash,
250
+ )
251
+ except Exception as e:
252
+ logger.debug("Failed to cache AI config", error=str(e))
253
+
254
+ def get_cached_config(
255
+ self,
256
+ table_name: str,
257
+ schema_hash: str | None = None,
258
+ ) -> dict[str, Any] | None:
259
+ cache_file = self._cache_dir / f"{table_name}.json"
260
+ if cache_file.exists():
261
+ try:
262
+ entry = json.loads(cache_file.read_text(encoding="utf-8"))
263
+ if isinstance(entry, dict) and "_meta" in entry:
264
+ cached_hash = entry["_meta"].get("schema_hash", "")
265
+ if schema_hash and cached_hash != schema_hash:
266
+ logger.debug(
267
+ "Cache schema hash mismatch, invalidating",
268
+ table_name=table_name,
269
+ cached_hash=cached_hash,
270
+ current_hash=schema_hash,
271
+ )
272
+ return None
273
+ return entry.get("config")
274
+ return entry
275
+ except Exception:
276
+ pass
277
+ return None
sqlseed_ai/suggest.py ADDED
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from sqlseed._utils.logger import get_logger
6
+ from sqlseed_ai._client import get_openai_client
7
+ from sqlseed_ai.config import AIConfig
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ class ColumnSuggester:
13
+
14
+ def __init__(self, config: Any | None = None) -> None:
15
+ self._config = config
16
+
17
+ def suggest(
18
+ self,
19
+ column_name: str,
20
+ column_type: str,
21
+ table_name: str,
22
+ all_column_names: list[str],
23
+ ) -> dict[str, Any] | None:
24
+ try:
25
+ client = get_openai_client(self._config)
26
+ model = AIConfig.model_fields["model"].default
27
+ if self._config is not None and hasattr(self._config, "model"):
28
+ model = self._config.model
29
+
30
+ prompt = (
31
+ f"Given a SQLite table '{table_name}' with columns {all_column_names}, "
32
+ f"the column '{column_name}' has type '{column_type}'. "
33
+ f"Suggest the best data generator name and params for this column. "
34
+ f"Available generators: string, integer, float, boolean, bytes, "
35
+ f"name, first_name, last_name, email, phone, address, company, "
36
+ f"url, ipv4, uuid, date, datetime, timestamp, text, sentence, "
37
+ f"password, choice, json, foreign_key, pattern. "
38
+ f'Respond in JSON format: {{"generator": "...", "params": {{}}}}'
39
+ )
40
+
41
+ response = client.chat.completions.create(
42
+ model=model,
43
+ messages=[{"role": "user", "content": prompt}],
44
+ max_tokens=256,
45
+ temperature=0.3,
46
+ response_format={"type": "json_object"},
47
+ )
48
+
49
+ content = response.choices[0].message.content if response.choices else None
50
+ if content is None:
51
+ return None
52
+
53
+ from sqlseed_ai._json_utils import parse_json_response
54
+
55
+ result = parse_json_response(content)
56
+ if "generator" in result:
57
+ return result
58
+
59
+ except Exception as e:
60
+ logger.warning("AI suggestion failed", column_name=column_name, error=e)
61
+
62
+ return None
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: sqlseed-ai
3
+ Version: 0.1.0
4
+ Summary: AI-powered data generation plugin for sqlseed
5
+ Project-URL: Homepage, https://github.com/sunbos/sqlseed
6
+ Project-URL: Repository, https://github.com/sunbos/sqlseed/tree/main/plugins/sqlseed-ai
7
+ Author-email: SunBo <1443584939@qq.com>
8
+ License-Expression: AGPL-3.0-or-later
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Python: >=3.10
17
+ Requires-Dist: openai>=1.0
18
+ Requires-Dist: pydantic>=2.0
19
+ Requires-Dist: sqlseed
@@ -0,0 +1,15 @@
1
+ sqlseed_ai/__init__.py,sha256=KD8UutkBS4td2IaeEcG1YlDPafiMEFYnX1A-2_X0qaw,2450
2
+ sqlseed_ai/_client.py,sha256=0lLJD03sCvJPeCZXYMUYrdi_RVGq9Xhufa2oZt_TvL8,891
3
+ sqlseed_ai/_json_utils.py,sha256=TByhP4wvxj17aiQe9-Lgx88k4tqIa2RoOBWR7i-xgB4,556
4
+ sqlseed_ai/analyzer.py,sha256=aetDgwymMBDop14Qhu2kLIlTP6AWJ5m1xKPQNVX-qgU,10554
5
+ sqlseed_ai/config.py,sha256=Zc3N1EOtbuAcfxo48nftKDJK1paR4Av8x0lYcYojTmg,713
6
+ sqlseed_ai/errors.py,sha256=W43rjOWaP1zFUSltXRtpB9CXalEP_qgHzMfzdmD418c,3846
7
+ sqlseed_ai/examples.py,sha256=Sfl9JVjGiMjwXm9NuuHZPmnKKGTvg_yl599w5z8O9Yk,5322
8
+ sqlseed_ai/nl_config.py,sha256=Wcf8ujcoA98XAKjUA6HkReuQplodCC83afDL7UIFDP4,2781
9
+ sqlseed_ai/provider.py,sha256=JNr4Gsr0-cXH3hRxVBogrXehIM4YHudWJhVJgwzEdiM,1853
10
+ sqlseed_ai/refiner.py,sha256=br0wVQYGm-HuN3FW5x8wAyOgFqDGonFGsEx3INCQVnc,9896
11
+ sqlseed_ai/suggest.py,sha256=EckLx5q5CgwK9RJRsslOJTjJG6DQNufdRqCOxmhr8VU,2196
12
+ sqlseed_ai-0.1.0.dist-info/METADATA,sha256=RcpIcb4GFHQfGemI1UdxjgRcuSS7Ds5uDfNpnHs3U2c,826
13
+ sqlseed_ai-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
14
+ sqlseed_ai-0.1.0.dist-info/entry_points.txt,sha256=zEHu6heoDb7qXrooVmJN10y-4LSWIhitMNdmcovm1UE,33
15
+ sqlseed_ai-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.29.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [sqlseed]
2
+ ai = sqlseed_ai:plugin