piccolo 1.8.0__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. piccolo/__init__.py +1 -1
  2. piccolo/apps/fixtures/commands/load.py +1 -1
  3. piccolo/apps/migrations/auto/__init__.py +8 -0
  4. piccolo/apps/migrations/auto/migration_manager.py +2 -1
  5. piccolo/apps/migrations/commands/backwards.py +3 -1
  6. piccolo/apps/migrations/commands/base.py +1 -1
  7. piccolo/apps/migrations/commands/check.py +1 -1
  8. piccolo/apps/migrations/commands/clean.py +1 -1
  9. piccolo/apps/migrations/commands/forwards.py +3 -1
  10. piccolo/apps/migrations/commands/new.py +4 -2
  11. piccolo/apps/schema/commands/generate.py +2 -2
  12. piccolo/apps/shell/commands/run.py +1 -1
  13. piccolo/columns/base.py +55 -29
  14. piccolo/columns/column_types.py +28 -4
  15. piccolo/columns/defaults/base.py +6 -4
  16. piccolo/columns/defaults/date.py +9 -1
  17. piccolo/columns/defaults/interval.py +1 -0
  18. piccolo/columns/defaults/time.py +9 -1
  19. piccolo/columns/defaults/timestamp.py +1 -0
  20. piccolo/columns/defaults/uuid.py +1 -1
  21. piccolo/columns/m2m.py +7 -7
  22. piccolo/columns/operators/comparison.py +4 -0
  23. piccolo/conf/apps.py +9 -4
  24. piccolo/engine/base.py +69 -20
  25. piccolo/engine/cockroach.py +2 -3
  26. piccolo/engine/postgres.py +33 -19
  27. piccolo/engine/sqlite.py +27 -22
  28. piccolo/query/functions/__init__.py +5 -0
  29. piccolo/query/functions/math.py +48 -0
  30. piccolo/query/methods/create_index.py +1 -1
  31. piccolo/query/methods/drop_index.py +1 -1
  32. piccolo/query/methods/objects.py +7 -7
  33. piccolo/query/methods/select.py +13 -7
  34. piccolo/query/mixins.py +3 -10
  35. piccolo/querystring.py +18 -0
  36. piccolo/schema.py +20 -12
  37. piccolo/table.py +22 -21
  38. piccolo/utils/encoding.py +5 -3
  39. {piccolo-1.8.0.dist-info → piccolo-1.10.0.dist-info}/METADATA +1 -1
  40. {piccolo-1.8.0.dist-info → piccolo-1.10.0.dist-info}/RECORD +59 -52
  41. tests/apps/migrations/auto/integration/test_migrations.py +1 -1
  42. tests/columns/test_array.py +91 -19
  43. tests/columns/test_get_sql_value.py +66 -0
  44. tests/conf/test_apps.py +1 -1
  45. tests/engine/test_nested_transaction.py +2 -0
  46. tests/engine/test_transaction.py +1 -2
  47. tests/query/functions/__init__.py +0 -0
  48. tests/query/functions/base.py +34 -0
  49. tests/query/functions/test_functions.py +64 -0
  50. tests/query/functions/test_math.py +39 -0
  51. tests/query/functions/test_string.py +25 -0
  52. tests/query/functions/test_type_conversion.py +134 -0
  53. tests/query/test_querystring.py +136 -0
  54. tests/table/test_indexes.py +4 -2
  55. tests/utils/test_pydantic.py +70 -29
  56. tests/query/test_functions.py +0 -238
  57. {piccolo-1.8.0.dist-info → piccolo-1.10.0.dist-info}/LICENSE +0 -0
  58. {piccolo-1.8.0.dist-info → piccolo-1.10.0.dist-info}/WHEEL +0 -0
  59. {piccolo-1.8.0.dist-info → piccolo-1.10.0.dist-info}/entry_points.txt +0 -0
  60. {piccolo-1.8.0.dist-info → piccolo-1.10.0.dist-info}/top_level.txt +0 -0
