sql-athame 0.4.0a8__tar.gz → 0.4.0a10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sql-athame
3
- Version: 0.4.0a8
3
+ Version: 0.4.0a10
4
4
  Summary: Python tool for slicing and dicing SQL
5
5
  Home-page: https://github.com/bdowning/sql-athame
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sql-athame"
3
- version = "0.4.0-alpha-8"
3
+ version = "0.4.0-alpha-10"
4
4
  description = "Python tool for slicing and dicing SQL"
5
5
  authors = ["Brian Downing <bdowning@lavos.net>"]
6
6
  license = "MIT"
@@ -0,0 +1,2 @@
1
+ from .base import Fragment, sql
2
+ from .dataclasses import ColumnInfo, ModelBase, ReplaceMultiplePlan
@@ -1,5 +1,6 @@
1
1
  import datetime
2
2
  import functools
3
+ import sys
3
4
  import uuid
4
5
  from collections.abc import AsyncGenerator, Iterable, Mapping
5
6
  from dataclasses import Field, InitVar, dataclass, fields
@@ -7,6 +8,7 @@ from typing import (
7
8
  Annotated,
8
9
  Any,
9
10
  Callable,
11
+ Generic,
10
12
  Optional,
11
13
  TypeVar,
12
14
  Union,
@@ -34,10 +36,13 @@ class ColumnInfo:
34
36
  type: Optional[str] = None
35
37
  create_type: Optional[str] = None
36
38
  nullable: Optional[bool] = None
37
- _constraints: tuple[str, ...] = ()
38
39
 
40
+ _constraints: tuple[str, ...] = ()
39
41
  constraints: InitVar[Union[str, Iterable[str], None]] = None
40
42
 
43
+ serialize: Optional[Callable[[Any], Any]] = None
44
+ deserialize: Optional[Callable[[Any], Any]] = None
45
+
41
46
  def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
42
47
  if constraints is not None:
43
48
  if type(constraints) is str:
@@ -51,29 +56,41 @@ class ColumnInfo:
51
56
  create_type=b.create_type if b.create_type is not None else a.create_type,
52
57
  nullable=b.nullable if b.nullable is not None else a.nullable,
53
58
  _constraints=(*a._constraints, *b._constraints),
59
+ serialize=b.serialize if b.serialize is not None else a.serialize,
60
+ deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
54
61
  )
55
62
 
56
63
 
57
64
  @dataclass
58
65
  class ConcreteColumnInfo:
66
+ field: Field
67
+ type_hint: type
59
68
  type: str
60
69
  create_type: str
61
70
  nullable: bool
62
71
  constraints: tuple[str, ...]
72
+ serialize: Optional[Callable[[Any], Any]] = None
73
+ deserialize: Optional[Callable[[Any], Any]] = None
63
74
 
64
75
  @staticmethod
65
- def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
76
+ def from_column_info(
77
+ field: Field, type_hint: Any, *args: ColumnInfo
78
+ ) -> "ConcreteColumnInfo":
66
79
  info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
67
80
  if info.create_type is None and info.type is not None:
68
81
  info.create_type = info.type
69
82
  info.type = sql_create_type_map.get(info.type.upper(), info.type)
70
83
  if type(info.type) is not str or type(info.create_type) is not str:
71
- raise ValueError(f"Missing SQL type for column {name!r}")
84
+ raise ValueError(f"Missing SQL type for column {field.name!r}")
72
85
  return ConcreteColumnInfo(
86
+ field=field,
87
+ type_hint=type_hint,
73
88
  type=info.type,
74
89
  create_type=info.create_type,
75
90
  nullable=bool(info.nullable),
76
91
  constraints=info._constraints,
92
+ serialize=info.serialize,
93
+ deserialize=info.deserialize,
77
94
  )
78
95
 
79
96
  def create_table_string(self) -> str:
@@ -84,13 +101,24 @@ class ConcreteColumnInfo:
84
101
  )
85
102
  return " ".join(parts)
