infrahub-server 1.3.2__py3-none-any.whl → 1.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. infrahub/cli/db.py +194 -13
  2. infrahub/core/branch/enums.py +8 -0
  3. infrahub/core/branch/models.py +28 -5
  4. infrahub/core/branch/tasks.py +5 -7
  5. infrahub/core/diff/coordinator.py +32 -34
  6. infrahub/core/diff/diff_locker.py +26 -0
  7. infrahub/core/graph/__init__.py +1 -1
  8. infrahub/core/initialization.py +4 -3
  9. infrahub/core/merge.py +31 -16
  10. infrahub/core/migrations/graph/__init__.py +24 -0
  11. infrahub/core/migrations/graph/m012_convert_account_generic.py +4 -3
  12. infrahub/core/migrations/graph/m013_convert_git_password_credential.py +4 -3
  13. infrahub/core/migrations/graph/m032_cleanup_orphaned_branch_relationships.py +105 -0
  14. infrahub/core/migrations/graph/m033_deduplicate_relationship_vertices.py +97 -0
  15. infrahub/core/node/__init__.py +3 -0
  16. infrahub/core/node/resource_manager/ip_address_pool.py +5 -3
  17. infrahub/core/node/resource_manager/ip_prefix_pool.py +7 -4
  18. infrahub/core/node/resource_manager/number_pool.py +3 -1
  19. infrahub/core/node/standard.py +4 -0
  20. infrahub/core/query/branch.py +25 -56
  21. infrahub/core/query/node.py +78 -24
  22. infrahub/core/query/relationship.py +11 -8
  23. infrahub/core/relationship/model.py +10 -5
  24. infrahub/dependencies/builder/diff/coordinator.py +3 -0
  25. infrahub/dependencies/builder/diff/locker.py +8 -0
  26. infrahub/graphql/mutations/main.py +7 -2
  27. infrahub/graphql/mutations/tasks.py +2 -0
  28. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/METADATA +1 -1
  29. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/RECORD +35 -30
  30. infrahub_testcontainers/container.py +1 -1
  31. infrahub_testcontainers/docker-compose-cluster.test.yml +3 -0
  32. infrahub_testcontainers/docker-compose.test.yml +1 -0
  33. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/LICENSE.txt +0 -0
  34. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/WHEEL +0 -0
  35. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/entry_points.txt +0 -0
infrahub/cli/db.py CHANGED
@@ -8,7 +8,7 @@ from csv import DictReader, DictWriter
8
8
  from datetime import datetime, timezone
9
9
  from enum import Enum
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any
11
+ from typing import TYPE_CHECKING, Any, Sequence
12
12
 
13
13
  import typer
14
14
  import ujson
@@ -38,7 +38,7 @@ from infrahub.core.initialization import (
38
38
  initialization,
39
39
  initialize_registry,
40
40
  )
41
- from infrahub.core.migrations.graph import get_graph_migrations
41
+ from infrahub.core.migrations.graph import get_graph_migrations, get_migration_by_number
42
42
  from infrahub.core.migrations.schema.models import SchemaApplyMigrationData
43
43
  from infrahub.core.migrations.schema.tasks import schema_apply_migrations
44
44
  from infrahub.core.schema import SchemaRoot, core_models, internal_schema
@@ -58,8 +58,15 @@ from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
58
58
  from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
59
59
  from .patch import patch_app
60
60
 
61
+
62
+ def get_timestamp_string() -> str:
63
+ """Generate a timestamp string in the format YYYYMMDD-HHMMSS."""
64
+ return datetime.now(tz=timezone.utc).strftime("%Y%m%d-%H%M%S")
65
+
66
+
61
67
  if TYPE_CHECKING:
62
68
  from infrahub.cli.context import CliContext
69
+ from infrahub.core.migrations.shared import ArbitraryMigration, GraphMigration, InternalSchemaMigration
63
70
  from infrahub.database import InfrahubDatabase
64
71
  from infrahub.database.index import IndexManagerBase
65
72
 