piccolo/__init__.py CHANGED
@@ -1 +1 @@
1
- __VERSION__ = "1.8.0"
1
+ __VERSION__ = "1.10.0"
@@ -51,7 +51,7 @@ async def load_json_string(
51
51
  finder = Finder()
52
52
  engine = engine_finder()
53
53
 
54
- if not engine:
54
+ if engine is None:
55
55
  raise Exception("Unable to find the engine.")
56
56
 
57
57
  # This is what we want to the insert into the database:
@@ -2,3 +2,11 @@ from .diffable_table import DiffableTable
2
2
  from .migration_manager import MigrationManager
3
3
  from .schema_differ import AlterStatements, SchemaDiffer
4
4
  from .schema_snapshot import SchemaSnapshot
5
+
6
+ __all__ = [
7
+ "DiffableTable",
8
+ "MigrationManager",
9
+ "AlterStatements",
10
+ "SchemaDiffer",
11
+ "SchemaSnapshot",
12
+ ]
@@ -261,7 +261,8 @@ class MigrationManager:
261
261
  cleaned_params = deserialise_params(params=params)
262
262
  column = column_class(**cleaned_params)
263
263
  column._meta.name = column_name
264
- column._meta.db_column_name = db_column_name
264
+ if db_column_name:
265
+ column._meta.db_column_name = db_column_name
265
266
 
266
267
  self.add_columns.append(
267
268
  AddColumnClass(
@@ -32,7 +32,9 @@ class BackwardsMigrationManager(BaseMigrationManager):
32
32
 
33
33
  async def run_migrations_backwards(self, app_config: AppConfig):
34
34
  migration_modules: t.Dict[str, MigrationModule] = (
35
- self.get_migration_modules(app_config.migrations_folder_path)
35
+ self.get_migration_modules(
36
+ app_config.resolved_migrations_folder_path
37
+ )
36
38
  )
37
39
 
38
40
  ran_migration_ids = await Migration.get_migrations_which_ran(
@@ -86,7 +86,7 @@ class BaseMigrationManager(Finder):
86
86
  """
87
87
  migration_managers: t.List[MigrationManager] = []
88
88
 
89
- migrations_folder = app_config.migrations_folder_path
89
+ migrations_folder = app_config.resolved_migrations_folder_path
90
90
 
91
91
  migration_modules: t.Dict[str, MigrationModule] = (
92
92
  self.get_migration_modules(migrations_folder)
@@ -36,7 +36,7 @@ class CheckMigrationManager(BaseMigrationManager):
36
36
  continue
37
37
 
38
38
  migration_modules = self.get_migration_modules(
39
- app_config.migrations_folder_path
39
+ app_config.resolved_migrations_folder_path
40
40
  )
41
41
  ids = self.get_migration_ids(migration_modules)
42
42
  for _id in ids:
@@ -20,7 +20,7 @@ class CleanMigrationManager(BaseMigrationManager):
20
20
  app_config = self.get_app_config(app_name=self.app_name)
21
21
 
22
22
  migration_module_dict = self.get_migration_modules(
23
- folder_path=app_config.migrations_folder_path
23
+ folder_path=app_config.resolved_migrations_folder_path
24
24
  )
25
25
 
26
26
  # The migration IDs which are in migration modules.
@@ -33,7 +33,9 @@ class ForwardsMigrationManager(BaseMigrationManager):
33
33
  )
34
34
 
35
35
  migration_modules: t.Dict[str, MigrationModule] = (
36
- self.get_migration_modules(app_config.migrations_folder_path)
36
+ self.get_migration_modules(
37
+ app_config.resolved_migrations_folder_path
38
+ )
37
39
  )
38
40
 
39
41
  ids = self.get_migration_ids(migration_modules)
@@ -98,7 +98,9 @@ def _generate_migration_meta(app_config: AppConfig) -> NewMigrationMeta:
98
98
 
99
99
  filename = f"{cleaned_app_name}_{cleaned_id}"
100
100
 
101
- path = os.path.join(app_config.migrations_folder_path, f"{filename}.py")
101
+ path = os.path.join(
102
+ app_config.resolved_migrations_folder_path, f"{filename}.py"
103
+ )
102
104
 
103
105
  return NewMigrationMeta(
104
106
  migration_id=_id, migration_filename=filename, migration_path=path
@@ -255,7 +257,7 @@ async def new(
255
257
 
256
258
  app_config = Finder().get_app_config(app_name=app_name)
257
259
 
258
- _create_migrations_folder(app_config.migrations_folder_path)
260
+ _create_migrations_folder(app_config.resolved_migrations_folder_path)
259
261
 
260
262
  try:
261
263
  await _create_new_migration(
@@ -313,7 +313,7 @@ COLUMN_TYPE_MAP_COCKROACH: t.Dict[str, t.Type[Column]] = {
313
313
  **{"integer": BigInt, "json": JSONB},
314
314
  }
315
315
 
316
- COLUMN_DEFAULT_PARSER = {
316
+ COLUMN_DEFAULT_PARSER: t.Dict[t.Type[Column], t.Any] = {
317
317
  BigInt: re.compile(r"^'?(?P<value>-?[0-9]\d*)'?(?:::bigint)?$"),
318
318
  Boolean: re.compile(r"^(?P<value>true|false)$"),
319
319
  Bytea: re.compile(r"'(?P<value>.*)'::bytea$"),
@@ -373,7 +373,7 @@ COLUMN_DEFAULT_PARSER = {
373
373
  }
374
374
 
375
375
  # Re-map for Cockroach compatibility.
376
- COLUMN_DEFAULT_PARSER_COCKROACH = {
376
+ COLUMN_DEFAULT_PARSER_COCKROACH: t.Dict[t.Type[Column], t.Any] = {
377
377
  **COLUMN_DEFAULT_PARSER,
378
378
  BigInt: re.compile(r"^(?P<value>-?\d+)$"),
379
379
  }
@@ -24,7 +24,7 @@ def start_ipython_shell(**tables: t.Type[Table]): # pragma: no cover
24
24
  if table_class_name not in existing_global_names:
25
25
  globals()[table_class_name] = table_class
26
26
 
27
- IPython.embed(using=_asyncio_runner, colors="neutral")
27
+ IPython.embed(using=_asyncio_runner, colors="neutral") # type: ignore
28
28
 
29
29
 
30
30
  def run() -> None:
piccolo/columns/base.py CHANGED
@@ -830,7 +830,11 @@ class Column(Selectable):
830
830
  engine_type=engine_type, with_alias=False
831
831
  )
832
832
 
833
- def get_sql_value(self, value: t.Any) -> t.Any:
833
+ def get_sql_value(
834
+ self,
835
+ value: t.Any,
836
+ delimiter: str = "'",
837
+ ) -> str:
834
838
  """
835
839
  When using DDL statements, we can't parameterise the values. An example
836
840
  is when setting the default for a column. So we have to convert from
@@ -839,11 +843,18 @@ class Column(Selectable):
839
843
 
840
844
  :param value:
841
845
  The Python value to convert to a string usable in a DDL statement
842
- e.g. 1.
846
+ e.g. ``1``.
847
+ :param delimiter:
848
+ The string returned by this function is wrapped in delimiters,
849
+ ready to be added to a DDL statement. For example:
850
+ ``'hello world'``.
843
851
  :returns:
844
- The string usable in the DDL statement e.g. '1'.
852
+ The string usable in the DDL statement e.g. ``'1'``.
845
853
 
846
854
  """
855
+ from piccolo.engine.sqlite import ADAPTERS as sqlite_adapters
856
+
857
+ # Common across all DB engines
847
858
  if isinstance(value, Default):
848
859
  return getattr(value, self._meta.engine_type)
849
860
  elif value is None:
@@ -851,37 +862,52 @@ class Column(Selectable):
851
862
  elif isinstance(value, (float, decimal.Decimal)):
852
863
  return str(value)
853
864
  elif isinstance(value, str):
854
- return f"'{value}'"
865
+ return f"{delimiter}{value}{delimiter}"
855
866
  elif isinstance(value, bool):
856
867
  return str(value).lower()
857
- elif isinstance(value, datetime.datetime):
858
- return f"'{value.isoformat().replace('T', ' ')}'"
859
- elif isinstance(value, datetime.date):
860
- return f"'{value.isoformat()}'"
861
- elif isinstance(value, datetime.time):
862
- return f"'{value.isoformat()}'"
863
- elif isinstance(value, datetime.timedelta):
864
- interval = IntervalCustom.from_timedelta(value)
865
- return getattr(interval, self._meta.engine_type)
866
868
  elif isinstance(value, bytes):
867
- return f"'{value.hex()}'"
868
- elif isinstance(value, uuid.UUID):
869
- return f"'{value}'"
870
- elif isinstance(value, list):
871
- # Convert to the array syntax.
872
- return (
873
- "'{"
874
- + ", ".join(
875
- (
876
- f'"{i}"'
877
- if isinstance(i, str)
878
- else str(self.get_sql_value(i))
869
+ return f"{delimiter}{value.hex()}{delimiter}"
870
+
871
+ # SQLite specific
872
+ if self._meta.engine_type == "sqlite":
873
+ if adapter := sqlite_adapters.get(type(value)):
874
+ sqlite_value = adapter(value)
875
+ return (
876
+ f"{delimiter}{sqlite_value}{delimiter}"
877
+ if isinstance(sqlite_value, str)
878
+ else sqlite_value
879
+ )
880
+
881
+ # Postgres and Cockroach
882
+ if self._meta.engine_type in ["postgres", "cockroach"]:
883
+ if isinstance(value, datetime.datetime):
884
+ return f"{delimiter}{value.isoformat().replace('T', ' ')}{delimiter}" # noqa: E501
885
+ elif isinstance(value, datetime.date):
886
+ return f"{delimiter}{value.isoformat()}{delimiter}"
887
+ elif isinstance(value, datetime.time):
888
+ return f"{delimiter}{value.isoformat()}{delimiter}"
889
+ elif isinstance(value, datetime.timedelta):
890
+ interval = IntervalCustom.from_timedelta(value)
891
+ return getattr(interval, self._meta.engine_type)
892
+ elif isinstance(value, uuid.UUID):
893
+ return f"{delimiter}{value}{delimiter}"
894
+ elif isinstance(value, list):
895
+ # Convert to the array syntax.
896
+ return (
897
+ delimiter
898
+ + "{"
899
+ + ",".join(
900
+ self.get_sql_value(
901
+ i,
902
+ delimiter="" if isinstance(i, list) else '"',
903
+ )
904
+ for i in value
879
905
  )
880
- for i in value
906
+ + "}"
907
+ + delimiter
881
908
  )
882
- ) + "}'"
883
- else:
884
- return value
909
+
910
+ return str(value)
885
911
 
886
912
  @property
887
913
  def column_type(self):
@@ -57,7 +57,11 @@ from piccolo.columns.defaults.timestamptz import (
57
57
  TimestamptzNow,
58
58
  )
59
59
  from piccolo.columns.defaults.uuid import UUID4, UUIDArg
60
- from piccolo.columns.operators.comparison import ArrayAll, ArrayAny
60
+ from piccolo.columns.operators.comparison import (
61
+ ArrayAll,
62
+ ArrayAny,
63
+ ArrayNotAny,
64
+ )
61
65
  from piccolo.columns.operators.string import Concat
62
66
  from piccolo.columns.reference import LazyTableReference
63
67
  from piccolo.querystring import QueryString
@@ -1952,7 +1956,9 @@ class ForeignKey(Column, t.Generic[ReferencedTable]):
1952
1956
 
1953
1957
  if is_table_class:
1954
1958
  # Record the reverse relationship on the target table.
1955
- references._meta._foreign_key_references.append(self)
1959
+ t.cast(
1960
+ t.Type[Table], references
1961
+ )._meta._foreign_key_references.append(self)
1956
1962
 
1957
1963
  # Allow columns on the referenced table to be accessed via
1958
1964
  # auto completion.
@@ -2670,6 +2676,24 @@ class Array(Column):
2670
2676
  else:
2671
2677
  raise ValueError("Unrecognised engine type")
2672
2678
 
2679
+ def not_any(self, value: t.Any) -> Where:
2680
+ """
2681
+ Check if the given value isn't in the array.
2682
+
2683
+ .. code-block:: python
2684
+
2685
+ >>> await Ticket.select().where(Ticket.seat_numbers.not_any(510))
2686
+
2687
+ """
2688
+ engine_type = self._meta.engine_type
2689
+
2690
+ if engine_type in ("postgres", "cockroach"):
2691
+ return Where(column=self, value=value, operator=ArrayNotAny)
2692
+ elif engine_type == "sqlite":
2693
+ return self.not_like(f"%{value}%")
2694
+ else:
2695
+ raise ValueError("Unrecognised engine type")
2696
+
2673
2697
  def all(self, value: t.Any) -> Where:
2674
2698
  """
2675
2699
  Check if all of the items in the array match the given value.
@@ -2688,7 +2712,7 @@ class Array(Column):
2688
2712
  else:
2689
2713
  raise ValueError("Unrecognised engine type")
2690
2714
 
2691
- def cat(self, value: t.List[t.Any]) -> QueryString:
2715
+ def cat(self, value: t.Union[t.Any, t.List[t.Any]]) -> QueryString:
2692
2716
  """
2693
2717
  Used in an ``update`` query to append items to an array.
2694
2718
 
@@ -2719,7 +2743,7 @@ class Array(Column):
2719
2743
  db_column_name = self._meta.db_column_name
2720
2744
  return QueryString(f'array_cat("{db_column_name}", {{}})', value)
2721
2745
 
2722
- def __add__(self, value: t.List[t.Any]) -> QueryString:
2746
+ def __add__(self, value: t.Union[t.Any, t.List[t.Any]]) -> QueryString:
2723
2747
  return self.cat(value)
2724
2748
 
2725
2749
  ###########################################################################
@@ -1,22 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import typing as t
4
- from abc import ABC, abstractmethod, abstractproperty
4
+ from abc import ABC, abstractmethod
5
5
 
6
6
  from piccolo.utils.repr import repr_class_instance
7
7
 
8
8
 
9
9
  class Default(ABC):
10
- @abstractproperty
10
+ @property
11
+ @abstractmethod
11
12
  def postgres(self) -> str:
12
13
  pass
13
14
 
14
- @abstractproperty
15
+ @property
16
+ @abstractmethod
15
17
  def sqlite(self) -> str:
16
18
  pass
17
19
 
18
20
  @abstractmethod
19
- def python(self):
21
+ def python(self) -> t.Any:
20
22
  pass
21
23
 
22
24
  def get_postgres_interval_string(self, attributes: t.List[str]) -> str:
@@ -102,7 +102,15 @@ class DateCustom(Default):
102
102
 
103
103
 
104
104
  # Might add an enum back which encapsulates all of the options.
105
- DateArg = t.Union[DateOffset, DateCustom, DateNow, Enum, None, datetime.date]
105
+ DateArg = t.Union[
106
+ DateOffset,
107
+ DateCustom,
108
+ DateNow,
109
+ Enum,
110
+ None,
111
+ datetime.date,
112
+ t.Callable[[], datetime.date],
113
+ ]
106
114
 
107
115
 
108
116
  __all__ = ["DateArg", "DateOffset", "DateCustom", "DateNow"]
@@ -80,6 +80,7 @@ IntervalArg = t.Union[
80
80
  Enum,
81
81
  None,
82
82
  datetime.timedelta,
83
+ t.Callable[[], datetime.timedelta],
83
84
  ]
84
85
 
85
86
 
@@ -89,7 +89,15 @@ class TimeCustom(Default):
89
89
  )
90
90
 
91
91
 
92
- TimeArg = t.Union[TimeCustom, TimeNow, TimeOffset, Enum, None, datetime.time]
92
+ TimeArg = t.Union[
93
+ TimeCustom,
94
+ TimeNow,
95
+ TimeOffset,
96
+ Enum,
97
+ None,
98
+ datetime.time,
99
+ t.Callable[[], datetime.time],
100
+ ]
93
101
 
94
102
 
95
103
  __all__ = ["TimeArg", "TimeCustom", "TimeNow", "TimeOffset"]
@@ -138,6 +138,7 @@ TimestampArg = t.Union[
138
138
  None,
139
139
  datetime.datetime,
140
140
  DatetimeDefault,
141
+ t.Callable[[], datetime.datetime],
141
142
  ]
142
143
 
143
144
 
@@ -22,7 +22,7 @@ class UUID4(Default):
22
22
  return uuid.uuid4()
23
23
 
24
24
 
25
- UUIDArg = t.Union[UUID4, uuid.UUID, str, Enum, None]
25
+ UUIDArg = t.Union[UUID4, uuid.UUID, str, Enum, None, t.Callable[[], uuid.UUID]]
26
26
 
27
27
 
28
28
  __all__ = ["UUIDArg", "UUID4"]
piccolo/columns/m2m.py CHANGED
@@ -131,6 +131,7 @@ class M2MSelect(Selectable):
131
131
  if len(self.columns) > 1 or not self.serialisation_safe:
132
132
  column_name = table_2_pk_name
133
133
  else:
134
+ assert len(self.columns) > 0
134
135
  column_name = self.columns[0]._meta.db_column_name
135
136
 
136
137
  return QueryString(
@@ -256,15 +257,14 @@ class M2MMeta:
256
257
 
257
258
  @dataclass
258
259
  class M2MAddRelated:
259
-
260
260
  target_row: Table
261
261
  m2m: M2M
262
262
  rows: t.Sequence[Table]
263
263
  extra_column_values: t.Dict[t.Union[Column, str], t.Any]
264
264
 
265
- def __post_init__(self) -> None:
266
- # Normalise `extra_column_values`, so we just have the column names.
267
- self.extra_column_values: t.Dict[str, t.Any] = {
265
+ @property
266
+ def resolved_extra_column_values(self) -> t.Dict[str, t.Any]:
267
+ return {
268
268
  i._meta.name if isinstance(i, Column) else i: j
269
269
  for i, j in self.extra_column_values.items()
270
270
  }
@@ -281,7 +281,9 @@ class M2MAddRelated:
281
281
  joining_table_rows = []
282
282
 
283
283
  for row in rows:
284
- joining_table_row = joining_table(**self.extra_column_values)
284
+ joining_table_row = joining_table(
285
+ **self.resolved_extra_column_values
286
+ )
285
287
  setattr(
286
288
  joining_table_row,
287
289
  self.m2m._meta.primary_foreign_key._meta.name,
@@ -323,7 +325,6 @@ class M2MAddRelated:
323
325
 
324
326
  @dataclass
325
327
  class M2MRemoveRelated:
326
-
327
328
  target_row: Table
328
329
  m2m: M2M
329
330
  rows: t.Sequence[Table]
@@ -363,7 +364,6 @@ class M2MRemoveRelated:
363
364
 
364
365
  @dataclass
365
366
  class M2MGetRelated:
366
-
367
367
  row: Table
368
368
  m2m: M2M
369
369
 
@@ -62,5 +62,9 @@ class ArrayAny(ComparisonOperator):
62
62
  template = "{value} = ANY ({name})"
63
63
 
64
64
 
65
+ class ArrayNotAny(ComparisonOperator):
66
+ template = "NOT {value} = ANY ({name})"
67
+
68
+
65
69
  class ArrayAll(ComparisonOperator):
66
70
  template = "{value} = ALL ({name})"
piccolo/conf/apps.py CHANGED
@@ -157,17 +157,22 @@ class AppConfig:
157
157
  """
158
158
 
159
159
  app_name: str
160
- migrations_folder_path: str
160
+ migrations_folder_path: t.Union[str, pathlib.Path]
161
161
  table_classes: t.List[t.Type[Table]] = field(default_factory=list)
162
162
  migration_dependencies: t.List[str] = field(default_factory=list)
163
163
  commands: t.List[t.Union[t.Callable, Command]] = field(
164
164
  default_factory=list
165
165
  )
166
166
 
167
- def __post_init__(self) -> None:
168
- if isinstance(self.migrations_folder_path, pathlib.Path):
169
- self.migrations_folder_path = str(self.migrations_folder_path)
167
+ @property
168
+ def resolved_migrations_folder_path(self) -> str:
169
+ return (
170
+ str(self.migrations_folder_path)
171
+ if isinstance(self.migrations_folder_path, pathlib.Path)
172
+ else self.migrations_folder_path
173
+ )
170
174
 
175
+ def __post_init__(self) -> None:
171
176
  self._migration_dependency_app_configs: t.Optional[
172
177
  t.List[AppConfig]
173
178
  ] = None
piccolo/engine/base.py CHANGED
@@ -7,12 +7,14 @@ import string
7
7
  import typing as t
8
8
  from abc import ABCMeta, abstractmethod
9
9
 
10
+ from typing_extensions import Self
11
+
10
12
  from piccolo.querystring import QueryString
11
13
  from piccolo.utils.sync import run_sync
12
14
  from piccolo.utils.warnings import Level, colored_string, colored_warning
13
15
 
14
16
  if t.TYPE_CHECKING: # pragma: no cover
15
- from piccolo.query.base import Query
17
+ from piccolo.query.base import DDL, Query
16
18
 
17
19
 
18
20
  logger = logging.getLogger(__name__)
@@ -32,31 +34,76 @@ def validate_savepoint_name(savepoint_name: str) -> None:
32
34
  )
33
35
 
34
36
 
35
- class Batch:
36
- pass
37
+ class BaseBatch(metaclass=ABCMeta):
38
+ @abstractmethod
39
+ async def __aenter__(self: Self, *args, **kwargs) -> Self: ...
37
40
 
41
+ @abstractmethod
42
+ async def __aexit__(self, *args, **kwargs): ...
38
43
 
39
- TransactionClass = t.TypeVar("TransactionClass")
44
+ @abstractmethod
45
+ def __aiter__(self: Self) -> Self: ...
40
46
 
47
+ @abstractmethod
48
+ async def __anext__(self) -> t.List[t.Dict]: ...
41
49
 
42
- class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
43
50
 
44
- __slots__ = ("query_id",)
51
+ class BaseTransaction(metaclass=ABCMeta):
45
52
 
46
- def __init__(self):
47
- run_sync(self.check_version())
48
- run_sync(self.prep_database())
49
- self.query_id = 0
53
+ __slots__: t.Tuple[str, ...] = tuple()
50
54
 
51
- @property
52
55
  @abstractmethod
53
- def engine_type(self) -> str:
54
- pass
56
+ async def __aenter__(self, *args, **kwargs): ...
55
57
 
56
- @property
57
58
  @abstractmethod
58
- def min_version_number(self) -> float:
59
- pass
59
+ async def __aexit__(self, *args, **kwargs) -> bool: ...
60
+
61
+
62
+ class BaseAtomic(metaclass=ABCMeta):
63
+
64
+ __slots__: t.Tuple[str, ...] = tuple()
65
+
66
+ @abstractmethod
67
+ def add(self, *query: t.Union[Query, DDL]): ...
68
+
69
+ @abstractmethod
70
+ async def run(self): ...
71
+
72
+ @abstractmethod
73
+ def run_sync(self): ...
74
+
75
+ @abstractmethod
76
+ def __await__(self): ...
77
+
78
+
79
+ TransactionClass = t.TypeVar("TransactionClass", bound=BaseTransaction)
80
+
81
+
82
+ class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
83
+ __slots__ = (
84
+ "query_id",
85
+ "log_queries",
86
+ "log_responses",
87
+ "engine_type",
88
+ "min_version_number",
89
+ "current_transaction",
90
+ )
91
+
92
+ def __init__(
93
+ self,
94
+ engine_type: str,
95
+ min_version_number: t.Union[int, float],
96
+ log_queries: bool = False,
97
+ log_responses: bool = False,
98
+ ):
99
+ self.log_queries = log_queries
100
+ self.log_responses = log_responses
101
+ self.engine_type = engine_type
102
+ self.min_version_number = min_version_number
103
+
104
+ run_sync(self.check_version())
105
+ run_sync(self.prep_database())
106
+ self.query_id = 0
60
107
 
61
108
  @abstractmethod
62
109
  async def get_version(self) -> float:
@@ -76,11 +123,13 @@ class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
76
123
  query: Query,
77
124
  batch_size: int = 100,
78
125
  node: t.Optional[str] = None,
79
- ) -> Batch:
126
+ ) -> BaseBatch:
80
127
  pass
81
128
 
82
129
  @abstractmethod
83
- async def run_querystring(self, querystring: QueryString, in_pool: bool):
130
+ async def run_querystring(
131
+ self, querystring: QueryString, in_pool: bool = True
132
+ ):
84
133
  pass
85
134
 
86
135
  @abstractmethod
@@ -88,11 +137,11 @@ class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
88
137
  pass
89
138
 
90
139
  @abstractmethod
91
- def transaction(self):
140
+ def transaction(self, *args, **kwargs) -> TransactionClass:
92
141
  pass
93
142
 
94
143
  @abstractmethod
95
- def atomic(self):
144
+ def atomic(self) -> BaseAtomic:
96
145
  pass
97
146
 
98
147
  async def check_version(self):
@@ -16,9 +16,6 @@ class CockroachEngine(PostgresEngine):
16
16
  :class:`PostgresEngine <piccolo.engine.postgres.PostgresEngine>`.
17
17
  """
18
18
 
19
- engine_type = "cockroach"
20
- min_version_number = 0 # Doesn't seem to work with cockroach versioning.
21
-
22
19
  def __init__(
23
20
  self,
24
21
  config: t.Dict[str, t.Any],
@@ -34,6 +31,8 @@ class CockroachEngine(PostgresEngine):
34
31
  log_responses=log_responses,
35
32
  extra_nodes=extra_nodes,
36
33
  )
34
+ self.engine_type = "cockroach"
35
+ self.min_version_number = 0
37
36
 
38
37
  async def prep_database(self):
39
38
  try: