sqlseed 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sqlseed/__init__.py +121 -0
  2. sqlseed/_utils/__init__.py +11 -0
  3. sqlseed/_utils/logger.py +30 -0
  4. sqlseed/_utils/metrics.py +45 -0
  5. sqlseed/_utils/progress.py +14 -0
  6. sqlseed/_utils/schema_helpers.py +51 -0
  7. sqlseed/_utils/sql_safe.py +45 -0
  8. sqlseed/_version.py +1 -0
  9. sqlseed/cli/__init__.py +3 -0
  10. sqlseed/cli/main.py +316 -0
  11. sqlseed/config/__init__.py +14 -0
  12. sqlseed/config/loader.py +66 -0
  13. sqlseed/config/models.py +99 -0
  14. sqlseed/config/snapshot.py +91 -0
  15. sqlseed/core/__init__.py +14 -0
  16. sqlseed/core/column_dag.py +108 -0
  17. sqlseed/core/constraints.py +116 -0
  18. sqlseed/core/expression.py +71 -0
  19. sqlseed/core/mapper.py +257 -0
  20. sqlseed/core/orchestrator.py +578 -0
  21. sqlseed/core/relation.py +124 -0
  22. sqlseed/core/result.py +23 -0
  23. sqlseed/core/schema.py +97 -0
  24. sqlseed/core/transform.py +27 -0
  25. sqlseed/database/__init__.py +14 -0
  26. sqlseed/database/_protocol.py +72 -0
  27. sqlseed/database/optimizer.py +96 -0
  28. sqlseed/database/raw_sqlite_adapter.py +197 -0
  29. sqlseed/database/sqlite_utils_adapter.py +183 -0
  30. sqlseed/generators/__init__.py +11 -0
  31. sqlseed/generators/_protocol.py +73 -0
  32. sqlseed/generators/base_provider.py +448 -0
  33. sqlseed/generators/faker_provider.py +157 -0
  34. sqlseed/generators/mimesis_provider.py +203 -0
  35. sqlseed/generators/registry.py +86 -0
  36. sqlseed/generators/stream.py +157 -0
  37. sqlseed/py.typed +0 -0
  38. sqlseed-0.1.0.dist-info/METADATA +934 -0
  39. sqlseed-0.1.0.dist-info/RECORD +42 -0
  40. sqlseed-0.1.0.dist-info/WHEEL +4 -0
  41. sqlseed-0.1.0.dist-info/entry_points.txt +6 -0
  42. sqlseed-0.1.0.dist-info/licenses/LICENSE +17 -0
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import yaml
7
+
8
+ from sqlseed._utils.logger import get_logger
9
+ from sqlseed.config.models import GeneratorConfig, TableConfig
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def load_config(path: str) -> GeneratorConfig:
15
+ config_path = Path(path)
16
+ if not config_path.exists():
17
+ raise FileNotFoundError(f"Configuration file not found: {path}")
18
+
19
+ suffix = config_path.suffix.lower()
20
+ with open(config_path, encoding="utf-8") as f:
21
+ if suffix in (".yaml", ".yml"):
22
+ raw = yaml.safe_load(f)
23
+ elif suffix == ".json":
24
+ raw = json.load(f)
25
+ else:
26
+ raise ValueError(f"Unsupported configuration file format: {suffix}")
27
+
28
+ if not isinstance(raw, dict):
29
+ raise ValueError("Configuration file must contain a YAML/JSON object")
30
+
31
+ return GeneratorConfig(**raw)
32
+
33
+
34
+ def save_config(config: GeneratorConfig, path: str) -> None:
35
+ config_path = Path(path)
36
+ config_path.parent.mkdir(parents=True, exist_ok=True)
37
+
38
+ suffix = config_path.suffix.lower()
39
+ data = config.model_dump(mode="json")
40
+
41
+ with open(config_path, "w", encoding="utf-8") as f:
42
+ if suffix in (".yaml", ".yml"):
43
+ yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
44
+ elif suffix == ".json":
45
+ json.dump(data, f, indent=2, ensure_ascii=False)
46
+ else:
47
+ raise ValueError(f"Unsupported configuration file format: {suffix}")
48
+
49
+ logger.info("Configuration saved", path=path)
50
+
51
+
52
+ def generate_template(db_path: str, table_name: str | None = None) -> GeneratorConfig:
53
+ tables: list[TableConfig] = []
54
+ if table_name:
55
+ tables.append(
56
+ TableConfig(
57
+ name=table_name,
58
+ count=1000,
59
+ columns=[],
60
+ )
61
+ )
62
+
63
+ return GeneratorConfig(
64
+ db_path=db_path,
65
+ tables=tables,
66
+ )
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, Field, field_validator, model_validator
7
+ from typing_extensions import Self
8
+
9
+
10
+ class ProviderType(str, Enum):
11
+ BASE = "base"
12
+ FAKER = "faker"
13
+ MIMESIS = "mimesis"
14
+ CUSTOM = "custom"
15
+ AI = "ai"
16
+
17
+
18
+ class ColumnConstraintsConfig(BaseModel):
19
+ """列约束配置"""
20
+
21
+ unique: bool = False
22
+ min_value: int | float | None = None
23
+ max_value: int | float | None = None
24
+ regex: str | None = None
25
+ max_retries: int = Field(default=100, gt=0)
26
+
27
+
28
+ class ColumnConfig(BaseModel):
29
+ """
30
+ 列配置 — 支持源列和派生列两种模式。
31
+
32
+ 源列模式:指定 generator + params
33
+ 派生列模式:指定 derive_from + expression
34
+ 两者不能同时使用。
35
+ """
36
+
37
+ name: str
38
+
39
+ # === 源列模式 ===
40
+ generator: str | None = None
41
+ provider: ProviderType | None = None
42
+ params: dict[str, Any] = Field(default_factory=dict)
43
+ null_ratio: float = Field(default=0.0, ge=0.0, le=1.0)
44
+
45
+ # === 派生列模式 ===
46
+ derive_from: str | None = None # 源列名
47
+ expression: str | None = None # 派生表达式
48
+
49
+ # === 约束 ===
50
+ constraints: ColumnConstraintsConfig | None = None
51
+
52
+ @field_validator("null_ratio")
53
+ @classmethod
54
+ def validate_null_ratio(cls, v: float) -> float:
55
+ if not 0.0 <= v <= 1.0:
56
+ raise ValueError("null_ratio must be between 0.0 and 1.0")
57
+ return v
58
+
59
+ @model_validator(mode="after")
60
+ def validate_column_mode(self) -> Self:
61
+ if self.derive_from and self.generator:
62
+ raise ValueError(f"Column '{self.name}': cannot use both 'generator' and 'derive_from'")
63
+ if self.derive_from and not self.expression:
64
+ raise ValueError(f"Column '{self.name}': 'derive_from' requires 'expression'")
65
+ return self
66
+
67
+
68
+ class TableConfig(BaseModel):
69
+ """单表生成配置"""
70
+
71
+ name: str
72
+ count: int = Field(default=1000, gt=0)
73
+ batch_size: int = Field(default=5000, gt=0)
74
+ columns: list[ColumnConfig] = Field(default_factory=list)
75
+ clear_before: bool = False # 默认不清空,保护原始数据
76
+ seed: int | None = None
77
+ transform: str | None = None # [NEW] Python 变换脚本路径
78
+
79
+
80
+ class ColumnAssociation(BaseModel):
81
+ """跨表列关联声明 — 用于隐式关联(同名列跨表引用)"""
82
+
83
+ column_name: str
84
+ source_table: str
85
+ target_tables: list[str] = Field(default_factory=list)
86
+ strategy: str = "shared_pool"
87
+
88
+
89
+ class GeneratorConfig(BaseModel):
90
+ """全局生成配置"""
91
+
92
+ db_path: str
93
+ provider: ProviderType = ProviderType.MIMESIS
94
+ locale: str = "en_US"
95
+ tables: list[TableConfig] = Field(default_factory=list)
96
+ associations: list[ColumnAssociation] = Field(default_factory=list)
97
+ optimize_pragma: bool = True
98
+ log_level: str = "INFO"
99
+ snapshot_dir: str | None = None
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import yaml
8
+
9
+ from sqlseed._utils.logger import get_logger
10
+ from sqlseed.config.models import GeneratorConfig
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class SnapshotManager:
16
+ def __init__(self, snapshot_dir: str | None = None) -> None:
17
+ self._snapshot_dir = Path(snapshot_dir) if snapshot_dir else Path("./snapshots")
18
+
19
+ def save(
20
+ self,
21
+ config: GeneratorConfig,
22
+ table_name: str,
23
+ count: int,
24
+ seed: int | None = None,
25
+ ) -> str:
26
+ self._snapshot_dir.mkdir(parents=True, exist_ok=True)
27
+
28
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
29
+ filename = f"{timestamp}_{table_name}.yaml"
30
+ filepath = self._snapshot_dir / filename
31
+
32
+ snapshot_data = {
33
+ "timestamp": timestamp,
34
+ "table_name": table_name,
35
+ "count": count,
36
+ "seed": seed,
37
+ "config": config.model_dump(mode="json"),
38
+ }
39
+
40
+ with open(filepath, "w", encoding="utf-8") as f:
41
+ yaml.dump(snapshot_data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
42
+
43
+ logger.info("Snapshot saved", filepath=str(filepath))
44
+ return str(filepath)
45
+
46
+ def load(self, snapshot_path: str) -> dict[str, Any]:
47
+ path = Path(snapshot_path)
48
+ if not path.exists():
49
+ raise FileNotFoundError(f"Snapshot not found: {snapshot_path}")
50
+
51
+ with open(path, encoding="utf-8") as f:
52
+ data: dict[str, Any] = yaml.safe_load(f)
53
+
54
+ return data
55
+
56
+ def replay(self, snapshot_path: str) -> Any:
57
+ from sqlseed.core.orchestrator import DataOrchestrator
58
+
59
+ data = self.load(snapshot_path)
60
+ config_data = data["config"]
61
+ config = GeneratorConfig(**config_data)
62
+
63
+ table_name = data["table_name"]
64
+ count = data["count"]
65
+ seed = data.get("seed")
66
+
67
+ table_config = None
68
+ for tc in config.tables:
69
+ if tc.name == table_name:
70
+ table_config = tc
71
+ break
72
+
73
+ with DataOrchestrator(
74
+ db_path=config.db_path,
75
+ provider_name=config.provider.value,
76
+ locale=config.locale,
77
+ optimize_pragma=config.optimize_pragma,
78
+ ) as orch:
79
+ return orch.fill_table(
80
+ table_name=table_name,
81
+ count=count,
82
+ seed=seed,
83
+ batch_size=table_config.batch_size if table_config else 5000,
84
+ clear_before=table_config.clear_before if table_config else False,
85
+ column_configs=table_config.columns if table_config else None,
86
+ )
87
+
88
+ def list_snapshots(self) -> list[str]:
89
+ if not self._snapshot_dir.exists():
90
+ return []
91
+ return sorted(str(p) for p in self._snapshot_dir.glob("*.yaml"))
@@ -0,0 +1,14 @@
1
+ from sqlseed.core.mapper import ColumnMapper, GeneratorSpec
2
+ from sqlseed.core.orchestrator import DataOrchestrator
3
+ from sqlseed.core.relation import RelationResolver
4
+ from sqlseed.core.result import GenerationResult
5
+ from sqlseed.core.schema import SchemaInferrer
6
+
7
+ __all__ = [
8
+ "ColumnMapper",
9
+ "DataOrchestrator",
10
+ "GenerationResult",
11
+ "GeneratorSpec",
12
+ "RelationResolver",
13
+ "SchemaInferrer",
14
+ ]
@@ -0,0 +1,108 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ from sqlseed.core.mapper import GeneratorSpec
7
+
8
+
9
+ @dataclass
10
+ class ColumnConstraints:
11
+ """列级约束"""
12
+
13
+ unique: bool = False
14
+ min_value: int | float | None = None
15
+ max_value: int | float | None = None
16
+ regex: str | None = None
17
+ max_retries: int = 100
18
+
19
+
20
+ @dataclass
21
+ class ColumnNode:
22
+ """DAG 中的一个节点,代表一个列"""
23
+
24
+ name: str
25
+ generator_spec: GeneratorSpec
26
+ depends_on: list[str] = field(default_factory=list) # 依赖的源列名
27
+ expression: str | None = None # 派生表达式
28
+ constraints: ColumnConstraints | None = None # 约束条件
29
+ is_derived: bool = False # 是否为派生列
30
+
31
+ @property
32
+ def is_skip(self) -> bool:
33
+ return self.generator_spec.generator_name == "skip"
34
+
35
+
36
+ class ColumnDAG:
37
+ """构建并管理列依赖图"""
38
+
39
+ def build(
40
+ self,
41
+ specs: dict[str, GeneratorSpec],
42
+ column_configs: list[Any] | None = None,
43
+ ) -> list[ColumnNode]:
44
+ nodes: dict[str, ColumnNode] = {}
45
+ config_map: dict[str, Any] = {}
46
+
47
+ if column_configs:
48
+ for cc in column_configs:
49
+ if hasattr(cc, "name"):
50
+ config_map[cc.name] = cc
51
+
52
+ for col_name, spec in specs.items():
53
+ cc = config_map.get(col_name)
54
+ constraints = None
55
+ expression = None
56
+ depends_on = []
57
+ is_derived = False
58
+ final_spec = spec
59
+
60
+ if cc:
61
+ if hasattr(cc, "constraints") and cc.constraints:
62
+ constraints = ColumnConstraints(
63
+ unique=cc.constraints.unique,
64
+ max_retries=cc.constraints.max_retries,
65
+ )
66
+ if hasattr(cc, "derive_from") and cc.derive_from:
67
+ depends_on = [cc.derive_from]
68
+ expression = cc.expression
69
+ is_derived = True
70
+ final_spec = GeneratorSpec(generator_name="__derive__")
71
+
72
+ nodes[col_name] = ColumnNode(
73
+ name=col_name,
74
+ generator_spec=final_spec,
75
+ depends_on=depends_on,
76
+ expression=expression,
77
+ constraints=constraints,
78
+ is_derived=is_derived,
79
+ )
80
+
81
+ return self._topological_sort(nodes)
82
+
83
+ def _topological_sort(self, nodes: dict[str, ColumnNode]) -> list[ColumnNode]:
84
+ """Kahn 算法拓扑排序"""
85
+ in_degree: dict[str, int] = {name: 0 for name in nodes}
86
+ adjacency: dict[str, list[str]] = {name: [] for name in nodes}
87
+
88
+ for name, node in nodes.items():
89
+ for dep in node.depends_on:
90
+ if dep in adjacency:
91
+ adjacency[dep].append(name)
92
+ in_degree[name] += 1
93
+
94
+ queue = [name for name, deg in in_degree.items() if deg == 0]
95
+ result: list[ColumnNode] = []
96
+
97
+ while queue:
98
+ current = queue.pop(0)
99
+ result.append(nodes[current])
100
+ for neighbor in adjacency.get(current, []):
101
+ in_degree[neighbor] -= 1
102
+ if in_degree[neighbor] == 0:
103
+ queue.append(neighbor)
104
+
105
+ if len(result) != len(nodes):
106
+ raise ValueError("Circular dependency detected in column definitions")
107
+
108
+ return result
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+
7
+ @dataclass
8
+ class RegisterResult:
9
+ registered: bool = True
10
+ need_backtrack: bool = False
11
+ backtrack_targets: list[str] = field(default_factory=list)
12
+
13
+
14
+ class ConstraintSolver:
15
+ """约束求解器,支持回溯和复合唯一约束
16
+
17
+ For large datasets (>100K rows), set probabilistic=True to use
18
+ a hash-based probabilistic set that trades a small false-positive
19
+ rate for significantly reduced memory usage.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ *,
25
+ probabilistic: bool = False,
26
+ expected_count: int = 10000,
27
+ ) -> None:
28
+ self._probabilistic = probabilistic
29
+ self._expected_count = expected_count
30
+ self._seen: dict[str, set[Any]] = {}
31
+ self._composite_seen: dict[str, set[tuple[Any, ...]]] = {}
32
+ if probabilistic:
33
+ self._hash_seen: dict[str, set[int]] = {}
34
+
35
+ def _is_seen(self, column_name: str, value: Any) -> bool:
36
+ if self._probabilistic:
37
+ h = hash(value)
38
+ if column_name not in self._hash_seen:
39
+ self._hash_seen[column_name] = set()
40
+ if h in self._hash_seen[column_name]:
41
+ return True
42
+ self._hash_seen[column_name].add(h)
43
+ return False
44
+ if column_name not in self._seen:
45
+ self._seen[column_name] = set()
46
+ if value in self._seen[column_name]:
47
+ return True
48
+ self._seen[column_name].add(value)
49
+ return False
50
+
51
+ def _unregister_value(self, column_name: str, value: Any) -> None:
52
+ if self._probabilistic:
53
+ if column_name in self._hash_seen:
54
+ self._hash_seen[column_name].discard(hash(value))
55
+ elif column_name in self._seen:
56
+ self._seen[column_name].discard(value)
57
+
58
+ def check_and_register(
59
+ self,
60
+ column_name: str,
61
+ value: Any,
62
+ unique: bool = False,
63
+ ) -> bool:
64
+ if unique:
65
+ return not self._is_seen(column_name, value)
66
+ return True
67
+
68
+ def try_register(
69
+ self,
70
+ column_name: str,
71
+ value: Any,
72
+ unique: bool = False,
73
+ source_columns: list[str] | None = None,
74
+ ) -> RegisterResult:
75
+ if not unique:
76
+ return RegisterResult(registered=True)
77
+
78
+ if self._is_seen(column_name, value):
79
+ return RegisterResult(
80
+ registered=False,
81
+ need_backtrack=True,
82
+ backtrack_targets=source_columns if source_columns else [column_name],
83
+ )
84
+ return RegisterResult(registered=True)
85
+
86
+ def check_composite(
87
+ self,
88
+ key_name: str,
89
+ values: tuple[Any, ...],
90
+ ) -> bool:
91
+ if key_name not in self._composite_seen:
92
+ self._composite_seen[key_name] = set()
93
+ if values in self._composite_seen[key_name]:
94
+ return False
95
+ self._composite_seen[key_name].add(values)
96
+ return True
97
+
98
+ def unregister_composite(
99
+ self,
100
+ key_name: str,
101
+ values: tuple[Any, ...],
102
+ ) -> None:
103
+ if key_name in self._composite_seen:
104
+ self._composite_seen[key_name].discard(values)
105
+
106
+ def reset(self) -> None:
107
+ self._seen.clear()
108
+ self._composite_seen.clear()
109
+ if self._probabilistic:
110
+ self._hash_seen.clear()
111
+
112
+ def reset_column(self, column_name: str) -> None:
113
+ self._seen.pop(column_name, None)
114
+
115
+ def unregister(self, column_name: str, value: Any) -> None:
116
+ self._unregister_value(column_name, value)
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ import threading
4
+ from typing import Any, ClassVar
5
+
6
+ import simpleeval
7
+
8
+ from sqlseed._utils.logger import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class ExpressionTimeoutError(TimeoutError):
14
+ pass
15
+
16
+
17
+ class ExpressionEngine:
18
+ """安全表达式求值器"""
19
+
20
+ SAFE_FUNCTIONS: ClassVar[dict[str, Any]] = {
21
+ "len": len,
22
+ "int": int,
23
+ "str": str,
24
+ "float": float,
25
+ "hex": hex,
26
+ "oct": oct,
27
+ "bin": bin,
28
+ "abs": abs,
29
+ "min": min,
30
+ "max": max,
31
+ "upper": lambda s: s.upper(),
32
+ "lower": lambda s: s.lower(),
33
+ "strip": lambda s: s.strip(),
34
+ "lstrip": lambda s: s.lstrip(),
35
+ "rstrip": lambda s: s.rstrip(),
36
+ "zfill": lambda s, w: str(s).zfill(w),
37
+ "replace": lambda s, old, new: str(s).replace(old, new),
38
+ "substr": lambda s, start, end=None: str(s)[start:end],
39
+ "lpad": lambda s, width, char="0": str(s).rjust(width, char),
40
+ "rpad": lambda s, width, char="0": str(s).ljust(width, char),
41
+ "concat": lambda *args: "".join(str(a) for a in args),
42
+ }
43
+
44
+ def __init__(self, timeout_seconds: int = 5) -> None:
45
+ self._timeout = timeout_seconds
46
+ self._evaluator = simpleeval.SimpleEval()
47
+ self._evaluator.functions = dict(self.SAFE_FUNCTIONS)
48
+
49
+ def evaluate(self, expression: str, context: dict[str, Any]) -> Any:
50
+ self._evaluator.names = context
51
+ result: Any = None
52
+ error: Exception | None = None
53
+
54
+ def _eval() -> None:
55
+ nonlocal result, error
56
+ try:
57
+ result = self._evaluator.eval(expression)
58
+ except Exception as e:
59
+ error = e
60
+
61
+ thread = threading.Thread(target=_eval)
62
+ thread.start()
63
+ thread.join(timeout=self._timeout)
64
+
65
+ if thread.is_alive():
66
+ raise ExpressionTimeoutError(f"Expression evaluation timed out after {self._timeout}s: {expression[:100]}")
67
+
68
+ if error is not None:
69
+ raise error
70
+
71
+ return result