sql-athame 0.4.0a5__tar.gz → 0.4.0a6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/PKG-INFO +1 -1
- sql_athame-0.4.0a6/pyproject.toml +103 -0
- sql_athame-0.4.0a6/sql_athame/__init__.py +2 -0
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/sql_athame/base.py +32 -38
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/sql_athame/dataclasses.py +115 -120
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/sql_athame/escape.py +5 -4
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/sql_athame/sqlalchemy.py +4 -4
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/sql_athame/types.py +4 -2
- sql_athame-0.4.0a5/pyproject.toml +0 -35
- sql_athame-0.4.0a5/sql_athame/__init__.py +0 -2
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/LICENSE +0 -0
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/README.md +0 -0
- {sql_athame-0.4.0a5 → sql_athame-0.4.0a6}/sql_athame/py.typed +0 -0
@@ -0,0 +1,103 @@
|
|
1
|
+
[tool.poetry]
|
2
|
+
name = "sql-athame"
|
3
|
+
version = "0.4.0-alpha-6"
|
4
|
+
description = "Python tool for slicing and dicing SQL"
|
5
|
+
authors = ["Brian Downing <bdowning@lavos.net>"]
|
6
|
+
license = "MIT"
|
7
|
+
readme = "README.md"
|
8
|
+
homepage = "https://github.com/bdowning/sql-athame"
|
9
|
+
repository = "https://github.com/bdowning/sql-athame"
|
10
|
+
|
11
|
+
[tool.poetry.extras]
|
12
|
+
asyncpg = ["asyncpg"]
|
13
|
+
|
14
|
+
[tool.poetry.dependencies]
|
15
|
+
python = "^3.9"
|
16
|
+
asyncpg = { version = "*", optional = true }
|
17
|
+
typing-extensions = "*"
|
18
|
+
|
19
|
+
[tool.poetry.group.dev.dependencies]
|
20
|
+
pytest = "*"
|
21
|
+
mypy = "*"
|
22
|
+
flake8 = "*"
|
23
|
+
ipython = "*"
|
24
|
+
pytest-cov = "*"
|
25
|
+
bump2version = "*"
|
26
|
+
asyncpg = "*"
|
27
|
+
pytest-asyncio = "*"
|
28
|
+
grip = "*"
|
29
|
+
SQLAlchemy = "*"
|
30
|
+
ruff = "*"
|
31
|
+
|
32
|
+
[build-system]
|
33
|
+
requires = ["poetry>=0.12"]
|
34
|
+
build-backend = "poetry.masonry.api"
|
35
|
+
|
36
|
+
[tool.ruff]
|
37
|
+
target-version = "py39"
|
38
|
+
|
39
|
+
[tool.ruff.lint]
|
40
|
+
select = [
|
41
|
+
"ASYNC",
|
42
|
+
"B",
|
43
|
+
"BLE",
|
44
|
+
"C4",
|
45
|
+
"DTZ",
|
46
|
+
"E",
|
47
|
+
"F",
|
48
|
+
"I",
|
49
|
+
"INP",
|
50
|
+
"ISC",
|
51
|
+
"LOG",
|
52
|
+
"N",
|
53
|
+
"PIE",
|
54
|
+
"PT",
|
55
|
+
"RET",
|
56
|
+
"RUF",
|
57
|
+
"SLOT",
|
58
|
+
"UP",
|
59
|
+
]
|
60
|
+
flake8-comprehensions.allow-dict-calls-with-keyword-arguments = true
|
61
|
+
ignore = [
|
62
|
+
"E501", # line too long
|
63
|
+
"E721", # type checks, currently broken
|
64
|
+
"ISC001", # conflicts with ruff format
|
65
|
+
"PT004", # Fixture `...` does not return anything, add leading underscore
|
66
|
+
"RET505", # Unnecessary `else` after `return` statement
|
67
|
+
"RET506", # Unnecessary `else` after `raise` statement
|
68
|
+
]
|
69
|
+
|
70
|
+
[tool.ruff.lint.per-file-ignores]
|
71
|
+
"__init__.py" = [
|
72
|
+
"F401", # wildcard import
|
73
|
+
]
|
74
|
+
|
75
|
+
[tool.pytest.ini_options]
|
76
|
+
addopts = [
|
77
|
+
"-v",
|
78
|
+
"--cov",
|
79
|
+
"--cov-report", "xml:results/pytest/coverage.xml",
|
80
|
+
"--cov-report", "html:results/pytest/cov_html",
|
81
|
+
"--cov-report", "term-missing",
|
82
|
+
"--junitxml=results/pytest/results.xml",
|
83
|
+
"--durations=5",
|
84
|
+
]
|
85
|
+
junit_family = "legacy"
|
86
|
+
asyncio_mode = "auto"
|
87
|
+
|
88
|
+
[tool.coverage]
|
89
|
+
run.include = [
|
90
|
+
"sql_athame/**/*.py",
|
91
|
+
"tests/**/*.py",
|
92
|
+
]
|
93
|
+
report.precision = 2
|
94
|
+
|
95
|
+
[tool.mypy]
|
96
|
+
disallow_incomplete_defs = true
|
97
|
+
check_untyped_defs = true
|
98
|
+
|
99
|
+
[[tool.mypy.overrides]]
|
100
|
+
module = [
|
101
|
+
"asyncpg",
|
102
|
+
]
|
103
|
+
ignore_missing_imports = true
|
@@ -2,16 +2,11 @@ import dataclasses
|
|
2
2
|
import json
|
3
3
|
import re
|
4
4
|
import string
|
5
|
+
from collections.abc import Iterable, Iterator, Sequence
|
5
6
|
from typing import (
|
6
7
|
Any,
|
7
8
|
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
Iterator,
|
11
|
-
List,
|
12
9
|
Optional,
|
13
|
-
Sequence,
|
14
|
-
Tuple,
|
15
10
|
Union,
|
16
11
|
cast,
|
17
12
|
overload,
|
@@ -34,7 +29,7 @@ def auto_numbered(field_name):
|
|
34
29
|
def process_slot_value(
|
35
30
|
name: str,
|
36
31
|
value: Any,
|
37
|
-
placeholders:
|
32
|
+
placeholders: dict[str, Placeholder],
|
38
33
|
) -> Union["Fragment", Placeholder]:
|
39
34
|
if isinstance(value, Fragment):
|
40
35
|
return value
|
@@ -47,9 +42,9 @@ def process_slot_value(
|
|
47
42
|
@dataclasses.dataclass
|
48
43
|
class Fragment:
|
49
44
|
__slots__ = ["parts"]
|
50
|
-
parts:
|
45
|
+
parts: list[Part]
|
51
46
|
|
52
|
-
def flatten_into(self, parts:
|
47
|
+
def flatten_into(self, parts: list[FlatPart]) -> None:
|
53
48
|
for part in self.parts:
|
54
49
|
if isinstance(part, Fragment):
|
55
50
|
part.flatten_into(parts)
|
@@ -70,10 +65,10 @@ class Fragment:
|
|
70
65
|
for i, part in enumerate(flattened.parts):
|
71
66
|
if isinstance(part, Slot):
|
72
67
|
func.append(
|
73
|
-
f" process_slot_value({
|
68
|
+
f" process_slot_value({part.name!r}, slots[{part.name!r}], placeholders),"
|
74
69
|
)
|
75
70
|
elif isinstance(part, str):
|
76
|
-
func.append(f" {
|
71
|
+
func.append(f" {part!r},")
|
77
72
|
else:
|
78
73
|
env[f"part_{i}"] = part
|
79
74
|
func.append(f" part_{i},")
|
@@ -82,9 +77,9 @@ class Fragment:
|
|
82
77
|
return env["compiled"] # type: ignore
|
83
78
|
|
84
79
|
def flatten(self) -> "Fragment":
|
85
|
-
parts:
|
80
|
+
parts: list[FlatPart] = []
|
86
81
|
self.flatten_into(parts)
|
87
|
-
out_parts:
|
82
|
+
out_parts: list[Part] = []
|
88
83
|
for part in parts:
|
89
84
|
if isinstance(part, str) and out_parts and isinstance(out_parts[-1], str):
|
90
85
|
out_parts[-1] += part
|
@@ -93,9 +88,9 @@ class Fragment:
|
|
93
88
|
return Fragment(out_parts)
|
94
89
|
|
95
90
|
def fill(self, **kwargs: Any) -> "Fragment":
|
96
|
-
parts:
|
97
|
-
self.flatten_into(cast(
|
98
|
-
placeholders:
|
91
|
+
parts: list[Part] = []
|
92
|
+
self.flatten_into(cast(list[FlatPart], parts))
|
93
|
+
placeholders: dict[str, Placeholder] = {}
|
99
94
|
for i, part in enumerate(parts):
|
100
95
|
if isinstance(part, Slot):
|
101
96
|
parts[i] = process_slot_value(
|
@@ -106,24 +101,24 @@ class Fragment:
|
|
106
101
|
@overload
|
107
102
|
def prep_query(
|
108
103
|
self, allow_slots: Literal[True]
|
109
|
-
) ->
|
104
|
+
) -> tuple[str, list[Union[Placeholder, Slot]]]: ... # pragma: no cover
|
110
105
|
|
111
106
|
@overload
|
112
107
|
def prep_query(
|
113
108
|
self, allow_slots: Literal[False] = False
|
114
|
-
) ->
|
109
|
+
) -> tuple[str, list[Placeholder]]: ... # pragma: no cover
|
115
110
|
|
116
|
-
def prep_query(self, allow_slots: bool = False) ->
|
117
|
-
parts:
|
111
|
+
def prep_query(self, allow_slots: bool = False) -> tuple[str, list[Any]]:
|
112
|
+
parts: list[FlatPart] = []
|
118
113
|
self.flatten_into(parts)
|
119
|
-
args:
|
120
|
-
placeholder_ids:
|
121
|
-
slot_ids:
|
122
|
-
out_parts:
|
114
|
+
args: list[Union[Placeholder, Slot]] = []
|
115
|
+
placeholder_ids: dict[Placeholder, int] = {}
|
116
|
+
slot_ids: dict[Slot, int] = {}
|
117
|
+
out_parts: list[str] = []
|
123
118
|
for part in parts:
|
124
119
|
if isinstance(part, Slot):
|
125
120
|
if not allow_slots:
|
126
|
-
raise ValueError(f"Unfilled slot: {
|
121
|
+
raise ValueError(f"Unfilled slot: {part.name!r}")
|
127
122
|
if part not in slot_ids:
|
128
123
|
args.append(part)
|
129
124
|
slot_ids[part] = len(args)
|
@@ -138,7 +133,7 @@ class Fragment:
|
|
138
133
|
out_parts.append(part)
|
139
134
|
return "".join(out_parts).strip(), args
|
140
135
|
|
141
|
-
def query(self) ->
|
136
|
+
def query(self) -> tuple[str, list[Any]]:
|
142
137
|
query, args = self.prep_query()
|
143
138
|
placeholder_values = [arg.value for arg in args]
|
144
139
|
return query, placeholder_values
|
@@ -146,16 +141,16 @@ class Fragment:
|
|
146
141
|
def sqlalchemy_text(self) -> Any:
|
147
142
|
return sqlalchemy_text_from_fragment(self)
|
148
143
|
|
149
|
-
def prepare(self) ->
|
144
|
+
def prepare(self) -> tuple[str, Callable[..., list[Any]]]:
|
150
145
|
query, args = self.prep_query(allow_slots=True)
|
151
|
-
env =
|
146
|
+
env = {}
|
152
147
|
func = [
|
153
148
|
"def generate_args(**kwargs):",
|
154
149
|
" return [",
|
155
150
|
]
|
156
151
|
for i, arg in enumerate(args):
|
157
152
|
if isinstance(arg, Slot):
|
158
|
-
func.append(f" kwargs[{
|
153
|
+
func.append(f" kwargs[{arg.name!r}],")
|
159
154
|
else:
|
160
155
|
env[f"value_{i}"] = arg.value
|
161
156
|
func.append(f" value_{i},")
|
@@ -178,10 +173,10 @@ class SQLFormatter:
|
|
178
173
|
if not preserve_formatting:
|
179
174
|
fmt = newline_whitespace_re.sub(" ", fmt)
|
180
175
|
fmtr = string.Formatter()
|
181
|
-
parts:
|
182
|
-
placeholders:
|
176
|
+
parts: list[Part] = []
|
177
|
+
placeholders: dict[str, Placeholder] = {}
|
183
178
|
next_auto_field = 0
|
184
|
-
for literal_text, field_name,
|
179
|
+
for literal_text, field_name, _format_spec, _conversion in fmtr.parse(fmt):
|
185
180
|
parts.append(literal_text)
|
186
181
|
if field_name is not None:
|
187
182
|
if auto_numbered(field_name):
|
@@ -258,12 +253,11 @@ class SQLFormatter:
|
|
258
253
|
parts = parts[0]
|
259
254
|
return Fragment(list(join_parts(parts, infix=", ")))
|
260
255
|
|
261
|
-
|
262
|
-
|
263
|
-
nested = list(nest_for_type(x, t) for x, t in zip(zip(*data), types))
|
256
|
+
def unnest(self, data: Iterable[Sequence[Any]], types: Iterable[str]) -> Fragment:
|
257
|
+
nested = [nest_for_type(x, t) for x, t in zip(zip(*data), types)]
|
264
258
|
if not nested:
|
265
|
-
nested =
|
266
|
-
return Fragment(["UNNEST(",
|
259
|
+
nested = [nest_for_type([], t) for t in types]
|
260
|
+
return Fragment(["UNNEST(", self.list(nested), ")"])
|
267
261
|
|
268
262
|
|
269
263
|
sql = SQLFormatter()
|
@@ -296,7 +290,7 @@ def lit(text: str) -> Fragment:
|
|
296
290
|
return Fragment([text])
|
297
291
|
|
298
292
|
|
299
|
-
def any_all(frags:
|
293
|
+
def any_all(frags: list[Fragment], op: str, base_case: str) -> Fragment:
|
300
294
|
if not frags:
|
301
295
|
return lit(base_case)
|
302
296
|
parts = join_parts(frags, prefix="(", infix=f") {op} (", suffix=")")
|
@@ -1,60 +1,57 @@
|
|
1
1
|
import datetime
|
2
2
|
import uuid
|
3
|
-
from
|
3
|
+
from collections.abc import AsyncGenerator, Iterable, Mapping
|
4
|
+
from dataclasses import Field, InitVar, dataclass, fields
|
4
5
|
from typing import (
|
6
|
+
Annotated,
|
5
7
|
Any,
|
6
|
-
AsyncGenerator,
|
7
8
|
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
Iterator,
|
11
|
-
List,
|
12
|
-
Mapping,
|
13
9
|
Optional,
|
14
|
-
Set,
|
15
|
-
Tuple,
|
16
|
-
Type,
|
17
10
|
TypeVar,
|
18
11
|
Union,
|
12
|
+
get_origin,
|
13
|
+
get_type_hints,
|
19
14
|
)
|
20
15
|
|
16
|
+
from typing_extensions import TypeAlias
|
17
|
+
|
21
18
|
from .base import Fragment, sql
|
22
19
|
|
23
|
-
Where = Union[Fragment, Iterable[Fragment]]
|
20
|
+
Where: TypeAlias = Union[Fragment, Iterable[Fragment]]
|
24
21
|
# KLUDGE to avoid a string argument being valid
|
25
|
-
SequenceOfStrings = Union[
|
26
|
-
FieldNames = SequenceOfStrings
|
27
|
-
FieldNamesSet = Union[SequenceOfStrings,
|
22
|
+
SequenceOfStrings: TypeAlias = Union[list[str], tuple[str, ...]]
|
23
|
+
FieldNames: TypeAlias = SequenceOfStrings
|
24
|
+
FieldNamesSet: TypeAlias = Union[SequenceOfStrings, set[str]]
|
28
25
|
|
29
|
-
Connection = Any
|
30
|
-
Pool = Any
|
26
|
+
Connection: TypeAlias = Any
|
27
|
+
Pool: TypeAlias = Any
|
31
28
|
|
32
29
|
|
33
30
|
@dataclass
|
34
31
|
class ColumnInfo:
|
35
32
|
type: str
|
36
|
-
create_type: str
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
)
|
57
|
-
|
33
|
+
create_type: str = ""
|
34
|
+
nullable: bool = False
|
35
|
+
_constraints: tuple[str, ...] = ()
|
36
|
+
|
37
|
+
constraints: InitVar[Union[str, Iterable[str], None]] = None
|
38
|
+
|
39
|
+
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
|
40
|
+
if self.create_type == "":
|
41
|
+
self.create_type = self.type
|
42
|
+
self.type = sql_create_type_map.get(self.type.upper(), self.type)
|
43
|
+
if constraints is not None:
|
44
|
+
if type(constraints) is str:
|
45
|
+
constraints = (constraints,)
|
46
|
+
self._constraints = tuple(constraints)
|
47
|
+
|
48
|
+
def create_table_string(self) -> str:
|
49
|
+
parts = (
|
50
|
+
self.create_type,
|
51
|
+
*(() if self.nullable else ("NOT NULL",)),
|
52
|
+
*self._constraints,
|
53
|
+
)
|
54
|
+
return " ".join(parts)
|
58
55
|
|
59
56
|
|
60
57
|
sql_create_type_map = {
|
@@ -64,43 +61,37 @@ sql_create_type_map = {
|
|
64
61
|
}
|
65
62
|
|
66
63
|
|
67
|
-
sql_type_map = {
|
68
|
-
Optional[bool]: ("BOOLEAN",),
|
69
|
-
Optional[bytes]: ("BYTEA",),
|
70
|
-
Optional[datetime.date]: ("DATE",),
|
71
|
-
Optional[datetime.datetime]: ("TIMESTAMP",),
|
72
|
-
Optional[float]: ("DOUBLE PRECISION",),
|
73
|
-
Optional[int]: ("INTEGER",),
|
74
|
-
Optional[str]: ("TEXT",),
|
75
|
-
Optional[uuid.UUID]: ("UUID",),
|
76
|
-
bool: ("BOOLEAN",
|
77
|
-
bytes: ("BYTEA",
|
78
|
-
datetime.date: ("DATE",
|
79
|
-
datetime.datetime: ("TIMESTAMP",
|
80
|
-
float: ("DOUBLE PRECISION",
|
81
|
-
int: ("INTEGER",
|
82
|
-
str: ("TEXT",
|
83
|
-
uuid.UUID: ("UUID",
|
64
|
+
sql_type_map: dict[Any, tuple[str, bool]] = {
|
65
|
+
Optional[bool]: ("BOOLEAN", True),
|
66
|
+
Optional[bytes]: ("BYTEA", True),
|
67
|
+
Optional[datetime.date]: ("DATE", True),
|
68
|
+
Optional[datetime.datetime]: ("TIMESTAMP", True),
|
69
|
+
Optional[float]: ("DOUBLE PRECISION", True),
|
70
|
+
Optional[int]: ("INTEGER", True),
|
71
|
+
Optional[str]: ("TEXT", True),
|
72
|
+
Optional[uuid.UUID]: ("UUID", True),
|
73
|
+
bool: ("BOOLEAN", False),
|
74
|
+
bytes: ("BYTEA", False),
|
75
|
+
datetime.date: ("DATE", False),
|
76
|
+
datetime.datetime: ("TIMESTAMP", False),
|
77
|
+
float: ("DOUBLE PRECISION", False),
|
78
|
+
int: ("INTEGER", False),
|
79
|
+
str: ("TEXT", False),
|
80
|
+
uuid.UUID: ("UUID", False),
|
84
81
|
}
|
85
82
|
|
86
83
|
|
87
|
-
def column_info_for_field(field):
|
88
|
-
if "sql_athame" in field.metadata:
|
89
|
-
return field.metadata["sql_athame"]
|
90
|
-
type, *constraints = sql_type_map[field.type]
|
91
|
-
return ColumnInfo(type, type, tuple(constraints))
|
92
|
-
|
93
|
-
|
94
84
|
T = TypeVar("T", bound="ModelBase")
|
95
85
|
U = TypeVar("U")
|
96
86
|
|
97
87
|
|
98
|
-
class ModelBase
|
99
|
-
_column_info: Optional[
|
100
|
-
_cache:
|
88
|
+
class ModelBase:
|
89
|
+
_column_info: Optional[dict[str, ColumnInfo]]
|
90
|
+
_cache: dict[tuple, Any]
|
101
91
|
table_name: str
|
102
|
-
primary_key_names:
|
92
|
+
primary_key_names: tuple[str, ...]
|
103
93
|
array_safe_insert: bool
|
94
|
+
_type_hints: dict[str, type]
|
104
95
|
|
105
96
|
def __init_subclass__(
|
106
97
|
cls,
|
@@ -138,27 +129,34 @@ class ModelBase(Mapping[str, Any]):
|
|
138
129
|
cls._cache[key] = thunk()
|
139
130
|
return cls._cache[key]
|
140
131
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
return iter(self.keys())
|
149
|
-
|
150
|
-
def __len__(self) -> int:
|
151
|
-
return len(self.keys())
|
132
|
+
@classmethod
|
133
|
+
def type_hints(cls) -> dict[str, type]:
|
134
|
+
try:
|
135
|
+
return cls._type_hints
|
136
|
+
except AttributeError:
|
137
|
+
cls._type_hints = get_type_hints(cls, include_extras=True)
|
138
|
+
return cls._type_hints
|
152
139
|
|
153
|
-
|
154
|
-
|
140
|
+
@classmethod
|
141
|
+
def column_info_for_field(cls, field: Field) -> ColumnInfo:
|
142
|
+
type_info = cls.type_hints()[field.name]
|
143
|
+
base_type = type_info
|
144
|
+
if get_origin(type_info) is Annotated:
|
145
|
+
base_type = type_info.__origin__ # type: ignore
|
146
|
+
for md in type_info.__metadata__: # type: ignore
|
147
|
+
if isinstance(md, ColumnInfo):
|
148
|
+
return md
|
149
|
+
type, nullable = sql_type_map[base_type]
|
150
|
+
return ColumnInfo(type=type, nullable=nullable)
|
155
151
|
|
156
152
|
@classmethod
|
157
153
|
def column_info(cls, column: str) -> ColumnInfo:
|
158
154
|
try:
|
159
155
|
return cls._column_info[column] # type: ignore
|
160
156
|
except AttributeError:
|
161
|
-
cls._column_info = {
|
157
|
+
cls._column_info = {
|
158
|
+
f.name: cls.column_info_for_field(f) for f in cls._fields()
|
159
|
+
}
|
162
160
|
return cls._column_info[column]
|
163
161
|
|
164
162
|
@classmethod
|
@@ -166,17 +164,17 @@ class ModelBase(Mapping[str, Any]):
|
|
166
164
|
return sql.identifier(cls.table_name, prefix=prefix)
|
167
165
|
|
168
166
|
@classmethod
|
169
|
-
def primary_key_names_sql(cls, *, prefix: Optional[str] = None) ->
|
167
|
+
def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
|
170
168
|
return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
|
171
169
|
|
172
170
|
@classmethod
|
173
|
-
def field_names(cls, *, exclude: FieldNamesSet = ()) ->
|
171
|
+
def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
|
174
172
|
return [f.name for f in cls._fields() if f.name not in exclude]
|
175
173
|
|
176
174
|
@classmethod
|
177
175
|
def field_names_sql(
|
178
176
|
cls, *, prefix: Optional[str] = None, exclude: FieldNamesSet = ()
|
179
|
-
) ->
|
177
|
+
) -> list[Fragment]:
|
180
178
|
return [
|
181
179
|
sql.identifier(f, prefix=prefix) for f in cls.field_names(exclude=exclude)
|
182
180
|
]
|
@@ -186,9 +184,9 @@ class ModelBase(Mapping[str, Any]):
|
|
186
184
|
|
187
185
|
@classmethod
|
188
186
|
def _get_field_values_fn(
|
189
|
-
cls:
|
190
|
-
) -> Callable[[T],
|
191
|
-
env:
|
187
|
+
cls: type[T], exclude: FieldNamesSet = ()
|
188
|
+
) -> Callable[[T], list[Any]]:
|
189
|
+
env: dict[str, Any] = {}
|
192
190
|
func = ["def get_field_values(self): return ["]
|
193
191
|
for f in cls._fields():
|
194
192
|
if f.name not in exclude:
|
@@ -197,7 +195,7 @@ class ModelBase(Mapping[str, Any]):
|
|
197
195
|
exec(" ".join(func), env)
|
198
196
|
return env["get_field_values"]
|
199
197
|
|
200
|
-
def field_values(self, *, exclude: FieldNamesSet = ()) ->
|
198
|
+
def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
|
201
199
|
get_field_values = self._cached(
|
202
200
|
("get_field_values", tuple(sorted(exclude))),
|
203
201
|
lambda: self._get_field_values_fn(exclude),
|
@@ -206,7 +204,7 @@ class ModelBase(Mapping[str, Any]):
|
|
206
204
|
|
207
205
|
def field_values_sql(
|
208
206
|
self, *, exclude: FieldNamesSet = (), default_none: bool = False
|
209
|
-
) ->
|
207
|
+
) -> list[Fragment]:
|
210
208
|
if default_none:
|
211
209
|
return [
|
212
210
|
sql.literal("DEFAULT") if value is None else sql.value(value)
|
@@ -217,7 +215,7 @@ class ModelBase(Mapping[str, Any]):
|
|
217
215
|
|
218
216
|
@classmethod
|
219
217
|
def from_tuple(
|
220
|
-
cls:
|
218
|
+
cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
|
221
219
|
) -> T:
|
222
220
|
names = (f.name for f in cls._fields() if f.name not in exclude)
|
223
221
|
kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
|
@@ -225,14 +223,14 @@ class ModelBase(Mapping[str, Any]):
|
|
225
223
|
|
226
224
|
@classmethod
|
227
225
|
def from_dict(
|
228
|
-
cls:
|
226
|
+
cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
|
229
227
|
) -> T:
|
230
228
|
names = {f.name for f in cls._fields() if f.name not in exclude}
|
231
229
|
kwargs = {k: v for k, v in dct.items() if k in names}
|
232
230
|
return cls(**kwargs)
|
233
231
|
|
234
232
|
@classmethod
|
235
|
-
def ensure_model(cls:
|
233
|
+
def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
|
236
234
|
if isinstance(row, cls):
|
237
235
|
return row
|
238
236
|
return cls(**row)
|
@@ -286,7 +284,7 @@ class ModelBase(Mapping[str, Any]):
|
|
286
284
|
|
287
285
|
@classmethod
|
288
286
|
async def select_cursor(
|
289
|
-
cls:
|
287
|
+
cls: type[T],
|
290
288
|
connection: Connection,
|
291
289
|
order_by: Union[FieldNames, str] = (),
|
292
290
|
for_update: bool = False,
|
@@ -301,12 +299,12 @@ class ModelBase(Mapping[str, Any]):
|
|
301
299
|
|
302
300
|
@classmethod
|
303
301
|
async def select(
|
304
|
-
cls:
|
302
|
+
cls: type[T],
|
305
303
|
connection_or_pool: Union[Connection, Pool],
|
306
304
|
order_by: Union[FieldNames, str] = (),
|
307
305
|
for_update: bool = False,
|
308
306
|
where: Where = (),
|
309
|
-
) ->
|
307
|
+
) -> list[T]:
|
310
308
|
return [
|
311
309
|
cls(**row)
|
312
310
|
for row in await connection_or_pool.fetch(
|
@@ -315,7 +313,7 @@ class ModelBase(Mapping[str, Any]):
|
|
315
313
|
]
|
316
314
|
|
317
315
|
@classmethod
|
318
|
-
def create_sql(cls:
|
316
|
+
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
|
319
317
|
return sql(
|
320
318
|
"INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
|
321
319
|
table=cls.table_name_sql(),
|
@@ -326,7 +324,7 @@ class ModelBase(Mapping[str, Any]):
|
|
326
324
|
|
327
325
|
@classmethod
|
328
326
|
async def create(
|
329
|
-
cls:
|
327
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
|
330
328
|
) -> T:
|
331
329
|
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
|
332
330
|
return cls(**row)
|
@@ -375,11 +373,10 @@ class ModelBase(Mapping[str, Any]):
|
|
375
373
|
self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
|
376
374
|
)
|
377
375
|
result = await connection_or_pool.fetchrow(*query)
|
378
|
-
|
379
|
-
return is_update
|
376
|
+
return result["xmax"] != 0
|
380
377
|
|
381
378
|
@classmethod
|
382
|
-
def delete_multiple_sql(cls:
|
379
|
+
def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
383
380
|
cached = cls._cached(
|
384
381
|
("delete_multiple_sql",),
|
385
382
|
lambda: sql(
|
@@ -397,12 +394,12 @@ class ModelBase(Mapping[str, Any]):
|
|
397
394
|
|
398
395
|
@classmethod
|
399
396
|
async def delete_multiple(
|
400
|
-
cls:
|
397
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
401
398
|
) -> str:
|
402
399
|
return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
|
403
400
|
|
404
401
|
@classmethod
|
405
|
-
def insert_multiple_sql(cls:
|
402
|
+
def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
406
403
|
cached = cls._cached(
|
407
404
|
("insert_multiple_sql",),
|
408
405
|
lambda: sql(
|
@@ -419,7 +416,7 @@ class ModelBase(Mapping[str, Any]):
|
|
419
416
|
)
|
420
417
|
|
421
418
|
@classmethod
|
422
|
-
def insert_multiple_array_safe_sql(cls:
|
419
|
+
def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
423
420
|
return sql(
|
424
421
|
"INSERT INTO {table} ({fields}) VALUES {values}",
|
425
422
|
table=cls.table_name_sql(),
|
@@ -432,13 +429,13 @@ class ModelBase(Mapping[str, Any]):
|
|
432
429
|
|
433
430
|
@classmethod
|
434
431
|
async def insert_multiple_unnest(
|
435
|
-
cls:
|
432
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
436
433
|
) -> str:
|
437
434
|
return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
|
438
435
|
|
439
436
|
@classmethod
|
440
437
|
async def insert_multiple_array_safe(
|
441
|
-
cls:
|
438
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
442
439
|
) -> str:
|
443
440
|
last = ""
|
444
441
|
for chunk in chunked(rows, 100):
|
@@ -449,7 +446,7 @@ class ModelBase(Mapping[str, Any]):
|
|
449
446
|
|
450
447
|
@classmethod
|
451
448
|
async def insert_multiple(
|
452
|
-
cls:
|
449
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
453
450
|
) -> str:
|
454
451
|
if cls.array_safe_insert:
|
455
452
|
return await cls.insert_multiple_array_safe(connection_or_pool, rows)
|
@@ -458,7 +455,7 @@ class ModelBase(Mapping[str, Any]):
|
|
458
455
|
|
459
456
|
@classmethod
|
460
457
|
async def upsert_multiple_unnest(
|
461
|
-
cls:
|
458
|
+
cls: type[T],
|
462
459
|
connection_or_pool: Union[Connection, Pool],
|
463
460
|
rows: Iterable[T],
|
464
461
|
insert_only: FieldNamesSet = (),
|
@@ -469,7 +466,7 @@ class ModelBase(Mapping[str, Any]):
|
|
469
466
|
|
470
467
|
@classmethod
|
471
468
|
async def upsert_multiple_array_safe(
|
472
|
-
cls:
|
469
|
+
cls: type[T],
|
473
470
|
connection_or_pool: Union[Connection, Pool],
|
474
471
|
rows: Iterable[T],
|
475
472
|
insert_only: FieldNamesSet = (),
|
@@ -485,7 +482,7 @@ class ModelBase(Mapping[str, Any]):
|
|
485
482
|
|
486
483
|
@classmethod
|
487
484
|
async def upsert_multiple(
|
488
|
-
cls:
|
485
|
+
cls: type[T],
|
489
486
|
connection_or_pool: Union[Connection, Pool],
|
490
487
|
rows: Iterable[T],
|
491
488
|
insert_only: FieldNamesSet = (),
|
@@ -501,9 +498,9 @@ class ModelBase(Mapping[str, Any]):
|
|
501
498
|
|
502
499
|
@classmethod
|
503
500
|
def _get_equal_ignoring_fn(
|
504
|
-
cls:
|
501
|
+
cls: type[T], ignore: FieldNamesSet = ()
|
505
502
|
) -> Callable[[T, T], bool]:
|
506
|
-
env:
|
503
|
+
env: dict[str, Any] = {}
|
507
504
|
func = ["def equal_ignoring(a, b):"]
|
508
505
|
for f in cls._fields():
|
509
506
|
if f.name not in ignore:
|
@@ -514,14 +511,14 @@ class ModelBase(Mapping[str, Any]):
|
|
514
511
|
|
515
512
|
@classmethod
|
516
513
|
async def replace_multiple(
|
517
|
-
cls:
|
514
|
+
cls: type[T],
|
518
515
|
connection: Connection,
|
519
516
|
rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
|
520
517
|
*,
|
521
518
|
where: Where,
|
522
519
|
ignore: FieldNamesSet = (),
|
523
520
|
insert_only: FieldNamesSet = (),
|
524
|
-
) ->
|
521
|
+
) -> tuple[list[T], list[T], list[T]]:
|
525
522
|
ignore = sorted(set(ignore) | set(insert_only))
|
526
523
|
equal_ignoring = cls._cached(
|
527
524
|
("equal_ignoring", tuple(ignore)),
|
@@ -556,32 +553,30 @@ class ModelBase(Mapping[str, Any]):
|
|
556
553
|
|
557
554
|
@classmethod
|
558
555
|
def _get_differences_ignoring_fn(
|
559
|
-
cls:
|
560
|
-
) -> Callable[[T, T],
|
561
|
-
env:
|
556
|
+
cls: type[T], ignore: FieldNamesSet = ()
|
557
|
+
) -> Callable[[T, T], list[str]]:
|
558
|
+
env: dict[str, Any] = {}
|
562
559
|
func = [
|
563
560
|
"def differences_ignoring(a, b):",
|
564
561
|
" diffs = []",
|
565
562
|
]
|
566
563
|
for f in cls._fields():
|
567
564
|
if f.name not in ignore:
|
568
|
-
func.append(
|
569
|
-
f" if a.{f.name} != b.{f.name}: diffs.append({repr(f.name)})"
|
570
|
-
)
|
565
|
+
func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
|
571
566
|
func += [" return diffs"]
|
572
567
|
exec("\n".join(func), env)
|
573
568
|
return env["differences_ignoring"]
|
574
569
|
|
575
570
|
@classmethod
|
576
571
|
async def replace_multiple_reporting_differences(
|
577
|
-
cls:
|
572
|
+
cls: type[T],
|
578
573
|
connection: Connection,
|
579
574
|
rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
|
580
575
|
*,
|
581
576
|
where: Where,
|
582
577
|
ignore: FieldNamesSet = (),
|
583
578
|
insert_only: FieldNamesSet = (),
|
584
|
-
) ->
|
579
|
+
) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
|
585
580
|
ignore = sorted(set(ignore) | set(insert_only))
|
586
581
|
differences_ignoring = cls._cached(
|
587
582
|
("differences_ignoring", tuple(ignore)),
|
@@ -1,19 +1,20 @@
|
|
1
1
|
import math
|
2
2
|
import uuid
|
3
|
-
from
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
|
6
7
|
def escape(value: Any) -> str:
|
7
8
|
if isinstance(value, str):
|
8
|
-
return f"E{
|
9
|
+
return f"E{value!r}"
|
9
10
|
elif isinstance(value, float) or isinstance(value, int):
|
10
11
|
if math.isnan(value):
|
11
12
|
raise ValueError("Can't escape NaN float")
|
12
13
|
elif math.isinf(value):
|
13
14
|
raise ValueError("Can't escape infinite float")
|
14
|
-
return f"{
|
15
|
+
return f"{value!r}"
|
15
16
|
elif isinstance(value, uuid.UUID):
|
16
|
-
return f"{
|
17
|
+
return f"{str(value)!r}::UUID"
|
17
18
|
elif isinstance(value, Sequence):
|
18
19
|
args = ", ".join(escape(x) for x in value)
|
19
20
|
return f"ARRAY[{args}]"
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Any
|
1
|
+
from typing import TYPE_CHECKING, Any
|
2
2
|
|
3
3
|
from .types import FlatPart, Placeholder, Slot
|
4
4
|
|
@@ -7,10 +7,10 @@ try:
|
|
7
7
|
from sqlalchemy.sql.elements import BindParameter
|
8
8
|
|
9
9
|
def sqlalchemy_text_from_fragment(self: "Fragment") -> Any:
|
10
|
-
parts:
|
10
|
+
parts: list[FlatPart] = []
|
11
11
|
self.flatten_into(parts)
|
12
|
-
bindparams:
|
13
|
-
out_parts:
|
12
|
+
bindparams: dict[str, Any] = {}
|
13
|
+
out_parts: list[str] = []
|
14
14
|
for part in parts:
|
15
15
|
if isinstance(part, Slot):
|
16
16
|
out_parts.append(f"(:{part.name})")
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import dataclasses
|
2
2
|
from typing import TYPE_CHECKING, Any, Union
|
3
3
|
|
4
|
+
from typing_extensions import TypeAlias
|
5
|
+
|
4
6
|
|
5
7
|
@dataclasses.dataclass(eq=False)
|
6
8
|
class Placeholder:
|
@@ -15,8 +17,8 @@ class Slot:
|
|
15
17
|
name: str
|
16
18
|
|
17
19
|
|
18
|
-
Part = Union[str, Placeholder, Slot, "Fragment"]
|
19
|
-
FlatPart = Union[str, Placeholder, Slot]
|
20
|
+
Part: TypeAlias = Union[str, Placeholder, Slot, "Fragment"]
|
21
|
+
FlatPart: TypeAlias = Union[str, Placeholder, Slot]
|
20
22
|
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from .base import Fragment
|
@@ -1,35 +0,0 @@
|
|
1
|
-
[tool.poetry]
|
2
|
-
name = "sql-athame"
|
3
|
-
version = "0.4.0-alpha-5"
|
4
|
-
description = "Python tool for slicing and dicing SQL"
|
5
|
-
authors = ["Brian Downing <bdowning@lavos.net>"]
|
6
|
-
license = "MIT"
|
7
|
-
readme = "README.md"
|
8
|
-
homepage = "https://github.com/bdowning/sql-athame"
|
9
|
-
repository = "https://github.com/bdowning/sql-athame"
|
10
|
-
|
11
|
-
[tool.poetry.extras]
|
12
|
-
asyncpg = ["asyncpg"]
|
13
|
-
|
14
|
-
[tool.poetry.dependencies]
|
15
|
-
python = "^3.9"
|
16
|
-
asyncpg = { version = "*", optional = true }
|
17
|
-
typing-extensions = "*"
|
18
|
-
|
19
|
-
[tool.poetry.dev-dependencies]
|
20
|
-
black = "*"
|
21
|
-
isort = "*"
|
22
|
-
pytest = "*"
|
23
|
-
mypy = "*"
|
24
|
-
flake8 = "*"
|
25
|
-
ipython = "*"
|
26
|
-
pytest-cov = "*"
|
27
|
-
bump2version = "*"
|
28
|
-
asyncpg = "*"
|
29
|
-
pytest-asyncio = "*"
|
30
|
-
grip = "*"
|
31
|
-
SQLAlchemy = "*"
|
32
|
-
|
33
|
-
[build-system]
|
34
|
-
requires = ["poetry>=0.12"]
|
35
|
-
build-backend = "poetry.masonry.api"
|
File without changes
|
File without changes
|
File without changes
|