sqlseed 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlseed/__init__.py +121 -0
- sqlseed/_utils/__init__.py +11 -0
- sqlseed/_utils/logger.py +30 -0
- sqlseed/_utils/metrics.py +45 -0
- sqlseed/_utils/progress.py +14 -0
- sqlseed/_utils/schema_helpers.py +51 -0
- sqlseed/_utils/sql_safe.py +45 -0
- sqlseed/_version.py +1 -0
- sqlseed/cli/__init__.py +3 -0
- sqlseed/cli/main.py +316 -0
- sqlseed/config/__init__.py +14 -0
- sqlseed/config/loader.py +66 -0
- sqlseed/config/models.py +99 -0
- sqlseed/config/snapshot.py +91 -0
- sqlseed/core/__init__.py +14 -0
- sqlseed/core/column_dag.py +108 -0
- sqlseed/core/constraints.py +116 -0
- sqlseed/core/expression.py +71 -0
- sqlseed/core/mapper.py +257 -0
- sqlseed/core/orchestrator.py +578 -0
- sqlseed/core/relation.py +124 -0
- sqlseed/core/result.py +23 -0
- sqlseed/core/schema.py +97 -0
- sqlseed/core/transform.py +27 -0
- sqlseed/database/__init__.py +14 -0
- sqlseed/database/_protocol.py +72 -0
- sqlseed/database/optimizer.py +96 -0
- sqlseed/database/raw_sqlite_adapter.py +197 -0
- sqlseed/database/sqlite_utils_adapter.py +183 -0
- sqlseed/generators/__init__.py +11 -0
- sqlseed/generators/_protocol.py +73 -0
- sqlseed/generators/base_provider.py +448 -0
- sqlseed/generators/faker_provider.py +157 -0
- sqlseed/generators/mimesis_provider.py +203 -0
- sqlseed/generators/registry.py +86 -0
- sqlseed/generators/stream.py +157 -0
- sqlseed/py.typed +0 -0
- sqlseed-0.1.0.dist-info/METADATA +934 -0
- sqlseed-0.1.0.dist-info/RECORD +42 -0
- sqlseed-0.1.0.dist-info/WHEEL +4 -0
- sqlseed-0.1.0.dist-info/entry_points.txt +6 -0
- sqlseed-0.1.0.dist-info/licenses/LICENSE +17 -0
sqlseed/core/result.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class GenerationResult:
|
|
8
|
+
table_name: str
|
|
9
|
+
count: int
|
|
10
|
+
elapsed: float
|
|
11
|
+
rows_per_second: float = 0.0
|
|
12
|
+
batch_count: int = 0
|
|
13
|
+
errors: list[str] = field(default_factory=list)
|
|
14
|
+
|
|
15
|
+
def __post_init__(self) -> None:
|
|
16
|
+
if self.count > 0 and self.elapsed > 0:
|
|
17
|
+
self.rows_per_second = self.count / self.elapsed
|
|
18
|
+
|
|
19
|
+
def __str__(self) -> str:
|
|
20
|
+
return (
|
|
21
|
+
f"GenerationResult(table={self.table_name}, count={self.count}, "
|
|
22
|
+
f"elapsed={self.elapsed:.2f}s, speed={self.rows_per_second:.0f} rows/s)"
|
|
23
|
+
)
|
sqlseed/core/schema.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from sqlseed.database._protocol import ColumnInfo, ForeignKeyInfo, IndexInfo
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SchemaInferrer:
|
|
10
|
+
def __init__(self, db_adapter: Any) -> None:
|
|
11
|
+
self._db = db_adapter
|
|
12
|
+
|
|
13
|
+
def get_column_info(self, table_name: str) -> list[ColumnInfo]:
|
|
14
|
+
return list(self._db.get_column_info(table_name))
|
|
15
|
+
|
|
16
|
+
def get_foreign_keys(self, table_name: str) -> list[ForeignKeyInfo]:
|
|
17
|
+
return list(self._db.get_foreign_keys(table_name))
|
|
18
|
+
|
|
19
|
+
def get_table_names(self) -> list[str]:
|
|
20
|
+
return list(self._db.get_table_names())
|
|
21
|
+
|
|
22
|
+
def get_primary_keys(self, table_name: str) -> list[str]:
|
|
23
|
+
return list(self._db.get_primary_keys(table_name))
|
|
24
|
+
|
|
25
|
+
def get_table_schema(self, table_name: str) -> dict[str, ColumnInfo]:
|
|
26
|
+
columns = self.get_column_info(table_name)
|
|
27
|
+
return {col.name: col for col in columns}
|
|
28
|
+
|
|
29
|
+
def get_index_info(self, table_name: str) -> list[IndexInfo]:
|
|
30
|
+
return list(self._db.get_index_info(table_name))
|
|
31
|
+
|
|
32
|
+
def get_sample_data(self, table_name: str, limit: int = 5) -> list[dict[str, Any]]:
|
|
33
|
+
result = self._db.get_sample_rows(table_name, limit=limit)
|
|
34
|
+
return cast("list[dict[str, Any]]", result)
|
|
35
|
+
|
|
36
|
+
def profile_column_distribution(
|
|
37
|
+
self,
|
|
38
|
+
table_name: str,
|
|
39
|
+
limit: int = 1000,
|
|
40
|
+
) -> list[dict[str, Any]]:
|
|
41
|
+
columns = self.get_column_info(table_name)
|
|
42
|
+
row_count = self._db.get_row_count(table_name)
|
|
43
|
+
|
|
44
|
+
if row_count == 0:
|
|
45
|
+
return []
|
|
46
|
+
|
|
47
|
+
profiles: list[dict[str, Any]] = []
|
|
48
|
+
for col in columns:
|
|
49
|
+
if col.is_primary_key and col.is_autoincrement:
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
profile = self._profile_single_column(table_name, col.name, row_count, limit)
|
|
53
|
+
profiles.append(profile)
|
|
54
|
+
|
|
55
|
+
return profiles
|
|
56
|
+
|
|
57
|
+
def _profile_single_column(
|
|
58
|
+
self,
|
|
59
|
+
table_name: str,
|
|
60
|
+
column_name: str,
|
|
61
|
+
total_rows: int,
|
|
62
|
+
limit: int,
|
|
63
|
+
) -> dict[str, Any]:
|
|
64
|
+
profile: dict[str, Any] = {"column": column_name}
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
values = self._db.get_column_values(table_name, column_name, limit=limit)
|
|
68
|
+
|
|
69
|
+
null_count = sum(1 for v in values if v is None)
|
|
70
|
+
non_null_values = [v for v in values if v is not None]
|
|
71
|
+
|
|
72
|
+
profile["null_ratio"] = round(null_count / len(values), 3) if values else 0.0
|
|
73
|
+
profile["distinct_count"] = len(set(non_null_values))
|
|
74
|
+
profile["sample_size"] = len(values)
|
|
75
|
+
profile["total_rows"] = total_rows
|
|
76
|
+
|
|
77
|
+
if non_null_values:
|
|
78
|
+
from collections import Counter
|
|
79
|
+
|
|
80
|
+
counter = Counter(non_null_values)
|
|
81
|
+
top5 = counter.most_common(5)
|
|
82
|
+
profile["top_values"] = [
|
|
83
|
+
{"value": str(v)[:50], "frequency": round(c / len(non_null_values), 3)} for v, c in top5
|
|
84
|
+
]
|
|
85
|
+
else:
|
|
86
|
+
profile["top_values"] = []
|
|
87
|
+
|
|
88
|
+
numeric_values = [v for v in non_null_values if isinstance(v, (int, float))]
|
|
89
|
+
if numeric_values:
|
|
90
|
+
profile["value_range"] = {"min": min(numeric_values), "max": max(numeric_values)}
|
|
91
|
+
else:
|
|
92
|
+
profile["value_range"] = None
|
|
93
|
+
|
|
94
|
+
except Exception:
|
|
95
|
+
profile["error"] = "failed to profile"
|
|
96
|
+
|
|
97
|
+
return profile
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib.util
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Protocol, cast
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RowTransformFn(Protocol):
|
|
9
|
+
def __call__(self, row: dict[str, Any], ctx: dict[str, Any]) -> dict[str, Any]: ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_transform(script_path: str) -> RowTransformFn:
|
|
13
|
+
path = Path(script_path)
|
|
14
|
+
if not path.exists():
|
|
15
|
+
raise FileNotFoundError(f"Transform script not found: {script_path}")
|
|
16
|
+
|
|
17
|
+
spec = importlib.util.spec_from_file_location("user_transform", str(path))
|
|
18
|
+
if spec is None or spec.loader is None:
|
|
19
|
+
raise ImportError(f"Cannot load transform script: {script_path}")
|
|
20
|
+
|
|
21
|
+
module = importlib.util.module_from_spec(spec)
|
|
22
|
+
spec.loader.exec_module(module)
|
|
23
|
+
|
|
24
|
+
fn = getattr(module, "transform_row", None)
|
|
25
|
+
if fn is None:
|
|
26
|
+
raise AttributeError(f"Transform script must define a 'transform_row(row, ctx)' function: {script_path}")
|
|
27
|
+
return cast("RowTransformFn", fn)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from sqlseed.database._protocol import ColumnInfo, DatabaseAdapter, ForeignKeyInfo
|
|
2
|
+
from sqlseed.database.optimizer import PragmaOptimizer, PragmaProfile
|
|
3
|
+
from sqlseed.database.raw_sqlite_adapter import RawSQLiteAdapter
|
|
4
|
+
from sqlseed.database.sqlite_utils_adapter import SQLiteUtilsAdapter
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ColumnInfo",
|
|
8
|
+
"DatabaseAdapter",
|
|
9
|
+
"ForeignKeyInfo",
|
|
10
|
+
"PragmaOptimizer",
|
|
11
|
+
"PragmaProfile",
|
|
12
|
+
"RawSQLiteAdapter",
|
|
13
|
+
"SQLiteUtilsAdapter",
|
|
14
|
+
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class ColumnInfo:
|
|
12
|
+
name: str
|
|
13
|
+
type: str
|
|
14
|
+
nullable: bool
|
|
15
|
+
default: Any
|
|
16
|
+
is_primary_key: bool
|
|
17
|
+
is_autoincrement: bool
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class ForeignKeyInfo:
|
|
22
|
+
column: str
|
|
23
|
+
ref_table: str
|
|
24
|
+
ref_column: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class IndexInfo:
|
|
29
|
+
name: str
|
|
30
|
+
table: str
|
|
31
|
+
columns: list[str]
|
|
32
|
+
unique: bool
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@runtime_checkable
|
|
36
|
+
class DatabaseAdapter(Protocol):
|
|
37
|
+
def connect(self, db_path: str) -> None: ...
|
|
38
|
+
|
|
39
|
+
def close(self) -> None: ...
|
|
40
|
+
|
|
41
|
+
def get_table_names(self) -> list[str]: ...
|
|
42
|
+
|
|
43
|
+
def get_column_info(self, table_name: str) -> list[ColumnInfo]: ...
|
|
44
|
+
|
|
45
|
+
def get_primary_keys(self, table_name: str) -> list[str]: ...
|
|
46
|
+
|
|
47
|
+
def get_foreign_keys(self, table_name: str) -> list[ForeignKeyInfo]: ...
|
|
48
|
+
|
|
49
|
+
def get_row_count(self, table_name: str) -> int: ...
|
|
50
|
+
|
|
51
|
+
def get_column_values(self, table_name: str, column_name: str, limit: int = 1000) -> list[Any]: ...
|
|
52
|
+
|
|
53
|
+
def get_index_info(self, table_name: str) -> list[IndexInfo]: ...
|
|
54
|
+
|
|
55
|
+
def get_sample_rows(self, table_name: str, limit: int = 5) -> list[dict[str, Any]]: ...
|
|
56
|
+
|
|
57
|
+
def batch_insert(
|
|
58
|
+
self,
|
|
59
|
+
table_name: str,
|
|
60
|
+
data: Iterator[dict[str, Any]],
|
|
61
|
+
batch_size: int = 5000,
|
|
62
|
+
) -> int: ...
|
|
63
|
+
|
|
64
|
+
def clear_table(self, table_name: str) -> None: ...
|
|
65
|
+
|
|
66
|
+
def optimize_for_bulk_write(self, expected_rows: int | None = None) -> None: ...
|
|
67
|
+
|
|
68
|
+
def restore_settings(self) -> None: ...
|
|
69
|
+
|
|
70
|
+
def __enter__(self) -> DatabaseAdapter: ...
|
|
71
|
+
|
|
72
|
+
def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any) -> None: ...
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from sqlseed._utils.logger import get_logger
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class PragmaProfile:
|
|
14
|
+
synchronous: Any = None
|
|
15
|
+
journal_mode: Any = None
|
|
16
|
+
cache_size: Any = None
|
|
17
|
+
temp_store: Any = None
|
|
18
|
+
auto_vacuum: Any = None
|
|
19
|
+
page_size: Any = None
|
|
20
|
+
mmap_size: Any = None
|
|
21
|
+
wal_autocheckpoint: Any = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PragmaOptimizer:
|
|
25
|
+
def __init__(self, execute_fn: Any, fetch_pragma_fn: Any) -> None:
|
|
26
|
+
self._execute = execute_fn
|
|
27
|
+
self._fetch_pragma = fetch_pragma_fn
|
|
28
|
+
self._original: PragmaProfile | None = None
|
|
29
|
+
|
|
30
|
+
def preserve(self) -> None:
|
|
31
|
+
self._original = PragmaProfile(
|
|
32
|
+
synchronous=self._fetch_pragma("synchronous"),
|
|
33
|
+
journal_mode=self._fetch_pragma("journal_mode"),
|
|
34
|
+
cache_size=self._fetch_pragma("cache_size"),
|
|
35
|
+
temp_store=self._fetch_pragma("temp_store"),
|
|
36
|
+
auto_vacuum=self._fetch_pragma("auto_vacuum"),
|
|
37
|
+
page_size=self._fetch_pragma("page_size"),
|
|
38
|
+
mmap_size=self._fetch_pragma("mmap_size"),
|
|
39
|
+
)
|
|
40
|
+
logger.debug("Preserved PRAGMA config", config=self._original)
|
|
41
|
+
|
|
42
|
+
def optimize(self, expected_rows: int | None = None) -> None:
|
|
43
|
+
if expected_rows is None:
|
|
44
|
+
expected_rows = 10000
|
|
45
|
+
|
|
46
|
+
if expected_rows > 100000:
|
|
47
|
+
self._apply_aggressive()
|
|
48
|
+
elif expected_rows > 10000:
|
|
49
|
+
self._apply_moderate()
|
|
50
|
+
else:
|
|
51
|
+
self._apply_light()
|
|
52
|
+
|
|
53
|
+
def _apply_light(self) -> None:
|
|
54
|
+
self._execute("PRAGMA synchronous = NORMAL")
|
|
55
|
+
self._execute("PRAGMA temp_store = MEMORY")
|
|
56
|
+
self._execute("PRAGMA cache_size = -8000")
|
|
57
|
+
logger.debug("Applied LIGHT PRAGMA optimization")
|
|
58
|
+
|
|
59
|
+
def _apply_moderate(self) -> None:
|
|
60
|
+
self._execute("PRAGMA synchronous = OFF")
|
|
61
|
+
self._execute("PRAGMA journal_mode = MEMORY")
|
|
62
|
+
self._execute("PRAGMA temp_store = MEMORY")
|
|
63
|
+
self._execute("PRAGMA cache_size = -16000")
|
|
64
|
+
self._execute("PRAGMA mmap_size = 268435456")
|
|
65
|
+
logger.debug("Applied MODERATE PRAGMA optimization")
|
|
66
|
+
|
|
67
|
+
def _apply_aggressive(self) -> None:
|
|
68
|
+
self._execute("PRAGMA synchronous = OFF")
|
|
69
|
+
self._execute("PRAGMA journal_mode = OFF")
|
|
70
|
+
self._execute("PRAGMA temp_store = MEMORY")
|
|
71
|
+
self._execute("PRAGMA cache_size = -32000")
|
|
72
|
+
self._execute("PRAGMA mmap_size = 536870912")
|
|
73
|
+
self._execute("PRAGMA page_size = 4096")
|
|
74
|
+
logger.debug("Applied AGGRESSIVE PRAGMA optimization")
|
|
75
|
+
|
|
76
|
+
def restore(self) -> None:
|
|
77
|
+
if self._original is None:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
for attr in [
|
|
81
|
+
"synchronous",
|
|
82
|
+
"journal_mode",
|
|
83
|
+
"cache_size",
|
|
84
|
+
"temp_store",
|
|
85
|
+
"auto_vacuum",
|
|
86
|
+
"page_size",
|
|
87
|
+
"mmap_size",
|
|
88
|
+
]:
|
|
89
|
+
value = getattr(self._original, attr)
|
|
90
|
+
if value is not None and (
|
|
91
|
+
isinstance(value, (int, float)) or (isinstance(value, str) and re.match(r"^[a-zA-Z0-9_-]+$", value))
|
|
92
|
+
):
|
|
93
|
+
self._execute(f"PRAGMA {attr} = {value}")
|
|
94
|
+
|
|
95
|
+
logger.debug("Restored original PRAGMA config")
|
|
96
|
+
self._original = None
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sqlite3
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from sqlseed._utils.logger import get_logger
|
|
9
|
+
from sqlseed._utils.sql_safe import build_insert_sql, quote_identifier
|
|
10
|
+
from sqlseed.database._protocol import ColumnInfo, ForeignKeyInfo, IndexInfo
|
|
11
|
+
from sqlseed.database.optimizer import PragmaOptimizer
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Iterator
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RawSQLiteAdapter:
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self._conn: sqlite3.Connection | None = None
|
|
22
|
+
self._optimizer: PragmaOptimizer | None = None
|
|
23
|
+
self._db_path: str = ""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def conn(self) -> sqlite3.Connection:
|
|
27
|
+
assert self._conn is not None, "Database not connected. Call connect() first."
|
|
28
|
+
return self._conn
|
|
29
|
+
|
|
30
|
+
def connect(self, db_path: str) -> None:
|
|
31
|
+
self._db_path = db_path
|
|
32
|
+
self._conn = sqlite3.connect(db_path)
|
|
33
|
+
self._conn.execute("PRAGMA foreign_keys = ON")
|
|
34
|
+
self._optimizer = PragmaOptimizer(
|
|
35
|
+
execute_fn=self._execute_pragma,
|
|
36
|
+
fetch_pragma_fn=self._fetch_pragma,
|
|
37
|
+
)
|
|
38
|
+
logger.debug("Connected to database via raw sqlite3", db_path=db_path)
|
|
39
|
+
|
|
40
|
+
def close(self) -> None:
|
|
41
|
+
if self._conn is not None:
|
|
42
|
+
self._conn.close()
|
|
43
|
+
self._conn = None
|
|
44
|
+
logger.debug("Closed raw sqlite3 connection", db_path=self._db_path)
|
|
45
|
+
|
|
46
|
+
def get_table_names(self) -> list[str]:
|
|
47
|
+
cursor = self.conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
|
48
|
+
return [row[0] for row in cursor.fetchall()]
|
|
49
|
+
|
|
50
|
+
def get_column_info(self, table_name: str) -> list[ColumnInfo]:
|
|
51
|
+
pks = set(self.get_primary_keys(table_name))
|
|
52
|
+
fks = {fk.column for fk in self.get_foreign_keys(table_name)}
|
|
53
|
+
|
|
54
|
+
cursor = self.conn.execute(f"PRAGMA table_info({quote_identifier(table_name)})")
|
|
55
|
+
result: list[ColumnInfo] = []
|
|
56
|
+
for row in cursor.fetchall():
|
|
57
|
+
_cid, name, col_type, notnull, default_val, _is_pk = row
|
|
58
|
+
is_pk_flag = name in pks
|
|
59
|
+
is_autoincrement = is_pk_flag and self._is_autoincrement(table_name, name)
|
|
60
|
+
result.append(
|
|
61
|
+
ColumnInfo(
|
|
62
|
+
name=name,
|
|
63
|
+
type=col_type.upper() if col_type else "TEXT",
|
|
64
|
+
nullable=not notnull and name not in fks,
|
|
65
|
+
default=default_val,
|
|
66
|
+
is_primary_key=is_pk_flag,
|
|
67
|
+
is_autoincrement=is_autoincrement,
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
def get_primary_keys(self, table_name: str) -> list[str]:
|
|
73
|
+
cursor = self.conn.execute(f"PRAGMA table_info({quote_identifier(table_name)})")
|
|
74
|
+
pks: list[str] = []
|
|
75
|
+
for row in cursor.fetchall():
|
|
76
|
+
_, name, _, _, _, is_pk = row
|
|
77
|
+
if is_pk:
|
|
78
|
+
pks.append(name)
|
|
79
|
+
return pks
|
|
80
|
+
|
|
81
|
+
def get_foreign_keys(self, table_name: str) -> list[ForeignKeyInfo]:
|
|
82
|
+
cursor = self.conn.execute(f"PRAGMA foreign_key_list({quote_identifier(table_name)})")
|
|
83
|
+
result: list[ForeignKeyInfo] = []
|
|
84
|
+
for row in cursor.fetchall():
|
|
85
|
+
_, _, ref_table, from_col, to_col, *_ = row
|
|
86
|
+
result.append(
|
|
87
|
+
ForeignKeyInfo(
|
|
88
|
+
column=from_col,
|
|
89
|
+
ref_table=ref_table,
|
|
90
|
+
ref_column=to_col,
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
def get_row_count(self, table_name: str) -> int:
|
|
96
|
+
safe_table = quote_identifier(table_name)
|
|
97
|
+
cursor = self.conn.execute(f"SELECT COUNT(*) FROM {safe_table}")
|
|
98
|
+
return int(cursor.fetchone()[0])
|
|
99
|
+
|
|
100
|
+
def get_column_values(self, table_name: str, column_name: str, limit: int = 1000) -> list[Any]:
|
|
101
|
+
safe_table = quote_identifier(table_name)
|
|
102
|
+
safe_column = quote_identifier(column_name)
|
|
103
|
+
cursor = self.conn.execute(
|
|
104
|
+
f"SELECT {safe_column} FROM {safe_table} LIMIT ?",
|
|
105
|
+
[limit],
|
|
106
|
+
)
|
|
107
|
+
return [row[0] for row in cursor.fetchall()]
|
|
108
|
+
|
|
109
|
+
def get_index_info(self, table_name: str) -> list[IndexInfo]:
|
|
110
|
+
safe_table = quote_identifier(table_name)
|
|
111
|
+
cursor = self.conn.execute(f"PRAGMA index_list({safe_table})")
|
|
112
|
+
result: list[IndexInfo] = []
|
|
113
|
+
for row in cursor.fetchall():
|
|
114
|
+
idx_name = row[1]
|
|
115
|
+
is_unique = bool(row[2])
|
|
116
|
+
if idx_name.startswith("sqlite_autoindex_"):
|
|
117
|
+
continue
|
|
118
|
+
col_cursor = self.conn.execute(f"PRAGMA index_info({quote_identifier(idx_name)})")
|
|
119
|
+
columns = [cr[2] for cr in col_cursor.fetchall() if cr[2] is not None]
|
|
120
|
+
result.append(IndexInfo(name=idx_name, table=table_name, columns=columns, unique=is_unique))
|
|
121
|
+
return result
|
|
122
|
+
|
|
123
|
+
def get_sample_rows(self, table_name: str, limit: int = 5) -> list[dict[str, Any]]:
|
|
124
|
+
safe_table = quote_identifier(table_name)
|
|
125
|
+
columns = self.get_column_info(table_name)
|
|
126
|
+
col_names = [quote_identifier(c.name) for c in columns]
|
|
127
|
+
cols_sql = ", ".join(col_names)
|
|
128
|
+
cursor = self.conn.execute(f"SELECT {cols_sql} FROM {safe_table} LIMIT ?", [limit])
|
|
129
|
+
col_name_list = [c.name for c in columns]
|
|
130
|
+
return [dict(zip(col_name_list, row, strict=False)) for row in cursor.fetchall()]
|
|
131
|
+
|
|
132
|
+
def batch_insert(
|
|
133
|
+
self,
|
|
134
|
+
table_name: str,
|
|
135
|
+
data: Iterator[dict[str, Any]],
|
|
136
|
+
batch_size: int = 5000,
|
|
137
|
+
) -> int:
|
|
138
|
+
inserted = 0
|
|
139
|
+
batch: list[dict[str, Any]] = []
|
|
140
|
+
for row in data:
|
|
141
|
+
batch.append(row)
|
|
142
|
+
if len(batch) >= batch_size:
|
|
143
|
+
inserted += self._insert_batch(table_name, batch)
|
|
144
|
+
batch = []
|
|
145
|
+
if batch:
|
|
146
|
+
inserted += self._insert_batch(table_name, batch)
|
|
147
|
+
return inserted
|
|
148
|
+
|
|
149
|
+
def _insert_batch(self, table_name: str, batch: list[dict[str, Any]]) -> int:
|
|
150
|
+
if not batch:
|
|
151
|
+
return 0
|
|
152
|
+
column_names = list(batch[0].keys())
|
|
153
|
+
sql = build_insert_sql(table_name, column_names)
|
|
154
|
+
values = [tuple(row[col] for col in column_names) for row in batch]
|
|
155
|
+
self.conn.executemany(sql, values)
|
|
156
|
+
self.conn.commit()
|
|
157
|
+
return len(batch)
|
|
158
|
+
|
|
159
|
+
def clear_table(self, table_name: str) -> None:
|
|
160
|
+
safe_table = quote_identifier(table_name)
|
|
161
|
+
self.conn.execute(f"DELETE FROM {safe_table}")
|
|
162
|
+
self.conn.commit()
|
|
163
|
+
logger.debug("Cleared table", table_name=table_name)
|
|
164
|
+
|
|
165
|
+
def optimize_for_bulk_write(self, expected_rows: int | None = None) -> None:
|
|
166
|
+
if self._optimizer is not None:
|
|
167
|
+
self._optimizer.preserve()
|
|
168
|
+
self._optimizer.optimize(expected_rows)
|
|
169
|
+
|
|
170
|
+
def restore_settings(self) -> None:
|
|
171
|
+
if self._optimizer is not None:
|
|
172
|
+
self._optimizer.restore()
|
|
173
|
+
self.conn.commit()
|
|
174
|
+
|
|
175
|
+
def _is_autoincrement(self, table_name: str, column_name: str) -> bool:
|
|
176
|
+
from sqlseed._utils.schema_helpers import detect_autoincrement
|
|
177
|
+
|
|
178
|
+
return detect_autoincrement(self.conn.execute, table_name, column_name)
|
|
179
|
+
|
|
180
|
+
def _execute_pragma(self, sql: str) -> None:
|
|
181
|
+
self.conn.execute(sql)
|
|
182
|
+
|
|
183
|
+
def _fetch_pragma(self, name: str) -> Any:
|
|
184
|
+
cursor = self.conn.execute(f"PRAGMA {name}")
|
|
185
|
+
row = cursor.fetchone()
|
|
186
|
+
return row[0] if row else None
|
|
187
|
+
|
|
188
|
+
def __enter__(self) -> Self:
|
|
189
|
+
return self
|
|
190
|
+
|
|
191
|
+
def __exit__(
|
|
192
|
+
self,
|
|
193
|
+
exc_type: type[BaseException] | None,
|
|
194
|
+
exc_val: BaseException | None,
|
|
195
|
+
exc_tb: Any,
|
|
196
|
+
) -> None:
|
|
197
|
+
self.close()
|