infrahub-server 1.3.5__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
infrahub/api/schema.py CHANGED
@@ -36,6 +36,7 @@ from infrahub.events import EventMeta
36
36
  from infrahub.events.schema_action import SchemaUpdatedEvent
37
37
  from infrahub.exceptions import MigrationError
38
38
  from infrahub.log import get_log_data, get_logger
39
+ from infrahub.permissions import define_global_permission_from_branch
39
40
  from infrahub.types import ATTRIBUTE_PYTHON_TYPES
40
41
  from infrahub.worker import WORKER_IDENTITY
41
42
  from infrahub.workflows.catalogue import SCHEMA_APPLY_MIGRATION, SCHEMA_VALIDATE_MIGRATION
@@ -287,13 +288,8 @@ async def load_schema(
287
288
  context: InfrahubContext = Depends(get_context),
288
289
  ) -> SchemaUpdate:
289
290
  permission_manager.raise_for_permission(
290
- permission=GlobalPermission(
291
- action=GlobalPermissions.MANAGE_SCHEMA.value,
292
- decision=(
293
- PermissionDecision.ALLOW_DEFAULT
294
- if branch.name in (GLOBAL_BRANCH_NAME, registry.default_branch)
295
- else PermissionDecision.ALLOW_OTHER
296
- ).value,
291
+ permission=define_global_permission_from_branch(
292
+ permission=GlobalPermissions.MANAGE_SCHEMA, branch_name=branch.name
297
293
  )
298
294
  )
299
295
 
infrahub/cli/db.py CHANGED
@@ -56,6 +56,7 @@ from infrahub.services.adapters.message_bus.local import BusSimulator
56
56
  from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
57
57
 
58
58
  from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
59
+ from .db_commands.check_inheritance import check_inheritance
59
60
  from .patch import patch_app
60
61
 
61
62
 
@@ -178,6 +179,30 @@ async def migrate_cmd(
178
179
  await dbdriver.close()
179
180
 
180
181
 
182
+ @app.command(name="check-inheritance")
183
+ async def check_inheritance_cmd(
184
+ ctx: typer.Context,
185
+ fix: bool = typer.Option(False, help="Fix the inheritance of any invalid nodes."),
186
+ config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
187
+ ) -> None:
188
+ """Check the database for any vertices with incorrect inheritance"""
189
+ logging.getLogger("infrahub").setLevel(logging.WARNING)
190
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
191
+ logging.getLogger("prefect").setLevel(logging.ERROR)
192
+
193
+ config.load_and_exit(config_file_name=config_file)
194
+
195
+ context: CliContext = ctx.obj
196
+ dbdriver = await context.init_db(retry=1)
197
+ await initialize_registry(db=dbdriver)
198
+
199
+ success = await check_inheritance(db=dbdriver, fix=fix)
200
+ if not success:
201
+ raise typer.Exit(code=1)
202
+
203
+ await dbdriver.close()
204
+
205
+
181
206
  @app.command(name="update-core-schema")
182
207
  async def update_core_schema_cmd(
183
208
  ctx: typer.Context,
File without changes
@@ -0,0 +1,284 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from rich import print as rprint
8
+ from rich.console import Console
9
+ from rich.table import Table
10
+
11
+ from infrahub.core import registry
12
+ from infrahub.core.branch.models import Branch
13
+ from infrahub.core.constants import InfrahubKind
14
+ from infrahub.core.migrations.query.node_duplicate import NodeDuplicateQuery, SchemaNodeInfo
15
+ from infrahub.core.query import Query, QueryType
16
+ from infrahub.core.schema import SchemaRoot, internal_schema
17
+ from infrahub.core.schema.manager import SchemaManager
18
+ from infrahub.log import get_logger
19
+
20
+ from ..constants import FAILED_BADGE, SUCCESS_BADGE
21
+
22
+ if TYPE_CHECKING:
23
+ from infrahub.core.schema.node_schema import NodeSchema
24
+ from infrahub.database import InfrahubDatabase
25
+
26
+ log = get_logger()
27
+
28
+
29
+ class GetSchemaWithUpdatedInheritance(Query):
30
+ """
31
+ Get the name, namespace, and branch of any SchemaNodes with _updated_ inheritance
32
+ This query will only return schemas that have had `inherit_from` updated after they were created
33
+ """
34
+
35
+ name = "get_schema_with_updated_inheritance"
36
+ type = QueryType.READ
37
+ insert_return = False
38
+
39
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
40
+ query = """
41
+ // find inherit_from attributes that have been updated
42
+ MATCH p = (schema_node:SchemaNode)-[has_attr_e:HAS_ATTRIBUTE {status: "active"}]->(a:Attribute {name: "inherit_from"})
43
+ WHERE has_attr_e.to IS NULL
44
+ CALL (a) {
45
+ // only get branches on which the value was updated, we can ignore the initial create
46
+ MATCH (a)-[e:HAS_VALUE]->(:AttributeValue)
47
+ ORDER BY e.from ASC
48
+ // tail leaves out the earliest one, which is the initial create
49
+ RETURN tail(collect(e.branch)) AS branches
50
+ }
51
+ WITH schema_node, a, branches
52
+ WHERE size(branches) > 0
53
+ UNWIND branches AS branch
54
+ WITH DISTINCT schema_node, a, branch
55
+
56
+ //get branch details
57
+ CALL (branch) {
58
+ MATCH (b:Branch {name: branch})
59
+ RETURN b.branched_from AS branched_from, b.hierarchy_level AS branch_level
60
+ }
61
+
62
+ // get the namespace for the schema
63
+ CALL (schema_node, a, branch, branched_from, branch_level) {
64
+ MATCH (schema_node)-[e1:HAS_ATTRIBUTE]-(:Attribute {name: "namespace"})-[e2:HAS_VALUE]->(av)
65
+ WHERE (
66
+ e1.branch = branch OR
67
+ (e1.branch_level < branch_level AND e1.from <= branched_from)
68
+ ) AND e1.to IS NULL
69
+ AND e1.status = "active"
70
+ AND (
71
+ e2.branch = branch OR
72
+ (e2.branch_level < branch_level AND e2.from <= branched_from)
73
+ ) AND e2.to IS NULL
74
+ AND e2.status = "active"
75
+ ORDER BY e2.branch_level DESC, e1.branch_level DESC, e2.from DESC, e1.from DESC
76
+ RETURN av.value AS namespace
77
+ LIMIT 1
78
+ }
79
+
80
+ // get the name for the schema
81
+ CALL (schema_node, a, branch, branched_from, branch_level) {
82
+ MATCH (schema_node)-[e1:HAS_ATTRIBUTE]-(:Attribute {name: "name"})-[e2:HAS_VALUE]->(av)
83
+ WHERE (
84
+ e1.branch = branch OR
85
+ (e1.branch_level < branch_level AND e1.from <= branched_from)
86
+ ) AND e1.to IS NULL
87
+ AND e1.status = "active"
88
+ AND (
89
+ e2.branch = branch OR
90
+ (e2.branch_level < branch_level AND e2.from <= branched_from)
91
+ ) AND e2.to IS NULL
92
+ AND e2.status = "active"
93
+ ORDER BY e2.branch_level DESC, e1.branch_level DESC, e2.from DESC, e1.from DESC
94
+ RETURN av.value AS name
95
+ LIMIT 1
96
+ }
97
+ RETURN name, namespace, branch
98
+ """
99
+ self.return_labels = ["name", "namespace", "branch"]
100
+ self.add_to_query(query)
101
+
102
+ def get_updated_inheritance_kinds_by_branch(self) -> dict[str, list[str]]:
103
+ kinds_by_branch: dict[str, list[str]] = defaultdict(list)
104
+ for result in self.results:
105
+ name = result.get_as_type(label="name", return_type=str)
106
+ namespace = result.get_as_type(label="namespace", return_type=str)
107
+ branch = result.get_as_type(label="branch", return_type=str)
108
+ kinds_by_branch[branch].append(f"{namespace}{name}")
109
+ return kinds_by_branch
110
+
111
+
112
+ @dataclass
113
+ class KindLabelCount:
114
+ kind: str
115
+ labels: frozenset[str]
116
+ num_nodes: int
117
+
118
+
119
+ @dataclass
120
+ class KindLabelCountCorrected(KindLabelCount):
121
+ node_schema: NodeSchema
122
+
123
+
124
+ class GetAllKindsAndLabels(Query):
125
+ """
126
+ Get the kind, labels, and number of nodes for the given kinds and branch
127
+ """
128
+
129
+ name = "get_all_kinds_and_labels"
130
+ type = QueryType.READ
131
+ insert_return = False
132
+
133
+ def __init__(self, kinds: list[str] | None = None, **kwargs: Any) -> None:
134
+ super().__init__(**kwargs)
135
+ self.kinds = kinds
136
+
137
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
138
+ self.params["branch_name"] = self.branch.name
139
+ self.params["branched_from"] = self.branch.get_branched_from()
140
+ self.params["branch_level"] = self.branch.hierarchy_level
141
+ kinds_str = "Node"
142
+ if self.kinds:
143
+ kinds_str = "|".join(self.kinds)
144
+ query = """
145
+ MATCH (n:%(kinds_str)s)-[r:IS_PART_OF]->(:Root)
146
+ WHERE (
147
+ r.branch = $branch_name OR
148
+ (r.branch_level < $branch_level AND r.from <= $branched_from)
149
+ )
150
+ AND r.to IS NULL
151
+ AND r.status = "active"
152
+ RETURN DISTINCT n.kind AS kind, labels(n) AS labels, count(*) AS num_nodes
153
+ ORDER BY kind ASC
154
+ """ % {"kinds_str": kinds_str}
155
+ self.return_labels = ["kind", "labels", "num_nodes"]
156
+ self.add_to_query(query)
157
+
158
+ def get_kind_label_counts(self) -> list[KindLabelCount]:
159
+ kind_label_counts: list[KindLabelCount] = []
160
+ for result in self.results:
161
+ kind = result.get_as_type(label="kind", return_type=str)
162
+ num_nodes = result.get_as_type(label="num_nodes", return_type=int)
163
+ labels: list[str] = result.get_as_type(label="labels", return_type=list)
164
+ # we can ignore the Node label and the label that matches the kind
165
+ cleaned_labels = frozenset(str(lbl) for lbl in labels if lbl not in ["Node", "CoreNode", kind])
166
+ kind_label_counts.append(KindLabelCount(kind=kind, labels=cleaned_labels, num_nodes=num_nodes))
167
+ return kind_label_counts
168
+
169
+
170
+ def display_kind_label_counts(kind_label_counts_by_branch: dict[str, list[KindLabelCountCorrected]]) -> None:
171
+ console = Console()
172
+
173
+ table = Table(title="Incorrect Inheritance Nodes")
174
+
175
+ table.add_column("Branch")
176
+ table.add_column("Kind")
177
+ table.add_column("Incorrect Labels")
178
+ table.add_column("Num Nodes")
179
+
180
+ for branch_name, kind_label_counts in kind_label_counts_by_branch.items():
181
+ for kind_label_count in kind_label_counts:
182
+ table.add_row(
183
+ branch_name, kind_label_count.kind, str(list(kind_label_count.labels)), str(kind_label_count.num_nodes)
184
+ )
185
+
186
+ console.print(table)
187
+
188
+
189
+ async def check_inheritance(db: InfrahubDatabase, fix: bool = False) -> bool:
190
+ """
191
+ Run migrations to update the inheritance of any nodes with incorrect inheritance from a failed migration
192
+ 1. Identifies node schemas that have had their inheritance updated after they were created
193
+ a. includes the kind and branch of the inheritance update
194
+ 2. Checks nodes of the given kinds on the given branch to verify their inheritance is correct
195
+ 3. Displays counts of any kinds with incorrect inheritance on the given branch
196
+ 4. If fix is True, runs migrations to update the inheritance of any nodes with incorrect inheritance
197
+ on the correct branch
198
+ """
199
+
200
+ updated_inheritance_query = await GetSchemaWithUpdatedInheritance.init(db=db)
201
+ await updated_inheritance_query.execute(db=db)
202
+ updated_inheritance_kinds_by_branch = updated_inheritance_query.get_updated_inheritance_kinds_by_branch()
203
+
204
+ if not updated_inheritance_kinds_by_branch:
205
+ rprint(f"{SUCCESS_BADGE} No schemas have had their inheritance updated")
206
+ return True
207
+
208
+ schema_manager = SchemaManager()
209
+ registry.schema = schema_manager
210
+ schema = SchemaRoot(**internal_schema)
211
+ schema_manager.register_schema(schema=schema)
212
+ branches_by_name = {b.name: b for b in await Branch.get_list(db=db)}
213
+
214
+ kind_label_counts_by_branch: dict[str, list[KindLabelCountCorrected]] = defaultdict(list)
215
+ for branch_name, kinds in updated_inheritance_kinds_by_branch.items():
216
+ rprint(f"Checking branch: {branch_name}", end="...")
217
+ branch = branches_by_name[branch_name]
218
+ schema_branch = await schema_manager.load_schema_from_db(db=db, branch=branch)
219
+ kind_label_query = await GetAllKindsAndLabels.init(db=db, branch=branch, kinds=kinds)
220
+ await kind_label_query.execute(db=db)
221
+ kind_label_counts = kind_label_query.get_kind_label_counts()
222
+
223
+ for kind_label_count in kind_label_counts:
224
+ node_schema = schema_branch.get_node(name=kind_label_count.kind, duplicate=False)
225
+ correct_labels = frozenset(node_schema.inherit_from)
226
+ if kind_label_count.labels == correct_labels:
227
+ continue
228
+
229
+ kind_label_counts_by_branch[branch_name].append(
230
+ KindLabelCountCorrected(
231
+ kind=kind_label_count.kind,
232
+ labels=kind_label_count.labels,
233
+ num_nodes=kind_label_count.num_nodes,
234
+ node_schema=node_schema,
235
+ )
236
+ )
237
+ rprint("done")
238
+
239
+ if not kind_label_counts_by_branch:
240
+ rprint(f"{SUCCESS_BADGE} All nodes have the correct inheritance")
241
+ return True
242
+
243
+ display_kind_label_counts(kind_label_counts_by_branch)
244
+
245
+ if not fix:
246
+ rprint(f"{FAILED_BADGE} Use the --fix flag to fix the inheritance of any invalid nodes")
247
+ return False
248
+
249
+ for branch_name, kind_label_counts_corrected in kind_label_counts_by_branch.items():
250
+ for kind_label_count in kind_label_counts_corrected:
251
+ rprint(f"Fixing kind {kind_label_count.kind} on branch {branch_name}", end="...")
252
+ node_schema = kind_label_count.node_schema
253
+ migration_query = await NodeDuplicateQuery.init(
254
+ db=db,
255
+ branch=branches_by_name[branch_name],
256
+ previous_node=SchemaNodeInfo(
257
+ name=node_schema.name,
258
+ namespace=node_schema.namespace,
259
+ branch_support=node_schema.branch.value,
260
+ labels=list(kind_label_count.labels) + [kind_label_count.kind, InfrahubKind.NODE],
261
+ kind=kind_label_count.kind,
262
+ ),
263
+ new_node=SchemaNodeInfo(
264
+ name=node_schema.name,
265
+ namespace=node_schema.namespace,
266
+ branch_support=node_schema.branch.value,
267
+ labels=list(node_schema.inherit_from) + [kind_label_count.kind, InfrahubKind.NODE],
268
+ kind=kind_label_count.kind,
269
+ ),
270
+ )
271
+ await migration_query.execute(db=db)
272
+ rprint("done")
273
+
274
+ rprint(f"{SUCCESS_BADGE} All nodes have the correct inheritance")
275
+
276
+ if registry.default_branch in kind_label_counts_by_branch:
277
+ kinds = [kind_label_count.kind for kind_label_count in kind_label_counts_by_branch[registry.default_branch]]
278
+ rprint(
279
+ "[bold cyan]Note that migrations were run on the default branch for the following schema kinds: "
280
+ f"{', '.join(kinds)}. You should rebase any branches that include/will include changes using "
281
+ "the migrated schemas[/bold cyan]"
282
+ )
283
+
284
+ return True
@@ -1 +1 @@
1
- GRAPH_VERSION = 34
1
+ GRAPH_VERSION = 35
infrahub/core/manager.py CHANGED
@@ -400,7 +400,7 @@ class NodeManager:
400
400
 
401
401
  results = []
402
402
  for peer in peers_info:
403
- result = await Relationship(schema=schema, branch=branch, at=at, node_id=peer.source_id).load(
403
+ result = Relationship(schema=schema, branch=branch, at=at, node_id=peer.source_id).load(
404
404
  db=db,
405
405
  id=peer.rel_node_id,
406
406
  db_id=peer.rel_node_db_id,
@@ -408,7 +408,7 @@ class NodeManager:
408
408
  data=peer,
409
409
  )
410
410
  if fetch_peers:
411
- await result.set_peer(value=peer_nodes[peer.peer_id])
411
+ result.set_peer(value=peer_nodes[peer.peer_id])
412
412
  results.append(result)
413
413
 
414
414
  return results
@@ -36,6 +36,7 @@ from .m031_check_number_attributes import Migration031
36
36
  from .m032_cleanup_orphaned_branch_relationships import Migration032
37
37
  from .m033_deduplicate_relationship_vertices import Migration033
38
38
  from .m034_find_orphaned_schema_fields import Migration034
39
+ from .m035_orphan_relationships import Migration035
39
40
 
40
41
  if TYPE_CHECKING:
41
42
  from infrahub.core.root import Root
@@ -77,6 +78,7 @@ MIGRATIONS: list[type[GraphMigration | InternalSchemaMigration | ArbitraryMigrat
77
78
  Migration032,
78
79
  Migration033,
79
80
  Migration034,
81
+ Migration035,
80
82
  ]
81
83
 
82
84
 
@@ -89,7 +89,7 @@ class Migration033(GraphMigration):
89
89
  """
90
90
 
91
91
  name: str = "033_deduplicate_relationship_vertices"
92
- minimum_version: int = 31
92
+ minimum_version: int = 32
93
93
  queries: Sequence[type[Query]] = [DeduplicateRelationshipVerticesQuery]
94
94
 
95
95
  async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Sequence
4
+
5
+ from infrahub.core.migrations.shared import GraphMigration, MigrationResult
6
+
7
+ from ...query import Query, QueryType
8
+
9
+ if TYPE_CHECKING:
10
+ from infrahub.database import InfrahubDatabase
11
+
12
+
13
+ class CleanupOrphanedRelationshipsQuery(Query):
14
+ name = "cleanup_orphaned_relationships"
15
+ type = QueryType.WRITE
16
+ insert_return = False
17
+
18
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
19
+ query = """
20
+ MATCH (rel:Relationship)-[:IS_RELATED]-(peer:Node)
21
+ WITH DISTINCT rel, peer.uuid AS p_uuid
22
+ WITH rel, count(*) AS num_peers
23
+ WHERE num_peers < 2
24
+ DETACH DELETE rel
25
+ """
26
+ self.add_to_query(query)
27
+
28
+
29
+ class Migration035(GraphMigration):
30
+ """
31
+ Remove Relationship vertices that only have a single peer
32
+ """
33
+
34
+ name: str = "035_clean_up_orphaned_relationships"
35
+ minimum_version: int = 34
36
+ queries: Sequence[type[Query]] = [CleanupOrphanedRelationshipsQuery]
37
+
38
+ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
39
+ return MigrationResult()
40
+
41
+ async def execute(self, db: InfrahubDatabase) -> MigrationResult:
42
+ # overrides parent class to skip transaction in case there are a lot of relationships to delete
43
+ return await self.do_execute(db=db)
@@ -21,6 +21,13 @@ class SchemaNodeInfo(BaseModel):
21
21
 
22
22
 
23
23
  class NodeDuplicateQuery(Query):
24
+ """
25
+ Duplicates a Node to use a new kind or inheritance.
26
+ Creates a copy of each affected Node and sets the new kind/inheritance.
27
+ Adds duplicate edges to the new Node that match all the active edges on the old Node.
28
+ Sets all the edges on the old Node to deleted.
29
+ """
30
+
24
31
  name = "node_duplicate"
25
32
  type = QueryType.WRITE
26
33
  insert_return: bool = False
@@ -38,11 +45,26 @@ class NodeDuplicateQuery(Query):
38
45
 
39
46
  def render_match(self) -> str:
40
47
  labels_str = ":".join(self.previous_node.labels)
41
- query = f"""
48
+ query = """
42
49
  // Find all the active nodes
43
- MATCH (node:{labels_str})
50
+ MATCH (node:%(labels_str)s)
44
51
  WITH DISTINCT node
45
- """
52
+ // ----------------
53
+ // Filter out nodes that have already been migrated
54
+ // ----------------
55
+ CALL (node) {
56
+ WITH labels(node) AS node_labels
57
+ UNWIND node_labels AS n_label
58
+ ORDER BY n_label ASC
59
+ WITH collect(n_label) AS sorted_labels
60
+
61
+ RETURN (
62
+ node.kind = $new_node.kind AND
63
+ sorted_labels = $new_sorted_labels
64
+ ) AS already_migrated
65
+ }
66
+ WITH node WHERE already_migrated = FALSE
67
+ """ % {"labels_str": labels_str}
46
68
 
47
69
  return query
48
70
 
@@ -111,6 +133,7 @@ class NodeDuplicateQuery(Query):
111
133
 
112
134
  self.params["new_node"] = self.new_node.model_dump()
113
135
  self.params["previous_node"] = self.previous_node.model_dump()
136
+ self.params["new_sorted_labels"] = sorted(self.new_node.labels + ["Node"])
114
137
 
115
138
  self.params["current_time"] = self.at.to_string()
116
139
  self.params["branch"] = self.branch.name
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import ipaddress
4
4
  from typing import TYPE_CHECKING, Any
5
5
 
6
+ from infrahub import lock
6
7
  from infrahub.core import registry
7
8
  from infrahub.core.ipam.reconciler import IpamReconciler
8
9
  from infrahub.core.query.ipam import get_ip_addresses
@@ -33,54 +34,55 @@ class CoreIPAddressPool(Node):
33
34
  prefixlen: int | None = None,
34
35
  at: Timestamp | None = None,
35
36
  ) -> Node:
36
- # Check if there is already a resource allocated with this identifier
37
- # if not, pull all existing prefixes and allocated the next available
38
-
39
- if identifier:
40
- query_get = await IPAddressPoolGetReserved.init(db=db, pool_id=self.id, identifier=identifier)
41
- await query_get.execute(db=db)
42
- result = query_get.get_result()
43
-
44
- if result:
45
- address = result.get_node("address")
46
- # TODO add support for branch, if the node is reserved with this id in another branch we should return an error
47
- node = await registry.manager.get_one(db=db, id=address.get("uuid"), branch=branch)
48
-
49
- if node:
50
- return node
51
-
52
- data = data or {}
53
-
54
- address_type = address_type or data.get("address_type") or self.default_address_type.value # type: ignore[attr-defined]
55
- if not address_type:
56
- raise ValueError(
57
- f"IPAddressPool: {self.name.value} | " # type: ignore[attr-defined]
58
- "An address_type or a default_value type must be provided to allocate a new IP address"
59
- )
60
-
61
- ip_namespace = await self.ip_namespace.get_peer(db=db) # type: ignore[attr-defined]
62
-
63
- prefixlen = prefixlen or data.get("prefixlen") or self.default_prefix_length.value # type: ignore[attr-defined]
64
-
65
- next_address = await self.get_next(db=db, prefixlen=prefixlen)
66
-
67
- target_schema = registry.get_node_schema(name=address_type, branch=branch)
68
- node = await Node.init(db=db, schema=target_schema, branch=branch, at=at)
69
- try:
70
- await node.new(db=db, address=str(next_address), ip_namespace=ip_namespace, **data)
71
- except ValidationError as exc:
72
- raise ValueError(f"IPAddressPool: {self.name.value} | {exc!s}") from exc # type: ignore[attr-defined]
73
- await node.save(db=db, at=at)
74
- reconciler = IpamReconciler(db=db, branch=branch)
75
- await reconciler.reconcile(ip_value=next_address, namespace=ip_namespace.id, node_uuid=node.get_id())
76
-
77
- if identifier:
78
- query_set = await IPAddressPoolSetReserved.init(
79
- db=db, pool_id=self.id, identifier=identifier, address_id=node.id, at=at
80
- )
81
- await query_set.execute(db=db)
82
-
83
- return node
37
+ async with lock.registry.get(name=self.get_id(), namespace="resource_pool"):
38
+ # Check if there is already a resource allocated with this identifier
39
+ # if not, pull all existing prefixes and allocated the next available
40
+
41
+ if identifier:
42
+ query_get = await IPAddressPoolGetReserved.init(db=db, pool_id=self.id, identifier=identifier)
43
+ await query_get.execute(db=db)
44
+ result = query_get.get_result()
45
+
46
+ if result:
47
+ address = result.get_node("address")
48
+ # TODO add support for branch, if the node is reserved with this id in another branch we should return an error
49
+ node = await registry.manager.get_one(db=db, id=address.get("uuid"), branch=branch)
50
+
51
+ if node:
52
+ return node
53
+
54
+ data = data or {}
55
+
56
+ address_type = address_type or data.get("address_type") or self.default_address_type.value # type: ignore[attr-defined]
57
+ if not address_type:
58
+ raise ValueError(
59
+ f"IPAddressPool: {self.name.value} | " # type: ignore[attr-defined]
60
+ "An address_type or a default_value type must be provided to allocate a new IP address"
61
+ )
62
+
63
+ ip_namespace = await self.ip_namespace.get_peer(db=db) # type: ignore[attr-defined]
64
+
65
+ prefixlen = prefixlen or data.get("prefixlen") or self.default_prefix_length.value # type: ignore[attr-defined]
66
+
67
+ next_address = await self.get_next(db=db, prefixlen=prefixlen)
68
+
69
+ target_schema = registry.get_node_schema(name=address_type, branch=branch)
70
+ node = await Node.init(db=db, schema=target_schema, branch=branch, at=at)
71
+ try:
72
+ await node.new(db=db, address=str(next_address), ip_namespace=ip_namespace, **data)
73
+ except ValidationError as exc:
74
+ raise ValueError(f"IPAddressPool: {self.name.value} | {exc!s}") from exc # type: ignore[attr-defined]
75
+ await node.save(db=db, at=at)
76
+ reconciler = IpamReconciler(db=db, branch=branch)
77
+ await reconciler.reconcile(ip_value=next_address, namespace=ip_namespace.id, node_uuid=node.get_id())
78
+
79
+ if identifier:
80
+ query_set = await IPAddressPoolSetReserved.init(
81
+ db=db, pool_id=self.id, identifier=identifier, address_id=node.id, at=at
82
+ )
83
+ await query_set.execute(db=db)
84
+
85
+ return node
84
86
 
85
87
  async def get_next(self, db: InfrahubDatabase, prefixlen: int | None = None) -> IPAddressType:
86
88
  resources = await self.resources.get_peers(db=db) # type: ignore[attr-defined]