sqlseed 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sqlseed/__init__.py +121 -0
  2. sqlseed/_utils/__init__.py +11 -0
  3. sqlseed/_utils/logger.py +30 -0
  4. sqlseed/_utils/metrics.py +45 -0
  5. sqlseed/_utils/progress.py +14 -0
  6. sqlseed/_utils/schema_helpers.py +51 -0
  7. sqlseed/_utils/sql_safe.py +45 -0
  8. sqlseed/_version.py +1 -0
  9. sqlseed/cli/__init__.py +3 -0
  10. sqlseed/cli/main.py +316 -0
  11. sqlseed/config/__init__.py +14 -0
  12. sqlseed/config/loader.py +66 -0
  13. sqlseed/config/models.py +99 -0
  14. sqlseed/config/snapshot.py +91 -0
  15. sqlseed/core/__init__.py +14 -0
  16. sqlseed/core/column_dag.py +108 -0
  17. sqlseed/core/constraints.py +116 -0
  18. sqlseed/core/expression.py +71 -0
  19. sqlseed/core/mapper.py +257 -0
  20. sqlseed/core/orchestrator.py +578 -0
  21. sqlseed/core/relation.py +124 -0
  22. sqlseed/core/result.py +23 -0
  23. sqlseed/core/schema.py +97 -0
  24. sqlseed/core/transform.py +27 -0
  25. sqlseed/database/__init__.py +14 -0
  26. sqlseed/database/_protocol.py +72 -0
  27. sqlseed/database/optimizer.py +96 -0
  28. sqlseed/database/raw_sqlite_adapter.py +197 -0
  29. sqlseed/database/sqlite_utils_adapter.py +183 -0
  30. sqlseed/generators/__init__.py +11 -0
  31. sqlseed/generators/_protocol.py +73 -0
  32. sqlseed/generators/base_provider.py +448 -0
  33. sqlseed/generators/faker_provider.py +157 -0
  34. sqlseed/generators/mimesis_provider.py +203 -0
  35. sqlseed/generators/registry.py +86 -0
  36. sqlseed/generators/stream.py +157 -0
  37. sqlseed/py.typed +0 -0
  38. sqlseed-0.1.0.dist-info/METADATA +934 -0
  39. sqlseed-0.1.0.dist-info/RECORD +42 -0
  40. sqlseed-0.1.0.dist-info/WHEEL +4 -0
  41. sqlseed-0.1.0.dist-info/entry_points.txt +6 -0
  42. sqlseed-0.1.0.dist-info/licenses/LICENSE +17 -0
