sql-athame 0.4.0a7__py3-none-any.whl → 0.4.0a9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_athame/dataclasses.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import datetime
2
2
  import functools
3
+ import sys
3
4
  import uuid
4
5
  from collections.abc import AsyncGenerator, Iterable, Mapping
5
6
  from dataclasses import Field, InitVar, dataclass, fields
@@ -10,6 +11,7 @@ from typing import (
10
11
  Optional,
11
12
  TypeVar,
12
13
  Union,
14
+ get_args,
13
15
  get_origin,
14
16
  get_type_hints,
15
17
  )
@@ -33,10 +35,13 @@ class ColumnInfo:
33
35
  type: Optional[str] = None
34
36
  create_type: Optional[str] = None
35
37
  nullable: Optional[bool] = None
36
- _constraints: tuple[str, ...] = ()
37
38
 
39
+ _constraints: tuple[str, ...] = ()
38
40
  constraints: InitVar[Union[str, Iterable[str], None]] = None
39
41
 
42
+ serialize: Optional[Callable[[Any], Any]] = None
43
+ deserialize: Optional[Callable[[Any], Any]] = None
44
+
40
45
  def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
41
46
  if constraints is not None:
42
47
  if type(constraints) is str:
@@ -50,29 +55,41 @@ class ColumnInfo:
50
55
  create_type=b.create_type if b.create_type is not None else a.create_type,
51
56
  nullable=b.nullable if b.nullable is not None else a.nullable,
52
57
  _constraints=(*a._constraints, *b._constraints),
58
+ serialize=b.serialize if b.serialize is not None else a.serialize,
59
+ deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
53
60
  )
54
61
 
55
62
 
56
63
  @dataclass
57
64
  class ConcreteColumnInfo:
65
+ field: Field
66
+ type_hint: type
58
67
  type: str
59
68
  create_type: str
60
69
  nullable: bool
61
70
  constraints: tuple[str, ...]
71
+ serialize: Optional[Callable[[Any], Any]] = None
72
+ deserialize: Optional[Callable[[Any], Any]] = None
62
73
 
63
74
  @staticmethod
64
- def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
75
+ def from_column_info(
76
+ field: Field, type_hint: Any, *args: ColumnInfo
77
+ ) -> "ConcreteColumnInfo":
65
78
  info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
66
79
  if info.create_type is None and info.type is not None:
67
80
  info.create_type = info.type
68
81
  info.type = sql_create_type_map.get(info.type.upper(), info.type)
69
82
  if type(info.type) is not str or type(info.create_type) is not str:
70
- raise ValueError(f"Missing SQL type for column {name!r}")
83
+ raise ValueError(f"Missing SQL type for column {field.name!r}")
71
84
  return ConcreteColumnInfo(
85
+ field=field,
86
+ type_hint=type_hint,
72
87
  type=info.type,
73
88
  create_type=info.create_type,
74
89
  nullable=bool(info.nullable),
75
90
  constraints=info._constraints,
91
+ serialize=info.serialize,
92
+ deserialize=info.deserialize,
76
93
  )
77
94
 
78
95
  def create_table_string(self) -> str:
@@ -83,6 +100,33 @@ class ConcreteColumnInfo:
83
100
  )
84
101
  return " ".join(parts)
85
102
 
103
+ def maybe_serialize(self, value: Any) -> Any:
104
+ if self.serialize:
105
+ return self.serialize(value)
106
+ return value
107
+
108
+
109
+ UNION_TYPES: tuple = (Union,)
110
+ if sys.version_info >= (3, 10):
111
+ from types import UnionType
112
+
113
+ UNION_TYPES = (Union, UnionType)
114
+
115
+ NULLABLE_TYPES = (type(None), Any, object)
116
+
117
+
118
+ def split_nullable(typ: type) -> tuple[bool, type]:
119
+ nullable = typ in NULLABLE_TYPES
120
+ if get_origin(typ) in UNION_TYPES:
121
+ args = []
122
+ for arg in get_args(typ):
123
+ if arg in NULLABLE_TYPES:
124
+ nullable = True
125
+ else:
126
+ args.append(arg)
127
+ return nullable, Union[tuple(args)] # type: ignore
128
+ return nullable, typ
129
+
86
130
 
