sql-athame 0.4.0a8__py3-none-any.whl → 0.4.0a9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_athame/dataclasses.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import datetime
2
2
  import functools
3
+ import sys
3
4
  import uuid
4
5
  from collections.abc import AsyncGenerator, Iterable, Mapping
5
6
  from dataclasses import Field, InitVar, dataclass, fields
@@ -34,10 +35,13 @@ class ColumnInfo:
34
35
  type: Optional[str] = None
35
36
  create_type: Optional[str] = None
36
37
  nullable: Optional[bool] = None
37
- _constraints: tuple[str, ...] = ()
38
38
 
39
+ _constraints: tuple[str, ...] = ()
39
40
  constraints: InitVar[Union[str, Iterable[str], None]] = None
40
41
 
42
+ serialize: Optional[Callable[[Any], Any]] = None
43
+ deserialize: Optional[Callable[[Any], Any]] = None
44
+
41
45
  def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
42
46
  if constraints is not None:
43
47
  if type(constraints) is str:
@@ -51,29 +55,41 @@ class ColumnInfo:
51
55
  create_type=b.create_type if b.create_type is not None else a.create_type,
52
56
  nullable=b.nullable if b.nullable is not None else a.nullable,
53
57
  _constraints=(*a._constraints, *b._constraints),
58
+ serialize=b.serialize if b.serialize is not None else a.serialize,
59
+ deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
54
60
  )
55
61
 
56
62
 
57
63
  @dataclass
58
64
  class ConcreteColumnInfo:
65
+ field: Field
66
+ type_hint: type
59
67
  type: str
60
68
  create_type: str
61
69
  nullable: bool
62
70
  constraints: tuple[str, ...]
71
+ serialize: Optional[Callable[[Any], Any]] = None
72
+ deserialize: Optional[Callable[[Any], Any]] = None
63
73
 
64
74
  @staticmethod
65
- def from_column_info(name: str, *args: ColumnInfo) -> "ConcreteColumnInfo":
75
+ def from_column_info(
76
+ field: Field, type_hint: Any, *args: ColumnInfo
77
+ ) -> "ConcreteColumnInfo":
66
78
  info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
67
79
  if info.create_type is None and info.type is not None:
68
80
  info.create_type = info.type
69
81
  info.type = sql_create_type_map.get(info.type.upper(), info.type)
70
82
  if type(info.type) is not str or type(info.create_type) is not str:
71
- raise ValueError(f"Missing SQL type for column {name!r}")
83
+ raise ValueError(f"Missing SQL type for column {field.name!r}")
72
84
  return ConcreteColumnInfo(
85
+ field=field,
86
+ type_hint=type_hint,
73
87
  type=info.type,
74
88
  create_type=info.create_type,
75
89
  nullable=bool(info.nullable),
76
90
  constraints=info._constraints,
91
+ serialize=info.serialize,
92
+ deserialize=info.deserialize,
77
93
  )
78
94
 
79
95
  def create_table_string(self) -> str:
@@ -84,13 +100,24 @@ class ConcreteColumnInfo:
84
100
  )
85
101
  return " ".join(parts)
86
102
 
103
+ def maybe_serialize(self, value: Any) -> Any:
104
+ if self.serialize:
105
+ return self.serialize(value)
106
+ return value
107
+
108
+
109
+ UNION_TYPES: tuple = (Union,)
110
+ if sys.version_info >= (3, 10):
111
+ from types import UnionType
112
+
113
+ UNION_TYPES = (Union, UnionType)
87
114
 
88
115
  NULLABLE_TYPES = (type(None), Any, object)
89
116
 
90
117
 
91
118
  def split_nullable(typ: type) -> tuple[bool, type]:
92
119
  nullable = typ in NULLABLE_TYPES
93
- if get_origin(typ) is Union:
120
+ if get_origin(typ) in UNION_TYPES:
94
121
  args = []
95
122
  for arg in get_args(typ):
96
123
  if arg in NULLABLE_TYPES:
@@ -108,7 +135,7 @@ sql_create_type_map = {
108
135
  }
109
136
 
110
137
 