sqlseed/__init__.py ADDED
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from sqlseed._version import __version__
6
+
7
+ __all__ = [
8
+ "ColumnConfig",
9
+ "DataOrchestrator",
10
+ "GenerationResult",
11
+ "GeneratorConfig",
12
+ "ProviderType",
13
+ "TableConfig",
14
+ "__version__",
15
+ "connect",
16
+ "fill",
17
+ "fill_from_config",
18
+ "load_config",
19
+ "preview",
20
+ ]
21
+
22
+ from sqlseed.config.loader import load_config
23
+ from sqlseed.config.models import (
24
+ ColumnConfig,
25
+ GeneratorConfig,
26
+ ProviderType,
27
+ TableConfig,
28
+ )
29
+ from sqlseed.core.orchestrator import DataOrchestrator
30
+ from sqlseed.core.result import GenerationResult
31
+
32
+
33
+ def fill(
34
+ db_path: str,
35
+ *,
36
+ table: str,
37
+ count: int = 1000,
38
+ columns: dict[str, Any] | None = None,
39
+ provider: str = "mimesis",
40
+ locale: str = "en_US",
41
+ seed: int | None = None,
42
+ batch_size: int = 5000,
43
+ clear_before: bool = False,
44
+ optimize_pragma: bool = True,
45
+ ) -> GenerationResult:
46
+ with DataOrchestrator(
47
+ db_path=db_path,
48
+ provider_name=provider,
49
+ locale=locale,
50
+ optimize_pragma=optimize_pragma,
51
+ ) as orch:
52
+ return orch.fill_table(
53
+ table_name=table,
54
+ count=count,
55
+ columns=columns,
56
+ seed=seed,
57
+ batch_size=batch_size,
58
+ clear_before=clear_before,
59
+ )
60
+
61
+
62
+ def connect(
63
+ db_path: str,
64
+ *,
65
+ provider: str = "mimesis",
66
+ locale: str = "en_US",
67
+ optimize_pragma: bool = True,
68
+ ) -> DataOrchestrator:
69
+ return DataOrchestrator(
70
+ db_path=db_path,
71
+ provider_name=provider,
72
+ locale=locale,
73
+ optimize_pragma=optimize_pragma,
74
+ )
75
+
76
+
77
+ def fill_from_config(config_path: str) -> list[GenerationResult]:
78
+ config = load_config(config_path)
79
+ results: list[GenerationResult] = []
80
+ with DataOrchestrator(
81
+ db_path=config.db_path,
82
+ provider_name=config.provider.value,
83
+ locale=config.locale,
84
+ optimize_pragma=config.optimize_pragma,
85
+ ) as orch:
86
+ for table_config in config.tables:
87
+ result = orch.fill_table(
88
+ table_name=table_config.name,
89
+ count=table_config.count,
90
+ seed=table_config.seed,
91
+ batch_size=table_config.batch_size,
92
+ clear_before=table_config.clear_before,
93
+ column_configs=table_config.columns,
94
+ transform=table_config.transform,
95
+ )
96
+ results.append(result)
97
+ return results
98
+
99
+
100
+ def preview(
101
+ db_path: str,
102
+ *,
103
+ table: str,
104
+ count: int = 5,
105
+ columns: dict[str, Any] | None = None,
106
+ provider: str = "mimesis",
107
+ locale: str = "en_US",
108
+ seed: int | None = None,
109
+ ) -> list[dict[str, Any]]:
110
+ with DataOrchestrator(
111
+ db_path=db_path,
112
+ provider_name=provider,
113
+ locale=locale,
114
+ optimize_pragma=False,
115
+ ) as orch:
116
+ return orch.preview_table(
117
+ table_name=table,
118
+ count=count,
119
+ columns=columns,
120
+ seed=seed,
121
+ )
@@ -0,0 +1,11 @@
1
+ from sqlseed._utils.metrics import MetricsCollector
2
+ from sqlseed._utils.progress import create_progress
3
+ from sqlseed._utils.sql_safe import build_insert_sql, quote_identifier, validate_table_name
4
+
5
+ __all__ = [
6
+ "MetricsCollector",
7
+ "build_insert_sql",
8
+ "create_progress",
9
+ "quote_identifier",
10
+ "validate_table_name",
11
+ ]
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import sys
5
+ from typing import Any
6
+
7
+ import structlog
8
+
9
+
10
+ def configure_logging(level: str = "INFO") -> None:
11
+ structlog.configure(
12
+ processors=[
13
+ structlog.contextvars.merge_contextvars,
14
+ structlog.processors.add_log_level,
15
+ structlog.processors.StackInfoRenderer(),
16
+ structlog.dev.set_exc_info,
17
+ structlog.processors.TimeStamper(fmt="iso"),
18
+ structlog.dev.ConsoleRenderer(),
19
+ ],
20
+ wrapper_class=structlog.make_filtering_bound_logger(
21
+ getattr(logging, level.upper(), logging.INFO),
22
+ ),
23
+ context_class=dict,
24
+ logger_factory=structlog.PrintLoggerFactory(file=sys.stderr),
25
+ cache_logger_on_first_use=True,
26
+ )
27
+
28
+
29
+ def get_logger(name: str | None = None) -> Any:
30
+ return structlog.get_logger(name)
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+
8
+ @dataclass
9
+ class MetricEntry:
10
+ name: str
11
+ value: float
12
+ timestamp: float = field(default_factory=time.monotonic)
13
+
14
+
15
+ class MetricsCollector:
16
+ def __init__(self) -> None:
17
+ self._entries: list[MetricEntry] = []
18
+
19
+ def record(self, name: str, value: float) -> None:
20
+ self._entries.append(MetricEntry(name=name, value=value))
21
+
22
+ def get_entries(self, name: str | None = None) -> list[MetricEntry]:
23
+ if name is None:
24
+ return list(self._entries)
25
+ return [e for e in self._entries if e.name == name]
26
+
27
+ def summary(self) -> dict[str, Any]:
28
+ if not self._entries:
29
+ return {}
30
+ by_name: dict[str, list[float]] = {}
31
+ for entry in self._entries:
32
+ by_name.setdefault(entry.name, []).append(entry.value)
33
+ result: dict[str, Any] = {}
34
+ for name, values in by_name.items():
35
+ result[name] = {
36
+ "count": len(values),
37
+ "total": sum(values),
38
+ "avg": sum(values) / len(values),
39
+ "min": min(values),
40
+ "max": max(values),
41
+ }
42
+ return result
43
+
44
+ def clear(self) -> None:
45
+ self._entries.clear()
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
4
+
5
+
6
+ def create_progress() -> Progress:
7
+ return Progress(
8
+ SpinnerColumn(),
9
+ TextColumn("[progress.description]{task.description}"),
10
+ BarColumn(),
11
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
12
+ TimeElapsedColumn(),
13
+ transient=True,
14
+ )
@@ -0,0 +1,51 @@
1
+ """
2
+ Shared database schema helper utilities.
3
+
4
+ Extracted from adapter implementations to avoid code duplication (DRY principle).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from typing import Any
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def detect_autoincrement(
16
+ execute_fn: Any,
17
+ table_name: str,
18
+ column_name: str,
19
+ ) -> bool:
20
+ """
21
+ Detect whether a column is AUTOINCREMENT by inspecting the CREATE TABLE SQL.
22
+
23
+ Works with both sqlite-utils Database.execute() and raw sqlite3 Connection.execute().
24
+
25
+ Args:
26
+ execute_fn: A callable that executes SQL and returns a cursor-like object
27
+ with a .fetchone() method.
28
+ table_name: Name of the table.
29
+ column_name: Name of the column to check.
30
+
31
+ Returns:
32
+ True if the column is declared as INTEGER PRIMARY KEY AUTOINCREMENT.
33
+ """
34
+ try:
35
+ result = execute_fn(
36
+ "SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
37
+ [table_name],
38
+ )
39
+ row = result.fetchone() if hasattr(result, "fetchone") else result
40
+ if row and row[0]:
41
+ sql_upper = row[0].upper()
42
+ if "AUTOINCREMENT" not in sql_upper:
43
+ return False
44
+ col_upper = column_name.upper()
45
+ for part in sql_upper.split(","):
46
+ stripped = part.strip()
47
+ if col_upper in stripped and "INTEGER" in stripped and "PRIMARY" in stripped:
48
+ return True
49
+ except Exception:
50
+ logger.debug("Failed to detect autoincrement", extra={"table": table_name, "column": column_name})
51
+ return False
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import re
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def quote_identifier(name: str) -> str:
10
+ """
11
+ Safely escape a SQL identifier (table name, column name).
12
+
13
+ Uses SQLite's double-quote escaping rules:
14
+ - Wrap the identifier in double quotes
15
+ - Replace internal double quotes with two double quotes
16
+ """
17
+ if not name or not name.strip():
18
+ raise ValueError("SQL identifier cannot be empty")
19
+ escaped = name.replace('"', '""')
20
+ return f'"{escaped}"'
21
+
22
+
23
+ def validate_table_name(name: str) -> str:
24
+ """
25
+ Validate and escape a table name.
26
+
27
+ Performs basic legality checks in addition to escaping.
28
+ Warns about table names containing special characters.
29
+ """
30
+ if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name):
31
+ logger.warning("Table name '%s' contains special characters and will be quoted", name)
32
+ return quote_identifier(name)
33
+
34
+
35
+ def build_insert_sql(table_name: str, column_names: list[str]) -> str:
36
+ """
37
+ 构建安全的 INSERT SQL 语句。
38
+
39
+ Returns:
40
+ INSERT INTO "table" ("col1", "col2") VALUES (?, ?)
41
+ """
42
+ safe_table = quote_identifier(table_name)
43
+ safe_columns = ", ".join(quote_identifier(col) for col in column_names)
44
+ placeholders = ", ".join(["?"] * len(column_names))
45
+ return f"INSERT INTO {safe_table} ({safe_columns}) VALUES ({placeholders})"
sqlseed/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
@@ -0,0 +1,3 @@
1
+ from sqlseed.cli.main import cli
2
+
3
+ __all__ = ["cli"]
sqlseed/cli/main.py ADDED
@@ -0,0 +1,316 @@
1
+ from __future__ import annotations
2
+
3
+ import click
4
+
5
+ from sqlseed._version import __version__
6
+
7
+
8
+ @click.group()
9
+ @click.version_option(version=__version__, prog_name="sqlseed")
10
+ def cli() -> None:
11
+ """sqlseed - Declarative SQLite test data generation toolkit."""
12
+ pass
13
+
14
+
15
+ @cli.command()
16
+ @click.argument("db_path", required=False)
17
+ @click.option("--table", "-t", default=None, help="Target table name")
18
+ @click.option("--count", "-n", default=1000, type=int, help="Number of rows to generate")
19
+ @click.option("--provider", "-p", default="mimesis", help="Data provider (mimesis|faker|base)")
20
+ @click.option("--locale", "-l", default="en_US", help="Locale for data generation")
21
+ @click.option("--seed", "-s", default=None, type=int, help="Random seed for reproducibility")
22
+ @click.option("--batch-size", "-b", default=5000, type=int, help="Batch size for insertion")
23
+ @click.option("--clear", is_flag=True, help="Clear table before generating")
24
+ @click.option("--config", "-c", "config_path", default=None, help="YAML/JSON config file path")
25
+ @click.option("--transform", "transform_path", default=None, help="Python transform script path")
26
+ @click.option("--snapshot", is_flag=True, help="Save generation snapshot for replay")
27
+ def fill(
28
+ db_path: str | None,
29
+ table: str | None,
30
+ count: int,
31
+ provider: str,
32
+ locale: str,
33
+ seed: int | None,
34
+ batch_size: int,
35
+ clear: bool,
36
+ config_path: str | None,
37
+ transform_path: str | None,
38
+ snapshot: bool,
39
+ ) -> None:
40
+ """Fill a table with generated test data.
41
+
42
+ Use --config for config-driven generation, or provide db_path + --table
43
+ for direct generation.
44
+ """
45
+ if config_path:
46
+ from sqlseed import fill_from_config
47
+
48
+ results = fill_from_config(config_path)
49
+ for result in results:
50
+ click.echo(str(result))
51
+ return
52
+
53
+ if not db_path:
54
+ raise click.UsageError("db_path is required when not using --config")
55
+ if not table:
56
+ raise click.UsageError("--table is required when not using --config")
57
+
58
+ from sqlseed import fill as api_fill
59
+
60
+ transform_fn = None
61
+ if transform_path:
62
+ from sqlseed.core.transform import load_transform
63
+
64
+ transform_fn = load_transform(transform_path)
65
+ click.echo(f"Transform script loaded: {transform_path}")
66
+
67
+ result = api_fill(
68
+ db_path,
69
+ table=table,
70
+ count=count,
71
+ provider=provider,
72
+ locale=locale,
73
+ seed=seed,
74
+ batch_size=batch_size,
75
+ clear_before=clear,
76
+ )
77
+ click.echo(str(result))
78
+
79
+ if transform_fn:
80
+ click.echo(f"Transform applied: {transform_path}")
81
+
82
+ if snapshot:
83
+ from sqlseed.config.models import GeneratorConfig, ProviderType, TableConfig
84
+ from sqlseed.config.snapshot import SnapshotManager
85
+
86
+ config = GeneratorConfig(
87
+ db_path=db_path,
88
+ provider=ProviderType(provider),
89
+ locale=locale,
90
+ tables=[
91
+ TableConfig(
92
+ name=table,
93
+ count=count,
94
+ batch_size=batch_size,
95
+ clear_before=clear,
96
+ seed=seed,
97
+ )
98
+ ],
99
+ )
100
+ manager = SnapshotManager()
101
+ snapshot_path = manager.save(config, table, count, seed)
102
+ click.echo(f"Snapshot saved: {snapshot_path}")
103
+
104
+
105
+ @cli.command()
106
+ @click.argument("db_path")
107
+ @click.option("--table", "-t", required=True, help="Target table name")
108
+ @click.option("--count", "-n", default=5, type=int, help="Number of rows to preview")
109
+ @click.option("--provider", "-p", default="mimesis", help="Data provider")
110
+ @click.option("--locale", "-l", default="en_US", help="Locale")
111
+ @click.option("--seed", "-s", default=None, type=int, help="Random seed")
112
+ def preview(
113
+ db_path: str,
114
+ table: str,
115
+ count: int,
116
+ provider: str,
117
+ locale: str,
118
+ seed: int | None,
119
+ ) -> None:
120
+ """Preview generated data without writing to database."""
121
+ from rich.console import Console
122
+ from rich.table import Table as RichTable
123
+
124
+ from sqlseed import preview as api_preview
125
+
126
+ rows = api_preview(
127
+ db_path,
128
+ table=table,
129
+ count=count,
130
+ provider=provider,
131
+ locale=locale,
132
+ seed=seed,
133
+ )
134
+
135
+ if not rows:
136
+ click.echo("No data generated.")
137
+ return
138
+
139
+ console = Console()
140
+ rich_table = RichTable(title=f"Preview: {table} ({count} rows)")
141
+
142
+ for col_name in rows[0]:
143
+ rich_table.add_column(col_name)
144
+
145
+ for row in rows:
146
+ rich_table.add_row(*[str(v) for v in row.values()])
147
+
148
+ console.print(rich_table)
149
+
150
+
151
+ @cli.command()
152
+ @click.argument("db_path")
153
+ @click.option("--table", "-t", default=None, help="Specific table to inspect")
154
+ @click.option("--show-mapping", is_flag=True, help="Show column mapping strategy")
155
+ def inspect(db_path: str, table: str | None, show_mapping: bool) -> None:
156
+ """Inspect database schema and column mapping strategies."""
157
+ from rich.console import Console
158
+ from rich.table import Table as RichTable
159
+
160
+ from sqlseed.core.orchestrator import DataOrchestrator
161
+
162
+ with DataOrchestrator(db_path) as orch:
163
+ console = Console()
164
+
165
+ tables = [table] if table else orch._db.get_table_names()
166
+
167
+ for tbl in tables:
168
+ count = orch._db.get_row_count(tbl)
169
+ columns = orch._schema.get_column_info(tbl)
170
+ fks = orch._db.get_foreign_keys(tbl)
171
+
172
+ rich_table = RichTable(title=f"Table: {tbl} ({count} rows)")
173
+ rich_table.add_column("Column")
174
+ rich_table.add_column("Type")
175
+ rich_table.add_column("Nullable")
176
+ rich_table.add_column("PK")
177
+ rich_table.add_column("Auto")
178
+
179
+ if show_mapping:
180
+ rich_table.add_column("Generator")
181
+ rich_table.add_column("Params")
182
+
183
+ for col in columns:
184
+ row_data = [
185
+ col.name,
186
+ col.type,
187
+ "✓" if col.nullable else "✗",
188
+ "✓" if col.is_primary_key else "",
189
+ "✓" if col.is_autoincrement else "",
190
+ ]
191
+ if show_mapping:
192
+ spec = orch._mapper.map_column(col)
193
+ row_data.extend([spec.generator_name, str(spec.params)])
194
+ rich_table.add_row(*row_data)
195
+
196
+ console.print(rich_table)
197
+
198
+ if fks:
199
+ fk_table = RichTable(title=f"Foreign Keys: {tbl}")
200
+ fk_table.add_column("Column")
201
+ fk_table.add_column("Ref Table")
202
+ fk_table.add_column("Ref Column")
203
+ for fk in fks:
204
+ fk_table.add_row(fk.column, fk.ref_table, fk.ref_column)
205
+ console.print(fk_table)
206
+
207
+
208
+ @cli.command()
209
+ @click.argument("config_path")
210
+ @click.option("--db", default="test.db", help="Database path for template")
211
+ def init(config_path: str, db: str) -> None:
212
+ """Generate a YAML configuration template."""
213
+ from sqlseed.config.loader import generate_template, save_config
214
+
215
+ config = generate_template(db)
216
+ save_config(config, config_path)
217
+ click.echo(f"Configuration template saved to: {config_path}")
218
+
219
+
220
+ @cli.command()
221
+ @click.argument("snapshot_path")
222
+ def replay(snapshot_path: str) -> None:
223
+ """Replay a previously saved snapshot."""
224
+ from sqlseed.config.snapshot import SnapshotManager
225
+
226
+ manager = SnapshotManager()
227
+ result = manager.replay(snapshot_path)
228
+ click.echo(str(result))
229
+
230
+
231
+ @cli.command("ai-suggest")
232
+ @click.argument("db_path")
233
+ @click.option("--table", "-t", required=True, help="Target table name")
234
+ @click.option("--output", "-o", required=True, help="Output YAML file path")
235
+ @click.option("--model", "-m", default=None, help="AI model name (default: qwen3-coder-plus)")
236
+ @click.option("--api-key", envvar="SQLSEED_AI_API_KEY", default=None, help="AI API key")
237
+ @click.option("--base-url", envvar="SQLSEED_AI_BASE_URL", default=None, help="AI API base URL")
238
+ @click.option("--max-retries", default=3, type=int, help="Max refinement retries (0=disable)")
239
+ @click.option("--verify/--no-verify", default=True, help="Enable AI config self-correction")
240
+ @click.option("--no-cache", is_flag=True, help="Skip cached AI configs")
241
+ def ai_suggest(
242
+ db_path: str,
243
+ table: str,
244
+ output: str,
245
+ model: str | None,
246
+ api_key: str | None,
247
+ base_url: str | None,
248
+ max_retries: int,
249
+ verify: bool,
250
+ no_cache: bool,
251
+ ) -> None:
252
+ """Analyze table schema and suggest generation rules via AI."""
253
+ import yaml
254
+ from sqlseed_ai.analyzer import SchemaAnalyzer
255
+ from sqlseed_ai.config import AIConfig
256
+
257
+ ai_config = AIConfig(api_key=api_key, base_url=base_url)
258
+ if model:
259
+ ai_config.model = model
260
+
261
+ analyzer = SchemaAnalyzer(config=ai_config)
262
+
263
+ if verify and max_retries > 0:
264
+ from sqlseed_ai.refiner import AiConfigRefiner, AISuggestionFailedError
265
+
266
+ refiner = AiConfigRefiner(analyzer, db_path)
267
+ try:
268
+ result = refiner.generate_and_refine(
269
+ table_name=table,
270
+ max_retries=max_retries,
271
+ no_cache=no_cache,
272
+ )
273
+ except AISuggestionFailedError as e:
274
+ click.echo(f"AI suggestion failed: {e}", err=True)
275
+ return
276
+ else:
277
+ from sqlseed.core.orchestrator import DataOrchestrator
278
+
279
+ with DataOrchestrator(db_path) as orch:
280
+ schema_ctx = orch.get_schema_context(table)
281
+ messages = analyzer.build_initial_messages(
282
+ table_name=schema_ctx["table_name"],
283
+ columns=schema_ctx["columns"],
284
+ indexes=schema_ctx["indexes"],
285
+ sample_data=schema_ctx["sample_data"],
286
+ foreign_keys=schema_ctx["foreign_keys"],
287
+ all_table_names=schema_ctx["all_table_names"],
288
+ distribution_profiles=schema_ctx.get("distribution"),
289
+ )
290
+ try:
291
+ result = analyzer.call_llm(messages)
292
+ except (ValueError, RuntimeError) as e:
293
+ click.echo(f"AI suggestion failed: {e}", err=True)
294
+ return
295
+
296
+ if result:
297
+ output_data = {
298
+ "db_path": db_path,
299
+ "provider": "mimesis",
300
+ "locale": "zh_CN",
301
+ "tables": [result],
302
+ }
303
+ with open(output, "w", encoding="utf-8") as f:
304
+ yaml.dump(output_data, f, allow_unicode=True, sort_keys=False, default_flow_style=False)
305
+ click.echo(f"AI suggestions saved to {output}")
306
+ else:
307
+ click.echo("No suggestions received. Ensure sqlseed-ai plugin is installed and API key is configured.")
308
+ click.echo("Set SQLSEED_AI_API_KEY environment variable or use --api-key option.")
309
+
310
+
311
+ def main() -> None:
312
+ cli()
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
@@ -0,0 +1,14 @@
1
+ from sqlseed.config.loader import generate_template, load_config, save_config
2
+ from sqlseed.config.models import ColumnConfig, GeneratorConfig, ProviderType, TableConfig
3
+ from sqlseed.config.snapshot import SnapshotManager
4
+
5
+ __all__ = [
6
+ "ColumnConfig",
7
+ "GeneratorConfig",
8
+ "ProviderType",
9
+ "SnapshotManager",
10
+ "TableConfig",
11
+ "generate_template",
12
+ "load_config",
13
+ "save_config",
14
+ ]