infrahub-server 1.3.2__py3-none-any.whl → 1.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. infrahub/cli/db.py +194 -13
  2. infrahub/core/branch/enums.py +8 -0
  3. infrahub/core/branch/models.py +28 -5
  4. infrahub/core/branch/tasks.py +5 -7
  5. infrahub/core/diff/coordinator.py +32 -34
  6. infrahub/core/diff/diff_locker.py +26 -0
  7. infrahub/core/graph/__init__.py +1 -1
  8. infrahub/core/initialization.py +4 -3
  9. infrahub/core/merge.py +31 -16
  10. infrahub/core/migrations/graph/__init__.py +24 -0
  11. infrahub/core/migrations/graph/m012_convert_account_generic.py +4 -3
  12. infrahub/core/migrations/graph/m013_convert_git_password_credential.py +4 -3
  13. infrahub/core/migrations/graph/m032_cleanup_orphaned_branch_relationships.py +105 -0
  14. infrahub/core/migrations/graph/m033_deduplicate_relationship_vertices.py +97 -0
  15. infrahub/core/node/__init__.py +3 -0
  16. infrahub/core/node/resource_manager/ip_address_pool.py +5 -3
  17. infrahub/core/node/resource_manager/ip_prefix_pool.py +7 -4
  18. infrahub/core/node/resource_manager/number_pool.py +3 -1
  19. infrahub/core/node/standard.py +4 -0
  20. infrahub/core/query/branch.py +25 -56
  21. infrahub/core/query/node.py +78 -24
  22. infrahub/core/query/relationship.py +11 -8
  23. infrahub/core/relationship/model.py +10 -5
  24. infrahub/dependencies/builder/diff/coordinator.py +3 -0
  25. infrahub/dependencies/builder/diff/locker.py +8 -0
  26. infrahub/graphql/mutations/main.py +7 -2
  27. infrahub/graphql/mutations/tasks.py +2 -0
  28. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/METADATA +1 -1
  29. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/RECORD +35 -30
  30. infrahub_testcontainers/container.py +1 -1
  31. infrahub_testcontainers/docker-compose-cluster.test.yml +3 -0
  32. infrahub_testcontainers/docker-compose.test.yml +1 -0
  33. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/LICENSE.txt +0 -0
  34. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/WHEEL +0 -0
  35. {infrahub_server-1.3.2.dist-info → infrahub_server-1.3.3.dist-info}/entry_points.txt +0 -0
infrahub/core/merge.py CHANGED
@@ -9,7 +9,7 @@ from infrahub.core.models import SchemaUpdateValidationResult
9
9
  from infrahub.core.protocols import CoreRepository
10
10
  from infrahub.core.registry import registry
11
11
  from infrahub.core.timestamp import Timestamp
12
- from infrahub.exceptions import ValidationError
12
+ from infrahub.exceptions import MergeFailedError, ValidationError
13
13
  from infrahub.log import get_logger
14
14
 
15
15
  from ..git.models import GitRepositoryMerge
@@ -18,6 +18,7 @@ from ..workflows.catalogue import GIT_REPOSITORIES_MERGE
18
18
  if TYPE_CHECKING:
19
19
  from infrahub.core.branch import Branch
20
20
  from infrahub.core.diff.coordinator import DiffCoordinator
21
+ from infrahub.core.diff.diff_locker import DiffLocker
21
22
  from infrahub.core.diff.merger.merger import DiffMerger
22
23
  from infrahub.core.diff.model.path import EnrichedDiffRoot
23
24
  from infrahub.core.diff.repository.repository import DiffRepository
@@ -39,6 +40,7 @@ class BranchMerger:
39
40
  diff_coordinator: DiffCoordinator,
40
41
  diff_merger: DiffMerger,
41
42
  diff_repository: DiffRepository,
43
+ diff_locker: DiffLocker,
42
44
  destination_branch: Branch | None = None,
43
45
  service: InfrahubServices | None = None,
44
46
  ):
@@ -48,6 +50,7 @@ class BranchMerger:
48
50
  self.diff_coordinator = diff_coordinator
49
51
  self.diff_merger = diff_merger
50
52
  self.diff_repository = diff_repository
53
+ self.diff_locker = diff_locker
51
54
  self.migrations: list[SchemaUpdateMigrationInfo] = []
52
55
  self._merge_at = Timestamp()
53
56
 