87
131
  sql_create_type_map = {
88
132
  "BIGSERIAL": "BIGINT",
@@ -91,23 +135,15 @@ sql_create_type_map = {
91
135
  }
92
136
 
93
137
 
94
- sql_type_map: dict[Any, tuple[str, bool]] = {
95
- Optional[bool]: ("BOOLEAN", True),
96
- Optional[bytes]: ("BYTEA", True),
97
- Optional[datetime.date]: ("DATE", True),
98
- Optional[datetime.datetime]: ("TIMESTAMP", True),
99
- Optional[float]: ("DOUBLE PRECISION", True),
100
- Optional[int]: ("INTEGER", True),
101
- Optional[str]: ("TEXT", True),
102
- Optional[uuid.UUID]: ("UUID", True),
103
- bool: ("BOOLEAN", False),
104
- bytes: ("BYTEA", False),
105
- datetime.date: ("DATE", False),
106
- datetime.datetime: ("TIMESTAMP", False),
107
- float: ("DOUBLE PRECISION", False),
108
- int: ("INTEGER", False),
109
- str: ("TEXT", False),
110
- uuid.UUID: ("UUID", False),
138
+ sql_type_map: dict[type, str] = {
139
+ bool: "BOOLEAN",
140
+ bytes: "BYTEA",
141
+ datetime.date: "DATE",
142
+ datetime.datetime: "TIMESTAMP",
143
+ float: "DOUBLE PRECISION",
144
+ int: "INTEGER",
145
+ str: "TEXT",
146
+ uuid.UUID: "UUID",
111
147
  }
112
148
 
113
149
 
@@ -116,12 +152,11 @@ U = TypeVar("U")
116
152
 
117
153
 
118
154
  class ModelBase:
119
- _column_info: Optional[dict[str, ConcreteColumnInfo]]
155
+ _column_info: dict[str, ConcreteColumnInfo]
120
156
  _cache: dict[tuple, Any]
121
157
  table_name: str
122
158
  primary_key_names: tuple[str, ...]
123
159
  array_safe_insert: bool
124
- _type_hints: dict[str, type]
125
160
 
126
161
  def __init_subclass__(
127
162
  cls,
@@ -144,13 +179,6 @@ class ModelBase:
144
179
  else:
145
180
  cls.primary_key_names = tuple(primary_key)
146
181
 
147
- @classmethod
148
- def _fields(cls):
149
- # wrapper to ignore typing weirdness: 'Argument 1 to "fields"
150
- # has incompatible type "..."; expected "DataclassInstance |
151
- # type[DataclassInstance]"'
152
- return fields(cls) # type: ignore
153
-
154
182
  @classmethod
155
183
  def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
156
184
  try:
@@ -160,38 +188,31 @@ class ModelBase:
160
188
  return cls._cache[key]
161
189
 
162
190
  @classmethod
163
- def type_hints(cls) -> dict[str, type]:
164
- try:
165
- return cls._type_hints
166
- except AttributeError:
167
- cls._type_hints = get_type_hints(cls, include_extras=True)
168
- return cls._type_hints
169
-
170
- @classmethod
171
- def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
172
- type_info = cls.type_hints()[field.name]
173
- base_type = type_info
174
- if get_origin(type_info) is Annotated:
175
- base_type = type_info.__origin__ # type: ignore
176
- info = []
191
+ def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
192
+ base_type = type_hint
193
+ metadata = []
194
+ if get_origin(type_hint) is Annotated:
195
+ base_type, *metadata = get_args(type_hint)
196
+ nullable, base_type = split_nullable(base_type)
197
+ info = [ColumnInfo(nullable=nullable)]
177
198
  if base_type in sql_type_map:
178
- _type, nullable = sql_type_map[base_type]
179
- info.append(ColumnInfo(type=_type, nullable=nullable))
180
- if get_origin(type_info) is Annotated:
181
- for md in type_info.__metadata__: # type: ignore
182
- if isinstance(md, ColumnInfo):
183
- info.append(md)
184
- return ConcreteColumnInfo.from_column_info(field.name, *info)
199
+ info.append(ColumnInfo(type=sql_type_map[base_type]))
200
+ for md in metadata:
201
+ if isinstance(md, ColumnInfo):
202
+ info.append(md)
203
+ return ConcreteColumnInfo.from_column_info(field, type_hint, *info)
185
204
 
186
205
  @classmethod
187
- def column_info(cls, column: str) -> ConcreteColumnInfo:
206
+ def column_info(cls) -> dict[str, ConcreteColumnInfo]:
188
207
  try:
189
- return cls._column_info[column] # type: ignore
208
+ return cls._column_info
190
209
  except AttributeError:
210
+ type_hints = get_type_hints(cls, include_extras=True)
191
211
  cls._column_info = {
192
- f.name: cls.column_info_for_field(f) for f in cls._fields()
212
+ f.name: cls.column_info_for_field(f, type_hints[f.name])
213
+ for f in fields(cls) # type: ignore
193
214
  }
194
- return cls._column_info[column]
215
+ return cls._column_info
195
216
 
196
217
  @classmethod
197
218
  def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
@@ -203,7 +224,11 @@ class ModelBase:
203
224
 
204
225
  @classmethod
205
226
  def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
206
- return [f.name for f in cls._fields() if f.name not in exclude]
227
+ return [
228
+ ci.field.name
229
+ for ci in cls.column_info().values()
230
+ if ci.field.name not in exclude
231
+ ]
207
232
 
208
233
  @classmethod
209
234
  def field_names_sql(
@@ -222,9 +247,13 @@ class ModelBase:
222
247
  ) -> Callable[[T], list[Any]]:
223
248
  env: dict[str, Any] = {}
224
249
  func = ["def get_field_values(self): return ["]
225
- for f in cls._fields():
226
- if f.name not in exclude:
227
- func.append(f"self.{f.name},")
250
+ for ci in cls.column_info().values():
251
+ if ci.field.name not in exclude:
252
+ if ci.serialize:
253
+ env[f"_ser_{ci.field.name}"] = ci.serialize
254
+ func.append(f"_ser_{ci.field.name}(self.{ci.field.name}), ")
255
+ else:
256
+ func.append(f"self.{ci.field.name},")
228
257
  func += ["]"]
229
258
  exec(" ".join(func), env)
230
259
  return env["get_field_values"]
@@ -248,36 +277,46 @@ class ModelBase:
248
277
  return [sql.value(value) for value in self.field_values()]
249
278
 
250
279
  @classmethod
251
- def from_tuple(
252
- cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
253
- ) -> T:
254
- names = (f.name for f in cls._fields() if f.name not in exclude)
255
- kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
256
- return cls(**kwargs)
280
+ def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
281
+ env: dict[str, Any] = {"cls": cls}
282
+ func = ["def from_mapping(mapping):"]
283
+ if not any(ci.deserialize for ci in cls.column_info().values()):
284
+ func.append(" return cls(**mapping)")
285
+ else:
286
+ func.append(" deser_dict = dict(mapping)")
287
+ for ci in cls.column_info().values():
288
+ if ci.deserialize:
289
+ env[f"_deser_{ci.field.name}"] = ci.deserialize
290
+ func.append(f" if {ci.field.name!r} in deser_dict:")
291
+ func.append(
292
+ f" deser_dict[{ci.field.name!r}] = _deser_{ci.field.name}(deser_dict[{ci.field.name!r}])"
293
+ )
294
+ func.append(" return cls(**deser_dict)")
295
+ exec("\n".join(func), env)
296
+ return env["from_mapping"]
257
297
 
258
298
  @classmethod
259
- def from_dict(
260
- cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
261
- ) -> T:
262
- names = {f.name for f in cls._fields() if f.name not in exclude}
263
- kwargs = {k: v for k, v in dct.items() if k in names}
264
- return cls(**kwargs)
299
+ def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
300
+ # KLUDGE nasty but... efficient?
301
+ from_mapping_fn = cls._get_from_mapping_fn()
302
+ cls.from_mapping = from_mapping_fn # type: ignore
303
+ return from_mapping_fn(mapping)
265
304
 
266
305
  @classmethod
267
306
  def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
268
307
  if isinstance(row, cls):
269
308
  return row
270
- return cls(**row)
309
+ return cls.from_mapping(row) # type: ignore
271
310
 
272
311
  @classmethod
273
312
  def create_table_sql(cls) -> Fragment:
274
313
  entries = [
275
314
  sql(
276
315
  "{} {}",
277
- sql.identifier(f.name),
278
- sql.literal(cls.column_info(f.name).create_table_string()),
316
+ sql.identifier(ci.field.name),
317
+ sql.literal(ci.create_table_string()),
279
318
  )
280
- for f in cls._fields()
319
+ for ci in cls.column_info().values()
281
320
  ]
282
321
  if cls.primary_key_names:
283
322
  entries += [sql("PRIMARY KEY ({})", sql.list(cls.primary_key_names_sql()))]
@@ -329,7 +368,7 @@ class ModelBase:
329
368
  *cls.select_sql(order_by=order_by, for_update=for_update, where=where),
330
369
  prefetch=prefetch,
331
370
  ):
332
- yield cls(**row)
371
+ yield cls.from_mapping(row)
333
372
 
334
373
  @classmethod
335
374
  async def select(
@@ -340,7 +379,7 @@ class ModelBase:
340
379
  where: Where = (),
341
380
  ) -> list[T]:
342
381
  return [
343
- cls(**row)
382
+ cls.from_mapping(row)
344
383
  for row in await connection_or_pool.fetch(
345
384
  *cls.select_sql(order_by=order_by, for_update=for_update, where=where)
346
385
  )
@@ -348,11 +387,14 @@ class ModelBase:
348
387
 
349
388
  @classmethod
350
389
  def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
390
+ column_info = cls.column_info()
351
391
  return sql(
352
392
  "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
353
393
  table=cls.table_name_sql(),
354
- fields=sql.list(sql.identifier(x) for x in kwargs.keys()),
355
- values=sql.list(sql.value(x) for x in kwargs.values()),
394
+ fields=sql.list(sql.identifier(k) for k in kwargs.keys()),
395
+ values=sql.list(
396
+ sql.value(column_info[k].maybe_serialize(v)) for k, v in kwargs.items()
397
+ ),
356
398
  out_fields=sql.list(cls.field_names_sql()),
357
399
  )
358
400
 
@@ -361,7 +403,7 @@ class ModelBase:
361
403
  cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
362
404
  ) -> T:
363
405
  row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
364
- return cls(**row)
406
+ return cls.from_mapping(row)
365
407
 
366
408
  def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
367
409
  cached = self._cached(
@@ -419,10 +461,11 @@ class ModelBase:
419
461
  pks=sql.list(sql.identifier(pk) for pk in cls.primary_key_names),
420
462
  ).compile(),
421
463
  )
464
+ column_info = cls.column_info()
422
465
  return cached(
423
466
  unnest=sql.unnest(
424
467
  (row.primary_key() for row in rows),
425
- (cls.column_info(pk).type for pk in cls.primary_key_names),
468
+ (column_info[pk].type for pk in cls.primary_key_names),
426
469
  ),
427
470
  )
428
471
 
@@ -442,10 +485,11 @@ class ModelBase:
442
485
  fields=sql.list(cls.field_names_sql()),
443
486
  ).compile(),