111
- sql_type_map: dict[Any, str] = {
138
+ sql_type_map: dict[type, str] = {
112
139
  bool: "BOOLEAN",
113
140
  bytes: "BYTEA",
114
141
  datetime.date: "DATE",
@@ -125,12 +152,11 @@ U = TypeVar("U")
125
152
 
126
153
 
127
154
  class ModelBase:
128
- _column_info: Optional[dict[str, ConcreteColumnInfo]]
155
+ _column_info: dict[str, ConcreteColumnInfo]
129
156
  _cache: dict[tuple, Any]
130
157
  table_name: str
131
158
  primary_key_names: tuple[str, ...]
132
159
  array_safe_insert: bool
133
- _type_hints: dict[str, type]
134
160
 
135
161
  def __init_subclass__(
136
162
  cls,
@@ -153,13 +179,6 @@ class ModelBase:
153
179
  else:
154
180
  cls.primary_key_names = tuple(primary_key)
155
181
 
156
- @classmethod
157
- def _fields(cls):
158
- # wrapper to ignore typing weirdness: 'Argument 1 to "fields"
159
- # has incompatible type "..."; expected "DataclassInstance |
160
- # type[DataclassInstance]"'
161
- return fields(cls) # type: ignore
162
-
163
182
  @classmethod
164
183
  def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
165
184
  try:
@@ -169,20 +188,11 @@ class ModelBase:
169
188
  return cls._cache[key]
170
189
 
171
190
  @classmethod
172
- def type_hints(cls) -> dict[str, type]:
173
- try:
174
- return cls._type_hints
175
- except AttributeError:
176
- cls._type_hints = get_type_hints(cls, include_extras=True)
177
- return cls._type_hints
178
-
179
- @classmethod
180
- def column_info_for_field(cls, field: Field) -> ConcreteColumnInfo:
181
- type_info = cls.type_hints()[field.name]
182
- base_type = type_info
191
+ def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
192
+ base_type = type_hint
183
193
  metadata = []
184
- if get_origin(type_info) is Annotated:
185
- base_type, *metadata = get_args(type_info)
194
+ if get_origin(type_hint) is Annotated:
195
+ base_type, *metadata = get_args(type_hint)
186
196
  nullable, base_type = split_nullable(base_type)
187
197
  info = [ColumnInfo(nullable=nullable)]
188
198
  if base_type in sql_type_map:
@@ -190,17 +200,19 @@ class ModelBase:
190
200
  for md in metadata:
191
201
  if isinstance(md, ColumnInfo):
192
202
  info.append(md)
193
- return ConcreteColumnInfo.from_column_info(field.name, *info)
203
+ return ConcreteColumnInfo.from_column_info(field, type_hint, *info)
194
204
 
195
205
  @classmethod
196
- def column_info(cls, column: str) -> ConcreteColumnInfo:
206
+ def column_info(cls) -> dict[str, ConcreteColumnInfo]:
197
207
  try:
198
- return cls._column_info[column] # type: ignore
208
+ return cls._column_info
199
209
  except AttributeError:
210
+ type_hints = get_type_hints(cls, include_extras=True)
200
211
  cls._column_info = {
201
- f.name: cls.column_info_for_field(f) for f in cls._fields()
212
+ f.name: cls.column_info_for_field(f, type_hints[f.name])
213
+ for f in fields(cls) # type: ignore
202
214
  }
203
- return cls._column_info[column]
215
+ return cls._column_info
204
216
 
205
217
  @classmethod
206
218
  def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
@@ -212,7 +224,11 @@ class ModelBase:
212
224
 
213
225
  @classmethod
214
226
  def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
215
- return [f.name for f in cls._fields() if f.name not in exclude]
227
+ return [
228
+ ci.field.name
229
+ for ci in cls.column_info().values()
230
+ if ci.field.name not in exclude
231
+ ]
216
232
 
217
233
  @classmethod
218
234
  def field_names_sql(
@@ -231,9 +247,13 @@ class ModelBase:
231
247
  ) -> Callable[[T], list[Any]]:
232
248
  env: dict[str, Any] = {}
233
249
  func = ["def get_field_values(self): return ["]
234
- for f in cls._fields():
235
- if f.name not in exclude:
236
- func.append(f"self.{f.name},")
250
+ for ci in cls.column_info().values():
251
+ if ci.field.name not in exclude:
252
+ if ci.serialize:
253
+ env[f"_ser_{ci.field.name}"] = ci.serialize
254
+ func.append(f"_ser_{ci.field.name}(self.{ci.field.name}), ")
255
+ else:
256
+ func.append(f"self.{ci.field.name},")
237
257
  func += ["]"]
238
258
  exec(" ".join(func), env)
239
259
  return env["get_field_values"]
@@ -257,36 +277,46 @@ class ModelBase:
257
277
  return [sql.value(value) for value in self.field_values()]
258
278
 
259
279
  @classmethod
260
- def from_tuple(
261
- cls: type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
262
- ) -> T:
263
- names = (f.name for f in cls._fields() if f.name not in exclude)
264
- kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
265
- return cls(**kwargs)
280
+ def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
281
+ env: dict[str, Any] = {"cls": cls}
282
+ func = ["def from_mapping(mapping):"]
283
+ if not any(ci.deserialize for ci in cls.column_info().values()):
284
+ func.append(" return cls(**mapping)")
285
+ else:
286
+ func.append(" deser_dict = dict(mapping)")
287
+ for ci in cls.column_info().values():
288
+ if ci.deserialize:
289
+ env[f"_deser_{ci.field.name}"] = ci.deserialize
290
+ func.append(f" if {ci.field.name!r} in deser_dict:")
291
+ func.append(
292
+ f" deser_dict[{ci.field.name!r}] = _deser_{ci.field.name}(deser_dict[{ci.field.name!r}])"
293
+ )
294
+ func.append(" return cls(**deser_dict)")
295
+ exec("\n".join(func), env)
296
+ return env["from_mapping"]
266
297
 
267
298
  @classmethod
268
- def from_dict(
269
- cls: type[T], dct: dict[str, Any], *, exclude: FieldNamesSet = ()
270
- ) -> T:
271
- names = {f.name for f in cls._fields() if f.name not in exclude}
272
- kwargs = {k: v for k, v in dct.items() if k in names}
273
- return cls(**kwargs)
299
+ def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
300
+ # KLUDGE nasty but... efficient?
301
+ from_mapping_fn = cls._get_from_mapping_fn()
302
+ cls.from_mapping = from_mapping_fn # type: ignore
303
+ return from_mapping_fn(mapping)
274
304
 
275
305
  @classmethod
276
306
  def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
277
307
  if isinstance(row, cls):
278
308
  return row
279
- return cls(**row)
309
+ return cls.from_mapping(row) # type: ignore
280
310
 
281
311
  @classmethod
282
312
  def create_table_sql(cls) -> Fragment:
283
313
  entries = [
284
314
  sql(
285
315
  "{} {}",
286
- sql.identifier(f.name),
287
- sql.literal(cls.column_info(f.name).create_table_string()),
316
+ sql.identifier(ci.field.name),
317
+ sql.literal(ci.create_table_string()),
288
318
  )
289
- for f in cls._fields()
319
+ for ci in cls.column_info().values()
290
320
  ]
291
321
  if cls.primary_key_names:
292
322
  entries += [sql("PRIMARY KEY ({})", sql.list(cls.primary_key_names_sql()))]
@@ -338,7 +368,7 @@ class ModelBase:
338
368
  *cls.select_sql(order_by=order_by, for_update=for_update, where=where),
339
369
  prefetch=prefetch,
340
370
  ):
341
- yield cls(**row)
371
+ yield cls.from_mapping(row)
342
372
 
343
373
  @classmethod
344
374
  async def select(
@@ -349,7 +379,7 @@ class ModelBase:
349
379
  where: Where = (),
350
380
  ) -> list[T]:
351
381
  return [
352
- cls(**row)
382
+ cls.from_mapping(row)
353
383
  for row in await connection_or_pool.fetch(
354
384
  *cls.select_sql(order_by=order_by, for_update=for_update, where=where)
355
385
  )
@@ -357,11 +387,14 @@ class ModelBase:
357
387
 
358
388
  @classmethod
359
389
  def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
390
+ column_info = cls.column_info()
360
391
  return sql(
361
392
  "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
362
393
  table=cls.table_name_sql(),
363
- fields=sql.list(sql.identifier(x) for x in kwargs.keys()),
364
- values=sql.list(sql.value(x) for x in kwargs.values()),
394
+ fields=sql.list(sql.identifier(k) for k in kwargs.keys()),
395
+ values=sql.list(
396
+ sql.value(column_info[k].maybe_serialize(v)) for k, v in kwargs.items()
397
+ ),
365
398
  out_fields=sql.list(cls.field_names_sql()),
366
399
  )
367
400
 
@@ -370,7 +403,7 @@ class ModelBase:
370
403
  cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
371
404
  ) -> T:
372
405
  row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
373
- return cls(**row)
406
+ return cls.from_mapping(row)
374
407
 
375
408
  def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
376
409
  cached = self._cached(
@@ -428,10 +461,11 @@ class ModelBase:
428
461
  pks=sql.list(sql.identifier(pk) for pk in cls.primary_key_names),
429
462
  ).compile(),
430
463
  )
464
+ column_info = cls.column_info()
431
465
  return cached(
432
466
  unnest=sql.unnest(
433
467
  (row.primary_key() for row in rows),
434
- (cls.column_info(pk).type for pk in cls.primary_key_names),
468
+ (column_info[pk].type for pk in cls.primary_key_names),
435
469
  ),
436
470
  )
437
471
 
@@ -451,10 +485,11 @@ class ModelBase:
451
485
  fields=sql.list(cls.field_names_sql()),
452
486
  ).compile(),
