sqlseed 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlseed/__init__.py +121 -0
- sqlseed/_utils/__init__.py +11 -0
- sqlseed/_utils/logger.py +30 -0
- sqlseed/_utils/metrics.py +45 -0
- sqlseed/_utils/progress.py +14 -0
- sqlseed/_utils/schema_helpers.py +51 -0
- sqlseed/_utils/sql_safe.py +45 -0
- sqlseed/_version.py +1 -0
- sqlseed/cli/__init__.py +3 -0
- sqlseed/cli/main.py +316 -0
- sqlseed/config/__init__.py +14 -0
- sqlseed/config/loader.py +66 -0
- sqlseed/config/models.py +99 -0
- sqlseed/config/snapshot.py +91 -0
- sqlseed/core/__init__.py +14 -0
- sqlseed/core/column_dag.py +108 -0
- sqlseed/core/constraints.py +116 -0
- sqlseed/core/expression.py +71 -0
- sqlseed/core/mapper.py +257 -0
- sqlseed/core/orchestrator.py +578 -0
- sqlseed/core/relation.py +124 -0
- sqlseed/core/result.py +23 -0
- sqlseed/core/schema.py +97 -0
- sqlseed/core/transform.py +27 -0
- sqlseed/database/__init__.py +14 -0
- sqlseed/database/_protocol.py +72 -0
- sqlseed/database/optimizer.py +96 -0
- sqlseed/database/raw_sqlite_adapter.py +197 -0
- sqlseed/database/sqlite_utils_adapter.py +183 -0
- sqlseed/generators/__init__.py +11 -0
- sqlseed/generators/_protocol.py +73 -0
- sqlseed/generators/base_provider.py +448 -0
- sqlseed/generators/faker_provider.py +157 -0
- sqlseed/generators/mimesis_provider.py +203 -0
- sqlseed/generators/registry.py +86 -0
- sqlseed/generators/stream.py +157 -0
- sqlseed/py.typed +0 -0
- sqlseed-0.1.0.dist-info/METADATA +934 -0
- sqlseed-0.1.0.dist-info/RECORD +42 -0
- sqlseed-0.1.0.dist-info/WHEEL +4 -0
- sqlseed-0.1.0.dist-info/entry_points.txt +6 -0
- sqlseed-0.1.0.dist-info/licenses/LICENSE +17 -0
|
@@ -0,0 +1,578 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import time
|
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
6
|
+
|
|
7
|
+
from sqlseed._utils.logger import get_logger
|
|
8
|
+
from sqlseed._utils.metrics import MetricsCollector
|
|
9
|
+
from sqlseed._utils.progress import create_progress
|
|
10
|
+
from sqlseed.core.column_dag import ColumnDAG
|
|
11
|
+
from sqlseed.core.constraints import ConstraintSolver
|
|
12
|
+
from sqlseed.core.expression import ExpressionEngine
|
|
13
|
+
from sqlseed.core.mapper import ColumnMapper, GeneratorSpec
|
|
14
|
+
from sqlseed.core.relation import RelationResolver, SharedPool
|
|
15
|
+
from sqlseed.core.result import GenerationResult
|
|
16
|
+
from sqlseed.core.schema import SchemaInferrer
|
|
17
|
+
from sqlseed.core.transform import load_transform
|
|
18
|
+
from sqlseed.database.raw_sqlite_adapter import RawSQLiteAdapter
|
|
19
|
+
from sqlseed.database.sqlite_utils_adapter import SQLiteUtilsAdapter
|
|
20
|
+
from sqlseed.generators.registry import ProviderRegistry
|
|
21
|
+
from sqlseed.generators.stream import DataStream
|
|
22
|
+
from sqlseed.plugins.manager import PluginManager
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from sqlseed.database._protocol import DatabaseAdapter
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DataOrchestrator:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
db_path: str,
|
|
34
|
+
*,
|
|
35
|
+
provider_name: str = "mimesis",
|
|
36
|
+
locale: str = "en_US",
|
|
37
|
+
optimize_pragma: bool = True,
|
|
38
|
+
) -> None:
|
|
39
|
+
self._db_path = db_path
|
|
40
|
+
self._provider_name = provider_name
|
|
41
|
+
self._locale = locale
|
|
42
|
+
self._optimize_pragma = optimize_pragma
|
|
43
|
+
|
|
44
|
+
self._db: DatabaseAdapter = self._create_adapter()
|
|
45
|
+
self._schema = SchemaInferrer(self._db)
|
|
46
|
+
self._mapper = ColumnMapper()
|
|
47
|
+
self._relation = RelationResolver(self._db)
|
|
48
|
+
self._registry = ProviderRegistry()
|
|
49
|
+
self._metrics = MetricsCollector()
|
|
50
|
+
self._plugins = PluginManager()
|
|
51
|
+
self._shared_pool = SharedPool()
|
|
52
|
+
|
|
53
|
+
self._connected = False
|
|
54
|
+
|
|
55
|
+
def _create_adapter(self) -> DatabaseAdapter:
|
|
56
|
+
try:
|
|
57
|
+
import sqlite_utils # noqa: F401
|
|
58
|
+
except ImportError:
|
|
59
|
+
logger.debug("sqlite-utils not available, falling back to raw sqlite3")
|
|
60
|
+
return RawSQLiteAdapter()
|
|
61
|
+
return SQLiteUtilsAdapter()
|
|
62
|
+
|
|
63
|
+
def _ensure_connected(self) -> None:
|
|
64
|
+
if not self._connected:
|
|
65
|
+
self._db.connect(self._db_path)
|
|
66
|
+
self._connected = True
|
|
67
|
+
self._plugins.load_plugins()
|
|
68
|
+
self._plugins.hook.sqlseed_register_providers(registry=self._registry)
|
|
69
|
+
self._plugins.hook.sqlseed_register_column_mappers(mapper=self._mapper)
|
|
70
|
+
self._registry.register_from_entry_points()
|
|
71
|
+
try:
|
|
72
|
+
provider = self._registry.ensure_provider(self._provider_name)
|
|
73
|
+
self._registry.set_default(self._provider_name)
|
|
74
|
+
except (ImportError, ValueError):
|
|
75
|
+
logger.warning(
|
|
76
|
+
"Provider not available, falling back to 'base'",
|
|
77
|
+
provider_name=self._provider_name,
|
|
78
|
+
)
|
|
79
|
+
self._provider_name = "base"
|
|
80
|
+
provider = self._registry.get(self._provider_name)
|
|
81
|
+
provider.set_locale(self._locale)
|
|
82
|
+
|
|
83
|
+
def fill_table(
|
|
84
|
+
self,
|
|
85
|
+
table_name: str,
|
|
86
|
+
*,
|
|
87
|
+
count: int = 1000,
|
|
88
|
+
columns: dict[str, Any] | None = None,
|
|
89
|
+
seed: int | None = None,
|
|
90
|
+
batch_size: int = 5000,
|
|
91
|
+
clear_before: bool = False,
|
|
92
|
+
column_configs: list[Any] | None = None,
|
|
93
|
+
transform: str | None = None,
|
|
94
|
+
) -> GenerationResult:
|
|
95
|
+
self._ensure_connected()
|
|
96
|
+
start_time = time.monotonic()
|
|
97
|
+
total_inserted = 0
|
|
98
|
+
batch_count = 0
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
if self._optimize_pragma:
|
|
102
|
+
self._db.optimize_for_bulk_write(count)
|
|
103
|
+
|
|
104
|
+
if clear_before:
|
|
105
|
+
self._db.clear_table(table_name)
|
|
106
|
+
|
|
107
|
+
column_infos = self._schema.get_column_info(table_name)
|
|
108
|
+
user_configs = self._resolve_user_configs(columns, column_configs)
|
|
109
|
+
generator_specs = self._mapper.map_columns(column_infos, user_configs)
|
|
110
|
+
generator_specs = self._resolve_foreign_keys(table_name, generator_specs)
|
|
111
|
+
generator_specs = self._apply_ai_suggestions(table_name, column_infos, generator_specs)
|
|
112
|
+
generator_specs = self._apply_template_pool(table_name, column_infos, generator_specs, count)
|
|
113
|
+
|
|
114
|
+
dag = ColumnDAG()
|
|
115
|
+
col_configs_list = list(user_configs.values()) if user_configs else None
|
|
116
|
+
dag_nodes = dag.build(generator_specs, col_configs_list)
|
|
117
|
+
|
|
118
|
+
expr_engine = ExpressionEngine()
|
|
119
|
+
constraint_solver = ConstraintSolver()
|
|
120
|
+
|
|
121
|
+
transform_fn = None
|
|
122
|
+
if transform:
|
|
123
|
+
transform_fn = load_transform(transform)
|
|
124
|
+
|
|
125
|
+
provider = self._registry.get(self._provider_name)
|
|
126
|
+
|
|
127
|
+
self._plugins.hook.sqlseed_before_generate(
|
|
128
|
+
table_name=table_name,
|
|
129
|
+
count=count,
|
|
130
|
+
config=None,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
stream = DataStream(
|
|
134
|
+
dag_nodes=dag_nodes,
|
|
135
|
+
provider=provider,
|
|
136
|
+
expr_engine=expr_engine,
|
|
137
|
+
constraint_solver=constraint_solver,
|
|
138
|
+
transform_fn=transform_fn,
|
|
139
|
+
seed=seed,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
progress = create_progress()
|
|
143
|
+
with progress:
|
|
144
|
+
task_id = progress.add_task(f"Generating {table_name}", total=count)
|
|
145
|
+
for batch in stream.generate(count, batch_size):
|
|
146
|
+
batch_count += 1
|
|
147
|
+
|
|
148
|
+
self._plugins.hook.sqlseed_before_insert(
|
|
149
|
+
table_name=table_name,
|
|
150
|
+
batch_number=batch_count,
|
|
151
|
+
batch_size=len(batch),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
current_batch = self._apply_batch_transforms(table_name, batch)
|
|
155
|
+
|
|
156
|
+
inserted = self._db.batch_insert(table_name, iter(current_batch), batch_size)
|
|
157
|
+
total_inserted += inserted
|
|
158
|
+
|
|
159
|
+
self._metrics.record(f"{table_name}.batch_insert", float(inserted))
|
|
160
|
+
|
|
161
|
+
self._plugins.hook.sqlseed_after_insert(
|
|
162
|
+
table_name=table_name,
|
|
163
|
+
batch_number=batch_count,
|
|
164
|
+
rows_inserted=inserted,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
progress.update(task_id, advance=len(batch))
|
|
168
|
+
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.error("Failed to fill table", table_name=table_name, error=e)
|
|
171
|
+
return GenerationResult(
|
|
172
|
+
table_name=table_name,
|
|
173
|
+
count=total_inserted,
|
|
174
|
+
elapsed=time.monotonic() - start_time,
|
|
175
|
+
errors=[str(e)],
|
|
176
|
+
)
|
|
177
|
+
finally:
|
|
178
|
+
if self._optimize_pragma:
|
|
179
|
+
self._db.restore_settings()
|
|
180
|
+
|
|
181
|
+
elapsed = time.monotonic() - start_time
|
|
182
|
+
|
|
183
|
+
self._metrics.record(f"{table_name}.total_elapsed", elapsed)
|
|
184
|
+
self._metrics.record(f"{table_name}.total_rows", float(total_inserted))
|
|
185
|
+
|
|
186
|
+
self._plugins.hook.sqlseed_after_generate(
|
|
187
|
+
table_name=table_name,
|
|
188
|
+
count=total_inserted,
|
|
189
|
+
elapsed=elapsed,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self._register_shared_pool(table_name, generator_specs)
|
|
193
|
+
|
|
194
|
+
return GenerationResult(
|
|
195
|
+
table_name=table_name,
|
|
196
|
+
count=total_inserted,
|
|
197
|
+
elapsed=elapsed,
|
|
198
|
+
batch_count=batch_count,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def preview_table(
|
|
202
|
+
self,
|
|
203
|
+
table_name: str,
|
|
204
|
+
*,
|
|
205
|
+
count: int = 5,
|
|
206
|
+
columns: dict[str, Any] | None = None,
|
|
207
|
+
seed: int | None = None,
|
|
208
|
+
transform: str | None = None,
|
|
209
|
+
column_configs: list[Any] | None = None,
|
|
210
|
+
) -> list[dict[str, Any]]:
|
|
211
|
+
self._ensure_connected()
|
|
212
|
+
|
|
213
|
+
column_infos = self._schema.get_column_info(table_name)
|
|
214
|
+
user_configs = self._resolve_user_configs(columns, column_configs)
|
|
215
|
+
generator_specs = self._mapper.map_columns(column_infos, user_configs)
|
|
216
|
+
generator_specs = self._resolve_foreign_keys(table_name, generator_specs)
|
|
217
|
+
|
|
218
|
+
dag = ColumnDAG()
|
|
219
|
+
col_configs_list = list(user_configs.values()) if user_configs else None
|
|
220
|
+
dag_nodes = dag.build(generator_specs, col_configs_list)
|
|
221
|
+
|
|
222
|
+
expr_engine = ExpressionEngine()
|
|
223
|
+
constraint_solver = ConstraintSolver()
|
|
224
|
+
|
|
225
|
+
transform_fn = None
|
|
226
|
+
if transform:
|
|
227
|
+
transform_fn = load_transform(transform)
|
|
228
|
+
|
|
229
|
+
provider = self._registry.get(self._provider_name)
|
|
230
|
+
|
|
231
|
+
stream = DataStream(
|
|
232
|
+
dag_nodes=dag_nodes,
|
|
233
|
+
provider=provider,
|
|
234
|
+
expr_engine=expr_engine,
|
|
235
|
+
constraint_solver=constraint_solver,
|
|
236
|
+
transform_fn=transform_fn,
|
|
237
|
+
seed=seed,
|
|
238
|
+
)
|
|
239
|
+
result: list[dict[str, Any]] = []
|
|
240
|
+
for batch in stream.generate(count, batch_size=count):
|
|
241
|
+
current_batch = self._apply_batch_transforms(table_name, batch)
|
|
242
|
+
result.extend(current_batch)
|
|
243
|
+
return result
|
|
244
|
+
|
|
245
|
+
def get_schema_context(self, table_name: str) -> dict[str, Any]:
|
|
246
|
+
self._ensure_connected()
|
|
247
|
+
column_infos = self._schema.get_column_info(table_name)
|
|
248
|
+
fks = self._db.get_foreign_keys(table_name)
|
|
249
|
+
all_tables = self._db.get_table_names()
|
|
250
|
+
|
|
251
|
+
indexes: list[dict[str, Any]] = []
|
|
252
|
+
with contextlib.suppress(Exception):
|
|
253
|
+
idx_infos = self._schema.get_index_info(table_name)
|
|
254
|
+
indexes = [{"name": idx.name, "columns": idx.columns, "unique": idx.unique} for idx in idx_infos]
|
|
255
|
+
|
|
256
|
+
sample_data: list[dict[str, Any]] = []
|
|
257
|
+
with contextlib.suppress(Exception):
|
|
258
|
+
sample_data = self._schema.get_sample_data(table_name, limit=5)
|
|
259
|
+
|
|
260
|
+
distribution: list[dict[str, Any]] = []
|
|
261
|
+
with contextlib.suppress(Exception):
|
|
262
|
+
distribution = self._schema.profile_column_distribution(table_name, limit=1000)
|
|
263
|
+
|
|
264
|
+
return {
|
|
265
|
+
"table_name": table_name,
|
|
266
|
+
"columns": column_infos,
|
|
267
|
+
"foreign_keys": fks,
|
|
268
|
+
"indexes": indexes,
|
|
269
|
+
"sample_data": sample_data,
|
|
270
|
+
"all_table_names": all_tables,
|
|
271
|
+
"distribution": distribution,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def get_column_names(self, table_name: str) -> set[str]:
|
|
275
|
+
self._ensure_connected()
|
|
276
|
+
return {c.name for c in self._schema.get_column_info(table_name)}
|
|
277
|
+
|
|
278
|
+
def get_skippable_columns(self, table_name: str) -> set[str]:
|
|
279
|
+
self._ensure_connected()
|
|
280
|
+
return {
|
|
281
|
+
c.name
|
|
282
|
+
for c in self._schema.get_column_info(table_name)
|
|
283
|
+
if (c.is_primary_key and c.is_autoincrement) or c.default is not None
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
def report(self) -> str:
|
|
287
|
+
if not self._connected:
|
|
288
|
+
return "Not connected to any database."
|
|
289
|
+
|
|
290
|
+
tables = self._db.get_table_names()
|
|
291
|
+
lines = [f"Database: {self._db_path}", "=" * 50]
|
|
292
|
+
for table in tables:
|
|
293
|
+
count = self._db.get_row_count(table)
|
|
294
|
+
lines.append(f" {table}: {count} rows")
|
|
295
|
+
return "\n".join(lines)
|
|
296
|
+
|
|
297
|
+
def _resolve_user_configs(
|
|
298
|
+
self,
|
|
299
|
+
columns: dict[str, Any] | None,
|
|
300
|
+
column_configs: list[Any] | None,
|
|
301
|
+
) -> dict[str, Any]:
|
|
302
|
+
from sqlseed.config.models import ColumnConfig
|
|
303
|
+
|
|
304
|
+
configs: dict[str, Any] = {}
|
|
305
|
+
|
|
306
|
+
if column_configs:
|
|
307
|
+
for cc in column_configs:
|
|
308
|
+
if isinstance(cc, ColumnConfig):
|
|
309
|
+
configs[cc.name] = cc
|
|
310
|
+
|
|
311
|
+
if columns:
|
|
312
|
+
for col_name, col_spec in columns.items():
|
|
313
|
+
if isinstance(col_spec, str):
|
|
314
|
+
configs[col_name] = ColumnConfig(name=col_name, generator=col_spec)
|
|
315
|
+
elif isinstance(col_spec, dict):
|
|
316
|
+
spec_copy = dict(col_spec)
|
|
317
|
+
gen_type = spec_copy.pop("type", "string")
|
|
318
|
+
configs[col_name] = ColumnConfig(
|
|
319
|
+
name=col_name,
|
|
320
|
+
generator=gen_type,
|
|
321
|
+
params=spec_copy,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return configs
|
|
325
|
+
|
|
326
|
+
def _resolve_foreign_keys(
|
|
327
|
+
self,
|
|
328
|
+
table_name: str,
|
|
329
|
+
specs: dict[str, GeneratorSpec],
|
|
330
|
+
) -> dict[str, GeneratorSpec]:
|
|
331
|
+
for col_name, spec in specs.items():
|
|
332
|
+
if spec.generator_name == "foreign_key_or_integer":
|
|
333
|
+
fk_info = self._relation.get_fk_info(table_name, col_name)
|
|
334
|
+
if fk_info:
|
|
335
|
+
ref_values = self._relation.resolve_foreign_key_values(table_name, col_name)
|
|
336
|
+
new_spec = GeneratorSpec(
|
|
337
|
+
generator_name="foreign_key",
|
|
338
|
+
params={
|
|
339
|
+
"ref_table": fk_info.ref_table,
|
|
340
|
+
"ref_column": fk_info.ref_column,
|
|
341
|
+
"strategy": "random",
|
|
342
|
+
"_ref_values": ref_values,
|
|
343
|
+
},
|
|
344
|
+
null_ratio=spec.null_ratio,
|
|
345
|
+
provider=spec.provider,
|
|
346
|
+
)
|
|
347
|
+
specs[col_name] = new_spec
|
|
348
|
+
else:
|
|
349
|
+
specs[col_name] = GeneratorSpec(
|
|
350
|
+
generator_name="integer",
|
|
351
|
+
params={"min_value": 1, "max_value": 999999},
|
|
352
|
+
null_ratio=spec.null_ratio,
|
|
353
|
+
provider=spec.provider,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
elif spec.generator_name == "foreign_key":
|
|
357
|
+
if "ref_table" in spec.params:
|
|
358
|
+
ref_values = self._db.get_column_values(
|
|
359
|
+
spec.params["ref_table"],
|
|
360
|
+
spec.params["ref_column"],
|
|
361
|
+
)
|
|
362
|
+
spec.params["_ref_values"] = ref_values
|
|
363
|
+
|
|
364
|
+
return self._resolve_implicit_associations(table_name, specs)
|
|
365
|
+
|
|
366
|
+
def _resolve_implicit_associations(
|
|
367
|
+
self,
|
|
368
|
+
table_name: str,
|
|
369
|
+
specs: dict[str, GeneratorSpec],
|
|
370
|
+
) -> dict[str, GeneratorSpec]:
|
|
371
|
+
"""Resolve implicit cross-table associations via SharedPool.
|
|
372
|
+
|
|
373
|
+
When a column name exists in the SharedPool (generated by a previously
|
|
374
|
+
filled table), automatically use those values as a foreign_key strategy.
|
|
375
|
+
This handles cases like sUserNo appearing in multiple tables without
|
|
376
|
+
an explicit FK constraint.
|
|
377
|
+
"""
|
|
378
|
+
if not self._shared_pool._pools:
|
|
379
|
+
return specs
|
|
380
|
+
|
|
381
|
+
for col_name, spec in list(specs.items()):
|
|
382
|
+
if spec.generator_name != "foreign_key_or_integer":
|
|
383
|
+
continue
|
|
384
|
+
if not self._shared_pool.has(col_name):
|
|
385
|
+
continue
|
|
386
|
+
|
|
387
|
+
pool_values = self._shared_pool.get(col_name)
|
|
388
|
+
if not pool_values:
|
|
389
|
+
continue
|
|
390
|
+
|
|
391
|
+
specs[col_name] = GeneratorSpec(
|
|
392
|
+
generator_name="foreign_key",
|
|
393
|
+
params={
|
|
394
|
+
"ref_table": "__shared_pool__",
|
|
395
|
+
"ref_column": col_name,
|
|
396
|
+
"strategy": "random",
|
|
397
|
+
"_ref_values": pool_values,
|
|
398
|
+
},
|
|
399
|
+
null_ratio=spec.null_ratio,
|
|
400
|
+
provider=spec.provider,
|
|
401
|
+
)
|
|
402
|
+
logger.debug(
|
|
403
|
+
"Resolved implicit association via SharedPool",
|
|
404
|
+
table_name=table_name,
|
|
405
|
+
column_name=col_name,
|
|
406
|
+
pool_size=len(pool_values),
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
return specs
|
|
410
|
+
|
|
411
|
+
AI_APPLICABLE_GENERATORS: ClassVar[frozenset[str]] = frozenset({"string", "integer", "date", "datetime", "choice"})
|
|
412
|
+
|
|
413
|
+
def _apply_ai_suggestions(
|
|
414
|
+
self,
|
|
415
|
+
table_name: str,
|
|
416
|
+
column_infos: list[Any],
|
|
417
|
+
specs: dict[str, GeneratorSpec],
|
|
418
|
+
) -> dict[str, GeneratorSpec]:
|
|
419
|
+
unmatched_cols = [
|
|
420
|
+
col
|
|
421
|
+
for col in column_infos
|
|
422
|
+
if specs.get(col.name) is not None
|
|
423
|
+
and specs[col.name].generator_name in self.AI_APPLICABLE_GENERATORS
|
|
424
|
+
and not col.is_primary_key
|
|
425
|
+
and not col.is_autoincrement
|
|
426
|
+
and col.default is None
|
|
427
|
+
]
|
|
428
|
+
if not unmatched_cols:
|
|
429
|
+
return specs
|
|
430
|
+
|
|
431
|
+
try:
|
|
432
|
+
fks = self._db.get_foreign_keys(table_name)
|
|
433
|
+
all_tables = self._db.get_table_names()
|
|
434
|
+
indexes = self._schema.get_index_info(table_name)
|
|
435
|
+
sample_data = self._schema.get_sample_data(table_name, limit=5)
|
|
436
|
+
|
|
437
|
+
ai_result = self._plugins.hook.sqlseed_ai_analyze_table(
|
|
438
|
+
table_name=table_name,
|
|
439
|
+
columns=column_infos,
|
|
440
|
+
indexes=[{"name": i.name, "columns": i.columns, "unique": i.unique} for i in indexes],
|
|
441
|
+
sample_data=sample_data,
|
|
442
|
+
foreign_keys=fks,
|
|
443
|
+
all_table_names=all_tables,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if ai_result and isinstance(ai_result, dict):
|
|
447
|
+
ai_columns = ai_result.get("columns", [])
|
|
448
|
+
if isinstance(ai_columns, list):
|
|
449
|
+
for col_cfg in ai_columns:
|
|
450
|
+
col_name = col_cfg.get("name") if isinstance(col_cfg, dict) else None
|
|
451
|
+
if col_name and col_name in specs:
|
|
452
|
+
gen = col_cfg.get("generator")
|
|
453
|
+
if gen and gen != "skip":
|
|
454
|
+
derive_from = col_cfg.get("derive_from")
|
|
455
|
+
expression = col_cfg.get("expression")
|
|
456
|
+
|
|
457
|
+
if derive_from and expression:
|
|
458
|
+
specs[col_name] = GeneratorSpec(
|
|
459
|
+
generator_name="__derive__",
|
|
460
|
+
params={"derive_from": derive_from, "expression": expression},
|
|
461
|
+
)
|
|
462
|
+
else:
|
|
463
|
+
params = col_cfg.get("params", {})
|
|
464
|
+
if isinstance(params, dict):
|
|
465
|
+
specs[col_name] = GeneratorSpec(
|
|
466
|
+
generator_name=gen,
|
|
467
|
+
params=params,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
except Exception as e:
|
|
471
|
+
logger.debug("AI suggestions not available", table_name=table_name, error=str(e))
|
|
472
|
+
|
|
473
|
+
return specs
|
|
474
|
+
|
|
475
|
+
def _apply_batch_transforms(
|
|
476
|
+
self,
|
|
477
|
+
table_name: str,
|
|
478
|
+
batch: list[dict[str, Any]],
|
|
479
|
+
) -> list[dict[str, Any]]:
|
|
480
|
+
results = self._plugins.hook.sqlseed_transform_batch(
|
|
481
|
+
table_name=table_name,
|
|
482
|
+
batch=batch,
|
|
483
|
+
)
|
|
484
|
+
current = batch
|
|
485
|
+
if results:
|
|
486
|
+
for r in results:
|
|
487
|
+
if r is not None:
|
|
488
|
+
current = r
|
|
489
|
+
return current
|
|
490
|
+
|
|
491
|
+
def _apply_template_pool(
|
|
492
|
+
self,
|
|
493
|
+
table_name: str,
|
|
494
|
+
column_infos: list[Any],
|
|
495
|
+
specs: dict[str, GeneratorSpec],
|
|
496
|
+
count: int,
|
|
497
|
+
) -> dict[str, GeneratorSpec]:
|
|
498
|
+
for col_name, spec in list(specs.items()):
|
|
499
|
+
if spec.generator_name != "string":
|
|
500
|
+
continue
|
|
501
|
+
col_info = next((c for c in column_infos if c.name == col_name), None)
|
|
502
|
+
if col_info is None or col_info.is_primary_key or col_info.is_autoincrement:
|
|
503
|
+
continue
|
|
504
|
+
if col_info.default is not None:
|
|
505
|
+
continue
|
|
506
|
+
|
|
507
|
+
sample_data_for_col: list[Any] = []
|
|
508
|
+
with contextlib.suppress(Exception):
|
|
509
|
+
sample_data_for_col = self._db.get_column_values(table_name, col_name, limit=10)
|
|
510
|
+
|
|
511
|
+
template_values = self._plugins.hook.sqlseed_pre_generate_templates(
|
|
512
|
+
table_name=table_name,
|
|
513
|
+
column_name=col_name,
|
|
514
|
+
column_type=col_info.type,
|
|
515
|
+
count=min(count, 50),
|
|
516
|
+
sample_data=sample_data_for_col,
|
|
517
|
+
)
|
|
518
|
+
if template_values:
|
|
519
|
+
specs[col_name] = GeneratorSpec(
|
|
520
|
+
generator_name="foreign_key",
|
|
521
|
+
params={
|
|
522
|
+
"ref_table": "__template_pool__",
|
|
523
|
+
"ref_column": col_name,
|
|
524
|
+
"strategy": "random",
|
|
525
|
+
"_ref_values": template_values,
|
|
526
|
+
},
|
|
527
|
+
)
|
|
528
|
+
return specs
|
|
529
|
+
|
|
530
|
+
def _register_shared_pool(
|
|
531
|
+
self,
|
|
532
|
+
table_name: str,
|
|
533
|
+
generator_specs: dict[str, GeneratorSpec],
|
|
534
|
+
) -> None:
|
|
535
|
+
for col_name, spec in generator_specs.items():
|
|
536
|
+
if spec.generator_name == "skip":
|
|
537
|
+
continue
|
|
538
|
+
with contextlib.suppress(Exception):
|
|
539
|
+
values = self._db.get_column_values(table_name, col_name, limit=10000)
|
|
540
|
+
if values:
|
|
541
|
+
self._shared_pool.merge(col_name, values)
|
|
542
|
+
|
|
543
|
+
def fill(
|
|
544
|
+
self,
|
|
545
|
+
table_name: str,
|
|
546
|
+
*,
|
|
547
|
+
count: int = 1000,
|
|
548
|
+
columns: dict[str, Any] | None = None,
|
|
549
|
+
seed: int | None = None,
|
|
550
|
+
batch_size: int = 5000,
|
|
551
|
+
clear_before: bool = False,
|
|
552
|
+
column_configs: list[Any] | None = None,
|
|
553
|
+
transform: str | None = None,
|
|
554
|
+
) -> GenerationResult:
|
|
555
|
+
return self.fill_table(
|
|
556
|
+
table_name=table_name,
|
|
557
|
+
count=count,
|
|
558
|
+
columns=columns,
|
|
559
|
+
seed=seed,
|
|
560
|
+
batch_size=batch_size,
|
|
561
|
+
clear_before=clear_before,
|
|
562
|
+
column_configs=column_configs,
|
|
563
|
+
transform=transform,
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
def close(self) -> None:
|
|
567
|
+
if self._connected:
|
|
568
|
+
self._db.close()
|
|
569
|
+
self._connected = False
|
|
570
|
+
|
|
571
|
+
def __enter__(self) -> DataOrchestrator:
|
|
572
|
+
self._ensure_connected()
|
|
573
|
+
return self
|
|
574
|
+
|
|
575
|
+
def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any) -> None:
|
|
576
|
+
if self._connected:
|
|
577
|
+
self._db.close()
|
|
578
|
+
self._connected = False
|
sqlseed/core/relation.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from sqlseed._utils.logger import get_logger
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from sqlseed.database._protocol import ForeignKeyInfo
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SharedPool:
|
|
14
|
+
"""Cross-table shared value pool for maintaining referential integrity."""
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
self._pools: dict[str, list[Any]] = {}
|
|
18
|
+
|
|
19
|
+
def register(self, column_name: str, values: list[Any]) -> None:
|
|
20
|
+
self._pools[column_name] = list(values)
|
|
21
|
+
|
|
22
|
+
def get(self, column_name: str) -> list[Any]:
|
|
23
|
+
return self._pools.get(column_name, [])
|
|
24
|
+
|
|
25
|
+
def has(self, column_name: str) -> bool:
|
|
26
|
+
return column_name in self._pools and len(self._pools[column_name]) > 0
|
|
27
|
+
|
|
28
|
+
def merge(self, column_name: str, values: list[Any]) -> None:
|
|
29
|
+
if column_name not in self._pools:
|
|
30
|
+
self._pools[column_name] = []
|
|
31
|
+
existing = set(self._pools[column_name])
|
|
32
|
+
for v in values:
|
|
33
|
+
if v not in existing:
|
|
34
|
+
self._pools[column_name].append(v)
|
|
35
|
+
existing.add(v)
|
|
36
|
+
|
|
37
|
+
def clear(self) -> None:
|
|
38
|
+
self._pools.clear()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RelationResolver:
|
|
42
|
+
def __init__(self, db_adapter: Any) -> None:
|
|
43
|
+
self._db = db_adapter
|
|
44
|
+
self._fk_cache: dict[str, list[ForeignKeyInfo]] = {}
|
|
45
|
+
|
|
46
|
+
def get_foreign_keys(self, table_name: str) -> list[ForeignKeyInfo]:
|
|
47
|
+
if table_name not in self._fk_cache:
|
|
48
|
+
self._fk_cache[table_name] = self._db.get_foreign_keys(table_name)
|
|
49
|
+
return self._fk_cache[table_name]
|
|
50
|
+
|
|
51
|
+
def get_dependencies(self, table_name: str) -> set[str]:
|
|
52
|
+
fks = self.get_foreign_keys(table_name)
|
|
53
|
+
return {fk.ref_table for fk in fks if fk.ref_table != table_name}
|
|
54
|
+
|
|
55
|
+
def topological_sort(self, table_names: list[str]) -> list[str]:
|
|
56
|
+
graph: dict[str, set[str]] = {}
|
|
57
|
+
for table in table_names:
|
|
58
|
+
deps = self.get_dependencies(table)
|
|
59
|
+
graph[table] = deps & set(table_names)
|
|
60
|
+
|
|
61
|
+
visited: set[str] = set()
|
|
62
|
+
temp_visited: set[str] = set()
|
|
63
|
+
result: list[str] = []
|
|
64
|
+
|
|
65
|
+
def visit(node: str) -> None:
|
|
66
|
+
if node in visited:
|
|
67
|
+
return
|
|
68
|
+
if node in temp_visited:
|
|
69
|
+
raise ValueError(f"Circular dependency detected involving table: {node}")
|
|
70
|
+
temp_visited.add(node)
|
|
71
|
+
for dep in graph.get(node, set()):
|
|
72
|
+
visit(dep)
|
|
73
|
+
temp_visited.discard(node)
|
|
74
|
+
visited.add(node)
|
|
75
|
+
result.append(node)
|
|
76
|
+
|
|
77
|
+
for table in table_names:
|
|
78
|
+
visit(table)
|
|
79
|
+
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
def resolve_foreign_key_values(
|
|
83
|
+
self,
|
|
84
|
+
table_name: str,
|
|
85
|
+
column_name: str,
|
|
86
|
+
) -> list[Any]:
|
|
87
|
+
fks = self.get_foreign_keys(table_name)
|
|
88
|
+
for fk in fks:
|
|
89
|
+
if fk.column == column_name:
|
|
90
|
+
values: list[Any] = self._db.get_column_values(fk.ref_table, fk.ref_column)
|
|
91
|
+
logger.debug(
|
|
92
|
+
"Resolved FK",
|
|
93
|
+
table_name=table_name,
|
|
94
|
+
column_name=column_name,
|
|
95
|
+
ref_table=fk.ref_table,
|
|
96
|
+
ref_column=fk.ref_column,
|
|
97
|
+
values_count=len(values),
|
|
98
|
+
)
|
|
99
|
+
return values
|
|
100
|
+
return []
|
|
101
|
+
|
|
102
|
+
def get_fk_info(self, table_name: str, column_name: str) -> ForeignKeyInfo | None:
|
|
103
|
+
fks = self.get_foreign_keys(table_name)
|
|
104
|
+
for fk in fks:
|
|
105
|
+
if fk.column == column_name:
|
|
106
|
+
return fk
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
def clear_cache(self) -> None:
|
|
110
|
+
self._fk_cache.clear()
|
|
111
|
+
|
|
112
|
+
def load_shared_pool(self, shared_pool: SharedPool, table_name: str) -> None:
|
|
113
|
+
"""Load generated values from shared pool into FK resolution."""
|
|
114
|
+
for col_name, values in shared_pool._pools.items():
|
|
115
|
+
fks = self.get_foreign_keys(table_name)
|
|
116
|
+
for fk in fks:
|
|
117
|
+
if fk.ref_column == col_name:
|
|
118
|
+
logger.debug(
|
|
119
|
+
"Loaded shared pool into FK",
|
|
120
|
+
table_name=table_name,
|
|
121
|
+
column_name=fk.column,
|
|
122
|
+
ref_column=col_name,
|
|
123
|
+
values_count=len(values),
|
|
124
|
+
)
|