spinta 0.2.dev24__py3-none-any.whl → 0.2.dev26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. spinta/backends/components.py +8 -1
  2. spinta/backends/constants.py +10 -0
  3. spinta/backends/helpers.py +27 -20
  4. spinta/backends/postgresql/commands/load.py +2 -1
  5. spinta/backends/postgresql/commands/wait.py +2 -2
  6. spinta/backends/postgresql/commands/wipe.py +1 -1
  7. spinta/backends/postgresql/components.py +7 -1
  8. spinta/backends/postgresql/helpers/migrate/actions.py +82 -3
  9. spinta/backends/postgresql/helpers/migrate/citus.py +383 -0
  10. spinta/backends/postgresql/sqlalchemy.py +19 -0
  11. spinta/cli/admin.py +2 -0
  12. spinta/cli/comment.py +14 -8
  13. spinta/cli/helpers/admin/components.py +3 -0
  14. spinta/cli/helpers/admin/registry.py +14 -0
  15. spinta/cli/helpers/admin/scripts/add_local_ids.py +80 -0
  16. spinta/cli/helpers/admin/scripts/citus_shard.py +126 -0
  17. spinta/cli/helpers/admin/scripts/remove_local_ids.py +55 -0
  18. spinta/cli/helpers/upgrade/scripts/backends/postgresql/comments.py +62 -26
  19. spinta/cli/inspect.py +3 -0
  20. spinta/cli/uncomment.py +80 -18
  21. spinta/components.py +7 -1
  22. spinta/config.py +3 -0
  23. spinta/config.yml +12 -1
  24. spinta/datasets/backends/dataframe/commands/read.py +5 -1
  25. spinta/datasets/backends/dataframe/ufuncs/query/components.py +26 -2
  26. spinta/datasets/backends/dataframe/ufuncs/query/ufuncs.py +33 -26
  27. spinta/datasets/backends/sql/commands/cast.py +89 -19
  28. spinta/datasets/backends/sql/commands/read.py +29 -79
  29. spinta/datasets/backends/sql/helpers.py +61 -0
  30. spinta/dimensions/scope/components.py +20 -2
  31. spinta/dimensions/scope/helpers.py +28 -2
  32. spinta/dimensions/scope/ufuncs.py +51 -0
  33. spinta/exceptions.py +17 -0
  34. spinta/manifests/commands/link.py +2 -0
  35. spinta/manifests/sql/helpers.py +1 -1
  36. spinta/manifests/tabular/helpers.py +2 -2
  37. spinta/testing/citus.py +96 -0
  38. spinta/testing/cli.py +13 -0
  39. spinta/testing/pytest.py +38 -24
  40. spinta/types/config.py +13 -1
  41. spinta/types/helpers.py +53 -3
  42. spinta/types/model.py +67 -4
  43. spinta/urlparams.py +85 -2
  44. spinta/utils/enums.py +17 -8
  45. spinta/utils/url.py +10 -0
  46. {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/METADATA +1 -1
  47. {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/RECORD +50 -44
  48. {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/WHEEL +0 -0
  49. {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/entry_points.txt +0 -0
  50. {spinta-0.2.dev24.dist-info → spinta-0.2.dev26.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import contextlib
4
+ import dataclasses
4
5
  from typing import Any, Type
5
6
  from typing import Dict
6
7
  from typing import Optional
7
8
  from typing import Set
8
9
 
9
- from spinta.backends.constants import BackendOrigin, BackendFeatures
10
+ from spinta.backends.constants import BackendOrigin, BackendFeatures, DistributionType
10
11
  from spinta.core.ufuncs import Env
11
12
  from spinta.ufuncs.resultbuilder.components import ResultBuilder
12
13
 
@@ -58,3 +59,9 @@ class Backend:
58
59
 
59
60
 
60
61
  SelectTree = Optional[Dict[str, "SelectTree"]]
62
+
63
+
64
+ @dataclasses.dataclass
65
+ class DistributionStrategy:
66
+ distribution_type: DistributionType
67
+ property: str | None = None
@@ -35,3 +35,13 @@ class BackendFeatures(enum.Enum):
35
35
 
36
36
  # Backend supports
37
37
  EXPAND = "EXPAND"
38
+
39
+ # Backend supports sharding
40
+ DISTRIBUTE = "DISTRIBUTE"
41
+
42
+
43
+ class DistributionType(enum.Enum):
44
+ SCHEMA = "schema"
45
+ TABLE = "table"
46
+ COPY = "copy"
47
+ UNDISTRIBUTED = "undistributed"
@@ -91,7 +91,7 @@ def get_select_tree(
91
91
  select = list(select.keys())
92
92
 
93
93
  select = _apply_always_show_id(context, action, select)
94
- if select is None and action in (Action.GETALL, Action.SEARCH):
94
+ if select is None and action in (Action.GETALL, Action.SEARCH, Action.GETONE):
95
95
  # If select is not given, select everything.
96
96
  select = {"*": {}}
97
97
  return flat_select_to_nested(select)
@@ -344,7 +344,7 @@ def get_ns_reserved_props(action: Action) -> list[str]:
344
344
  return []
345
345
 
346
346
 
347
- @dataclasses.dataclass
347
+ @dataclasses.dataclass(frozen=True)
348
348
  class TableIdentifier:
349
349
  """
350
350
  Represents a table identifier across logical (app) and PostgreSQL layers.
@@ -391,35 +391,42 @@ class TableIdentifier:
391
391
  table_arg: str | None = dataclasses.field(default=None)
392
392
  default_pg_schema: str | None = dataclasses.field(default=None)
393
393
 
394
- logical_name: str = dataclasses.field(init=False)
394
+ logical_name: str = dataclasses.field(init=False, compare=False)
395
395
  # Name with namespace connected with '/', like it is used with Model class
396
- logical_qualified_name: str = dataclasses.field(init=False)
396
+ logical_qualified_name: str = dataclasses.field(init=False, compare=False)
397
397
 
398
- pg_table_name: str = dataclasses.field(init=False)
399
- pg_schema_name: str | None = dataclasses.field(init=False)
398
+ pg_table_name: str = dataclasses.field(init=False, compare=False)
399
+ pg_schema_name: str | None = dataclasses.field(init=False, compare=False)
400
400
  # Used for hashed schema and table names
401
- pg_qualified_name: str = dataclasses.field(init=False)
401
+ pg_qualified_name: str = dataclasses.field(init=False, compare=False)
402
402
  # Escaped qualified name, used for queries
403
- pg_escaped_qualified_name: str = dataclasses.field(init=False)
403
+ pg_escaped_qualified_name: str = dataclasses.field(init=False, compare=False)
404
404
 
405
405
  def __post_init__(self):
406
- self.logical_name = self.base_name + self.table_type.value
406
+ logical_name = self.base_name + self.table_type.value
407
407
  if self.table_arg:
408
- self.logical_name += "/" + self.table_arg
408
+ logical_name += "/" + self.table_arg
409
409
 
410
- self.logical_qualified_name = f"{self.schema}/{self.logical_name}" if self.schema else self.logical_name
410
+ logical_qualified_name = f"{self.schema}/{logical_name}" if self.schema else logical_name
411
411
 
412
- self.pg_table_name = get_pg_name(self.logical_name)
413
- self.pg_schema_name = get_pg_name(self.schema) if self.schema else self.default_pg_schema
414
- self.pg_qualified_name = (
415
- f"{self.pg_schema_name}.{self.pg_table_name}" if self.pg_schema_name else self.pg_table_name
416
- )
417
- self.pg_escaped_qualified_name = (
418
- f"{pg_identifier_preparer.quote(self.pg_schema_name)}.{pg_identifier_preparer.quote(self.pg_table_name)}"
419
- if self.pg_schema_name
420
- else pg_identifier_preparer.quote(self.pg_table_name)
412
+ pg_table_name = get_pg_name(logical_name)
413
+ pg_schema_name = get_pg_name(self.schema) if self.schema else self.default_pg_schema
414
+ pg_qualified_name = f"{pg_schema_name}.{pg_table_name}" if pg_schema_name else pg_table_name
415
+ pg_escaped_qualified_name = (
416
+ f"{pg_identifier_preparer.quote(pg_schema_name)}.{pg_identifier_preparer.quote(pg_table_name)}"
417
+ if pg_schema_name
418
+ else pg_identifier_preparer.quote(pg_table_name)
421
419
  )
422
420
 
421
+ # This is needed because we want to make this dataclass hashable (frozen=True, does that)
422
+ # But because it becomes immutable, we need to set all the attributes manually (the same way dataclass __init__ does).
423
+ object.__setattr__(self, "logical_name", logical_name)
424
+ object.__setattr__(self, "logical_qualified_name", logical_qualified_name)
425
+ object.__setattr__(self, "pg_table_name", pg_table_name)
426
+ object.__setattr__(self, "pg_schema_name", pg_schema_name)
427
+ object.__setattr__(self, "pg_qualified_name", pg_qualified_name)
428
+ object.__setattr__(self, "pg_escaped_qualified_name", pg_escaped_qualified_name)
429
+
423
430
  def change_table_type(self, new_type: TableType, table_arg: str | None = None) -> "TableIdentifier":
424
431
  return dataclasses.replace(self, table_type=new_type, table_arg=table_arg)
425
432
 
@@ -5,6 +5,7 @@ import sqlalchemy as sa
5
5
  from spinta import commands
6
6
  from spinta.backends.postgresql.components import PostgreSQL
7
7
  from spinta.backends.postgresql.helpers.name import PG_NAMING_CONVENTION
8
+ from spinta.backends.postgresql.sqlalchemy import create_postgresql_engine
8
9
  from spinta.components import Context
9
10
  from spinta.utils.sqlalchemy import get_metadata_naming_convention
10
11
 
@@ -12,7 +13,7 @@ from spinta.utils.sqlalchemy import get_metadata_naming_convention
12
13
  @commands.load.register(Context, PostgreSQL, dict)
13
14
  def load(context: Context, backend: PostgreSQL, config: Dict[str, Any]):
14
15
  backend.dsn = config["dsn"]
15
- backend.engine = sa.create_engine(backend.dsn, echo=False)
16
+ backend.engine = create_postgresql_engine(backend.dsn, echo=False)
16
17
  backend.schema = sa.MetaData(backend.engine, naming_convention=get_metadata_naming_convention(PG_NAMING_CONVENTION))
17
18
  backend.tables = {}
18
19
 
@@ -1,7 +1,7 @@
1
- import sqlalchemy as sa
2
1
  import sqlalchemy.exc
3
2
 
4
3
  from spinta import commands
4
+ from spinta.backends.postgresql.sqlalchemy import create_postgresql_engine
5
5
  from spinta.components import Context
6
6
  from spinta.backends.postgresql.components import PostgreSQL
7
7
 
@@ -10,7 +10,7 @@ from spinta.backends.postgresql.components import PostgreSQL
10
10
  def wait(context: Context, backend: PostgreSQL, *, fail: bool = False) -> bool:
11
11
  rc = context.get("rc")
12
12
  dsn = rc.get("backends", backend.name, "dsn", required=True)
13
- engine = sa.create_engine(dsn, connect_args={"connect_timeout": 0})
13
+ engine = create_postgresql_engine(dsn, connect_args={"connect_timeout": 0})
14
14
  try:
15
15
  conn = engine.connect()
16
16
  except sqlalchemy.exc.OperationalError:
@@ -48,7 +48,7 @@ def wipe(context: Context, model: Model, backend: PostgreSQL):
48
48
  if changelog_table_identifier.pg_schema_name
49
49
  else f'"{seqname}"'
50
50
  )
51
- connection.execute(f"ALTER SEQUENCE {seq_escaped_named} RESTART")
51
+ connection.execute(sa.func.setval(seq_escaped_named, 1, False))
52
52
 
53
53
  # Delete data table
54
54
  table = backend.get_table(model)
@@ -26,7 +26,13 @@ class PostgreSQL(Backend):
26
26
  },
27
27
  }
28
28
 
29
- features = {BackendFeatures.FILE_BLOCKS, BackendFeatures.WRITE, BackendFeatures.EXPAND, BackendFeatures.PAGINATION}
29
+ features = {
30
+ BackendFeatures.FILE_BLOCKS,
31
+ BackendFeatures.WRITE,
32
+ BackendFeatures.EXPAND,
33
+ BackendFeatures.PAGINATION,
34
+ BackendFeatures.DISTRIBUTE,
35
+ }
30
36
 
31
37
  engine: Engine = None
32
38
  schema: sa.MetaData = None
@@ -3,7 +3,7 @@ from sqlalchemy.dialects.postgresql import UUID
3
3
  from sqlalchemy.dialects import postgresql
4
4
 
5
5
  import sqlalchemy as sa
6
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING, Generator
7
7
 
8
8
  from spinta.backends.helpers import TableIdentifier
9
9
  from spinta.backends.postgresql.helpers.name import name_changed
@@ -91,6 +91,19 @@ class RenameTableMigrationAction(MigrationAction):
91
91
  )
92
92
 
93
93
 
94
+ class SetTableCommentMigrationAction(MigrationAction):
95
+ def __init__(self, table_identifier: TableIdentifier, comment: str) -> None:
96
+ self.table_identifier = table_identifier
97
+ self.comment = comment
98
+
99
+ def execute(self, op: "Operations") -> None:
100
+ op.create_table_comment(
101
+ table_name=self.table_identifier.pg_table_name,
102
+ comment=self.comment,
103
+ schema=self.table_identifier.pg_schema_name,
104
+ )
105
+
106
+
94
107
  class AddColumnMigrationAction(MigrationAction):
95
108
  def __init__(self, table_identifier: TableIdentifier, column: sa.Column) -> None:
96
109
  self.table_identifier = table_identifier
@@ -149,6 +162,21 @@ class AlterColumnMigrationAction(MigrationAction):
149
162
  )
150
163
 
151
164
 
165
+ class SetColumnCommentMigrationAction(MigrationAction):
166
+ def __init__(self, table_identifier: TableIdentifier, column: str, comment: str) -> None:
167
+ self.table_identifier = table_identifier
168
+ self.comment = comment
169
+ self.column = column
170
+
171
+ def execute(self, op: "Operations") -> None:
172
+ op.alter_column(
173
+ table_name=self.table_identifier.pg_table_name,
174
+ column_name=self.column,
175
+ comment=self.comment,
176
+ schema=self.table_identifier.pg_schema_name,
177
+ )
178
+
179
+
152
180
  class DropConstraintMigrationAction(MigrationAction):
153
181
  def __init__(self, table_identifier: TableIdentifier, constraint_name: str) -> None:
154
182
  self.table_identifier = table_identifier
@@ -547,6 +575,50 @@ class CreateSchemaMigrationAction(MigrationAction):
547
575
  op.execute(self.query)
548
576
 
549
577
 
578
+ class DistributeSchema(MigrationAction):
579
+ def __init__(self, schema_name: str) -> None:
580
+ self.schema_name = schema_name
581
+ self.query = f"SELECT citus_schema_distribute('{pg_identifier_preparer.quote(schema_name)}')"
582
+
583
+ def execute(self, op: "Operations") -> None:
584
+ op.execute(self.query)
585
+
586
+
587
+ class DistributeReference(MigrationAction):
588
+ def __init__(self, table_identifier: TableIdentifier) -> None:
589
+ self.query = f"SELECT create_reference_table('{table_identifier.pg_escaped_qualified_name}')"
590
+
591
+ def execute(self, op: "Operations") -> None:
592
+ op.execute(self.query)
593
+
594
+
595
+ class DistributeTable(MigrationAction):
596
+ def __init__(self, table_identifier: TableIdentifier, column: str) -> None:
597
+ self.query = f"SELECT create_distributed_table('{table_identifier.pg_escaped_qualified_name}', '{column}')"
598
+
599
+ def execute(self, op: "Operations") -> None:
600
+ op.execute(self.query)
601
+
602
+
603
+ class UndistributeSchema(MigrationAction):
604
+ def __init__(self, schema_name: str) -> None:
605
+ self.schema_name = schema_name
606
+ self.query = f"SELECT citus_schema_undistribute('{pg_identifier_preparer.quote(schema_name)}')"
607
+
608
+ def execute(self, op: "Operations") -> None:
609
+ op.execute(self.query)
610
+
611
+
612
+ class UndistributeTable(MigrationAction):
613
+ def __init__(self, table_identifier: TableIdentifier) -> None:
614
+ self.query = (
615
+ f"SELECT undistribute_table('{table_identifier.pg_escaped_qualified_name}', cascade_via_foreign_keys=>true)"
616
+ )
617
+
618
+ def execute(self, op: "Operations") -> None:
619
+ op.execute(self.query)
620
+
621
+
550
622
  class MigrationHandler:
551
623
  def __init__(self) -> None:
552
624
  self.migrations: list[MigrationAction] = []
@@ -571,8 +643,15 @@ class MigrationHandler:
571
643
  return True
572
644
  return False
573
645
 
574
- def run_migrations(self, op: "Operations") -> None:
646
+ def gather_migrations(self) -> Generator[MigrationAction, None, None]:
575
647
  for migration in self.migrations:
576
- migration.execute(op)
648
+ yield migration
577
649
  for migration in self.foreign_key_migration:
650
+ yield migration
651
+
652
+ def run_migrations(self, op: "Operations") -> None:
653
+ for migration in self.gather_migrations():
578
654
  migration.execute(op)
655
+
656
+ def count(self) -> int:
657
+ return len(list(self.gather_migrations()))
@@ -0,0 +1,383 @@
1
+ import dataclasses
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+
5
+ import sqlalchemy as sa
6
+ from tqdm import tqdm
7
+ from multipledispatch import dispatch
8
+
9
+ from spinta.backends import Backend
10
+ from spinta.backends.constants import DistributionType
11
+ from spinta.backends.helpers import TableIdentifier
12
+ from spinta.backends.helpers import get_table_identifier
13
+ from spinta.backends.postgresql.components import PostgreSQL
14
+ from spinta.backends.postgresql.helpers.migrate.actions import (
15
+ MigrationHandler,
16
+ UndistributeSchema,
17
+ UndistributeTable,
18
+ DistributeReference,
19
+ DistributeTable,
20
+ DistributeSchema,
21
+ )
22
+ from spinta.cli.helpers.message import cli_message
23
+ from spinta.components import Context, Model
24
+ from spinta.exceptions import NotImplementedFeature
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class ShardingPlan:
29
+ schemas: set[str] = dataclasses.field(default_factory=set)
30
+ references: set[TableIdentifier] = dataclasses.field(default_factory=set)
31
+ distributed: dict[TableIdentifier, str] = dataclasses.field(default_factory=dict)
32
+ local: set[TableIdentifier] = dataclasses.field(default_factory=set)
33
+
34
+ _lookup: dict[TableIdentifier | str, DistributionType] = dataclasses.field(init=False, default_factory=dict)
35
+
36
+ def __sub__(self, other) -> "ShardingPlan":
37
+ return ShardingPlan(
38
+ schemas=self.schemas - other.schemas,
39
+ references=self.references - other.references,
40
+ distributed=dict(self.distributed.items() - other.distributed.items()),
41
+ local=self.local - other.local,
42
+ )
43
+
44
+ def __post_init__(self) -> None:
45
+ for schema in self.schemas:
46
+ self._lookup[schema] = DistributionType.SCHEMA
47
+
48
+ for table_identifier in self.distributed.keys():
49
+ self._lookup[table_identifier] = DistributionType.TABLE
50
+
51
+ for table_identifier in self.references:
52
+ self._lookup[table_identifier] = DistributionType.COPY
53
+
54
+ for table_identifier in self.local:
55
+ self._lookup[table_identifier] = DistributionType.UNDISTRIBUTED
56
+
57
+ def distribution_type(self, key: TableIdentifier | str) -> DistributionType | None:
58
+ return self._lookup.get(key, None)
59
+
60
+ def discard(self, key: TableIdentifier | str) -> None:
61
+ distribution_type = self._lookup.pop(key, None)
62
+ if distribution_type is None:
63
+ return
64
+
65
+ match distribution_type:
66
+ case DistributionType.SCHEMA:
67
+ self.schemas.discard(key)
68
+ case DistributionType.TABLE:
69
+ self.distributed.pop(key)
70
+ case DistributionType.COPY:
71
+ self.references.discard(key)
72
+ case DistributionType.UNDISTRIBUTED:
73
+ self.local.discard(key)
74
+
75
+
76
+ def _generate_current_distribution_query(schemas: list[str] | None = None) -> (str, dict):
77
+ base_query = """
78
+ SELECT
79
+ n.nspname AS schema_name,
80
+ c.relname AS table_name,
81
+ format('%I.%I', n.nspname, c.relname) AS full_table_name,
82
+ d.description AS table_comment,
83
+ COALESCE(ct.citus_table_type, 'local') AS distribution_type,
84
+ ct.distribution_column
85
+ FROM pg_class c
86
+ JOIN pg_namespace n ON n.oid = c.relnamespace
87
+ LEFT JOIN citus_tables ct ON ct.table_name = c.oid::regclass
88
+ LEFT JOIN pg_description d ON d.objoid = c.oid AND d.objsubid = 0
89
+ WHERE c.relkind = 'r'
90
+ AND n.nspname NOT LIKE 'pg_%'
91
+ AND n.nspname <> 'information_schema'
92
+ """
93
+
94
+ params = {}
95
+ if schemas:
96
+ base_query += " AND n.nspname = ANY(:schemas)"
97
+ params["schemas"] = schemas
98
+
99
+ base_query += " ORDER BY n.nspname, c.relname"
100
+ return base_query, params
101
+
102
+
103
+ @dispatch(Context)
104
+ def gather_current_sharding_plan(context: Context, **kwargs) -> dict[str, ShardingPlan]:
105
+ store = context.get("store")
106
+ backends = store.backends
107
+
108
+ plans: dict[str, ShardingPlan] = defaultdict(ShardingPlan)
109
+ for backend_name, backend in backends.items():
110
+ plans[backend_name] = gather_current_sharding_plan(context, backend, **kwargs)
111
+ return plans
112
+
113
+
114
+ @dispatch(Context, Backend)
115
+ def gather_current_sharding_plan(context: Context, backend: Backend, **kwargs) -> ShardingPlan:
116
+ # Currently, only postgresql backend supports citus distribution, instead of erroring, return empty plan.
117
+ return ShardingPlan()
118
+
119
+
120
+ @dispatch(Context, PostgreSQL)
121
+ def gather_current_sharding_plan(
122
+ context: Context, backend: PostgreSQL, schemas: list[str] | None = None, **kwargs
123
+ ) -> ShardingPlan:
124
+ plan = ShardingPlan()
125
+ with backend.begin() as conn:
126
+ query, params = _generate_current_distribution_query(schemas)
127
+ rows = conn.execute(sa.text(query), params).fetchall()
128
+
129
+ for row in rows:
130
+ if not row["table_comment"]:
131
+ continue
132
+
133
+ table_identifier = get_table_identifier(row["table_comment"])
134
+ match row["distribution_type"]:
135
+ case "schema":
136
+ plan.schemas.add(row["schema_name"])
137
+ case "distributed":
138
+ plan.distributed[table_identifier] = row["distribution_column"]
139
+ case "reference":
140
+ plan.references.add(table_identifier)
141
+ case _:
142
+ plan.local.add(table_identifier)
143
+ return plan
144
+
145
+
146
+ def create_sharding_plan(context: Context, models: list[Model], **kwargs) -> dict[str, ShardingPlan]:
147
+ plans = defaultdict(ShardingPlan)
148
+ for model in models:
149
+ if not isinstance(model.backend, PostgreSQL):
150
+ continue
151
+
152
+ if not model.external or not model.external.dataset:
153
+ continue
154
+
155
+ distribution_strategy = model.distribution_strategy
156
+ plan = plans[model.backend.name]
157
+ table_identifier = get_table_identifier(model)
158
+ match distribution_strategy.distribution_type:
159
+ case DistributionType.SCHEMA:
160
+ plan.schemas.add(table_identifier.pg_schema_name)
161
+ case DistributionType.TABLE:
162
+ plan.distributed[table_identifier] = distribution_strategy.property
163
+ case DistributionType.COPY:
164
+ plan.references.add(table_identifier)
165
+ case _:
166
+ plan.local.add(table_identifier)
167
+
168
+ return plans
169
+
170
+
171
+ def invalidate_default_distribution(
172
+ context: Context, backend: PostgreSQL, plan: ShardingPlan, verbose: bool = False, **kwargs
173
+ ) -> ShardingPlan:
174
+ default_distribution = context.get("config").default_distribution_strategy
175
+ if not default_distribution:
176
+ return plan
177
+
178
+ match default_distribution.distribution_type:
179
+ case DistributionType.SCHEMA:
180
+ return invalidate_default_schema_distributions(context, backend, plan)
181
+ case _:
182
+ if verbose:
183
+ cli_message(
184
+ f"Skipped invalidation of default distribution for {default_distribution.distribution_type.value} type"
185
+ )
186
+
187
+ return plan
188
+
189
+
190
+ def valid_schema_distribution_foreign_key(plan: ShardingPlan, schema: str, foreign_key: dict) -> bool:
191
+ if foreign_key["referred_schema"] == schema:
192
+ return True
193
+
194
+ for reference in plan.references:
195
+ if (
196
+ reference.pg_schema_name == foreign_key["referred_schema"]
197
+ and reference.pg_table_name == foreign_key["referred_table"]
198
+ ):
199
+ return True
200
+
201
+ return False
202
+
203
+
204
+ @dispatch(Context, Backend, ShardingPlan)
205
+ def invalidate_default_schema_distributions(
206
+ context: Context, backend: Backend, plan: ShardingPlan, **kwargs
207
+ ) -> ShardingPlan:
208
+ raise NotImplementedFeature(f"Ability to invalidate default schema distribution for {backend.type!r} backend type")
209
+
210
+
211
+ @dispatch(Context, PostgreSQL, ShardingPlan)
212
+ def invalidate_default_schema_distributions(
213
+ context: Context, backend: PostgreSQL, plan: ShardingPlan, **kwargs
214
+ ) -> ShardingPlan:
215
+ if not plan.schemas:
216
+ return plan
217
+
218
+ invalid_schemas = set()
219
+
220
+ inspector = sa.inspect(backend.engine)
221
+
222
+ plan_copy = deepcopy(plan)
223
+ for schema in plan.schemas:
224
+ tables = inspector.get_table_names(schema=schema)
225
+
226
+ for table in tables:
227
+ foreign_keys = inspector.get_foreign_keys(table, schema=schema)
228
+ if not foreign_keys:
229
+ continue
230
+
231
+ for key in foreign_keys:
232
+ if not valid_schema_distribution_foreign_key(plan, schema, key):
233
+ invalid_schemas.add(key["referred_schema"])
234
+ invalid_schemas.add(schema)
235
+
236
+ if not invalid_schemas:
237
+ return plan
238
+
239
+ for invalid_schema in invalid_schemas:
240
+ plan_copy.schemas.discard(invalid_schema)
241
+
242
+ return plan_copy
243
+
244
+
245
+ def _build_fk_graph(conn: sa.engine.Connection) -> dict[tuple[str, str], set[tuple[str, str]]]:
246
+ graph: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set)
247
+
248
+ rows = conn.execute("""
249
+ SELECT
250
+ src_ns.nspname AS source_schema,
251
+ src.relname AS source_table,
252
+ tgt_ns.nspname AS target_schema,
253
+ tgt.relname AS target_table
254
+ FROM pg_constraint con
255
+ JOIN pg_class src
256
+ ON src.oid = con.conrelid
257
+ JOIN pg_namespace src_ns
258
+ ON src_ns.oid = src.relnamespace
259
+ JOIN pg_class tgt
260
+ ON tgt.oid = con.confrelid
261
+ JOIN pg_namespace tgt_ns
262
+ ON tgt_ns.oid = tgt.relnamespace
263
+ WHERE con.contype = 'f'
264
+ """)
265
+
266
+ for src_schema, src_table, tgt_schema, tgt_table in rows.fetchall():
267
+ src = (src_schema, src_table)
268
+ tgt = (tgt_schema, tgt_table)
269
+
270
+ if src == tgt:
271
+ continue
272
+
273
+ graph[src].add(tgt)
274
+ graph[tgt].add(src)
275
+
276
+ visited: set[tuple[str, str]] = set()
277
+ key_to_component: dict[tuple[str, str], set[tuple[str, str]]] = {}
278
+
279
+ for node in graph:
280
+ if node in visited:
281
+ continue
282
+
283
+ stack = [node]
284
+ component: set[tuple[str, str]] = set()
285
+
286
+ while stack:
287
+ cur = stack.pop()
288
+ if cur in visited:
289
+ continue
290
+
291
+ visited.add(cur)
292
+ component.add(cur)
293
+ stack.extend(graph[cur] - visited)
294
+
295
+ for n in component:
296
+ key_to_component[n] = component
297
+
298
+ return key_to_component
299
+
300
+
301
+ def build_fk_components(
302
+ conn: sa.engine.Connection, tables: set[TableIdentifier]
303
+ ) -> dict[TableIdentifier, set[TableIdentifier]]:
304
+ """
305
+ Build FK-connected components using FULL DB graph,
306
+ then return mapping only for given `tables`.
307
+ """
308
+
309
+ target_map = {(t.pg_schema_name, t.pg_table_name): t for t in tables}
310
+ graph = _build_fk_graph(conn)
311
+
312
+ result = {}
313
+ for key, table in target_map.items():
314
+ full_component = graph.get(key, {key})
315
+ result[table] = {target_map[k] for k in full_component if k in target_map}
316
+
317
+ return result
318
+
319
+
320
+ def undistribute_all(
321
+ context: Context,
322
+ backend: PostgreSQL,
323
+ plan: ShardingPlan,
324
+ handler: MigrationHandler,
325
+ progress_bar: tqdm | None = None,
326
+ **kwargs,
327
+ ) -> None:
328
+ for schema in plan.schemas:
329
+ handler.add_action(UndistributeSchema(schema_name=schema))
330
+ if progress_bar is not None:
331
+ progress_bar.update(1)
332
+
333
+ if not (plan.references or plan.distributed):
334
+ return
335
+
336
+ processed = set()
337
+ undistributed_tables = plan.references | set(table for table in plan.distributed.keys())
338
+ with backend.begin() as conn:
339
+ component_map = build_fk_components(conn, undistributed_tables)
340
+
341
+ for table in plan.distributed.keys():
342
+ if table in processed:
343
+ continue
344
+
345
+ handler.add_action(UndistributeTable(table_identifier=table))
346
+ if progress_bar is not None:
347
+ progress_bar.update(1)
348
+ component = component_map[table]
349
+ processed.update(component)
350
+
351
+ for table in plan.references:
352
+ if table in processed:
353
+ continue
354
+
355
+ handler.add_action(UndistributeTable(table_identifier=table))
356
+ if progress_bar is not None:
357
+ progress_bar.update(1)
358
+ component = component_map[table]
359
+ processed.update(component)
360
+
361
+
362
+ def distribute_all(
363
+ context: Context,
364
+ backend: PostgreSQL,
365
+ plan: ShardingPlan,
366
+ handler: MigrationHandler,
367
+ progress_bar: tqdm | None = None,
368
+ **kwargs,
369
+ ) -> None:
370
+ for table in plan.references:
371
+ handler.add_action(DistributeReference(table_identifier=table))
372
+ if progress_bar is not None:
373
+ progress_bar.update(1)
374
+
375
+ for table, column in plan.distributed.items():
376
+ handler.add_action(DistributeTable(table_identifier=table, column=column))
377
+ if progress_bar is not None:
378
+ progress_bar.update(1)
379
+
380
+ for schema in plan.schemas:
381
+ handler.add_action(DistributeSchema(schema_name=schema))
382
+ if progress_bar is not None:
383
+ progress_bar.update(1)