86
103
 
104
+ def maybe_serialize(self, value: Any) -> Any:
105
+ if self.serialize:
106
+ return self.serialize(value)
107
+ return value
108
+
109
+
110
+ UNION_TYPES: tuple = (Union,)
111
+ if sys.version_info >= (3, 10):
112
+ from types import UnionType
113
+
114
+ UNION_TYPES = (Union, UnionType)
87
115
 
88
116
  NULLABLE_TYPES = (type(None), Any, object)
89
117
 
90
118
 
91
119
  def split_nullable(typ: type) -> tuple[bool, type]:
92
120
  nullable = typ in NULLABLE_TYPES
93
- if get_origin(typ) is Union:
121
+ if get_origin(typ) in UNION_TYPES:
94
122
  args = []
95
123
  for arg in get_args(typ):
96
124
  if arg in NULLABLE_TYPES:
@@ -108,7 +136,7 @@ sql_create_type_map = {
108
136
  }
109
137
 
110
138
 
111
- sql_type_map: dict[Any, str] = {
139
+ sql_type_map: dict[type, str] = {
112
140
  bool: "BOOLEAN",
113
141
  bytes: "BYTEA",
114
142
  datetime.date: "DATE",
@@ -125,12 +153,11 @@ U = TypeVar("U")
125
153
 
126
154
 
127
155
  class ModelBase:
128
- _column_info: Optional[dict[str, ConcreteColumnInfo]]
156
+ _column_info: dict[str, ConcreteColumnInfo]
129
157
  _cache: dict[tuple, Any]
130
158
  table_name: str
131
159
  primary_key_names: tuple[str, ...]
132
160
  array_safe_insert: bool
133
- _type_hints: dict[str, type]
134
161
 
135
162
  def __init_subclass__(
136
163
  cls,
@@ -153,13 +180,6 @@ class ModelBase:
153
180
  else:
154
181
  cls.primary_key_names = tuple(primary_key)
155
182
 
156
- @classmethod
157
- def _fields(cls):
158
- # wrapper to ignore typing weirdness: 'Argument 1 to "fields"
159
- # has incompatible type "..."; expected "DataclassInstance |
160
- # type[DataclassInstance]"'
161
- return fields(cls) # type: ignore
162
-
163
183
  @classmethod
164
184
  def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
165
185
  try:
@@ -169,20 +189,11 @@ class ModelBase:
169
189
  return cls._cache[key]
170
190
 
171
191
  @classmethod
172
- def type_hints(cls) -> dict[str, type]:
173
- try:
174
- return cls._type_hints
175
- except AttributeError:
176
- cls._type_hints = get_type_hints(cls, include_extras=True)
177
- return cls._type_hints
178
-
179
- @classmethod
180
- def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
181
- type_info = cls.type_hints()[field.name]
182
- base_type = type_info
192
+ def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
193
+ base_type = type_hint
183
194
  metadata = []
184
- if get_origin(type_info) is Annotated:
185
- base_type, *metadata = get_args(type_info)
195
+ if get_origin(type_hint) is Annotated:
196
+ base_type, *metadata = get_args(type_hint)
186
197
  nullable, base_type = split_nullable(base_type)
187
198
  info = [ColumnInfo(nullable=nullable)]
188
199
  if base_type in sql_type_map:
@@ -190,17 +201,19 @@ class ModelBase:
190
201
  for md in metadata:
191
202
  if isinstance(md, ColumnInfo):
192
203
  info.append(md)
193
- return ConcreteColumnInfo.from_column_info(field.name, *info)
204
+ return ConcreteColumnInfo.from_column_info(field, type_hint, *info)
194
205
 
195
206
  @classmethod
196
- def column_info(cls, column: str) -> ConcreteColumnInfo:
207
+ def column_info(cls) -> dict[str, ConcreteColumnInfo]:
197
208
  try:
198
- return cls._column_info[column] # type: ignore
209
+ return cls._column_info
199
210
  except AttributeError:
211
+ type_hints = get_type_hints(cls, include_extras=True)
200
212
  cls._column_info = {
201
- f.name: cls.column_info_for_field(f) for f in cls._fields()
213
+ f.name: cls.column_info_for_field(f, type_hints[f.name])
214
+ for f in fields(cls) # type: ignore
202
215
  }
203
- return cls._column_info[column]
216
+ return cls._column_info
204
217
 
205
218
  @classmethod
206
219
  def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
@@ -212,7 +225,11 @@ class ModelBase:
212
225
 
213
226
  @classmethod
214
227
  def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
215
- return [f.name for f in cls._fields() if f.name not in exclude]
228
+ return [
229
+ ci.field.name
230
+ for ci in cls.column_info().values()
231
+ if ci.field.name not in exclude
232
+ ]
216
233
 
217
234
  @classmethod
218
235
  def field_names_sql(
@@ -231,9 +248,13 @@ class ModelBase:
231
248
  ) -> Callable[[T], list[Any]]:
232
249
  env: dict[str, Any] = {}
233
250
  func = ["def get_field_values(self): return ["]
234
- for f in cls._fields():
235
- if f.name not in exclude:
236
- func.append(f"self.{f.name},")
251
+ for ci in cls.column_info().values():
252
+ if ci.field.name not in exclude:
253
+ if ci.serialize:
254
+ env[f"_ser_{ci.field.name}"] = ci.serialize
255
+ func.append(f"_ser_{ci.field.name}(self.{ci.field.name}),")
256
+ else:
257
+ func.append(f"self.{ci.field.name},")
237
258
  func += ["]"]
238
259
  exec(" ".join(func), env)
239
260
  return env["get_field_values"]
@@ -257,36 +278,46 @@ class ModelBase:
257
278
  return [sql.value(value) for value in self.field_values()]
258
279
 
259
280
  @classmethod
260
- def from_tuple(
261
- cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
262
- ) -> T:
263
- names = (f.name for f in cls._fields() if f.name not in exclude)
264
- kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
265
- return cls(**kwargs)
281
+ def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
282
+ env: dict[str, Any] = {"cls": cls}
283
+ func = ["def from_mapping(mapping):"]
284
+ if not any(ci.deserialize for ci in cls.column_info().values()):
285
+ func.append(" return cls(**mapping)")
286
+ else:
287
+ func.append(" deser_dict = dict(mapping)")
288
+ for ci in cls.column_info().values():
289
+ if ci.deserialize:
290
+ env[f"_deser_{ci.field.name}"] = ci.deserialize
291
+ func.append(f" if {ci.field.name!r} in deser_dict:")
292
+ func.append(
293
+ f" deser_dict[{ci.field.name!r}] = _deser_{ci.field.name}(deser_dict[{ci.field.name!r}])"
294
+ )
295
+ func.append(" return cls(**deser_dict)")
296
+ exec("\n".join(func), env)
297
+ return env["from_mapping"]
266
298
 
267
299
  @classmethod
268
- def from_dict(
269
- cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
270
- ) -> T:
271
- names = {f.name for f in cls._fields() if f.name not in exclude}
272
- kwargs = {k: v for k, v in dct.items() if k in names}
273
- return cls(**kwargs)
300
+ def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
301
+ # KLUDGE nasty but... efficient?
302
+ from_mapping_fn = cls._get_from_mapping_fn()
303
+ cls.from_mapping = from_mapping_fn # type: ignore
304
+ return from_mapping_fn(mapping)
274
305
 
275
306
  @classmethod
276
307
  def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
277
308
  if isinstance(row, cls):
278
309
  return row
279
- return cls(**row)
310
+ return cls.from_mapping(row) # type: ignore
280
311
 
281
312
  @classmethod
282
313
  def create_table_sql(cls) -> Fragment:
283
314
  entries = [
284
315
  sql(
285
316
  "{} {}",
286
- sql.identifier(f.name),
287
- sql.literal(cls.column_info(f.name).create_table_string()),
317
+ sql.identifier(ci.field.name),
318
+ sql.literal(ci.create_table_string()),
288
319
  )
289
- for f in cls._fields()
320
+ for ci in cls.column_info().values()
290
321
  ]
291
322
  if cls.primary_key_names:
292
323
  entries += [sql("PRIMARY KEY ({})", sql.list(cls.primary_key_names_sql()))]
@@ -338,7 +369,7 @@ class ModelBase:
338
369
  *cls.select_sql(order_by=order_by, for_update=for_update, where=where),
339
370
  prefetch=prefetch,
340
371
  ):
341
- yield cls(**row)
372
+ yield cls.from_mapping(row)
342
373
 
343
374
  @classmethod
344
375
  async def select(
@@ -349,7 +380,7 @@ class ModelBase:
349
380
  where: Where = (),
350
381
  ) -> list[T]:
351
382
  return [
352
- cls(**row)
383
+ cls.from_mapping(row)
353
384
  for row in await connection_or_pool.fetch(
354
385
  *cls.select_sql(order_by=order_by, for_update=for_update, where=where)
355
386
  )
@@ -357,11 +388,14 @@ class ModelBase:
357
388
 
358
389
  @classmethod
359
390
  def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
391
+ column_info = cls.column_info()
360
392
  return sql(
361
393
  "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
362
394
  table=cls.table_name_sql(),
363
- fields=sql.list(sql.identifier(x) for x in kwargs.keys()),
364
- values=sql.list(sql.value(x) for x in kwargs.values()),
395
+ fields=sql.list(sql.identifier(k) for k in kwargs.keys()),
396
+ values=sql.list(
397
+ sql.value(column_info[k].maybe_serialize(v)) for k, v in kwargs.items()
398
+ ),
365
399
  out_fields=sql.list(cls.field_names_sql()),
366
400
  )
367
401
 
@@ -370,7 +404,7 @@ class ModelBase:
370
404
  cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
371
405
  ) -> T:
372
406
  row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
373
- return cls(**row)
407
+ return cls.from_mapping(row)
374
408
 
375
409
  def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
376
410
  cached = self._cached(
@@ -428,10 +462,11 @@ class ModelBase:
428
462
  pks=sql.list(sql.identifier(pk) for pk in cls.primary_key_names),
429
463
  ).compile(),
430
464
  )
465
+ column_info = cls.column_info()
431
466
  return cached(
432
467
  unnest=sql.unnest(
433
468
  (row.primary_key() for row in rows),
434
- (cls.column_info(pk).type for pk in cls.primary_key_names),
469
+ (column_info[pk].type for pk in cls.primary_key_names),
435
470
  ),
436
471
  )
437
472
 
@@ -451,10 +486,11 @@ class ModelBase:
451
486
  fields=sql.list(cls.field_names_sql()),
452
487
  ).compile(),
