sqlseed 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlseed/__init__.py +121 -0
- sqlseed/_utils/__init__.py +11 -0
- sqlseed/_utils/logger.py +30 -0
- sqlseed/_utils/metrics.py +45 -0
- sqlseed/_utils/progress.py +14 -0
- sqlseed/_utils/schema_helpers.py +51 -0
- sqlseed/_utils/sql_safe.py +45 -0
- sqlseed/_version.py +1 -0
- sqlseed/cli/__init__.py +3 -0
- sqlseed/cli/main.py +316 -0
- sqlseed/config/__init__.py +14 -0
- sqlseed/config/loader.py +66 -0
- sqlseed/config/models.py +99 -0
- sqlseed/config/snapshot.py +91 -0
- sqlseed/core/__init__.py +14 -0
- sqlseed/core/column_dag.py +108 -0
- sqlseed/core/constraints.py +116 -0
- sqlseed/core/expression.py +71 -0
- sqlseed/core/mapper.py +257 -0
- sqlseed/core/orchestrator.py +578 -0
- sqlseed/core/relation.py +124 -0
- sqlseed/core/result.py +23 -0
- sqlseed/core/schema.py +97 -0
- sqlseed/core/transform.py +27 -0
- sqlseed/database/__init__.py +14 -0
- sqlseed/database/_protocol.py +72 -0
- sqlseed/database/optimizer.py +96 -0
- sqlseed/database/raw_sqlite_adapter.py +197 -0
- sqlseed/database/sqlite_utils_adapter.py +183 -0
- sqlseed/generators/__init__.py +11 -0
- sqlseed/generators/_protocol.py +73 -0
- sqlseed/generators/base_provider.py +448 -0
- sqlseed/generators/faker_provider.py +157 -0
- sqlseed/generators/mimesis_provider.py +203 -0
- sqlseed/generators/registry.py +86 -0
- sqlseed/generators/stream.py +157 -0
- sqlseed/py.typed +0 -0
- sqlseed-0.1.0.dist-info/METADATA +934 -0
- sqlseed-0.1.0.dist-info/RECORD +42 -0
- sqlseed-0.1.0.dist-info/WHEEL +4 -0
- sqlseed-0.1.0.dist-info/entry_points.txt +6 -0
- sqlseed-0.1.0.dist-info/licenses/LICENSE +17 -0
sqlseed/config/loader.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from sqlseed._utils.logger import get_logger
|
|
9
|
+
from sqlseed.config.models import GeneratorConfig, TableConfig
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_config(path: str) -> GeneratorConfig:
|
|
15
|
+
config_path = Path(path)
|
|
16
|
+
if not config_path.exists():
|
|
17
|
+
raise FileNotFoundError(f"Configuration file not found: {path}")
|
|
18
|
+
|
|
19
|
+
suffix = config_path.suffix.lower()
|
|
20
|
+
with open(config_path, encoding="utf-8") as f:
|
|
21
|
+
if suffix in (".yaml", ".yml"):
|
|
22
|
+
raw = yaml.safe_load(f)
|
|
23
|
+
elif suffix == ".json":
|
|
24
|
+
raw = json.load(f)
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Unsupported configuration file format: {suffix}")
|
|
27
|
+
|
|
28
|
+
if not isinstance(raw, dict):
|
|
29
|
+
raise ValueError("Configuration file must contain a YAML/JSON object")
|
|
30
|
+
|
|
31
|
+
return GeneratorConfig(**raw)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def save_config(config: GeneratorConfig, path: str) -> None:
|
|
35
|
+
config_path = Path(path)
|
|
36
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
37
|
+
|
|
38
|
+
suffix = config_path.suffix.lower()
|
|
39
|
+
data = config.model_dump(mode="json")
|
|
40
|
+
|
|
41
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
42
|
+
if suffix in (".yaml", ".yml"):
|
|
43
|
+
yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
|
44
|
+
elif suffix == ".json":
|
|
45
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unsupported configuration file format: {suffix}")
|
|
48
|
+
|
|
49
|
+
logger.info("Configuration saved", path=path)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def generate_template(db_path: str, table_name: str | None = None) -> GeneratorConfig:
|
|
53
|
+
tables: list[TableConfig] = []
|
|
54
|
+
if table_name:
|
|
55
|
+
tables.append(
|
|
56
|
+
TableConfig(
|
|
57
|
+
name=table_name,
|
|
58
|
+
count=1000,
|
|
59
|
+
columns=[],
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return GeneratorConfig(
|
|
64
|
+
db_path=db_path,
|
|
65
|
+
tables=tables,
|
|
66
|
+
)
|
sqlseed/config/models.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ProviderType(str, Enum):
|
|
11
|
+
BASE = "base"
|
|
12
|
+
FAKER = "faker"
|
|
13
|
+
MIMESIS = "mimesis"
|
|
14
|
+
CUSTOM = "custom"
|
|
15
|
+
AI = "ai"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ColumnConstraintsConfig(BaseModel):
|
|
19
|
+
"""列约束配置"""
|
|
20
|
+
|
|
21
|
+
unique: bool = False
|
|
22
|
+
min_value: int | float | None = None
|
|
23
|
+
max_value: int | float | None = None
|
|
24
|
+
regex: str | None = None
|
|
25
|
+
max_retries: int = Field(default=100, gt=0)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ColumnConfig(BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
列配置 — 支持源列和派生列两种模式。
|
|
31
|
+
|
|
32
|
+
源列模式:指定 generator + params
|
|
33
|
+
派生列模式:指定 derive_from + expression
|
|
34
|
+
两者不能同时使用。
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
name: str
|
|
38
|
+
|
|
39
|
+
# === 源列模式 ===
|
|
40
|
+
generator: str | None = None
|
|
41
|
+
provider: ProviderType | None = None
|
|
42
|
+
params: dict[str, Any] = Field(default_factory=dict)
|
|
43
|
+
null_ratio: float = Field(default=0.0, ge=0.0, le=1.0)
|
|
44
|
+
|
|
45
|
+
# === 派生列模式 ===
|
|
46
|
+
derive_from: str | None = None # 源列名
|
|
47
|
+
expression: str | None = None # 派生表达式
|
|
48
|
+
|
|
49
|
+
# === 约束 ===
|
|
50
|
+
constraints: ColumnConstraintsConfig | None = None
|
|
51
|
+
|
|
52
|
+
@field_validator("null_ratio")
|
|
53
|
+
@classmethod
|
|
54
|
+
def validate_null_ratio(cls, v: float) -> float:
|
|
55
|
+
if not 0.0 <= v <= 1.0:
|
|
56
|
+
raise ValueError("null_ratio must be between 0.0 and 1.0")
|
|
57
|
+
return v
|
|
58
|
+
|
|
59
|
+
@model_validator(mode="after")
|
|
60
|
+
def validate_column_mode(self) -> Self:
|
|
61
|
+
if self.derive_from and self.generator:
|
|
62
|
+
raise ValueError(f"Column '{self.name}': cannot use both 'generator' and 'derive_from'")
|
|
63
|
+
if self.derive_from and not self.expression:
|
|
64
|
+
raise ValueError(f"Column '{self.name}': 'derive_from' requires 'expression'")
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class TableConfig(BaseModel):
|
|
69
|
+
"""单表生成配置"""
|
|
70
|
+
|
|
71
|
+
name: str
|
|
72
|
+
count: int = Field(default=1000, gt=0)
|
|
73
|
+
batch_size: int = Field(default=5000, gt=0)
|
|
74
|
+
columns: list[ColumnConfig] = Field(default_factory=list)
|
|
75
|
+
clear_before: bool = False # 默认不清空,保护原始数据
|
|
76
|
+
seed: int | None = None
|
|
77
|
+
transform: str | None = None # [NEW] Python 变换脚本路径
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ColumnAssociation(BaseModel):
|
|
81
|
+
"""跨表列关联声明 — 用于隐式关联(同名列跨表引用)"""
|
|
82
|
+
|
|
83
|
+
column_name: str
|
|
84
|
+
source_table: str
|
|
85
|
+
target_tables: list[str] = Field(default_factory=list)
|
|
86
|
+
strategy: str = "shared_pool"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class GeneratorConfig(BaseModel):
|
|
90
|
+
"""全局生成配置"""
|
|
91
|
+
|
|
92
|
+
db_path: str
|
|
93
|
+
provider: ProviderType = ProviderType.MIMESIS
|
|
94
|
+
locale: str = "en_US"
|
|
95
|
+
tables: list[TableConfig] = Field(default_factory=list)
|
|
96
|
+
associations: list[ColumnAssociation] = Field(default_factory=list)
|
|
97
|
+
optimize_pragma: bool = True
|
|
98
|
+
log_level: str = "INFO"
|
|
99
|
+
snapshot_dir: str | None = None
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from sqlseed._utils.logger import get_logger
|
|
10
|
+
from sqlseed.config.models import GeneratorConfig
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SnapshotManager:
|
|
16
|
+
def __init__(self, snapshot_dir: str | None = None) -> None:
|
|
17
|
+
self._snapshot_dir = Path(snapshot_dir) if snapshot_dir else Path("./snapshots")
|
|
18
|
+
|
|
19
|
+
def save(
|
|
20
|
+
self,
|
|
21
|
+
config: GeneratorConfig,
|
|
22
|
+
table_name: str,
|
|
23
|
+
count: int,
|
|
24
|
+
seed: int | None = None,
|
|
25
|
+
) -> str:
|
|
26
|
+
self._snapshot_dir.mkdir(parents=True, exist_ok=True)
|
|
27
|
+
|
|
28
|
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
|
|
29
|
+
filename = f"{timestamp}_{table_name}.yaml"
|
|
30
|
+
filepath = self._snapshot_dir / filename
|
|
31
|
+
|
|
32
|
+
snapshot_data = {
|
|
33
|
+
"timestamp": timestamp,
|
|
34
|
+
"table_name": table_name,
|
|
35
|
+
"count": count,
|
|
36
|
+
"seed": seed,
|
|
37
|
+
"config": config.model_dump(mode="json"),
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
with open(filepath, "w", encoding="utf-8") as f:
|
|
41
|
+
yaml.dump(snapshot_data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
|
42
|
+
|
|
43
|
+
logger.info("Snapshot saved", filepath=str(filepath))
|
|
44
|
+
return str(filepath)
|
|
45
|
+
|
|
46
|
+
def load(self, snapshot_path: str) -> dict[str, Any]:
|
|
47
|
+
path = Path(snapshot_path)
|
|
48
|
+
if not path.exists():
|
|
49
|
+
raise FileNotFoundError(f"Snapshot not found: {snapshot_path}")
|
|
50
|
+
|
|
51
|
+
with open(path, encoding="utf-8") as f:
|
|
52
|
+
data: dict[str, Any] = yaml.safe_load(f)
|
|
53
|
+
|
|
54
|
+
return data
|
|
55
|
+
|
|
56
|
+
def replay(self, snapshot_path: str) -> Any:
|
|
57
|
+
from sqlseed.core.orchestrator import DataOrchestrator
|
|
58
|
+
|
|
59
|
+
data = self.load(snapshot_path)
|
|
60
|
+
config_data = data["config"]
|
|
61
|
+
config = GeneratorConfig(**config_data)
|
|
62
|
+
|
|
63
|
+
table_name = data["table_name"]
|
|
64
|
+
count = data["count"]
|
|
65
|
+
seed = data.get("seed")
|
|
66
|
+
|
|
67
|
+
table_config = None
|
|
68
|
+
for tc in config.tables:
|
|
69
|
+
if tc.name == table_name:
|
|
70
|
+
table_config = tc
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
with DataOrchestrator(
|
|
74
|
+
db_path=config.db_path,
|
|
75
|
+
provider_name=config.provider.value,
|
|
76
|
+
locale=config.locale,
|
|
77
|
+
optimize_pragma=config.optimize_pragma,
|
|
78
|
+
) as orch:
|
|
79
|
+
return orch.fill_table(
|
|
80
|
+
table_name=table_name,
|
|
81
|
+
count=count,
|
|
82
|
+
seed=seed,
|
|
83
|
+
batch_size=table_config.batch_size if table_config else 5000,
|
|
84
|
+
clear_before=table_config.clear_before if table_config else False,
|
|
85
|
+
column_configs=table_config.columns if table_config else None,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def list_snapshots(self) -> list[str]:
|
|
89
|
+
if not self._snapshot_dir.exists():
|
|
90
|
+
return []
|
|
91
|
+
return sorted(str(p) for p in self._snapshot_dir.glob("*.yaml"))
|
sqlseed/core/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from sqlseed.core.mapper import ColumnMapper, GeneratorSpec
|
|
2
|
+
from sqlseed.core.orchestrator import DataOrchestrator
|
|
3
|
+
from sqlseed.core.relation import RelationResolver
|
|
4
|
+
from sqlseed.core.result import GenerationResult
|
|
5
|
+
from sqlseed.core.schema import SchemaInferrer
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"ColumnMapper",
|
|
9
|
+
"DataOrchestrator",
|
|
10
|
+
"GenerationResult",
|
|
11
|
+
"GeneratorSpec",
|
|
12
|
+
"RelationResolver",
|
|
13
|
+
"SchemaInferrer",
|
|
14
|
+
]
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlseed.core.mapper import GeneratorSpec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ColumnConstraints:
|
|
11
|
+
"""列级约束"""
|
|
12
|
+
|
|
13
|
+
unique: bool = False
|
|
14
|
+
min_value: int | float | None = None
|
|
15
|
+
max_value: int | float | None = None
|
|
16
|
+
regex: str | None = None
|
|
17
|
+
max_retries: int = 100
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ColumnNode:
|
|
22
|
+
"""DAG 中的一个节点,代表一个列"""
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
generator_spec: GeneratorSpec
|
|
26
|
+
depends_on: list[str] = field(default_factory=list) # 依赖的源列名
|
|
27
|
+
expression: str | None = None # 派生表达式
|
|
28
|
+
constraints: ColumnConstraints | None = None # 约束条件
|
|
29
|
+
is_derived: bool = False # 是否为派生列
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def is_skip(self) -> bool:
|
|
33
|
+
return self.generator_spec.generator_name == "skip"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ColumnDAG:
|
|
37
|
+
"""构建并管理列依赖图"""
|
|
38
|
+
|
|
39
|
+
def build(
|
|
40
|
+
self,
|
|
41
|
+
specs: dict[str, GeneratorSpec],
|
|
42
|
+
column_configs: list[Any] | None = None,
|
|
43
|
+
) -> list[ColumnNode]:
|
|
44
|
+
nodes: dict[str, ColumnNode] = {}
|
|
45
|
+
config_map: dict[str, Any] = {}
|
|
46
|
+
|
|
47
|
+
if column_configs:
|
|
48
|
+
for cc in column_configs:
|
|
49
|
+
if hasattr(cc, "name"):
|
|
50
|
+
config_map[cc.name] = cc
|
|
51
|
+
|
|
52
|
+
for col_name, spec in specs.items():
|
|
53
|
+
cc = config_map.get(col_name)
|
|
54
|
+
constraints = None
|
|
55
|
+
expression = None
|
|
56
|
+
depends_on = []
|
|
57
|
+
is_derived = False
|
|
58
|
+
final_spec = spec
|
|
59
|
+
|
|
60
|
+
if cc:
|
|
61
|
+
if hasattr(cc, "constraints") and cc.constraints:
|
|
62
|
+
constraints = ColumnConstraints(
|
|
63
|
+
unique=cc.constraints.unique,
|
|
64
|
+
max_retries=cc.constraints.max_retries,
|
|
65
|
+
)
|
|
66
|
+
if hasattr(cc, "derive_from") and cc.derive_from:
|
|
67
|
+
depends_on = [cc.derive_from]
|
|
68
|
+
expression = cc.expression
|
|
69
|
+
is_derived = True
|
|
70
|
+
final_spec = GeneratorSpec(generator_name="__derive__")
|
|
71
|
+
|
|
72
|
+
nodes[col_name] = ColumnNode(
|
|
73
|
+
name=col_name,
|
|
74
|
+
generator_spec=final_spec,
|
|
75
|
+
depends_on=depends_on,
|
|
76
|
+
expression=expression,
|
|
77
|
+
constraints=constraints,
|
|
78
|
+
is_derived=is_derived,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return self._topological_sort(nodes)
|
|
82
|
+
|
|
83
|
+
def _topological_sort(self, nodes: dict[str, ColumnNode]) -> list[ColumnNode]:
|
|
84
|
+
"""Kahn 算法拓扑排序"""
|
|
85
|
+
in_degree: dict[str, int] = {name: 0 for name in nodes}
|
|
86
|
+
adjacency: dict[str, list[str]] = {name: [] for name in nodes}
|
|
87
|
+
|
|
88
|
+
for name, node in nodes.items():
|
|
89
|
+
for dep in node.depends_on:
|
|
90
|
+
if dep in adjacency:
|
|
91
|
+
adjacency[dep].append(name)
|
|
92
|
+
in_degree[name] += 1
|
|
93
|
+
|
|
94
|
+
queue = [name for name, deg in in_degree.items() if deg == 0]
|
|
95
|
+
result: list[ColumnNode] = []
|
|
96
|
+
|
|
97
|
+
while queue:
|
|
98
|
+
current = queue.pop(0)
|
|
99
|
+
result.append(nodes[current])
|
|
100
|
+
for neighbor in adjacency.get(current, []):
|
|
101
|
+
in_degree[neighbor] -= 1
|
|
102
|
+
if in_degree[neighbor] == 0:
|
|
103
|
+
queue.append(neighbor)
|
|
104
|
+
|
|
105
|
+
if len(result) != len(nodes):
|
|
106
|
+
raise ValueError("Circular dependency detected in column definitions")
|
|
107
|
+
|
|
108
|
+
return result
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class RegisterResult:
|
|
9
|
+
registered: bool = True
|
|
10
|
+
need_backtrack: bool = False
|
|
11
|
+
backtrack_targets: list[str] = field(default_factory=list)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConstraintSolver:
|
|
15
|
+
"""约束求解器,支持回溯和复合唯一约束
|
|
16
|
+
|
|
17
|
+
For large datasets (>100K rows), set probabilistic=True to use
|
|
18
|
+
a hash-based probabilistic set that trades a small false-positive
|
|
19
|
+
rate for significantly reduced memory usage.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
probabilistic: bool = False,
|
|
26
|
+
expected_count: int = 10000,
|
|
27
|
+
) -> None:
|
|
28
|
+
self._probabilistic = probabilistic
|
|
29
|
+
self._expected_count = expected_count
|
|
30
|
+
self._seen: dict[str, set[Any]] = {}
|
|
31
|
+
self._composite_seen: dict[str, set[tuple[Any, ...]]] = {}
|
|
32
|
+
if probabilistic:
|
|
33
|
+
self._hash_seen: dict[str, set[int]] = {}
|
|
34
|
+
|
|
35
|
+
def _is_seen(self, column_name: str, value: Any) -> bool:
|
|
36
|
+
if self._probabilistic:
|
|
37
|
+
h = hash(value)
|
|
38
|
+
if column_name not in self._hash_seen:
|
|
39
|
+
self._hash_seen[column_name] = set()
|
|
40
|
+
if h in self._hash_seen[column_name]:
|
|
41
|
+
return True
|
|
42
|
+
self._hash_seen[column_name].add(h)
|
|
43
|
+
return False
|
|
44
|
+
if column_name not in self._seen:
|
|
45
|
+
self._seen[column_name] = set()
|
|
46
|
+
if value in self._seen[column_name]:
|
|
47
|
+
return True
|
|
48
|
+
self._seen[column_name].add(value)
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
def _unregister_value(self, column_name: str, value: Any) -> None:
|
|
52
|
+
if self._probabilistic:
|
|
53
|
+
if column_name in self._hash_seen:
|
|
54
|
+
self._hash_seen[column_name].discard(hash(value))
|
|
55
|
+
elif column_name in self._seen:
|
|
56
|
+
self._seen[column_name].discard(value)
|
|
57
|
+
|
|
58
|
+
def check_and_register(
|
|
59
|
+
self,
|
|
60
|
+
column_name: str,
|
|
61
|
+
value: Any,
|
|
62
|
+
unique: bool = False,
|
|
63
|
+
) -> bool:
|
|
64
|
+
if unique:
|
|
65
|
+
return not self._is_seen(column_name, value)
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
def try_register(
|
|
69
|
+
self,
|
|
70
|
+
column_name: str,
|
|
71
|
+
value: Any,
|
|
72
|
+
unique: bool = False,
|
|
73
|
+
source_columns: list[str] | None = None,
|
|
74
|
+
) -> RegisterResult:
|
|
75
|
+
if not unique:
|
|
76
|
+
return RegisterResult(registered=True)
|
|
77
|
+
|
|
78
|
+
if self._is_seen(column_name, value):
|
|
79
|
+
return RegisterResult(
|
|
80
|
+
registered=False,
|
|
81
|
+
need_backtrack=True,
|
|
82
|
+
backtrack_targets=source_columns if source_columns else [column_name],
|
|
83
|
+
)
|
|
84
|
+
return RegisterResult(registered=True)
|
|
85
|
+
|
|
86
|
+
def check_composite(
|
|
87
|
+
self,
|
|
88
|
+
key_name: str,
|
|
89
|
+
values: tuple[Any, ...],
|
|
90
|
+
) -> bool:
|
|
91
|
+
if key_name not in self._composite_seen:
|
|
92
|
+
self._composite_seen[key_name] = set()
|
|
93
|
+
if values in self._composite_seen[key_name]:
|
|
94
|
+
return False
|
|
95
|
+
self._composite_seen[key_name].add(values)
|
|
96
|
+
return True
|
|
97
|
+
|
|
98
|
+
def unregister_composite(
|
|
99
|
+
self,
|
|
100
|
+
key_name: str,
|
|
101
|
+
values: tuple[Any, ...],
|
|
102
|
+
) -> None:
|
|
103
|
+
if key_name in self._composite_seen:
|
|
104
|
+
self._composite_seen[key_name].discard(values)
|
|
105
|
+
|
|
106
|
+
def reset(self) -> None:
|
|
107
|
+
self._seen.clear()
|
|
108
|
+
self._composite_seen.clear()
|
|
109
|
+
if self._probabilistic:
|
|
110
|
+
self._hash_seen.clear()
|
|
111
|
+
|
|
112
|
+
def reset_column(self, column_name: str) -> None:
|
|
113
|
+
self._seen.pop(column_name, None)
|
|
114
|
+
|
|
115
|
+
def unregister(self, column_name: str, value: Any) -> None:
|
|
116
|
+
self._unregister_value(column_name, value)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import threading
|
|
4
|
+
from typing import Any, ClassVar
|
|
5
|
+
|
|
6
|
+
import simpleeval
|
|
7
|
+
|
|
8
|
+
from sqlseed._utils.logger import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ExpressionTimeoutError(TimeoutError):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ExpressionEngine:
|
|
18
|
+
"""安全表达式求值器"""
|
|
19
|
+
|
|
20
|
+
SAFE_FUNCTIONS: ClassVar[dict[str, Any]] = {
|
|
21
|
+
"len": len,
|
|
22
|
+
"int": int,
|
|
23
|
+
"str": str,
|
|
24
|
+
"float": float,
|
|
25
|
+
"hex": hex,
|
|
26
|
+
"oct": oct,
|
|
27
|
+
"bin": bin,
|
|
28
|
+
"abs": abs,
|
|
29
|
+
"min": min,
|
|
30
|
+
"max": max,
|
|
31
|
+
"upper": lambda s: s.upper(),
|
|
32
|
+
"lower": lambda s: s.lower(),
|
|
33
|
+
"strip": lambda s: s.strip(),
|
|
34
|
+
"lstrip": lambda s: s.lstrip(),
|
|
35
|
+
"rstrip": lambda s: s.rstrip(),
|
|
36
|
+
"zfill": lambda s, w: str(s).zfill(w),
|
|
37
|
+
"replace": lambda s, old, new: str(s).replace(old, new),
|
|
38
|
+
"substr": lambda s, start, end=None: str(s)[start:end],
|
|
39
|
+
"lpad": lambda s, width, char="0": str(s).rjust(width, char),
|
|
40
|
+
"rpad": lambda s, width, char="0": str(s).ljust(width, char),
|
|
41
|
+
"concat": lambda *args: "".join(str(a) for a in args),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
def __init__(self, timeout_seconds: int = 5) -> None:
|
|
45
|
+
self._timeout = timeout_seconds
|
|
46
|
+
self._evaluator = simpleeval.SimpleEval()
|
|
47
|
+
self._evaluator.functions = dict(self.SAFE_FUNCTIONS)
|
|
48
|
+
|
|
49
|
+
def evaluate(self, expression: str, context: dict[str, Any]) -> Any:
|
|
50
|
+
self._evaluator.names = context
|
|
51
|
+
result: Any = None
|
|
52
|
+
error: Exception | None = None
|
|
53
|
+
|
|
54
|
+
def _eval() -> None:
|
|
55
|
+
nonlocal result, error
|
|
56
|
+
try:
|
|
57
|
+
result = self._evaluator.eval(expression)
|
|
58
|
+
except Exception as e:
|
|
59
|
+
error = e
|
|
60
|
+
|
|
61
|
+
thread = threading.Thread(target=_eval)
|
|
62
|
+
thread.start()
|
|
63
|
+
thread.join(timeout=self._timeout)
|
|
64
|
+
|
|
65
|
+
if thread.is_alive():
|
|
66
|
+
raise ExpressionTimeoutError(f"Expression evaluation timed out after {self._timeout}s: {expression[:100]}")
|
|
67
|
+
|
|
68
|
+
if error is not None:
|
|
69
|
+
raise error
|
|
70
|
+
|
|
71
|
+
return result
|