sql-athame 0.4.0a4__py3-none-any.whl → 0.4.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_athame/__init__.py +2 -2
- sql_athame/base.py +32 -38
- sql_athame/dataclasses.py +146 -128
- sql_athame/escape.py +5 -4
- sql_athame/sqlalchemy.py +4 -4
- sql_athame/types.py +4 -2
- {sql_athame-0.4.0a4.dist-info → sql_athame-0.4.0a6.dist-info}/METADATA +1 -1
- sql_athame-0.4.0a6.dist-info/RECORD +11 -0
- sql_athame-0.4.0a4.dist-info/RECORD +0 -11
- {sql_athame-0.4.0a4.dist-info → sql_athame-0.4.0a6.dist-info}/LICENSE +0 -0
- {sql_athame-0.4.0a4.dist-info → sql_athame-0.4.0a6.dist-info}/WHEEL +0 -0
sql_athame/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from .base import Fragment, sql
|
2
|
-
from .dataclasses import
|
1
|
+
from .base import Fragment, sql
|
2
|
+
from .dataclasses import ColumnInfo, ModelBase
|
sql_athame/base.py
CHANGED
@@ -2,16 +2,11 @@ import dataclasses
|
|
2
2
|
import json
|
3
3
|
import re
|
4
4
|
import string
|
5
|
+
from collections.abc import Iterable, Iterator, Sequence
|
5
6
|
from typing import (
|
6
7
|
Any,
|
7
8
|
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
Iterator,
|
11
|
-
List,
|
12
9
|
Optional,
|
13
|
-
Sequence,
|
14
|
-
Tuple,
|
15
10
|
Union,
|
16
11
|
cast,
|
17
12
|
overload,
|
@@ -34,7 +29,7 @@ def auto_numbered(field_name):
|
|
34
29
|
def process_slot_value(
|
35
30
|
name: str,
|
36
31
|
value: Any,
|
37
|
-
placeholders:
|
32
|
+
placeholders: dict[str, Placeholder],
|
38
33
|
) -> Union["Fragment", Placeholder]:
|
39
34
|
if isinstance(value, Fragment):
|
40
35
|
return value
|
@@ -47,9 +42,9 @@ def process_slot_value(
|
|
47
42
|
@dataclasses.dataclass
|
48
43
|
class Fragment:
|
49
44
|
__slots__ = ["parts"]
|
50
|
-
parts:
|
45
|
+
parts: list[Part]
|
51
46
|
|
52
|
-
def flatten_into(self, parts:
|
47
|
+
def flatten_into(self, parts: list[FlatPart]) -> None:
|
53
48
|
for part in self.parts:
|
54
49
|
if isinstance(part, Fragment):
|
55
50
|
part.flatten_into(parts)
|
@@ -70,10 +65,10 @@ class Fragment:
|
|
70
65
|
for i, part in enumerate(flattened.parts):
|
71
66
|
if isinstance(part, Slot):
|
72
67
|
func.append(
|
73
|
-
f" process_slot_value({
|
68
|
+
f" process_slot_value({part.name!r}, slots[{part.name!r}], placeholders),"
|
74
69
|
)
|
75
70
|
elif isinstance(part, str):
|
76
|
-
func.append(f" {
|
71
|
+
func.append(f" {part!r},")
|
77
72
|
else:
|
78
73
|
env[f"part_{i}"] = part
|
79
74
|
func.append(f" part_{i},")
|
@@ -82,9 +77,9 @@ class Fragment:
|
|
82
77
|
return env["compiled"] # type: ignore
|
83
78
|
|
84
79
|
def flatten(self) -> "Fragment":
|
85
|
-
parts:
|
80
|
+
parts: list[FlatPart] = []
|
86
81
|
self.flatten_into(parts)
|
87
|
-
out_parts:
|
82
|
+
out_parts: list[Part] = []
|
88
83
|
for part in parts:
|
89
84
|
if isinstance(part, str) and out_parts and isinstance(out_parts[-1], str):
|
90
85
|
out_parts[-1] += part
|
@@ -93,9 +88,9 @@ class Fragment:
|
|
93
88
|
return Fragment(out_parts)
|
94
89
|
|
95
90
|
def fill(self, **kwargs: Any) -> "Fragment":
|
96
|
-
parts:
|
97
|
-
self.flatten_into(cast(
|
98
|
-
placeholders:
|
91
|
+
parts: list[Part] = []
|
92
|
+
self.flatten_into(cast(list[FlatPart], parts))
|
93
|
+
placeholders: dict[str, Placeholder] = {}
|
99
94
|
for i, part in enumerate(parts):
|
100
95
|
if isinstance(part, Slot):
|
101
96
|
parts[i] = process_slot_value(
|
@@ -106,24 +101,24 @@ class Fragment:
|
|
106
101
|
@overload
|
107
102
|
def prep_query(
|
108
103
|
self, allow_slots: Literal[True]
|
109
|
-
) ->
|
104
|
+
) -> tuple[str, list[Union[Placeholder, Slot]]]: ... # pragma: no cover
|
110
105
|
|
111
106
|
@overload
|
112
107
|
def prep_query(
|
113
108
|
self, allow_slots: Literal[False] = False
|
114
|
-
) ->
|
109
|
+
) -> tuple[str, list[Placeholder]]: ... # pragma: no cover
|
115
110
|
|
116
|
-
def prep_query(self, allow_slots: bool = False) ->
|
117
|
-
parts:
|
111
|
+
def prep_query(self, allow_slots: bool = False) -> tuple[str, list[Any]]:
|
112
|
+
parts: list[FlatPart] = []
|
118
113
|
self.flatten_into(parts)
|
119
|
-
args:
|
120
|
-
placeholder_ids:
|
121
|
-
slot_ids:
|
122
|
-
out_parts:
|
114
|
+
args: list[Union[Placeholder, Slot]] = []
|
115
|
+
placeholder_ids: dict[Placeholder, int] = {}
|
116
|
+
slot_ids: dict[Slot, int] = {}
|
117
|
+
out_parts: list[str] = []
|
123
118
|
for part in parts:
|
124
119
|
if isinstance(part, Slot):
|
125
120
|
if not allow_slots:
|
126
|
-
raise ValueError(f"Unfilled slot: {
|
121
|
+
raise ValueError(f"Unfilled slot: {part.name!r}")
|
127
122
|
if part not in slot_ids:
|
128
123
|
args.append(part)
|
129
124
|
slot_ids[part] = len(args)
|
@@ -138,7 +133,7 @@ class Fragment:
|
|
138
133
|
out_parts.append(part)
|
139
134
|
return "".join(out_parts).strip(), args
|
140
135
|
|
141
|
-
def query(self) ->
|
136
|
+
def query(self) -> tuple[str, list[Any]]:
|
142
137
|
query, args = self.prep_query()
|
143
138
|
placeholder_values = [arg.value for arg in args]
|
144
139
|
return query, placeholder_values
|
@@ -146,16 +141,16 @@ class Fragment:
|
|
146
141
|
def sqlalchemy_text(self) -> Any:
|
147
142
|
return sqlalchemy_text_from_fragment(self)
|
148
143
|
|
149
|
-
def prepare(self) ->
|
144
|
+
def prepare(self) -> tuple[str, Callable[..., list[Any]]]:
|
150
145
|
query, args = self.prep_query(allow_slots=True)
|
151
|
-
env =
|
146
|
+
env = {}
|
152
147
|
func = [
|
153
148
|
"def generate_args(**kwargs):",
|
154
149
|
" return [",
|
155
150
|
]
|
156
151
|
for i, arg in enumerate(args):
|
157
152
|
if isinstance(arg, Slot):
|
158
|
-
func.append(f" kwargs[{
|
153
|
+
func.append(f" kwargs[{arg.name!r}],")
|
159
154
|
else:
|
160
155
|
env[f"value_{i}"] = arg.value
|
161
156
|
func.append(f" value_{i},")
|
@@ -178,10 +173,10 @@ class SQLFormatter:
|
|
178
173
|
if not preserve_formatting:
|
179
174
|
fmt = newline_whitespace_re.sub(" ", fmt)
|
180
175
|
fmtr = string.Formatter()
|
181
|
-
parts:
|
182
|
-
placeholders:
|
176
|
+
parts: list[Part] = []
|
177
|
+
placeholders: dict[str, Placeholder] = {}
|
183
178
|
next_auto_field = 0
|
184
|
-
for literal_text, field_name,
|
179
|
+
for literal_text, field_name, _format_spec, _conversion in fmtr.parse(fmt):
|
185
180
|
parts.append(literal_text)
|
186
181
|
if field_name is not None:
|
187
182
|
if auto_numbered(field_name):
|
@@ -258,12 +253,11 @@ class SQLFormatter:
|
|
258
253
|
parts = parts[0]
|
259
254
|
return Fragment(list(join_parts(parts, infix=", ")))
|
260
255
|
|
261
|
-
|
262
|
-
|
263
|
-
nested = list(nest_for_type(x, t) for x, t in zip(zip(*data), types))
|
256
|
+
def unnest(self, data: Iterable[Sequence[Any]], types: Iterable[str]) -> Fragment:
|
257
|
+
nested = [nest_for_type(x, t) for x, t in zip(zip(*data), types)]
|
264
258
|
if not nested:
|
265
|
-
nested =
|
266
|
-
return Fragment(["UNNEST(",
|
259
|
+
nested = [nest_for_type([], t) for t in types]
|
260
|
+
return Fragment(["UNNEST(", self.list(nested), ")"])
|
267
261
|
|
268
262
|
|
269
263
|
sql = SQLFormatter()
|
@@ -296,7 +290,7 @@ def lit(text: str) -> Fragment:
|
|
296
290
|
return Fragment([text])
|
297
291
|
|
298
292
|
|
299
|
-
def any_all(frags:
|
293
|
+
def any_all(frags: list[Fragment], op: str, base_case: str) -> Fragment:
|
300
294
|
if not frags:
|
301
295
|
return lit(base_case)
|
302
296
|
parts = join_parts(frags, prefix="(", infix=f") {op} (", suffix=")")
|
sql_athame/dataclasses.py
CHANGED
@@ -1,60 +1,57 @@
|
|
1
1
|
import datetime
|
2
2
|
import uuid
|
3
|
-
from
|
3
|
+
from collections.abc import AsyncGenerator, Iterable, Mapping
|
4
|
+
from dataclasses import Field, InitVar, dataclass, fields
|
4
5
|
from typing import (
|
6
|
+
Annotated,
|
5
7
|
Any,
|
6
|
-
AsyncGenerator,
|
7
8
|
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
Iterator,
|
11
|
-
List,
|
12
|
-
Mapping,
|
13
9
|
Optional,
|
14
|
-
Set,
|
15
|
-
Tuple,
|
16
|
-
Type,
|
17
10
|
TypeVar,
|
18
11
|
Union,
|
12
|
+
get_origin,
|
13
|
+
get_type_hints,
|
19
14
|
)
|
20
15
|
|
16
|
+
from typing_extensions import TypeAlias
|
17
|
+
|
21
18
|
from .base import Fragment, sql
|
22
19
|
|
23
|
-
Where = Union[Fragment, Iterable[Fragment]]
|
20
|
+
Where: TypeAlias = Union[Fragment, Iterable[Fragment]]
|
24
21
|
# KLUDGE to avoid a string argument being valid
|
25
|
-
SequenceOfStrings = Union[
|
26
|
-
FieldNames = SequenceOfStrings
|
27
|
-
FieldNamesSet = Union[SequenceOfStrings,
|
22
|
+
SequenceOfStrings: TypeAlias = Union[list[str], tuple[str, ...]]
|
23
|
+
FieldNames: TypeAlias = SequenceOfStrings
|
24
|
+
FieldNamesSet: TypeAlias = Union[SequenceOfStrings, set[str]]
|
28
25
|
|
29
|
-
Connection = Any
|
30
|
-
Pool = Any
|
26
|
+
Connection: TypeAlias = Any
|
27
|
+
Pool: TypeAlias = Any
|
31
28
|
|
32
29
|
|
33
30
|
@dataclass
|
34
31
|
class ColumnInfo:
|
35
32
|
type: str
|
36
|
-
create_type: str
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
)
|
57
|
-
|
33
|
+
create_type: str = ""
|
34
|
+
nullable: bool = False
|
35
|
+
_constraints: tuple[str, ...] = ()
|
36
|
+
|
37
|
+
constraints: InitVar[Union[str, Iterable[str], None]] = None
|
38
|
+
|
39
|
+
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
|
40
|
+
if self.create_type == "":
|
41
|
+
self.create_type = self.type
|
42
|
+
self.type = sql_create_type_map.get(self.type.upper(), self.type)
|
43
|
+
if constraints is not None:
|
44
|
+
if type(constraints) is str:
|
45
|
+
constraints = (constraints,)
|
46
|
+
self._constraints = tuple(constraints)
|
47
|
+
|
48
|
+
def create_table_string(self) -> str:
|
49
|
+
parts = (
|
50
|
+
self.create_type,
|
51
|
+
*(() if self.nullable else ("NOT NULL",)),
|
52
|
+
*self._constraints,
|
53
|
+
)
|
54
|
+
return " ".join(parts)
|
58
55
|
|
59
56
|
|
60
57
|
sql_create_type_map = {
|
@@ -64,43 +61,37 @@ sql_create_type_map = {
|
|
64
61
|
}
|
65
62
|
|
66
63
|
|
67
|
-
sql_type_map = {
|
68
|
-
Optional[bool]: ("BOOLEAN",),
|
69
|
-
Optional[bytes]: ("BYTEA",),
|
70
|
-
Optional[datetime.date]: ("DATE",),
|
71
|
-
Optional[datetime.datetime]: ("TIMESTAMP",),
|
72
|
-
Optional[float]: ("DOUBLE PRECISION",),
|
73
|
-
Optional[int]: ("INTEGER",),
|
74
|
-
Optional[str]: ("TEXT",),
|
75
|
-
Optional[uuid.UUID]: ("UUID",),
|
76
|
-
bool: ("BOOLEAN",
|
77
|
-
bytes: ("BYTEA",
|
78
|
-
datetime.date: ("DATE",
|
79
|
-
datetime.datetime: ("TIMESTAMP",
|
80
|
-
float: ("DOUBLE PRECISION",
|
81
|
-
int: ("INTEGER",
|
82
|
-
str: ("TEXT",
|
83
|
-
uuid.UUID: ("UUID",
|
64
|
+
sql_type_map: dict[Any, tuple[str, bool]] = {
|
65
|
+
Optional[bool]: ("BOOLEAN", True),
|
66
|
+
Optional[bytes]: ("BYTEA", True),
|
67
|
+
Optional[datetime.date]: ("DATE", True),
|
68
|
+
Optional[datetime.datetime]: ("TIMESTAMP", True),
|
69
|
+
Optional[float]: ("DOUBLE PRECISION", True),
|
70
|
+
Optional[int]: ("INTEGER", True),
|
71
|
+
Optional[str]: ("TEXT", True),
|
72
|
+
Optional[uuid.UUID]: ("UUID", True),
|
73
|
+
bool: ("BOOLEAN", False),
|
74
|
+
bytes: ("BYTEA", False),
|
75
|
+
datetime.date: ("DATE", False),
|
76
|
+
datetime.datetime: ("TIMESTAMP", False),
|
77
|
+
float: ("DOUBLE PRECISION", False),
|
78
|
+
int: ("INTEGER", False),
|
79
|
+
str: ("TEXT", False),
|
80
|
+
uuid.UUID: ("UUID", False),
|
84
81
|
}
|
85
82
|
|
86
83
|
|
87
|
-
def column_info_for_field(field):
|
88
|
-
if "sql_athame" in field.metadata:
|
89
|
-
return field.metadata["sql_athame"]
|
90
|
-
type, *constraints = sql_type_map[field.type]
|
91
|
-
return ColumnInfo(type, type, tuple(constraints))
|
92
|
-
|
93
|
-
|
94
84
|
T = TypeVar("T", bound="ModelBase")
|
95
85
|
U = TypeVar("U")
|
96
86
|
|
97
87
|
|
98
|
-
class ModelBase
|
99
|
-
_column_info: Optional[
|
100
|
-
_cache:
|
88
|
+
class ModelBase:
|
89
|
+
_column_info: Optional[dict[str, ColumnInfo]]
|
90
|
+
_cache: dict[tuple, Any]
|
101
91
|
table_name: str
|
102
|
-
primary_key_names:
|
92
|
+
primary_key_names: tuple[str, ...]
|
103
93
|
array_safe_insert: bool
|
94
|
+
_type_hints: dict[str, type]
|
104
95
|
|
105
96
|
def __init_subclass__(
|
106
97
|
cls,
|
@@ -138,27 +129,34 @@ class ModelBase(Mapping[str, Any]):
|
|
138
129
|
cls._cache[key] = thunk()
|
139
130
|
return cls._cache[key]
|
140
131
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
return iter(self.keys())
|
149
|
-
|
150
|
-
def __len__(self) -> int:
|
151
|
-
return len(self.keys())
|
132
|
+
@classmethod
|
133
|
+
def type_hints(cls) -> dict[str, type]:
|
134
|
+
try:
|
135
|
+
return cls._type_hints
|
136
|
+
except AttributeError:
|
137
|
+
cls._type_hints = get_type_hints(cls, include_extras=True)
|
138
|
+
return cls._type_hints
|
152
139
|
|
153
|
-
|
154
|
-
|
140
|
+
@classmethod
|
141
|
+
def column_info_for_field(cls, field: Field) -> ColumnInfo:
|
142
|
+
type_info = cls.type_hints()[field.name]
|
143
|
+
base_type = type_info
|
144
|
+
if get_origin(type_info) is Annotated:
|
145
|
+
base_type = type_info.__origin__ # type: ignore
|
146
|
+
for md in type_info.__metadata__: # type: ignore
|
147
|
+
if isinstance(md, ColumnInfo):
|
148
|
+
return md
|
149
|
+
type, nullable = sql_type_map[base_type]
|
150
|
+
return ColumnInfo(type=type, nullable=nullable)
|
155
151
|
|
156
152
|
@classmethod
|
157
153
|
def column_info(cls, column: str) -> ColumnInfo:
|
158
154
|
try:
|
159
155
|
return cls._column_info[column] # type: ignore
|
160
156
|
except AttributeError:
|
161
|
-
cls._column_info = {
|
157
|
+
cls._column_info = {
|
158
|
+
f.name: cls.column_info_for_field(f) for f in cls._fields()
|
159
|
+
}
|
162
160
|
return cls._column_info[column]
|
163
161
|
|
164
162
|
@classmethod
|
@@ -166,17 +164,17 @@ class ModelBase(Mapping[str, Any]):
|
|
166
164
|
return sql.identifier(cls.table_name, prefix=prefix)
|
167
165
|
|
168
166
|
@classmethod
|
169
|
-
def primary_key_names_sql(cls, *, prefix: Optional[str] = None) ->
|
167
|
+
def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
|
170
168
|
return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
|
171
169
|
|
172
170
|
@classmethod
|
173
|
-
def field_names(cls, *, exclude: FieldNamesSet = ()) ->
|
171
|
+
def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
|
174
172
|
return [f.name for f in cls._fields() if f.name not in exclude]
|
175
173
|
|
176
174
|
@classmethod
|
177
175
|
def field_names_sql(
|
178
176
|
cls, *, prefix: Optional[str] = None, exclude: FieldNamesSet = ()
|
179
|
-
) ->
|
177
|
+
) -> list[Fragment]:
|
180
178
|
return [
|
181
179
|
sql.identifier(f, prefix=prefix) for f in cls.field_names(exclude=exclude)
|
182
180
|
]
|
@@ -186,9 +184,9 @@ class ModelBase(Mapping[str, Any]):
|
|
186
184
|
|
187
185
|
@classmethod
|
188
186
|
def _get_field_values_fn(
|
189
|
-
cls:
|
190
|
-
) -> Callable[[T],
|
191
|
-
env:
|
187
|
+
cls: type[T], exclude: FieldNamesSet = ()
|
188
|
+
) -> Callable[[T], list[Any]]:
|
189
|
+
env: dict[str, Any] = {}
|
192
190
|
func = ["def get_field_values(self): return ["]
|
193
191
|
for f in cls._fields():
|
194
192
|
if f.name not in exclude:
|
@@ -197,7 +195,7 @@ class ModelBase(Mapping[str, Any]):
|
|
197
195
|
exec(" ".join(func), env)
|
198
196
|
return env["get_field_values"]
|
199
197
|
|
200
|
-
def field_values(self, *, exclude: FieldNamesSet = ()) ->
|
198
|
+
def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
|
201
199
|
get_field_values = self._cached(
|
202
200
|
("get_field_values", tuple(sorted(exclude))),
|
203
201
|
lambda: self._get_field_values_fn(exclude),
|
@@ -206,7 +204,7 @@ class ModelBase(Mapping[str, Any]):
|
|
206
204
|
|
207
205
|
def field_values_sql(
|
208
206
|
self, *, exclude: FieldNamesSet = (), default_none: bool = False
|
209
|
-
) ->
|
207
|
+
) -> list[Fragment]:
|
210
208
|
if default_none:
|
211
209
|
return [
|
212
210
|
sql.literal("DEFAULT") if value is None else sql.value(value)
|
@@ -217,7 +215,7 @@ class ModelBase(Mapping[str, Any]):
|
|
217
215
|
|
218
216
|
@classmethod
|
219
217
|
def from_tuple(
|
220
|
-
cls:
|
218
|
+
cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
|
221
219
|
) -> T:
|
222
220
|
names = (f.name for f in cls._fields() if f.name not in exclude)
|
223
221
|
kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
|
@@ -225,14 +223,14 @@ class ModelBase(Mapping[str, Any]):
|
|
225
223
|
|
226
224
|
@classmethod
|
227
225
|
def from_dict(
|
228
|
-
cls:
|
226
|
+
cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
|
229
227
|
) -> T:
|
230
228
|
names = {f.name for f in cls._fields() if f.name not in exclude}
|
231
229
|
kwargs = {k: v for k, v in dct.items() if k in names}
|
232
230
|
return cls(**kwargs)
|
233
231
|
|
234
232
|
@classmethod
|
235
|
-
def ensure_model(cls:
|
233
|
+
def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
|
236
234
|
if isinstance(row, cls):
|
237
235
|
return row
|
238
236
|
return cls(**row)
|
@@ -286,7 +284,7 @@ class ModelBase(Mapping[str, Any]):
|
|
286
284
|
|
287
285
|
@classmethod
|
288
286
|
async def select_cursor(
|
289
|
-
cls:
|
287
|
+
cls: type[T],
|
290
288
|
connection: Connection,
|
291
289
|
order_by: Union[FieldNames, str] = (),
|
292
290
|
for_update: bool = False,
|
@@ -301,12 +299,12 @@ class ModelBase(Mapping[str, Any]):
|
|
301
299
|
|
302
300
|
@classmethod
|
303
301
|
async def select(
|
304
|
-
cls:
|
302
|
+
cls: type[T],
|
305
303
|
connection_or_pool: Union[Connection, Pool],
|
306
304
|
order_by: Union[FieldNames, str] = (),
|
307
305
|
for_update: bool = False,
|
308
306
|
where: Where = (),
|
309
|
-
) ->
|
307
|
+
) -> list[T]:
|
310
308
|
return [
|
311
309
|
cls(**row)
|
312
310
|
for row in await connection_or_pool.fetch(
|
@@ -315,7 +313,7 @@ class ModelBase(Mapping[str, Any]):
|
|
315
313
|
]
|
316
314
|
|
317
315
|
@classmethod
|
318
|
-
def create_sql(cls:
|
316
|
+
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
|
319
317
|
return sql(
|
320
318
|
"INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
|
321
319
|
table=cls.table_name_sql(),
|
@@ -326,7 +324,7 @@ class ModelBase(Mapping[str, Any]):
|
|
326
324
|
|
327
325
|
@classmethod
|
328
326
|
async def create(
|
329
|
-
cls:
|
327
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
|
330
328
|
) -> T:
|
331
329
|
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
|
332
330
|
return cls(**row)
|
@@ -375,11 +373,10 @@ class ModelBase(Mapping[str, Any]):
|
|
375
373
|
self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
|
376
374
|
)
|
377
375
|
result = await connection_or_pool.fetchrow(*query)
|
378
|
-
|
379
|
-
return is_update
|
376
|
+
return result["xmax"] != 0
|
380
377
|
|
381
378
|
@classmethod
|
382
|
-
def delete_multiple_sql(cls:
|
379
|
+
def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
383
380
|
cached = cls._cached(
|
384
381
|
("delete_multiple_sql",),
|
385
382
|
lambda: sql(
|
@@ -397,12 +394,12 @@ class ModelBase(Mapping[str, Any]):
|
|
397
394
|
|
398
395
|
@classmethod
|
399
396
|
async def delete_multiple(
|
400
|
-
cls:
|
397
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
401
398
|
) -> str:
|
402
399
|
return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
|
403
400
|
|
404
401
|
@classmethod
|
405
|
-
def insert_multiple_sql(cls:
|
402
|
+
def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
406
403
|
cached = cls._cached(
|
407
404
|
("insert_multiple_sql",),
|
408
405
|
lambda: sql(
|
@@ -419,7 +416,7 @@ class ModelBase(Mapping[str, Any]):
|
|
419
416
|
)
|
420
417
|
|
421
418
|
@classmethod
|
422
|
-
def insert_multiple_array_safe_sql(cls:
|
419
|
+
def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
423
420
|
return sql(
|
424
421
|
"INSERT INTO {table} ({fields}) VALUES {values}",
|
425
422
|
table=cls.table_name_sql(),
|
@@ -432,13 +429,13 @@ class ModelBase(Mapping[str, Any]):
|
|
432
429
|
|
433
430
|
@classmethod
|
434
431
|
async def insert_multiple_unnest(
|
435
|
-
cls:
|
432
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
436
433
|
) -> str:
|
437
434
|
return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
|
438
435
|
|
439
436
|
@classmethod
|
440
437
|
async def insert_multiple_array_safe(
|
441
|
-
cls:
|
438
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
442
439
|
) -> str:
|
443
440
|
last = ""
|
444
441
|
for chunk in chunked(rows, 100):
|
@@ -449,7 +446,7 @@ class ModelBase(Mapping[str, Any]):
|
|
449
446
|
|
450
447
|
@classmethod
|
451
448
|
async def insert_multiple(
|
452
|
-
cls:
|
449
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
453
450
|
) -> str:
|
454
451
|
if cls.array_safe_insert:
|
455
452
|
return await cls.insert_multiple_array_safe(connection_or_pool, rows)
|
@@ -458,37 +455,52 @@ class ModelBase(Mapping[str, Any]):
|
|
458
455
|
|
459
456
|
@classmethod
|
460
457
|
async def upsert_multiple_unnest(
|
461
|
-
cls:
|
458
|
+
cls: type[T],
|
459
|
+
connection_or_pool: Union[Connection, Pool],
|
460
|
+
rows: Iterable[T],
|
461
|
+
insert_only: FieldNamesSet = (),
|
462
462
|
) -> str:
|
463
463
|
return await connection_or_pool.execute(
|
464
|
-
*cls.upsert_sql(cls.insert_multiple_sql(rows))
|
464
|
+
*cls.upsert_sql(cls.insert_multiple_sql(rows), exclude=insert_only)
|
465
465
|
)
|
466
466
|
|
467
467
|
@classmethod
|
468
468
|
async def upsert_multiple_array_safe(
|
469
|
-
cls:
|
469
|
+
cls: type[T],
|
470
|
+
connection_or_pool: Union[Connection, Pool],
|
471
|
+
rows: Iterable[T],
|
472
|
+
insert_only: FieldNamesSet = (),
|
470
473
|
) -> str:
|
471
474
|
last = ""
|
472
475
|
for chunk in chunked(rows, 100):
|
473
476
|
last = await connection_or_pool.execute(
|
474
|
-
*cls.upsert_sql(
|
477
|
+
*cls.upsert_sql(
|
478
|
+
cls.insert_multiple_array_safe_sql(chunk), exclude=insert_only
|
479
|
+
)
|
475
480
|
)
|
476
481
|
return last
|
477
482
|
|
478
483
|
@classmethod
|
479
484
|
async def upsert_multiple(
|
480
|
-
cls:
|
485
|
+
cls: type[T],
|
486
|
+
connection_or_pool: Union[Connection, Pool],
|
487
|
+
rows: Iterable[T],
|
488
|
+
insert_only: FieldNamesSet = (),
|
481
489
|
) -> str:
|
482
490
|
if cls.array_safe_insert:
|
483
|
-
return await cls.upsert_multiple_array_safe(
|
491
|
+
return await cls.upsert_multiple_array_safe(
|
492
|
+
connection_or_pool, rows, insert_only=insert_only
|
493
|
+
)
|
484
494
|
else:
|
485
|
-
return await cls.upsert_multiple_unnest(
|
495
|
+
return await cls.upsert_multiple_unnest(
|
496
|
+
connection_or_pool, rows, insert_only=insert_only
|
497
|
+
)
|
486
498
|
|
487
499
|
@classmethod
|
488
500
|
def _get_equal_ignoring_fn(
|
489
|
-
cls:
|
501
|
+
cls: type[T], ignore: FieldNamesSet = ()
|
490
502
|
) -> Callable[[T, T], bool]:
|
491
|
-
env:
|
503
|
+
env: dict[str, Any] = {}
|
492
504
|
func = ["def equal_ignoring(a, b):"]
|
493
505
|
for f in cls._fields():
|
494
506
|
if f.name not in ignore:
|
@@ -499,15 +511,17 @@ class ModelBase(Mapping[str, Any]):
|
|
499
511
|
|
500
512
|
@classmethod
|
501
513
|
async def replace_multiple(
|
502
|
-
cls:
|
514
|
+
cls: type[T],
|
503
515
|
connection: Connection,
|
504
516
|
rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
|
505
517
|
*,
|
506
518
|
where: Where,
|
507
519
|
ignore: FieldNamesSet = (),
|
508
|
-
|
520
|
+
insert_only: FieldNamesSet = (),
|
521
|
+
) -> tuple[list[T], list[T], list[T]]:
|
522
|
+
ignore = sorted(set(ignore) | set(insert_only))
|
509
523
|
equal_ignoring = cls._cached(
|
510
|
-
("equal_ignoring", tuple(
|
524
|
+
("equal_ignoring", tuple(ignore)),
|
511
525
|
lambda: cls._get_equal_ignoring_fn(ignore),
|
512
526
|
)
|
513
527
|
pending = {row.primary_key(): row for row in map(cls.ensure_model, rows)}
|
@@ -529,7 +543,9 @@ class ModelBase(Mapping[str, Any]):
|
|
529
543
|
created = list(pending.values())
|
530
544
|
|
531
545
|
if created or updated:
|
532
|
-
await cls.upsert_multiple(
|
546
|
+
await cls.upsert_multiple(
|
547
|
+
connection, (*created, *updated), insert_only=insert_only
|
548
|
+
)
|
533
549
|
if deleted:
|
534
550
|
await cls.delete_multiple(connection, deleted)
|
535
551
|
|
@@ -537,33 +553,33 @@ class ModelBase(Mapping[str, Any]):
|
|
537
553
|
|
538
554
|
@classmethod
|
539
555
|
def _get_differences_ignoring_fn(
|
540
|
-
cls:
|
541
|
-
) -> Callable[[T, T],
|
542
|
-
env:
|
556
|
+
cls: type[T], ignore: FieldNamesSet = ()
|
557
|
+
) -> Callable[[T, T], list[str]]:
|
558
|
+
env: dict[str, Any] = {}
|
543
559
|
func = [
|
544
560
|
"def differences_ignoring(a, b):",
|
545
561
|
" diffs = []",
|
546
562
|
]
|
547
563
|
for f in cls._fields():
|
548
564
|
if f.name not in ignore:
|
549
|
-
func.append(
|
550
|
-
f" if a.{f.name} != b.{f.name}: diffs.append({repr(f.name)})"
|
551
|
-
)
|
565
|
+
func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
|
552
566
|
func += [" return diffs"]
|
553
567
|
exec("\n".join(func), env)
|
554
568
|
return env["differences_ignoring"]
|
555
569
|
|
556
570
|
@classmethod
|
557
571
|
async def replace_multiple_reporting_differences(
|
558
|
-
cls:
|
572
|
+
cls: type[T],
|
559
573
|
connection: Connection,
|
560
574
|
rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
|
561
575
|
*,
|
562
576
|
where: Where,
|
563
577
|
ignore: FieldNamesSet = (),
|
564
|
-
|
578
|
+
insert_only: FieldNamesSet = (),
|
579
|
+
) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
|
580
|
+
ignore = sorted(set(ignore) | set(insert_only))
|
565
581
|
differences_ignoring = cls._cached(
|
566
|
-
("differences_ignoring", tuple(
|
582
|
+
("differences_ignoring", tuple(ignore)),
|
567
583
|
lambda: cls._get_differences_ignoring_fn(ignore),
|
568
584
|
)
|
569
585
|
|
@@ -588,7 +604,9 @@ class ModelBase(Mapping[str, Any]):
|
|
588
604
|
|
589
605
|
if created or updated_triples:
|
590
606
|
await cls.upsert_multiple(
|
591
|
-
connection,
|
607
|
+
connection,
|
608
|
+
(*created, *(t[1] for t in updated_triples)),
|
609
|
+
insert_only=insert_only,
|
592
610
|
)
|
593
611
|
if deleted:
|
594
612
|
await cls.delete_multiple(connection, deleted)
|
sql_athame/escape.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1
1
|
import math
|
2
2
|
import uuid
|
3
|
-
from
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
|
6
7
|
def escape(value: Any) -> str:
|
7
8
|
if isinstance(value, str):
|
8
|
-
return f"E{
|
9
|
+
return f"E{value!r}"
|
9
10
|
elif isinstance(value, float) or isinstance(value, int):
|
10
11
|
if math.isnan(value):
|
11
12
|
raise ValueError("Can't escape NaN float")
|
12
13
|
elif math.isinf(value):
|
13
14
|
raise ValueError("Can't escape infinite float")
|
14
|
-
return f"{
|
15
|
+
return f"{value!r}"
|
15
16
|
elif isinstance(value, uuid.UUID):
|
16
|
-
return f"{
|
17
|
+
return f"{str(value)!r}::UUID"
|
17
18
|
elif isinstance(value, Sequence):
|
18
19
|
args = ", ".join(escape(x) for x in value)
|
19
20
|
return f"ARRAY[{args}]"
|
sql_athame/sqlalchemy.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Any
|
1
|
+
from typing import TYPE_CHECKING, Any
|
2
2
|
|
3
3
|
from .types import FlatPart, Placeholder, Slot
|
4
4
|
|
@@ -7,10 +7,10 @@ try:
|
|
7
7
|
from sqlalchemy.sql.elements import BindParameter
|
8
8
|
|
9
9
|
def sqlalchemy_text_from_fragment(self: "Fragment") -> Any:
|
10
|
-
parts:
|
10
|
+
parts: list[FlatPart] = []
|
11
11
|
self.flatten_into(parts)
|
12
|
-
bindparams:
|
13
|
-
out_parts:
|
12
|
+
bindparams: dict[str, Any] = {}
|
13
|
+
out_parts: list[str] = []
|
14
14
|
for part in parts:
|
15
15
|
if isinstance(part, Slot):
|
16
16
|
out_parts.append(f"(:{part.name})")
|
sql_athame/types.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import dataclasses
|
2
2
|
from typing import TYPE_CHECKING, Any, Union
|
3
3
|
|
4
|
+
from typing_extensions import TypeAlias
|
5
|
+
|
4
6
|
|
5
7
|
@dataclasses.dataclass(eq=False)
|
6
8
|
class Placeholder:
|
@@ -15,8 +17,8 @@ class Slot:
|
|
15
17
|
name: str
|
16
18
|
|
17
19
|
|
18
|
-
Part = Union[str, Placeholder, Slot, "Fragment"]
|
19
|
-
FlatPart = Union[str, Placeholder, Slot]
|
20
|
+
Part: TypeAlias = Union[str, Placeholder, Slot, "Fragment"]
|
21
|
+
FlatPart: TypeAlias = Union[str, Placeholder, Slot]
|
20
22
|
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from .base import Fragment
|
@@ -0,0 +1,11 @@
|
|
1
|
+
sql_athame/__init__.py,sha256=7OBIMZOcrD2pvfIL-rjD1IGZ3TNQbwyu76a9PWk-yYg,79
|
2
|
+
sql_athame/base.py,sha256=FR7EmC0VkX1VRgvAutSEfYSWhlEYpoqS1Kqxp1jHp6Y,10293
|
3
|
+
sql_athame/dataclasses.py,sha256=qb4EESR6J-iv6UScktMLuKAwH3ZA3IOwCM0v6oMv8Q8,20848
|
4
|
+
sql_athame/escape.py,sha256=kK101xXeFitlvuG-L_hvhdpgGJCtmRTprsn1yEfZKws,758
|
5
|
+
sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
sql_athame/sqlalchemy.py,sha256=aWopfPh3j71XwKmcN_VcHRNlhscI0Sckd4AiyGf8Tpw,1293
|
7
|
+
sql_athame/types.py,sha256=FQ06l9Uc-vo57UrAarvnukILdV2gN1IaYUnHJ_bNYic,475
|
8
|
+
sql_athame-0.4.0a6.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
|
9
|
+
sql_athame-0.4.0a6.dist-info/METADATA,sha256=8Ov07iCKPAo35uWW3t9WRRlXIgRlNGGrMpyXnOI6TRs,12845
|
10
|
+
sql_athame-0.4.0a6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
11
|
+
sql_athame-0.4.0a6.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
sql_athame/__init__.py,sha256=rzUQcbzmj3qkPZpL9jI_ALTRv-e1pAV4jSCryWkutlk,130
|
2
|
-
sql_athame/base.py,sha256=fSnHQhh5ULeJ5q32RVUAvpWtF0qoY61B2gEEP59Nrpo,10350
|
3
|
-
sql_athame/dataclasses.py,sha256=EPq9wd2mqcEvT0kAEhrZztlDDV3Hroiwa8GoP3hi_l8,19787
|
4
|
-
sql_athame/escape.py,sha256=LXExbiYtc407yDU4vPieyY2Pq5nypsJFfBc_2-gsbUg,743
|
5
|
-
sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
-
sql_athame/sqlalchemy.py,sha256=c-pCLE11hTh5I19rY1Vp5E7P7lAaj9i-i7ko2L8rlF4,1305
|
7
|
-
sql_athame/types.py,sha256=7P4OyY0ezRlb2UDD9lpdXiLChnhQcBvHWaG_PKy3jmE,412
|
8
|
-
sql_athame-0.4.0a4.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
|
9
|
-
sql_athame-0.4.0a4.dist-info/METADATA,sha256=J5gkrkq4cAwD81oODgQjugY9y3k7-RImhFvqLMaJqr4,12845
|
10
|
-
sql_athame-0.4.0a4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
11
|
-
sql_athame-0.4.0a4.dist-info/RECORD,,
|
File without changes
|
File without changes
|