sqlseed 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sqlseed/__init__.py +121 -0
  2. sqlseed/_utils/__init__.py +11 -0
  3. sqlseed/_utils/logger.py +30 -0
  4. sqlseed/_utils/metrics.py +45 -0
  5. sqlseed/_utils/progress.py +14 -0
  6. sqlseed/_utils/schema_helpers.py +51 -0
  7. sqlseed/_utils/sql_safe.py +45 -0
  8. sqlseed/_version.py +1 -0
  9. sqlseed/cli/__init__.py +3 -0
  10. sqlseed/cli/main.py +316 -0
  11. sqlseed/config/__init__.py +14 -0
  12. sqlseed/config/loader.py +66 -0
  13. sqlseed/config/models.py +99 -0
  14. sqlseed/config/snapshot.py +91 -0
  15. sqlseed/core/__init__.py +14 -0
  16. sqlseed/core/column_dag.py +108 -0
  17. sqlseed/core/constraints.py +116 -0
  18. sqlseed/core/expression.py +71 -0
  19. sqlseed/core/mapper.py +257 -0
  20. sqlseed/core/orchestrator.py +578 -0
  21. sqlseed/core/relation.py +124 -0
  22. sqlseed/core/result.py +23 -0
  23. sqlseed/core/schema.py +97 -0
  24. sqlseed/core/transform.py +27 -0
  25. sqlseed/database/__init__.py +14 -0
  26. sqlseed/database/_protocol.py +72 -0
  27. sqlseed/database/optimizer.py +96 -0
  28. sqlseed/database/raw_sqlite_adapter.py +197 -0
  29. sqlseed/database/sqlite_utils_adapter.py +183 -0
  30. sqlseed/generators/__init__.py +11 -0
  31. sqlseed/generators/_protocol.py +73 -0
  32. sqlseed/generators/base_provider.py +448 -0
  33. sqlseed/generators/faker_provider.py +157 -0
  34. sqlseed/generators/mimesis_provider.py +203 -0
  35. sqlseed/generators/registry.py +86 -0
  36. sqlseed/generators/stream.py +157 -0
  37. sqlseed/py.typed +0 -0
  38. sqlseed-0.1.0.dist-info/METADATA +934 -0
  39. sqlseed-0.1.0.dist-info/RECORD +42 -0
  40. sqlseed-0.1.0.dist-info/WHEEL +4 -0
  41. sqlseed-0.1.0.dist-info/entry_points.txt +6 -0
  42. sqlseed-0.1.0.dist-info/licenses/LICENSE +17 -0
