infrahub-server 1.3.1__py3-none-any.whl → 1.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- infrahub/cli/db.py +194 -13
- infrahub/core/branch/enums.py +8 -0
- infrahub/core/branch/models.py +28 -5
- infrahub/core/branch/tasks.py +5 -7
- infrahub/core/diff/calculator.py +4 -1
- infrahub/core/diff/coordinator.py +32 -34
- infrahub/core/diff/diff_locker.py +26 -0
- infrahub/core/diff/query_parser.py +23 -32
- infrahub/core/graph/__init__.py +1 -1
- infrahub/core/initialization.py +4 -3
- infrahub/core/merge.py +31 -16
- infrahub/core/migrations/graph/__init__.py +24 -0
- infrahub/core/migrations/graph/m012_convert_account_generic.py +4 -3
- infrahub/core/migrations/graph/m013_convert_git_password_credential.py +4 -3
- infrahub/core/migrations/graph/m032_cleanup_orphaned_branch_relationships.py +105 -0
- infrahub/core/migrations/graph/m033_deduplicate_relationship_vertices.py +97 -0
- infrahub/core/node/__init__.py +3 -0
- infrahub/core/node/constraints/grouped_uniqueness.py +88 -132
- infrahub/core/node/resource_manager/ip_address_pool.py +5 -3
- infrahub/core/node/resource_manager/ip_prefix_pool.py +7 -4
- infrahub/core/node/resource_manager/number_pool.py +3 -1
- infrahub/core/node/standard.py +4 -0
- infrahub/core/query/branch.py +25 -56
- infrahub/core/query/node.py +78 -24
- infrahub/core/query/relationship.py +11 -8
- infrahub/core/relationship/model.py +10 -5
- infrahub/core/validators/uniqueness/model.py +17 -0
- infrahub/core/validators/uniqueness/query.py +212 -1
- infrahub/dependencies/builder/diff/coordinator.py +3 -0
- infrahub/dependencies/builder/diff/locker.py +8 -0
- infrahub/graphql/mutations/main.py +25 -4
- infrahub/graphql/mutations/tasks.py +2 -0
- infrahub_sdk/node/node.py +22 -10
- infrahub_sdk/node/related_node.py +7 -0
- {infrahub_server-1.3.1.dist-info → infrahub_server-1.3.3.dist-info}/METADATA +1 -1
- {infrahub_server-1.3.1.dist-info → infrahub_server-1.3.3.dist-info}/RECORD +42 -37
- infrahub_testcontainers/container.py +1 -1
- infrahub_testcontainers/docker-compose-cluster.test.yml +3 -0
- infrahub_testcontainers/docker-compose.test.yml +1 -0
- {infrahub_server-1.3.1.dist-info → infrahub_server-1.3.3.dist-info}/LICENSE.txt +0 -0
- {infrahub_server-1.3.1.dist-info → infrahub_server-1.3.3.dist-info}/WHEEL +0 -0
- {infrahub_server-1.3.1.dist-info → infrahub_server-1.3.3.dist-info}/entry_points.txt +0 -0
infrahub/cli/db.py
CHANGED
|
@@ -8,7 +8,7 @@ from csv import DictReader, DictWriter
|
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from enum import Enum
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
|
12
12
|
|
|
13
13
|
import typer
|
|
14
14
|
import ujson
|
|
@@ -38,7 +38,7 @@ from infrahub.core.initialization import (
|
|
|
38
38
|
initialization,
|
|
39
39
|
initialize_registry,
|
|
40
40
|
)
|
|
41
|
-
from infrahub.core.migrations.graph import get_graph_migrations
|
|
41
|
+
from infrahub.core.migrations.graph import get_graph_migrations, get_migration_by_number
|
|
42
42
|
from infrahub.core.migrations.schema.models import SchemaApplyMigrationData
|
|
43
43
|
from infrahub.core.migrations.schema.tasks import schema_apply_migrations
|
|
44
44
|
from infrahub.core.schema import SchemaRoot, core_models, internal_schema
|
|
@@ -58,8 +58,15 @@ from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
|
|
|
58
58
|
from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
|
|
59
59
|
from .patch import patch_app
|
|
60
60
|
|
|
61
|
+
|
|
62
|
+
def get_timestamp_string() -> str:
|
|
63
|
+
"""Generate a timestamp string in the format YYYYMMDD-HHMMSS."""
|
|
64
|
+
return datetime.now(tz=timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
65
|
+
|
|
66
|
+
|
|
61
67
|
if TYPE_CHECKING:
|
|
62
68
|
from infrahub.cli.context import CliContext
|
|
69
|
+
from infrahub.core.migrations.shared import ArbitraryMigration, GraphMigration, InternalSchemaMigration
|
|
63
70
|
from infrahub.database import InfrahubDatabase
|
|
64
71
|
from infrahub.database.index import IndexManagerBase
|
|
65
72
|
|
|
@@ -154,6 +161,7 @@ async def migrate_cmd(
|
|
|
154
161
|
ctx: typer.Context,
|
|
155
162
|
check: bool = typer.Option(False, help="Check the state of the database without applying the migrations."),
|
|
156
163
|
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
|
|
164
|
+
migration_number: int | None = typer.Option(None, help="Apply a specific migration by number"),
|
|
157
165
|
) -> None:
|
|
158
166
|
"""Check the current format of the internal graph and apply the necessary migrations"""
|
|
159
167
|
logging.getLogger("infrahub").setLevel(logging.WARNING)
|
|
@@ -165,7 +173,7 @@ async def migrate_cmd(
|
|
|
165
173
|
context: CliContext = ctx.obj
|
|
166
174
|
dbdriver = await context.init_db(retry=1)
|
|
167
175
|
|
|
168
|
-
await migrate_database(db=dbdriver, initialize=True, check=check)
|
|
176
|
+
await migrate_database(db=dbdriver, initialize=True, check=check, migration_number=migration_number)
|
|
169
177
|
|
|
170
178
|
await dbdriver.close()
|
|
171
179
|
|
|
@@ -287,7 +295,9 @@ async def index(
|
|
|
287
295
|
await dbdriver.close()
|
|
288
296
|
|
|
289
297
|
|
|
290
|
-
async def migrate_database(
|
|
298
|
+
async def migrate_database(
|
|
299
|
+
db: InfrahubDatabase, initialize: bool = False, check: bool = False, migration_number: int | str | None = None
|
|
300
|
+
) -> bool:
|
|
291
301
|
"""Apply the latest migrations to the database, this function will print the status directly in the console.
|
|
292
302
|
|
|
293
303
|
Returns a boolean indicating whether a migration failed or if all migrations succeeded.
|
|
@@ -295,6 +305,7 @@ async def migrate_database(db: InfrahubDatabase, initialize: bool = False, check
|
|
|
295
305
|
Args:
|
|
296
306
|
db: The database object.
|
|
297
307
|
check: If True, the function will only check the status of the database and not apply the migrations. Defaults to False.
|
|
308
|
+
migration_number: If provided, the function will only apply the migration with the given number. Defaults to None.
|
|
298
309
|
"""
|
|
299
310
|
rprint("Checking current state of the Database")
|
|
300
311
|
|
|
@@ -302,15 +313,28 @@ async def migrate_database(db: InfrahubDatabase, initialize: bool = False, check
|
|
|
302
313
|
await initialize_registry(db=db)
|
|
303
314
|
|
|
304
315
|
root_node = await get_root_node(db=db)
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
316
|
+
if migration_number:
|
|
317
|
+
migration = get_migration_by_number(migration_number)
|
|
318
|
+
migrations: Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration] = [migration]
|
|
319
|
+
if check:
|
|
320
|
+
if root_node.graph_version > migration.minimum_version:
|
|
321
|
+
rprint(
|
|
322
|
+
f"Migration {migration_number} already applied. To apply again, run the command without the --check flag."
|
|
323
|
+
)
|
|
324
|
+
return True
|
|
325
|
+
rprint(
|
|
326
|
+
f"Migration {migration_number} needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations."
|
|
327
|
+
)
|
|
328
|
+
return False
|
|
329
|
+
else:
|
|
330
|
+
migrations = await get_graph_migrations(root=root_node)
|
|
331
|
+
if not migrations:
|
|
332
|
+
rprint(f"Database up-to-date (v{root_node.graph_version}), no migration to execute.")
|
|
333
|
+
return True
|
|
310
334
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
335
|
+
rprint(
|
|
336
|
+
f"Database needs to be updated (v{root_node.graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending"
|
|
337
|
+
)
|
|
314
338
|
|
|
315
339
|
if check:
|
|
316
340
|
return True
|
|
@@ -529,7 +553,7 @@ WITH n, edges + root_edges AS edges, CASE
|
|
|
529
553
|
END AS vertices
|
|
530
554
|
RETURN vertices, edges
|
|
531
555
|
""" % {"id_func": db.get_id_function_name()}
|
|
532
|
-
timestamp_str =
|
|
556
|
+
timestamp_str = get_timestamp_string()
|
|
533
557
|
export_dir /= Path(f"export-{timestamp_str}")
|
|
534
558
|
if not export_dir.exists():
|
|
535
559
|
export_dir.mkdir(parents=True)
|
|
@@ -737,3 +761,160 @@ async def load_export(db: InfrahubDatabase, export_dir: Path, query_limit: int =
|
|
|
737
761
|
await load_edges(db=db, edge_type=edge_type, edge_dicts=edge_dicts)
|
|
738
762
|
rprint("Edges loaded")
|
|
739
763
|
rprint(f"{SUCCESS_BADGE} Export loaded")
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
@app.command(name="check")
|
|
767
|
+
async def check_cmd(
|
|
768
|
+
ctx: typer.Context,
|
|
769
|
+
output_dir: Path = typer.Option( # noqa: B008
|
|
770
|
+
None, help="Directory to save detailed check results (defaults to infrahub_db_check_YYYYMMDD-HHMMSS)"
|
|
771
|
+
),
|
|
772
|
+
config_file: str = typer.Option(
|
|
773
|
+
"infrahub.toml", envvar="INFRAHUB_CONFIG", help="Location of the configuration file to use for Infrahub"
|
|
774
|
+
),
|
|
775
|
+
) -> None:
|
|
776
|
+
"""Run database sanity checks and output the results to the CSV files."""
|
|
777
|
+
logging.getLogger("infrahub").setLevel(logging.WARNING)
|
|
778
|
+
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
|
779
|
+
logging.getLogger("prefect").setLevel(logging.ERROR)
|
|
780
|
+
|
|
781
|
+
config.load_and_exit(config_file_name=config_file)
|
|
782
|
+
|
|
783
|
+
# Create output directory if not specified
|
|
784
|
+
if output_dir is None:
|
|
785
|
+
timestamp_str = get_timestamp_string()
|
|
786
|
+
output_dir = Path(f"infrahub_db_check_{timestamp_str}")
|
|
787
|
+
|
|
788
|
+
if not output_dir.exists():
|
|
789
|
+
output_dir.mkdir(parents=True)
|
|
790
|
+
|
|
791
|
+
context: CliContext = ctx.obj
|
|
792
|
+
dbdriver = await context.init_db(retry=1)
|
|
793
|
+
|
|
794
|
+
await run_database_checks(db=dbdriver, output_dir=output_dir)
|
|
795
|
+
|
|
796
|
+
await dbdriver.close()
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
async def run_database_checks(db: InfrahubDatabase, output_dir: Path) -> None:
|
|
800
|
+
"""Run a series of database health checks and output the results to the terminal.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
db: The database object.
|
|
804
|
+
output_dir: Directory to save detailed check results.
|
|
805
|
+
"""
|
|
806
|
+
rprint("Running database health checks...")
|
|
807
|
+
|
|
808
|
+
# Check 1: Duplicate active relationships
|
|
809
|
+
rprint("\n[bold cyan]Check 1: Duplicate Active Relationships[/bold cyan]")
|
|
810
|
+
duplicate_active_rels_query = """
|
|
811
|
+
MATCH (a:Node)-[e1:IS_RELATED {status: "active"}]-(r:Relationship)-[e2:IS_RELATED {branch: e1.branch, status: "active"}]-(b:Node)
|
|
812
|
+
WHERE a.uuid < b.uuid
|
|
813
|
+
AND e1.to IS NULL
|
|
814
|
+
AND e2.to IS NULL
|
|
815
|
+
WITH DISTINCT a.uuid AS a_uuid,
|
|
816
|
+
b.uuid AS b_uuid,
|
|
817
|
+
r.name AS r_name,
|
|
818
|
+
e1.branch AS branch,
|
|
819
|
+
CASE
|
|
820
|
+
WHEN startNode(e1) = a AND startNode(e2) = r THEN "out"
|
|
821
|
+
WHEN startNode(e1) = r AND startNode(e2) = b THEN "in"
|
|
822
|
+
ELSE "bidir"
|
|
823
|
+
END AS direction,
|
|
824
|
+
count(*) AS num_paths,
|
|
825
|
+
collect(DISTINCT a.kind) AS a_kinds,
|
|
826
|
+
collect(DISTINCT b.kind) AS b_kinds
|
|
827
|
+
WHERE num_paths > 1
|
|
828
|
+
RETURN a_uuid, a_kinds, b_uuid, b_kinds, r_name, branch, direction, num_paths
|
|
829
|
+
"""
|
|
830
|
+
|
|
831
|
+
results = await db.execute_query(query=duplicate_active_rels_query)
|
|
832
|
+
if results:
|
|
833
|
+
rprint(f"[red]Found {len(results)} duplicate active relationships[/red]")
|
|
834
|
+
# Write detailed results to file
|
|
835
|
+
output_file = output_dir / "duplicate_active_relationships.csv"
|
|
836
|
+
with output_file.open(mode="w", newline="") as f:
|
|
837
|
+
writer = DictWriter(
|
|
838
|
+
f, fieldnames=["a_uuid", "a_kinds", "b_uuid", "b_kinds", "r_name", "branch", "direction", "num_paths"]
|
|
839
|
+
)
|
|
840
|
+
writer.writeheader()
|
|
841
|
+
for result in results:
|
|
842
|
+
writer.writerow(dict(result))
|
|
843
|
+
rprint(f" Detailed results written to: {output_file}")
|
|
844
|
+
else:
|
|
845
|
+
rprint(f"{SUCCESS_BADGE} No duplicate active relationships found")
|
|
846
|
+
|
|
847
|
+
# Check 2: Duplicated relationship nodes
|
|
848
|
+
rprint("\n[bold cyan]Check 2: Duplicated Relationship Nodes[/bold cyan]")
|
|
849
|
+
duplicate_rel_nodes_query = """
|
|
850
|
+
MATCH (r:Relationship)
|
|
851
|
+
WITH r.uuid AS r_uuid, COUNT(*) AS num_rels
|
|
852
|
+
WHERE num_rels > 1
|
|
853
|
+
MATCH (n:Node)-[:IS_RELATED]-(r:Relationship {uuid: r_uuid})
|
|
854
|
+
WITH DISTINCT r_uuid, n.uuid AS n_uuid, n.kind AS n_kind
|
|
855
|
+
WITH r_uuid, collect([n_uuid, n_kind]) AS node_details
|
|
856
|
+
RETURN r_uuid, node_details
|
|
857
|
+
"""
|
|
858
|
+
|
|
859
|
+
results = await db.execute_query(query=duplicate_rel_nodes_query)
|
|
860
|
+
if results:
|
|
861
|
+
rprint(f"[red]Found {len(results)} duplicated relationship nodes[/red]")
|
|
862
|
+
# Write detailed results to file
|
|
863
|
+
output_file = output_dir / "duplicated_relationship_nodes.csv"
|
|
864
|
+
with output_file.open(mode="w", newline="") as f:
|
|
865
|
+
writer = DictWriter(f, fieldnames=["r_uuid", "node_details"])
|
|
866
|
+
writer.writeheader()
|
|
867
|
+
for result in results:
|
|
868
|
+
writer.writerow(dict(result))
|
|
869
|
+
rprint(f" Detailed results written to: {output_file}")
|
|
870
|
+
else:
|
|
871
|
+
rprint(f"{SUCCESS_BADGE} No duplicated relationship nodes found")
|
|
872
|
+
|
|
873
|
+
# Check 3: Duplicated edges
|
|
874
|
+
rprint("\n[bold cyan]Check 3: Duplicated Edges[/bold cyan]")
|
|
875
|
+
duplicate_edges_query = """
|
|
876
|
+
MATCH (a)
|
|
877
|
+
CALL (a) {
|
|
878
|
+
MATCH (a)-[e]->(b)
|
|
879
|
+
WHERE elementId(a) < elementId(b)
|
|
880
|
+
WITH DISTINCT a, b, type(e) AS e_type, count(*) AS total_num_edges
|
|
881
|
+
WHERE total_num_edges > 1
|
|
882
|
+
MATCH (a)-[e]->(b)
|
|
883
|
+
WHERE type(e) = e_type
|
|
884
|
+
WITH
|
|
885
|
+
elementId(a) AS a_id,
|
|
886
|
+
labels(a) AS a_labels,
|
|
887
|
+
elementId(b) AS b_id,
|
|
888
|
+
labels(b) AS b_labels,
|
|
889
|
+
type(e) AS e_type,
|
|
890
|
+
e.branch AS branch,
|
|
891
|
+
e.status AS status,
|
|
892
|
+
e.from AS time,
|
|
893
|
+
collect(e) AS edges
|
|
894
|
+
WITH a_id, a_labels, b_id, b_labels, e_type, branch, status, time, size(edges) AS num_edges
|
|
895
|
+
WHERE num_edges > 1
|
|
896
|
+
WITH a_id, a_labels, b_id, b_labels, e_type, branch, status, time, num_edges
|
|
897
|
+
RETURN a_id, a_labels, b_id, b_labels, e_type, branch, status, time, num_edges
|
|
898
|
+
}
|
|
899
|
+
RETURN a_id, a_labels, b_id, b_labels, e_type, branch, status, time, num_edges
|
|
900
|
+
"""
|
|
901
|
+
|
|
902
|
+
results = await db.execute_query(query=duplicate_edges_query)
|
|
903
|
+
if results:
|
|
904
|
+
rprint(f"[red]Found {len(results)} sets of duplicated edges[/red]")
|
|
905
|
+
# Write detailed results to file
|
|
906
|
+
output_file = output_dir / "duplicated_edges.csv"
|
|
907
|
+
with output_file.open(mode="w", newline="") as f:
|
|
908
|
+
writer = DictWriter(
|
|
909
|
+
f,
|
|
910
|
+
fieldnames=["a_id", "a_labels", "b_id", "b_labels", "e_type", "branch", "status", "time", "num_edges"],
|
|
911
|
+
)
|
|
912
|
+
writer.writeheader()
|
|
913
|
+
for result in results:
|
|
914
|
+
writer.writerow(dict(result))
|
|
915
|
+
rprint(f" Detailed results written to: {output_file}")
|
|
916
|
+
else:
|
|
917
|
+
rprint(f"{SUCCESS_BADGE} No duplicated edges found")
|
|
918
|
+
|
|
919
|
+
rprint(f"\n{SUCCESS_BADGE} Database health checks completed")
|
|
920
|
+
rprint(f"Detailed results saved to: {output_dir.absolute()}")
|
infrahub/core/branch/models.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
|
5
5
|
|
|
6
6
|
from pydantic import Field, field_validator
|
|
7
7
|
|
|
@@ -21,6 +21,8 @@ from infrahub.core.registry import registry
|
|
|
21
21
|
from infrahub.core.timestamp import Timestamp
|
|
22
22
|
from infrahub.exceptions import BranchNotFoundError, InitializationError, ValidationError
|
|
23
23
|
|
|
24
|
+
from .enums import BranchStatus
|
|
25
|
+
|
|
24
26
|
if TYPE_CHECKING:
|
|
25
27
|
from infrahub.database import InfrahubDatabase
|
|
26
28
|
|
|
@@ -29,7 +31,7 @@ class Branch(StandardNode):
|
|
|
29
31
|
name: str = Field(
|
|
30
32
|
max_length=250, min_length=3, description="Name of the branch (git ref standard)", validate_default=True
|
|
31
33
|
)
|
|
32
|
-
status:
|
|
34
|
+
status: BranchStatus = BranchStatus.OPEN
|
|
33
35
|
description: str = ""
|
|
34
36
|
origin_branch: str = "main"
|
|
35
37
|
branched_from: Optional[str] = Field(default=None, validate_default=True)
|
|
@@ -131,14 +133,17 @@ class Branch(StandardNode):
|
|
|
131
133
|
return True
|
|
132
134
|
|
|
133
135
|
@classmethod
|
|
134
|
-
async def get_by_name(cls, name: str, db: InfrahubDatabase) -> Branch:
|
|
136
|
+
async def get_by_name(cls, name: str, db: InfrahubDatabase, ignore_deleting: bool = True) -> Branch:
|
|
135
137
|
query = """
|
|
136
138
|
MATCH (n:Branch)
|
|
137
139
|
WHERE n.name = $name
|
|
140
|
+
AND NOT n.status IN $ignore_statuses
|
|
138
141
|
RETURN n
|
|
139
142
|
"""
|
|
140
143
|
|
|
141
|
-
params = {"name": name}
|
|
144
|
+
params: dict[str, Any] = {"name": name}
|
|
145
|
+
if ignore_deleting:
|
|
146
|
+
params["ignore_statuses"] = [BranchStatus.DELETING.value]
|
|
142
147
|
|
|
143
148
|
results = await db.execute_query(query=query, params=params, name="branch_get_by_name", type=QueryType.READ)
|
|
144
149
|
|
|
@@ -147,6 +152,20 @@ class Branch(StandardNode):
|
|
|
147
152
|
|
|
148
153
|
return cls.from_db(results[0].values()[0])
|
|
149
154
|
|
|
155
|
+
@classmethod
|
|
156
|
+
async def get_list(
|
|
157
|
+
cls,
|
|
158
|
+
db: InfrahubDatabase,
|
|
159
|
+
limit: int = 1000,
|
|
160
|
+
ids: list[str] | None = None,
|
|
161
|
+
name: str | None = None,
|
|
162
|
+
**kwargs: dict[str, Any],
|
|
163
|
+
) -> list[Self]:
|
|
164
|
+
branches = await super().get_list(db=db, limit=limit, ids=ids, name=name, **kwargs)
|
|
165
|
+
branches = [branch for branch in branches if branch.status != BranchStatus.DELETING]
|
|
166
|
+
|
|
167
|
+
return branches
|
|
168
|
+
|
|
150
169
|
@classmethod
|
|
151
170
|
def isinstance(cls, obj: Any) -> bool:
|
|
152
171
|
return isinstance(obj, cls)
|
|
@@ -248,9 +267,13 @@ class Branch(StandardNode):
|
|
|
248
267
|
raise ValidationError(f"Unable to delete {self.name} it is the default branch.")
|
|
249
268
|
if self.is_global:
|
|
250
269
|
raise ValidationError(f"Unable to delete {self.name} this is an internal branch.")
|
|
251
|
-
|
|
270
|
+
|
|
271
|
+
self.status = BranchStatus.DELETING
|
|
272
|
+
await self.save(db=db)
|
|
273
|
+
|
|
252
274
|
query = await DeleteBranchRelationshipsQuery.init(db=db, branch_name=self.name)
|
|
253
275
|
await query.execute(db=db)
|
|
276
|
+
await super().delete(db=db)
|
|
254
277
|
|
|
255
278
|
def get_query_filter_relationships(
|
|
256
279
|
self, rel_labels: list, at: Optional[Union[Timestamp, str]] = None, include_outside_parentheses: bool = False
|
infrahub/core/branch/tasks.py
CHANGED
|
@@ -15,6 +15,7 @@ from infrahub.core.branch import Branch
|
|
|
15
15
|
from infrahub.core.changelog.diff import DiffChangelogCollector, MigrationTracker
|
|
16
16
|
from infrahub.core.constants import MutationAction
|
|
17
17
|
from infrahub.core.diff.coordinator import DiffCoordinator
|
|
18
|
+
from infrahub.core.diff.diff_locker import DiffLocker
|
|
18
19
|
from infrahub.core.diff.ipam_diff_parser import IpamDiffParser
|
|
19
20
|
from infrahub.core.diff.merger.merger import DiffMerger
|
|
20
21
|
from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRoot, EnrichedDiffRootMetadata
|
|
@@ -31,7 +32,7 @@ from infrahub.dependencies.registry import get_component_registry
|
|
|
31
32
|
from infrahub.events.branch_action import BranchCreatedEvent, BranchDeletedEvent, BranchMergedEvent, BranchRebasedEvent
|
|
32
33
|
from infrahub.events.models import EventMeta, InfrahubEvent
|
|
33
34
|
from infrahub.events.node_action import get_node_event
|
|
34
|
-
from infrahub.exceptions import BranchNotFoundError,
|
|
35
|
+
from infrahub.exceptions import BranchNotFoundError, ValidationError
|
|
35
36
|
from infrahub.graphql.mutations.models import BranchCreateModel # noqa: TC001
|
|
36
37
|
from infrahub.services import InfrahubServices # noqa: TC001 needed for prefect flow
|
|
37
38
|
from infrahub.workflows.catalogue import (
|
|
@@ -65,6 +66,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
|
|
|
65
66
|
diff_merger=diff_merger,
|
|
66
67
|
diff_repository=diff_repository,
|
|
67
68
|
source_branch=obj,
|
|
69
|
+
diff_locker=DiffLocker(),
|
|
68
70
|
service=service,
|
|
69
71
|
)
|
|
70
72
|
|
|
@@ -221,14 +223,10 @@ async def merge_branch(
|
|
|
221
223
|
diff_merger=diff_merger,
|
|
222
224
|
diff_repository=diff_repository,
|
|
223
225
|
source_branch=obj,
|
|
226
|
+
diff_locker=DiffLocker(),
|
|
224
227
|
service=service,
|
|
225
228
|
)
|
|
226
|
-
|
|
227
|
-
branch_diff = await merger.merge()
|
|
228
|
-
except Exception as exc:
|
|
229
|
-
log.exception("Merge failed, beginning rollback")
|
|
230
|
-
await merger.rollback()
|
|
231
|
-
raise MergeFailedError(branch_name=branch) from exc
|
|
229
|
+
branch_diff = await merger.merge()
|
|
232
230
|
await merger.update_schema()
|
|
233
231
|
|
|
234
232
|
changelog_collector = DiffChangelogCollector(diff=branch_diff, branch=obj, db=db)
|
infrahub/core/diff/calculator.py
CHANGED
|
@@ -181,8 +181,11 @@ class DiffCalculator:
|
|
|
181
181
|
log.info("Diff property-level calculation queries for branch complete")
|
|
182
182
|
|
|
183
183
|
if base_branch.name != diff_branch.name:
|
|
184
|
-
current_node_field_specifiers = diff_parser.get_current_node_field_specifiers()
|
|
185
184
|
new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
|
|
185
|
+
current_node_field_specifiers = None
|
|
186
|
+
if previous_node_specifiers is not None:
|
|
187
|
+
current_node_field_specifiers = previous_node_specifiers - new_node_field_specifiers
|
|
188
|
+
|
|
186
189
|
base_calculation_request = DiffCalculationRequest(
|
|
187
190
|
base_branch=base_branch,
|
|
188
191
|
diff_branch=base_branch,
|
|
@@ -6,8 +6,7 @@ from uuid import uuid4
|
|
|
6
6
|
|
|
7
7
|
from prefect import flow
|
|
8
8
|
|
|
9
|
-
from infrahub import
|
|
10
|
-
from infrahub.core.branch import Branch # noqa: TC001
|
|
9
|
+
from infrahub.core.branch import Branch
|
|
11
10
|
from infrahub.core.timestamp import Timestamp
|
|
12
11
|
from infrahub.exceptions import ValidationError
|
|
13
12
|
from infrahub.log import get_logger
|
|
@@ -26,12 +25,14 @@ from .model.path import (
|
|
|
26
25
|
|
|
27
26
|
if TYPE_CHECKING:
|
|
28
27
|
from infrahub.core.node import Node
|
|
28
|
+
from infrahub.database import InfrahubDatabase
|
|
29
29
|
|
|
30
30
|
from .calculator import DiffCalculator
|
|
31
31
|
from .combiner import DiffCombiner
|
|
32
32
|
from .conflict_transferer import DiffConflictTransferer
|
|
33
33
|
from .conflicts_enricher import ConflictsEnricher
|
|
34
34
|
from .data_check_synchronizer import DiffDataCheckSynchronizer
|
|
35
|
+
from .diff_locker import DiffLocker
|
|
35
36
|
from .enricher.aggregated import AggregatedDiffEnricher
|
|
36
37
|
from .enricher.labels import DiffLabelsEnricher
|
|
37
38
|
from .repository.repository import DiffRepository
|
|
@@ -59,10 +60,9 @@ class EnrichedDiffRequest:
|
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
class DiffCoordinator:
|
|
62
|
-
lock_namespace = "diff-update"
|
|
63
|
-
|
|
64
63
|
def __init__(
|
|
65
64
|
self,
|
|
65
|
+
db: InfrahubDatabase,
|
|
66
66
|
diff_repo: DiffRepository,
|
|
67
67
|
diff_calculator: DiffCalculator,
|
|
68
68
|
diff_enricher: AggregatedDiffEnricher,
|
|
@@ -71,7 +71,9 @@ class DiffCoordinator:
|
|
|
71
71
|
labels_enricher: DiffLabelsEnricher,
|
|
72
72
|
data_check_synchronizer: DiffDataCheckSynchronizer,
|
|
73
73
|
conflict_transferer: DiffConflictTransferer,
|
|
74
|
+
diff_locker: DiffLocker,
|
|
74
75
|
) -> None:
|
|
76
|
+
self.db = db
|
|
75
77
|
self.diff_repo = diff_repo
|
|
76
78
|
self.diff_calculator = diff_calculator
|
|
77
79
|
self.diff_enricher = diff_enricher
|
|
@@ -80,7 +82,7 @@ class DiffCoordinator:
|
|
|
80
82
|
self.labels_enricher = labels_enricher
|
|
81
83
|
self.data_check_synchronizer = data_check_synchronizer
|
|
82
84
|
self.conflict_transferer = conflict_transferer
|
|
83
|
-
self.
|
|
85
|
+
self.diff_locker = diff_locker
|
|
84
86
|
|
|
85
87
|
async def run_update(
|
|
86
88
|
self,
|
|
@@ -113,37 +115,35 @@ class DiffCoordinator:
|
|
|
113
115
|
name=name,
|
|
114
116
|
)
|
|
115
117
|
|
|
116
|
-
def _get_lock_name(self, base_branch_name: str, diff_branch_name: str, is_incremental: bool) -> str:
|
|
117
|
-
lock_name = f"{base_branch_name}__{diff_branch_name}"
|
|
118
|
-
if is_incremental:
|
|
119
|
-
lock_name += "__incremental"
|
|
120
|
-
return lock_name
|
|
121
|
-
|
|
122
118
|
async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRootMetadata:
|
|
119
|
+
tracking_id = BranchTrackingId(name=diff_branch.name)
|
|
123
120
|
log.info(f"Received request to update branch diff for {base_branch.name} - {diff_branch.name}")
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
)
|
|
127
|
-
existing_incremental_lock = self.lock_registry.get_existing(
|
|
128
|
-
name=incremental_lock_name, namespace=self.lock_namespace
|
|
121
|
+
existing_incremental_lock = self.diff_locker.get_existing_lock(
|
|
122
|
+
target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=True
|
|
129
123
|
)
|
|
130
124
|
if existing_incremental_lock and await existing_incremental_lock.locked():
|
|
131
125
|
log.info(f"Branch diff update for {base_branch.name} - {diff_branch.name} already in progress")
|
|
132
|
-
async with self.
|
|
126
|
+
async with self.diff_locker.acquire_lock(
|
|
127
|
+
target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=True
|
|
128
|
+
):
|
|
133
129
|
log.info(f"Existing branch diff update for {base_branch.name} - {diff_branch.name} complete")
|
|
134
|
-
return await self.diff_repo.get_one(
|
|
135
|
-
tracking_id=BranchTrackingId(name=diff_branch.name), diff_branch_name=diff_branch.name
|
|
136
|
-
)
|
|
137
|
-
general_lock_name = self._get_lock_name(
|
|
138
|
-
base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False
|
|
139
|
-
)
|
|
130
|
+
return await self.diff_repo.get_one(tracking_id=tracking_id, diff_branch_name=diff_branch.name)
|
|
140
131
|
from_time = Timestamp(diff_branch.get_branched_from())
|
|
141
132
|
to_time = Timestamp()
|
|
142
|
-
tracking_id = BranchTrackingId(name=diff_branch.name)
|
|
143
133
|
async with (
|
|
144
|
-
self.
|
|
145
|
-
|
|
134
|
+
self.diff_locker.acquire_lock(
|
|
135
|
+
target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=True
|
|
136
|
+
),
|
|
137
|
+
self.diff_locker.acquire_lock(
|
|
138
|
+
target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=False
|
|
139
|
+
),
|
|
146
140
|
):
|
|
141
|
+
refreshed_branch = await Branch.get_by_name(db=self.db, name=diff_branch.name)
|
|
142
|
+
if refreshed_branch.get_branched_from() != diff_branch.get_branched_from():
|
|
143
|
+
log.info(
|
|
144
|
+
f"Branch {diff_branch.name} was merged or rebased while waiting for lock, returning latest diff"
|
|
145
|
+
)
|
|
146
|
+
return await self.diff_repo.get_one(tracking_id=tracking_id, diff_branch_name=diff_branch.name)
|
|
147
147
|
log.info(f"Acquired lock to run branch diff update for {base_branch.name} - {diff_branch.name}")
|
|
148
148
|
enriched_diffs, node_identifiers_to_drop = await self._update_diffs(
|
|
149
149
|
base_branch=base_branch,
|
|
@@ -169,10 +169,9 @@ class DiffCoordinator:
|
|
|
169
169
|
name: str,
|
|
170
170
|
) -> EnrichedDiffRootMetadata:
|
|
171
171
|
tracking_id = NameTrackingId(name=name)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
)
|
|
175
|
-
async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace):
|
|
172
|
+
async with self.diff_locker.acquire_lock(
|
|
173
|
+
target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=False
|
|
174
|
+
):
|
|
176
175
|
log.info(f"Acquired lock to run arbitrary diff update for {base_branch.name} - {diff_branch.name}")
|
|
177
176
|
enriched_diffs, node_identifiers_to_drop = await self._update_diffs(
|
|
178
177
|
base_branch=base_branch,
|
|
@@ -196,10 +195,9 @@ class DiffCoordinator:
|
|
|
196
195
|
diff_branch: Branch,
|
|
197
196
|
diff_id: str,
|
|
198
197
|
) -> EnrichedDiffRoot:
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
)
|
|
202
|
-
async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace):
|
|
198
|
+
async with self.diff_locker.acquire_lock(
|
|
199
|
+
target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=False
|
|
200
|
+
):
|
|
203
201
|
log.info(f"Acquired lock to recalculate diff for {base_branch.name} - {diff_branch.name}")
|
|
204
202
|
current_branch_diff = await self.diff_repo.get_one(diff_branch_name=diff_branch.name, diff_id=diff_id)
|
|
205
203
|
current_base_diff = await self.diff_repo.get_one(
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from infrahub import lock
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DiffLocker:
|
|
5
|
+
lock_namespace = "diff-update"
|
|
6
|
+
|
|
7
|
+
def __init__(self) -> None:
|
|
8
|
+
self.lock_registry = lock.registry
|
|
9
|
+
|
|
10
|
+
def get_lock_name(self, base_branch_name: str, diff_branch_name: str, is_incremental: bool) -> str:
|
|
11
|
+
lock_name = f"{base_branch_name}__{diff_branch_name}"
|
|
12
|
+
if is_incremental:
|
|
13
|
+
lock_name += "__incremental"
|
|
14
|
+
return lock_name
|
|
15
|
+
|
|
16
|
+
def get_existing_lock(
|
|
17
|
+
self, target_branch_name: str, source_branch_name: str, is_incremental: bool = False
|
|
18
|
+
) -> lock.InfrahubLock | None:
|
|
19
|
+
name = self.get_lock_name(target_branch_name, source_branch_name, is_incremental)
|
|
20
|
+
return self.lock_registry.get_existing(name=name, namespace=self.lock_namespace)
|
|
21
|
+
|
|
22
|
+
def acquire_lock(
|
|
23
|
+
self, target_branch_name: str, source_branch_name: str, is_incremental: bool = False
|
|
24
|
+
) -> lock.InfrahubLock:
|
|
25
|
+
name = self.get_lock_name(target_branch_name, source_branch_name, is_incremental)
|
|
26
|
+
return self.lock_registry.get(name=name, namespace=self.lock_namespace)
|