piccolo 1.9.0__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- piccolo/__init__.py +1 -1
- piccolo/apps/fixtures/commands/load.py +1 -1
- piccolo/apps/migrations/auto/__init__.py +8 -0
- piccolo/apps/migrations/auto/migration_manager.py +2 -1
- piccolo/apps/migrations/commands/backwards.py +3 -1
- piccolo/apps/migrations/commands/base.py +1 -1
- piccolo/apps/migrations/commands/check.py +1 -1
- piccolo/apps/migrations/commands/clean.py +1 -1
- piccolo/apps/migrations/commands/forwards.py +3 -1
- piccolo/apps/migrations/commands/new.py +4 -2
- piccolo/apps/schema/commands/generate.py +2 -2
- piccolo/apps/shell/commands/run.py +1 -1
- piccolo/columns/column_types.py +28 -4
- piccolo/columns/defaults/base.py +1 -1
- piccolo/columns/defaults/date.py +9 -1
- piccolo/columns/defaults/interval.py +1 -0
- piccolo/columns/defaults/time.py +9 -1
- piccolo/columns/defaults/timestamp.py +1 -0
- piccolo/columns/defaults/uuid.py +1 -1
- piccolo/columns/m2m.py +7 -7
- piccolo/columns/operators/comparison.py +4 -0
- piccolo/conf/apps.py +9 -4
- piccolo/engine/base.py +69 -20
- piccolo/engine/cockroach.py +2 -3
- piccolo/engine/postgres.py +33 -19
- piccolo/engine/sqlite.py +27 -22
- piccolo/query/methods/create_index.py +1 -1
- piccolo/query/methods/drop_index.py +1 -1
- piccolo/query/methods/objects.py +7 -7
- piccolo/query/methods/select.py +13 -7
- piccolo/query/mixins.py +3 -10
- piccolo/schema.py +18 -11
- piccolo/table.py +22 -21
- piccolo/utils/encoding.py +5 -3
- {piccolo-1.9.0.dist-info → piccolo-1.10.0.dist-info}/METADATA +1 -1
- {piccolo-1.9.0.dist-info → piccolo-1.10.0.dist-info}/RECORD +47 -47
- tests/apps/migrations/auto/integration/test_migrations.py +1 -1
- tests/columns/test_array.py +28 -0
- tests/conf/test_apps.py +1 -1
- tests/engine/test_nested_transaction.py +2 -0
- tests/engine/test_transaction.py +1 -2
- tests/table/test_indexes.py +4 -2
- tests/utils/test_pydantic.py +70 -29
- {piccolo-1.9.0.dist-info → piccolo-1.10.0.dist-info}/LICENSE +0 -0
- {piccolo-1.9.0.dist-info → piccolo-1.10.0.dist-info}/WHEEL +0 -0
- {piccolo-1.9.0.dist-info → piccolo-1.10.0.dist-info}/entry_points.txt +0 -0
- {piccolo-1.9.0.dist-info → piccolo-1.10.0.dist-info}/top_level.txt +0 -0
piccolo/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__VERSION__ = "1.
|
1
|
+
__VERSION__ = "1.10.0"
|
@@ -2,3 +2,11 @@ from .diffable_table import DiffableTable
|
|
2
2
|
from .migration_manager import MigrationManager
|
3
3
|
from .schema_differ import AlterStatements, SchemaDiffer
|
4
4
|
from .schema_snapshot import SchemaSnapshot
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"DiffableTable",
|
8
|
+
"MigrationManager",
|
9
|
+
"AlterStatements",
|
10
|
+
"SchemaDiffer",
|
11
|
+
"SchemaSnapshot",
|
12
|
+
]
|
@@ -261,7 +261,8 @@ class MigrationManager:
|
|
261
261
|
cleaned_params = deserialise_params(params=params)
|
262
262
|
column = column_class(**cleaned_params)
|
263
263
|
column._meta.name = column_name
|
264
|
-
|
264
|
+
if db_column_name:
|
265
|
+
column._meta.db_column_name = db_column_name
|
265
266
|
|
266
267
|
self.add_columns.append(
|
267
268
|
AddColumnClass(
|
@@ -32,7 +32,9 @@ class BackwardsMigrationManager(BaseMigrationManager):
|
|
32
32
|
|
33
33
|
async def run_migrations_backwards(self, app_config: AppConfig):
|
34
34
|
migration_modules: t.Dict[str, MigrationModule] = (
|
35
|
-
self.get_migration_modules(
|
35
|
+
self.get_migration_modules(
|
36
|
+
app_config.resolved_migrations_folder_path
|
37
|
+
)
|
36
38
|
)
|
37
39
|
|
38
40
|
ran_migration_ids = await Migration.get_migrations_which_ran(
|
@@ -86,7 +86,7 @@ class BaseMigrationManager(Finder):
|
|
86
86
|
"""
|
87
87
|
migration_managers: t.List[MigrationManager] = []
|
88
88
|
|
89
|
-
migrations_folder = app_config.
|
89
|
+
migrations_folder = app_config.resolved_migrations_folder_path
|
90
90
|
|
91
91
|
migration_modules: t.Dict[str, MigrationModule] = (
|
92
92
|
self.get_migration_modules(migrations_folder)
|
@@ -36,7 +36,7 @@ class CheckMigrationManager(BaseMigrationManager):
|
|
36
36
|
continue
|
37
37
|
|
38
38
|
migration_modules = self.get_migration_modules(
|
39
|
-
app_config.
|
39
|
+
app_config.resolved_migrations_folder_path
|
40
40
|
)
|
41
41
|
ids = self.get_migration_ids(migration_modules)
|
42
42
|
for _id in ids:
|
@@ -20,7 +20,7 @@ class CleanMigrationManager(BaseMigrationManager):
|
|
20
20
|
app_config = self.get_app_config(app_name=self.app_name)
|
21
21
|
|
22
22
|
migration_module_dict = self.get_migration_modules(
|
23
|
-
folder_path=app_config.
|
23
|
+
folder_path=app_config.resolved_migrations_folder_path
|
24
24
|
)
|
25
25
|
|
26
26
|
# The migration IDs which are in migration modules.
|
@@ -33,7 +33,9 @@ class ForwardsMigrationManager(BaseMigrationManager):
|
|
33
33
|
)
|
34
34
|
|
35
35
|
migration_modules: t.Dict[str, MigrationModule] = (
|
36
|
-
self.get_migration_modules(
|
36
|
+
self.get_migration_modules(
|
37
|
+
app_config.resolved_migrations_folder_path
|
38
|
+
)
|
37
39
|
)
|
38
40
|
|
39
41
|
ids = self.get_migration_ids(migration_modules)
|
@@ -98,7 +98,9 @@ def _generate_migration_meta(app_config: AppConfig) -> NewMigrationMeta:
|
|
98
98
|
|
99
99
|
filename = f"{cleaned_app_name}_{cleaned_id}"
|
100
100
|
|
101
|
-
path = os.path.join(
|
101
|
+
path = os.path.join(
|
102
|
+
app_config.resolved_migrations_folder_path, f"{filename}.py"
|
103
|
+
)
|
102
104
|
|
103
105
|
return NewMigrationMeta(
|
104
106
|
migration_id=_id, migration_filename=filename, migration_path=path
|
@@ -255,7 +257,7 @@ async def new(
|
|
255
257
|
|
256
258
|
app_config = Finder().get_app_config(app_name=app_name)
|
257
259
|
|
258
|
-
_create_migrations_folder(app_config.
|
260
|
+
_create_migrations_folder(app_config.resolved_migrations_folder_path)
|
259
261
|
|
260
262
|
try:
|
261
263
|
await _create_new_migration(
|
@@ -313,7 +313,7 @@ COLUMN_TYPE_MAP_COCKROACH: t.Dict[str, t.Type[Column]] = {
|
|
313
313
|
**{"integer": BigInt, "json": JSONB},
|
314
314
|
}
|
315
315
|
|
316
|
-
COLUMN_DEFAULT_PARSER = {
|
316
|
+
COLUMN_DEFAULT_PARSER: t.Dict[t.Type[Column], t.Any] = {
|
317
317
|
BigInt: re.compile(r"^'?(?P<value>-?[0-9]\d*)'?(?:::bigint)?$"),
|
318
318
|
Boolean: re.compile(r"^(?P<value>true|false)$"),
|
319
319
|
Bytea: re.compile(r"'(?P<value>.*)'::bytea$"),
|
@@ -373,7 +373,7 @@ COLUMN_DEFAULT_PARSER = {
|
|
373
373
|
}
|
374
374
|
|
375
375
|
# Re-map for Cockroach compatibility.
|
376
|
-
COLUMN_DEFAULT_PARSER_COCKROACH = {
|
376
|
+
COLUMN_DEFAULT_PARSER_COCKROACH: t.Dict[t.Type[Column], t.Any] = {
|
377
377
|
**COLUMN_DEFAULT_PARSER,
|
378
378
|
BigInt: re.compile(r"^(?P<value>-?\d+)$"),
|
379
379
|
}
|
@@ -24,7 +24,7 @@ def start_ipython_shell(**tables: t.Type[Table]): # pragma: no cover
|
|
24
24
|
if table_class_name not in existing_global_names:
|
25
25
|
globals()[table_class_name] = table_class
|
26
26
|
|
27
|
-
IPython.embed(using=_asyncio_runner, colors="neutral")
|
27
|
+
IPython.embed(using=_asyncio_runner, colors="neutral") # type: ignore
|
28
28
|
|
29
29
|
|
30
30
|
def run() -> None:
|
piccolo/columns/column_types.py
CHANGED
@@ -57,7 +57,11 @@ from piccolo.columns.defaults.timestamptz import (
|
|
57
57
|
TimestamptzNow,
|
58
58
|
)
|
59
59
|
from piccolo.columns.defaults.uuid import UUID4, UUIDArg
|
60
|
-
from piccolo.columns.operators.comparison import
|
60
|
+
from piccolo.columns.operators.comparison import (
|
61
|
+
ArrayAll,
|
62
|
+
ArrayAny,
|
63
|
+
ArrayNotAny,
|
64
|
+
)
|
61
65
|
from piccolo.columns.operators.string import Concat
|
62
66
|
from piccolo.columns.reference import LazyTableReference
|
63
67
|
from piccolo.querystring import QueryString
|
@@ -1952,7 +1956,9 @@ class ForeignKey(Column, t.Generic[ReferencedTable]):
|
|
1952
1956
|
|
1953
1957
|
if is_table_class:
|
1954
1958
|
# Record the reverse relationship on the target table.
|
1955
|
-
|
1959
|
+
t.cast(
|
1960
|
+
t.Type[Table], references
|
1961
|
+
)._meta._foreign_key_references.append(self)
|
1956
1962
|
|
1957
1963
|
# Allow columns on the referenced table to be accessed via
|
1958
1964
|
# auto completion.
|
@@ -2670,6 +2676,24 @@ class Array(Column):
|
|
2670
2676
|
else:
|
2671
2677
|
raise ValueError("Unrecognised engine type")
|
2672
2678
|
|
2679
|
+
def not_any(self, value: t.Any) -> Where:
|
2680
|
+
"""
|
2681
|
+
Check if the given value isn't in the array.
|
2682
|
+
|
2683
|
+
.. code-block:: python
|
2684
|
+
|
2685
|
+
>>> await Ticket.select().where(Ticket.seat_numbers.not_any(510))
|
2686
|
+
|
2687
|
+
"""
|
2688
|
+
engine_type = self._meta.engine_type
|
2689
|
+
|
2690
|
+
if engine_type in ("postgres", "cockroach"):
|
2691
|
+
return Where(column=self, value=value, operator=ArrayNotAny)
|
2692
|
+
elif engine_type == "sqlite":
|
2693
|
+
return self.not_like(f"%{value}%")
|
2694
|
+
else:
|
2695
|
+
raise ValueError("Unrecognised engine type")
|
2696
|
+
|
2673
2697
|
def all(self, value: t.Any) -> Where:
|
2674
2698
|
"""
|
2675
2699
|
Check if all of the items in the array match the given value.
|
@@ -2688,7 +2712,7 @@ class Array(Column):
|
|
2688
2712
|
else:
|
2689
2713
|
raise ValueError("Unrecognised engine type")
|
2690
2714
|
|
2691
|
-
def cat(self, value: t.List[t.Any]) -> QueryString:
|
2715
|
+
def cat(self, value: t.Union[t.Any, t.List[t.Any]]) -> QueryString:
|
2692
2716
|
"""
|
2693
2717
|
Used in an ``update`` query to append items to an array.
|
2694
2718
|
|
@@ -2719,7 +2743,7 @@ class Array(Column):
|
|
2719
2743
|
db_column_name = self._meta.db_column_name
|
2720
2744
|
return QueryString(f'array_cat("{db_column_name}", {{}})', value)
|
2721
2745
|
|
2722
|
-
def __add__(self, value: t.List[t.Any]) -> QueryString:
|
2746
|
+
def __add__(self, value: t.Union[t.Any, t.List[t.Any]]) -> QueryString:
|
2723
2747
|
return self.cat(value)
|
2724
2748
|
|
2725
2749
|
###########################################################################
|
piccolo/columns/defaults/base.py
CHANGED
piccolo/columns/defaults/date.py
CHANGED
@@ -102,7 +102,15 @@ class DateCustom(Default):
|
|
102
102
|
|
103
103
|
|
104
104
|
# Might add an enum back which encapsulates all of the options.
|
105
|
-
DateArg = t.Union[
|
105
|
+
DateArg = t.Union[
|
106
|
+
DateOffset,
|
107
|
+
DateCustom,
|
108
|
+
DateNow,
|
109
|
+
Enum,
|
110
|
+
None,
|
111
|
+
datetime.date,
|
112
|
+
t.Callable[[], datetime.date],
|
113
|
+
]
|
106
114
|
|
107
115
|
|
108
116
|
__all__ = ["DateArg", "DateOffset", "DateCustom", "DateNow"]
|
piccolo/columns/defaults/time.py
CHANGED
@@ -89,7 +89,15 @@ class TimeCustom(Default):
|
|
89
89
|
)
|
90
90
|
|
91
91
|
|
92
|
-
TimeArg = t.Union[
|
92
|
+
TimeArg = t.Union[
|
93
|
+
TimeCustom,
|
94
|
+
TimeNow,
|
95
|
+
TimeOffset,
|
96
|
+
Enum,
|
97
|
+
None,
|
98
|
+
datetime.time,
|
99
|
+
t.Callable[[], datetime.time],
|
100
|
+
]
|
93
101
|
|
94
102
|
|
95
103
|
__all__ = ["TimeArg", "TimeCustom", "TimeNow", "TimeOffset"]
|
piccolo/columns/defaults/uuid.py
CHANGED
piccolo/columns/m2m.py
CHANGED
@@ -131,6 +131,7 @@ class M2MSelect(Selectable):
|
|
131
131
|
if len(self.columns) > 1 or not self.serialisation_safe:
|
132
132
|
column_name = table_2_pk_name
|
133
133
|
else:
|
134
|
+
assert len(self.columns) > 0
|
134
135
|
column_name = self.columns[0]._meta.db_column_name
|
135
136
|
|
136
137
|
return QueryString(
|
@@ -256,15 +257,14 @@ class M2MMeta:
|
|
256
257
|
|
257
258
|
@dataclass
|
258
259
|
class M2MAddRelated:
|
259
|
-
|
260
260
|
target_row: Table
|
261
261
|
m2m: M2M
|
262
262
|
rows: t.Sequence[Table]
|
263
263
|
extra_column_values: t.Dict[t.Union[Column, str], t.Any]
|
264
264
|
|
265
|
-
|
266
|
-
|
267
|
-
|
265
|
+
@property
|
266
|
+
def resolved_extra_column_values(self) -> t.Dict[str, t.Any]:
|
267
|
+
return {
|
268
268
|
i._meta.name if isinstance(i, Column) else i: j
|
269
269
|
for i, j in self.extra_column_values.items()
|
270
270
|
}
|
@@ -281,7 +281,9 @@ class M2MAddRelated:
|
|
281
281
|
joining_table_rows = []
|
282
282
|
|
283
283
|
for row in rows:
|
284
|
-
joining_table_row = joining_table(
|
284
|
+
joining_table_row = joining_table(
|
285
|
+
**self.resolved_extra_column_values
|
286
|
+
)
|
285
287
|
setattr(
|
286
288
|
joining_table_row,
|
287
289
|
self.m2m._meta.primary_foreign_key._meta.name,
|
@@ -323,7 +325,6 @@ class M2MAddRelated:
|
|
323
325
|
|
324
326
|
@dataclass
|
325
327
|
class M2MRemoveRelated:
|
326
|
-
|
327
328
|
target_row: Table
|
328
329
|
m2m: M2M
|
329
330
|
rows: t.Sequence[Table]
|
@@ -363,7 +364,6 @@ class M2MRemoveRelated:
|
|
363
364
|
|
364
365
|
@dataclass
|
365
366
|
class M2MGetRelated:
|
366
|
-
|
367
367
|
row: Table
|
368
368
|
m2m: M2M
|
369
369
|
|
piccolo/conf/apps.py
CHANGED
@@ -157,17 +157,22 @@ class AppConfig:
|
|
157
157
|
"""
|
158
158
|
|
159
159
|
app_name: str
|
160
|
-
migrations_folder_path: str
|
160
|
+
migrations_folder_path: t.Union[str, pathlib.Path]
|
161
161
|
table_classes: t.List[t.Type[Table]] = field(default_factory=list)
|
162
162
|
migration_dependencies: t.List[str] = field(default_factory=list)
|
163
163
|
commands: t.List[t.Union[t.Callable, Command]] = field(
|
164
164
|
default_factory=list
|
165
165
|
)
|
166
166
|
|
167
|
-
|
168
|
-
|
169
|
-
|
167
|
+
@property
|
168
|
+
def resolved_migrations_folder_path(self) -> str:
|
169
|
+
return (
|
170
|
+
str(self.migrations_folder_path)
|
171
|
+
if isinstance(self.migrations_folder_path, pathlib.Path)
|
172
|
+
else self.migrations_folder_path
|
173
|
+
)
|
170
174
|
|
175
|
+
def __post_init__(self) -> None:
|
171
176
|
self._migration_dependency_app_configs: t.Optional[
|
172
177
|
t.List[AppConfig]
|
173
178
|
] = None
|
piccolo/engine/base.py
CHANGED
@@ -7,12 +7,14 @@ import string
|
|
7
7
|
import typing as t
|
8
8
|
from abc import ABCMeta, abstractmethod
|
9
9
|
|
10
|
+
from typing_extensions import Self
|
11
|
+
|
10
12
|
from piccolo.querystring import QueryString
|
11
13
|
from piccolo.utils.sync import run_sync
|
12
14
|
from piccolo.utils.warnings import Level, colored_string, colored_warning
|
13
15
|
|
14
16
|
if t.TYPE_CHECKING: # pragma: no cover
|
15
|
-
from piccolo.query.base import Query
|
17
|
+
from piccolo.query.base import DDL, Query
|
16
18
|
|
17
19
|
|
18
20
|
logger = logging.getLogger(__name__)
|
@@ -32,31 +34,76 @@ def validate_savepoint_name(savepoint_name: str) -> None:
|
|
32
34
|
)
|
33
35
|
|
34
36
|
|
35
|
-
class
|
36
|
-
|
37
|
+
class BaseBatch(metaclass=ABCMeta):
|
38
|
+
@abstractmethod
|
39
|
+
async def __aenter__(self: Self, *args, **kwargs) -> Self: ...
|
37
40
|
|
41
|
+
@abstractmethod
|
42
|
+
async def __aexit__(self, *args, **kwargs): ...
|
38
43
|
|
39
|
-
|
44
|
+
@abstractmethod
|
45
|
+
def __aiter__(self: Self) -> Self: ...
|
40
46
|
|
47
|
+
@abstractmethod
|
48
|
+
async def __anext__(self) -> t.List[t.Dict]: ...
|
41
49
|
|
42
|
-
class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
|
43
50
|
|
44
|
-
|
51
|
+
class BaseTransaction(metaclass=ABCMeta):
|
45
52
|
|
46
|
-
|
47
|
-
run_sync(self.check_version())
|
48
|
-
run_sync(self.prep_database())
|
49
|
-
self.query_id = 0
|
53
|
+
__slots__: t.Tuple[str, ...] = tuple()
|
50
54
|
|
51
|
-
@property
|
52
55
|
@abstractmethod
|
53
|
-
def
|
54
|
-
pass
|
56
|
+
async def __aenter__(self, *args, **kwargs): ...
|
55
57
|
|
56
|
-
@property
|
57
58
|
@abstractmethod
|
58
|
-
def
|
59
|
-
|
59
|
+
async def __aexit__(self, *args, **kwargs) -> bool: ...
|
60
|
+
|
61
|
+
|
62
|
+
class BaseAtomic(metaclass=ABCMeta):
|
63
|
+
|
64
|
+
__slots__: t.Tuple[str, ...] = tuple()
|
65
|
+
|
66
|
+
@abstractmethod
|
67
|
+
def add(self, *query: t.Union[Query, DDL]): ...
|
68
|
+
|
69
|
+
@abstractmethod
|
70
|
+
async def run(self): ...
|
71
|
+
|
72
|
+
@abstractmethod
|
73
|
+
def run_sync(self): ...
|
74
|
+
|
75
|
+
@abstractmethod
|
76
|
+
def __await__(self): ...
|
77
|
+
|
78
|
+
|
79
|
+
TransactionClass = t.TypeVar("TransactionClass", bound=BaseTransaction)
|
80
|
+
|
81
|
+
|
82
|
+
class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
|
83
|
+
__slots__ = (
|
84
|
+
"query_id",
|
85
|
+
"log_queries",
|
86
|
+
"log_responses",
|
87
|
+
"engine_type",
|
88
|
+
"min_version_number",
|
89
|
+
"current_transaction",
|
90
|
+
)
|
91
|
+
|
92
|
+
def __init__(
|
93
|
+
self,
|
94
|
+
engine_type: str,
|
95
|
+
min_version_number: t.Union[int, float],
|
96
|
+
log_queries: bool = False,
|
97
|
+
log_responses: bool = False,
|
98
|
+
):
|
99
|
+
self.log_queries = log_queries
|
100
|
+
self.log_responses = log_responses
|
101
|
+
self.engine_type = engine_type
|
102
|
+
self.min_version_number = min_version_number
|
103
|
+
|
104
|
+
run_sync(self.check_version())
|
105
|
+
run_sync(self.prep_database())
|
106
|
+
self.query_id = 0
|
60
107
|
|
61
108
|
@abstractmethod
|
62
109
|
async def get_version(self) -> float:
|
@@ -76,11 +123,13 @@ class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
|
|
76
123
|
query: Query,
|
77
124
|
batch_size: int = 100,
|
78
125
|
node: t.Optional[str] = None,
|
79
|
-
) ->
|
126
|
+
) -> BaseBatch:
|
80
127
|
pass
|
81
128
|
|
82
129
|
@abstractmethod
|
83
|
-
async def run_querystring(
|
130
|
+
async def run_querystring(
|
131
|
+
self, querystring: QueryString, in_pool: bool = True
|
132
|
+
):
|
84
133
|
pass
|
85
134
|
|
86
135
|
@abstractmethod
|
@@ -88,11 +137,11 @@ class Engine(t.Generic[TransactionClass], metaclass=ABCMeta):
|
|
88
137
|
pass
|
89
138
|
|
90
139
|
@abstractmethod
|
91
|
-
def transaction(self):
|
140
|
+
def transaction(self, *args, **kwargs) -> TransactionClass:
|
92
141
|
pass
|
93
142
|
|
94
143
|
@abstractmethod
|
95
|
-
def atomic(self):
|
144
|
+
def atomic(self) -> BaseAtomic:
|
96
145
|
pass
|
97
146
|
|
98
147
|
async def check_version(self):
|
piccolo/engine/cockroach.py
CHANGED
@@ -16,9 +16,6 @@ class CockroachEngine(PostgresEngine):
|
|
16
16
|
:class:`PostgresEngine <piccolo.engine.postgres.PostgresEngine>`.
|
17
17
|
"""
|
18
18
|
|
19
|
-
engine_type = "cockroach"
|
20
|
-
min_version_number = 0 # Doesn't seem to work with cockroach versioning.
|
21
|
-
|
22
19
|
def __init__(
|
23
20
|
self,
|
24
21
|
config: t.Dict[str, t.Any],
|
@@ -34,6 +31,8 @@ class CockroachEngine(PostgresEngine):
|
|
34
31
|
log_responses=log_responses,
|
35
32
|
extra_nodes=extra_nodes,
|
36
33
|
)
|
34
|
+
self.engine_type = "cockroach"
|
35
|
+
self.min_version_number = 0
|
37
36
|
|
38
37
|
async def prep_database(self):
|
39
38
|
try:
|
piccolo/engine/postgres.py
CHANGED
@@ -4,7 +4,15 @@ import contextvars
|
|
4
4
|
import typing as t
|
5
5
|
from dataclasses import dataclass
|
6
6
|
|
7
|
-
from
|
7
|
+
from typing_extensions import Self
|
8
|
+
|
9
|
+
from piccolo.engine.base import (
|
10
|
+
BaseAtomic,
|
11
|
+
BaseBatch,
|
12
|
+
BaseTransaction,
|
13
|
+
Engine,
|
14
|
+
validate_savepoint_name,
|
15
|
+
)
|
8
16
|
from piccolo.engine.exceptions import TransactionError
|
9
17
|
from piccolo.query.base import DDL, Query
|
10
18
|
from piccolo.querystring import QueryString
|
@@ -18,16 +26,17 @@ if t.TYPE_CHECKING: # pragma: no cover
|
|
18
26
|
from asyncpg.connection import Connection
|
19
27
|
from asyncpg.cursor import Cursor
|
20
28
|
from asyncpg.pool import Pool
|
29
|
+
from asyncpg.transaction import Transaction
|
21
30
|
|
22
31
|
|
23
32
|
@dataclass
|
24
|
-
class AsyncBatch(
|
33
|
+
class AsyncBatch(BaseBatch):
|
25
34
|
connection: Connection
|
26
35
|
query: Query
|
27
36
|
batch_size: int
|
28
37
|
|
29
38
|
# Set internally
|
30
|
-
_transaction = None
|
39
|
+
_transaction: t.Optional[Transaction] = None
|
31
40
|
_cursor: t.Optional[Cursor] = None
|
32
41
|
|
33
42
|
@property
|
@@ -36,20 +45,26 @@ class AsyncBatch(Batch):
|
|
36
45
|
raise ValueError("_cursor not set")
|
37
46
|
return self._cursor
|
38
47
|
|
48
|
+
@property
|
49
|
+
def transaction(self) -> Transaction:
|
50
|
+
if not self._transaction:
|
51
|
+
raise ValueError("The transaction can't be found.")
|
52
|
+
return self._transaction
|
53
|
+
|
39
54
|
async def next(self) -> t.List[t.Dict]:
|
40
55
|
data = await self.cursor.fetch(self.batch_size)
|
41
56
|
return await self.query._process_results(data)
|
42
57
|
|
43
|
-
def __aiter__(self):
|
58
|
+
def __aiter__(self: Self) -> Self:
|
44
59
|
return self
|
45
60
|
|
46
|
-
async def __anext__(self):
|
61
|
+
async def __anext__(self) -> t.List[t.Dict]:
|
47
62
|
response = await self.next()
|
48
63
|
if response == []:
|
49
64
|
raise StopAsyncIteration()
|
50
65
|
return response
|
51
66
|
|
52
|
-
async def __aenter__(self):
|
67
|
+
async def __aenter__(self: Self) -> Self:
|
53
68
|
self._transaction = self.connection.transaction()
|
54
69
|
await self._transaction.start()
|
55
70
|
querystring = self.query.querystrings[0]
|
@@ -60,9 +75,9 @@ class AsyncBatch(Batch):
|
|
60
75
|
|
61
76
|
async def __aexit__(self, exception_type, exception, traceback):
|
62
77
|
if exception:
|
63
|
-
await self.
|
78
|
+
await self.transaction.rollback()
|
64
79
|
else:
|
65
|
-
await self.
|
80
|
+
await self.transaction.commit()
|
66
81
|
|
67
82
|
await self.connection.close()
|
68
83
|
|
@@ -72,7 +87,7 @@ class AsyncBatch(Batch):
|
|
72
87
|
###############################################################################
|
73
88
|
|
74
89
|
|
75
|
-
class Atomic:
|
90
|
+
class Atomic(BaseAtomic):
|
76
91
|
"""
|
77
92
|
This is useful if you want to build up a transaction programatically, by
|
78
93
|
adding queries to it.
|
@@ -140,7 +155,7 @@ class Savepoint:
|
|
140
155
|
)
|
141
156
|
|
142
157
|
|
143
|
-
class PostgresTransaction:
|
158
|
+
class PostgresTransaction(BaseTransaction):
|
144
159
|
"""
|
145
160
|
Used for wrapping queries in a transaction, using a context manager.
|
146
161
|
Currently it's async only.
|
@@ -243,7 +258,7 @@ class PostgresTransaction:
|
|
243
258
|
|
244
259
|
###########################################################################
|
245
260
|
|
246
|
-
async def __aexit__(self, exception_type, exception, traceback):
|
261
|
+
async def __aexit__(self, exception_type, exception, traceback) -> bool:
|
247
262
|
if self._parent:
|
248
263
|
return exception is None
|
249
264
|
|
@@ -269,7 +284,7 @@ class PostgresTransaction:
|
|
269
284
|
###############################################################################
|
270
285
|
|
271
286
|
|
272
|
-
class PostgresEngine(Engine[
|
287
|
+
class PostgresEngine(Engine[PostgresTransaction]):
|
273
288
|
"""
|
274
289
|
Used to connect to PostgreSQL.
|
275
290
|
|
@@ -331,16 +346,10 @@ class PostgresEngine(Engine[t.Optional[PostgresTransaction]]):
|
|
331
346
|
__slots__ = (
|
332
347
|
"config",
|
333
348
|
"extensions",
|
334
|
-
"log_queries",
|
335
|
-
"log_responses",
|
336
349
|
"extra_nodes",
|
337
350
|
"pool",
|
338
|
-
"current_transaction",
|
339
351
|
)
|
340
352
|
|
341
|
-
engine_type = "postgres"
|
342
|
-
min_version_number = 10
|
343
|
-
|
344
353
|
def __init__(
|
345
354
|
self,
|
346
355
|
config: t.Dict[str, t.Any],
|
@@ -362,7 +371,12 @@ class PostgresEngine(Engine[t.Optional[PostgresTransaction]]):
|
|
362
371
|
self.current_transaction = contextvars.ContextVar(
|
363
372
|
f"pg_current_transaction_{database_name}", default=None
|
364
373
|
)
|
365
|
-
super().__init__(
|
374
|
+
super().__init__(
|
375
|
+
engine_type="postgres",
|
376
|
+
log_queries=log_queries,
|
377
|
+
log_responses=log_responses,
|
378
|
+
min_version_number=10,
|
379
|
+
)
|
366
380
|
|
367
381
|
@staticmethod
|
368
382
|
def _parse_raw_version_string(version_string: str) -> float:
|