444
487
  )
488
+ column_info = cls.column_info()
445
489
  return cached(
446
490
  unnest=sql.unnest(
447
491
  (row.field_values() for row in rows),
448
- (cls.column_info(name).type for name in cls.field_names()),
492
+ (column_info[name].type for name in cls.field_names()),
449
493
  ),
450
494
  )
451
495
 
@@ -536,9 +580,9 @@ class ModelBase:
536
580
  ) -> Callable[[T, T], bool]:
537
581
  env: dict[str, Any] = {}
538
582
  func = ["def equal_ignoring(a, b):"]
539
- for f in cls._fields():
540
- if f.name not in ignore:
541
- func.append(f" if a.{f.name} != b.{f.name}: return False")
583
+ for ci in cls.column_info().values():
584
+ if ci.field.name not in ignore:
585
+ func.append(f" if a.{ci.field.name} != b.{ci.field.name}: return False")
542
586
  func += [" return True"]
543
587
  exec("\n".join(func), env)
544
588
  return env["equal_ignoring"]
@@ -594,9 +638,11 @@ class ModelBase:
594
638
  "def differences_ignoring(a, b):",
595
639
  " diffs = []",
596
640
  ]
597
- for f in cls._fields():
598
- if f.name not in ignore:
599
- func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
641
+ for ci in cls.column_info().values():
642
+ if ci.field.name not in ignore:
643
+ func.append(
644
+ f" if a.{ci.field.name} != b.{ci.field.name}: diffs.append({ci.field.name!r})"
645
+ )
600
646
  func += [" return diffs"]