453
488
  )
489
+ column_info = cls.column_info()
454
490
  return cached(
455
491
  unnest=sql.unnest(
456
492
  (row.field_values() for row in rows),
457
- (cls.column_info(name).type for name in cls.field_names()),
493
+ (column_info[name].type for name in cls.field_names()),
458
494
  ),
459
495
  )
460
496
 
@@ -545,15 +581,15 @@ class ModelBase:
545
581
  ) -> Callable[[T, T], bool]:
546
582
  env: dict[str, Any] = {}
547
583
  func = ["def equal_ignoring(a, b):"]
548
- for f in cls._fields():
549
- if f.name not in ignore:
550
- func.append(f" if a.{f.name} != b.{f.name}: return False")
584
+ for ci in cls.column_info().values():
585
+ if ci.field.name not in ignore:
586
+ func.append(f" if a.{ci.field.name} != b.{ci.field.name}: return False")
551
587
  func += [" return True"]
552
588
  exec("\n".join(func), env)
553
589
  return env["equal_ignoring"]
554
590
 
555
591
  @classmethod
556
- async def replace_multiple(
592
+ async def plan_replace_multiple(
557
593
  cls: type[T],
558
594
  connection: Connection,
559
595
  rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
@@ -561,7 +597,7 @@ class ModelBase:
561
597
  where: Where,
562
598
  ignore: FieldNamesSet = (),
563
599
  insert_only: FieldNamesSet = (),
564
- ) -> tuple[list[T], list[T], list[T]]:
600
+ ) -> "ReplaceMultiplePlan[T]":
565
601
  ignore = sorted(set(ignore) | set(insert_only))