453
487
  )
488
+ column_info = cls.column_info()
454
489
  return cached(
455
490
  unnest=sql.unnest(
456
491
  (row.field_values() for row in rows),
457
- (cls.column_info(name).type for name in cls.field_names()),
492
+ (column_info[name].type for name in cls.field_names()),
458
493
  ),
459
494
  )
460
495
 
@@ -545,9 +580,9 @@ class ModelBase:
545
580
  ) -> Callable[[T, T], bool]:
546
581
  env: dict[str, Any] = {}
547
582
  func = ["def equal_ignoring(a, b):"]
548
- for f in cls._fields():
549
- if f.name not in ignore:
550
- func.append(f" if a.{f.name} != b.{f.name}: return False")
583
+ for ci in cls.column_info().values():
584
+ if ci.field.name not in ignore:
585
+ func.append(f" if a.{ci.field.name} != b.{ci.field.name}: return False")
551
586
  func += [" return True"]
552
587
  exec("\n".join(func), env)
553
588
  return env["equal_ignoring"]
@@ -603,9 +638,11 @@ class ModelBase:
603
638
  "def differences_ignoring(a, b):",
604
639
  " diffs = []",
605
640
  ]
606
- for f in cls._fields():
607
- if f.name not in ignore:
608
- func.append(f" if a.{f.name} != b.{f.name}: diffs.append({f.name!r})")
641
+ for ci in cls.column_info().values():
642
+ if ci.field.name not in ignore:
643
+ func.append(
644
+ f" if a.{ci.field.name} != b.{ci.field.name}: diffs.append({ci.field.name!r})"
645
+ )
609
646
  func += [" return diffs"]