601
647
  exec("\n".join(func), env)
602
648
  return env["differences_ignoring"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sql-athame
3
- Version: 0.4.0a7
3
+ Version: 0.4.0a9
4
4
  Summary: Python tool for slicing and dicing SQL
5
5
  Home-page: https://github.com/bdowning/sql-athame
6
6
  License: MIT
@@ -1,11 +1,11 @@
1
1
  sql_athame/__init__.py,sha256=7OBIMZOcrD2pvfIL-rjD1IGZ3TNQbwyu76a9PWk-yYg,79
2
2
  sql_athame/base.py,sha256=FR7EmC0VkX1VRgvAutSEfYSWhlEYpoqS1Kqxp1jHp6Y,10293
3
- sql_athame/dataclasses.py,sha256=8JDACQr5RCeCbu2QRAzA9rpM9i1TJNGKEFXEFbGJUgo,22193
3
+ sql_athame/dataclasses.py,sha256=9Q-Z3itKyuqhR5u47bVBfA714uFbf-K4t1FPiFd8XAE,23792
4
4
  sql_athame/escape.py,sha256=kK101xXeFitlvuG-L_hvhdpgGJCtmRTprsn1yEfZKws,758
5
5
  sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  sql_athame/sqlalchemy.py,sha256=aWopfPh3j71XwKmcN_VcHRNlhscI0Sckd4AiyGf8Tpw,1293
7
7
  sql_athame/types.py,sha256=FQ06l9Uc-vo57UrAarvnukILdV2gN1IaYUnHJ_bNYic,475
8
- sql_athame-0.4.0a7.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
9
- sql_athame-0.4.0a7.dist-info/METADATA,sha256=OqUSaxi_5K6vfYxWpiXGP_qDXPU-qQnCrkY3yruhzi4,12845
10
- sql_athame-0.4.0a7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
11
- sql_athame-0.4.0a7.dist-info/RECORD,,
8
+ sql_athame-0.4.0a9.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
9
+ sql_athame-0.4.0a9.dist-info/METADATA,sha256=pf4xAdRJ7NuJaViLLWsQeq8LRIA78tY_YdOmPBjpFgg,12845
10
+ sql_athame-0.4.0a9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
11
+ sql_athame-0.4.0a9.dist-info/RECORD,,