sql-athame 0.4.0a5__py3-none-any.whl → 0.4.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_athame/__init__.py +2 -2
- sql_athame/base.py +32 -38
- sql_athame/dataclasses.py +150 -121
- sql_athame/escape.py +5 -4
- sql_athame/sqlalchemy.py +4 -4
- sql_athame/types.py +4 -2
- {sql_athame-0.4.0a5.dist-info → sql_athame-0.4.0a7.dist-info}/METADATA +1 -1
- sql_athame-0.4.0a7.dist-info/RECORD +11 -0
- sql_athame-0.4.0a5.dist-info/RECORD +0 -11
- {sql_athame-0.4.0a5.dist-info → sql_athame-0.4.0a7.dist-info}/LICENSE +0 -0
- {sql_athame-0.4.0a5.dist-info → sql_athame-0.4.0a7.dist-info}/WHEEL +0 -0
sql_athame/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from .base import Fragment, sql
|
2
|
-
from .dataclasses import
|
1
|
+
from .base import Fragment, sql
|
2
|
+
from .dataclasses import ColumnInfo, ModelBase
|
sql_athame/base.py
CHANGED
@@ -2,16 +2,11 @@ import dataclasses
|
|
2
2
|
import json
|
3
3
|
import re
|
4
4
|
import string
|
5
|
+
from collections.abc import Iterable, Iterator, Sequence
|
5
6
|
from typing import (
|
6
7
|
Any,
|
7
8
|
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
Iterator,
|
11
|
-
List,
|
12
9
|
Optional,
|
13
|
-
Sequence,
|
14
|
-
Tuple,
|
15
10
|
Union,
|
16
11
|
cast,
|
17
12
|
overload,
|
@@ -34,7 +29,7 @@ def auto_numbered(field_name):
|
|
34
29
|
def process_slot_value(
|
35
30
|
name: str,
|
36
31
|
value: Any,
|
37
|
-
placeholders:
|
32
|
+
placeholders: dict[str, Placeholder],
|
38
33
|
) -> Union["Fragment", Placeholder]:
|
39
34
|
if isinstance(value, Fragment):
|
40
35
|
return value
|
@@ -47,9 +42,9 @@ def process_slot_value(
|
|
47
42
|
@dataclasses.dataclass
|
48
43
|
class Fragment:
|
49
44
|
__slots__ = ["parts"]
|
50
|
-
parts:
|
45
|
+
parts: list[Part]
|
51
46
|
|
52
|
-
def flatten_into(self, parts:
|
47
|
+
def flatten_into(self, parts: list[FlatPart]) -> None:
|
53
48
|
for part in self.parts:
|
54
49
|
if isinstance(part, Fragment):
|
55
50
|
part.flatten_into(parts)
|
@@ -70,10 +65,10 @@ class Fragment:
|
|
70
65
|
for i, part in enumerate(flattened.parts):
|
71
66
|
if isinstance(part, Slot):
|
72
67
|
func.append(
|
73
|
-
f" process_slot_value({
|
68
|
+
f" process_slot_value({part.name!r}, slots[{part.name!r}], placeholders),"
|
74
69
|
)
|
75
70
|
elif isinstance(part, str):
|
76
|
-
func.append(f" {
|
71
|
+
func.append(f" {part!r},")
|
77
72
|
else:
|
78
73
|
env[f"part_{i}"] = part
|
79
74
|
func.append(f" part_{i},")
|
@@ -82,9 +77,9 @@ class Fragment:
|
|
82
77
|
return env["compiled"] # type: ignore
|
83
78
|
|
84
79
|
def flatten(self) -> "Fragment":
|
85
|
-
parts:
|
80
|
+
parts: list[FlatPart] = []
|
86
81
|
self.flatten_into(parts)
|
87
|
-
out_parts:
|
82
|
+
out_parts: list[Part] = []
|
88
83
|
for part in parts:
|
89
84
|
if isinstance(part, str) and out_parts and isinstance(out_parts[-1], str):
|
90
85
|
out_parts[-1] += part
|
@@ -93,9 +88,9 @@ class Fragment:
|
|
93
88
|
return Fragment(out_parts)
|
94
89
|
|
95
90
|
def fill(self, **kwargs: Any) -> "Fragment":
|
96
|
-
parts:
|
97
|
-
self.flatten_into(cast(
|
98
|
-
placeholders:
|
91
|
+
parts: list[Part] = []
|
92
|
+
self.flatten_into(cast(list[FlatPart], parts))
|
93
|
+
placeholders: dict[str, Placeholder] = {}
|
99
94
|
for i, part in enumerate(parts):
|
100
95
|
if isinstance(part, Slot):
|
101
96
|
parts[i] = process_slot_value(
|
@@ -106,24 +101,24 @@ class Fragment:
|
|
106
101
|
@overload
|
107
102
|
def prep_query(
|
108
103
|
self, allow_slots: Literal[True]
|
109
|
-
) ->
|
104
|
+
) -> tuple[str, list[Union[Placeholder, Slot]]]: ... # pragma: no cover
|
110
105
|
|
111
106
|
@overload
|
112
107
|
def prep_query(
|
113
108
|
self, allow_slots: Literal[False] = False
|
114
|
-
) ->
|
109
|
+
) -> tuple[str, list[Placeholder]]: ... # pragma: no cover
|
115
110
|
|
116
|
-
def prep_query(self, allow_slots: bool = False) ->
|
117
|
-
parts:
|
111
|
+
def prep_query(self, allow_slots: bool = False) -> tuple[str, list[Any]]:
|
112
|
+
parts: list[FlatPart] = []
|
118
113
|
self.flatten_into(parts)
|
119
|
-
args:
|
120
|
-
placeholder_ids:
|
121
|
-
slot_ids:
|
122
|
-
out_parts:
|
114
|
+
args: list[Union[Placeholder, Slot]] = []
|
115
|
+
placeholder_ids: dict[Placeholder, int] = {}
|
116
|
+
slot_ids: dict[Slot, int] = {}
|
117
|
+
out_parts: list[str] = []
|
123
118
|
for part in parts:
|
124
119
|
if isinstance(part, Slot):
|
125
120
|
if not allow_slots:
|
126
|
-
raise ValueError(f"Unfilled slot: {
|
121
|
+
raise ValueError(f"Unfilled slot: {part.name!r}")
|
127
122
|
if part not in slot_ids:
|
128
123
|
args.append(part)
|
129
124
|
slot_ids[part] = len(args)
|
@@ -138,7 +133,7 @@ class Fragment:
|
|
138
133
|
out_parts.append(part)
|
139
134
|
return "".join(out_parts).strip(), args
|
140
135
|
|
141
|
-
def query(self) ->
|
136
|
+
def query(self) -> tuple[str, list[Any]]:
|
142
137
|
query, args = self.prep_query()
|
143
138
|
placeholder_values = [arg.value for arg in args]
|
144
139
|
return query, placeholder_values
|
@@ -146,16 +141,16 @@ class Fragment:
|
|
146
141
|
def sqlalchemy_text(self) -> Any:
|
147
142
|
return sqlalchemy_text_from_fragment(self)
|
148
143
|
|
149
|
-
def prepare(self) ->
|
144
|
+
def prepare(self) -> tuple[str, Callable[..., list[Any]]]:
|
150
145
|
query, args = self.prep_query(allow_slots=True)
|
151
|
-
env =
|
146
|
+
env = {}
|
152
147
|
func = [
|
153
148
|
"def generate_args(**kwargs):",
|
154
149
|
" return [",
|
155
150
|
]
|
156
151
|
for i, arg in enumerate(args):
|
157
152
|
if isinstance(arg, Slot):
|
158
|
-
func.append(f" kwargs[{
|
153
|
+
func.append(f" kwargs[{arg.name!r}],")
|
159
154
|
else:
|
160
155
|
env[f"value_{i}"] = arg.value
|
161
156
|
func.append(f" value_{i},")
|
@@ -178,10 +173,10 @@ class SQLFormatter:
|
|
178
173
|
if not preserve_formatting:
|
179
174
|
fmt = newline_whitespace_re.sub(" ", fmt)
|
180
175
|
fmtr = string.Formatter()
|
181
|
-
parts:
|
182
|
-
placeholders:
|
176
|
+
parts: list[Part] = []
|
177
|
+
placeholders: dict[str, Placeholder] = {}
|
183
178
|
next_auto_field = 0
|
184
|
-
for literal_text, field_name,
|
179
|
+
for literal_text, field_name, _format_spec, _conversion in fmtr.parse(fmt):
|
185
180
|
parts.append(literal_text)
|
186
181
|
if field_name is not None:
|
187
182
|
if auto_numbered(field_name):
|
@@ -258,12 +253,11 @@ class SQLFormatter:
|
|
258
253
|
parts = parts[0]
|
259
254
|
return Fragment(list(join_parts(parts, infix=", ")))
|
260
255
|
|
261
|
-
|
262
|
-
|
263
|
-
nested = list(nest_for_type(x, t) for x, t in zip(zip(*data), types))
|
256
|
+
def unnest(self, data: Iterable[Sequence[Any]], types: Iterable[str]) -> Fragment:
|
257
|
+
nested = [nest_for_type(x, t) for x, t in zip(zip(*data), types)]
|
264
258
|
if not nested:
|
265
|
-
nested =
|
266
|
-
return Fragment(["UNNEST(",
|
259
|
+
nested = [nest_for_type([], t) for t in types]
|
260
|
+
return Fragment(["UNNEST(", self.list(nested), ")"])
|
267
261
|
|
268
262
|
|
269
263
|
sql = SQLFormatter()
|
@@ -296,7 +290,7 @@ def lit(text: str) -> Fragment:
|
|
296
290
|
return Fragment([text])
|
297
291
|
|
298
292
|
|
299
|
-
def any_all(frags:
|
293
|
+
def any_all(frags: list[Fragment], op: str, base_case: str) -> Fragment:
|
300
294
|
if not frags:
|
301
295
|
return lit(base_case)
|
302
296
|
parts = join_parts(frags, prefix="(", infix=f") {op} (", suffix=")")
|
sql_athame/dataclasses.py
CHANGED
@@ -1,60 +1,87 @@
|
|
1
1
|
import datetime
|
2
|
+
import functools
|
2
3
|
import uuid
|
3
|
-
from
|
4
|
+
from collections.abc import AsyncGenerator, Iterable, Mapping
|
5
|
+
from dataclasses import Field, InitVar, dataclass, fields
|
4
6
|
from typing import (
|
7
|
+
Annotated,
|
5
8
|
Any,
|
6
|
-
AsyncGenerator,
|
7
9
|
Callable,
|
8
|
-
Dict,
|
9
|
-
Iterable,
|
10
|
-
Iterator,
|
11
|
-
List,
|
12
|
-
Mapping,
|
13
10
|
Optional,
|
14
|
-
Set,
|
15
|
-
Tuple,
|
16
|
-
Type,
|
17
11
|
TypeVar,
|
18
12
|
Union,
|
13
|
+
get_origin,
|
14
|
+
get_type_hints,
|
19
15
|
)
|
20
16
|
|
17
|
+
from typing_extensions import TypeAlias
|
18
|
+
|
21
19
|
from .base import Fragment, sql
|
22
20
|
|
23
|
-
Where = Union[Fragment, Iterable[Fragment]]
|
21
|
+
Where: TypeAlias = Union[Fragment, Iterable[Fragment]]
|
24
22
|
# KLUDGE to avoid a string argument being valid
|
25
|
-
SequenceOfStrings = Union[
|
26
|
-
FieldNames = SequenceOfStrings
|
27
|
-
FieldNamesSet = Union[SequenceOfStrings,
|
23
|
+
SequenceOfStrings: TypeAlias = Union[list[str], tuple[str, ...]]
|
24
|
+
FieldNames: TypeAlias = SequenceOfStrings
|
25
|
+
FieldNamesSet: TypeAlias = Union[SequenceOfStrings, set[str]]
|
28
26
|
|
29
|
-
Connection = Any
|
30
|
-
Pool = Any
|
27
|
+
Connection: TypeAlias = Any
|
28
|
+
Pool: TypeAlias = Any
|
31
29
|
|
32
30
|
|
33
31
|
@dataclass
|
34
32
|
class ColumnInfo:
|
35
|
-
type: str
|
36
|
-
create_type: str
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
33
|
+
type: Optional[str] = None
|
34
|
+
create_type: Optional[str] = None
|
35
|
+
nullable: Optional[bool] = None
|
36
|
+
_constraints: tuple[str, ...] = ()
|
37
|
+
|
38
|
+
constraints: InitVar[Union[str, Iterable[str], None]] = None
|
39
|
+
|
40
|
+
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
|
41
|
+
if constraints is not None:
|
42
|
+
if type(constraints) is str:
|
43
|
+
constraints = (constraints,)
|
44
|
+
self._constraints = tuple(constraints)
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
|
48
|
+
return ColumnInfo(
|
49
|
+
type=b.type if b.type is not None else a.type,
|
50
|
+
create_type=b.create_type if b.create_type is not None else a.create_type,
|
51
|
+
nullable=b.nullable if b.nullable is not None else a.nullable,
|
52
|
+
_constraints=(*a._constraints, *b._constraints),
|
53
|
+
)
|
42
54
|
|
43
|
-
def model_field_metadata(
|
44
|
-
type: str, constraints: Union[str, Iterable[str]] = ()
|
45
|
-
) -> Dict[str, Any]:
|
46
|
-
if isinstance(constraints, str):
|
47
|
-
constraints = (constraints,)
|
48
|
-
info = ColumnInfo(
|
49
|
-
sql_create_type_map.get(type.upper(), type), type, tuple(constraints)
|
50
|
-
)
|
51
|
-
return {"sql_athame": info}
|
52
55
|
|
56
|
+
@dataclass
|
57
|
+
class ConcreteColumnInfo:
|
58
|
+
type: str
|
59
|
+
create_type: str
|
60
|
+
nullable: bool
|
61
|
+
constraints: tuple[str, ...]
|
62
|
+
|
63
|
+
@staticmethod
|
64
|
+
def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
|
65
|
+
info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
|
66
|
+
if info.create_type is None and info.type is not None:
|
67
|
+
info.create_type = info.type
|
68
|
+
info.type = sql_create_type_map.get(info.type.upper(), info.type)
|
69
|
+
if type(info.type) is not str or type(info.create_type) is not str:
|
70
|
+
raise ValueError(f"Missing SQL type for column {name!r}")
|
71
|
+
return ConcreteColumnInfo(
|
72
|
+
type=info.type,
|
73
|
+
create_type=info.create_type,
|
74
|
+
nullable=bool(info.nullable),
|
75
|
+
constraints=info._constraints,
|
76
|
+
)
|
53
77
|
|
54
|
-
def
|
55
|
-
|
56
|
-
|
57
|
-
|
78
|
+
def create_table_string(self) -> str:
|
79
|
+
parts = (
|
80
|
+
self.create_type,
|
81
|
+
*(() if self.nullable else ("NOT NULL",)),
|
82
|
+
*self.constraints,
|
83
|
+
)
|
84
|
+
return " ".join(parts)
|
58
85
|
|
59
86
|
|
60
87
|
sql_create_type_map = {
|
@@ -64,43 +91,37 @@ sql_create_type_map = {
|
|
64
91
|
}
|
65
92
|
|
66
93
|
|
67
|
-
sql_type_map = {
|
68
|
-
Optional[bool]: ("BOOLEAN",),
|
69
|
-
Optional[bytes]: ("BYTEA",),
|
70
|
-
Optional[datetime.date]: ("DATE",),
|
71
|
-
Optional[datetime.datetime]: ("TIMESTAMP",),
|
72
|
-
Optional[float]: ("DOUBLE PRECISION",),
|
73
|
-
Optional[int]: ("INTEGER",),
|
74
|
-
Optional[str]: ("TEXT",),
|
75
|
-
Optional[uuid.UUID]: ("UUID",),
|
76
|
-
bool: ("BOOLEAN",
|
77
|
-
bytes: ("BYTEA",
|
78
|
-
datetime.date: ("DATE",
|
79
|
-
datetime.datetime: ("TIMESTAMP",
|
80
|
-
float: ("DOUBLE PRECISION",
|
81
|
-
int: ("INTEGER",
|
82
|
-
str: ("TEXT",
|
83
|
-
uuid.UUID: ("UUID",
|
94
|
+
sql_type_map: dict[Any, tuple[str, bool]] = {
|
95
|
+
Optional[bool]: ("BOOLEAN", True),
|
96
|
+
Optional[bytes]: ("BYTEA", True),
|
97
|
+
Optional[datetime.date]: ("DATE", True),
|
98
|
+
Optional[datetime.datetime]: ("TIMESTAMP", True),
|
99
|
+
Optional[float]: ("DOUBLE PRECISION", True),
|
100
|
+
Optional[int]: ("INTEGER", True),
|
101
|
+
Optional[str]: ("TEXT", True),
|
102
|
+
Optional[uuid.UUID]: ("UUID", True),
|
103
|
+
bool: ("BOOLEAN", False),
|
104
|
+
bytes: ("BYTEA", False),
|
105
|
+
datetime.date: ("DATE", False),
|
106
|
+
datetime.datetime: ("TIMESTAMP", False),
|
107
|
+
float: ("DOUBLE PRECISION", False),
|
108
|
+
int: ("INTEGER", False),
|
109
|
+
str: ("TEXT", False),
|
110
|
+
uuid.UUID: ("UUID", False),
|
84
111
|
}
|
85
112
|
|
86
113
|
|
87
|
-
def column_info_for_field(field):
|
88
|
-
if "sql_athame" in field.metadata:
|
89
|
-
return field.metadata["sql_athame"]
|
90
|
-
type, *constraints = sql_type_map[field.type]
|
91
|
-
return ColumnInfo(type, type, tuple(constraints))
|
92
|
-
|
93
|
-
|
94
114
|
T = TypeVar("T", bound="ModelBase")
|
95
115
|
U = TypeVar("U")
|
96
116
|
|
97
117
|
|
98
|
-
class ModelBase
|
99
|
-
_column_info: Optional[
|
100
|
-
_cache:
|
118
|
+
class ModelBase:
|
119
|
+
_column_info: Optional[dict[str, ConcreteColumnInfo]]
|
120
|
+
_cache: dict[tuple, Any]
|
101
121
|
table_name: str
|
102
|
-
primary_key_names:
|
122
|
+
primary_key_names: tuple[str, ...]
|
103
123
|
array_safe_insert: bool
|
124
|
+
_type_hints: dict[str, type]
|
104
125
|
|
105
126
|
def __init_subclass__(
|
106
127
|
cls,
|
@@ -138,27 +159,38 @@ class ModelBase(Mapping[str, Any]):
|
|
138
159
|
cls._cache[key] = thunk()
|
139
160
|
return cls._cache[key]
|
140
161
|
|
141
|
-
def keys(self):
|
142
|
-
return self.field_names()
|
143
|
-
|
144
|
-
def __getitem__(self, key: str) -> Any:
|
145
|
-
return getattr(self, key)
|
146
|
-
|
147
|
-
def __iter__(self) -> Iterator[Any]:
|
148
|
-
return iter(self.keys())
|
149
|
-
|
150
|
-
def __len__(self) -> int:
|
151
|
-
return len(self.keys())
|
152
|
-
|
153
|
-
def get(self, key: str, default: Any = None) -> Any:
|
154
|
-
return getattr(self, key, default)
|
155
|
-
|
156
162
|
@classmethod
|
157
|
-
def
|
163
|
+
def type_hints(cls) -> dict[str, type]:
|
164
|
+
try:
|
165
|
+
return cls._type_hints
|
166
|
+
except AttributeError:
|
167
|
+
cls._type_hints = get_type_hints(cls, include_extras=True)
|
168
|
+
return cls._type_hints
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
|
172
|
+
type_info = cls.type_hints()[field.name]
|
173
|
+
base_type = type_info
|
174
|
+
if get_origin(type_info) is Annotated:
|
175
|
+
base_type = type_info.__origin__ # type: ignore
|
176
|
+
info = []
|
177
|
+
if base_type in sql_type_map:
|
178
|
+
_type, nullable = sql_type_map[base_type]
|
179
|
+
info.append(ColumnInfo(type=_type, nullable=nullable))
|
180
|
+
if get_origin(type_info) is Annotated:
|
181
|
+
for md in type_info.__metadata__: # type: ignore
|
182
|
+
if isinstance(md, ColumnInfo):
|
183
|
+
info.append(md)
|
184
|
+
return ConcreteColumnInfo.from_column_info(field.name, *info)
|
185
|
+
|
186
|
+
@classmethod
|
187
|
+
def column_info(cls, column: str) -> ConcreteColumnInfo:
|
158
188
|
try:
|
159
189
|
return cls._column_info[column] # type: ignore
|
160
190
|
except AttributeError:
|
161
|
-
cls._column_info = {
|
191
|
+
cls._column_info = {
|
192
|
+
f.name: cls.column_info_for_field(f) for f in cls._fields()
|
193
|
+
}
|
162
194
|
return cls._column_info[column]
|
163
195
|
|
164
196
|
@classmethod
|
@@ -166,17 +198,17 @@ class ModelBase(Mapping[str, Any]):
|
|
166
198
|
return sql.identifier(cls.table_name, prefix=prefix)
|
167
199
|
|
168
200
|
@classmethod
|
169
|
-
def primary_key_names_sql(cls, *, prefix: Optional[str] = None) ->
|
201
|
+
def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
|
170
202
|
return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
|
171
203
|
|
172
204
|
@classmethod
|
173
|
-
def field_names(cls, *, exclude: FieldNamesSet = ()) ->
|
205
|
+
def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
|
174
206
|
return [f.name for f in cls._fields() if f.name not in exclude]
|
175
207
|
|
176
208
|
@classmethod
|
177
209
|
def field_names_sql(
|
178
210
|
cls, *, prefix: Optional[str] = None, exclude: FieldNamesSet = ()
|
179
|
-
) ->
|
211
|
+
) -> list[Fragment]:
|
180
212
|
return [
|
181
213
|
sql.identifier(f, prefix=prefix) for f in cls.field_names(exclude=exclude)
|
182
214
|
]
|
@@ -186,9 +218,9 @@ class ModelBase(Mapping[str, Any]):
|
|
186
218
|
|
187
219
|
@classmethod
|
188
220
|
def _get_field_values_fn(
|
189
|
-
cls:
|
190
|
-
) -> Callable[[T],
|
191
|
-
env:
|
221
|
+
cls: type[T], exclude: FieldNamesSet = ()
|
222
|
+
) -> Callable[[T], list[Any]]:
|
223
|
+
env: dict[str, Any] = {}
|
192
224
|
func = ["def get_field_values(self): return ["]
|
193
225
|
for f in cls._fields():
|
194
226
|
if f.name not in exclude:
|
@@ -197,7 +229,7 @@ class ModelBase(Mapping[str, Any]):
|
|
197
229
|
exec(" ".join(func), env)
|
198
230
|
return env["get_field_values"]
|
199
231
|
|
200
|
-
def field_values(self, *, exclude: FieldNamesSet = ()) ->
|
232
|
+
def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
|
201
233
|
get_field_values = self._cached(
|
202
234
|
("get_field_values", tuple(sorted(exclude))),
|
203
235
|
lambda: self._get_field_values_fn(exclude),
|
@@ -206,7 +238,7 @@ class ModelBase(Mapping[str, Any]):
|
|
206
238
|
|
207
239
|
def field_values_sql(
|
208
240
|
self, *, exclude: FieldNamesSet = (), default_none: bool = False
|
209
|
-
) ->
|
241
|
+
) -> list[Fragment]:
|
210
242
|
if default_none:
|
211
243
|
return [
|
212
244
|
sql.literal("DEFAULT") if value is None else sql.value(value)
|
@@ -217,7 +249,7 @@ class ModelBase(Mapping[str, Any]):
|
|
217
249
|
|
218
250
|
@classmethod
|
219
251
|
def from_tuple(
|
220
|
-
cls:
|
252
|
+
cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
|
221
253
|
) -> T:
|
222
254
|
names = (f.name for f in cls._fields() if f.name not in exclude)
|
223
255
|
kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
|
@@ -225,14 +257,14 @@ class ModelBase(Mapping[str, Any]):
|
|
225
257
|
|
226
258
|
@classmethod
|
227
259
|
def from_dict(
|
228
|
-
cls:
|
260
|
+
cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
|
229
261
|
) -> T:
|
230
262
|
names = {f.name for f in cls._fields() if f.name not in exclude}
|
231
263
|
kwargs = {k: v for k, v in dct.items() if k in names}
|
232
264
|
return cls(**kwargs)
|
233
265
|
|
234
266
|
@classmethod
|
235
|
-
def ensure_model(cls:
|
267
|
+
def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
|
236
268
|
if isinstance(row, cls):
|
237
269
|
return row
|
238
270
|
return cls(**row)
|
@@ -286,7 +318,7 @@ class ModelBase(Mapping[str, Any]):
|
|
286
318
|
|
287
319
|
@classmethod
|
288
320
|
async def select_cursor(
|
289
|
-
cls:
|
321
|
+
cls: type[T],
|
290
322
|
connection: Connection,
|
291
323
|
order_by: Union[FieldNames, str] = (),
|
292
324
|
for_update: bool = False,
|
@@ -301,12 +333,12 @@ class ModelBase(Mapping[str, Any]):
|
|
301
333
|
|
302
334
|
@classmethod
|
303
335
|
async def select(
|
304
|
-
cls:
|
336
|
+
cls: type[T],
|
305
337
|
connection_or_pool: Union[Connection, Pool],
|
306
338
|
order_by: Union[FieldNames, str] = (),
|
307
339
|
for_update: bool = False,
|
308
340
|
where: Where = (),
|
309
|
-
) ->
|
341
|
+
) -> list[T]:
|
310
342
|
return [
|
311
343
|
cls(**row)
|
312
344
|
for row in await connection_or_pool.fetch(
|
@@ -315,7 +347,7 @@ class ModelBase(Mapping[str, Any]):
|
|
315
347
|
]
|
316
348
|
|
317
349
|
@classmethod
|
318
|
-
def create_sql(cls:
|
350
|
+
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
|
319
351
|
return sql(
|
320
352
|
"INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
|
321
353
|
table=cls.table_name_sql(),
|
@@ -326,7 +358,7 @@ class ModelBase(Mapping[str, Any]):
|
|
326
358
|
|
327
359
|
@classmethod
|
328
360
|
async def create(
|
329
|
-
cls:
|
361
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
|
330
362
|
) -> T:
|
331
363
|
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
|
332
364
|
return cls(**row)
|
@@ -375,11 +407,10 @@ class ModelBase(Mapping[str, Any]):
|
|
375
407
|
self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
|
376
408
|
)
|
377
409
|
result = await connection_or_pool.fetchrow(*query)
|
378
|
-
|
379
|
-
return is_update
|
410
|
+
return result["xmax"] != 0
|
380
411
|
|
381
412
|
@classmethod
|
382
|
-
def delete_multiple_sql(cls:
|
413
|
+
def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
383
414
|
cached = cls._cached(
|
384
415
|
("delete_multiple_sql",),
|
385
416
|
lambda: sql(
|
@@ -397,12 +428,12 @@ class ModelBase(Mapping[str, Any]):
|
|
397
428
|
|
398
429
|
@classmethod
|
399
430
|
async def delete_multiple(
|
400
|
-
cls:
|
431
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
401
432
|
) -> str:
|
402
433
|
return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
|
403
434
|
|
404
435
|
@classmethod
|
405
|
-
def insert_multiple_sql(cls:
|
436
|
+
def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
406
437
|
cached = cls._cached(
|
407
438
|
("insert_multiple_sql",),
|
408
439
|
lambda: sql(
|
@@ -419,7 +450,7 @@ class ModelBase(Mapping[str, Any]):
|
|
419
450
|
)
|
420
451
|
|
421
452
|
@classmethod
|
422
|
-
def insert_multiple_array_safe_sql(cls:
|
453
|
+
def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
423
454
|
return sql(
|
424
455
|
"INSERT INTO {table} ({fields}) VALUES {values}",
|
425
456
|
table=cls.table_name_sql(),
|
@@ -432,13 +463,13 @@ class ModelBase(Mapping[str, Any]):
|
|
432
463
|
|
433
464
|
@classmethod
|
434
465
|
async def insert_multiple_unnest(
|
435
|
-
cls:
|
466
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
436
467
|
) -> str:
|
437
468
|
return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
|
438
469
|
|
439
470
|
@classmethod
|
440
471
|
async def insert_multiple_array_safe(
|
441
|
-
cls:
|
472
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
442
473
|
) -> str:
|
443
474
|
last = ""
|
444
475
|
for chunk in chunked(rows, 100):
|
@@ -449,7 +480,7 @@ class ModelBase(Mapping[str, Any]):
|
|
449
480
|
|
450
481
|
@classmethod
|
451
482
|
async def insert_multiple(
|
452
|
-
cls:
|
483
|
+
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
453
484
|
) -> str:
|
454
485
|
if cls.array_safe_insert:
|
455
486
|
return await cls.insert_multiple_array_safe(connection_or_pool, rows)
|
@@ -458,7 +489,7 @@ class ModelBase(Mapping[str, Any]):
|
|
458
489
|
|
459
490
|
@classmethod
|
460
491
|
async def upsert_multiple_unnest(
|
461
|
-
cls:
|
492
|
+
cls: type[T],
|
462
493
|
connection_or_pool: Union[Connection, Pool],
|
463
494
|
rows: Iterable[T],
|
464
495
|
insert_only: FieldNamesSet = (),
|
@@ -469,7 +500,7 @@ class ModelBase(Mapping[str, Any]):
|
|
469
500
|
|
470
501
|
@classmethod
|
471
502
|
async def upsert_multiple_array_safe(
|
472
|
-
cls:
|
503
|
+
cls: type[T],
|
473
504
|
connection_or_pool: Union[Connection, Pool],
|
474
505
|
rows: Iterable[T],
|
475
506
|
insert_only: FieldNamesSet = (),
|
@@ -485,7 +516,7 @@ class ModelBase(Mapping[str, Any]):
|
|
485
516
|
|
486
517
|
@classmethod
|
487
518
|
async def upsert_multiple(
|
488
|
-
cls:
|
519
|
+
cls: type[T],
|
489
520
|
connection_or_pool: Union[Connection, Pool],
|
490
521
|
rows: Iterable[T],
|
491
522
|
insert_only: FieldNamesSet = (),
|
@@ -501,9 +532,9 @@ class ModelBase(Mapping[str, Any]):
|
|
501
532
|
|
502
533
|
@classmethod
|
503
534
|
def _get_equal_ignoring_fn(
|
504
|
-
cls:
|
535
|
+
cls: type[T], ignore: FieldNamesSet = ()
|
505
536
|
) -> Callable[[T, T], bool]:
|
506
|
-
env:
|
537
|
+
env: dict[str, Any] = {}
|
507
538
|
func = ["def equal_ignoring(a, b):"]
|
508
539
|
for f in cls._fields():
|
509
540
|
if f.name not in ignore:
|
@@ -514,14 +545,14 @@ class ModelBase(Mapping[str, Any]):
|
|
514
545
|
|
515
546
|
@classmethod
|
516
547
|
async def replace_multiple(
|
517
|
-
cls:
|
548
|
+
cls: type[T],
|
518
549
|
connection: Connection,
|
519
550
|
rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
|
520
551
|
*,
|
521
552
|
where: Where,
|
522
553
|
ignore: FieldNamesSet = (),
|
523
554
|
insert_only: FieldNamesSet = (),
|
524
|
-
) ->
|
555
|
+
) -> tuple[list[T], list[T], list[T]]:
|
525
556
|
ignore = sorted(set(ignore) | set(insert_only))
|
526
557
|
equal_ignoring = cls._cached(
|
527
558
|
("equal_ignoring", tuple(ignore)),
|
@@ -556,32 +587,30 @@ class ModelBase(Mapping[str, Any]):
|
|
556
587
|
|
557
588
|
@classmethod
|
558
589
|
def _get_differences_ignoring_fn(
|
559
|
-
cls:
|
560
|
-
) -> Callable[[T, T],
|
561
|
-
env:
|
590
|
+
cls: type[T], ignore: FieldNamesSet = ()
|
591
|
+
) -> Callable[[T, T], list[str]]:
|
592
|
+
env: dict[str, Any] = {}
|
562
593
|
func = [
|
563
594
|
"def differences_ignoring(a, b):",
|
564
595
|
" diffs = []",
|
565
596
|
]
|
566
597
|
for f in cls._fields():
|
567
598
|
if f.name not in ignore:
|
568
|
-
func.append(
|
569
|
-
f" if a.{f.name} != b.{f.name}: diffs.append({repr(f.name)})"
|
570
|
-
)
|
599
|
+
func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
|
571
600
|
func += [" return diffs"]
|
572
601
|
exec("\n".join(func), env)
|
573
602
|
return env["differences_ignoring"]
|
574
603
|
|
575
604
|
@classmethod
|
576
605
|
async def replace_multiple_reporting_differences(
|
577
|
-
cls:
|
606
|
+
cls: type[T],
|
578
607
|
connection: Connection,
|
579
608
|
rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
|
580
609
|
*,
|
581
610
|
where: Where,
|
582
611
|
ignore: FieldNamesSet = (),
|
583
612
|
insert_only: FieldNamesSet = (),
|
584
|
-
) ->
|
613
|
+
) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
|
585
614
|
ignore = sorted(set(ignore) | set(insert_only))
|
586
615
|
differences_ignoring = cls._cached(
|
587
616
|
("differences_ignoring", tuple(ignore)),
|
sql_athame/escape.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1
1
|
import math
|
2
2
|
import uuid
|
3
|
-
from
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
|
6
7
|
def escape(value: Any) -> str:
|
7
8
|
if isinstance(value, str):
|
8
|
-
return f"E{
|
9
|
+
return f"E{value!r}"
|
9
10
|
elif isinstance(value, float) or isinstance(value, int):
|
10
11
|
if math.isnan(value):
|
11
12
|
raise ValueError("Can't escape NaN float")
|
12
13
|
elif math.isinf(value):
|
13
14
|
raise ValueError("Can't escape infinite float")
|
14
|
-
return f"{
|
15
|
+
return f"{value!r}"
|
15
16
|
elif isinstance(value, uuid.UUID):
|
16
|
-
return f"{
|
17
|
+
return f"{str(value)!r}::UUID"
|
17
18
|
elif isinstance(value, Sequence):
|
18
19
|
args = ", ".join(escape(x) for x in value)
|
19
20
|
return f"ARRAY[{args}]"
|
sql_athame/sqlalchemy.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Any
|
1
|
+
from typing import TYPE_CHECKING, Any
|
2
2
|
|
3
3
|
from .types import FlatPart, Placeholder, Slot
|
4
4
|
|
@@ -7,10 +7,10 @@ try:
|
|
7
7
|
from sqlalchemy.sql.elements import BindParameter
|
8
8
|
|
9
9
|
def sqlalchemy_text_from_fragment(self: "Fragment") -> Any:
|
10
|
-
parts:
|
10
|
+
parts: list[FlatPart] = []
|
11
11
|
self.flatten_into(parts)
|
12
|
-
bindparams:
|
13
|
-
out_parts:
|
12
|
+
bindparams: dict[str, Any] = {}
|
13
|
+
out_parts: list[str] = []
|
14
14
|
for part in parts:
|
15
15
|
if isinstance(part, Slot):
|
16
16
|
out_parts.append(f"(:{part.name})")
|
sql_athame/types.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import dataclasses
|
2
2
|
from typing import TYPE_CHECKING, Any, Union
|
3
3
|
|
4
|
+
from typing_extensions import TypeAlias
|
5
|
+
|
4
6
|
|
5
7
|
@dataclasses.dataclass(eq=False)
|
6
8
|
class Placeholder:
|
@@ -15,8 +17,8 @@ class Slot:
|
|
15
17
|
name: str
|
16
18
|
|
17
19
|
|
18
|
-
Part = Union[str, Placeholder, Slot, "Fragment"]
|
19
|
-
FlatPart = Union[str, Placeholder, Slot]
|
20
|
+
Part: TypeAlias = Union[str, Placeholder, Slot, "Fragment"]
|
21
|
+
FlatPart: TypeAlias = Union[str, Placeholder, Slot]
|
20
22
|
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from .base import Fragment
|
@@ -0,0 +1,11 @@
|
|
1
|
+
sql_athame/__init__.py,sha256=7OBIMZOcrD2pvfIL-rjD1IGZ3TNQbwyu76a9PWk-yYg,79
|
2
|
+
sql_athame/base.py,sha256=FR7EmC0VkX1VRgvAutSEfYSWhlEYpoqS1Kqxp1jHp6Y,10293
|
3
|
+
sql_athame/dataclasses.py,sha256=8JDACQr5RCeCbu2QRAzA9rpM9i1TJNGKEFXEFbGJUgo,22193
|
4
|
+
sql_athame/escape.py,sha256=kK101xXeFitlvuG-L_hvhdpgGJCtmRTprsn1yEfZKws,758
|
5
|
+
sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
sql_athame/sqlalchemy.py,sha256=aWopfPh3j71XwKmcN_VcHRNlhscI0Sckd4AiyGf8Tpw,1293
|
7
|
+
sql_athame/types.py,sha256=FQ06l9Uc-vo57UrAarvnukILdV2gN1IaYUnHJ_bNYic,475
|
8
|
+
sql_athame-0.4.0a7.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
|
9
|
+
sql_athame-0.4.0a7.dist-info/METADATA,sha256=OqUSaxi_5K6vfYxWpiXGP_qDXPU-qQnCrkY3yruhzi4,12845
|
10
|
+
sql_athame-0.4.0a7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
11
|
+
sql_athame-0.4.0a7.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
sql_athame/__init__.py,sha256=rzUQcbzmj3qkPZpL9jI_ALTRv-e1pAV4jSCryWkutlk,130
|
2
|
-
sql_athame/base.py,sha256=fSnHQhh5ULeJ5q32RVUAvpWtF0qoY61B2gEEP59Nrpo,10350
|
3
|
-
sql_athame/dataclasses.py,sha256=KTn8bUBajNXDDVxgoBF7BdFkv8gSQynfa8fyYv8WrzM,20442
|
4
|
-
sql_athame/escape.py,sha256=LXExbiYtc407yDU4vPieyY2Pq5nypsJFfBc_2-gsbUg,743
|
5
|
-
sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
-
sql_athame/sqlalchemy.py,sha256=c-pCLE11hTh5I19rY1Vp5E7P7lAaj9i-i7ko2L8rlF4,1305
|
7
|
-
sql_athame/types.py,sha256=7P4OyY0ezRlb2UDD9lpdXiLChnhQcBvHWaG_PKy3jmE,412
|
8
|
-
sql_athame-0.4.0a5.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
|
9
|
-
sql_athame-0.4.0a5.dist-info/METADATA,sha256=-9SuyMtbLf8Dl4oHXKPRrIGRyuVcVPc2Yt5B3958zNo,12845
|
10
|
-
sql_athame-0.4.0a5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
11
|
-
sql_athame-0.4.0a5.dist-info/RECORD,,
|
File without changes
|
File without changes
|