sqlseed-ai 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,69 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ .venv/
25
+ env/
26
+ ENV/
27
+
28
+ # IDE
29
+ .vscode/
30
+ .idea/
31
+ *.swp
32
+ *.swo
33
+ *~
34
+
35
+ # Testing
36
+ .pytest_cache/
37
+ .coverage
38
+ htmlcov/
39
+
40
+ # Type checking
41
+ .mypy_cache/
42
+ .dmypy.json
43
+ dmypy.json
44
+
45
+ # Linting
46
+ .ruff_cache/
47
+
48
+ # Project specific
49
+ *.db
50
+ *.sqlite
51
+ *.sqlite3
52
+ snapshots/
53
+
54
+ # AI cache
55
+ .sqlseed_cache/
56
+
57
+ # Archived temp files
58
+ _archived_temp/
59
+
60
+ # macOS
61
+ .DS_Store
62
+
63
+ # Trae IDE
64
+ .trae/
65
+
66
+ # Build artifacts
67
+ dist/
68
+ *.whl
69
+ *.tar.gz
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: sqlseed-ai
3
+ Version: 0.1.0
4
+ Summary: AI-powered data generation plugin for sqlseed
5
+ Project-URL: Homepage, https://github.com/sunbos/sqlseed
6
+ Project-URL: Repository, https://github.com/sunbos/sqlseed/tree/main/plugins/sqlseed-ai
7
+ Author-email: SunBo <1443584939@qq.com>
8
+ License-Expression: AGPL-3.0-or-later
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Python: >=3.10
17
+ Requires-Dist: openai>=1.0
18
+ Requires-Dist: pydantic>=2.0
19
+ Requires-Dist: sqlseed
@@ -0,0 +1,37 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "sqlseed-ai"
7
+ version = "0.1.0"
8
+ requires-python = ">=3.10"
9
+ description = "AI-powered data generation plugin for sqlseed"
10
+ license = "AGPL-3.0-or-later"
11
+ authors = [
12
+ {name = "SunBo", email = "1443584939@qq.com"},
13
+ ]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
17
+ "Programming Language :: Python :: 3",
18
+ "Programming Language :: Python :: 3.10",
19
+ "Programming Language :: Python :: 3.11",
20
+ "Programming Language :: Python :: 3.12",
21
+ "Programming Language :: Python :: 3.13",
22
+ ]
23
+ dependencies = [
24
+ "sqlseed",
25
+ "openai>=1.0",
26
+ "pydantic>=2.0",
27
+ ]
28
+
29
+ [project.urls]
30
+ Homepage = "https://github.com/sunbos/sqlseed"
31
+ Repository = "https://github.com/sunbos/sqlseed/tree/main/plugins/sqlseed-ai"
32
+
33
+ [project.entry-points."sqlseed"]
34
+ ai = "sqlseed_ai:plugin"
35
+
36
+ [tool.hatch.build.targets.wheel]
37
+ packages = ["src/sqlseed_ai"]
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from typing import Any
5
+
6
+ from sqlseed.plugins.hookspecs import hookimpl
7
+ from sqlseed_ai.analyzer import SchemaAnalyzer
8
+
9
+ _SIMPLE_COL_RE = re.compile(
10
+ r'(^|[_\s])('
11
+ r'name|email|phone|address|url|uuid|'
12
+ r'date|time|datetime|timestamp|boolean|'
13
+ r'int|float|double|real|text|string|'
14
+ r'char|varchar|blob|byte|id|code|title|'
15
+ r'description|status|type|category|count|'
16
+ r'amount|price|value|number|index|order|level'
17
+ r')($|[_\s])',
18
+ re.IGNORECASE,
19
+ )
20
+
21
+
22
+ class AISqlseedPlugin:
23
+ def __init__(self) -> None:
24
+ self._analyzer: SchemaAnalyzer | None = None
25
+
26
+ def _get_analyzer(self) -> SchemaAnalyzer:
27
+ if self._analyzer is None:
28
+ self._analyzer = SchemaAnalyzer()
29
+ return self._analyzer
30
+
31
+ def _is_simple_column(self, column_name: str, column_type: str) -> bool:
32
+ return bool(
33
+ _SIMPLE_COL_RE.search(column_name) or _SIMPLE_COL_RE.search(column_type)
34
+ )
35
+
36
+ @hookimpl
37
+ def sqlseed_ai_analyze_table(
38
+ self,
39
+ table_name: str,
40
+ columns: list[Any],
41
+ indexes: list[dict[str, Any]],
42
+ sample_data: list[dict[str, Any]],
43
+ foreign_keys: list[Any],
44
+ all_table_names: list[str],
45
+ ) -> dict[str, Any] | None:
46
+ analyzer = self._get_analyzer()
47
+ return analyzer.analyze_table(
48
+ table_name=table_name,
49
+ columns=columns,
50
+ indexes=indexes,
51
+ sample_data=sample_data,
52
+ foreign_keys=foreign_keys,
53
+ all_table_names=all_table_names,
54
+ )
55
+
56
+ @hookimpl
57
+ def sqlseed_pre_generate_templates(
58
+ self,
59
+ table_name: str,
60
+ column_name: str,
61
+ column_type: str,
62
+ count: int,
63
+ sample_data: list[Any],
64
+ ) -> list[Any] | None:
65
+ if self._is_simple_column(column_name, column_type):
66
+ return None
67
+
68
+ analyzer = self._get_analyzer()
69
+ try:
70
+ return analyzer.generate_template_values(
71
+ column_name=column_name,
72
+ column_type=column_type,
73
+ count=min(count, 50),
74
+ sample_data=sample_data,
75
+ )
76
+ except Exception:
77
+ return None
78
+
79
+ @hookimpl
80
+ def sqlseed_register_providers(self, registry: Any) -> None:
81
+ pass
82
+
83
+ @hookimpl
84
+ def sqlseed_register_column_mappers(self, mapper: Any) -> None:
85
+ pass
86
+
87
+
88
+ plugin = AISqlseedPlugin()
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from sqlseed._utils.logger import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ def get_openai_client(config: Any | None = None) -> Any:
11
+ try:
12
+ from openai import OpenAI
13
+
14
+ from sqlseed_ai.config import AIConfig
15
+
16
+ if config is None:
17
+ config = AIConfig.from_env()
18
+
19
+ api_key = config.api_key if hasattr(config, "api_key") else None
20
+ base_url = config.base_url if hasattr(config, "base_url") else None
21
+
22
+ if not api_key:
23
+ raise ValueError(
24
+ "AI API key not configured. Set SQLSEED_AI_API_KEY or OPENAI_API_KEY environment variable."
25
+ )
26
+
27
+ return OpenAI(api_key=api_key, base_url=base_url)
28
+ except ImportError:
29
+ raise ImportError(
30
+ "openai is not installed. Install it with: pip install sqlseed-ai"
31
+ ) from None
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+
7
+ def parse_json_response(content: str) -> dict[str, Any]:
8
+ cleaned = content.strip()
9
+ if cleaned.startswith("```json"):
10
+ cleaned = cleaned[7:]
11
+ if cleaned.startswith("```"):
12
+ cleaned = cleaned[3:]
13
+ if cleaned.endswith("```"):
14
+ cleaned = cleaned[:-3]
15
+ cleaned = cleaned.strip()
16
+
17
+ try:
18
+ result = json.loads(cleaned)
19
+ if isinstance(result, dict):
20
+ return result
21
+ except json.JSONDecodeError:
22
+ pass
23
+
24
+ return {}
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from sqlseed._utils.logger import get_logger
6
+ from sqlseed_ai._client import get_openai_client
7
+ from sqlseed_ai.config import AIConfig
8
+
9
+ logger = get_logger(__name__)
10
+
11
+ SYSTEM_PROMPT = """You are an expert database test data engineer.
12
+ You analyze SQLite table schemas and recommend data generation configurations for the sqlseed toolkit.
13
+
14
+ ## Available Generators
15
+ - string (params: min_length, max_length, charset)
16
+ - integer (params: min_value, max_value)
17
+ - float (params: min_value, max_value, precision)
18
+ - boolean
19
+ - bytes (params: length)
20
+ - name, first_name, last_name
21
+ - email, phone, address, company
22
+ - url, ipv4, uuid
23
+ - date (params: start_year, end_year)
24
+ - datetime (params: start_year, end_year)
25
+ - timestamp
26
+ - text (params: min_length, max_length)
27
+ - sentence, password
28
+ - choice (params: choices)
29
+ - json (params: schema)
30
+ - pattern (params: regex) — generates strings matching a regex pattern
31
+
32
+ ## Key Rules
33
+ 1. INTEGER PRIMARY KEY AUTOINCREMENT columns → do NOT include (auto-skip)
34
+ 2. Columns with DEFAULT values → do NOT include (auto-skip)
35
+ 3. Nullable columns → do NOT include unless they have semantic meaning
36
+ 4. Use `pattern` generator with regex for card numbers, codes, IDs with specific formats
37
+ 5. Use `derive_from` + `expression` when one column is computed from another
38
+ 6. Use `constraints.unique: true` for columns that must be unique
39
+ 7. Detect cross-column dependencies: if CutCard4byte = last 8 chars of sCardNo, use derive_from
40
+ 8. Detect implicit business associations: if sUserNo appears in multiple tables, note it
41
+
42
+ ## Output Format
43
+ You MUST respond with a valid JSON object (NOT YAML, NOT markdown fences).
44
+ The JSON object must have this exact structure:
45
+ {
46
+ "name": "table_name",
47
+ "count": 1000,
48
+ "columns": [
49
+ {
50
+ "name": "column_name",
51
+ "generator": "generator_name",
52
+ "params": {"key": "value"}
53
+ },
54
+ {
55
+ "name": "derived_column",
56
+ "derive_from": "source_column",
57
+ "expression": "value[-8:]",
58
+ "constraints": {"unique": true}
59
+ }
60
+ ]
61
+ }
62
+
63
+ IMPORTANT: Do NOT include columns that are PRIMARY KEY AUTOINCREMENT or have DEFAULT values."""
64
+
65
+
66
+ class SchemaAnalyzer:
67
+ def __init__(self, config: AIConfig | None = None) -> None:
68
+ self._config = config
69
+
70
+ def analyze_table(
71
+ self,
72
+ table_name: str,
73
+ columns: list[Any],
74
+ indexes: list[dict[str, Any]],
75
+ sample_data: list[dict[str, Any]],
76
+ foreign_keys: list[Any],
77
+ all_table_names: list[str],
78
+ ) -> dict[str, Any] | None:
79
+ if self._config is None:
80
+ self._config = AIConfig.from_env()
81
+
82
+ if not self._config.api_key:
83
+ logger.warning("AI API key not configured, skipping analysis")
84
+ return None
85
+
86
+ messages = self.build_initial_messages(
87
+ table_name=table_name,
88
+ columns=columns,
89
+ indexes=indexes,
90
+ sample_data=sample_data,
91
+ foreign_keys=foreign_keys,
92
+ all_table_names=all_table_names,
93
+ )
94
+
95
+ try:
96
+ return self.call_llm(messages)
97
+ except Exception as e:
98
+ logger.warning("AI analysis failed", table_name=table_name, error=str(e))
99
+ return None
100
+
101
+ def build_initial_messages(
102
+ self,
103
+ table_name: str,
104
+ columns: list[Any],
105
+ indexes: list[dict[str, Any]],
106
+ sample_data: list[dict[str, Any]],
107
+ foreign_keys: list[Any],
108
+ all_table_names: list[str],
109
+ distribution_profiles: list[dict[str, Any]] | None = None,
110
+ ) -> list[dict[str, str]]:
111
+ context = self._build_context(
112
+ table_name=table_name,
113
+ columns=columns,
114
+ indexes=indexes,
115
+ sample_data=sample_data,
116
+ foreign_keys=foreign_keys,
117
+ all_table_names=all_table_names,
118
+ distribution_profiles=distribution_profiles,
119
+ )
120
+ messages: list[dict[str, str]] = [
121
+ {"role": "system", "content": SYSTEM_PROMPT},
122
+ ]
123
+
124
+ from sqlseed_ai.examples import FEW_SHOT_EXAMPLES
125
+ for example in FEW_SHOT_EXAMPLES:
126
+ messages.append({"role": "user", "content": example["input"]})
127
+ messages.append({"role": "assistant", "content": example["output"]})
128
+
129
+ messages.append({"role": "user", "content": context})
130
+
131
+ return messages
132
+
133
+ def call_llm(self, messages: list[dict[str, str]]) -> dict[str, Any]:
134
+ if self._config is None:
135
+ self._config = AIConfig.from_env()
136
+ if not self._config.api_key:
137
+ raise ValueError("AI API key not configured")
138
+
139
+ client = get_openai_client(self._config)
140
+ try:
141
+ response = client.chat.completions.create(
142
+ model=self._config.model,
143
+ messages=messages,
144
+ max_tokens=self._config.max_tokens,
145
+ temperature=self._config.temperature,
146
+ response_format={"type": "json_object"},
147
+ )
148
+ except Exception as e:
149
+ raise RuntimeError(
150
+ f"LLM API call failed (model={self._config.model}): {e}"
151
+ ) from e
152
+
153
+ if not response.choices:
154
+ raise RuntimeError(
155
+ f"LLM returned no choices (model={self._config.model}). "
156
+ "The API key or model may be invalid."
157
+ )
158
+ content = response.choices[0].message.content
159
+ if content is None:
160
+ return {}
161
+ return self._parse_json_response(content)
162
+
163
+ TEMPLATE_SYSTEM_PROMPT = (
164
+ "You are a data generation assistant. Generate realistic sample values "
165
+ "for the given database column. Return a JSON object with a 'values' "
166
+ "array containing the requested number of unique, realistic values. "
167
+ "Each value must be valid for the column type. Do NOT include explanations."
168
+ )
169
+
170
+ def generate_template_values(
171
+ self,
172
+ column_name: str,
173
+ column_type: str,
174
+ count: int,
175
+ sample_data: list[Any],
176
+ ) -> list[Any]:
177
+ prompt = (
178
+ f"Generate {count} realistic sample values for a database column "
179
+ f"named '{column_name}' with type '{column_type}'."
180
+ )
181
+ if sample_data:
182
+ prompt += f"\nExisting sample values: {sample_data[:5]}"
183
+ prompt += (
184
+ f"\nRespond with a JSON object: {{\"values\": [...]}}."
185
+ f"\nEach value should be a valid {column_type} value."
186
+ )
187
+
188
+ messages = [
189
+ {"role": "system", "content": self.TEMPLATE_SYSTEM_PROMPT},
190
+ {"role": "user", "content": prompt},
191
+ ]
192
+ result = self.call_llm(messages)
193
+ return result.get("values", [])
194
+
195
+ def _build_context(
196
+ self,
197
+ table_name: str,
198
+ columns: list[Any],
199
+ indexes: list[dict[str, Any]],
200
+ sample_data: list[dict[str, Any]],
201
+ foreign_keys: list[Any],
202
+ all_table_names: list[str],
203
+ distribution_profiles: list[dict[str, Any]] | None = None,
204
+ ) -> str:
205
+ lines: list[str] = []
206
+ lines.append(f"# Table: {table_name}")
207
+ lines.append("")
208
+
209
+ self._append_columns_info(lines, columns)
210
+
211
+ if indexes:
212
+ self._append_indexes_info(lines, indexes)
213
+
214
+ if foreign_keys:
215
+ lines.append("")
216
+ lines.append("## Foreign Keys")
217
+ for fk in foreign_keys:
218
+ lines.append(f"- {fk.column} → {fk.ref_table}.{fk.ref_column}")
219
+
220
+ if all_table_names:
221
+ lines.append("")
222
+ lines.append("## All Tables in Database")
223
+ lines.append(", ".join(all_table_names))
224
+
225
+ if sample_data:
226
+ lines.append("")
227
+ lines.append("## Sample Data (existing rows)")
228
+ for i, row in enumerate(sample_data[:5]):
229
+ row_str = ", ".join(f"{k}={v}" for k, v in row.items())
230
+ lines.append(f" Row {i + 1}: {row_str}")
231
+
232
+ if distribution_profiles:
233
+ self._append_distribution_info(lines, distribution_profiles)
234
+
235
+ lines.append("")
236
+ lines.append(
237
+ "Please analyze this table schema and recommend "
238
+ "a complete sqlseed JSON configuration for generating test data."
239
+ )
240
+
241
+ return "\n".join(lines)
242
+
243
+ def _append_columns_info(
244
+ self,
245
+ lines: list[str],
246
+ columns: list[Any],
247
+ ) -> None:
248
+ lines.append("## Columns")
249
+ for col in columns:
250
+ parts = [f"- {col.name}: {col.type}"]
251
+ if col.is_primary_key:
252
+ parts.append("PRIMARY KEY")
253
+ if col.is_autoincrement:
254
+ parts.append("AUTOINCREMENT")
255
+ if col.nullable:
256
+ parts.append("NULLABLE")
257
+ if col.default is not None:
258
+ parts.append(f"DEFAULT={col.default}")
259
+ if not col.nullable and col.default is None and not col.is_primary_key:
260
+ parts.append("NOT NULL")
261
+ lines.append(" ".join(parts))
262
+
263
+ def _append_indexes_info(
264
+ self,
265
+ lines: list[str],
266
+ indexes: list[dict[str, Any]],
267
+ ) -> None:
268
+ lines.append("")
269
+ lines.append("## Indexes")
270
+ for idx in indexes:
271
+ unique_str = "UNIQUE " if idx.get("unique") else ""
272
+ cols_str = ", ".join(idx.get("columns", []))
273
+ lines.append(f"- {unique_str}INDEX ({cols_str})")
274
+
275
+ def _append_distribution_info(
276
+ self,
277
+ lines: list[str],
278
+ distribution_profiles: list[dict[str, Any]],
279
+ ) -> None:
280
+ lines.append("")
281
+ lines.append("## Column Distribution (from existing data)")
282
+ for profile in distribution_profiles:
283
+ col = profile["column"]
284
+ distinct = profile.get("distinct_count", "?")
285
+ null_ratio = profile.get("null_ratio", 0)
286
+ lines.append(
287
+ f"- {col}: {distinct} distinct values, {null_ratio:.1%} null"
288
+ )
289
+ top_values = profile.get("top_values", [])
290
+ if top_values:
291
+ top_str = ", ".join(
292
+ f"{tv['value']}({tv['frequency']:.0%})"
293
+ for tv in top_values[:3]
294
+ )
295
+ lines.append(f" Top values: {top_str}")
296
+ vr = profile.get("value_range")
297
+ if vr:
298
+ lines.append(f" Range: [{vr['min']}, {vr['max']}]")
299
+
300
+ def _parse_json_response(self, content: str) -> dict[str, Any]:
301
+ from sqlseed_ai._json_utils import parse_json_response
302
+
303
+ return parse_json_response(content)
304
+
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class AIConfig(BaseModel):
9
+ api_key: str | None = None
10
+ model: str = "qwen3-coder-plus"
11
+ base_url: str | None = None
12
+ temperature: float = Field(default=0.3, ge=0.0, le=2.0)
13
+ max_tokens: int = Field(default=4096, gt=0)
14
+
15
+ @classmethod
16
+ def from_env(cls) -> AIConfig:
17
+ api_key = os.environ.get("SQLSEED_AI_API_KEY") or os.environ.get("OPENAI_API_KEY")
18
+ base_url = os.environ.get("SQLSEED_AI_BASE_URL") or os.environ.get("OPENAI_BASE_URL")
19
+ model = os.environ.get("SQLSEED_AI_MODEL", cls.model_fields["model"].default)
20
+ return cls(api_key=api_key, base_url=base_url, model=model)