spinta 0.2.dev24__py3-none-any.whl → 0.2.dev26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spinta/backends/components.py +8 -1
- spinta/backends/constants.py +10 -0
- spinta/backends/helpers.py +27 -20
- spinta/backends/postgresql/commands/load.py +2 -1
- spinta/backends/postgresql/commands/wait.py +2 -2
- spinta/backends/postgresql/commands/wipe.py +1 -1
- spinta/backends/postgresql/components.py +7 -1
- spinta/backends/postgresql/helpers/migrate/actions.py +82 -3
- spinta/backends/postgresql/helpers/migrate/citus.py +383 -0
- spinta/backends/postgresql/sqlalchemy.py +19 -0
- spinta/cli/admin.py +2 -0
- spinta/cli/comment.py +14 -8
- spinta/cli/helpers/admin/components.py +3 -0
- spinta/cli/helpers/admin/registry.py +14 -0
- spinta/cli/helpers/admin/scripts/add_local_ids.py +80 -0
- spinta/cli/helpers/admin/scripts/citus_shard.py +126 -0
- spinta/cli/helpers/admin/scripts/remove_local_ids.py +55 -0
- spinta/cli/helpers/upgrade/scripts/backends/postgresql/comments.py +62 -26
- spinta/cli/inspect.py +3 -0
- spinta/cli/uncomment.py +80 -18
- spinta/components.py +7 -1
- spinta/config.py +3 -0
- spinta/config.yml +12 -1
- spinta/datasets/backends/dataframe/commands/read.py +5 -1
- spinta/datasets/backends/dataframe/ufuncs/query/components.py +26 -2
- spinta/datasets/backends/dataframe/ufuncs/query/ufuncs.py +33 -26
- spinta/datasets/backends/sql/commands/cast.py +89 -19
- spinta/datasets/backends/sql/commands/read.py +29 -79
- spinta/datasets/backends/sql/helpers.py +61 -0
- spinta/dimensions/scope/components.py +20 -2
- spinta/dimensions/scope/helpers.py +28 -2
- spinta/dimensions/scope/ufuncs.py +51 -0
- spinta/exceptions.py +17 -0
- spinta/manifests/commands/link.py +2 -0
- spinta/manifests/sql/helpers.py +1 -1
- spinta/manifests/tabular/helpers.py +2 -2
- spinta/testing/citus.py +96 -0
- spinta/testing/cli.py +13 -0
- spinta/testing/pytest.py +38 -24
- spinta/types/config.py +13 -1
- spinta/types/helpers.py +53 -3
- spinta/types/model.py +67 -4
- spinta/urlparams.py +85 -2
- spinta/utils/enums.py +17 -8
- spinta/utils/url.py +10 -0
- {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/METADATA +1 -1
- {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/RECORD +50 -44
- {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/WHEEL +0 -0
- {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/entry_points.txt +0 -0
- {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/licenses/LICENSE +0 -0
spinta/backends/components.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import contextlib
|
|
4
|
+
import dataclasses
|
|
4
5
|
from typing import Any, Type
|
|
5
6
|
from typing import Dict
|
|
6
7
|
from typing import Optional
|
|
7
8
|
from typing import Set
|
|
8
9
|
|
|
9
|
-
from spinta.backends.constants import BackendOrigin, BackendFeatures
|
|
10
|
+
from spinta.backends.constants import BackendOrigin, BackendFeatures, DistributionType
|
|
10
11
|
from spinta.core.ufuncs import Env
|
|
11
12
|
from spinta.ufuncs.resultbuilder.components import ResultBuilder
|
|
12
13
|
|
|
@@ -58,3 +59,9 @@ class Backend:
|
|
|
58
59
|
|
|
59
60
|
|
|
60
61
|
SelectTree = Optional[Dict[str, "SelectTree"]]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclasses.dataclass
|
|
65
|
+
class DistributionStrategy:
|
|
66
|
+
distribution_type: DistributionType
|
|
67
|
+
property: str | None = None
|
spinta/backends/constants.py
CHANGED
|
@@ -35,3 +35,13 @@ class BackendFeatures(enum.Enum):
|
|
|
35
35
|
|
|
36
36
|
# Backend supports
|
|
37
37
|
EXPAND = "EXPAND"
|
|
38
|
+
|
|
39
|
+
# Backend supports sharding
|
|
40
|
+
DISTRIBUTE = "DISTRIBUTE"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DistributionType(enum.Enum):
|
|
44
|
+
SCHEMA = "schema"
|
|
45
|
+
TABLE = "table"
|
|
46
|
+
COPY = "copy"
|
|
47
|
+
UNDISTRIBUTED = "undistributed"
|
spinta/backends/helpers.py
CHANGED
|
@@ -91,7 +91,7 @@ def get_select_tree(
|
|
|
91
91
|
select = list(select.keys())
|
|
92
92
|
|
|
93
93
|
select = _apply_always_show_id(context, action, select)
|
|
94
|
-
if select is None and action in (Action.GETALL, Action.SEARCH):
|
|
94
|
+
if select is None and action in (Action.GETALL, Action.SEARCH, Action.GETONE):
|
|
95
95
|
# If select is not given, select everything.
|
|
96
96
|
select = {"*": {}}
|
|
97
97
|
return flat_select_to_nested(select)
|
|
@@ -344,7 +344,7 @@ def get_ns_reserved_props(action: Action) -> list[str]:
|
|
|
344
344
|
return []
|
|
345
345
|
|
|
346
346
|
|
|
347
|
-
@dataclasses.dataclass
|
|
347
|
+
@dataclasses.dataclass(frozen=True)
|
|
348
348
|
class TableIdentifier:
|
|
349
349
|
"""
|
|
350
350
|
Represents a table identifier across logical (app) and PostgreSQL layers.
|
|
@@ -391,35 +391,42 @@ class TableIdentifier:
|
|
|
391
391
|
table_arg: str | None = dataclasses.field(default=None)
|
|
392
392
|
default_pg_schema: str | None = dataclasses.field(default=None)
|
|
393
393
|
|
|
394
|
-
logical_name: str = dataclasses.field(init=False)
|
|
394
|
+
logical_name: str = dataclasses.field(init=False, compare=False)
|
|
395
395
|
# Name with namespace connected with '/', like it is used with Model class
|
|
396
|
-
logical_qualified_name: str = dataclasses.field(init=False)
|
|
396
|
+
logical_qualified_name: str = dataclasses.field(init=False, compare=False)
|
|
397
397
|
|
|
398
|
-
pg_table_name: str = dataclasses.field(init=False)
|
|
399
|
-
pg_schema_name: str | None = dataclasses.field(init=False)
|
|
398
|
+
pg_table_name: str = dataclasses.field(init=False, compare=False)
|
|
399
|
+
pg_schema_name: str | None = dataclasses.field(init=False, compare=False)
|
|
400
400
|
# Used for hashed schema and table names
|
|
401
|
-
pg_qualified_name: str = dataclasses.field(init=False)
|
|
401
|
+
pg_qualified_name: str = dataclasses.field(init=False, compare=False)
|
|
402
402
|
# Escaped qualified name, used for queries
|
|
403
|
-
pg_escaped_qualified_name: str = dataclasses.field(init=False)
|
|
403
|
+
pg_escaped_qualified_name: str = dataclasses.field(init=False, compare=False)
|
|
404
404
|
|
|
405
405
|
def __post_init__(self):
|
|
406
|
-
|
|
406
|
+
logical_name = self.base_name + self.table_type.value
|
|
407
407
|
if self.table_arg:
|
|
408
|
-
|
|
408
|
+
logical_name += "/" + self.table_arg
|
|
409
409
|
|
|
410
|
-
|
|
410
|
+
logical_qualified_name = f"{self.schema}/{logical_name}" if self.schema else logical_name
|
|
411
411
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
if self.pg_schema_name
|
|
420
|
-
else pg_identifier_preparer.quote(self.pg_table_name)
|
|
412
|
+
pg_table_name = get_pg_name(logical_name)
|
|
413
|
+
pg_schema_name = get_pg_name(self.schema) if self.schema else self.default_pg_schema
|
|
414
|
+
pg_qualified_name = f"{pg_schema_name}.{pg_table_name}" if pg_schema_name else pg_table_name
|
|
415
|
+
pg_escaped_qualified_name = (
|
|
416
|
+
f"{pg_identifier_preparer.quote(pg_schema_name)}.{pg_identifier_preparer.quote(pg_table_name)}"
|
|
417
|
+
if pg_schema_name
|
|
418
|
+
else pg_identifier_preparer.quote(pg_table_name)
|
|
421
419
|
)
|
|
422
420
|
|
|
421
|
+
# This is needed because we want to make this dataclass hashable (frozen=True, does that)
|
|
422
|
+
# But because it becomes immutable, we need to set all the attributes manually (the same way dataclass __init__ does).
|
|
423
|
+
object.__setattr__(self, "logical_name", logical_name)
|
|
424
|
+
object.__setattr__(self, "logical_qualified_name", logical_qualified_name)
|
|
425
|
+
object.__setattr__(self, "pg_table_name", pg_table_name)
|
|
426
|
+
object.__setattr__(self, "pg_schema_name", pg_schema_name)
|
|
427
|
+
object.__setattr__(self, "pg_qualified_name", pg_qualified_name)
|
|
428
|
+
object.__setattr__(self, "pg_escaped_qualified_name", pg_escaped_qualified_name)
|
|
429
|
+
|
|
423
430
|
def change_table_type(self, new_type: TableType, table_arg: str | None = None) -> "TableIdentifier":
|
|
424
431
|
return dataclasses.replace(self, table_type=new_type, table_arg=table_arg)
|
|
425
432
|
|
|
@@ -5,6 +5,7 @@ import sqlalchemy as sa
|
|
|
5
5
|
from spinta import commands
|
|
6
6
|
from spinta.backends.postgresql.components import PostgreSQL
|
|
7
7
|
from spinta.backends.postgresql.helpers.name import PG_NAMING_CONVENTION
|
|
8
|
+
from spinta.backends.postgresql.sqlalchemy import create_postgresql_engine
|
|
8
9
|
from spinta.components import Context
|
|
9
10
|
from spinta.utils.sqlalchemy import get_metadata_naming_convention
|
|
10
11
|
|
|
@@ -12,7 +13,7 @@ from spinta.utils.sqlalchemy import get_metadata_naming_convention
|
|
|
12
13
|
@commands.load.register(Context, PostgreSQL, dict)
|
|
13
14
|
def load(context: Context, backend: PostgreSQL, config: Dict[str, Any]):
|
|
14
15
|
backend.dsn = config["dsn"]
|
|
15
|
-
backend.engine =
|
|
16
|
+
backend.engine = create_postgresql_engine(backend.dsn, echo=False)
|
|
16
17
|
backend.schema = sa.MetaData(backend.engine, naming_convention=get_metadata_naming_convention(PG_NAMING_CONVENTION))
|
|
17
18
|
backend.tables = {}
|
|
18
19
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import sqlalchemy as sa
|
|
2
1
|
import sqlalchemy.exc
|
|
3
2
|
|
|
4
3
|
from spinta import commands
|
|
4
|
+
from spinta.backends.postgresql.sqlalchemy import create_postgresql_engine
|
|
5
5
|
from spinta.components import Context
|
|
6
6
|
from spinta.backends.postgresql.components import PostgreSQL
|
|
7
7
|
|
|
@@ -10,7 +10,7 @@ from spinta.backends.postgresql.components import PostgreSQL
|
|
|
10
10
|
def wait(context: Context, backend: PostgreSQL, *, fail: bool = False) -> bool:
|
|
11
11
|
rc = context.get("rc")
|
|
12
12
|
dsn = rc.get("backends", backend.name, "dsn", required=True)
|
|
13
|
-
engine =
|
|
13
|
+
engine = create_postgresql_engine(dsn, connect_args={"connect_timeout": 0})
|
|
14
14
|
try:
|
|
15
15
|
conn = engine.connect()
|
|
16
16
|
except sqlalchemy.exc.OperationalError:
|
|
@@ -48,7 +48,7 @@ def wipe(context: Context, model: Model, backend: PostgreSQL):
|
|
|
48
48
|
if changelog_table_identifier.pg_schema_name
|
|
49
49
|
else f'"{seqname}"'
|
|
50
50
|
)
|
|
51
|
-
connection.execute(
|
|
51
|
+
connection.execute(sa.func.setval(seq_escaped_named, 1, False))
|
|
52
52
|
|
|
53
53
|
# Delete data table
|
|
54
54
|
table = backend.get_table(model)
|
|
@@ -26,7 +26,13 @@ class PostgreSQL(Backend):
|
|
|
26
26
|
},
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
-
features = {
|
|
29
|
+
features = {
|
|
30
|
+
BackendFeatures.FILE_BLOCKS,
|
|
31
|
+
BackendFeatures.WRITE,
|
|
32
|
+
BackendFeatures.EXPAND,
|
|
33
|
+
BackendFeatures.PAGINATION,
|
|
34
|
+
BackendFeatures.DISTRIBUTE,
|
|
35
|
+
}
|
|
30
36
|
|
|
31
37
|
engine: Engine = None
|
|
32
38
|
schema: sa.MetaData = None
|
|
@@ -3,7 +3,7 @@ from sqlalchemy.dialects.postgresql import UUID
|
|
|
3
3
|
from sqlalchemy.dialects import postgresql
|
|
4
4
|
|
|
5
5
|
import sqlalchemy as sa
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING, Generator
|
|
7
7
|
|
|
8
8
|
from spinta.backends.helpers import TableIdentifier
|
|
9
9
|
from spinta.backends.postgresql.helpers.name import name_changed
|
|
@@ -91,6 +91,19 @@ class RenameTableMigrationAction(MigrationAction):
|
|
|
91
91
|
)
|
|
92
92
|
|
|
93
93
|
|
|
94
|
+
class SetTableCommentMigrationAction(MigrationAction):
|
|
95
|
+
def __init__(self, table_identifier: TableIdentifier, comment: str) -> None:
|
|
96
|
+
self.table_identifier = table_identifier
|
|
97
|
+
self.comment = comment
|
|
98
|
+
|
|
99
|
+
def execute(self, op: "Operations") -> None:
|
|
100
|
+
op.create_table_comment(
|
|
101
|
+
table_name=self.table_identifier.pg_table_name,
|
|
102
|
+
comment=self.comment,
|
|
103
|
+
schema=self.table_identifier.pg_schema_name,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
94
107
|
class AddColumnMigrationAction(MigrationAction):
|
|
95
108
|
def __init__(self, table_identifier: TableIdentifier, column: sa.Column) -> None:
|
|
96
109
|
self.table_identifier = table_identifier
|
|
@@ -149,6 +162,21 @@ class AlterColumnMigrationAction(MigrationAction):
|
|
|
149
162
|
)
|
|
150
163
|
|
|
151
164
|
|
|
165
|
+
class SetColumnCommentMigrationAction(MigrationAction):
|
|
166
|
+
def __init__(self, table_identifier: TableIdentifier, column: str, comment: str) -> None:
|
|
167
|
+
self.table_identifier = table_identifier
|
|
168
|
+
self.comment = comment
|
|
169
|
+
self.column = column
|
|
170
|
+
|
|
171
|
+
def execute(self, op: "Operations") -> None:
|
|
172
|
+
op.alter_column(
|
|
173
|
+
table_name=self.table_identifier.pg_table_name,
|
|
174
|
+
column_name=self.column,
|
|
175
|
+
comment=self.comment,
|
|
176
|
+
schema=self.table_identifier.pg_schema_name,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
152
180
|
class DropConstraintMigrationAction(MigrationAction):
|
|
153
181
|
def __init__(self, table_identifier: TableIdentifier, constraint_name: str) -> None:
|
|
154
182
|
self.table_identifier = table_identifier
|
|
@@ -547,6 +575,50 @@ class CreateSchemaMigrationAction(MigrationAction):
|
|
|
547
575
|
op.execute(self.query)
|
|
548
576
|
|
|
549
577
|
|
|
578
|
+
class DistributeSchema(MigrationAction):
|
|
579
|
+
def __init__(self, schema_name: str) -> None:
|
|
580
|
+
self.schema_name = schema_name
|
|
581
|
+
self.query = f"SELECT citus_schema_distribute('{pg_identifier_preparer.quote(schema_name)}')"
|
|
582
|
+
|
|
583
|
+
def execute(self, op: "Operations") -> None:
|
|
584
|
+
op.execute(self.query)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class DistributeReference(MigrationAction):
|
|
588
|
+
def __init__(self, table_identifier: TableIdentifier) -> None:
|
|
589
|
+
self.query = f"SELECT create_reference_table('{table_identifier.pg_escaped_qualified_name}')"
|
|
590
|
+
|
|
591
|
+
def execute(self, op: "Operations") -> None:
|
|
592
|
+
op.execute(self.query)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
class DistributeTable(MigrationAction):
|
|
596
|
+
def __init__(self, table_identifier: TableIdentifier, column: str) -> None:
|
|
597
|
+
self.query = f"SELECT create_distributed_table('{table_identifier.pg_escaped_qualified_name}', '{column}')"
|
|
598
|
+
|
|
599
|
+
def execute(self, op: "Operations") -> None:
|
|
600
|
+
op.execute(self.query)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
class UndistributeSchema(MigrationAction):
|
|
604
|
+
def __init__(self, schema_name: str) -> None:
|
|
605
|
+
self.schema_name = schema_name
|
|
606
|
+
self.query = f"SELECT citus_schema_undistribute('{pg_identifier_preparer.quote(schema_name)}')"
|
|
607
|
+
|
|
608
|
+
def execute(self, op: "Operations") -> None:
|
|
609
|
+
op.execute(self.query)
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
class UndistributeTable(MigrationAction):
|
|
613
|
+
def __init__(self, table_identifier: TableIdentifier) -> None:
|
|
614
|
+
self.query = (
|
|
615
|
+
f"SELECT undistribute_table('{table_identifier.pg_escaped_qualified_name}', cascade_via_foreign_keys=>true)"
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
def execute(self, op: "Operations") -> None:
|
|
619
|
+
op.execute(self.query)
|
|
620
|
+
|
|
621
|
+
|
|
550
622
|
class MigrationHandler:
|
|
551
623
|
def __init__(self) -> None:
|
|
552
624
|
self.migrations: list[MigrationAction] = []
|
|
@@ -571,8 +643,15 @@ class MigrationHandler:
|
|
|
571
643
|
return True
|
|
572
644
|
return False
|
|
573
645
|
|
|
574
|
-
def
|
|
646
|
+
def gather_migrations(self) -> Generator[MigrationAction, None, None]:
|
|
575
647
|
for migration in self.migrations:
|
|
576
|
-
migration
|
|
648
|
+
yield migration
|
|
577
649
|
for migration in self.foreign_key_migration:
|
|
650
|
+
yield migration
|
|
651
|
+
|
|
652
|
+
def run_migrations(self, op: "Operations") -> None:
|
|
653
|
+
for migration in self.gather_migrations():
|
|
578
654
|
migration.execute(op)
|
|
655
|
+
|
|
656
|
+
def count(self) -> int:
|
|
657
|
+
return len(list(self.gather_migrations()))
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
|
|
5
|
+
import sqlalchemy as sa
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from multipledispatch import dispatch
|
|
8
|
+
|
|
9
|
+
from spinta.backends import Backend
|
|
10
|
+
from spinta.backends.constants import DistributionType
|
|
11
|
+
from spinta.backends.helpers import TableIdentifier
|
|
12
|
+
from spinta.backends.helpers import get_table_identifier
|
|
13
|
+
from spinta.backends.postgresql.components import PostgreSQL
|
|
14
|
+
from spinta.backends.postgresql.helpers.migrate.actions import (
|
|
15
|
+
MigrationHandler,
|
|
16
|
+
UndistributeSchema,
|
|
17
|
+
UndistributeTable,
|
|
18
|
+
DistributeReference,
|
|
19
|
+
DistributeTable,
|
|
20
|
+
DistributeSchema,
|
|
21
|
+
)
|
|
22
|
+
from spinta.cli.helpers.message import cli_message
|
|
23
|
+
from spinta.components import Context, Model
|
|
24
|
+
from spinta.exceptions import NotImplementedFeature
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclasses.dataclass
|
|
28
|
+
class ShardingPlan:
|
|
29
|
+
schemas: set[str] = dataclasses.field(default_factory=set)
|
|
30
|
+
references: set[TableIdentifier] = dataclasses.field(default_factory=set)
|
|
31
|
+
distributed: dict[TableIdentifier, str] = dataclasses.field(default_factory=dict)
|
|
32
|
+
local: set[TableIdentifier] = dataclasses.field(default_factory=set)
|
|
33
|
+
|
|
34
|
+
_lookup: dict[TableIdentifier | str, DistributionType] = dataclasses.field(init=False, default_factory=dict)
|
|
35
|
+
|
|
36
|
+
def __sub__(self, other) -> "ShardingPlan":
|
|
37
|
+
return ShardingPlan(
|
|
38
|
+
schemas=self.schemas - other.schemas,
|
|
39
|
+
references=self.references - other.references,
|
|
40
|
+
distributed=dict(self.distributed.items() - other.distributed.items()),
|
|
41
|
+
local=self.local - other.local,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def __post_init__(self) -> None:
|
|
45
|
+
for schema in self.schemas:
|
|
46
|
+
self._lookup[schema] = DistributionType.SCHEMA
|
|
47
|
+
|
|
48
|
+
for table_identifier in self.distributed.keys():
|
|
49
|
+
self._lookup[table_identifier] = DistributionType.TABLE
|
|
50
|
+
|
|
51
|
+
for table_identifier in self.references:
|
|
52
|
+
self._lookup[table_identifier] = DistributionType.COPY
|
|
53
|
+
|
|
54
|
+
for table_identifier in self.local:
|
|
55
|
+
self._lookup[table_identifier] = DistributionType.UNDISTRIBUTED
|
|
56
|
+
|
|
57
|
+
def distribution_type(self, key: TableIdentifier | str) -> DistributionType | None:
|
|
58
|
+
return self._lookup.get(key, None)
|
|
59
|
+
|
|
60
|
+
def discard(self, key: TableIdentifier | str) -> None:
|
|
61
|
+
distribution_type = self._lookup.pop(key, None)
|
|
62
|
+
if distribution_type is None:
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
match distribution_type:
|
|
66
|
+
case DistributionType.SCHEMA:
|
|
67
|
+
self.schemas.discard(key)
|
|
68
|
+
case DistributionType.TABLE:
|
|
69
|
+
self.distributed.pop(key)
|
|
70
|
+
case DistributionType.COPY:
|
|
71
|
+
self.references.discard(key)
|
|
72
|
+
case DistributionType.UNDISTRIBUTED:
|
|
73
|
+
self.local.discard(key)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _generate_current_distribution_query(schemas: list[str] | None = None) -> (str, dict):
|
|
77
|
+
base_query = """
|
|
78
|
+
SELECT
|
|
79
|
+
n.nspname AS schema_name,
|
|
80
|
+
c.relname AS table_name,
|
|
81
|
+
format('%I.%I', n.nspname, c.relname) AS full_table_name,
|
|
82
|
+
d.description AS table_comment,
|
|
83
|
+
COALESCE(ct.citus_table_type, 'local') AS distribution_type,
|
|
84
|
+
ct.distribution_column
|
|
85
|
+
FROM pg_class c
|
|
86
|
+
JOIN pg_namespace n ON n.oid = c.relnamespace
|
|
87
|
+
LEFT JOIN citus_tables ct ON ct.table_name = c.oid::regclass
|
|
88
|
+
LEFT JOIN pg_description d ON d.objoid = c.oid AND d.objsubid = 0
|
|
89
|
+
WHERE c.relkind = 'r'
|
|
90
|
+
AND n.nspname NOT LIKE 'pg_%'
|
|
91
|
+
AND n.nspname <> 'information_schema'
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
params = {}
|
|
95
|
+
if schemas:
|
|
96
|
+
base_query += " AND n.nspname = ANY(:schemas)"
|
|
97
|
+
params["schemas"] = schemas
|
|
98
|
+
|
|
99
|
+
base_query += " ORDER BY n.nspname, c.relname"
|
|
100
|
+
return base_query, params
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dispatch(Context)
|
|
104
|
+
def gather_current_sharding_plan(context: Context, **kwargs) -> dict[str, ShardingPlan]:
|
|
105
|
+
store = context.get("store")
|
|
106
|
+
backends = store.backends
|
|
107
|
+
|
|
108
|
+
plans: dict[str, ShardingPlan] = defaultdict(ShardingPlan)
|
|
109
|
+
for backend_name, backend in backends.items():
|
|
110
|
+
plans[backend_name] = gather_current_sharding_plan(context, backend, **kwargs)
|
|
111
|
+
return plans
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dispatch(Context, Backend)
|
|
115
|
+
def gather_current_sharding_plan(context: Context, backend: Backend, **kwargs) -> ShardingPlan:
|
|
116
|
+
# Currently, only postgresql backend supports citus distribution, instead of erroring, return empty plan.
|
|
117
|
+
return ShardingPlan()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dispatch(Context, PostgreSQL)
|
|
121
|
+
def gather_current_sharding_plan(
|
|
122
|
+
context: Context, backend: PostgreSQL, schemas: list[str] | None = None, **kwargs
|
|
123
|
+
) -> ShardingPlan:
|
|
124
|
+
plan = ShardingPlan()
|
|
125
|
+
with backend.begin() as conn:
|
|
126
|
+
query, params = _generate_current_distribution_query(schemas)
|
|
127
|
+
rows = conn.execute(sa.text(query), params).fetchall()
|
|
128
|
+
|
|
129
|
+
for row in rows:
|
|
130
|
+
if not row["table_comment"]:
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
table_identifier = get_table_identifier(row["table_comment"])
|
|
134
|
+
match row["distribution_type"]:
|
|
135
|
+
case "schema":
|
|
136
|
+
plan.schemas.add(row["schema_name"])
|
|
137
|
+
case "distributed":
|
|
138
|
+
plan.distributed[table_identifier] = row["distribution_column"]
|
|
139
|
+
case "reference":
|
|
140
|
+
plan.references.add(table_identifier)
|
|
141
|
+
case _:
|
|
142
|
+
plan.local.add(table_identifier)
|
|
143
|
+
return plan
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def create_sharding_plan(context: Context, models: list[Model], **kwargs) -> dict[str, ShardingPlan]:
|
|
147
|
+
plans = defaultdict(ShardingPlan)
|
|
148
|
+
for model in models:
|
|
149
|
+
if not isinstance(model.backend, PostgreSQL):
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
if not model.external or not model.external.dataset:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
distribution_strategy = model.distribution_strategy
|
|
156
|
+
plan = plans[model.backend.name]
|
|
157
|
+
table_identifier = get_table_identifier(model)
|
|
158
|
+
match distribution_strategy.distribution_type:
|
|
159
|
+
case DistributionType.SCHEMA:
|
|
160
|
+
plan.schemas.add(table_identifier.pg_schema_name)
|
|
161
|
+
case DistributionType.TABLE:
|
|
162
|
+
plan.distributed[table_identifier] = distribution_strategy.property
|
|
163
|
+
case DistributionType.COPY:
|
|
164
|
+
plan.references.add(table_identifier)
|
|
165
|
+
case _:
|
|
166
|
+
plan.local.add(table_identifier)
|
|
167
|
+
|
|
168
|
+
return plans
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def invalidate_default_distribution(
|
|
172
|
+
context: Context, backend: PostgreSQL, plan: ShardingPlan, verbose: bool = False, **kwargs
|
|
173
|
+
) -> ShardingPlan:
|
|
174
|
+
default_distribution = context.get("config").default_distribution_strategy
|
|
175
|
+
if not default_distribution:
|
|
176
|
+
return plan
|
|
177
|
+
|
|
178
|
+
match default_distribution.distribution_type:
|
|
179
|
+
case DistributionType.SCHEMA:
|
|
180
|
+
return invalidate_default_schema_distributions(context, backend, plan)
|
|
181
|
+
case _:
|
|
182
|
+
if verbose:
|
|
183
|
+
cli_message(
|
|
184
|
+
f"Skipped invalidation of default distribution for {default_distribution.distribution_type.value} type"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return plan
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def valid_schema_distribution_foreign_key(plan: ShardingPlan, schema: str, foreign_key: dict) -> bool:
|
|
191
|
+
if foreign_key["referred_schema"] == schema:
|
|
192
|
+
return True
|
|
193
|
+
|
|
194
|
+
for reference in plan.references:
|
|
195
|
+
if (
|
|
196
|
+
reference.pg_schema_name == foreign_key["referred_schema"]
|
|
197
|
+
and reference.pg_table_name == foreign_key["referred_table"]
|
|
198
|
+
):
|
|
199
|
+
return True
|
|
200
|
+
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dispatch(Context, Backend, ShardingPlan)
|
|
205
|
+
def invalidate_default_schema_distributions(
|
|
206
|
+
context: Context, backend: Backend, plan: ShardingPlan, **kwargs
|
|
207
|
+
) -> ShardingPlan:
|
|
208
|
+
raise NotImplementedFeature(f"Ability to invalidate default schema distribution for {backend.type!r} backend type")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dispatch(Context, PostgreSQL, ShardingPlan)
|
|
212
|
+
def invalidate_default_schema_distributions(
|
|
213
|
+
context: Context, backend: PostgreSQL, plan: ShardingPlan, **kwargs
|
|
214
|
+
) -> ShardingPlan:
|
|
215
|
+
if not plan.schemas:
|
|
216
|
+
return plan
|
|
217
|
+
|
|
218
|
+
invalid_schemas = set()
|
|
219
|
+
|
|
220
|
+
inspector = sa.inspect(backend.engine)
|
|
221
|
+
|
|
222
|
+
plan_copy = deepcopy(plan)
|
|
223
|
+
for schema in plan.schemas:
|
|
224
|
+
tables = inspector.get_table_names(schema=schema)
|
|
225
|
+
|
|
226
|
+
for table in tables:
|
|
227
|
+
foreign_keys = inspector.get_foreign_keys(table, schema=schema)
|
|
228
|
+
if not foreign_keys:
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
for key in foreign_keys:
|
|
232
|
+
if not valid_schema_distribution_foreign_key(plan, schema, key):
|
|
233
|
+
invalid_schemas.add(key["referred_schema"])
|
|
234
|
+
invalid_schemas.add(schema)
|
|
235
|
+
|
|
236
|
+
if not invalid_schemas:
|
|
237
|
+
return plan
|
|
238
|
+
|
|
239
|
+
for invalid_schema in invalid_schemas:
|
|
240
|
+
plan_copy.schemas.discard(invalid_schema)
|
|
241
|
+
|
|
242
|
+
return plan_copy
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _build_fk_graph(conn: sa.engine.Connection) -> dict[tuple[str, str], set[tuple[str, str]]]:
|
|
246
|
+
graph: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
|
|
247
|
+
|
|
248
|
+
rows = conn.execute("""
|
|
249
|
+
SELECT
|
|
250
|
+
src_ns.nspname AS source_schema,
|
|
251
|
+
src.relname AS source_table,
|
|
252
|
+
tgt_ns.nspname AS target_schema,
|
|
253
|
+
tgt.relname AS target_table
|
|
254
|
+
FROM pg_constraint con
|
|
255
|
+
JOIN pg_class src
|
|
256
|
+
ON src.oid = con.conrelid
|
|
257
|
+
JOIN pg_namespace src_ns
|
|
258
|
+
ON src_ns.oid = src.relnamespace
|
|
259
|
+
JOIN pg_class tgt
|
|
260
|
+
ON tgt.oid = con.confrelid
|
|
261
|
+
JOIN pg_namespace tgt_ns
|
|
262
|
+
ON tgt_ns.oid = tgt.relnamespace
|
|
263
|
+
WHERE con.contype = 'f'
|
|
264
|
+
""")
|
|
265
|
+
|
|
266
|
+
for src_schema, src_table, tgt_schema, tgt_table in rows.fetchall():
|
|
267
|
+
src = (src_schema, src_table)
|
|
268
|
+
tgt = (tgt_schema, tgt_table)
|
|
269
|
+
|
|
270
|
+
if src == tgt:
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
graph[src].add(tgt)
|
|
274
|
+
graph[tgt].add(src)
|
|
275
|
+
|
|
276
|
+
visited: set[tuple[str, str]] = set()
|
|
277
|
+
key_to_component: dict[tuple[str, str], set[tuple[str, str]]] = {}
|
|
278
|
+
|
|
279
|
+
for node in graph:
|
|
280
|
+
if node in visited:
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
stack = [node]
|
|
284
|
+
component: set[tuple[str, str]] = set()
|
|
285
|
+
|
|
286
|
+
while stack:
|
|
287
|
+
cur = stack.pop()
|
|
288
|
+
if cur in visited:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
visited.add(cur)
|
|
292
|
+
component.add(cur)
|
|
293
|
+
stack.extend(graph[cur] - visited)
|
|
294
|
+
|
|
295
|
+
for n in component:
|
|
296
|
+
key_to_component[n] = component
|
|
297
|
+
|
|
298
|
+
return key_to_component
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def build_fk_components(
|
|
302
|
+
conn: sa.engine.Connection, tables: set[TableIdentifier]
|
|
303
|
+
) -> dict[TableIdentifier, set[TableIdentifier]]:
|
|
304
|
+
"""
|
|
305
|
+
Build FK-connected components using FULL DB graph,
|
|
306
|
+
then return mapping only for given `tables`.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
target_map = {(t.pg_schema_name, t.pg_table_name): t for t in tables}
|
|
310
|
+
graph = _build_fk_graph(conn)
|
|
311
|
+
|
|
312
|
+
result = {}
|
|
313
|
+
for key, table in target_map.items():
|
|
314
|
+
full_component = graph.get(key, {key})
|
|
315
|
+
result[table] = {target_map[k] for k in full_component if k in target_map}
|
|
316
|
+
|
|
317
|
+
return result
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def undistribute_all(
|
|
321
|
+
context: Context,
|
|
322
|
+
backend: PostgreSQL,
|
|
323
|
+
plan: ShardingPlan,
|
|
324
|
+
handler: MigrationHandler,
|
|
325
|
+
progress_bar: tqdm | None = None,
|
|
326
|
+
**kwargs,
|
|
327
|
+
) -> None:
|
|
328
|
+
for schema in plan.schemas:
|
|
329
|
+
handler.add_action(UndistributeSchema(schema_name=schema))
|
|
330
|
+
if progress_bar is not None:
|
|
331
|
+
progress_bar.update(1)
|
|
332
|
+
|
|
333
|
+
if not (plan.references or plan.distributed):
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
processed = set()
|
|
337
|
+
undistributed_tables = plan.references | set(table for table in plan.distributed.keys())
|
|
338
|
+
with backend.begin() as conn:
|
|
339
|
+
component_map = build_fk_components(conn, undistributed_tables)
|
|
340
|
+
|
|
341
|
+
for table in plan.distributed.keys():
|
|
342
|
+
if table in processed:
|
|
343
|
+
continue
|
|
344
|
+
|
|
345
|
+
handler.add_action(UndistributeTable(table_identifier=table))
|
|
346
|
+
if progress_bar is not None:
|
|
347
|
+
progress_bar.update(1)
|
|
348
|
+
component = component_map[table]
|
|
349
|
+
processed.update(component)
|
|
350
|
+
|
|
351
|
+
for table in plan.references:
|
|
352
|
+
if table in processed:
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
handler.add_action(UndistributeTable(table_identifier=table))
|
|
356
|
+
if progress_bar is not None:
|
|
357
|
+
progress_bar.update(1)
|
|
358
|
+
component = component_map[table]
|
|
359
|
+
processed.update(component)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def distribute_all(
|
|
363
|
+
context: Context,
|
|
364
|
+
backend: PostgreSQL,
|
|
365
|
+
plan: ShardingPlan,
|
|
366
|
+
handler: MigrationHandler,
|
|
367
|
+
progress_bar: tqdm | None = None,
|
|
368
|
+
**kwargs,
|
|
369
|
+
) -> None:
|
|
370
|
+
for table in plan.references:
|
|
371
|
+
handler.add_action(DistributeReference(table_identifier=table))
|
|
372
|
+
if progress_bar is not None:
|
|
373
|
+
progress_bar.update(1)
|
|
374
|
+
|
|
375
|
+
for table, column in plan.distributed.items():
|
|
376
|
+
handler.add_action(DistributeTable(table_identifier=table, column=column))
|
|
377
|
+
if progress_bar is not None:
|
|
378
|
+
progress_bar.update(1)
|
|
379
|
+
|
|
380
|
+
for schema in plan.schemas:
|
|
381
|
+
handler.add_action(DistributeSchema(schema_name=schema))
|
|
382
|
+
if progress_bar is not None:
|
|
383
|
+
progress_bar.update(1)
|