566
602
  equal_ignoring = cls._cached(
567
603
  ("equal_ignoring", tuple(ignore)),
@@ -585,14 +621,23 @@ class ModelBase:
585
621
 
586
622
  created = list(pending.values())
587
623
 
588
- if created or updated:
589
- await cls.upsert_multiple(
590
- connection, (*created, *updated), insert_only=insert_only
591
- )
592
- if deleted:
593
- await cls.delete_multiple(connection, deleted)
624
+ return ReplaceMultiplePlan(cls, insert_only, created, updated, deleted)
594
625
 
595
- return created, updated, deleted
626
+ @classmethod
627
+ async def replace_multiple(
628
+ cls: type[T],
629
+ connection: Connection,
630
+ rows: Union[Iterable[T], Iterable[Mapping[str, Any]]],
631
+ *,
632
+ where: Where,
633
+ ignore: FieldNamesSet = (),
634
+ insert_only: FieldNamesSet = (),
635
+ ) -> tuple[list[T], list[T], list[T]]:
636
+ plan = await cls.plan_replace_multiple(
637
+ connection, rows, where=where, ignore=ignore, insert_only=insert_only
638
+ )
639
+ await plan.execute(connection)
640
+ return plan.cud
596
641
 
597
642
  @classmethod
598
643
  def _get_differences_ignoring_fn(
@@ -603,9 +648,11 @@ class ModelBase:
603
648
  "def differences_ignoring(a, b):",
604
649
  " diffs = []",
605
650
  ]
606
- for f in cls._fields():
607
- if f.name not in ignore:
608
- func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
651
+ for ci in cls.column_info().values():
652
+ if ci.field.name not in ignore:
653
+ func.append(
654
+ f" if a.{ci.field.name} != b.{ci.field.name}: diffs.append({ci.field.name!r})"
655
+ )
609
656
  func += [" return diffs"]
610
657
  exec("\n".join(func), env)
611
658
  return env["differences_ignoring"]
@@ -657,6 +704,33 @@ class ModelBase:
657
704
  return created, updated_triples, deleted
658
705
 
659
706
 
707
+ @dataclass
708
+ class ReplaceMultiplePlan(Generic[T]):
709
+ model_class: type[T]
710
+ insert_only: FieldNamesSet
711
+ created: list[T]
712
+ updated: list[T]
713
+ deleted: list[T]
714
+
715
+ @property
716
+ def cud(self) -> tuple[list[T], list[T], list[T]]:
717
+ return (self.created, self.updated, self.deleted)
718
+
719
+ async def execute_upserts(self, connection: Connection) -> None:
720
+ if self.created or self.updated:
721
+ await self.model_class.upsert_multiple(
722
+ connection, (*self.created, *self.updated), insert_only=self.insert_only
723
+ )
724
+
725
+ async def execute_deletes(self, connection: Connection) -> None:
726
+ if self.deleted:
727
+ await self.model_class.delete_multiple(connection, self.deleted)
728
+
729
+ async def execute(self, connection: Connection) -> None:
730
+ await self.execute_upserts(connection)
731
+ await self.execute_deletes(connection)
732
+
733
+
660
734
  def chunked(lst, n):
661
735
  if type(lst) is not list:
662
736
  lst = list(lst)
@@ -1,2 +0,0 @@
1
- from .base import Fragment, sql
2
- from .dataclasses import ColumnInfo, ModelBase
File without changes
File without changes