sql-athame 0.4.0a5__tar.gz → 0.4.0a7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sql-athame
3
- Version: 0.4.0a5
3
+ Version: 0.4.0a7
4
4
  Summary: Python tool for slicing and dicing SQL
5
5
  Home-page: https://github.com/bdowning/sql-athame
6
6
  License: MIT
@@ -0,0 +1,103 @@
1
+ [tool.poetry]
2
+ name = "sql-athame"
3
+ version = "0.4.0-alpha-7"
4
+ description = "Python tool for slicing and dicing SQL"
5
+ authors = ["Brian Downing <bdowning@lavos.net>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ homepage = "https://github.com/bdowning/sql-athame"
9
+ repository = "https://github.com/bdowning/sql-athame"
10
+
11
+ [tool.poetry.extras]
12
+ asyncpg = ["asyncpg"]
13
+
14
+ [tool.poetry.dependencies]
15
+ python = "^3.9"
16
+ asyncpg = { version = "*", optional = true }
17
+ typing-extensions = "*"
18
+
19
+ [tool.poetry.group.dev.dependencies]
20
+ pytest = "*"
21
+ mypy = "*"
22
+ flake8 = "*"
23
+ ipython = "*"
24
+ pytest-cov = "*"
25
+ bump2version = "*"
26
+ asyncpg = "*"
27
+ pytest-asyncio = "*"
28
+ grip = "*"
29
+ SQLAlchemy = "*"
30
+ ruff = "*"
31
+
32
+ [build-system]
33
+ requires = ["poetry>=0.12"]
34
+ build-backend = "poetry.masonry.api"
35
+
36
+ [tool.ruff]
37
+ target-version = "py39"
38
+
39
+ [tool.ruff.lint]
40
+ select = [
41
+ "ASYNC",
42
+ "B",
43
+ "BLE",
44
+ "C4",
45
+ "DTZ",
46
+ "E",
47
+ "F",
48
+ "I",
49
+ "INP",
50
+ "ISC",
51
+ "LOG",
52
+ "N",
53
+ "PIE",
54
+ "PT",
55
+ "RET",
56
+ "RUF",
57
+ "SLOT",
58
+ "UP",
59
+ ]
60
+ flake8-comprehensions.allow-dict-calls-with-keyword-arguments = true
61
+ ignore = [
62
+ "E501", # line too long
63
+ "E721", # type checks, currently broken
64
+ "ISC001", # conflicts with ruff format
65
+ "PT004", # Fixture `...` does not return anything, add leading underscore
66
+ "RET505", # Unnecessary `else` after `return` statement
67
+ "RET506", # Unnecessary `else` after `raise` statement
68
+ ]
69
+
70
+ [tool.ruff.lint.per-file-ignores]
71
+ "__init__.py" = [
72
+ "F401", # wildcard import
73
+ ]
74
+
75
+ [tool.pytest.ini_options]
76
+ addopts = [
77
+ "-v",
78
+ "--cov",
79
+ "--cov-report", "xml:results/pytest/coverage.xml",
80
+ "--cov-report", "html:results/pytest/cov_html",
81
+ "--cov-report", "term-missing",
82
+ "--junitxml=results/pytest/results.xml",
83
+ "--durations=5",
84
+ ]
85
+ junit_family = "legacy"
86
+ asyncio_mode = "auto"
87
+
88
+ [tool.coverage]
89
+ run.include = [
90
+ "sql_athame/**/*.py",
91
+ "tests/**/*.py",
92
+ ]
93
+ report.precision = 2
94
+
95
+ [tool.mypy]
96
+ disallow_incomplete_defs = true
97
+ check_untyped_defs = true
98
+
99
+ [[tool.mypy.overrides]]
100
+ module = [
101
+ "asyncpg",
102
+ ]
103
+ ignore_missing_imports = true
@@ -0,0 +1,2 @@
1
+ from .base import Fragment, sql
2
+ from .dataclasses import ColumnInfo, ModelBase
@@ -2,16 +2,11 @@ import dataclasses
2
2
  import json
3
3
  import re
4
4
  import string
5
+ from collections.abc import Iterable, Iterator, Sequence
5
6
  from typing import (
6
7
  Any,
7
8
  Callable,
8
- Dict,
9
- Iterable,
10
- Iterator,
11
- List,
12
9
  Optional,
13
- Sequence,
14
- Tuple,
15
10
  Union,
16
11
  cast,
17
12
  overload,
@@ -34,7 +29,7 @@ def auto_numbered(field_name):
34
29
  def process_slot_value(
35
30
  name: str,
36
31
  value: Any,
37
- placeholders: Dict[str, Placeholder],
32
+ placeholders: dict[str, Placeholder],
38
33
  ) -> Union["Fragment", Placeholder]:
39
34
  if isinstance(value, Fragment):
40
35
  return value
@@ -47,9 +42,9 @@ def process_slot_value(
47
42
  @dataclasses.dataclass
48
43
  class Fragment:
49
44
  __slots__ = ["parts"]
50
- parts: List[Part]
45
+ parts: list[Part]
51
46
 
52
- def flatten_into(self, parts: List[FlatPart]) -> None:
47
+ def flatten_into(self, parts: list[FlatPart]) -> None:
53
48
  for part in self.parts:
54
49
  if isinstance(part, Fragment):
55
50
  part.flatten_into(parts)
@@ -70,10 +65,10 @@ class Fragment:
70
65
  for i, part in enumerate(flattened.parts):
71
66
  if isinstance(part, Slot):
72
67
  func.append(
73
- f" process_slot_value({repr(part.name)}, slots[{repr(part.name)}], placeholders),"
68
+ f" process_slot_value({part.name!r}, slots[{part.name!r}], placeholders),"
74
69
  )
75
70
  elif isinstance(part, str):
76
- func.append(f" {repr(part)},")
71
+ func.append(f" {part!r},")
77
72
  else:
78
73
  env[f"part_{i}"] = part
79
74
  func.append(f" part_{i},")
@@ -82,9 +77,9 @@ class Fragment:
82
77
  return env["compiled"] # type: ignore
83
78
 
84
79
  def flatten(self) -> "Fragment":
85
- parts: List[FlatPart] = []
80
+ parts: list[FlatPart] = []
86
81
  self.flatten_into(parts)
87
- out_parts: List[Part] = []
82
+ out_parts: list[Part] = []
88
83
  for part in parts:
89
84
  if isinstance(part, str) and out_parts and isinstance(out_parts[-1], str):
90
85
  out_parts[-1] += part
@@ -93,9 +88,9 @@ class Fragment:
93
88
  return Fragment(out_parts)
94
89
 
95
90
  def fill(self, **kwargs: Any) -> "Fragment":
96
- parts: List[Part] = []
97
- self.flatten_into(cast(List[FlatPart], parts))
98
- placeholders: Dict[str, Placeholder] = {}
91
+ parts: list[Part] = []
92
+ self.flatten_into(cast(list[FlatPart], parts))
93
+ placeholders: dict[str, Placeholder] = {}
99
94
  for i, part in enumerate(parts):
100
95
  if isinstance(part, Slot):
101
96
  parts[i] = process_slot_value(
@@ -106,24 +101,24 @@ class Fragment:
106
101
  @overload
107
102
  def prep_query(
108
103
  self, allow_slots: Literal[True]
109
- ) -> Tuple[str, List[Union[Placeholder, Slot]]]: ... # pragma: no cover
104
+ ) -> tuple[str, list[Union[Placeholder, Slot]]]: ... # pragma: no cover
110
105
 
111
106
  @overload
112
107
  def prep_query(
113
108
  self, allow_slots: Literal[False] = False
114
- ) -> Tuple[str, List[Placeholder]]: ... # pragma: no cover
109
+ ) -> tuple[str, list[Placeholder]]: ... # pragma: no cover
115
110
 
116
- def prep_query(self, allow_slots: bool = False) -> Tuple[str, List[Any]]:
117
- parts: List[FlatPart] = []
111
+ def prep_query(self, allow_slots: bool = False) -> tuple[str, list[Any]]:
112
+ parts: list[FlatPart] = []
118
113
  self.flatten_into(parts)
119
- args: List[Union[Placeholder, Slot]] = []
120
- placeholder_ids: Dict[Placeholder, int] = {}
121
- slot_ids: Dict[Slot, int] = {}
122
- out_parts: List[str] = []
114
+ args: list[Union[Placeholder, Slot]] = []
115
+ placeholder_ids: dict[Placeholder, int] = {}
116
+ slot_ids: dict[Slot, int] = {}
117
+ out_parts: list[str] = []
123
118
  for part in parts:
124
119
  if isinstance(part, Slot):
125
120
  if not allow_slots:
126
- raise ValueError(f"Unfilled slot: {repr(part.name)}")
121
+ raise ValueError(f"Unfilled slot: {part.name!r}")
127
122
  if part not in slot_ids:
128
123
  args.append(part)
129
124
  slot_ids[part] = len(args)
@@ -138,7 +133,7 @@ class Fragment:
138
133
  out_parts.append(part)
139
134
  return "".join(out_parts).strip(), args
140
135
 
141
- def query(self) -> Tuple[str, List[Any]]:
136
+ def query(self) -> tuple[str, list[Any]]:
142
137
  query, args = self.prep_query()
143
138
  placeholder_values = [arg.value for arg in args]
144
139
  return query, placeholder_values
@@ -146,16 +141,16 @@ class Fragment:
146
141
  def sqlalchemy_text(self) -> Any:
147
142
  return sqlalchemy_text_from_fragment(self)
148
143
 
149
- def prepare(self) -> Tuple[str, Callable[..., List[Any]]]:
144
+ def prepare(self) -> tuple[str, Callable[..., list[Any]]]:
150
145
  query, args = self.prep_query(allow_slots=True)
151
- env = dict()
146
+ env = {}
152
147
  func = [
153
148
  "def generate_args(**kwargs):",
154
149
  " return [",
155
150
  ]
156
151
  for i, arg in enumerate(args):
157
152
  if isinstance(arg, Slot):
158
- func.append(f" kwargs[{repr(arg.name)}],")
153
+ func.append(f" kwargs[{arg.name!r}],")
159
154
  else:
160
155
  env[f"value_{i}"] = arg.value
161
156
  func.append(f" value_{i},")
@@ -178,10 +173,10 @@ class SQLFormatter:
178
173
  if not preserve_formatting:
179
174
  fmt = newline_whitespace_re.sub(" ", fmt)
180
175
  fmtr = string.Formatter()
181
- parts: List[Part] = []
182
- placeholders: Dict[str, Placeholder] = {}
176
+ parts: list[Part] = []
177
+ placeholders: dict[str, Placeholder] = {}
183
178
  next_auto_field = 0
184
- for literal_text, field_name, format_spec, conversion in fmtr.parse(fmt):
179
+ for literal_text, field_name, _format_spec, _conversion in fmtr.parse(fmt):
185
180
  parts.append(literal_text)
186
181
  if field_name is not None:
187
182
  if auto_numbered(field_name):
@@ -258,12 +253,11 @@ class SQLFormatter:
258
253
  parts = parts[0]
259
254
  return Fragment(list(join_parts(parts, infix=", ")))
260
255
 
261
- @staticmethod
262
- def unnest(data: Iterable[Sequence[Any]], types: Iterable[str]) -> Fragment:
263
- nested = list(nest_for_type(x, t) for x, t in zip(zip(*data), types))
256
+ def unnest(self, data: Iterable[Sequence[Any]], types: Iterable[str]) -> Fragment:
257
+ nested = [nest_for_type(x, t) for x, t in zip(zip(*data), types)]
264
258
  if not nested:
265
- nested = list(nest_for_type([], t) for t in types)
266
- return Fragment(["UNNEST(", sql.list(nested), ")"])
259
+ nested = [nest_for_type([], t) for t in types]
260
+ return Fragment(["UNNEST(", self.list(nested), ")"])
267
261
 
268
262
 
269
263
  sql = SQLFormatter()
@@ -296,7 +290,7 @@ def lit(text: str) -> Fragment:
296
290
  return Fragment([text])
297
291
 
298
292
 
299
- def any_all(frags: List[Fragment], op: str, base_case: str) -> Fragment:
293
+ def any_all(frags: list[Fragment], op: str, base_case: str) -> Fragment:
300
294
  if not frags:
301
295
  return lit(base_case)
302
296
  parts = join_parts(frags, prefix="(", infix=f") {op} (", suffix=")")
@@ -1,60 +1,87 @@
1
1
  import datetime
2
+ import functools
2
3
  import uuid
3
- from dataclasses import dataclass, field, fields
4
+ from collections.abc import AsyncGenerator, Iterable, Mapping
5
+ from dataclasses import Field, InitVar, dataclass, fields
4
6
  from typing import (
7
+ Annotated,
5
8
  Any,
6
- AsyncGenerator,
7
9
  Callable,
8
- Dict,
9
- Iterable,
10
- Iterator,
11
- List,
12
- Mapping,
13
10
  Optional,
14
- Set,
15
- Tuple,
16
- Type,
17
11
  TypeVar,
18
12
  Union,
13
+ get_origin,
14
+ get_type_hints,
19
15
  )
20
16
 
17
+ from typing_extensions import TypeAlias
18
+
21
19
  from .base import Fragment, sql
22
20
 
23
- Where = Union[Fragment, Iterable[Fragment]]
21
+ Where: TypeAlias = Union[Fragment, Iterable[Fragment]]
24
22
  # KLUDGE to avoid a string argument being valid
25
- SequenceOfStrings = Union[List[str], Tuple[str, ...]]
26
- FieldNames = SequenceOfStrings
27
- FieldNamesSet = Union[SequenceOfStrings, Set[str]]
23
+ SequenceOfStrings: TypeAlias = Union[list[str], tuple[str, ...]]
24
+ FieldNames: TypeAlias = SequenceOfStrings
25
+ FieldNamesSet: TypeAlias = Union[SequenceOfStrings, set[str]]
28
26
 
29
- Connection = Any
30
- Pool = Any
27
+ Connection: TypeAlias = Any
28
+ Pool: TypeAlias = Any
31
29
 
32
30
 
33
31
  @dataclass
34
32
  class ColumnInfo:
35
- type: str
36
- create_type: str
37
- constraints: Tuple[str, ...]
38
-
39
- def create_table_string(self):
40
- return " ".join((self.create_type, *self.constraints))
41
-
33
+ type: Optional[str] = None
34
+ create_type: Optional[str] = None
35
+ nullable: Optional[bool] = None
36
+ _constraints: tuple[str, ...] = ()
37
+
38
+ constraints: InitVar[Union[str, Iterable[str], None]] = None
39
+
40
+ def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
41
+ if constraints is not None:
42
+ if type(constraints) is str:
43
+ constraints = (constraints,)
44
+ self._constraints = tuple(constraints)
45
+
46
+ @staticmethod
47
+ def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
48
+ return ColumnInfo(
49
+ type=b.type if b.type is not None else a.type,
50
+ create_type=b.create_type if b.create_type is not None else a.create_type,
51
+ nullable=b.nullable if b.nullable is not None else a.nullable,
52
+ _constraints=(*a._constraints, *b._constraints),
53
+ )
42
54
 
43
- def model_field_metadata(
44
- type: str, constraints: Union[str, Iterable[str]] = ()
45
- ) -> Dict[str, Any]:
46
- if isinstance(constraints, str):
47
- constraints = (constraints,)
48
- info = ColumnInfo(
49
- sql_create_type_map.get(type.upper(), type), type, tuple(constraints)
50
- )
51
- return {"sql_athame": info}
52
55
 
56
+ @dataclass
57
+ class ConcreteColumnInfo:
58
+ type: str
59
+ create_type: str
60
+ nullable: bool
61
+ constraints: tuple[str, ...]
62
+
63
+ @staticmethod
64
+ def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
65
+ info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
66
+ if info.create_type is None and info.type is not None:
67
+ info.create_type = info.type
68
+ info.type = sql_create_type_map.get(info.type.upper(), info.type)
69
+ if type(info.type) is not str or type(info.create_type) is not str:
70
+ raise ValueError(f"Missing SQL type for column {name!r}")
71
+ return ConcreteColumnInfo(
72
+ type=info.type,
73
+ create_type=info.create_type,
74
+ nullable=bool(info.nullable),
75
+ constraints=info._constraints,
76
+ )
53
77
 
54
- def model_field(
55
- *, type: str, constraints: Union[str, Iterable[str]] = (), **kwargs: Any
56
- ) -> Any:
57
- return field(**kwargs, metadata=model_field_metadata(type, constraints))
78
+ def create_table_string(self) -> str:
79
+ parts = (
80
+ self.create_type,
81
+ *(() if self.nullable else ("NOT NULL",)),
82
+ *self.constraints,
83
+ )
84
+ return " ".join(parts)
58
85
 
59
86
 
60
87
  sql_create_type_map = {
@@ -64,43 +91,37 @@ sql_create_type_map = {
64
91
  }
65
92
 
66
93
 
67
- sql_type_map = {
68
- Optional[bool]: ("BOOLEAN",),
69
- Optional[bytes]: ("BYTEA",),
70
- Optional[datetime.date]: ("DATE",),
71
- Optional[datetime.datetime]: ("TIMESTAMP",),
72
- Optional[float]: ("DOUBLE PRECISION",),
73
- Optional[int]: ("INTEGER",),
74
- Optional[str]: ("TEXT",),
75
- Optional[uuid.UUID]: ("UUID",),
76
- bool: ("BOOLEAN", "NOT NULL"),
77
- bytes: ("BYTEA", "NOT NULL"),
78
- datetime.date: ("DATE", "NOT NULL"),
79
- datetime.datetime: ("TIMESTAMP", "NOT NULL"),
80
- float: ("DOUBLE PRECISION", "NOT NULL"),
81
- int: ("INTEGER", "NOT NULL"),
82
- str: ("TEXT", "NOT NULL"),
83
- uuid.UUID: ("UUID", "NOT NULL"),
94
+ sql_type_map: dict[Any, tuple[str, bool]] = {
95
+ Optional[bool]: ("BOOLEAN", True),
96
+ Optional[bytes]: ("BYTEA", True),
97
+ Optional[datetime.date]: ("DATE", True),
98
+ Optional[datetime.datetime]: ("TIMESTAMP", True),
99
+ Optional[float]: ("DOUBLE PRECISION", True),
100
+ Optional[int]: ("INTEGER", True),
101
+ Optional[str]: ("TEXT", True),
102
+ Optional[uuid.UUID]: ("UUID", True),
103
+ bool: ("BOOLEAN", False),
104
+ bytes: ("BYTEA", False),
105
+ datetime.date: ("DATE", False),
106
+ datetime.datetime: ("TIMESTAMP", False),
107
+ float: ("DOUBLE PRECISION", False),
108
+ int: ("INTEGER", False),
109
+ str: ("TEXT", False),
110
+ uuid.UUID: ("UUID", False),
84
111
  }
85
112
 
86
113
 
87
- def column_info_for_field(field):
88
- if "sql_athame" in field.metadata:
89
- return field.metadata["sql_athame"]
90
- type, *constraints = sql_type_map[field.type]
91
- return ColumnInfo(type, type, tuple(constraints))
92
-
93
-
94
114
  T = TypeVar("T", bound="ModelBase")
95
115
  U = TypeVar("U")
96
116
 
97
117
 
98
- class ModelBase(Mapping[str, Any]):
99
- _column_info: Optional[Dict[str, ColumnInfo]]
100
- _cache: Dict[tuple, Any]
118
+ class ModelBase:
119
+ _column_info: Optional[dict[str, ConcreteColumnInfo]]
120
+ _cache: dict[tuple, Any]
101
121
  table_name: str
102
- primary_key_names: Tuple[str, ...]
122
+ primary_key_names: tuple[str, ...]
103
123
  array_safe_insert: bool
124
+ _type_hints: dict[str, type]
104
125
 
105
126
  def __init_subclass__(
106
127
  cls,
@@ -138,27 +159,38 @@ class ModelBase(Mapping[str, Any]):
138
159
  cls._cache[key] = thunk()
139
160
  return cls._cache[key]
140
161
 
141
- def keys(self):
142
- return self.field_names()
143
-
144
- def __getitem__(self, key: str) -> Any:
145
- return getattr(self, key)
146
-
147
- def __iter__(self) -> Iterator[Any]:
148
- return iter(self.keys())
149
-
150
- def __len__(self) -> int:
151
- return len(self.keys())
152
-
153
- def get(self, key: str, default: Any = None) -> Any:
154
- return getattr(self, key, default)
155
-
156
162
  @classmethod
157
- def column_info(cls, column: str) -> ColumnInfo:
163
+ def type_hints(cls) -> dict[str, type]:
164
+ try:
165
+ return cls._type_hints
166
+ except AttributeError:
167
+ cls._type_hints = get_type_hints(cls, include_extras=True)
168
+ return cls._type_hints
169
+
170
+ @classmethod
171
+ def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
172
+ type_info = cls.type_hints()[field.name]
173
+ base_type = type_info
174
+ if get_origin(type_info) is Annotated:
175
+ base_type = type_info.__origin__ # type: ignore
176
+ info = []
177
+ if base_type in sql_type_map:
178
+ _type, nullable = sql_type_map[base_type]
179
+ info.append(ColumnInfo(type=_type, nullable=nullable))
180
+ if get_origin(type_info) is Annotated:
181
+ for md in type_info.__metadata__: # type: ignore
182
+ if isinstance(md, ColumnInfo):
183
+ info.append(md)
184
+ return ConcreteColumnInfo.from_column_info(field.name, *info)
185
+
186
+ @classmethod
187
+ def column_info(cls, column: str) -> ConcreteColumnInfo:
158
188
  try:
159
189
  return cls._column_info[column] # type: ignore
160
190
  except AttributeError:
161
- cls._column_info = {f.name: column_info_for_field(f) for f in cls._fields()}
191
+ cls._column_info = {
192
+ f.name: cls.column_info_for_field(f) for f in cls._fields()
193
+ }
162
194
  return cls._column_info[column]
163
195
 
164
196
  @classmethod
@@ -166,17 +198,17 @@ class ModelBase(Mapping[str, Any]):
166
198
  return sql.identifier(cls.table_name, prefix=prefix)
167
199
 
168
200
  @classmethod
169
- def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> List[Fragment]:
201
+ def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
170
202
  return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
171
203
 
172
204
  @classmethod
173
- def field_names(cls, *, exclude: FieldNamesSet = ()) -> List[str]:
205
+ def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
174
206
  return [f.name for f in cls._fields() if f.name not in exclude]
175
207
 
176
208
  @classmethod
177
209
  def field_names_sql(
178
210
  cls, *, prefix: Optional[str] = None, exclude: FieldNamesSet = ()
179
- ) -> List[Fragment]:
211
+ ) -> list[Fragment]:
180
212
  return [
181
213
  sql.identifier(f, prefix=prefix) for f in cls.field_names(exclude=exclude)
182
214
  ]
@@ -186,9 +218,9 @@ class ModelBase(Mapping[str, Any]):
186
218
 
187
219
  @classmethod
188
220
  def _get_field_values_fn(
189
- cls: Type[T], exclude: FieldNamesSet = ()
190
- ) -> Callable[[T], List[Any]]:
191
- env: Dict[str, Any] = dict()
221
+ cls: type[T], exclude: FieldNamesSet = ()
222
+ ) -> Callable[[T], list[Any]]:
223
+ env: dict[str, Any] = {}
192
224
  func = ["def get_field_values(self): return ["]
193
225
  for f in cls._fields():
194
226
  if f.name not in exclude:
@@ -197,7 +229,7 @@ class ModelBase(Mapping[str, Any]):
197
229
  exec(" ".join(func), env)
198
230
  return env["get_field_values"]
199
231
 
200
- def field_values(self, *, exclude: FieldNamesSet = ()) -> List[Any]:
232
+ def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
201
233
  get_field_values = self._cached(
202
234
  ("get_field_values", tuple(sorted(exclude))),
203
235
  lambda: self._get_field_values_fn(exclude),
@@ -206,7 +238,7 @@ class ModelBase(Mapping[str, Any]):
206
238
 
207
239
  def field_values_sql(
208
240
  self, *, exclude: FieldNamesSet = (), default_none: bool = False
209
- ) -> List[Fragment]:
241
+ ) -> list[Fragment]:
210
242
  if default_none:
211
243
  return [
212
244
  sql.literal("DEFAULT") if value is None else sql.value(value)
@@ -217,7 +249,7 @@ class ModelBase(Mapping[str, Any]):
217
249
 
218
250
  @classmethod
219
251
  def from_tuple(
220
- cls: Type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
252
+ cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
221
253
  ) -> T:
222
254
  names = (f.name for f in cls._fields() if f.name not in exclude)
223
255
  kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
@@ -225,14 +257,14 @@ class ModelBase(Mapping[str, Any]):
225
257
 
226
258
  @classmethod
227
259
  def from_dict(
228
- cls: Type[T], dct: Dict[str, Any], *, exclude: FieldNamesSet = ()
260
+ cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
229
261
  ) -> T:
230
262
  names = {f.name for f in cls._fields() if f.name not in exclude}
231
263
  kwargs = {k: v for k, v in dct.items() if k in names}
232
264
  return cls(**kwargs)
233
265
 
234
266
  @classmethod
235
- def ensure_model(cls: Type[T], row: Union[T, Mapping[str, Any]]) -> T:
267
+ def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
236
268
  if isinstance(row, cls):
237
269
  return row
238
270
  return cls(**row)
@@ -286,7 +318,7 @@ class ModelBase(Mapping[str, Any]):
286
318
 
287
319
  @classmethod
288
320
  async def select_cursor(
289
- cls: Type[T],
321
+ cls: type[T],
290
322
  connection: Connection,
291
323
  order_by: Union[FieldNames, str] = (),
292
324
  for_update: bool = False,
@@ -301,12 +333,12 @@ class ModelBase(Mapping[str, Any]):
301
333
 
302
334
  @classmethod
303
335
  async def select(
304
- cls: Type[T],
336
+ cls: type[T],
305
337
  connection_or_pool: Union[Connection, Pool],
306
338
  order_by: Union[FieldNames, str] = (),
307
339
  for_update: bool = False,
308
340
  where: Where = (),
309
- ) -> List[T]:
341
+ ) -> list[T]:
310
342
  return [
311
343
  cls(**row)
312
344
  for row in await connection_or_pool.fetch(
@@ -315,7 +347,7 @@ class ModelBase(Mapping[str, Any]):
315
347
  ]
316
348
 
317
349
  @classmethod
318
- def create_sql(cls: Type[T], **kwargs: Any) -> Fragment:
350
+ def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
319
351
  return sql(
320
352
  "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
321
353
  table=cls.table_name_sql(),
@@ -326,7 +358,7 @@ class ModelBase(Mapping[str, Any]):
326
358
 
327
359
  @classmethod
328
360
  async def create(
329
- cls: Type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
361
+ cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
330
362
  ) -> T:
331
363
  row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
332
364
  return cls(**row)
@@ -375,11 +407,10 @@ class ModelBase(Mapping[str, Any]):
375
407
  self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
376
408
  )
377
409
  result = await connection_or_pool.fetchrow(*query)
378
- is_update = result["xmax"] != 0
379
- return is_update
410
+ return result["xmax"] != 0
380
411
 
381
412
  @classmethod
382
- def delete_multiple_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
413
+ def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
383
414
  cached = cls._cached(
384
415
  ("delete_multiple_sql",),
385
416
  lambda: sql(
@@ -397,12 +428,12 @@ class ModelBase(Mapping[str, Any]):
397
428
 
398
429
  @classmethod
399
430
  async def delete_multiple(
400
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
431
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
401
432
  ) -> str:
402
433
  return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
403
434
 
404
435
  @classmethod
405
- def insert_multiple_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
436
+ def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
406
437
  cached = cls._cached(
407
438
  ("insert_multiple_sql",),
408
439
  lambda: sql(
@@ -419,7 +450,7 @@ class ModelBase(Mapping[str, Any]):
419
450
  )
420
451
 
421
452
  @classmethod
422
- def insert_multiple_array_safe_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
453
+ def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
423
454
  return sql(
424
455
  "INSERT INTO {table} ({fields}) VALUES {values}",
425
456
  table=cls.table_name_sql(),
@@ -432,13 +463,13 @@ class ModelBase(Mapping[str, Any]):
432
463
 
433
464
  @classmethod
434
465
  async def insert_multiple_unnest(
435
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
466
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
436
467
  ) -> str:
437
468
  return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
438
469
 
439
470
  @classmethod
440
471
  async def insert_multiple_array_safe(
441
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
472
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
442
473
  ) -> str:
443
474
  last = ""
444
475
  for chunk in chunked(rows, 100):
@@ -449,7 +480,7 @@ class ModelBase(Mapping[str, Any]):
449
480
 
450
481
  @classmethod
451
482
  async def insert_multiple(
452
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
483
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
453
484
  ) -> str:
454
485
  if cls.array_safe_insert:
455
486
  return await cls.insert_multiple_array_safe(connection_or_pool, rows)
@@ -458,7 +489,7 @@ class ModelBase(Mapping[str, Any]):
458
489
 
459
490
  @classmethod
460
491
  async def upsert_multiple_unnest(
461
- cls: Type[T],
492
+ cls: type[T],
462
493
  connection_or_pool: Union[Connection, Pool],
463
494
  rows: Iterable[T],
464
495
  insert_only: FieldNamesSet = (),
@@ -469,7 +500,7 @@ class ModelBase(Mapping[str, Any]):
469
500
 
470
501
  @classmethod
471
502
  async def upsert_multiple_array_safe(
472
- cls: Type[T],
503
+ cls: type[T],
473
504
  connection_or_pool: Union[Connection, Pool],
474
505
  rows: Iterable[T],
475
506
  insert_only: FieldNamesSet = (),
@@ -485,7 +516,7 @@ class ModelBase(Mapping[str, Any]):
485
516
 
486
517
  @classmethod
487
518
  async def upsert_multiple(
488
- cls: Type[T],
519
+ cls: type[T],
489
520
  connection_or_pool: Union[Connection, Pool],
490
521
  rows: Iterable[T],
491
522
  insert_only: FieldNamesSet = (),
@@ -501,9 +532,9 @@ class ModelBase(Mapping[str, Any]):
501
532
 
502
533
  @classmethod
503
534
  def _get_equal_ignoring_fn(
504
- cls: Type[T], ignore: FieldNamesSet = ()
535
+ cls: type[T], ignore: FieldNamesSet = ()
505
536
  ) -> Callable[[T, T], bool]:
506
- env: Dict[str, Any] = dict()
537
+ env: dict[str, Any] = {}
507
538
  func = ["def equal_ignoring(a, b):"]
508
539
  for f in cls._fields():
509
540
  if f.name not in ignore:
@@ -514,14 +545,14 @@ class ModelBase(Mapping[str, Any]):
514
545
 
515
546
  @classmethod
516
547
  async def replace_multiple(
517
- cls: Type[T],
548
+ cls: type[T],
518
549
  connection: Connection,
519
550
  rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
520
551
  *,
521
552
  where: Where,
522
553
  ignore: FieldNamesSet = (),
523
554
  insert_only: FieldNamesSet = (),
524
- ) -> Tuple[List[T], List[T], List[T]]:
555
+ ) -> tuple[list[T], list[T], list[T]]:
525
556
  ignore = sorted(set(ignore) | set(insert_only))
526
557
  equal_ignoring = cls._cached(
527
558
  ("equal_ignoring", tuple(ignore)),
@@ -556,32 +587,30 @@ class ModelBase(Mapping[str, Any]):
556
587
 
557
588
  @classmethod
558
589
  def _get_differences_ignoring_fn(
559
- cls: Type[T], ignore: FieldNamesSet = ()
560
- ) -> Callable[[T, T], List[str]]:
561
- env: Dict[str, Any] = dict()
590
+ cls: type[T], ignore: FieldNamesSet = ()
591
+ ) -> Callable[[T, T], list[str]]:
592
+ env: dict[str, Any] = {}
562
593
  func = [
563
594
  "def differences_ignoring(a, b):",
564
595
  " diffs = []",
565
596
  ]
566
597
  for f in cls._fields():
567
598
  if f.name not in ignore:
568
- func.append(
569
- f" if a.{f.name} != b.{f.name}: diffs.append({repr(f.name)})"
570
- )
599
+ func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
571
600
  func += [" return diffs"]
572
601
  exec("\n".join(func), env)
573
602
  return env["differences_ignoring"]
574
603
 
575
604
  @classmethod
576
605
  async def replace_multiple_reporting_differences(
577
- cls: Type[T],
606
+ cls: type[T],
578
607
  connection: Connection,
579
608
  rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
580
609
  *,
581
610
  where: Where,
582
611
  ignore: FieldNamesSet = (),
583
612
  insert_only: FieldNamesSet = (),
584
- ) -> Tuple[List[T], List[Tuple[T, T, List[str]]], List[T]]:
613
+ ) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
585
614
  ignore = sorted(set(ignore) | set(insert_only))
586
615
  differences_ignoring = cls._cached(
587
616
  ("differences_ignoring", tuple(ignore)),
@@ -1,19 +1,20 @@
1
1
  import math
2
2
  import uuid
3
- from typing import Any, Sequence
3
+ from collections.abc import Sequence
4
+ from typing import Any
4
5
 
5
6
 
6
7
  def escape(value: Any) -> str:
7
8
  if isinstance(value, str):
8
- return f"E{repr(value)}"
9
+ return f"E{value!r}"
9
10
  elif isinstance(value, float) or isinstance(value, int):
10
11
  if math.isnan(value):
11
12
  raise ValueError("Can't escape NaN float")
12
13
  elif math.isinf(value):
13
14
  raise ValueError("Can't escape infinite float")
14
- return f"{repr(value)}"
15
+ return f"{value!r}"
15
16
  elif isinstance(value, uuid.UUID):
16
- return f"{repr(str(value))}::UUID"
17
+ return f"{str(value)!r}::UUID"
17
18
  elif isinstance(value, Sequence):
18
19
  args = ", ".join(escape(x) for x in value)
19
20
  return f"ARRAY[{args}]"
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Any, Dict, List
1
+ from typing import TYPE_CHECKING, Any
2
2
 
3
3
  from .types import FlatPart, Placeholder, Slot
4
4
 
@@ -7,10 +7,10 @@ try:
7
7
  from sqlalchemy.sql.elements import BindParameter
8
8
 
9
9
  def sqlalchemy_text_from_fragment(self: "Fragment") -> Any:
10
- parts: List[FlatPart] = []
10
+ parts: list[FlatPart] = []
11
11
  self.flatten_into(parts)
12
- bindparams: Dict[str, Any] = {}
13
- out_parts: List[str] = []
12
+ bindparams: dict[str, Any] = {}
13
+ out_parts: list[str] = []
14
14
  for part in parts:
15
15
  if isinstance(part, Slot):
16
16
  out_parts.append(f"(:{part.name})")
@@ -1,6 +1,8 @@
1
1
  import dataclasses
2
2
  from typing import TYPE_CHECKING, Any, Union
3
3
 
4
+ from typing_extensions import TypeAlias
5
+
4
6
 
5
7
  @dataclasses.dataclass(eq=False)
6
8
  class Placeholder:
@@ -15,8 +17,8 @@ class Slot:
15
17
  name: str
16
18
 
17
19
 
18
- Part = Union[str, Placeholder, Slot, "Fragment"]
19
- FlatPart = Union[str, Placeholder, Slot]
20
+ Part: TypeAlias = Union[str, Placeholder, Slot, "Fragment"]
21
+ FlatPart: TypeAlias = Union[str, Placeholder, Slot]
20
22
 
21
23
  if TYPE_CHECKING:
22
24
  from .base import Fragment
@@ -1,35 +0,0 @@
1
- [tool.poetry]
2
- name = "sql-athame"
3
- version = "0.4.0-alpha-5"
4
- description = "Python tool for slicing and dicing SQL"
5
- authors = ["Brian Downing <bdowning@lavos.net>"]
6
- license = "MIT"
7
- readme = "README.md"
8
- homepage = "https://github.com/bdowning/sql-athame"
9
- repository = "https://github.com/bdowning/sql-athame"
10
-
11
- [tool.poetry.extras]
12
- asyncpg = ["asyncpg"]
13
-
14
- [tool.poetry.dependencies]
15
- python = "^3.9"
16
- asyncpg = { version = "*", optional = true }
17
- typing-extensions = "*"
18
-
19
- [tool.poetry.dev-dependencies]
20
- black = "*"
21
- isort = "*"
22
- pytest = "*"
23
- mypy = "*"
24
- flake8 = "*"
25
- ipython = "*"
26
- pytest-cov = "*"
27
- bump2version = "*"
28
- asyncpg = "*"
29
- pytest-asyncio = "*"
30
- grip = "*"
31
- SQLAlchemy = "*"
32
-
33
- [build-system]
34
- requires = ["poetry>=0.12"]
35
- build-backend = "poetry.masonry.api"
@@ -1,2 +0,0 @@
1
- from .base import Fragment, sql # noqa: F401
2
- from .dataclasses import ModelBase, model_field, model_field_metadata # noqa: F401
File without changes
File without changes