@@ -185,22 +188,34 @@ class BranchMerger:
185
188
  )
186
189
  log.info("Diff updated for merge")
187
190
 
188
- errors: list[str] = []
189
- async for conflict_path, conflict in self.diff_repository.get_all_conflicts_for_diff(
190
- diff_branch_name=self.source_branch.name, tracking_id=BranchTrackingId(name=self.source_branch.name)
191
+ log.info("Acquiring lock for merge")
192
+ async with self.diff_locker.acquire_lock(
193
+ target_branch_name=self.destination_branch.name,
194
+ source_branch_name=self.source_branch.name,
195
+ is_incremental=False,
191
196
  ):
192
- if conflict.selected_branch is None or conflict.resolvable is False:
193
- errors.append(conflict_path)
194
-
195
- if errors:
196
- raise ValidationError(
197
- f"Unable to merge the branch '{self.source_branch.name}', conflict resolution missing: {', '.join(errors)}"
198
- )
199
-
200
- # TODO need to find a way to properly communicate back to the user any issue that could come up during the merge
201
- # From the Graph or From the repositories
202
- self._merge_at = Timestamp(at)
203
- branch_diff = await self.diff_merger.merge_graph(at=self._merge_at)
197
+ log.info("Lock acquired for merge")
198
+ try:
199
+ errors: list[str] = []
200
+ async for conflict_path, conflict in self.diff_repository.get_all_conflicts_for_diff(
201
+ diff_branch_name=self.source_branch.name, tracking_id=BranchTrackingId(name=self.source_branch.name)
202
+ ):
203
+ if conflict.selected_branch is None or conflict.resolvable is False:
204
+ errors.append(conflict_path)
205
+
206
+ if errors:
207
+ raise ValidationError(
208
+ f"Unable to merge the branch '{self.source_branch.name}', conflict resolution missing: {', '.join(errors)}"
209
+ )
210
+
211
+ # TODO need to find a way to properly communicate back to the user any issue that could come up during the merge
212
+ # From the Graph or From the repositories
213
+ self._merge_at = Timestamp(at)
214
+ branch_diff = await self.diff_merger.merge_graph(at=self._merge_at)
215
+ except Exception as exc:
216
+ log.exception("Merge failed, beginning rollback")
217
+ await self.rollback()
218
+ raise MergeFailedError(branch_name=self.source_branch.name) from exc
204
219
  await self.merge_repositories()
205
220
  return branch_diff
206
221
 
@@ -33,6 +33,8 @@ from .m028_delete_diffs import Migration028
33
33
  from .m029_duplicates_cleanup import Migration029
34
34
  from .m030_illegal_edges import Migration030
35
35
  from .m031_check_number_attributes import Migration031
36
+ from .m032_cleanup_orphaned_branch_relationships import Migration032
37
+ from .m033_deduplicate_relationship_vertices import Migration033
36
38
 
37
39
  if TYPE_CHECKING:
38
40
  from infrahub.core.root import Root
@@ -71,6 +73,8 @@ MIGRATIONS: list[type[GraphMigration | InternalSchemaMigration | ArbitraryMigrat
71
73
  Migration029,
72
74
  Migration030,
73
75
  Migration031,
76
+ Migration032,
77
+ Migration033,
74
78
  ]
75
79
 
76
80
 
@@ -85,3 +89,23 @@ async def get_graph_migrations(
85
89
  applicable_migrations.append(migration)
86
90
 
87
91
  return applicable_migrations
92
+
93
+
94
+ def get_migration_by_number(
95
+ migration_number: int | str,
96
+ ) -> GraphMigration | InternalSchemaMigration | ArbitraryMigration:
97
+ # Convert to string and pad with zeros if needed
98
+ try:
99
+ num = int(migration_number)
100
+ migration_str = f"{num:03d}"
101
+ except (ValueError, TypeError) as exc:
102
+ raise ValueError(f"Invalid migration number: {migration_number}") from exc
103
+
104
+ migration_name = f"Migration{migration_str}"
105
+
106
+ # Find the migration in the MIGRATIONS list
107
+ for migration_class in MIGRATIONS:
108
+ if migration_class.__name__ == migration_name:
109
+ return migration_class.init()
110
+
111
+ raise ValueError(f"Migration {migration_number} not found")
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Any, Sequence
4
4
 
5
5
  from infrahub.core.branch import Branch
6
+ from infrahub.core.branch.enums import BranchStatus
6
7
  from infrahub.core.constants import GLOBAL_BRANCH_NAME, BranchSupportType, InfrahubKind
7
8
  from infrahub.core.migrations.shared import MigrationResult
8
9
  from infrahub.core.query import Query, QueryType
@@ -20,7 +21,7 @@ if TYPE_CHECKING:
20
21
 
21
22
  global_branch = Branch(
22
23
  name=GLOBAL_BRANCH_NAME,
23
- status="OPEN",
24
+ status=BranchStatus.OPEN,
24
25
  description="Global Branch",
25
26
  hierarchy_level=1,
26
27
  is_global=True,
@@ -29,7 +30,7 @@ global_branch = Branch(
29
30
 
30
31
  default_branch = Branch(
31
32
  name="main",
32
- status="OPEN",
33
+ status=BranchStatus.OPEN,
33
34
  description="Default Branch",
34
35
  hierarchy_level=1,
35
36
  is_global=False,
@@ -105,7 +106,7 @@ class Migration012AddLabelData(NodeDuplicateQuery):
105
106
 
106
107
  branch = Branch(
107
108
  name=GLOBAL_BRANCH_NAME,
108
- status="OPEN",
109
+ status=BranchStatus.OPEN,
109
110
  description="Global Branch",
110
111
  hierarchy_level=1,
111
112
  is_global=True,
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Any, Sequence
4
4
 
5
5
  from infrahub.core.branch import Branch
6
+ from infrahub.core.branch.enums import BranchStatus
6
7
  from infrahub.core.constants import (
7
8
  GLOBAL_BRANCH_NAME,
8
9
  BranchSupportType,
@@ -23,7 +24,7 @@ if TYPE_CHECKING:
23
24
 
24
25
  default_branch = Branch(
25
26
  name="main",
26
- status="OPEN",
27
+ status=BranchStatus.OPEN,
27
28
  description="Default Branch",
28
29
  hierarchy_level=1,
29
30
  is_global=False,
@@ -42,7 +43,7 @@ class Migration013ConvertCoreRepositoryWithCred(Query):
42
43
 
43
44
  global_branch = Branch(
44
45
  name=GLOBAL_BRANCH_NAME,
45
- status="OPEN",
46
+ status=BranchStatus.OPEN,
46
47
  description="Global Branch",
47
48
  hierarchy_level=1,
48
49
  is_global=True,
@@ -176,7 +177,7 @@ class Migration013ConvertCoreRepositoryWithoutCred(Query):
176
177
 
177
178
  global_branch = Branch(
178
179
  name=GLOBAL_BRANCH_NAME,
179
- status="OPEN",
180
+ status=BranchStatus.OPEN,
180
181
  description="Global Branch",
181
182
  hierarchy_level=1,
182
183
  is_global=True,
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from infrahub.core.migrations.shared import MigrationResult
6
+ from infrahub.core.query import Query, QueryType
7
+ from infrahub.core.query.branch import DeleteBranchRelationshipsQuery
8
+ from infrahub.log import get_logger
9
+
10
+ from ..shared import ArbitraryMigration
11
+
12
+ if TYPE_CHECKING:
13
+ from infrahub.database import InfrahubDatabase
14
+
15
+ log = get_logger()
16
+
17
+
18
+ class DeletedBranchCleanupQuery(Query):
19
+ """
20
+ Find all unique edge branch names for which there is no Branch object
21
+ """
22
+
23
+ name = "deleted_branch_cleanup"
24
+ type = QueryType.WRITE
25
+ insert_return = False
26
+
27
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
28
+ query = """
29
+ MATCH (b:Branch)
30
+ WITH collect(DISTINCT b.name) AS branch_names
31
+ MATCH ()-[e]->()
32
+ WHERE e.branch IS NOT NULL
33
+ AND NOT e.branch IN branch_names
34
+ RETURN DISTINCT (e.branch) AS branch_name
35
+ """
36
+ self.add_to_query(query)
37
+ self.return_labels = ["branch_name"]
38
+
39
+
40
+ class DeleteOrphanRelationshipsQuery(Query):
41
+ """
42
+ Find all Relationship vertices that link to fewer than 2 Node vertices and delete them
43
+ """
44
+
45
+ name = "delete_orphan_relationships"
46
+ type = QueryType.WRITE
47
+ insert_return = False
48
+
49
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
50
+ query = """
51
+ MATCH (r:Relationship)-[:IS_RELATED]-(n:Node)
52
+ WITH DISTINCT r, n
53
+ WITH r, count(*) AS node_count
54
+ WHERE node_count < 2
55
+ DETACH DELETE r
56
+ """
57
+ self.add_to_query(query)
58
+
59
+
60
+ class Migration032(ArbitraryMigration):
61
+ """
62
+ Delete edges for branches that were not completely deleted
63
+ """
64
+
65
+ name: str = "032_cleanup_deleted_branches"
66
+ minimum_version: int = 31
67
+
68
+ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
69
+ return MigrationResult()
70
+
71
+ async def execute(self, db: InfrahubDatabase) -> MigrationResult:
72
+ migration_result = MigrationResult()
73
+
74
+ try:
75
+ log.info("Get partially deleted branch names...")
76
+ orphaned_branches_query = await DeletedBranchCleanupQuery.init(db=db)
77
+ await orphaned_branches_query.execute(db=db)
78
+
79
+ orphaned_branch_names = []
80
+ for result in orphaned_branches_query.get_results():
81
+ branch_name = result.get_as_type("branch_name", str)
82
+ orphaned_branch_names.append(branch_name)
83
+
84
+ if not orphaned_branch_names:
85
+ log.info("No partially deleted branches found. All done.")
86
+ return migration_result
87
+
88
+ log.info(f"Found {len(orphaned_branch_names)} orphaned branch names: {orphaned_branch_names}")
89
+
90
+ for branch_name in orphaned_branch_names:
91
+ log.info(f"Cleaning up branch '{branch_name}'...")
92
+ delete_query = await DeleteBranchRelationshipsQuery.init(db=db, branch_name=branch_name)
93
+ await delete_query.execute(db=db)
94
+ log.info(f"Branch '{branch_name}' cleaned up.")
95
+
96
+ log.info("Deleting orphaned relationships...")
97
+ delete_relationships_query = await DeleteOrphanRelationshipsQuery.init(db=db)
98
+ await delete_relationships_query.execute(db=db)
99
+ log.info("Orphaned relationships deleted.")
100
+
101
+ except Exception as exc:
102
+ migration_result.errors.append(str(exc))
103
+ log.exception("Error during branch cleanup")
104
+
105
+ return migration_result
@@ -0,0 +1,97 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Sequence
4
+
5
+ from infrahub.core.migrations.shared import GraphMigration, MigrationResult
6
+ from infrahub.log import get_logger
7
+
8
+ from ...query import Query, QueryType
9
+
10
+ if TYPE_CHECKING:
11
+ from infrahub.database import InfrahubDatabase
12
+
13
+ log = get_logger()
14
+
15
+
16
+ class DeduplicateRelationshipVerticesQuery(Query):
17
+ """
18
+ For each group of duplicate Relationships with the same UUID, delete any Relationship that meets the following criteria:
19
+ - is linked to a deleted node (only if the delete time is before the Relationship's from time)
20
+ - is linked to a node on an incorrect branch (ie Relationship added on main, but Node is on a branch)
21
+ """
22
+
23
+ name = "deduplicate_relationship_vertices"
24
+ type = QueryType.WRITE
25
+ insert_return = False
26
+
27
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
28
+ query = """
29
+ MATCH (root:Root)
30
+ WITH root.default_branch AS default_branch_name
31
+ // ------------
32
+ // Find all Relationship vertices with duplicate UUIDs
33
+ // ------------
34
+ MATCH (r:Relationship)
35
+ WITH r.uuid AS r_uuid, default_branch_name, count(*) AS num_dups
36
+ WHERE num_dups > 1
37
+ WITH DISTINCT r_uuid, default_branch_name
38
+ // ------------
39
+ // get the branched_from time for each relationship edge and node
40
+ // ------------
41
+ MATCH (rel:Relationship {uuid: r_uuid})
42
+ CALL (rel) {
43
+ MATCH (rel)-[is_rel_e:IS_RELATED {status: "active"}]-(n:Node)
44
+ MATCH (rel_branch:Branch {name: is_rel_e.branch})
45
+ RETURN is_rel_e, rel_branch.branched_from AS rel_branched_from, n
46
+ }
47
+ // ------------
48
+ // for each IS_RELATED edge of the relationship, check the IS_PART_OF edges of the Node vertex
49
+ // to determine if this side of the relationship is legal
50
+ // ------------
51
+ CALL (n, is_rel_e, rel_branched_from, default_branch_name) {
52
+ OPTIONAL MATCH (n)-[is_part_of_e:IS_PART_OF {status: "active"}]->(:Root)
53
+ WHERE (
54
+ // the Node's create time must precede the Relationship's create time
55
+ is_part_of_e.from <= is_rel_e.from AND (is_part_of_e.to >= is_rel_e.from OR is_part_of_e.to IS NULL)
56
+ // the Node must have been created on a branch of equal or lesser depth than the Relationship
57
+ AND is_part_of_e.branch_level <= is_rel_e.branch_level
58
+ // if the Node and Relationships were created on branch_level = 2, then they must be on the same branch
59
+ AND (
60
+ is_part_of_e.branch_level = 1
61
+ OR is_part_of_e.branch = is_rel_e.branch
62
+ )
63
+ // if the Node was created on the default branch, and the Relationship was created on a branch,
64
+ // then the Node must have been created after the branched_from time of the Relationship's branch
65
+ AND (
66
+ is_part_of_e.branch <> default_branch_name
67
+ OR is_rel_e.branch_level = 1
68
+ OR is_part_of_e.from <= rel_branched_from
69
+ )
70
+ )
71
+ WITH is_part_of_e IS NOT NULL AS is_legal
72
+ ORDER BY is_legal DESC
73
+ RETURN is_legal
74
+ LIMIT 1
75
+ }
76
+ WITH rel, is_legal
77
+ ORDER BY rel, is_legal ASC
78
+ WITH rel, head(collect(is_legal)) AS is_legal
79
+ WHERE is_legal = false
80
+ DETACH DELETE rel
81
+ """
82
+ self.add_to_query(query)
83
+
84
+
85
+ class Migration033(GraphMigration):
86
+ """
87
+ Identifies duplicate Relationship vertices that have the same UUID property. Deletes any duplicates that
88
+ are linked to deleted nodes or nodes on in incorrect branch.
89
+ """
90
+
91
+ name: str = "033_deduplicate_relationship_vertices"
92
+ minimum_version: int = 31
93
+ queries: Sequence[type[Query]] = [DeduplicateRelationshipVerticesQuery]
94
+
95
+ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
96
+ result = MigrationResult()
97
+ return result
@@ -82,6 +82,9 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
82
82
  def get_schema(self) -> NonGenericSchemaTypes:
83
83
  return self._schema
84
84
 
85
+ def get_branch(self) -> Branch:
86
+ return self._branch
87
+
85
88
  def get_kind(self) -> str:
86
89
  """Return the main Kind of the Object."""
87
90
  return self._schema.kind
@@ -18,6 +18,7 @@ from .. import Node
18
18
  if TYPE_CHECKING:
19
19
  from infrahub.core.branch import Branch
20
20
  from infrahub.core.ipam.constants import IPAddressType
21
+ from infrahub.core.timestamp import Timestamp
21
22
  from infrahub.database import InfrahubDatabase
22
23
 
23
24
 
@@ -30,6 +31,7 @@ class CoreIPAddressPool(Node):
30
31
  data: dict[str, Any] | None = None,
31
32
  address_type: str | None = None,
32
33
  prefixlen: int | None = None,
34
+ at: Timestamp | None = None,
33
35
  ) -> Node:
34
36
  # Check if there is already a resource allocated with this identifier
35
37
  # if not, pull all existing prefixes and allocated the next available
@@ -63,18 +65,18 @@ class CoreIPAddressPool(Node):
63
65
  next_address = await self.get_next(db=db, prefixlen=prefixlen)
64
66
 
65
67
  target_schema = registry.get_node_schema(name=address_type, branch=branch)
66
- node = await Node.init(db=db, schema=target_schema, branch=branch)
68
+ node = await Node.init(db=db, schema=target_schema, branch=branch, at=at)
67
69
  try:
68
70
  await node.new(db=db, address=str(next_address), ip_namespace=ip_namespace, **data)
69
71
  except ValidationError as exc:
70
72
  raise ValueError(f"IPAddressPool: {self.name.value} | {exc!s}") from exc # type: ignore[attr-defined]
71
- await node.save(db=db)
73
+ await node.save(db=db, at=at)
72
74
  reconciler = IpamReconciler(db=db, branch=branch)
73
75
  await reconciler.reconcile(ip_value=next_address, namespace=ip_namespace.id, node_uuid=node.get_id())
74
76
 
75
77
  if identifier:
76
78
  query_set = await IPAddressPoolSetReserved.init(
77
- db=db, pool_id=self.id, identifier=identifier, address_id=node.id
79
+ db=db, pool_id=self.id, identifier=identifier, address_id=node.id, at=at
78
80
  )
79
81
  await query_set.execute(db=db)
80
82
 
@@ -20,6 +20,7 @@ from .. import Node
20
20
  if TYPE_CHECKING:
21
21
  from infrahub.core.branch import Branch
22
22
  from infrahub.core.ipam.constants import IPNetworkType
23
+ from infrahub.core.timestamp import Timestamp
23
24
  from infrahub.database import InfrahubDatabase
24
25
 
25
26
 
@@ -33,6 +34,7 @@ class CoreIPPrefixPool(Node):
33
34
  prefixlen: int | None = None,
34
35
  member_type: str | None = None,
35
36
  prefix_type: str | None = None,
37
+ at: Timestamp | None = None,
36
38
  ) -> Node:
37
39
  # Check if there is already a resource allocated with this identifier
38
40
  # if not, pull all existing prefixes and allocated the next available
@@ -68,20 +70,21 @@ class CoreIPPrefixPool(Node):
68
70
  )
69
71
 
70
72
  member_type = member_type or data.get("member_type", None) or self.default_member_type.value.value # type: ignore[attr-defined]
73
+ data["member_type"] = member_type
71
74
 
72
75
  target_schema = registry.get_node_schema(name=prefix_type, branch=branch)
73
- node = await Node.init(db=db, schema=target_schema, branch=branch)
76
+ node = await Node.init(db=db, schema=target_schema, branch=branch, at=at)
74
77
  try:
75
- await node.new(db=db, prefix=str(next_prefix), member_type=member_type, ip_namespace=ip_namespace, **data)
78
+ await node.new(db=db, prefix=str(next_prefix), ip_namespace=ip_namespace, **data)
76
79
  except ValidationError as exc:
77
80
  raise ValueError(f"IPPrefixPool: {self.name.value} | {exc!s}") from exc # type: ignore[attr-defined]
78
- await node.save(db=db)
81
+ await node.save(db=db, at=at)
79
82
  reconciler = IpamReconciler(db=db, branch=branch)
80
83
  await reconciler.reconcile(ip_value=next_prefix, namespace=ip_namespace.id, node_uuid=node.get_id())
81
84
 
82
85
  if identifier:
83
86
  query_set = await PrefixPoolSetReserved.init(
84
- db=db, pool_id=self.id, identifier=identifier, prefix_id=node.id
87
+ db=db, pool_id=self.id, identifier=identifier, prefix_id=node.id, at=at
85
88
  )
86
89
  await query_set.execute(db=db)
87
90
 
@@ -12,6 +12,7 @@ from .. import Node
12
12
  if TYPE_CHECKING:
13
13
  from infrahub.core.attribute import BaseAttribute
14
14
  from infrahub.core.branch import Branch
15
+ from infrahub.core.timestamp import Timestamp
15
16
  from infrahub.database import InfrahubDatabase
16
17
 
17
18
 
@@ -41,6 +42,7 @@ class CoreNumberPool(Node):
41
42
  node: Node,
42
43
  attribute: BaseAttribute,
43
44
  identifier: str | None = None,
45
+ at: Timestamp | None = None,
44
46
  ) -> int:
45
47
  identifier = identifier or node.get_id()
46
48
  # Check if there is already a resource allocated with this identifier
@@ -56,7 +58,7 @@ class CoreNumberPool(Node):
56
58
  number = await self.get_next(db=db, branch=branch, attribute=attribute)
57
59
 
58
60
  query_set = await NumberPoolSetReserved.init(
59
- db=db, pool_id=self.get_id(), identifier=identifier, reserved=number
61
+ db=db, pool_id=self.get_id(), identifier=identifier, reserved=number, at=at
60
62
  )
61
63
  await query_set.execute(db=db)
62
64
  return number
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ from enum import Enum
4
5
  from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin
5
6
  from uuid import UUID
6
7
 
@@ -191,6 +192,9 @@ class StandardNode(BaseModel):
191
192
  continue
192
193
 
193
194
  attr_value = getattr(self, attr_name)
195
+ if isinstance(attr_value, Enum):
196
+ attr_value = attr_value.value
197
+
194
198
  field_type = self.guess_field_type(field)
195
199
 
196
200
  if attr_value is None:
@@ -3,43 +3,12 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Any
4
4
 
5
5
  from infrahub import config
6
- from infrahub.core.constants import RelationshipStatus
7
6
  from infrahub.core.query import Query, QueryType
8
7
 
9
8
  if TYPE_CHECKING:
10
9
  from infrahub.database import InfrahubDatabase
11
10
 
12
11
 
13
- class AddNodeToBranch(Query):
14
- name: str = "node_add_to_branch"
15
- insert_return: bool = False
16
-
17
- type: QueryType = QueryType.WRITE
18
-
19
- def __init__(self, node_id: int, **kwargs: Any):
20
- self.node_id = node_id
21
- super().__init__(**kwargs)
22
-
23
- async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa: ARG002
24
- query = """
25
- MATCH (root:Root)
26
- MATCH (d) WHERE %(id_func)s(d) = $node_id
27
- WITH root,d
28
- CREATE (d)-[r:IS_PART_OF { branch: $branch, branch_level: $branch_level, from: $now, status: $status }]->(root)
29
- RETURN %(id_func)s(r)
30
- """ % {
31
- "id_func": db.get_id_function_name(),
32
- }
33
-
34
- self.params["node_id"] = db.to_database_id(self.node_id)
35
- self.params["now"] = self.at.to_string()
36
- self.params["branch"] = self.branch.name
37
- self.params["branch_level"] = self.branch.hierarchy_level
38
- self.params["status"] = RelationshipStatus.ACTIVE.value
39
-
40
- self.add_to_query(query)
41
-
42
-
43
12
  class DeleteBranchRelationshipsQuery(Query):
44
13
  name: str = "delete_branch_relationships"
45
14
  insert_return: bool = False
@@ -52,31 +21,31 @@ class DeleteBranchRelationshipsQuery(Query):
52
21
 
53
22
  async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: # noqa: ARG002
54
23
  query = """
55
- MATCH (s)-[r1]-(d)
56
- WHERE r1.branch = $branch_name
57
- DELETE r1
58
-
59
- WITH collect(DISTINCT s) + collect(DISTINCT d) AS nodes
60
-
61
- // Collect node IDs for filtering
62
- WITH nodes, [n in nodes | n.uuid] as nodes_uuids
63
-
64
- // Also delete agnostic relationships that would not have been deleted above
65
- MATCH (s2: Node)-[r2]-(d2)
66
- WHERE NOT exists((s2)-[:IS_PART_OF]-(:Root))
67
- AND s2.uuid IN nodes_uuids
68
- DELETE r2
69
-
70
- WITH nodes, collect(DISTINCT s2) + collect(DISTINCT d2) as additional_nodes
71
-
72
- WITH nodes + additional_nodes as nodes
73
-
74
- // Delete nodes that are no longer connected to any other nodes
75
- UNWIND nodes AS n
76
- WITH DISTINCT n
77
- MATCH (n)
78
- WHERE NOT exists((n)--())
79
- DELETE n
24
+ // delete all relationships on this branch
25
+ MATCH (s)-[r1]-(d)
26
+ WHERE r1.branch = $branch_name
27
+ CALL (r1) {
28
+ DELETE r1
29
+ } IN TRANSACTIONS
30
+
31
+ // check for any orphaned Node vertices and delete them
32
+ WITH collect(DISTINCT s.uuid) + collect(DISTINCT d.uuid) AS nodes_uuids
33
+ MATCH (s2:Node)-[r2]-(d2)
34
+ WHERE NOT exists((s2)-[:IS_PART_OF]-(:Root))
35
+ AND s2.uuid IN nodes_uuids
36
+ CALL (r2) {
37
+ DELETE r2
38
+ } IN TRANSACTIONS
39
+
40
+ // reduce results to a single row
41
+ WITH 1 AS one LIMIT 1
42
+
43
+ // find any orphaned vertices and delete them
44
+ MATCH (n)
45
+ WHERE NOT exists((n)--())
46
+ CALL (n) {
47
+ DELETE n
48
+ } IN TRANSACTIONS
80
49
  """
81
50
  self.params["branch_name"] = self.branch_name
82
51
  self.add_to_query(query)