sql-athame 0.4.0a4__py3-none-any.whl → 0.4.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_athame/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .base import Fragment, sql # noqa: F401
2
- from .dataclasses import ModelBase, model_field, model_field_metadata # noqa: F401
1
+ from .base import Fragment, sql
2
+ from .dataclasses import ColumnInfo, ModelBase
sql_athame/base.py CHANGED
@@ -2,16 +2,11 @@ import dataclasses
2
2
  import json
3
3
  import re
4
4
  import string
5
+ from collections.abc import Iterable, Iterator, Sequence
5
6
  from typing import (
6
7
  Any,
7
8
  Callable,
8
- Dict,
9
- Iterable,
10
- Iterator,
11
- List,
12
9
  Optional,
13
- Sequence,
14
- Tuple,
15
10
  Union,
16
11
  cast,
17
12
  overload,
@@ -34,7 +29,7 @@ def auto_numbered(field_name):
34
29
  def process_slot_value(
35
30
  name: str,
36
31
  value: Any,
37
- placeholders: Dict[str, Placeholder],
32
+ placeholders: dict[str, Placeholder],
38
33
  ) -> Union["Fragment", Placeholder]:
39
34
  if isinstance(value, Fragment):
40
35
  return value
@@ -47,9 +42,9 @@ def process_slot_value(
47
42
  @dataclasses.dataclass
48
43
  class Fragment:
49
44
  __slots__ = ["parts"]
50
- parts: List[Part]
45
+ parts: list[Part]
51
46
 
52
- def flatten_into(self, parts: List[FlatPart]) -> None:
47
+ def flatten_into(self, parts: list[FlatPart]) -> None:
53
48
  for part in self.parts:
54
49
  if isinstance(part, Fragment):
55
50
  part.flatten_into(parts)
@@ -70,10 +65,10 @@ class Fragment:
70
65
  for i, part in enumerate(flattened.parts):
71
66
  if isinstance(part, Slot):
72
67
  func.append(
73
- f" process_slot_value({repr(part.name)}, slots[{repr(part.name)}], placeholders),"
68
+ f" process_slot_value({part.name!r}, slots[{part.name!r}], placeholders),"
74
69
  )
75
70
  elif isinstance(part, str):
76
- func.append(f" {repr(part)},")
71
+ func.append(f" {part!r},")
77
72
  else:
78
73
  env[f"part_{i}"] = part
79
74
  func.append(f" part_{i},")
@@ -82,9 +77,9 @@ class Fragment:
82
77
  return env["compiled"] # type: ignore
83
78
 
84
79
  def flatten(self) -> "Fragment":
85
- parts: List[FlatPart] = []
80
+ parts: list[FlatPart] = []
86
81
  self.flatten_into(parts)
87
- out_parts: List[Part] = []
82
+ out_parts: list[Part] = []
88
83
  for part in parts:
89
84
  if isinstance(part, str) and out_parts and isinstance(out_parts[-1], str):
90
85
  out_parts[-1] += part
@@ -93,9 +88,9 @@ class Fragment:
93
88
  return Fragment(out_parts)
94
89
 
95
90
  def fill(self, **kwargs: Any) -> "Fragment":
96
- parts: List[Part] = []
97
- self.flatten_into(cast(List[FlatPart], parts))
98
- placeholders: Dict[str, Placeholder] = {}
91
+ parts: list[Part] = []
92
+ self.flatten_into(cast(list[FlatPart], parts))
93
+ placeholders: dict[str, Placeholder] = {}
99
94
  for i, part in enumerate(parts):
100
95
  if isinstance(part, Slot):
101
96
  parts[i] = process_slot_value(
@@ -106,24 +101,24 @@ class Fragment:
106
101
  @overload
107
102
  def prep_query(
108
103
  self, allow_slots: Literal[True]
109
- ) -> Tuple[str, List[Union[Placeholder, Slot]]]: ... # pragma: no cover
104
+ ) -> tuple[str, list[Union[Placeholder, Slot]]]: ... # pragma: no cover
110
105
 
111
106
  @overload
112
107
  def prep_query(
113
108
  self, allow_slots: Literal[False] = False
114
- ) -> Tuple[str, List[Placeholder]]: ... # pragma: no cover
109
+ ) -> tuple[str, list[Placeholder]]: ... # pragma: no cover
115
110
 
116
- def prep_query(self, allow_slots: bool = False) -> Tuple[str, List[Any]]:
117
- parts: List[FlatPart] = []
111
+ def prep_query(self, allow_slots: bool = False) -> tuple[str, list[Any]]:
112
+ parts: list[FlatPart] = []
118
113
  self.flatten_into(parts)
119
- args: List[Union[Placeholder, Slot]] = []
120
- placeholder_ids: Dict[Placeholder, int] = {}
121
- slot_ids: Dict[Slot, int] = {}
122
- out_parts: List[str] = []
114
+ args: list[Union[Placeholder, Slot]] = []
115
+ placeholder_ids: dict[Placeholder, int] = {}
116
+ slot_ids: dict[Slot, int] = {}
117
+ out_parts: list[str] = []
123
118
  for part in parts:
124
119
  if isinstance(part, Slot):
125
120
  if not allow_slots:
126
- raise ValueError(f"Unfilled slot: {repr(part.name)}")
121
+ raise ValueError(f"Unfilled slot: {part.name!r}")
127
122
  if part not in slot_ids:
128
123
  args.append(part)
129
124
  slot_ids[part] = len(args)
@@ -138,7 +133,7 @@ class Fragment:
138
133
  out_parts.append(part)
139
134
  return "".join(out_parts).strip(), args
140
135
 
141
- def query(self) -> Tuple[str, List[Any]]:
136
+ def query(self) -> tuple[str, list[Any]]:
142
137
  query, args = self.prep_query()
143
138
  placeholder_values = [arg.value for arg in args]
144
139
  return query, placeholder_values
@@ -146,16 +141,16 @@ class Fragment:
146
141
  def sqlalchemy_text(self) -> Any:
147
142
  return sqlalchemy_text_from_fragment(self)
148
143
 
149
- def prepare(self) -> Tuple[str, Callable[..., List[Any]]]:
144
+ def prepare(self) -> tuple[str, Callable[..., list[Any]]]:
150
145
  query, args = self.prep_query(allow_slots=True)
151
- env = dict()
146
+ env = {}
152
147
  func = [
153
148
  "def generate_args(**kwargs):",
154
149
  " return [",
155
150
  ]
156
151
  for i, arg in enumerate(args):
157
152
  if isinstance(arg, Slot):
158
- func.append(f" kwargs[{repr(arg.name)}],")
153
+ func.append(f" kwargs[{arg.name!r}],")
159
154
  else:
160
155
  env[f"value_{i}"] = arg.value
161
156
  func.append(f" value_{i},")
@@ -178,10 +173,10 @@ class SQLFormatter:
178
173
  if not preserve_formatting:
179
174
  fmt = newline_whitespace_re.sub(" ", fmt)
180
175
  fmtr = string.Formatter()
181
- parts: List[Part] = []
182
- placeholders: Dict[str, Placeholder] = {}
176
+ parts: list[Part] = []
177
+ placeholders: dict[str, Placeholder] = {}
183
178
  next_auto_field = 0
184
- for literal_text, field_name, format_spec, conversion in fmtr.parse(fmt):
179
+ for literal_text, field_name, _format_spec, _conversion in fmtr.parse(fmt):
185
180
  parts.append(literal_text)
186
181
  if field_name is not None:
187
182
  if auto_numbered(field_name):
@@ -258,12 +253,11 @@ class SQLFormatter:
258
253
  parts = parts[0]
259
254
  return Fragment(list(join_parts(parts, infix=", ")))
260
255
 
261
- @staticmethod
262
- def unnest(data: Iterable[Sequence[Any]], types: Iterable[str]) -> Fragment:
263
- nested = list(nest_for_type(x, t) for x, t in zip(zip(*data), types))
256
+ def unnest(self, data: Iterable[Sequence[Any]], types: Iterable[str]) -> Fragment:
257
+ nested = [nest_for_type(x, t) for x, t in zip(zip(*data), types)]
264
258
  if not nested:
265
- nested = list(nest_for_type([], t) for t in types)
266
- return Fragment(["UNNEST(", sql.list(nested), ")"])
259
+ nested = [nest_for_type([], t) for t in types]
260
+ return Fragment(["UNNEST(", self.list(nested), ")"])
267
261
 
268
262
 
269
263
  sql = SQLFormatter()
@@ -296,7 +290,7 @@ def lit(text: str) -> Fragment:
296
290
  return Fragment([text])
297
291
 
298
292
 
299
- def any_all(frags: List[Fragment], op: str, base_case: str) -> Fragment:
293
+ def any_all(frags: list[Fragment], op: str, base_case: str) -> Fragment:
300
294
  if not frags:
301
295
  return lit(base_case)
302
296
  parts = join_parts(frags, prefix="(", infix=f") {op} (", suffix=")")
sql_athame/dataclasses.py CHANGED
@@ -1,60 +1,57 @@
1
1
  import datetime
2
2
  import uuid
3
- from dataclasses import dataclass, field, fields
3
+ from collections.abc import AsyncGenerator, Iterable, Mapping
4
+ from dataclasses import Field, InitVar, dataclass, fields
4
5
  from typing import (
6
+ Annotated,
5
7
  Any,
6
- AsyncGenerator,
7
8
  Callable,
8
- Dict,
9
- Iterable,
10
- Iterator,
11
- List,
12
- Mapping,
13
9
  Optional,
14
- Set,
15
- Tuple,
16
- Type,
17
10
  TypeVar,
18
11
  Union,
12
+ get_origin,
13
+ get_type_hints,
19
14
  )
20
15
 
16
+ from typing_extensions import TypeAlias
17
+
21
18
  from .base import Fragment, sql
22
19
 
23
- Where = Union[Fragment, Iterable[Fragment]]
20
+ Where: TypeAlias = Union[Fragment, Iterable[Fragment]]
24
21
  # KLUDGE to avoid a string argument being valid
25
- SequenceOfStrings = Union[List[str], Tuple[str, ...]]
26
- FieldNames = SequenceOfStrings
27
- FieldNamesSet = Union[SequenceOfStrings, Set[str]]
22
+ SequenceOfStrings: TypeAlias = Union[list[str], tuple[str, ...]]
23
+ FieldNames: TypeAlias = SequenceOfStrings
24
+ FieldNamesSet: TypeAlias = Union[SequenceOfStrings, set[str]]
28
25
 
29
- Connection = Any
30
- Pool = Any
26
+ Connection: TypeAlias = Any
27
+ Pool: TypeAlias = Any
31
28
 
32
29
 
33
30
  @dataclass
34
31
  class ColumnInfo:
35
32
  type: str
36
- create_type: str
37
- constraints: Tuple[str, ...]
38
-
39
- def create_table_string(self):
40
- return " ".join((self.create_type, *self.constraints))
41
-
42
-
43
- def model_field_metadata(
44
- type: str, constraints: Union[str, Iterable[str]] = ()
45
- ) -> Dict[str, Any]:
46
- if isinstance(constraints, str):
47
- constraints = (constraints,)
48
- info = ColumnInfo(
49
- sql_create_type_map.get(type.upper(), type), type, tuple(constraints)
50
- )
51
- return {"sql_athame": info}
52
-
53
-
54
- def model_field(
55
- *, type: str, constraints: Union[str, Iterable[str]] = (), **kwargs: Any
56
- ) -> Any:
57
- return field(**kwargs, metadata=model_field_metadata(type, constraints))
33
+ create_type: str = ""
34
+ nullable: bool = False
35
+ _constraints: tuple[str, ...] = ()
36
+
37
+ constraints: InitVar[Union[str, Iterable[str], None]] = None
38
+
39
+ def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
40
+ if self.create_type == "":
41
+ self.create_type = self.type
42
+ self.type = sql_create_type_map.get(self.type.upper(), self.type)
43
+ if constraints is not None:
44
+ if type(constraints) is str:
45
+ constraints = (constraints,)
46
+ self._constraints = tuple(constraints)
47
+
48
+ def create_table_string(self) -> str:
49
+ parts = (
50
+ self.create_type,
51
+ *(() if self.nullable else ("NOT NULL",)),
52
+ *self._constraints,
53
+ )
54
+ return " ".join(parts)
58
55
 
59
56
 
60
57
  sql_create_type_map = {
@@ -64,43 +61,37 @@ sql_create_type_map = {
64
61
  }
65
62
 
66
63
 
67
- sql_type_map = {
68
- Optional[bool]: ("BOOLEAN",),
69
- Optional[bytes]: ("BYTEA",),
70
- Optional[datetime.date]: ("DATE",),
71
- Optional[datetime.datetime]: ("TIMESTAMP",),
72
- Optional[float]: ("DOUBLE PRECISION",),
73
- Optional[int]: ("INTEGER",),
74
- Optional[str]: ("TEXT",),
75
- Optional[uuid.UUID]: ("UUID",),
76
- bool: ("BOOLEAN", "NOT NULL"),
77
- bytes: ("BYTEA", "NOT NULL"),
78
- datetime.date: ("DATE", "NOT NULL"),
79
- datetime.datetime: ("TIMESTAMP", "NOT NULL"),
80
- float: ("DOUBLE PRECISION", "NOT NULL"),
81
- int: ("INTEGER", "NOT NULL"),
82
- str: ("TEXT", "NOT NULL"),
83
- uuid.UUID: ("UUID", "NOT NULL"),
64
+ sql_type_map: dict[Any, tuple[str, bool]] = {
65
+ Optional[bool]: ("BOOLEAN", True),
66
+ Optional[bytes]: ("BYTEA", True),
67
+ Optional[datetime.date]: ("DATE", True),
68
+ Optional[datetime.datetime]: ("TIMESTAMP", True),
69
+ Optional[float]: ("DOUBLE PRECISION", True),
70
+ Optional[int]: ("INTEGER", True),
71
+ Optional[str]: ("TEXT", True),
72
+ Optional[uuid.UUID]: ("UUID", True),
73
+ bool: ("BOOLEAN", False),
74
+ bytes: ("BYTEA", False),
75
+ datetime.date: ("DATE", False),
76
+ datetime.datetime: ("TIMESTAMP", False),
77
+ float: ("DOUBLE PRECISION", False),
78
+ int: ("INTEGER", False),
79
+ str: ("TEXT", False),
80
+ uuid.UUID: ("UUID", False),
84
81
  }
85
82
 
86
83
 
87
- def column_info_for_field(field):
88
- if "sql_athame" in field.metadata:
89
- return field.metadata["sql_athame"]
90
- type, *constraints = sql_type_map[field.type]
91
- return ColumnInfo(type, type, tuple(constraints))
92
-
93
-
94
84
  T = TypeVar("T", bound="ModelBase")
95
85
  U = TypeVar("U")
96
86
 
97
87
 
98
- class ModelBase(Mapping[str, Any]):
99
- _column_info: Optional[Dict[str, ColumnInfo]]
100
- _cache: Dict[tuple, Any]
88
+ class ModelBase:
89
+ _column_info: Optional[dict[str, ColumnInfo]]
90
+ _cache: dict[tuple, Any]
101
91
  table_name: str
102
- primary_key_names: Tuple[str, ...]
92
+ primary_key_names: tuple[str, ...]
103
93
  array_safe_insert: bool
94
+ _type_hints: dict[str, type]
104
95
 
105
96
  def __init_subclass__(
106
97
  cls,
@@ -138,27 +129,34 @@ class ModelBase(Mapping[str, Any]):
138
129
  cls._cache[key] = thunk()
139
130
  return cls._cache[key]
140
131
 
141
- def keys(self):
142
- return self.field_names()
143
-
144
- def __getitem__(self, key: str) -> Any:
145
- return getattr(self, key)
146
-
147
- def __iter__(self) -> Iterator[Any]:
148
- return iter(self.keys())
149
-
150
- def __len__(self) -> int:
151
- return len(self.keys())
132
+ @classmethod
133
+ def type_hints(cls) -> dict[str, type]:
134
+ try:
135
+ return cls._type_hints
136
+ except AttributeError:
137
+ cls._type_hints = get_type_hints(cls, include_extras=True)
138
+ return cls._type_hints
152
139
 
153
- def get(self, key: str, default: Any = None) -> Any:
154
- return getattr(self, key, default)
140
+ @classmethod
141
+ def column_info_for_field(cls, field: Field) -> ColumnInfo:
142
+ type_info = cls.type_hints()[field.name]
143
+ base_type = type_info
144
+ if get_origin(type_info) is Annotated:
145
+ base_type = type_info.__origin__ # type: ignore
146
+ for md in type_info.__metadata__: # type: ignore
147
+ if isinstance(md, ColumnInfo):
148
+ return md
149
+ type, nullable = sql_type_map[base_type]
150
+ return ColumnInfo(type=type, nullable=nullable)
155
151
 
156
152
  @classmethod
157
153
  def column_info(cls, column: str) -> ColumnInfo:
158
154
  try:
159
155
  return cls._column_info[column] # type: ignore
160
156
  except AttributeError:
161
- cls._column_info = {f.name: column_info_for_field(f) for f in cls._fields()}
157
+ cls._column_info = {
158
+ f.name: cls.column_info_for_field(f) for f in cls._fields()
159
+ }
162
160
  return cls._column_info[column]
163
161
 
164
162
  @classmethod
@@ -166,17 +164,17 @@ class ModelBase(Mapping[str, Any]):
166
164
  return sql.identifier(cls.table_name, prefix=prefix)
167
165
 
168
166
  @classmethod
169
- def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> List[Fragment]:
167
+ def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
170
168
  return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
171
169
 
172
170
  @classmethod
173
- def field_names(cls, *, exclude: FieldNamesSet = ()) -> List[str]:
171
+ def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
174
172
  return [f.name for f in cls._fields() if f.name not in exclude]
175
173
 
176
174
  @classmethod
177
175
  def field_names_sql(
178
176
  cls, *, prefix: Optional[str] = None, exclude: FieldNamesSet = ()
179
- ) -> List[Fragment]:
177
+ ) -> list[Fragment]:
180
178
  return [
181
179
  sql.identifier(f, prefix=prefix) for f in cls.field_names(exclude=exclude)
182
180
  ]
@@ -186,9 +184,9 @@ class ModelBase(Mapping[str, Any]):
186
184
 
187
185
  @classmethod
188
186
  def _get_field_values_fn(
189
- cls: Type[T], exclude: FieldNamesSet = ()
190
- ) -> Callable[[T], List[Any]]:
191
- env: Dict[str, Any] = dict()
187
+ cls: type[T], exclude: FieldNamesSet = ()
188
+ ) -> Callable[[T], list[Any]]:
189
+ env: dict[str, Any] = {}
192
190
  func = ["def get_field_values(self): return ["]
193
191
  for f in cls._fields():
194
192
  if f.name not in exclude:
@@ -197,7 +195,7 @@ class ModelBase(Mapping[str, Any]):
197
195
  exec(" ".join(func), env)
198
196
  return env["get_field_values"]
199
197
 
200
- def field_values(self, *, exclude: FieldNamesSet = ()) -> List[Any]:
198
+ def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
201
199
  get_field_values = self._cached(
202
200
  ("get_field_values", tuple(sorted(exclude))),
203
201
  lambda: self._get_field_values_fn(exclude),
@@ -206,7 +204,7 @@ class ModelBase(Mapping[str, Any]):
206
204
 
207
205
  def field_values_sql(
208
206
  self, *, exclude: FieldNamesSet = (), default_none: bool = False
209
- ) -> List[Fragment]:
207
+ ) -> list[Fragment]:
210
208
  if default_none:
211
209
  return [
212
210
  sql.literal("DEFAULT") if value is None else sql.value(value)
@@ -217,7 +215,7 @@ class ModelBase(Mapping[str, Any]):
217
215
 
218
216
  @classmethod
219
217
  def from_tuple(
220
- cls: Type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
218
+ cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
221
219
  ) -> T:
222
220
  names = (f.name for f in cls._fields() if f.name not in exclude)
223
221
  kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
@@ -225,14 +223,14 @@ class ModelBase(Mapping[str, Any]):
225
223
 
226
224
  @classmethod
227
225
  def from_dict(
228
- cls: Type[T], dct: Dict[str, Any], *, exclude: FieldNamesSet = ()
226
+ cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
229
227
  ) -> T:
230
228
  names = {f.name for f in cls._fields() if f.name not in exclude}
231
229
  kwargs = {k: v for k, v in dct.items() if k in names}
232
230
  return cls(**kwargs)
233
231
 
234
232
  @classmethod
235
- def ensure_model(cls: Type[T], row: Union[T, Mapping[str, Any]]) -> T:
233
+ def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
236
234
  if isinstance(row, cls):
237
235
  return row
238
236
  return cls(**row)
@@ -286,7 +284,7 @@ class ModelBase(Mapping[str, Any]):
286
284
 
287
285
  @classmethod
288
286
  async def select_cursor(
289
- cls: Type[T],
287
+ cls: type[T],
290
288
  connection: Connection,
291
289
  order_by: Union[FieldNames, str] = (),
292
290
  for_update: bool = False,
@@ -301,12 +299,12 @@ class ModelBase(Mapping[str, Any]):
301
299
 
302
300
  @classmethod
303
301
  async def select(
304
- cls: Type[T],
302
+ cls: type[T],
305
303
  connection_or_pool: Union[Connection, Pool],
306
304
  order_by: Union[FieldNames, str] = (),
307
305
  for_update: bool = False,
308
306
  where: Where = (),
309
- ) -> List[T]:
307
+ ) -> list[T]:
310
308
  return [
311
309
  cls(**row)
312
310
  for row in await connection_or_pool.fetch(
@@ -315,7 +313,7 @@ class ModelBase(Mapping[str, Any]):
315
313
  ]
316
314
 
317
315
  @classmethod
318
- def create_sql(cls: Type[T], **kwargs: Any) -> Fragment:
316
+ def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
319
317
  return sql(
320
318
  "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
321
319
  table=cls.table_name_sql(),
@@ -326,7 +324,7 @@ class ModelBase(Mapping[str, Any]):
326
324
 
327
325
  @classmethod
328
326
  async def create(
329
- cls: Type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
327
+ cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
330
328
  ) -> T:
331
329
  row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
332
330
  return cls(**row)
@@ -375,11 +373,10 @@ class ModelBase(Mapping[str, Any]):
375
373
  self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
376
374
  )
377
375
  result = await connection_or_pool.fetchrow(*query)
378
- is_update = result["xmax"] != 0
379
- return is_update
376
+ return result["xmax"] != 0
380
377
 
381
378
  @classmethod
382
- def delete_multiple_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
379
+ def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
383
380
  cached = cls._cached(
384
381
  ("delete_multiple_sql",),
385
382
  lambda: sql(
@@ -397,12 +394,12 @@ class ModelBase(Mapping[str, Any]):
397
394
 
398
395
  @classmethod
399
396
  async def delete_multiple(
400
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
397
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
401
398
  ) -> str:
402
399
  return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
403
400
 
404
401
  @classmethod
405
- def insert_multiple_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
402
+ def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
406
403
  cached = cls._cached(
407
404
  ("insert_multiple_sql",),
408
405
  lambda: sql(
@@ -419,7 +416,7 @@ class ModelBase(Mapping[str, Any]):
419
416
  )
420
417
 
421
418
  @classmethod
422
- def insert_multiple_array_safe_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
419
+ def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
423
420
  return sql(
424
421
  "INSERT INTO {table} ({fields}) VALUES {values}",
425
422
  table=cls.table_name_sql(),
@@ -432,13 +429,13 @@ class ModelBase(Mapping[str, Any]):
432
429
 
433
430
  @classmethod
434
431
  async def insert_multiple_unnest(
435
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
432
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
436
433
  ) -> str:
437
434
  return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
438
435
 
439
436
  @classmethod
440
437
  async def insert_multiple_array_safe(
441
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
438
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
442
439
  ) -> str:
443
440
  last = ""
444
441
  for chunk in chunked(rows, 100):
@@ -449,7 +446,7 @@ class ModelBase(Mapping[str, Any]):
449
446
 
450
447
  @classmethod
451
448
  async def insert_multiple(
452
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
449
+ cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
453
450
  ) -> str:
454
451
  if cls.array_safe_insert:
455
452
  return await cls.insert_multiple_array_safe(connection_or_pool, rows)
@@ -458,37 +455,52 @@ class ModelBase(Mapping[str, Any]):
458
455
 
459
456
  @classmethod
460
457
  async def upsert_multiple_unnest(
461
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
458
+ cls: type[T],
459
+ connection_or_pool: Union[Connection, Pool],
460
+ rows: Iterable[T],
461
+ insert_only: FieldNamesSet = (),
462
462
  ) -> str:
463
463
  return await connection_or_pool.execute(
464
- *cls.upsert_sql(cls.insert_multiple_sql(rows))
464
+ *cls.upsert_sql(cls.insert_multiple_sql(rows), exclude=insert_only)
465
465
  )
466
466
 
467
467
  @classmethod
468
468
  async def upsert_multiple_array_safe(
469
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
469
+ cls: type[T],
470
+ connection_or_pool: Union[Connection, Pool],
471
+ rows: Iterable[T],
472
+ insert_only: FieldNamesSet = (),
470
473
  ) -> str:
471
474
  last = ""
472
475
  for chunk in chunked(rows, 100):
473
476
  last = await connection_or_pool.execute(
474
- *cls.upsert_sql(cls.insert_multiple_array_safe_sql(chunk))
477
+ *cls.upsert_sql(
478
+ cls.insert_multiple_array_safe_sql(chunk), exclude=insert_only
479
+ )
475
480
  )
476
481
  return last
477
482
 
478
483
  @classmethod
479
484
  async def upsert_multiple(
480
- cls: Type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
485
+ cls: type[T],
486
+ connection_or_pool: Union[Connection, Pool],
487
+ rows: Iterable[T],
488
+ insert_only: FieldNamesSet = (),
481
489
  ) -> str:
482
490
  if cls.array_safe_insert:
483
- return await cls.upsert_multiple_array_safe(connection_or_pool, rows)
491
+ return await cls.upsert_multiple_array_safe(
492
+ connection_or_pool, rows, insert_only=insert_only
493
+ )
484
494
  else:
485
- return await cls.upsert_multiple_unnest(connection_or_pool, rows)
495
+ return await cls.upsert_multiple_unnest(
496
+ connection_or_pool, rows, insert_only=insert_only
497
+ )
486
498
 
487
499
  @classmethod
488
500
  def _get_equal_ignoring_fn(
489
- cls: Type[T], ignore: FieldNamesSet = ()
501
+ cls: type[T], ignore: FieldNamesSet = ()
490
502
  ) -> Callable[[T, T], bool]:
491
- env: Dict[str, Any] = dict()
503
+ env: dict[str, Any] = {}
492
504
  func = ["def equal_ignoring(a, b):"]
493
505
  for f in cls._fields():
494
506
  if f.name not in ignore:
@@ -499,15 +511,17 @@ class ModelBase(Mapping[str, Any]):
499
511
 
500
512
  @classmethod
501
513
  async def replace_multiple(
502
- cls: Type[T],
514
+ cls: type[T],
503
515
  connection: Connection,
504
516
  rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
505
517
  *,
506
518
  where: Where,
507
519
  ignore: FieldNamesSet = (),
508
- ) -> Tuple[List[T], List[T], List[T]]:
520
+ insert_only: FieldNamesSet = (),
521
+ ) -> tuple[list[T], list[T], list[T]]:
522
+ ignore = sorted(set(ignore) | set(insert_only))
509
523
  equal_ignoring = cls._cached(
510
- ("equal_ignoring", tuple(sorted(ignore))),
524
+ ("equal_ignoring", tuple(ignore)),
511
525
  lambda: cls._get_equal_ignoring_fn(ignore),
512
526
  )
513
527
  pending = {row.primary_key(): row for row in map(cls.ensure_model, rows)}
@@ -529,7 +543,9 @@ class ModelBase(Mapping[str, Any]):
529
543
  created = list(pending.values())
530
544
 
531
545
  if created or updated:
532
- await cls.upsert_multiple(connection, (*created, *updated))
546
+ await cls.upsert_multiple(
547
+ connection, (*created, *updated), insert_only=insert_only
548
+ )
533
549
  if deleted:
534
550
  await cls.delete_multiple(connection, deleted)
535
551
 
@@ -537,33 +553,33 @@ class ModelBase(Mapping[str, Any]):
537
553
 
538
554
  @classmethod
539
555
  def _get_differences_ignoring_fn(
540
- cls: Type[T], ignore: FieldNamesSet = ()
541
- ) -> Callable[[T, T], List[str]]:
542
- env: Dict[str, Any] = dict()
556
+ cls: type[T], ignore: FieldNamesSet = ()
557
+ ) -> Callable[[T, T], list[str]]:
558
+ env: dict[str, Any] = {}
543
559
  func = [
544
560
  "def differences_ignoring(a, b):",
545
561
  " diffs = []",
546
562
  ]
547
563
  for f in cls._fields():
548
564
  if f.name not in ignore:
549
- func.append(
550
- f" if a.{f.name} != b.{f.name}: diffs.append({repr(f.name)})"
551
- )
565
+ func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
552
566
  func += [" return diffs"]
553
567
  exec("\n".join(func), env)
554
568
  return env["differences_ignoring"]
555
569
 
556
570
  @classmethod
557
571
  async def replace_multiple_reporting_differences(
558
- cls: Type[T],
572
+ cls: type[T],
559
573
  connection: Connection,
560
574
  rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
561
575
  *,
562
576
  where: Where,
563
577
  ignore: FieldNamesSet = (),
564
- ) -> Tuple[List[T], List[Tuple[T, T, List[str]]], List[T]]:
578
+ insert_only: FieldNamesSet = (),
579
+ ) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
580
+ ignore = sorted(set(ignore) | set(insert_only))
565
581
  differences_ignoring = cls._cached(
566
- ("differences_ignoring", tuple(sorted(ignore))),
582
+ ("differences_ignoring", tuple(ignore)),
567
583
  lambda: cls._get_differences_ignoring_fn(ignore),
568
584
  )
569
585
 
@@ -588,7 +604,9 @@ class ModelBase(Mapping[str, Any]):
588
604
 
589
605
  if created or updated_triples:
590
606
  await cls.upsert_multiple(
591
- connection, (*created, *(t[1] for t in updated_triples))
607
+ connection,
608
+ (*created, *(t[1] for t in updated_triples)),
609
+ insert_only=insert_only,
592
610
  )
593
611
  if deleted:
594
612
  await cls.delete_multiple(connection, deleted)
sql_athame/escape.py CHANGED
@@ -1,19 +1,20 @@
1
1
  import math
2
2
  import uuid
3
- from typing import Any, Sequence
3
+ from collections.abc import Sequence
4
+ from typing import Any
4
5
 
5
6
 
6
7
  def escape(value: Any) -> str:
7
8
  if isinstance(value, str):
8
- return f"E{repr(value)}"
9
+ return f"E{value!r}"
9
10
  elif isinstance(value, float) or isinstance(value, int):
10
11
  if math.isnan(value):
11
12
  raise ValueError("Can't escape NaN float")
12
13
  elif math.isinf(value):
13
14
  raise ValueError("Can't escape infinite float")
14
- return f"{repr(value)}"
15
+ return f"{value!r}"
15
16
  elif isinstance(value, uuid.UUID):
16
- return f"{repr(str(value))}::UUID"
17
+ return f"{str(value)!r}::UUID"
17
18
  elif isinstance(value, Sequence):
18
19
  args = ", ".join(escape(x) for x in value)
19
20
  return f"ARRAY[{args}]"
sql_athame/sqlalchemy.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Any, Dict, List
1
+ from typing import TYPE_CHECKING, Any
2
2
 
3
3
  from .types import FlatPart, Placeholder, Slot
4
4
 
@@ -7,10 +7,10 @@ try:
7
7
  from sqlalchemy.sql.elements import BindParameter
8
8
 
9
9
  def sqlalchemy_text_from_fragment(self: "Fragment") -> Any:
10
- parts: List[FlatPart] = []
10
+ parts: list[FlatPart] = []
11
11
  self.flatten_into(parts)
12
- bindparams: Dict[str, Any] = {}
13
- out_parts: List[str] = []
12
+ bindparams: dict[str, Any] = {}
13
+ out_parts: list[str] = []
14
14
  for part in parts:
15
15
  if isinstance(part, Slot):
16
16
  out_parts.append(f"(:{part.name})")
sql_athame/types.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import dataclasses
2
2
  from typing import TYPE_CHECKING, Any, Union
3
3
 
4
+ from typing_extensions import TypeAlias
5
+
4
6
 
5
7
  @dataclasses.dataclass(eq=False)
6
8
  class Placeholder:
@@ -15,8 +17,8 @@ class Slot:
15
17
  name: str
16
18
 
17
19
 
18
- Part = Union[str, Placeholder, Slot, "Fragment"]
19
- FlatPart = Union[str, Placeholder, Slot]
20
+ Part: TypeAlias = Union[str, Placeholder, Slot, "Fragment"]
21
+ FlatPart: TypeAlias = Union[str, Placeholder, Slot]
20
22
 
21
23
  if TYPE_CHECKING:
22
24
  from .base import Fragment
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sql-athame
3
- Version: 0.4.0a4
3
+ Version: 0.4.0a6
4
4
  Summary: Python tool for slicing and dicing SQL
5
5
  Home-page: https://github.com/bdowning/sql-athame
6
6
  License: MIT
@@ -0,0 +1,11 @@
1
+ sql_athame/__init__.py,sha256=7OBIMZOcrD2pvfIL-rjD1IGZ3TNQbwyu76a9PWk-yYg,79
2
+ sql_athame/base.py,sha256=FR7EmC0VkX1VRgvAutSEfYSWhlEYpoqS1Kqxp1jHp6Y,10293
3
+ sql_athame/dataclasses.py,sha256=qb4EESR6J-iv6UScktMLuKAwH3ZA3IOwCM0v6oMv8Q8,20848
4
+ sql_athame/escape.py,sha256=kK101xXeFitlvuG-L_hvhdpgGJCtmRTprsn1yEfZKws,758
5
+ sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ sql_athame/sqlalchemy.py,sha256=aWopfPh3j71XwKmcN_VcHRNlhscI0Sckd4AiyGf8Tpw,1293
7
+ sql_athame/types.py,sha256=FQ06l9Uc-vo57UrAarvnukILdV2gN1IaYUnHJ_bNYic,475
8
+ sql_athame-0.4.0a6.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
9
+ sql_athame-0.4.0a6.dist-info/METADATA,sha256=8Ov07iCKPAo35uWW3t9WRRlXIgRlNGGrMpyXnOI6TRs,12845
10
+ sql_athame-0.4.0a6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
11
+ sql_athame-0.4.0a6.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- sql_athame/__init__.py,sha256=rzUQcbzmj3qkPZpL9jI_ALTRv-e1pAV4jSCryWkutlk,130
2
- sql_athame/base.py,sha256=fSnHQhh5ULeJ5q32RVUAvpWtF0qoY61B2gEEP59Nrpo,10350
3
- sql_athame/dataclasses.py,sha256=EPq9wd2mqcEvT0kAEhrZztlDDV3Hroiwa8GoP3hi_l8,19787
4
- sql_athame/escape.py,sha256=LXExbiYtc407yDU4vPieyY2Pq5nypsJFfBc_2-gsbUg,743
5
- sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- sql_athame/sqlalchemy.py,sha256=c-pCLE11hTh5I19rY1Vp5E7P7lAaj9i-i7ko2L8rlF4,1305
7
- sql_athame/types.py,sha256=7P4OyY0ezRlb2UDD9lpdXiLChnhQcBvHWaG_PKy3jmE,412
8
- sql_athame-0.4.0a4.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
9
- sql_athame-0.4.0a4.dist-info/METADATA,sha256=J5gkrkq4cAwD81oODgQjugY9y3k7-RImhFvqLMaJqr4,12845
10
- sql_athame-0.4.0a4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
11
- sql_athame-0.4.0a4.dist-info/RECORD,,