sql-athame 0.4.0a8__py3-none-any.whl → 0.4.0a9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_athame/dataclasses.py +102 -65
- {sql_athame-0.4.0a8.dist-info → sql_athame-0.4.0a9.dist-info}/METADATA +1 -1
- {sql_athame-0.4.0a8.dist-info → sql_athame-0.4.0a9.dist-info}/RECORD +5 -5
- {sql_athame-0.4.0a8.dist-info → sql_athame-0.4.0a9.dist-info}/LICENSE +0 -0
- {sql_athame-0.4.0a8.dist-info → sql_athame-0.4.0a9.dist-info}/WHEEL +0 -0
sql_athame/dataclasses.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import datetime
|
2
2
|
import functools
|
3
|
+
import sys
|
3
4
|
import uuid
|
4
5
|
from collections.abc import AsyncGenerator, Iterable, Mapping
|
5
6
|
from dataclasses import Field, InitVar, dataclass, fields
|
@@ -34,10 +35,13 @@ class ColumnInfo:
|
|
34
35
|
type: Optional[str] = None
|
35
36
|
create_type: Optional[str] = None
|
36
37
|
nullable: Optional[bool] = None
|
37
|
-
_constraints: tuple[str, ...] = ()
|
38
38
|
|
39
|
+
_constraints: tuple[str, ...] = ()
|
39
40
|
constraints: InitVar[Union[str, Iterable[str], None]] = None
|
40
41
|
|
42
|
+
serialize: Optional[Callable[[Any], Any]] = None
|
43
|
+
deserialize: Optional[Callable[[Any], Any]] = None
|
44
|
+
|
41
45
|
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
|
42
46
|
if constraints is not None:
|
43
47
|
if type(constraints) is str:
|
@@ -51,29 +55,41 @@ class ColumnInfo:
|
|
51
55
|
create_type=b.create_type if b.create_type is not None else a.create_type,
|
52
56
|
nullable=b.nullable if b.nullable is not None else a.nullable,
|
53
57
|
_constraints=(*a._constraints, *b._constraints),
|
58
|
+
serialize=b.serialize if b.serialize is not None else a.serialize,
|
59
|
+
deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
|
54
60
|
)
|
55
61
|
|
56
62
|
|
57
63
|
@dataclass
|
58
64
|
class ConcreteColumnInfo:
|
65
|
+
field: Field
|
66
|
+
type_hint: type
|
59
67
|
type: str
|
60
68
|
create_type: str
|
61
69
|
nullable: bool
|
62
70
|
constraints: tuple[str, ...]
|
71
|
+
serialize: Optional[Callable[[Any], Any]] = None
|
72
|
+
deserialize: Optional[Callable[[Any], Any]] = None
|
63
73
|
|
64
74
|
@staticmethod
|
65
|
-
def from_column_info(
|
75
|
+
def from_column_info(
|
76
|
+
field: Field, type_hint: Any, *args: ColumnInfo
|
77
|
+
) -> "ConcreteColumnInfo":
|
66
78
|
info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
|
67
79
|
if info.create_type is None and info.type is not None:
|
68
80
|
info.create_type = info.type
|
69
81
|
info.type = sql_create_type_map.get(info.type.upper(), info.type)
|
70
82
|
if type(info.type) is not str or type(info.create_type) is not str:
|
71
|
-
raise ValueError(f"Missing SQL type for column {name!r}")
|
83
|
+
raise ValueError(f"Missing SQL type for column {field.name!r}")
|
72
84
|
return ConcreteColumnInfo(
|
85
|
+
field=field,
|
86
|
+
type_hint=type_hint,
|
73
87
|
type=info.type,
|
74
88
|
create_type=info.create_type,
|
75
89
|
nullable=bool(info.nullable),
|
76
90
|
constraints=info._constraints,
|
91
|
+
serialize=info.serialize,
|
92
|
+
deserialize=info.deserialize,
|
77
93
|
)
|
78
94
|
|
79
95
|
def create_table_string(self) -> str:
|
@@ -84,13 +100,24 @@ class ConcreteColumnInfo:
|
|
84
100
|
)
|
85
101
|
return " ".join(parts)
|
86
102
|
|
103
|
+
def maybe_serialize(self, value: Any) -> Any:
|
104
|
+
if self.serialize:
|
105
|
+
return self.serialize(value)
|
106
|
+
return value
|
107
|
+
|
108
|
+
|
109
|
+
UNION_TYPES: tuple = (Union,)
|
110
|
+
if sys.version_info >= (3, 10):
|
111
|
+
from types import UnionType
|
112
|
+
|
113
|
+
UNION_TYPES = (Union, UnionType)
|
87
114
|
|
88
115
|
NULLABLE_TYPES = (type(None), Any, object)
|
89
116
|
|
90
117
|
|
91
118
|
def split_nullable(typ: type) -> tuple[bool, type]:
|
92
119
|
nullable = typ in NULLABLE_TYPES
|
93
|
-
if get_origin(typ)
|
120
|
+
if get_origin(typ) in UNION_TYPES:
|
94
121
|
args = []
|
95
122
|
for arg in get_args(typ):
|
96
123
|
if arg in NULLABLE_TYPES:
|
@@ -108,7 +135,7 @@ sql_create_type_map = {
|
|
108
135
|
}
|
109
136
|
|
110
137
|
|
111
|
-
sql_type_map: dict[
|
138
|
+
sql_type_map: dict[type, str] = {
|
112
139
|
bool: "BOOLEAN",
|
113
140
|
bytes: "BYTEA",
|
114
141
|
datetime.date: "DATE",
|
@@ -125,12 +152,11 @@ U = TypeVar("U")
|
|
125
152
|
|
126
153
|
|
127
154
|
class ModelBase:
|
128
|
-
_column_info:
|
155
|
+
_column_info: dict[str, ConcreteColumnInfo]
|
129
156
|
_cache: dict[tuple, Any]
|
130
157
|
table_name: str
|
131
158
|
primary_key_names: tuple[str, ...]
|
132
159
|
array_safe_insert: bool
|
133
|
-
_type_hints: dict[str, type]
|
134
160
|
|
135
161
|
def __init_subclass__(
|
136
162
|
cls,
|
@@ -153,13 +179,6 @@ class ModelBase:
|
|
153
179
|
else:
|
154
180
|
cls.primary_key_names = tuple(primary_key)
|
155
181
|
|
156
|
-
@classmethod
|
157
|
-
def _fields(cls):
|
158
|
-
# wrapper to ignore typing weirdness: 'Argument 1 to "fields"
|
159
|
-
# has incompatible type "..."; expected "DataclassInstance |
|
160
|
-
# type[DataclassInstance]"'
|
161
|
-
return fields(cls) # type: ignore
|
162
|
-
|
163
182
|
@classmethod
|
164
183
|
def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
|
165
184
|
try:
|
@@ -169,20 +188,11 @@ class ModelBase:
|
|
169
188
|
return cls._cache[key]
|
170
189
|
|
171
190
|
@classmethod
|
172
|
-
def
|
173
|
-
|
174
|
-
return cls._type_hints
|
175
|
-
except AttributeError:
|
176
|
-
cls._type_hints = get_type_hints(cls, include_extras=True)
|
177
|
-
return cls._type_hints
|
178
|
-
|
179
|
-
@classmethod
|
180
|
-
def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
|
181
|
-
type_info = cls.type_hints()[field.name]
|
182
|
-
base_type = type_info
|
191
|
+
def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
|
192
|
+
base_type = type_hint
|
183
193
|
metadata = []
|
184
|
-
if get_origin(
|
185
|
-
base_type, *metadata = get_args(
|
194
|
+
if get_origin(type_hint) is Annotated:
|
195
|
+
base_type, *metadata = get_args(type_hint)
|
186
196
|
nullable, base_type = split_nullable(base_type)
|
187
197
|
info = [ColumnInfo(nullable=nullable)]
|
188
198
|
if base_type in sql_type_map:
|
@@ -190,17 +200,19 @@ class ModelBase:
|
|
190
200
|
for md in metadata:
|
191
201
|
if isinstance(md, ColumnInfo):
|
192
202
|
info.append(md)
|
193
|
-
return ConcreteColumnInfo.from_column_info(field
|
203
|
+
return ConcreteColumnInfo.from_column_info(field, type_hint, *info)
|
194
204
|
|
195
205
|
@classmethod
|
196
|
-
def column_info(cls
|
206
|
+
def column_info(cls) -> dict[str, ConcreteColumnInfo]:
|
197
207
|
try:
|
198
|
-
return cls._column_info
|
208
|
+
return cls._column_info
|
199
209
|
except AttributeError:
|
210
|
+
type_hints = get_type_hints(cls, include_extras=True)
|
200
211
|
cls._column_info = {
|
201
|
-
f.name: cls.column_info_for_field(f
|
212
|
+
f.name: cls.column_info_for_field(f, type_hints[f.name])
|
213
|
+
for f in fields(cls) # type: ignore
|
202
214
|
}
|
203
|
-
return cls._column_info
|
215
|
+
return cls._column_info
|
204
216
|
|
205
217
|
@classmethod
|
206
218
|
def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
|
@@ -212,7 +224,11 @@ class ModelBase:
|
|
212
224
|
|
213
225
|
@classmethod
|
214
226
|
def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
|
215
|
-
return [
|
227
|
+
return [
|
228
|
+
ci.field.name
|
229
|
+
for ci in cls.column_info().values()
|
230
|
+
if ci.field.name not in exclude
|
231
|
+
]
|
216
232
|
|
217
233
|
@classmethod
|
218
234
|
def field_names_sql(
|
@@ -231,9 +247,13 @@ class ModelBase:
|
|
231
247
|
) -> Callable[[T], list[Any]]:
|
232
248
|
env: dict[str, Any] = {}
|
233
249
|
func = ["def get_field_values(self): return ["]
|
234
|
-
for
|
235
|
-
if
|
236
|
-
|
250
|
+
for ci in cls.column_info().values():
|
251
|
+
if ci.field.name not in exclude:
|
252
|
+
if ci.serialize:
|
253
|
+
env[f"_ser_{ci.field.name}"] = ci.serialize
|
254
|
+
func.append(f"_ser_{ci.field.name}(self.{ci.field.name}), ")
|
255
|
+
else:
|
256
|
+
func.append(f"self.{ci.field.name},")
|
237
257
|
func += ["]"]
|
238
258
|
exec(" ".join(func), env)
|
239
259
|
return env["get_field_values"]
|
@@ -257,36 +277,46 @@ class ModelBase:
|
|
257
277
|
return [sql.value(value) for value in self.field_values()]
|
258
278
|
|
259
279
|
@classmethod
|
260
|
-
def
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
280
|
+
def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
|
281
|
+
env: dict[str, Any] = {"cls": cls}
|
282
|
+
func = ["def from_mapping(mapping):"]
|
283
|
+
if not any(ci.deserialize for ci in cls.column_info().values()):
|
284
|
+
func.append(" return cls(**mapping)")
|
285
|
+
else:
|
286
|
+
func.append(" deser_dict = dict(mapping)")
|
287
|
+
for ci in cls.column_info().values():
|
288
|
+
if ci.deserialize:
|
289
|
+
env[f"_deser_{ci.field.name}"] = ci.deserialize
|
290
|
+
func.append(f" if {ci.field.name!r} in deser_dict:")
|
291
|
+
func.append(
|
292
|
+
f" deser_dict[{ci.field.name!r}] = _deser_{ci.field.name}(deser_dict[{ci.field.name!r}])"
|
293
|
+
)
|
294
|
+
func.append(" return cls(**deser_dict)")
|
295
|
+
exec("\n".join(func), env)
|
296
|
+
return env["from_mapping"]
|
266
297
|
|
267
298
|
@classmethod
|
268
|
-
def
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
return cls(**kwargs)
|
299
|
+
def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
|
300
|
+
# KLUDGE nasty but... efficient?
|
301
|
+
from_mapping_fn = cls._get_from_mapping_fn()
|
302
|
+
cls.from_mapping = from_mapping_fn # type: ignore
|
303
|
+
return from_mapping_fn(mapping)
|
274
304
|
|
275
305
|
@classmethod
|
276
306
|
def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
|
277
307
|
if isinstance(row, cls):
|
278
308
|
return row
|
279
|
-
return cls(
|
309
|
+
return cls.from_mapping(row) # type: ignore
|
280
310
|
|
281
311
|
@classmethod
|
282
312
|
def create_table_sql(cls) -> Fragment:
|
283
313
|
entries = [
|
284
314
|
sql(
|
285
315
|
"{} {}",
|
286
|
-
sql.identifier(
|
287
|
-
sql.literal(
|
316
|
+
sql.identifier(ci.field.name),
|
317
|
+
sql.literal(ci.create_table_string()),
|
288
318
|
)
|
289
|
-
for
|
319
|
+
for ci in cls.column_info().values()
|
290
320
|
]
|
291
321
|
if cls.primary_key_names:
|
292
322
|
entries += [sql("PRIMARY KEY ({})", sql.list(cls.primary_key_names_sql()))]
|
@@ -338,7 +368,7 @@ class ModelBase:
|
|
338
368
|
*cls.select_sql(order_by=order_by, for_update=for_update, where=where),
|
339
369
|
prefetch=prefetch,
|
340
370
|
):
|
341
|
-
yield cls(
|
371
|
+
yield cls.from_mapping(row)
|
342
372
|
|
343
373
|
@classmethod
|
344
374
|
async def select(
|
@@ -349,7 +379,7 @@ class ModelBase:
|
|
349
379
|
where: Where = (),
|
350
380
|
) -> list[T]:
|
351
381
|
return [
|
352
|
-
cls(
|
382
|
+
cls.from_mapping(row)
|
353
383
|
for row in await connection_or_pool.fetch(
|
354
384
|
*cls.select_sql(order_by=order_by, for_update=for_update, where=where)
|
355
385
|
)
|
@@ -357,11 +387,14 @@ class ModelBase:
|
|
357
387
|
|
358
388
|
@classmethod
|
359
389
|
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
|
390
|
+
column_info = cls.column_info()
|
360
391
|
return sql(
|
361
392
|
"INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
|
362
393
|
table=cls.table_name_sql(),
|
363
|
-
fields=sql.list(sql.identifier(
|
364
|
-
values=sql.list(
|
394
|
+
fields=sql.list(sql.identifier(k) for k in kwargs.keys()),
|
395
|
+
values=sql.list(
|
396
|
+
sql.value(column_info[k].maybe_serialize(v)) for k, v in kwargs.items()
|
397
|
+
),
|
365
398
|
out_fields=sql.list(cls.field_names_sql()),
|
366
399
|
)
|
367
400
|
|
@@ -370,7 +403,7 @@ class ModelBase:
|
|
370
403
|
cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
|
371
404
|
) -> T:
|
372
405
|
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
|
373
|
-
return cls(
|
406
|
+
return cls.from_mapping(row)
|
374
407
|
|
375
408
|
def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
|
376
409
|
cached = self._cached(
|
@@ -428,10 +461,11 @@ class ModelBase:
|
|
428
461
|
pks=sql.list(sql.identifier(pk) for pk in cls.primary_key_names),
|
429
462
|
).compile(),
|
430
463
|
)
|
464
|
+
column_info = cls.column_info()
|
431
465
|
return cached(
|
432
466
|
unnest=sql.unnest(
|
433
467
|
(row.primary_key() for row in rows),
|
434
|
-
(
|
468
|
+
(column_info[pk].type for pk in cls.primary_key_names),
|
435
469
|
),
|
436
470
|
)
|
437
471
|
|
@@ -451,10 +485,11 @@ class ModelBase:
|
|
451
485
|
fields=sql.list(cls.field_names_sql()),
|
452
486
|
).compile(),
|
453
487
|
)
|
488
|
+
column_info = cls.column_info()
|
454
489
|
return cached(
|
455
490
|
unnest=sql.unnest(
|
456
491
|
(row.field_values() for row in rows),
|
457
|
-
(
|
492
|
+
(column_info[name].type for name in cls.field_names()),
|
458
493
|
),
|
459
494
|
)
|
460
495
|
|
@@ -545,9 +580,9 @@ class ModelBase:
|
|
545
580
|
) -> Callable[[T, T], bool]:
|
546
581
|
env: dict[str, Any] = {}
|
547
582
|
func = ["def equal_ignoring(a, b):"]
|
548
|
-
for
|
549
|
-
if
|
550
|
-
func.append(f" if a.{
|
583
|
+
for ci in cls.column_info().values():
|
584
|
+
if ci.field.name not in ignore:
|
585
|
+
func.append(f" if a.{ci.field.name} != b.{ci.field.name}: return False")
|
551
586
|
func += [" return True"]
|
552
587
|
exec("\n".join(func), env)
|
553
588
|
return env["equal_ignoring"]
|
@@ -603,9 +638,11 @@ class ModelBase:
|
|
603
638
|
"def differences_ignoring(a, b):",
|
604
639
|
" diffs = []",
|
605
640
|
]
|
606
|
-
for
|
607
|
-
if
|
608
|
-
func.append(
|
641
|
+
for ci in cls.column_info().values():
|
642
|
+
if ci.field.name not in ignore:
|
643
|
+
func.append(
|
644
|
+
f" if a.{ci.field.name} != b.{ci.field.name}: diffs.append({ci.field.name!r})"
|
645
|
+
)
|
609
646
|
func += [" return diffs"]
|
610
647
|
exec("\n".join(func), env)
|
611
648
|
return env["differences_ignoring"]
|
@@ -1,11 +1,11 @@
|
|
1
1
|
sql_athame/__init__.py,sha256=7OBIMZOcrD2pvfIL-rjD1IGZ3TNQbwyu76a9PWk-yYg,79
|
2
2
|
sql_athame/base.py,sha256=FR7EmC0VkX1VRgvAutSEfYSWhlEYpoqS1Kqxp1jHp6Y,10293
|
3
|
-
sql_athame/dataclasses.py,sha256=
|
3
|
+
sql_athame/dataclasses.py,sha256=9Q-Z3itKyuqhR5u47bVBfA714uFbf-K4t1FPiFd8XAE,23792
|
4
4
|
sql_athame/escape.py,sha256=kK101xXeFitlvuG-L_hvhdpgGJCtmRTprsn1yEfZKws,758
|
5
5
|
sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
sql_athame/sqlalchemy.py,sha256=aWopfPh3j71XwKmcN_VcHRNlhscI0Sckd4AiyGf8Tpw,1293
|
7
7
|
sql_athame/types.py,sha256=FQ06l9Uc-vo57UrAarvnukILdV2gN1IaYUnHJ_bNYic,475
|
8
|
-
sql_athame-0.4.
|
9
|
-
sql_athame-0.4.
|
10
|
-
sql_athame-0.4.
|
11
|
-
sql_athame-0.4.
|
8
|
+
sql_athame-0.4.0a9.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
|
9
|
+
sql_athame-0.4.0a9.dist-info/METADATA,sha256=pf4xAdRJ7NuJaViLLWsQeq8LRIA78tY_YdOmPBjpFgg,12845
|
10
|
+
sql_athame-0.4.0a9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
11
|
+
sql_athame-0.4.0a9.dist-info/RECORD,,
|
File without changes
|
File without changes
|