anysite-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of anysite-cli might be problematic. Click here for more details.
- anysite/__init__.py +4 -0
- anysite/__main__.py +6 -0
- anysite/api/__init__.py +21 -0
- anysite/api/client.py +271 -0
- anysite/api/errors.py +137 -0
- anysite/api/schemas.py +333 -0
- anysite/batch/__init__.py +1 -0
- anysite/batch/executor.py +176 -0
- anysite/batch/input.py +160 -0
- anysite/batch/rate_limiter.py +98 -0
- anysite/cli/__init__.py +1 -0
- anysite/cli/config.py +176 -0
- anysite/cli/executor.py +388 -0
- anysite/cli/options.py +249 -0
- anysite/config/__init__.py +11 -0
- anysite/config/paths.py +46 -0
- anysite/config/settings.py +187 -0
- anysite/dataset/__init__.py +37 -0
- anysite/dataset/analyzer.py +268 -0
- anysite/dataset/cli.py +644 -0
- anysite/dataset/collector.py +686 -0
- anysite/dataset/db_loader.py +248 -0
- anysite/dataset/errors.py +30 -0
- anysite/dataset/exporters.py +121 -0
- anysite/dataset/history.py +153 -0
- anysite/dataset/models.py +245 -0
- anysite/dataset/notifications.py +87 -0
- anysite/dataset/scheduler.py +107 -0
- anysite/dataset/storage.py +171 -0
- anysite/dataset/transformer.py +213 -0
- anysite/db/__init__.py +38 -0
- anysite/db/adapters/__init__.py +1 -0
- anysite/db/adapters/base.py +158 -0
- anysite/db/adapters/postgres.py +201 -0
- anysite/db/adapters/sqlite.py +183 -0
- anysite/db/cli.py +687 -0
- anysite/db/config.py +92 -0
- anysite/db/manager.py +166 -0
- anysite/db/operations/__init__.py +1 -0
- anysite/db/operations/insert.py +199 -0
- anysite/db/operations/query.py +43 -0
- anysite/db/schema/__init__.py +1 -0
- anysite/db/schema/inference.py +213 -0
- anysite/db/schema/types.py +71 -0
- anysite/db/utils/__init__.py +1 -0
- anysite/db/utils/sanitize.py +99 -0
- anysite/main.py +498 -0
- anysite/models/__init__.py +1 -0
- anysite/output/__init__.py +11 -0
- anysite/output/console.py +45 -0
- anysite/output/formatters.py +301 -0
- anysite/output/templates.py +76 -0
- anysite/py.typed +0 -0
- anysite/streaming/__init__.py +1 -0
- anysite/streaming/progress.py +121 -0
- anysite/streaming/writer.py +130 -0
- anysite/utils/__init__.py +1 -0
- anysite/utils/fields.py +242 -0
- anysite/utils/retry.py +109 -0
- anysite_cli-0.1.0.dist-info/METADATA +437 -0
- anysite_cli-0.1.0.dist-info/RECORD +64 -0
- anysite_cli-0.1.0.dist-info/WHEEL +4 -0
- anysite_cli-0.1.0.dist-info/entry_points.txt +2 -0
- anysite_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Record transformer — filter, field selection, and column injection.
|
|
2
|
+
|
|
3
|
+
Applies per-source transforms to collected records before Parquet storage.
|
|
4
|
+
The filter parser is intentionally safe: no ``eval()``, only tokenize → parse.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from anysite.dataset.models import TransformConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FilterParseError(Exception):
|
|
16
|
+
"""Raised when a filter expression cannot be parsed."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RecordTransformer:
|
|
20
|
+
"""Apply transform pipeline: filter → select fields → add columns."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config: TransformConfig) -> None:
|
|
23
|
+
self.config = config
|
|
24
|
+
self._filter_fn = _parse_filter(config.filter) if config.filter else None
|
|
25
|
+
|
|
26
|
+
def apply(self, records: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
27
|
+
result = records
|
|
28
|
+
|
|
29
|
+
# 1. Filter
|
|
30
|
+
if self._filter_fn:
|
|
31
|
+
result = [r for r in result if self._filter_fn(r)]
|
|
32
|
+
|
|
33
|
+
# 2. Select fields
|
|
34
|
+
if self.config.fields:
|
|
35
|
+
result = [_select_fields(r, self.config.fields) for r in result]
|
|
36
|
+
|
|
37
|
+
# 3. Add static columns
|
|
38
|
+
if self.config.add_columns:
|
|
39
|
+
for r in result:
|
|
40
|
+
r.update(self.config.add_columns)
|
|
41
|
+
|
|
42
|
+
return result
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ---------------------------------------------------------------------------
|
|
46
|
+
# Safe filter parser
|
|
47
|
+
# ---------------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
_TOKEN_RE = re.compile(
|
|
50
|
+
r"""
|
|
51
|
+
\s*(?:
|
|
52
|
+
(?P<field>\.[a-zA-Z_][a-zA-Z0-9_.]*) | # .field.path
|
|
53
|
+
(?P<string>"[^"]*"|'[^']*') | # quoted string
|
|
54
|
+
(?P<number>-?\d+(?:\.\d+)?) | # number
|
|
55
|
+
(?P<op>==|!=|>=|<=|>|<) | # comparison
|
|
56
|
+
(?P<logic>and|or) | # logical
|
|
57
|
+
(?P<null>null|none|None) # null literal
|
|
58
|
+
)\s*
|
|
59
|
+
""",
|
|
60
|
+
re.VERBOSE,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _tokenize(expr: str) -> list[tuple[str, str]]:
|
|
65
|
+
"""Tokenize a filter expression into (type, value) pairs."""
|
|
66
|
+
tokens: list[tuple[str, str]] = []
|
|
67
|
+
pos = 0
|
|
68
|
+
while pos < len(expr):
|
|
69
|
+
m = _TOKEN_RE.match(expr, pos)
|
|
70
|
+
if not m:
|
|
71
|
+
raise FilterParseError(f"Unexpected character at position {pos}: {expr[pos:]!r}")
|
|
72
|
+
for name in ("field", "string", "number", "op", "logic", "null"):
|
|
73
|
+
val = m.group(name)
|
|
74
|
+
if val is not None:
|
|
75
|
+
tokens.append((name, val))
|
|
76
|
+
break
|
|
77
|
+
pos = m.end()
|
|
78
|
+
return tokens
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _parse_filter(expr: str) -> Any:
|
|
82
|
+
"""Parse a filter expression into a callable predicate.
|
|
83
|
+
|
|
84
|
+
Supported syntax:
|
|
85
|
+
.field > 10
|
|
86
|
+
.name != ""
|
|
87
|
+
.status == "active" and .count > 0
|
|
88
|
+
.field != null
|
|
89
|
+
"""
|
|
90
|
+
if not expr or not expr.strip():
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
tokens = _tokenize(expr)
|
|
94
|
+
if not tokens:
|
|
95
|
+
raise FilterParseError(f"Empty filter expression: {expr!r}")
|
|
96
|
+
|
|
97
|
+
# Parse into comparisons joined by and/or
|
|
98
|
+
comparisons: list[tuple[str, str, Any]] = [] # (field, op, value)
|
|
99
|
+
connectors: list[str] = [] # 'and' | 'or'
|
|
100
|
+
|
|
101
|
+
i = 0
|
|
102
|
+
while i < len(tokens):
|
|
103
|
+
# Expect: field op value
|
|
104
|
+
if i >= len(tokens) or tokens[i][0] != "field":
|
|
105
|
+
raise FilterParseError(f"Expected field, got {tokens[i] if i < len(tokens) else 'end'}")
|
|
106
|
+
field_path = tokens[i][1][1:] # strip leading dot
|
|
107
|
+
i += 1
|
|
108
|
+
|
|
109
|
+
if i >= len(tokens) or tokens[i][0] != "op":
|
|
110
|
+
raise FilterParseError(f"Expected operator after .{field_path}")
|
|
111
|
+
op = tokens[i][1]
|
|
112
|
+
i += 1
|
|
113
|
+
|
|
114
|
+
if i >= len(tokens):
|
|
115
|
+
raise FilterParseError(f"Expected value after .{field_path} {op}")
|
|
116
|
+
|
|
117
|
+
tok_type, tok_val = tokens[i]
|
|
118
|
+
if tok_type == "string":
|
|
119
|
+
value: Any = tok_val[1:-1] # strip quotes
|
|
120
|
+
elif tok_type == "number":
|
|
121
|
+
value = float(tok_val) if "." in tok_val else int(tok_val)
|
|
122
|
+
elif tok_type == "null":
|
|
123
|
+
value = None
|
|
124
|
+
else:
|
|
125
|
+
raise FilterParseError(f"Expected value, got {tokens[i]}")
|
|
126
|
+
i += 1
|
|
127
|
+
|
|
128
|
+
comparisons.append((field_path, op, value))
|
|
129
|
+
|
|
130
|
+
# Check for connector
|
|
131
|
+
if i < len(tokens):
|
|
132
|
+
if tokens[i][0] == "logic":
|
|
133
|
+
connectors.append(tokens[i][1])
|
|
134
|
+
i += 1
|
|
135
|
+
else:
|
|
136
|
+
raise FilterParseError(f"Expected 'and'/'or', got {tokens[i]}")
|
|
137
|
+
|
|
138
|
+
# Build callable
|
|
139
|
+
def _eval_comparison(record: dict[str, Any], field: str, op: str, val: Any) -> bool:
|
|
140
|
+
actual = _get_dot_value(record, field)
|
|
141
|
+
if val is None:
|
|
142
|
+
if op == "==":
|
|
143
|
+
return actual is None
|
|
144
|
+
if op == "!=":
|
|
145
|
+
return actual is not None
|
|
146
|
+
return False
|
|
147
|
+
if actual is None:
|
|
148
|
+
return False
|
|
149
|
+
try:
|
|
150
|
+
if op == "==":
|
|
151
|
+
return actual == val
|
|
152
|
+
if op == "!=":
|
|
153
|
+
return actual != val
|
|
154
|
+
if op == ">":
|
|
155
|
+
return actual > val
|
|
156
|
+
if op == "<":
|
|
157
|
+
return actual < val
|
|
158
|
+
if op == ">=":
|
|
159
|
+
return actual >= val
|
|
160
|
+
if op == "<=":
|
|
161
|
+
return actual <= val
|
|
162
|
+
except TypeError:
|
|
163
|
+
return False
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
def predicate(record: dict[str, Any]) -> bool:
|
|
167
|
+
results = [_eval_comparison(record, f, o, v) for f, o, v in comparisons]
|
|
168
|
+
if not connectors:
|
|
169
|
+
return results[0]
|
|
170
|
+
# Evaluate left to right: and binds tighter than or
|
|
171
|
+
# Simple left-to-right evaluation
|
|
172
|
+
result = results[0]
|
|
173
|
+
for idx, conn in enumerate(connectors):
|
|
174
|
+
if conn == "and":
|
|
175
|
+
result = result and results[idx + 1]
|
|
176
|
+
else: # or
|
|
177
|
+
result = result or results[idx + 1]
|
|
178
|
+
return result
|
|
179
|
+
|
|
180
|
+
return predicate
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _get_dot_value(record: dict[str, Any], path: str) -> Any:
|
|
184
|
+
"""Get a nested value using dot notation."""
|
|
185
|
+
current: Any = record
|
|
186
|
+
for part in path.split("."):
|
|
187
|
+
if isinstance(current, dict):
|
|
188
|
+
current = current.get(part)
|
|
189
|
+
else:
|
|
190
|
+
return None
|
|
191
|
+
return current
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _select_fields(record: dict[str, Any], fields: list[str]) -> dict[str, Any]:
|
|
195
|
+
"""Select specific fields from a record, supporting dot notation."""
|
|
196
|
+
result: dict[str, Any] = {}
|
|
197
|
+
for field in fields:
|
|
198
|
+
# Support "path.to.field AS alias" syntax
|
|
199
|
+
if " AS " in field:
|
|
200
|
+
path, _, alias = field.partition(" AS ")
|
|
201
|
+
path = path.strip()
|
|
202
|
+
alias = alias.strip()
|
|
203
|
+
elif " as " in field:
|
|
204
|
+
path, _, alias = field.partition(" as ")
|
|
205
|
+
path = path.strip()
|
|
206
|
+
alias = alias.strip()
|
|
207
|
+
else:
|
|
208
|
+
path = field
|
|
209
|
+
alias = field.replace(".", "_") if "." in field else field
|
|
210
|
+
|
|
211
|
+
value = _get_dot_value(record, path)
|
|
212
|
+
result[alias] = value
|
|
213
|
+
return result
|
anysite/db/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Database integration subsystem for storing API data in SQL databases."""
|
|
2
|
+
|
|
3
|
+
from typing import NoReturn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def check_db_deps(db_type: str | None = None) -> None:
|
|
7
|
+
"""Check that optional database dependencies are installed.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
db_type: Specific database type to check ('postgres', 'mysql').
|
|
11
|
+
If None, only checks that the db module itself is usable.
|
|
12
|
+
|
|
13
|
+
Raises:
|
|
14
|
+
SystemExit: If required packages are not installed.
|
|
15
|
+
"""
|
|
16
|
+
if db_type == "postgres":
|
|
17
|
+
try:
|
|
18
|
+
import psycopg # noqa: F401
|
|
19
|
+
except ImportError:
|
|
20
|
+
_missing_deps_error(["psycopg"], extra="postgres")
|
|
21
|
+
|
|
22
|
+
elif db_type == "mysql":
|
|
23
|
+
try:
|
|
24
|
+
import pymysql # noqa: F401
|
|
25
|
+
except ImportError:
|
|
26
|
+
_missing_deps_error(["pymysql"], extra="mysql")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _missing_deps_error(missing: list[str], extra: str = "db") -> NoReturn:
|
|
30
|
+
import typer
|
|
31
|
+
|
|
32
|
+
names = ", ".join(missing)
|
|
33
|
+
typer.echo(
|
|
34
|
+
f"Error: Missing required packages: {names}\n"
|
|
35
|
+
f"Install with: pip install anysite-cli[{extra}]",
|
|
36
|
+
err=True,
|
|
37
|
+
)
|
|
38
|
+
raise typer.Exit(1)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Database adapters."""
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Abstract base class for database adapters."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Generator
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from anysite.db.config import OnConflict
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatabaseAdapter(ABC):
|
|
14
|
+
"""Abstract base class for all database adapters.
|
|
15
|
+
|
|
16
|
+
Adapters are synchronous. Use as a context manager for
|
|
17
|
+
automatic connect/disconnect:
|
|
18
|
+
|
|
19
|
+
with SQLiteAdapter(config) as db:
|
|
20
|
+
db.execute("CREATE TABLE ...")
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def connect(self) -> None:
|
|
25
|
+
"""Open a connection to the database."""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def disconnect(self) -> None:
|
|
29
|
+
"""Close the database connection."""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def execute(self, sql: str, params: tuple[Any, ...] | None = None) -> None:
|
|
33
|
+
"""Execute a SQL statement.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
sql: SQL statement with parameter placeholders.
|
|
37
|
+
params: Parameter values for the statement.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def fetch_one(self, sql: str, params: tuple[Any, ...] | None = None) -> dict[str, Any] | None:
|
|
42
|
+
"""Execute a query and return the first row.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
sql: SQL query with parameter placeholders.
|
|
46
|
+
params: Parameter values for the query.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
First row as a dictionary, or None if no results.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def fetch_all(self, sql: str, params: tuple[Any, ...] | None = None) -> list[dict[str, Any]]:
|
|
54
|
+
"""Execute a query and return all rows.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
sql: SQL query with parameter placeholders.
|
|
58
|
+
params: Parameter values for the query.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of rows as dictionaries.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def insert_batch(
|
|
66
|
+
self,
|
|
67
|
+
table: str,
|
|
68
|
+
rows: list[dict[str, Any]],
|
|
69
|
+
on_conflict: OnConflict = OnConflict.ERROR,
|
|
70
|
+
conflict_columns: list[str] | None = None,
|
|
71
|
+
) -> int:
|
|
72
|
+
"""Insert multiple rows into a table.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
table: Table name.
|
|
76
|
+
rows: List of row dictionaries.
|
|
77
|
+
on_conflict: Conflict resolution strategy.
|
|
78
|
+
conflict_columns: Columns that define uniqueness for upsert.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Number of rows inserted/affected.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def table_exists(self, table: str) -> bool:
|
|
86
|
+
"""Check if a table exists.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
table: Table name.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
True if the table exists.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def get_table_schema(self, table: str) -> list[dict[str, str]]:
|
|
97
|
+
"""Get the schema of a table.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
table: Table name.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List of column info dicts with 'name', 'type', 'nullable', 'primary_key' keys.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def create_table(self, table: str, columns: dict[str, str], primary_key: str | None = None) -> None:
|
|
108
|
+
"""Create a table.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
table: Table name.
|
|
112
|
+
columns: Mapping of column name to SQL type.
|
|
113
|
+
primary_key: Optional column name to use as primary key.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def get_server_info(self) -> dict[str, str]:
|
|
118
|
+
"""Get database server information.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Dictionary with server info (version, type, etc.).
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
@contextmanager
|
|
125
|
+
def transaction(self) -> Generator[None, None, None]:
|
|
126
|
+
"""Context manager for transactions.
|
|
127
|
+
|
|
128
|
+
Usage:
|
|
129
|
+
with adapter.transaction():
|
|
130
|
+
adapter.execute("INSERT ...")
|
|
131
|
+
adapter.execute("UPDATE ...")
|
|
132
|
+
"""
|
|
133
|
+
self._begin_transaction()
|
|
134
|
+
try:
|
|
135
|
+
yield
|
|
136
|
+
self._commit_transaction()
|
|
137
|
+
except Exception:
|
|
138
|
+
self._rollback_transaction()
|
|
139
|
+
raise
|
|
140
|
+
|
|
141
|
+
@abstractmethod
|
|
142
|
+
def _begin_transaction(self) -> None:
|
|
143
|
+
"""Begin a transaction."""
|
|
144
|
+
|
|
145
|
+
@abstractmethod
|
|
146
|
+
def _commit_transaction(self) -> None:
|
|
147
|
+
"""Commit the current transaction."""
|
|
148
|
+
|
|
149
|
+
@abstractmethod
|
|
150
|
+
def _rollback_transaction(self) -> None:
|
|
151
|
+
"""Roll back the current transaction."""
|
|
152
|
+
|
|
153
|
+
def __enter__(self) -> DatabaseAdapter:
|
|
154
|
+
self.connect()
|
|
155
|
+
return self
|
|
156
|
+
|
|
157
|
+
def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any) -> None:
|
|
158
|
+
self.disconnect()
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""PostgreSQL database adapter using psycopg v3."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from anysite.db.adapters.base import DatabaseAdapter
|
|
9
|
+
from anysite.db.config import ConnectionConfig, OnConflict
|
|
10
|
+
from anysite.db.utils.sanitize import sanitize_identifier, sanitize_table_name
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PostgresAdapter(DatabaseAdapter):
|
|
14
|
+
"""PostgreSQL adapter using psycopg v3 (sync mode)."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, config: ConnectionConfig) -> None:
|
|
17
|
+
self.config = config
|
|
18
|
+
self._conn: Any = None # psycopg.Connection
|
|
19
|
+
|
|
20
|
+
def connect(self) -> None:
|
|
21
|
+
if self._conn is not None:
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
import psycopg
|
|
25
|
+
from psycopg.rows import dict_row
|
|
26
|
+
|
|
27
|
+
url = self.config.get_url()
|
|
28
|
+
if url:
|
|
29
|
+
self._conn = psycopg.connect(url, row_factory=dict_row)
|
|
30
|
+
else:
|
|
31
|
+
password = self.config.get_password()
|
|
32
|
+
connect_kwargs: dict[str, Any] = {
|
|
33
|
+
"host": self.config.host,
|
|
34
|
+
"dbname": self.config.database,
|
|
35
|
+
"row_factory": dict_row,
|
|
36
|
+
}
|
|
37
|
+
if self.config.user:
|
|
38
|
+
connect_kwargs["user"] = self.config.user
|
|
39
|
+
if password:
|
|
40
|
+
connect_kwargs["password"] = password
|
|
41
|
+
if self.config.port:
|
|
42
|
+
connect_kwargs["port"] = self.config.port
|
|
43
|
+
|
|
44
|
+
self._conn = psycopg.connect(**connect_kwargs)
|
|
45
|
+
|
|
46
|
+
# Set autocommit for non-transactional operations
|
|
47
|
+
self._conn.autocommit = True
|
|
48
|
+
|
|
49
|
+
def disconnect(self) -> None:
|
|
50
|
+
if self._conn is not None:
|
|
51
|
+
self._conn.close()
|
|
52
|
+
self._conn = None
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def conn(self) -> Any:
|
|
56
|
+
if self._conn is None:
|
|
57
|
+
raise RuntimeError("Not connected. Call connect() first or use as context manager.")
|
|
58
|
+
return self._conn
|
|
59
|
+
|
|
60
|
+
def execute(self, sql: str, params: tuple[Any, ...] | None = None) -> None:
|
|
61
|
+
self.conn.execute(sql, params)
|
|
62
|
+
|
|
63
|
+
def fetch_one(self, sql: str, params: tuple[Any, ...] | None = None) -> dict[str, Any] | None:
|
|
64
|
+
cursor = self.conn.execute(sql, params)
|
|
65
|
+
return cursor.fetchone()
|
|
66
|
+
|
|
67
|
+
def fetch_all(self, sql: str, params: tuple[Any, ...] | None = None) -> list[dict[str, Any]]:
|
|
68
|
+
cursor = self.conn.execute(sql, params)
|
|
69
|
+
return cursor.fetchall()
|
|
70
|
+
|
|
71
|
+
def insert_batch(
|
|
72
|
+
self,
|
|
73
|
+
table: str,
|
|
74
|
+
rows: list[dict[str, Any]],
|
|
75
|
+
on_conflict: OnConflict = OnConflict.ERROR,
|
|
76
|
+
conflict_columns: list[str] | None = None,
|
|
77
|
+
) -> int:
|
|
78
|
+
if not rows:
|
|
79
|
+
return 0
|
|
80
|
+
|
|
81
|
+
safe_table = sanitize_table_name(table)
|
|
82
|
+
|
|
83
|
+
# Collect all column names
|
|
84
|
+
all_columns: list[str] = []
|
|
85
|
+
seen: set[str] = set()
|
|
86
|
+
for row in rows:
|
|
87
|
+
for col in row:
|
|
88
|
+
if col not in seen:
|
|
89
|
+
seen.add(col)
|
|
90
|
+
all_columns.append(col)
|
|
91
|
+
|
|
92
|
+
safe_columns = [sanitize_identifier(col) for col in all_columns]
|
|
93
|
+
placeholders = ", ".join(f"%({col})s" for col in all_columns)
|
|
94
|
+
col_list = ", ".join(safe_columns)
|
|
95
|
+
|
|
96
|
+
# Build the INSERT statement
|
|
97
|
+
if on_conflict == OnConflict.IGNORE and conflict_columns:
|
|
98
|
+
safe_conflict = [sanitize_identifier(c) for c in conflict_columns]
|
|
99
|
+
conflict_list = ", ".join(safe_conflict)
|
|
100
|
+
sql = (
|
|
101
|
+
f"INSERT INTO {safe_table} ({col_list}) VALUES ({placeholders}) "
|
|
102
|
+
f"ON CONFLICT ({conflict_list}) DO NOTHING"
|
|
103
|
+
)
|
|
104
|
+
elif on_conflict in (OnConflict.UPDATE, OnConflict.REPLACE) and conflict_columns:
|
|
105
|
+
safe_conflict = [sanitize_identifier(c) for c in conflict_columns]
|
|
106
|
+
conflict_list = ", ".join(safe_conflict)
|
|
107
|
+
update_cols = [c for c in safe_columns if c not in safe_conflict]
|
|
108
|
+
update_clause = ", ".join(f"{c} = EXCLUDED.{c}" for c in update_cols)
|
|
109
|
+
sql = (
|
|
110
|
+
f"INSERT INTO {safe_table} ({col_list}) VALUES ({placeholders}) "
|
|
111
|
+
f"ON CONFLICT ({conflict_list}) DO UPDATE SET {update_clause}"
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
sql = f"INSERT INTO {safe_table} ({col_list}) VALUES ({placeholders})"
|
|
115
|
+
|
|
116
|
+
# Prepare rows, serializing complex types to JSON
|
|
117
|
+
prepared_rows: list[dict[str, Any]] = []
|
|
118
|
+
for row in rows:
|
|
119
|
+
prepared: dict[str, Any] = {}
|
|
120
|
+
for col in all_columns:
|
|
121
|
+
val = row.get(col)
|
|
122
|
+
if isinstance(val, (dict, list)):
|
|
123
|
+
val = json.dumps(val)
|
|
124
|
+
prepared[col] = val
|
|
125
|
+
prepared_rows.append(prepared)
|
|
126
|
+
|
|
127
|
+
# Use executemany for batch insert
|
|
128
|
+
with self.conn.transaction():
|
|
129
|
+
cursor = self.conn.cursor()
|
|
130
|
+
cursor.executemany(sql, prepared_rows)
|
|
131
|
+
return len(prepared_rows)
|
|
132
|
+
|
|
133
|
+
def table_exists(self, table: str) -> bool:
|
|
134
|
+
row = self.fetch_one(
|
|
135
|
+
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = %s AND table_schema = 'public')",
|
|
136
|
+
(table,),
|
|
137
|
+
)
|
|
138
|
+
return bool(row and row.get("exists"))
|
|
139
|
+
|
|
140
|
+
def get_table_schema(self, table: str) -> list[dict[str, str]]:
|
|
141
|
+
rows = self.fetch_all(
|
|
142
|
+
"""
|
|
143
|
+
SELECT c.column_name, c.data_type, c.is_nullable,
|
|
144
|
+
CASE WHEN tc.constraint_type = 'PRIMARY KEY' THEN 'YES' ELSE 'NO' END as primary_key
|
|
145
|
+
FROM information_schema.columns c
|
|
146
|
+
LEFT JOIN information_schema.key_column_usage kcu
|
|
147
|
+
ON c.table_name = kcu.table_name AND c.column_name = kcu.column_name
|
|
148
|
+
LEFT JOIN information_schema.table_constraints tc
|
|
149
|
+
ON kcu.constraint_name = tc.constraint_name AND tc.constraint_type = 'PRIMARY KEY'
|
|
150
|
+
WHERE c.table_name = %s AND c.table_schema = 'public'
|
|
151
|
+
ORDER BY c.ordinal_position
|
|
152
|
+
""",
|
|
153
|
+
(table,),
|
|
154
|
+
)
|
|
155
|
+
return [
|
|
156
|
+
{
|
|
157
|
+
"name": r["column_name"],
|
|
158
|
+
"type": r["data_type"],
|
|
159
|
+
"nullable": r["is_nullable"],
|
|
160
|
+
"primary_key": r["primary_key"],
|
|
161
|
+
}
|
|
162
|
+
for r in rows
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
def create_table(
|
|
166
|
+
self,
|
|
167
|
+
table: str,
|
|
168
|
+
columns: dict[str, str],
|
|
169
|
+
primary_key: str | None = None,
|
|
170
|
+
) -> None:
|
|
171
|
+
safe_table = sanitize_table_name(table)
|
|
172
|
+
col_defs: list[str] = []
|
|
173
|
+
for col_name, col_type in columns.items():
|
|
174
|
+
safe_col = sanitize_identifier(col_name)
|
|
175
|
+
pk_suffix = " PRIMARY KEY" if col_name == primary_key else ""
|
|
176
|
+
col_defs.append(f"{safe_col} {col_type}{pk_suffix}")
|
|
177
|
+
|
|
178
|
+
cols_sql = ", ".join(col_defs)
|
|
179
|
+
sql = f"CREATE TABLE IF NOT EXISTS {safe_table} ({cols_sql})"
|
|
180
|
+
self.execute(sql)
|
|
181
|
+
|
|
182
|
+
def get_server_info(self) -> dict[str, str]:
|
|
183
|
+
row = self.fetch_one("SELECT version()")
|
|
184
|
+
version = row["version"] if row else "unknown"
|
|
185
|
+
return {
|
|
186
|
+
"type": "postgres",
|
|
187
|
+
"version": version,
|
|
188
|
+
"host": self.config.host or "unknown",
|
|
189
|
+
"database": self.config.database or "unknown",
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
def _begin_transaction(self) -> None:
|
|
193
|
+
self.conn.autocommit = False
|
|
194
|
+
|
|
195
|
+
def _commit_transaction(self) -> None:
|
|
196
|
+
self.conn.commit()
|
|
197
|
+
self.conn.autocommit = True
|
|
198
|
+
|
|
199
|
+
def _rollback_transaction(self) -> None:
|
|
200
|
+
self.conn.rollback()
|
|
201
|
+
self.conn.autocommit = True
|