610
647
  exec("\n".join(func), env)
611
648
  return env["differences_ignoring"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sql-athame
3
- Version: 0.4.0a8
3
+ Version: 0.4.0a9
4
4
  Summary: Python tool for slicing and dicing SQL
5
5
  Home-page: https://github.com/bdowning/sql-athame
6
6
  License: MIT
@@ -1,11 +1,11 @@
1
1
  sql_athame/__init__.py,sha256=7OBIMZOcrD2pvfIL-rjD1IGZ3TNQbwyu76a9PWk-yYg,79
2
2
  sql_athame/base.py,sha256=FR7EmC0VkX1VRgvAutSEfYSWhlEYpoqS1Kqxp1jHp6Y,10293
3
- sql_athame/dataclasses.py,sha256=NRPGOlqaTMj49B4gbvC5nCRrFZFrivoRdhSmG8VTkAg,22167
3
+ sql_athame/dataclasses.py,sha256=9Q-Z3itKyuqhR5u47bVBfA714uFbf-K4t1FPiFd8XAE,23792
4
4
  sql_athame/escape.py,sha256=kK101xXeFitlvuG-L_hvhdpgGJCtmRTprsn1yEfZKws,758
5
5
  sql_athame/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  sql_athame/sqlalchemy.py,sha256=aWopfPh3j71XwKmcN_VcHRNlhscI0Sckd4AiyGf8Tpw,1293
7
7
  sql_athame/types.py,sha256=FQ06l9Uc-vo57UrAarvnukILdV2gN1IaYUnHJ_bNYic,475
8
- sql_athame-0.4.0a8.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
9
- sql_athame-0.4.0a8.dist-info/METADATA,sha256=dXFzs8L8wxnzsjKcL8jj3NyMJPi6l5tIbdZ01ALkzc8,12845
10
- sql_athame-0.4.0a8.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
11
- sql_athame-0.4.0a8.dist-info/RECORD,,
8
+ sql_athame-0.4.0a9.dist-info/LICENSE,sha256=xqV29vPFqITcKifYrGPgVIBjq4fdmLSwY3gRUtDKafg,1076
9
+ sql_athame-0.4.0a9.dist-info/METADATA,sha256=pf4xAdRJ7NuJaViLLWsQeq8LRIA78tY_YdOmPBjpFgg,12845
10
+ sql_athame-0.4.0a9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
11
+ sql_athame-0.4.0a9.dist-info/RECORD,,