@@ -0,0 +1,578 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import time
5
+ from typing import TYPE_CHECKING, Any, ClassVar
6
+
7
+ from sqlseed._utils.logger import get_logger
8
+ from sqlseed._utils.metrics import MetricsCollector
9
+ from sqlseed._utils.progress import create_progress
10
+ from sqlseed.core.column_dag import ColumnDAG
11
+ from sqlseed.core.constraints import ConstraintSolver
12
+ from sqlseed.core.expression import ExpressionEngine
13
+ from sqlseed.core.mapper import ColumnMapper, GeneratorSpec
14
+ from sqlseed.core.relation import RelationResolver, SharedPool
15
+ from sqlseed.core.result import GenerationResult
16
+ from sqlseed.core.schema import SchemaInferrer
17
+ from sqlseed.core.transform import load_transform
18
+ from sqlseed.database.raw_sqlite_adapter import RawSQLiteAdapter
19
+ from sqlseed.database.sqlite_utils_adapter import SQLiteUtilsAdapter
20
+ from sqlseed.generators.registry import ProviderRegistry
21
+ from sqlseed.generators.stream import DataStream
22
+ from sqlseed.plugins.manager import PluginManager
23
+
24
+ if TYPE_CHECKING:
25
+ from sqlseed.database._protocol import DatabaseAdapter
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ class DataOrchestrator:
31
+ def __init__(
32
+ self,
33
+ db_path: str,
34
+ *,
35
+ provider_name: str = "mimesis",
36
+ locale: str = "en_US",
37
+ optimize_pragma: bool = True,
38
+ ) -> None:
39
+ self._db_path = db_path
40
+ self._provider_name = provider_name
41
+ self._locale = locale
42
+ self._optimize_pragma = optimize_pragma
43
+
44
+ self._db: DatabaseAdapter = self._create_adapter()
45
+ self._schema = SchemaInferrer(self._db)
46
+ self._mapper = ColumnMapper()
47
+ self._relation = RelationResolver(self._db)
48
+ self._registry = ProviderRegistry()
49
+ self._metrics = MetricsCollector()
50
+ self._plugins = PluginManager()
51
+ self._shared_pool = SharedPool()
52
+
53
+ self._connected = False
54
+
55
+ def _create_adapter(self) -> DatabaseAdapter:
56
+ try:
57
+ import sqlite_utils # noqa: F401
58
+ except ImportError:
59
+ logger.debug("sqlite-utils not available, falling back to raw sqlite3")
60
+ return RawSQLiteAdapter()
61
+ return SQLiteUtilsAdapter()
62
+
63
+ def _ensure_connected(self) -> None:
64
+ if not self._connected:
65
+ self._db.connect(self._db_path)
66
+ self._connected = True
67
+ self._plugins.load_plugins()
68
+ self._plugins.hook.sqlseed_register_providers(registry=self._registry)
69
+ self._plugins.hook.sqlseed_register_column_mappers(mapper=self._mapper)
70
+ self._registry.register_from_entry_points()
71
+ try:
72
+ provider = self._registry.ensure_provider(self._provider_name)
73
+ self._registry.set_default(self._provider_name)
74
+ except (ImportError, ValueError):
75
+ logger.warning(
76
+ "Provider not available, falling back to 'base'",
77
+ provider_name=self._provider_name,
78
+ )
79
+ self._provider_name = "base"
80
+ provider = self._registry.get(self._provider_name)
81
+ provider.set_locale(self._locale)
82
+
83
+ def fill_table(
84
+ self,
85
+ table_name: str,
86
+ *,
87
+ count: int = 1000,
88
+ columns: dict[str, Any] | None = None,
89
+ seed: int | None = None,
90
+ batch_size: int = 5000,
91
+ clear_before: bool = False,
92
+ column_configs: list[Any] | None = None,
93
+ transform: str | None = None,
94
+ ) -> GenerationResult:
95
+ self._ensure_connected()
96
+ start_time = time.monotonic()
97
+ total_inserted = 0
98
+ batch_count = 0
99
+
100
+ try:
101
+ if self._optimize_pragma:
102
+ self._db.optimize_for_bulk_write(count)
103
+
104
+ if clear_before:
105
+ self._db.clear_table(table_name)
106
+
107
+ column_infos = self._schema.get_column_info(table_name)
108
+ user_configs = self._resolve_user_configs(columns, column_configs)
109
+ generator_specs = self._mapper.map_columns(column_infos, user_configs)
110
+ generator_specs = self._resolve_foreign_keys(table_name, generator_specs)
111
+ generator_specs = self._apply_ai_suggestions(table_name, column_infos, generator_specs)
112
+ generator_specs = self._apply_template_pool(table_name, column_infos, generator_specs, count)
113
+
114
+ dag = ColumnDAG()
115
+ col_configs_list = list(user_configs.values()) if user_configs else None
116
+ dag_nodes = dag.build(generator_specs, col_configs_list)
117
+
118
+ expr_engine = ExpressionEngine()
119
+ constraint_solver = ConstraintSolver()
120
+
121
+ transform_fn = None
122
+ if transform:
123
+ transform_fn = load_transform(transform)
124
+
125
+ provider = self._registry.get(self._provider_name)
126
+
127
+ self._plugins.hook.sqlseed_before_generate(
128
+ table_name=table_name,
129
+ count=count,
130
+ config=None,
131
+ )
132
+
133
+ stream = DataStream(
134
+ dag_nodes=dag_nodes,
135
+ provider=provider,
136
+ expr_engine=expr_engine,
137
+ constraint_solver=constraint_solver,
138
+ transform_fn=transform_fn,
139
+ seed=seed,
140
+ )
141
+
142
+ progress = create_progress()
143
+ with progress:
144
+ task_id = progress.add_task(f"Generating {table_name}", total=count)
145
+ for batch in stream.generate(count, batch_size):
146
+ batch_count += 1
147
+
148
+ self._plugins.hook.sqlseed_before_insert(
149
+ table_name=table_name,
150
+ batch_number=batch_count,
151
+ batch_size=len(batch),
152
+ )
153
+
154
+ current_batch = self._apply_batch_transforms(table_name, batch)
155
+
156
+ inserted = self._db.batch_insert(table_name, iter(current_batch), batch_size)
157
+ total_inserted += inserted
158
+
159
+ self._metrics.record(f"{table_name}.batch_insert", float(inserted))
160
+
161
+ self._plugins.hook.sqlseed_after_insert(
162
+ table_name=table_name,
163
+ batch_number=batch_count,
164
+ rows_inserted=inserted,
165
+ )
166
+
167
+ progress.update(task_id, advance=len(batch))
168
+
169
+ except Exception as e:
170
+ logger.error("Failed to fill table", table_name=table_name, error=e)
171
+ return GenerationResult(
172
+ table_name=table_name,
173
+ count=total_inserted,
174
+ elapsed=time.monotonic() - start_time,
175
+ errors=[str(e)],
176
+ )
177
+ finally:
178
+ if self._optimize_pragma:
179
+ self._db.restore_settings()
180
+
181
+ elapsed = time.monotonic() - start_time
182
+
183
+ self._metrics.record(f"{table_name}.total_elapsed", elapsed)
184
+ self._metrics.record(f"{table_name}.total_rows", float(total_inserted))
185
+
186
+ self._plugins.hook.sqlseed_after_generate(
187
+ table_name=table_name,
188
+ count=total_inserted,
189
+ elapsed=elapsed,
190
+ )
191
+
192
+ self._register_shared_pool(table_name, generator_specs)
193
+
194
+ return GenerationResult(
195
+ table_name=table_name,
196
+ count=total_inserted,
197
+ elapsed=elapsed,
198
+ batch_count=batch_count,
199
+ )
200
+
201
+ def preview_table(
202
+ self,
203
+ table_name: str,
204
+ *,
205
+ count: int = 5,
206
+ columns: dict[str, Any] | None = None,
207
+ seed: int | None = None,
208
+ transform: str | None = None,
209
+ column_configs: list[Any] | None = None,
210
+ ) -> list[dict[str, Any]]:
211
+ self._ensure_connected()
212
+
213
+ column_infos = self._schema.get_column_info(table_name)
214
+ user_configs = self._resolve_user_configs(columns, column_configs)
215
+ generator_specs = self._mapper.map_columns(column_infos, user_configs)
216
+ generator_specs = self._resolve_foreign_keys(table_name, generator_specs)
217
+
218
+ dag = ColumnDAG()
219
+ col_configs_list = list(user_configs.values()) if user_configs else None
220
+ dag_nodes = dag.build(generator_specs, col_configs_list)
221
+
222
+ expr_engine = ExpressionEngine()
223
+ constraint_solver = ConstraintSolver()
224
+
225
+ transform_fn = None
226
+ if transform:
227
+ transform_fn = load_transform(transform)
228
+
229
+ provider = self._registry.get(self._provider_name)
230
+
231
+ stream = DataStream(
232
+ dag_nodes=dag_nodes,
233
+ provider=provider,
234
+ expr_engine=expr_engine,
235
+ constraint_solver=constraint_solver,
236
+ transform_fn=transform_fn,
237
+ seed=seed,
238
+ )
239
+ result: list[dict[str, Any]] = []
240
+ for batch in stream.generate(count, batch_size=count):
241
+ current_batch = self._apply_batch_transforms(table_name, batch)
242
+ result.extend(current_batch)
243
+ return result
244
+
245
+ def get_schema_context(self, table_name: str) -> dict[str, Any]:
246
+ self._ensure_connected()
247
+ column_infos = self._schema.get_column_info(table_name)
248
+ fks = self._db.get_foreign_keys(table_name)
249
+ all_tables = self._db.get_table_names()
250
+
251
+ indexes: list[dict[str, Any]] = []
252
+ with contextlib.suppress(Exception):
253
+ idx_infos = self._schema.get_index_info(table_name)
254
+ indexes = [{"name": idx.name, "columns": idx.columns, "unique": idx.unique} for idx in idx_infos]
255
+
256
+ sample_data: list[dict[str, Any]] = []
257
+ with contextlib.suppress(Exception):
258
+ sample_data = self._schema.get_sample_data(table_name, limit=5)
259
+
260
+ distribution: list[dict[str, Any]] = []
261
+ with contextlib.suppress(Exception):
262
+ distribution = self._schema.profile_column_distribution(table_name, limit=1000)
263
+
264
+ return {
265
+ "table_name": table_name,
266
+ "columns": column_infos,
267
+ "foreign_keys": fks,
268
+ "indexes": indexes,
269
+ "sample_data": sample_data,
270
+ "all_table_names": all_tables,
271
+ "distribution": distribution,
272
+ }
273
+
274
+ def get_column_names(self, table_name: str) -> set[str]:
275
+ self._ensure_connected()
276
+ return {c.name for c in self._schema.get_column_info(table_name)}
277
+
278
+ def get_skippable_columns(self, table_name: str) -> set[str]:
279
+ self._ensure_connected()
280
+ return {
281
+ c.name
282
+ for c in self._schema.get_column_info(table_name)
283
+ if (c.is_primary_key and c.is_autoincrement) or c.default is not None
284
+ }
285
+
286
+ def report(self) -> str:
287
+ if not self._connected:
288
+ return "Not connected to any database."
289
+
290
+ tables = self._db.get_table_names()
291
+ lines = [f"Database: {self._db_path}", "=" * 50]
292
+ for table in tables:
293
+ count = self._db.get_row_count(table)
294
+ lines.append(f" {table}: {count} rows")
295
+ return "\n".join(lines)
296
+
297
+ def _resolve_user_configs(
298
+ self,
299
+ columns: dict[str, Any] | None,
300
+ column_configs: list[Any] | None,
301
+ ) -> dict[str, Any]:
302
+ from sqlseed.config.models import ColumnConfig
303
+
304
+ configs: dict[str, Any] = {}
305
+
306
+ if column_configs:
307
+ for cc in column_configs:
308
+ if isinstance(cc, ColumnConfig):
309
+ configs[cc.name] = cc
310
+
311
+ if columns:
312
+ for col_name, col_spec in columns.items():
313
+ if isinstance(col_spec, str):
314
+ configs[col_name] = ColumnConfig(name=col_name, generator=col_spec)
315
+ elif isinstance(col_spec, dict):
316
+ spec_copy = dict(col_spec)
317
+ gen_type = spec_copy.pop("type", "string")
318
+ configs[col_name] = ColumnConfig(
319
+ name=col_name,
320
+ generator=gen_type,
321
+ params=spec_copy,
322
+ )
323
+
324
+ return configs
325
+
326
+ def _resolve_foreign_keys(
327
+ self,
328
+ table_name: str,
329
+ specs: dict[str, GeneratorSpec],
330
+ ) -> dict[str, GeneratorSpec]:
331
+ for col_name, spec in specs.items():
332
+ if spec.generator_name == "foreign_key_or_integer":
333
+ fk_info = self._relation.get_fk_info(table_name, col_name)
334
+ if fk_info:
335
+ ref_values = self._relation.resolve_foreign_key_values(table_name, col_name)
336
+ new_spec = GeneratorSpec(
337
+ generator_name="foreign_key",
338
+ params={
339
+ "ref_table": fk_info.ref_table,
340
+ "ref_column": fk_info.ref_column,
341
+ "strategy": "random",
342
+ "_ref_values": ref_values,
343
+ },
344
+ null_ratio=spec.null_ratio,
345
+ provider=spec.provider,
346
+ )
347
+ specs[col_name] = new_spec
348
+ else:
349
+ specs[col_name] = GeneratorSpec(
350
+ generator_name="integer",
351
+ params={"min_value": 1, "max_value": 999999},
352
+ null_ratio=spec.null_ratio,
353
+ provider=spec.provider,
354
+ )
355
+
356
+ elif spec.generator_name == "foreign_key":
357
+ if "ref_table" in spec.params:
358
+ ref_values = self._db.get_column_values(
359
+ spec.params["ref_table"],
360
+ spec.params["ref_column"],
361
+ )
362
+ spec.params["_ref_values"] = ref_values
363
+
364
+ return self._resolve_implicit_associations(table_name, specs)
365
+
366
+ def _resolve_implicit_associations(
367
+ self,
368
+ table_name: str,
369
+ specs: dict[str, GeneratorSpec],
370
+ ) -> dict[str, GeneratorSpec]:
371
+ """Resolve implicit cross-table associations via SharedPool.
372
+
373
+ When a column name exists in the SharedPool (generated by a previously
374
+ filled table), automatically use those values as a foreign_key strategy.
375
+ This handles cases like sUserNo appearing in multiple tables without
376
+ an explicit FK constraint.
377
+ """
378
+ if not self._shared_pool._pools:
379
+ return specs
380
+
381
+ for col_name, spec in list(specs.items()):
382
+ if spec.generator_name != "foreign_key_or_integer":
383
+ continue
384
+ if not self._shared_pool.has(col_name):
385
+ continue
386
+
387
+ pool_values = self._shared_pool.get(col_name)
388
+ if not pool_values:
389
+ continue
390
+
391
+ specs[col_name] = GeneratorSpec(
392
+ generator_name="foreign_key",
393
+ params={
394
+ "ref_table": "__shared_pool__",
395
+ "ref_column": col_name,
396
+ "strategy": "random",
397
+ "_ref_values": pool_values,
398
+ },
399
+ null_ratio=spec.null_ratio,
400
+ provider=spec.provider,
401
+ )
402
+ logger.debug(
403
+ "Resolved implicit association via SharedPool",
404
+ table_name=table_name,
405
+ column_name=col_name,
406
+ pool_size=len(pool_values),
407
+ )
408
+
409
+ return specs
410
+
411
+ AI_APPLICABLE_GENERATORS: ClassVar[frozenset[str]] = frozenset({"string", "integer", "date", "datetime", "choice"})
412
+
413
+ def _apply_ai_suggestions(
414
+ self,
415
+ table_name: str,
416
+ column_infos: list[Any],
417
+ specs: dict[str, GeneratorSpec],
418
+ ) -> dict[str, GeneratorSpec]:
419
+ unmatched_cols = [
420
+ col
421
+ for col in column_infos
422
+ if specs.get(col.name) is not None
423
+ and specs[col.name].generator_name in self.AI_APPLICABLE_GENERATORS
424
+ and not col.is_primary_key
425
+ and not col.is_autoincrement
426
+ and col.default is None
427
+ ]
428
+ if not unmatched_cols:
429
+ return specs
430
+
431
+ try:
432
+ fks = self._db.get_foreign_keys(table_name)
433
+ all_tables = self._db.get_table_names()
434
+ indexes = self._schema.get_index_info(table_name)
435
+ sample_data = self._schema.get_sample_data(table_name, limit=5)
436
+
437
+ ai_result = self._plugins.hook.sqlseed_ai_analyze_table(
438
+ table_name=table_name,
439
+ columns=column_infos,
440
+ indexes=[{"name": i.name, "columns": i.columns, "unique": i.unique} for i in indexes],
441
+ sample_data=sample_data,
442
+ foreign_keys=fks,
443
+ all_table_names=all_tables,
444
+ )
445
+
446
+ if ai_result and isinstance(ai_result, dict):
447
+ ai_columns = ai_result.get("columns", [])
448
+ if isinstance(ai_columns, list):
449
+ for col_cfg in ai_columns:
450
+ col_name = col_cfg.get("name") if isinstance(col_cfg, dict) else None
451
+ if col_name and col_name in specs:
452
+ gen = col_cfg.get("generator")
453
+ if gen and gen != "skip":
454
+ derive_from = col_cfg.get("derive_from")
455
+ expression = col_cfg.get("expression")
456
+
457
+ if derive_from and expression:
458
+ specs[col_name] = GeneratorSpec(
459
+ generator_name="__derive__",
460
+ params={"derive_from": derive_from, "expression": expression},
461
+ )
462
+ else:
463
+ params = col_cfg.get("params", {})
464
+ if isinstance(params, dict):
465
+ specs[col_name] = GeneratorSpec(
466
+ generator_name=gen,
467
+ params=params,
468
+ )
469
+
470
+ except Exception as e:
471
+ logger.debug("AI suggestions not available", table_name=table_name, error=str(e))
472
+
473
+ return specs
474
+
475
+ def _apply_batch_transforms(
476
+ self,
477
+ table_name: str,
478
+ batch: list[dict[str, Any]],
479
+ ) -> list[dict[str, Any]]:
480
+ results = self._plugins.hook.sqlseed_transform_batch(
481
+ table_name=table_name,
482
+ batch=batch,
483
+ )
484
+ current = batch
485
+ if results:
486
+ for r in results:
487
+ if r is not None:
488
+ current = r
489
+ return current
490
+
491
+ def _apply_template_pool(
492
+ self,
493
+ table_name: str,
494
+ column_infos: list[Any],
495
+ specs: dict[str, GeneratorSpec],
496
+ count: int,
497
+ ) -> dict[str, GeneratorSpec]:
498
+ for col_name, spec in list(specs.items()):
499
+ if spec.generator_name != "string":
500
+ continue
501
+ col_info = next((c for c in column_infos if c.name == col_name), None)
502
+ if col_info is None or col_info.is_primary_key or col_info.is_autoincrement:
503
+ continue
504
+ if col_info.default is not None:
505
+ continue
506
+
507
+ sample_data_for_col: list[Any] = []
508
+ with contextlib.suppress(Exception):
509
+ sample_data_for_col = self._db.get_column_values(table_name, col_name, limit=10)
510
+
511
+ template_values = self._plugins.hook.sqlseed_pre_generate_templates(
512
+ table_name=table_name,
513
+ column_name=col_name,
514
+ column_type=col_info.type,
515
+ count=min(count, 50),
516
+ sample_data=sample_data_for_col,
517
+ )
518
+ if template_values:
519
+ specs[col_name] = GeneratorSpec(
520
+ generator_name="foreign_key",
521
+ params={
522
+ "ref_table": "__template_pool__",
523
+ "ref_column": col_name,
524
+ "strategy": "random",
525
+ "_ref_values": template_values,
526
+ },
527
+ )
528
+ return specs
529
+
530
+ def _register_shared_pool(
531
+ self,
532
+ table_name: str,
533
+ generator_specs: dict[str, GeneratorSpec],
534
+ ) -> None:
535
+ for col_name, spec in generator_specs.items():
536
+ if spec.generator_name == "skip":
537
+ continue
538
+ with contextlib.suppress(Exception):
539
+ values = self._db.get_column_values(table_name, col_name, limit=10000)
540
+ if values:
541
+ self._shared_pool.merge(col_name, values)
542
+
543
+ def fill(
544
+ self,
545
+ table_name: str,
546
+ *,
547
+ count: int = 1000,
548
+ columns: dict[str, Any] | None = None,
549
+ seed: int | None = None,
550
+ batch_size: int = 5000,
551
+ clear_before: bool = False,
552
+ column_configs: list[Any] | None = None,
553
+ transform: str | None = None,
554
+ ) -> GenerationResult:
555
+ return self.fill_table(
556
+ table_name=table_name,
557
+ count=count,
558
+ columns=columns,
559
+ seed=seed,
560
+ batch_size=batch_size,
561
+ clear_before=clear_before,
562
+ column_configs=column_configs,
563
+ transform=transform,
564
+ )
565
+
566
+ def close(self) -> None:
567
+ if self._connected:
568
+ self._db.close()
569
+ self._connected = False
570
+
571
+ def __enter__(self) -> DataOrchestrator:
572
+ self._ensure_connected()
573
+ return self
574
+
575
+ def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any) -> None:
576
+ if self._connected:
577
+ self._db.close()
578
+ self._connected = False
@@ -0,0 +1,124 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from sqlseed._utils.logger import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ from sqlseed.database._protocol import ForeignKeyInfo
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class SharedPool:
14
+ """Cross-table shared value pool for maintaining referential integrity."""
15
+
16
+ def __init__(self) -> None:
17
+ self._pools: dict[str, list[Any]] = {}
18
+
19
+ def register(self, column_name: str, values: list[Any]) -> None:
20
+ self._pools[column_name] = list(values)
21
+
22
+ def get(self, column_name: str) -> list[Any]:
23
+ return self._pools.get(column_name, [])
24
+
25
+ def has(self, column_name: str) -> bool:
26
+ return column_name in self._pools and len(self._pools[column_name]) > 0
27
+
28
+ def merge(self, column_name: str, values: list[Any]) -> None:
29
+ if column_name not in self._pools:
30
+ self._pools[column_name] = []
31
+ existing = set(self._pools[column_name])
32
+ for v in values:
33
+ if v not in existing:
34
+ self._pools[column_name].append(v)
35
+ existing.add(v)
36
+
37
+ def clear(self) -> None:
38
+ self._pools.clear()
39
+
40
+
41
+ class RelationResolver:
42
+ def __init__(self, db_adapter: Any) -> None:
43
+ self._db = db_adapter
44
+ self._fk_cache: dict[str, list[ForeignKeyInfo]] = {}
45
+
46
+ def get_foreign_keys(self, table_name: str) -> list[ForeignKeyInfo]:
47
+ if table_name not in self._fk_cache:
48
+ self._fk_cache[table_name] = self._db.get_foreign_keys(table_name)
49
+ return self._fk_cache[table_name]
50
+
51
+ def get_dependencies(self, table_name: str) -> set[str]:
52
+ fks = self.get_foreign_keys(table_name)
53
+ return {fk.ref_table for fk in fks if fk.ref_table != table_name}
54
+
55
+ def topological_sort(self, table_names: list[str]) -> list[str]:
56
+ graph: dict[str, set[str]] = {}
57
+ for table in table_names:
58
+ deps = self.get_dependencies(table)
59
+ graph[table] = deps & set(table_names)
60
+
61
+ visited: set[str] = set()
62
+ temp_visited: set[str] = set()
63
+ result: list[str] = []
64
+
65
+ def visit(node: str) -> None:
66
+ if node in visited:
67
+ return
68
+ if node in temp_visited:
69
+ raise ValueError(f"Circular dependency detected involving table: {node}")
70
+ temp_visited.add(node)
71
+ for dep in graph.get(node, set()):
72
+ visit(dep)
73
+ temp_visited.discard(node)
74
+ visited.add(node)
75
+ result.append(node)
76
+
77
+ for table in table_names:
78
+ visit(table)
79
+
80
+ return result
81
+
82
+ def resolve_foreign_key_values(
83
+ self,
84
+ table_name: str,
85
+ column_name: str,
86
+ ) -> list[Any]:
87
+ fks = self.get_foreign_keys(table_name)
88
+ for fk in fks:
89
+ if fk.column == column_name:
90
+ values: list[Any] = self._db.get_column_values(fk.ref_table, fk.ref_column)
91
+ logger.debug(
92
+ "Resolved FK",
93
+ table_name=table_name,
94
+ column_name=column_name,
95
+ ref_table=fk.ref_table,
96
+ ref_column=fk.ref_column,
97
+ values_count=len(values),
98
+ )
99
+ return values
100
+ return []
101
+
102
+ def get_fk_info(self, table_name: str, column_name: str) -> ForeignKeyInfo | None:
103
+ fks = self.get_foreign_keys(table_name)
104
+ for fk in fks:
105
+ if fk.column == column_name:
106
+ return fk
107
+ return None
108
+
109
+ def clear_cache(self) -> None:
110
+ self._fk_cache.clear()
111
+
112
+ def load_shared_pool(self, shared_pool: SharedPool, table_name: str) -> None:
113
+ """Load generated values from shared pool into FK resolution."""
114
+ for col_name, values in shared_pool._pools.items():
115
+ fks = self.get_foreign_keys(table_name)
116
+ for fk in fks:
117
+ if fk.ref_column == col_name:
118
+ logger.debug(
119
+ "Loaded shared pool into FK",
120
+ table_name=table_name,
121
+ column_name=fk.column,
122
+ ref_column=col_name,
123
+ values_count=len(values),
124
+ )