@@ -154,6 +161,7 @@ async def migrate_cmd(
154
161
  ctx: typer.Context,
155
162
  check: bool = typer.Option(False, help="Check the state of the database without applying the migrations."),
156
163
  config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
164
+ migration_number: int | None = typer.Option(None, help="Apply a specific migration by number"),
157
165
  ) -> None:
158
166
  """Check the current format of the internal graph and apply the necessary migrations"""
159
167
  logging.getLogger("infrahub").setLevel(logging.WARNING)
@@ -165,7 +173,7 @@ async def migrate_cmd(
165
173
  context: CliContext = ctx.obj
166
174
  dbdriver = await context.init_db(retry=1)
167
175
 
168
- await migrate_database(db=dbdriver, initialize=True, check=check)
176
+ await migrate_database(db=dbdriver, initialize=True, check=check, migration_number=migration_number)
169
177
 
170
178
  await dbdriver.close()
171
179
 
@@ -287,7 +295,9 @@ async def index(
287
295
  await dbdriver.close()
288
296
 
289
297
 
290
- async def migrate_database(db: InfrahubDatabase, initialize: bool = False, check: bool = False) -> bool:
298
+ async def migrate_database(
299
+ db: InfrahubDatabase, initialize: bool = False, check: bool = False, migration_number: int | str | None = None
300
+ ) -> bool:
291
301
  """Apply the latest migrations to the database, this function will print the status directly in the console.
292
302
 
293
303
  Returns a boolean indicating whether a migration failed or if all migrations succeeded.
@@ -295,6 +305,7 @@ async def migrate_database(db: InfrahubDatabase, initialize: bool = False, check
295
305
  Args:
296
306
  db: The database object.
297
307
  check: If True, the function will only check the status of the database and not apply the migrations. Defaults to False.
308
+ migration_number: If provided, the function will only apply the migration with the given number. Defaults to None.
298
309
  """
299
310
  rprint("Checking current state of the Database")
300
311
 
@@ -302,15 +313,28 @@ async def migrate_database(db: InfrahubDatabase, initialize: bool = False, check
302
313
  await initialize_registry(db=db)
303
314
 
304
315
  root_node = await get_root_node(db=db)
305
- migrations = await get_graph_migrations(root=root_node)
306
-
307
- if not migrations:
308
- rprint(f"Database up-to-date (v{root_node.graph_version}), no migration to execute.")
309
- return True
316
+ if migration_number:
317
+ migration = get_migration_by_number(migration_number)
318
+ migrations: Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration] = [migration]
319
+ if check:
320
+ if root_node.graph_version > migration.minimum_version:
321
+ rprint(
322
+ f"Migration {migration_number} already applied. To apply again, run the command without the --check flag."
323
+ )
324
+ return True
325
+ rprint(
326
+ f"Migration {migration_number} needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations."
327
+ )
328
+ return False
329
+ else:
330
+ migrations = await get_graph_migrations(root=root_node)
331
+ if not migrations:
332
+ rprint(f"Database up-to-date (v{root_node.graph_version}), no migration to execute.")
333
+ return True
310
334
 
311
- rprint(
312
- f"Database needs to be updated (v{root_node.graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending"
313
- )
335
+ rprint(
336
+ f"Database needs to be updated (v{root_node.graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending"
337
+ )
314
338
 
315
339
  if check:
316
340
  return True
@@ -529,7 +553,7 @@ WITH n, edges + root_edges AS edges, CASE
529
553
  END AS vertices
530
554
  RETURN vertices, edges
531
555
  """ % {"id_func": db.get_id_function_name()}
532
- timestamp_str = datetime.now(tz=timezone.utc).strftime("%Y%m%d-%H%M%S")
556
+ timestamp_str = get_timestamp_string()
533
557
  export_dir /= Path(f"export-{timestamp_str}")
534
558
  if not export_dir.exists():
535
559
  export_dir.mkdir(parents=True)
@@ -737,3 +761,160 @@ async def load_export(db: InfrahubDatabase, export_dir: Path, query_limit: int =
737
761
  await load_edges(db=db, edge_type=edge_type, edge_dicts=edge_dicts)
738
762
  rprint("Edges loaded")
739
763
  rprint(f"{SUCCESS_BADGE} Export loaded")
764
+
765
+
766
+ @app.command(name="check")
767
+ async def check_cmd(
768
+ ctx: typer.Context,
769
+ output_dir: Path = typer.Option( # noqa: B008
770
+ None, help="Directory to save detailed check results (defaults to infrahub_db_check_YYYYMMDD-HHMMSS)"
771
+ ),
772
+ config_file: str = typer.Option(
773
+ "infrahub.toml", envvar="INFRAHUB_CONFIG", help="Location of the configuration file to use for Infrahub"
774
+ ),
775
+ ) -> None:
776
+ """Run database sanity checks and output the results to the CSV files."""
777
+ logging.getLogger("infrahub").setLevel(logging.WARNING)
778
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
779
+ logging.getLogger("prefect").setLevel(logging.ERROR)
780
+
781
+ config.load_and_exit(config_file_name=config_file)
782
+
783
+ # Create output directory if not specified
784
+ if output_dir is None:
785
+ timestamp_str = get_timestamp_string()
786
+ output_dir = Path(f"infrahub_db_check_{timestamp_str}")
787
+
788
+ if not output_dir.exists():
789
+ output_dir.mkdir(parents=True)
790
+
791
+ context: CliContext = ctx.obj
792
+ dbdriver = await context.init_db(retry=1)
793
+
794
+ await run_database_checks(db=dbdriver, output_dir=output_dir)
795
+
796
+ await dbdriver.close()
797
+
798
+
799
+ async def run_database_checks(db: InfrahubDatabase, output_dir: Path) -> None:
800
+ """Run a series of database health checks and output the results to the terminal.
801
+
802
+ Args:
803
+ db: The database object.
804
+ output_dir: Directory to save detailed check results.
805
+ """
806
+ rprint("Running database health checks...")
807
+
808
+ # Check 1: Duplicate active relationships
809
+ rprint("\n[bold cyan]Check 1: Duplicate Active Relationships[/bold cyan]")
810
+ duplicate_active_rels_query = """
811
+ MATCH (a:Node)-[e1:IS_RELATED {status: "active"}]-(r:Relationship)-[e2:IS_RELATED {branch: e1.branch, status: "active"}]-(b:Node)
812
+ WHERE a.uuid < b.uuid
813
+ AND e1.to IS NULL
814
+ AND e2.to IS NULL
815
+ WITH DISTINCT a.uuid AS a_uuid,
816
+ b.uuid AS b_uuid,
817
+ r.name AS r_name,
818
+ e1.branch AS branch,
819
+ CASE
820
+ WHEN startNode(e1) = a AND startNode(e2) = r THEN "out"
821
+ WHEN startNode(e1) = r AND startNode(e2) = b THEN "in"
822
+ ELSE "bidir"
823
+ END AS direction,
824
+ count(*) AS num_paths,
825
+ collect(DISTINCT a.kind) AS a_kinds,
826
+ collect(DISTINCT b.kind) AS b_kinds
827
+ WHERE num_paths > 1
828
+ RETURN a_uuid, a_kinds, b_uuid, b_kinds, r_name, branch, direction, num_paths
829
+ """
830
+
831
+ results = await db.execute_query(query=duplicate_active_rels_query)
832
+ if results:
833
+ rprint(f"[red]Found {len(results)} duplicate active relationships[/red]")
834
+ # Write detailed results to file
835
+ output_file = output_dir / "duplicate_active_relationships.csv"
836
+ with output_file.open(mode="w", newline="") as f:
837
+ writer = DictWriter(
838
+ f, fieldnames=["a_uuid", "a_kinds", "b_uuid", "b_kinds", "r_name", "branch", "direction", "num_paths"]
839
+ )
840
+ writer.writeheader()
841
+ for result in results:
842
+ writer.writerow(dict(result))
843
+ rprint(f" Detailed results written to: {output_file}")
844
+ else:
845
+ rprint(f"{SUCCESS_BADGE} No duplicate active relationships found")
846
+
847
+ # Check 2: Duplicated relationship nodes
848
+ rprint("\n[bold cyan]Check 2: Duplicated Relationship Nodes[/bold cyan]")
849
+ duplicate_rel_nodes_query = """
850
+ MATCH (r:Relationship)
851
+ WITH r.uuid AS r_uuid, COUNT(*) AS num_rels
852
+ WHERE num_rels > 1
853
+ MATCH (n:Node)-[:IS_RELATED]-(r:Relationship {uuid: r_uuid})
854
+ WITH DISTINCT r_uuid, n.uuid AS n_uuid, n.kind AS n_kind
855
+ WITH r_uuid, collect([n_uuid, n_kind]) AS node_details
856
+ RETURN r_uuid, node_details
857
+ """
858
+
859
+ results = await db.execute_query(query=duplicate_rel_nodes_query)
860
+ if results:
861
+ rprint(f"[red]Found {len(results)} duplicated relationship nodes[/red]")
862
+ # Write detailed results to file
863
+ output_file = output_dir / "duplicated_relationship_nodes.csv"
864
+ with output_file.open(mode="w", newline="") as f:
865
+ writer = DictWriter(f, fieldnames=["r_uuid", "node_details"])
866
+ writer.writeheader()
867
+ for result in results:
868
+ writer.writerow(dict(result))
869
+ rprint(f" Detailed results written to: {output_file}")
870
+ else:
871
+ rprint(f"{SUCCESS_BADGE} No duplicated relationship nodes found")
872
+
873
+ # Check 3: Duplicated edges
874
+ rprint("\n[bold cyan]Check 3: Duplicated Edges[/bold cyan]")
875
+ duplicate_edges_query = """
876
+ MATCH (a)
877
+ CALL (a) {
878
+ MATCH (a)-[e]->(b)
879
+ WHERE elementId(a) < elementId(b)
880
+ WITH DISTINCT a, b, type(e) AS e_type, count(*) AS total_num_edges
881
+ WHERE total_num_edges > 1
882
+ MATCH (a)-[e]->(b)
883
+ WHERE type(e) = e_type
884
+ WITH
885
+ elementId(a) AS a_id,
886
+ labels(a) AS a_labels,
887
+ elementId(b) AS b_id,
888
+ labels(b) AS b_labels,
889
+ type(e) AS e_type,
890
+ e.branch AS branch,
891
+ e.status AS status,
892
+ e.from AS time,
893
+ collect(e) AS edges
894
+ WITH a_id, a_labels, b_id, b_labels, e_type, branch, status, time, size(edges) AS num_edges
895
+ WHERE num_edges > 1
896
+ WITH a_id, a_labels, b_id, b_labels, e_type, branch, status, time, num_edges
897
+ RETURN a_id, a_labels, b_id, b_labels, e_type, branch, status, time, num_edges
898
+ }
899
+ RETURN a_id, a_labels, b_id, b_labels, e_type, branch, status, time, num_edges
900
+ """
901
+
902
+ results = await db.execute_query(query=duplicate_edges_query)
903
+ if results:
904
+ rprint(f"[red]Found {len(results)} sets of duplicated edges[/red]")
905
+ # Write detailed results to file
906
+ output_file = output_dir / "duplicated_edges.csv"
907
+ with output_file.open(mode="w", newline="") as f:
908
+ writer = DictWriter(
909
+ f,
910
+ fieldnames=["a_id", "a_labels", "b_id", "b_labels", "e_type", "branch", "status", "time", "num_edges"],
911
+ )
912
+ writer.writeheader()
913
+ for result in results:
914
+ writer.writerow(dict(result))
915
+ rprint(f" Detailed results written to: {output_file}")
916
+ else:
917
+ rprint(f"{SUCCESS_BADGE} No duplicated edges found")
918
+
919
+ rprint(f"\n{SUCCESS_BADGE} Database health checks completed")
920
+ rprint(f"Detailed results saved to: {output_dir.absolute()}")
@@ -0,0 +1,8 @@
1
+ from infrahub.utils import InfrahubStringEnum
2
+
3
+
4
+ class BranchStatus(InfrahubStringEnum):
5
+ OPEN = "OPEN"
6
+ NEED_REBASE = "NEED_REBASE"
7
+ CLOSED = "CLOSED"
8
+ DELETING = "DELETING"
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import re
4
- from typing import TYPE_CHECKING, Any, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Optional, Self, Union
5
5
 
6
6
  from pydantic import Field, field_validator
7
7
 
@@ -21,6 +21,8 @@ from infrahub.core.registry import registry
21
21
  from infrahub.core.timestamp import Timestamp
22
22
  from infrahub.exceptions import BranchNotFoundError, InitializationError, ValidationError
23
23
 
24
+ from .enums import BranchStatus
25
+
24
26
  if TYPE_CHECKING:
25
27
  from infrahub.database import InfrahubDatabase
26
28
 
@@ -29,7 +31,7 @@ class Branch(StandardNode):
29
31
  name: str = Field(
30
32
  max_length=250, min_length=3, description="Name of the branch (git ref standard)", validate_default=True
31
33
  )
32
- status: str = "OPEN" # OPEN, CLOSED
34
+ status: BranchStatus = BranchStatus.OPEN
33
35
  description: str = ""
34
36
  origin_branch: str = "main"
35
37
  branched_from: Optional[str] = Field(default=None, validate_default=True)
@@ -131,14 +133,17 @@ class Branch(StandardNode):
131
133
  return True
132
134
 
133
135
  @classmethod
134
- async def get_by_name(cls, name: str, db: InfrahubDatabase) -> Branch:
136
+ async def get_by_name(cls, name: str, db: InfrahubDatabase, ignore_deleting: bool = True) -> Branch:
135
137
  query = """
136
138
  MATCH (n:Branch)
137
139
  WHERE n.name = $name
140
+ AND NOT n.status IN $ignore_statuses
138
141
  RETURN n
139
142
  """
140
143
 
141
- params = {"name": name}
144
+ params: dict[str, Any] = {"name": name}
145
+ if ignore_deleting:
146
+ params["ignore_statuses"] = [BranchStatus.DELETING.value]
142
147
 
143
148
  results = await db.execute_query(query=query, params=params, name="branch_get_by_name", type=QueryType.READ)
144
149
 
@@ -147,6 +152,20 @@ class Branch(StandardNode):
147
152
 
148
153
  return cls.from_db(results[0].values()[0])
149
154
 
155
+ @classmethod
156
+ async def get_list(
157
+ cls,
158
+ db: InfrahubDatabase,
159
+ limit: int = 1000,
160
+ ids: list[str] | None = None,
161
+ name: str | None = None,
162
+ **kwargs: dict[str, Any],
163
+ ) -> list[Self]:
164
+ branches = await super().get_list(db=db, limit=limit, ids=ids, name=name, **kwargs)
165
+ branches = [branch for branch in branches if branch.status != BranchStatus.DELETING]
166
+
167
+ return branches
168
+
150
169
  @classmethod
151
170
  def isinstance(cls, obj: Any) -> bool:
152
171
  return isinstance(obj, cls)
@@ -248,9 +267,13 @@ class Branch(StandardNode):
248
267
  raise ValidationError(f"Unable to delete {self.name} it is the default branch.")
249
268
  if self.is_global:
250
269
  raise ValidationError(f"Unable to delete {self.name} this is an internal branch.")
251
- await super().delete(db=db)
270
+
271
+ self.status = BranchStatus.DELETING
272
+ await self.save(db=db)
273
+
252
274
  query = await DeleteBranchRelationshipsQuery.init(db=db, branch_name=self.name)
253
275
  await query.execute(db=db)
276
+ await super().delete(db=db)
254
277
 
255
278
  def get_query_filter_relationships(
256
279
  self, rel_labels: list, at: Optional[Union[Timestamp, str]] = None, include_outside_parentheses: bool = False
@@ -15,6 +15,7 @@ from infrahub.core.branch import Branch
15
15
  from infrahub.core.changelog.diff import DiffChangelogCollector, MigrationTracker
16
16
  from infrahub.core.constants import MutationAction
17
17
  from infrahub.core.diff.coordinator import DiffCoordinator
18
+ from infrahub.core.diff.diff_locker import DiffLocker
18
19
  from infrahub.core.diff.ipam_diff_parser import IpamDiffParser
19
20
  from infrahub.core.diff.merger.merger import DiffMerger
20
21
  from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRoot, EnrichedDiffRootMetadata
@@ -31,7 +32,7 @@ from infrahub.dependencies.registry import get_component_registry
31
32
  from infrahub.events.branch_action import BranchCreatedEvent, BranchDeletedEvent, BranchMergedEvent, BranchRebasedEvent
32
33
  from infrahub.events.models import EventMeta, InfrahubEvent
33
34
  from infrahub.events.node_action import get_node_event
34
- from infrahub.exceptions import BranchNotFoundError, MergeFailedError, ValidationError
35
+ from infrahub.exceptions import BranchNotFoundError, ValidationError
35
36
  from infrahub.graphql.mutations.models import BranchCreateModel # noqa: TC001
36
37
  from infrahub.services import InfrahubServices # noqa: TC001 needed for prefect flow
37
38
  from infrahub.workflows.catalogue import (
@@ -65,6 +66,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
65
66
  diff_merger=diff_merger,
66
67
  diff_repository=diff_repository,
67
68
  source_branch=obj,
69
+ diff_locker=DiffLocker(),
68
70
  service=service,
69
71
  )
70
72
 
@@ -221,14 +223,10 @@ async def merge_branch(
221
223
  diff_merger=diff_merger,
222
224
  diff_repository=diff_repository,
223
225
  source_branch=obj,
226
+ diff_locker=DiffLocker(),
224
227
  service=service,
225
228
  )
226
- try:
227
- branch_diff = await merger.merge()
228
- except Exception as exc:
229
- log.exception("Merge failed, beginning rollback")
230
- await merger.rollback()
231
- raise MergeFailedError(branch_name=branch) from exc
229
+ branch_diff = await merger.merge()
232
230
  await merger.update_schema()
233
231
 
234
232
  changelog_collector = DiffChangelogCollector(diff=branch_diff, branch=obj, db=db)
@@ -6,8 +6,7 @@ from uuid import uuid4
6
6
 
7
7
  from prefect import flow
8
8
 
9
- from infrahub import lock
10
- from infrahub.core.branch import Branch # noqa: TC001
9
+ from infrahub.core.branch import Branch
11
10
  from infrahub.core.timestamp import Timestamp
12
11
  from infrahub.exceptions import ValidationError
13
12
  from infrahub.log import get_logger
@@ -26,12 +25,14 @@ from .model.path import (
26
25
 
27
26
  if TYPE_CHECKING:
28
27
  from infrahub.core.node import Node
28
+ from infrahub.database import InfrahubDatabase
29
29
 
30
30
  from .calculator import DiffCalculator
31
31
  from .combiner import DiffCombiner
32
32
  from .conflict_transferer import DiffConflictTransferer
33
33
  from .conflicts_enricher import ConflictsEnricher
34
34
  from .data_check_synchronizer import DiffDataCheckSynchronizer
35
+ from .diff_locker import DiffLocker
35
36
  from .enricher.aggregated import AggregatedDiffEnricher
36
37
  from .enricher.labels import DiffLabelsEnricher
37
38
  from .repository.repository import DiffRepository
@@ -59,10 +60,9 @@ class EnrichedDiffRequest:
59
60
 
60
61
 
61
62
  class DiffCoordinator:
62
- lock_namespace = "diff-update"
63
-
64
63
  def __init__(
65
64
  self,
65
+ db: InfrahubDatabase,
66
66
  diff_repo: DiffRepository,
67
67
  diff_calculator: DiffCalculator,
68
68
  diff_enricher: AggregatedDiffEnricher,
@@ -71,7 +71,9 @@ class DiffCoordinator:
71
71
  labels_enricher: DiffLabelsEnricher,
72
72
  data_check_synchronizer: DiffDataCheckSynchronizer,
73
73
  conflict_transferer: DiffConflictTransferer,
74
+ diff_locker: DiffLocker,
74
75
  ) -> None:
76
+ self.db = db
75
77
  self.diff_repo = diff_repo
76
78
  self.diff_calculator = diff_calculator
77
79
  self.diff_enricher = diff_enricher
@@ -80,7 +82,7 @@ class DiffCoordinator:
80
82
  self.labels_enricher = labels_enricher
81
83
  self.data_check_synchronizer = data_check_synchronizer
82
84
  self.conflict_transferer = conflict_transferer
83
- self.lock_registry = lock.registry
85
+ self.diff_locker = diff_locker
84
86
 
85
87
  async def run_update(
86
88
  self,
@@ -113,37 +115,35 @@ class DiffCoordinator:
113
115
  name=name,
114
116
  )
115
117
 
116
- def _get_lock_name(self, base_branch_name: str, diff_branch_name: str, is_incremental: bool) -> str:
117
- lock_name = f"{base_branch_name}__{diff_branch_name}"
118
- if is_incremental:
119
- lock_name += "__incremental"
120
- return lock_name
121
-
122
118
  async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) -> EnrichedDiffRootMetadata:
119
+ tracking_id = BranchTrackingId(name=diff_branch.name)
123
120
  log.info(f"Received request to update branch diff for {base_branch.name} - {diff_branch.name}")
124
- incremental_lock_name = self._get_lock_name(
125
- base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=True
126
- )
127
- existing_incremental_lock = self.lock_registry.get_existing(
128
- name=incremental_lock_name, namespace=self.lock_namespace
121
+ existing_incremental_lock = self.diff_locker.get_existing_lock(
122
+ target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=True
129
123
  )
130
124
  if existing_incremental_lock and await existing_incremental_lock.locked():
131
125
  log.info(f"Branch diff update for {base_branch.name} - {diff_branch.name} already in progress")
132
- async with self.lock_registry.get(name=incremental_lock_name, namespace=self.lock_namespace):
126
+ async with self.diff_locker.acquire_lock(
127
+ target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=True
128
+ ):
133
129
  log.info(f"Existing branch diff update for {base_branch.name} - {diff_branch.name} complete")
134
- return await self.diff_repo.get_one(
135
- tracking_id=BranchTrackingId(name=diff_branch.name), diff_branch_name=diff_branch.name
136
- )
137
- general_lock_name = self._get_lock_name(
138
- base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False
139
- )
130
+ return await self.diff_repo.get_one(tracking_id=tracking_id, diff_branch_name=diff_branch.name)
140
131
  from_time = Timestamp(diff_branch.get_branched_from())
141
132
  to_time = Timestamp()
142
- tracking_id = BranchTrackingId(name=diff_branch.name)
143
133
  async with (
144
- self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace),
145
- self.lock_registry.get(name=incremental_lock_name, namespace=self.lock_namespace),
134
+ self.diff_locker.acquire_lock(
135
+ target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=True
136
+ ),
137
+ self.diff_locker.acquire_lock(
138
+ target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=False
139
+ ),
146
140
  ):
141
+ refreshed_branch = await Branch.get_by_name(db=self.db, name=diff_branch.name)
142
+ if refreshed_branch.get_branched_from() != diff_branch.get_branched_from():
143
+ log.info(
144
+ f"Branch {diff_branch.name} was merged or rebased while waiting for lock, returning latest diff"
145
+ )
146
+ return await self.diff_repo.get_one(tracking_id=tracking_id, diff_branch_name=diff_branch.name)
147
147
  log.info(f"Acquired lock to run branch diff update for {base_branch.name} - {diff_branch.name}")
148
148
  enriched_diffs, node_identifiers_to_drop = await self._update_diffs(
149
149
  base_branch=base_branch,
@@ -169,10 +169,9 @@ class DiffCoordinator:
169
169
  name: str,
170
170
  ) -> EnrichedDiffRootMetadata:
171
171
  tracking_id = NameTrackingId(name=name)
172
- general_lock_name = self._get_lock_name(
173
- base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False
174
- )
175
- async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace):
172
+ async with self.diff_locker.acquire_lock(
173
+ target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=False
174
+ ):
176
175
  log.info(f"Acquired lock to run arbitrary diff update for {base_branch.name} - {diff_branch.name}")
177
176
  enriched_diffs, node_identifiers_to_drop = await self._update_diffs(
178
177
  base_branch=base_branch,
@@ -196,10 +195,9 @@ class DiffCoordinator:
196
195
  diff_branch: Branch,
197
196
  diff_id: str,
198
197
  ) -> EnrichedDiffRoot:
199
- general_lock_name = self._get_lock_name(
200
- base_branch_name=base_branch.name, diff_branch_name=diff_branch.name, is_incremental=False
201
- )
202
- async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace):
198
+ async with self.diff_locker.acquire_lock(
199
+ target_branch_name=base_branch.name, source_branch_name=diff_branch.name, is_incremental=False
200
+ ):
203
201
  log.info(f"Acquired lock to recalculate diff for {base_branch.name} - {diff_branch.name}")
204
202
  current_branch_diff = await self.diff_repo.get_one(diff_branch_name=diff_branch.name, diff_id=diff_id)
205
203
  current_base_diff = await self.diff_repo.get_one(
@@ -0,0 +1,26 @@
1
+ from infrahub import lock
2
+
3
+
4
+ class DiffLocker:
5
+ lock_namespace = "diff-update"
6
+
7
+ def __init__(self) -> None:
8
+ self.lock_registry = lock.registry
9
+
10
+ def get_lock_name(self, base_branch_name: str, diff_branch_name: str, is_incremental: bool) -> str:
11
+ lock_name = f"{base_branch_name}__{diff_branch_name}"
12
+ if is_incremental:
13
+ lock_name += "__incremental"
14
+ return lock_name
15
+
16
+ def get_existing_lock(
17
+ self, target_branch_name: str, source_branch_name: str, is_incremental: bool = False
18
+ ) -> lock.InfrahubLock | None:
19
+ name = self.get_lock_name(target_branch_name, source_branch_name, is_incremental)
20
+ return self.lock_registry.get_existing(name=name, namespace=self.lock_namespace)
21
+
22
+ def acquire_lock(
23
+ self, target_branch_name: str, source_branch_name: str, is_incremental: bool = False
24
+ ) -> lock.InfrahubLock:
25
+ name = self.get_lock_name(target_branch_name, source_branch_name, is_incremental)
26
+ return self.lock_registry.get(name=name, namespace=self.lock_namespace)
@@ -1 +1 @@
1
- GRAPH_VERSION = 31
1
+ GRAPH_VERSION = 33
@@ -6,6 +6,7 @@ from infrahub import config, lock
6
6
  from infrahub.constants.database import DatabaseType
7
7
  from infrahub.core import registry
8
8
  from infrahub.core.branch import Branch
9
+ from infrahub.core.branch.enums import BranchStatus
9
10
  from infrahub.core.constants import (
10
11
  DEFAULT_IP_NAMESPACE,
11
12
  GLOBAL_BRANCH_NAME,
@@ -224,7 +225,7 @@ async def create_root_node(db: InfrahubDatabase) -> Root:
224
225
  async def create_default_branch(db: InfrahubDatabase) -> Branch:
225
226
  branch = Branch(
226
227
  name=registry.default_branch,
227
- status="OPEN",
228
+ status=BranchStatus.OPEN,
228
229
  description="Default Branch",
229
230
  hierarchy_level=1,
230
231
  is_default=True,
@@ -241,7 +242,7 @@ async def create_default_branch(db: InfrahubDatabase) -> Branch:
241
242
  async def create_global_branch(db: InfrahubDatabase) -> Branch:
242
243
  branch = Branch(
243
244
  name=GLOBAL_BRANCH_NAME,
244
- status="OPEN",
245
+ status=BranchStatus.OPEN,
245
246
  description="Global Branch",
246
247
  hierarchy_level=1,
247
248
  is_global=True,
@@ -264,7 +265,7 @@ async def create_branch(
264
265
  description = description or f"Branch {branch_name}"
265
266
  branch = Branch(
266
267
  name=branch_name,
267
- status="OPEN",
268
+ status=BranchStatus.OPEN,
268
269
  hierarchy_level=2,
269
270
  description=description,
270
271
  is_default=False,