infrahub-server 1.3.3__py3-none-any.whl → 1.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. infrahub/api/schema.py +2 -2
  2. infrahub/core/convert_object_type/conversion.py +10 -0
  3. infrahub/core/diff/enricher/hierarchy.py +7 -3
  4. infrahub/core/diff/query_parser.py +7 -3
  5. infrahub/core/graph/__init__.py +1 -1
  6. infrahub/core/migrations/graph/__init__.py +2 -0
  7. infrahub/core/migrations/graph/m034_find_orphaned_schema_fields.py +84 -0
  8. infrahub/core/migrations/schema/node_attribute_add.py +55 -2
  9. infrahub/core/migrations/shared.py +37 -9
  10. infrahub/core/node/__init__.py +41 -21
  11. infrahub/core/node/resource_manager/number_pool.py +60 -22
  12. infrahub/core/query/resource_manager.py +117 -20
  13. infrahub/core/schema/__init__.py +5 -0
  14. infrahub/core/schema/attribute_parameters.py +6 -0
  15. infrahub/core/schema/attribute_schema.py +6 -0
  16. infrahub/core/schema/manager.py +5 -11
  17. infrahub/core/schema/relationship_schema.py +6 -0
  18. infrahub/core/schema/schema_branch.py +50 -11
  19. infrahub/core/validators/node/attribute.py +15 -0
  20. infrahub/core/validators/tasks.py +12 -4
  21. infrahub/graphql/queries/resource_manager.py +4 -4
  22. infrahub/tasks/registry.py +63 -35
  23. infrahub_sdk/client.py +7 -8
  24. infrahub_sdk/ctl/utils.py +3 -0
  25. infrahub_sdk/node/node.py +6 -6
  26. infrahub_sdk/node/relationship.py +43 -2
  27. infrahub_sdk/yaml.py +13 -7
  28. infrahub_server-1.3.4.dist-info/LICENSE.txt +201 -0
  29. {infrahub_server-1.3.3.dist-info → infrahub_server-1.3.4.dist-info}/METADATA +3 -3
  30. {infrahub_server-1.3.3.dist-info → infrahub_server-1.3.4.dist-info}/RECORD +32 -31
  31. infrahub_server-1.3.3.dist-info/LICENSE.txt +0 -661
  32. {infrahub_server-1.3.3.dist-info → infrahub_server-1.3.4.dist-info}/WHEEL +0 -0
  33. {infrahub_server-1.3.3.dist-info → infrahub_server-1.3.4.dist-info}/entry_points.txt +0 -0
infrahub/api/schema.py CHANGED
@@ -151,9 +151,9 @@ def evaluate_candidate_schemas(
151
151
  branch_schema: SchemaBranch, schemas_to_evaluate: SchemasLoadAPI
152
152
  ) -> tuple[SchemaBranch, SchemaUpdateValidationResult]:
153
153
  candidate_schema = branch_schema.duplicate()
154
- schema = _merge_candidate_schemas(schemas=schemas_to_evaluate.schemas)
155
-
156
154
  try:
155
+ schema = _merge_candidate_schemas(schemas=schemas_to_evaluate.schemas)
156
+
157
157
  candidate_schema.load_schema(schema=schema)
158
158
  candidate_schema.process()
159
159
 
@@ -9,6 +9,7 @@ from infrahub.core.manager import NodeManager
9
9
  from infrahub.core.node import Node
10
10
  from infrahub.core.node.create import create_node
11
11
  from infrahub.core.query.relationship import GetAllPeersIds
12
+ from infrahub.core.query.resource_manager import PoolChangeReserved
12
13
  from infrahub.core.relationship import RelationshipManager
13
14
  from infrahub.core.schema import NodeSchema
14
15
  from infrahub.database import InfrahubDatabase
@@ -121,4 +122,13 @@ async def convert_object_type(
121
122
  for peer in peers.values():
122
123
  peer.validate_relationships()
123
124
 
125
+ # If the node had some value reserved in any Pools / Resource Manager, we need to change the identifier of the reservation(s)
126
+ query = await PoolChangeReserved.init(
127
+ db=dbt,
128
+ existing_identifier=node.get_id(),
129
+ new_identifier=node_created.get_id(),
130
+ branch=branch,
131
+ )
132
+ await query.execute(db=dbt)
133
+
124
134
  return node_created
@@ -7,6 +7,7 @@ from infrahub.core.query.node import NodeGetHierarchyQuery
7
7
  from infrahub.core.query.relationship import RelationshipGetPeerQuery, RelationshipPeerData
8
8
  from infrahub.core.schema import ProfileSchema, TemplateSchema
9
9
  from infrahub.database import InfrahubDatabase
10
+ from infrahub.exceptions import SchemaNotFoundError
10
11
  from infrahub.log import get_logger
11
12
 
12
13
  from ..model.path import (
@@ -42,9 +43,12 @@ class DiffHierarchyEnricher(DiffEnricherInterface):
42
43
  node_hierarchy_map: dict[str, list[NodeIdentifier]] = defaultdict(list)
43
44
 
44
45
  for node in enriched_diff_root.nodes:
45
- schema_node = self.db.schema.get(
46
- name=node.kind, branch=enriched_diff_root.diff_branch_name, duplicate=False
47
- )
46
+ try:
47
+ schema_node = self.db.schema.get(
48
+ name=node.kind, branch=enriched_diff_root.diff_branch_name, duplicate=False
49
+ )
50
+ except SchemaNotFoundError:
51
+ continue
48
52
 
49
53
  if isinstance(schema_node, ProfileSchema | TemplateSchema):
50
54
  continue
@@ -13,6 +13,7 @@ from infrahub.core.constants import (
13
13
  )
14
14
  from infrahub.core.constants.database import DatabaseEdgeType
15
15
  from infrahub.core.timestamp import Timestamp
16
+ from infrahub.exceptions import SchemaNotFoundError
16
17
 
17
18
  from .model.field_specifiers_map import NodeFieldSpecifierMap
18
19
  from .model.path import (
@@ -566,9 +567,12 @@ class DiffQueryParser:
566
567
  if database_path.deepest_branch == self.diff_branch_name:
567
568
  branches_to_check.append(self.base_branch_name)
568
569
  for schema_branch_name in branches_to_check:
569
- node_schema = self.schema_manager.get(
570
- name=database_path.node_kind, branch=schema_branch_name, duplicate=False
571
- )
570
+ try:
571
+ node_schema = self.schema_manager.get(
572
+ name=database_path.node_kind, branch=schema_branch_name, duplicate=False
573
+ )
574
+ except SchemaNotFoundError:
575
+ continue
572
576
  relationship_schemas = node_schema.get_relationships_by_identifier(id=database_path.attribute_name)
573
577
  if len(relationship_schemas) == 1:
574
578
  return relationship_schemas[0]
@@ -1 +1 @@
1
- GRAPH_VERSION = 33
1
+ GRAPH_VERSION = 34
@@ -35,6 +35,7 @@ from .m030_illegal_edges import Migration030
35
35
  from .m031_check_number_attributes import Migration031
36
36
  from .m032_cleanup_orphaned_branch_relationships import Migration032
37
37
  from .m033_deduplicate_relationship_vertices import Migration033
38
+ from .m034_find_orphaned_schema_fields import Migration034
38
39
 
39
40
  if TYPE_CHECKING:
40
41
  from infrahub.core.root import Root
@@ -75,6 +76,7 @@ MIGRATIONS: list[type[GraphMigration | InternalSchemaMigration | ArbitraryMigrat
75
76
  Migration031,
76
77
  Migration032,
77
78
  Migration033,
79
+ Migration034,
78
80
  ]
79
81
 
80
82
 
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import TYPE_CHECKING, Any, Sequence
5
+
6
+ from infrahub.core.initialization import initialization
7
+ from infrahub.core.manager import NodeManager
8
+ from infrahub.core.migrations.shared import ArbitraryMigration, MigrationResult
9
+ from infrahub.core.timestamp import Timestamp
10
+ from infrahub.lock import initialize_lock
11
+ from infrahub.log import get_logger
12
+
13
+ from ...query import Query, QueryType
14
+
15
+ if TYPE_CHECKING:
16
+ from infrahub.database import InfrahubDatabase
17
+
18
+ log = get_logger()
19
+
20
+
21
+ class FindOrphanedSchemaFieldsQuery(Query):
22
+ name = "find_orphaned_schema_fields"
23
+ type = QueryType.WRITE
24
+
25
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
26
+ query = """
27
+ // ------------
28
+ // Find orphaned SchemaRelationship and SchemaAttribute vertices
29
+ // ------------
30
+ MATCH (schema_field:SchemaRelationship|SchemaAttribute)-[e:IS_RELATED]-(rel:Relationship)
31
+ WHERE rel.name IN ["schema__node__relationships", "schema__node__attributes"]
32
+ AND e.status = "deleted" OR e.to IS NOT NULL
33
+ WITH schema_field, e.branch AS branch, CASE
34
+ WHEN e.status = "deleted" THEN e.from
35
+ ELSE e.to
36
+ END AS delete_time
37
+ CALL (schema_field, branch) {
38
+ OPTIONAL MATCH (schema_field)-[is_part_of:IS_PART_OF {branch: branch}]->(:Root)
39
+ WHERE is_part_of.status = "deleted" OR is_part_of.to IS NOT NULL
40
+ RETURN is_part_of IS NOT NULL AS is_deleted
41
+ }
42
+ WITH schema_field, branch, delete_time
43
+ WHERE is_deleted = FALSE
44
+ """
45
+ self.add_to_query(query)
46
+ self.return_labels = ["schema_field.uuid AS schema_field_uuid", "branch", "delete_time"]
47
+
48
+
49
+ class Migration034(ArbitraryMigration):
50
+ """
51
+ Finds active SchemaRelationship and SchemaAttribute vertices with deleted relationships to SchemaNodes or
52
+ SchemaGenerics and deletes them on the same branch at the same time
53
+ """
54
+
55
+ name: str = "034_find_orphaned_schema_fields"
56
+ minimum_version: int = 33
57
+ queries: Sequence[type[Query]] = [FindOrphanedSchemaFieldsQuery]
58
+
59
+ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
60
+ return MigrationResult()
61
+
62
+ async def execute(self, db: InfrahubDatabase) -> MigrationResult:
63
+ try:
64
+ initialize_lock()
65
+ await initialization(db=db)
66
+ query = await FindOrphanedSchemaFieldsQuery.init(db=db)
67
+ await query.execute(db=db)
68
+ schema_field_uuids_by_branch: dict[str, dict[str, str]] = defaultdict(dict)
69
+ for result in query.get_results():
70
+ schema_field_uuid = result.get_as_type("schema_field_uuid", return_type=str)
71
+ branch = result.get_as_type("branch", return_type=str)
72
+ delete_time = result.get_as_type("delete_time", return_type=str)
73
+ schema_field_uuids_by_branch[branch][schema_field_uuid] = delete_time
74
+
75
+ for branch, schema_rel_details in schema_field_uuids_by_branch.items():
76
+ node_map = await NodeManager.get_many(db=db, branch=branch, ids=list(schema_rel_details.keys()))
77
+ for schema_field_uuid, orphan_schema_rel_node in node_map.items():
78
+ delete_time = Timestamp(schema_rel_details[schema_field_uuid])
79
+ await orphan_schema_rel_node.delete(db=db, at=delete_time)
80
+ except Exception as exc:
81
+ log.exception("Error during orphaned schema field cleanup")
82
+ return MigrationResult(errors=[str(exc)])
83
+
84
+ return MigrationResult()
@@ -1,10 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Sequence
3
+ from typing import TYPE_CHECKING, Any, Sequence
4
+
5
+ from infrahub.core import registry
6
+ from infrahub.core.node import Node
7
+ from infrahub.exceptions import PoolExhaustedError
8
+ from infrahub.tasks.registry import update_branch_registry
4
9
 
5
10
  from ..query import AttributeMigrationQuery
6
11
  from ..query.attribute_add import AttributeAddQuery
7
- from ..shared import AttributeSchemaMigration
12
+ from ..shared import AttributeSchemaMigration, MigrationResult
13
+
14
+ if TYPE_CHECKING:
15
+ from infrahub.core.node.resource_manager.number_pool import CoreNumberPool
16
+ from infrahub.database import InfrahubDatabase
17
+
18
+ from ...branch import Branch
19
+ from ...timestamp import Timestamp
8
20
 
9
21
 
10
22
  class NodeAttributeAddMigrationQuery01(AttributeMigrationQuery, AttributeAddQuery):
@@ -29,3 +41,44 @@ class NodeAttributeAddMigrationQuery01(AttributeMigrationQuery, AttributeAddQuer
29
41
  class NodeAttributeAddMigration(AttributeSchemaMigration):
30
42
  name: str = "node.attribute.add"
31
43
  queries: Sequence[type[AttributeMigrationQuery]] = [NodeAttributeAddMigrationQuery01] # type: ignore[assignment]
44
+
45
+ async def execute_post_queries(
46
+ self,
47
+ db: InfrahubDatabase,
48
+ result: MigrationResult,
49
+ branch: Branch,
50
+ at: Timestamp, # noqa: ARG002
51
+ ) -> MigrationResult:
52
+ if self.new_attribute_schema.kind != "NumberPool":
53
+ return result
54
+
55
+ number_pool: CoreNumberPool = await Node.fetch_or_create_number_pool( # type: ignore[assignment]
56
+ db=db, branch=branch, schema_node=self.new_schema, schema_attribute=self.new_attribute_schema
57
+ )
58
+
59
+ await update_branch_registry(db=db, branch=branch)
60
+
61
+ nodes: list[Node] = await registry.manager.query(
62
+ db=db, branch=branch, schema=self.new_schema, fields={"id": True, self.new_attribute_schema.name: True}
63
+ )
64
+
65
+ try:
66
+ numbers = await number_pool.get_next_many(
67
+ db=db,
68
+ branch=branch,
69
+ quantity=len(nodes),
70
+ attribute=self.new_attribute_schema,
71
+ )
72
+ except PoolExhaustedError as exc:
73
+ result.errors.append(str(exc))
74
+ return result
75
+
76
+ for node, number in zip(nodes, numbers, strict=True):
77
+ await number_pool.reserve(db=db, number=number, identifier=node.get_id())
78
+ attr = getattr(node, self.new_attribute_schema.name)
79
+ attr.value = number
80
+ attr.source = number_pool.id
81
+
82
+ await node.save(db=db, fields=[self.new_attribute_schema.name])
83
+
84
+ return result
@@ -16,13 +16,13 @@ from infrahub.core.schema import (
16
16
  SchemaRoot,
17
17
  internal_schema,
18
18
  )
19
+ from infrahub.core.timestamp import Timestamp
19
20
 
20
21
  from .query import MigrationQuery # noqa: TC001
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from infrahub.core.branch import Branch
24
25
  from infrahub.core.schema.schema_branch import SchemaBranch
25
- from infrahub.core.timestamp import Timestamp
26
26
  from infrahub.database import InfrahubDatabase
27
27
 
28
28
 
@@ -47,18 +47,46 @@ class SchemaMigration(BaseModel):
47
47
  previous_node_schema: NodeSchema | GenericSchema | None = None
48
48
  schema_path: SchemaPath
49
49
 
50
+ async def execute_pre_queries(
51
+ self,
52
+ db: InfrahubDatabase, # noqa: ARG002
53
+ result: MigrationResult,
54
+ branch: Branch, # noqa: ARG002
55
+ at: Timestamp, # noqa: ARG002
56
+ ) -> MigrationResult:
57
+ return result
58
+
59
+ async def execute_post_queries(
60
+ self,
61
+ db: InfrahubDatabase, # noqa: ARG002
62
+ result: MigrationResult,
63
+ branch: Branch, # noqa: ARG002
64
+ at: Timestamp, # noqa: ARG002
65
+ ) -> MigrationResult:
66
+ return result
67
+
68
+ async def execute_queries(
69
+ self, db: InfrahubDatabase, result: MigrationResult, branch: Branch, at: Timestamp
70
+ ) -> MigrationResult:
71
+ for migration_query in self.queries:
72
+ try:
73
+ query = await migration_query.init(db=db, branch=branch, at=at, migration=self)
74
+ await query.execute(db=db)
75
+ result.nbr_migrations_executed += query.get_nbr_migrations_executed()
76
+ except Exception as exc:
77
+ result.errors.append(str(exc))
78
+ return result
79
+
80
+ return result
81
+
50
82
  async def execute(self, db: InfrahubDatabase, branch: Branch, at: Timestamp | str | None = None) -> MigrationResult:
51
83
  async with db.start_transaction() as ts:
52
84
  result = MigrationResult()
85
+ at = Timestamp(at)
53
86
 
54
- for migration_query in self.queries:
55
- try:
56
- query = await migration_query.init(db=ts, branch=branch, at=at, migration=self)
57
- await query.execute(db=ts)
58
- result.nbr_migrations_executed += query.get_nbr_migrations_executed()
59
- except Exception as exc:
60
- result.errors.append(str(exc))
61
- return result
87
+ await self.execute_pre_queries(db=ts, result=result, branch=branch, at=at)
88
+ await self.execute_queries(db=ts, result=result, branch=branch, at=at)
89
+ await self.execute_post_queries(db=ts, result=result, branch=branch, at=at)
62
90
 
63
91
  return result
64
92
 
@@ -26,6 +26,7 @@ from infrahub.core.protocols import CoreNumberPool, CoreObjectTemplate
26
26
  from infrahub.core.query.node import NodeCheckIDQuery, NodeCreateAllQuery, NodeDeleteQuery, NodeGetListQuery
27
27
  from infrahub.core.schema import (
28
28
  AttributeSchema,
29
+ GenericSchema,
29
30
  NodeSchema,
30
31
  NonGenericSchemaTypes,
31
32
  ProfileSchema,
@@ -263,11 +264,9 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
263
264
  within the create code.
264
265
  """
265
266
 
266
- number_pool_parameters: NumberPoolParameters | None = None
267
267
  if attribute.schema.kind == "NumberPool" and isinstance(attribute.schema.parameters, NumberPoolParameters):
268
268
  attribute.from_pool = {"id": attribute.schema.parameters.number_pool_id}
269
269
  attribute.is_default = False
270
- number_pool_parameters = attribute.schema.parameters
271
270
 
272
271
  if not attribute.from_pool:
273
272
  return
@@ -277,9 +276,9 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
277
276
  db=db, id=attribute.from_pool["id"], kind=CoreNumberPool
278
277
  )
279
278
  except NodeNotFoundError:
280
- if number_pool_parameters:
281
- number_pool = await self._fetch_or_create_number_pool(
282
- db=db, attribute=attribute, number_pool_parameters=number_pool_parameters
279
+ if attribute.schema.kind == "NumberPool" and isinstance(attribute.schema.parameters, NumberPoolParameters):
280
+ number_pool = await self.fetch_or_create_number_pool(
281
+ db=db, schema_node=self._schema, schema_attribute=attribute.schema, branch=self._branch
283
282
  )
284
283
 
285
284
  else:
@@ -295,7 +294,9 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
295
294
  and number_pool.node_attribute.value == attribute.name
296
295
  ):
297
296
  try:
298
- next_free = await number_pool.get_resource(db=db, branch=self._branch, node=self, attribute=attribute)
297
+ next_free = await number_pool.get_resource(
298
+ db=db, branch=self._branch, node=self, attribute=attribute.schema
299
+ )
299
300
  except PoolExhaustedError:
300
301
  errors.append(
301
302
  ValidationError({f"{attribute.name}.from_pool": f"The pool {number_pool.node.value} is exhausted."})
@@ -313,10 +314,28 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
313
314
  )
314
315
  )
315
316
 
316
- async def _fetch_or_create_number_pool(
317
- self, db: InfrahubDatabase, attribute: BaseAttribute, number_pool_parameters: NumberPoolParameters
317
+ @staticmethod
318
+ async def fetch_or_create_number_pool(
319
+ db: InfrahubDatabase,
320
+ schema_node: NodeSchema | GenericSchema,
321
+ schema_attribute: AttributeSchema,
322
+ branch: Branch | None = None,
318
323
  ) -> CoreNumberPool:
324
+ """Fetch or create a number pool based on the schema attribute parameters.
325
+
326
+ Warning, ideally this method should be outside of the Node class, but it is itself using the Node class to create the pool node.
327
+ """
328
+
329
+ if (
330
+ schema_attribute.kind != "NumberPool"
331
+ or not schema_attribute.parameters
332
+ or not isinstance(schema_attribute.parameters, NumberPoolParameters)
333
+ ):
334
+ raise ValueError("Attribute is not of type NumberPool")
335
+
319
336
  number_pool_from_db: CoreNumberPool | None = None
337
+ number_pool_parameters: NumberPoolParameters = schema_attribute.parameters
338
+
320
339
  lock_definition = NumberPoolLockDefinition(pool_id=str(number_pool_parameters.number_pool_id))
321
340
  async with lock.registry.get(
322
341
  name=lock_definition.lock_name, namespace=lock_definition.namespace_name, local=False
@@ -325,37 +344,37 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
325
344
  number_pool_from_db = await registry.manager.get_one_by_id_or_default_filter(
326
345
  db=db, id=str(number_pool_parameters.number_pool_id), kind=CoreNumberPool
327
346
  )
347
+ return number_pool_from_db # type: ignore[return-value]
348
+
328
349
  except NodeNotFoundError:
329
350
  schema = db.schema.get_node_schema(name="CoreNumberPool", duplicate=False)
330
351
 
331
- pool_node = self._schema.kind
332
- schema_attribute = self._schema.get_attribute(attribute.schema.name)
352
+ pool_node = schema_node.kind
333
353
  if schema_attribute.inherited:
334
- for generic_name in self._schema.inherit_from:
354
+ for generic_name in schema_node.inherit_from:
335
355
  generic_node = db.schema.get_generic_schema(name=generic_name, duplicate=False)
336
- if attribute.schema.name in generic_node.attribute_names:
356
+ if schema_attribute.name in generic_node.attribute_names:
337
357
  pool_node = generic_node.kind
338
358
  break
339
359
 
340
- number_pool = await Node.init(db=db, schema=schema, branch=self._branch)
360
+ number_pool = await Node.init(db=db, schema=schema, branch=branch)
341
361
  await number_pool.new(
342
362
  db=db,
343
363
  id=number_pool_parameters.number_pool_id,
344
- name=f"{pool_node}.{attribute.schema.name} [{number_pool_parameters.number_pool_id}]",
364
+ name=f"{pool_node}.{schema_attribute.name} [{number_pool_parameters.number_pool_id}]",
345
365
  node=pool_node,
346
- node_attribute=attribute.schema.name,
366
+ node_attribute=schema_attribute.name,
347
367
  start_range=number_pool_parameters.start_range,
348
368
  end_range=number_pool_parameters.end_range,
349
369
  pool_type=NumberPoolType.SCHEMA.value,
350
370
  )
351
371
  await number_pool.save(db=db)
352
372
 
353
- # Do a lookup of the number pool to get the correct mapped type from the registry
354
- # without this we don't get access to the .get_resource() method.
355
- created_pool: CoreNumberPool = number_pool_from_db or await registry.manager.get_one_by_id_or_default_filter(
356
- db=db, id=number_pool.id, kind=CoreNumberPool
357
- )
358
- return created_pool
373
+ # Do a lookup of the number pool to get the correct mapped type from the registry
374
+ # without this we don't get access to the .get_resource() method.
375
+ return await registry.manager.get_one_by_id_or_default_filter(
376
+ db=db, id=number_pool.id, kind=CoreNumberPool
377
+ )
359
378
 
360
379
  async def handle_object_template(self, fields: dict, db: InfrahubDatabase, errors: list) -> None:
361
380
  """Fill the `fields` parameters with values from an object template if one is in use."""
@@ -819,6 +838,7 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
819
838
 
820
839
  query = await NodeDeleteQuery.init(db=db, node=self, at=delete_at)
821
840
  await query.execute(db=db)
841
+
822
842
  self._node_changelog = node_changelog
823
843
 
824
844
  async def to_graphql(
@@ -10,17 +10,15 @@ from infrahub.exceptions import PoolExhaustedError
10
10
  from .. import Node
11
11
 
12
12
  if TYPE_CHECKING:
13
- from infrahub.core.attribute import BaseAttribute
14
13
  from infrahub.core.branch import Branch
14
+ from infrahub.core.schema import AttributeSchema
15
15
  from infrahub.core.timestamp import Timestamp
16
16
  from infrahub.database import InfrahubDatabase
17
17
 
18
18
 
19
19
  class CoreNumberPool(Node):
20
20
  def get_attribute_nb_excluded_values(self) -> int:
21
- """
22
- Returns the number of excluded values for the attribute of the number pool.
23
- """
21
+ """Returns the number of excluded values for the attribute of the number pool."""
24
22
 
25
23
  pool_node = registry.schema.get(name=self.node.value) # type: ignore [attr-defined]
26
24
  attribute = [attribute for attribute in pool_node.attributes if attribute.name == self.node_attribute.value][0] # type: ignore [attr-defined]
@@ -35,18 +33,42 @@ class CoreNumberPool(Node):
35
33
  res = len(attribute.parameters.get_excluded_single_values()) + sum_excluded_values
36
34
  return res
37
35
 
36
+ async def get_used(
37
+ self,
38
+ db: InfrahubDatabase,
39
+ branch: Branch,
40
+ ) -> list[int]:
41
+ """Returns a list of used numbers in the pool."""
42
+
43
+ query = await NumberPoolGetUsed.init(db=db, branch=branch, pool=self, branch_agnostic=True)
44
+ await query.execute(db=db)
45
+ used = [result.value for result in query.iter_results()]
46
+ return [item for item in used if item is not None]
47
+
48
+ async def reserve(self, db: InfrahubDatabase, number: int, identifier: str, at: Timestamp | None = None) -> None:
49
+ """Reserve a number in the pool for a specific identifier."""
50
+
51
+ query = await NumberPoolSetReserved.init(
52
+ db=db, pool_id=self.get_id(), identifier=identifier, reserved=number, at=at
53
+ )
54
+ await query.execute(db=db)
55
+
38
56
  async def get_resource(
39
57
  self,
40
58
  db: InfrahubDatabase,
41
59
  branch: Branch,
42
60
  node: Node,
43
- attribute: BaseAttribute,
61
+ attribute: AttributeSchema,
44
62
  identifier: str | None = None,
45
63
  at: Timestamp | None = None,
46
64
  ) -> int:
65
+ # NOTE: ideally we should use the HFID as the identifier (if available)
66
+ # one of the challenge with using the HFID is that it might change over time
67
+ # so we need to ensure that the identifier is stable, or we need to handle the case where the identifier changes
47
68
  identifier = identifier or node.get_id()
69
+
48
70
  # Check if there is already a resource allocated with this identifier
49
- # if not, pull all existing prefixes and allocated the next available
71
+ # if not, pull all existing number and allocate the next available
50
72
  # TODO add support for branch, if the node is reserved with this id in another branch we should return an error
51
73
  query_get = await NumberPoolGetReserved.init(db=db, branch=branch, pool_id=self.id, identifier=identifier)
52
74
  await query_get.execute(db=db)
@@ -56,35 +78,51 @@ class CoreNumberPool(Node):
56
78
 
57
79
  # If we have not returned a value we need to find one if avaiable
58
80
  number = await self.get_next(db=db, branch=branch, attribute=attribute)
59
-
60
- query_set = await NumberPoolSetReserved.init(
61
- db=db, pool_id=self.get_id(), identifier=identifier, reserved=number, at=at
62
- )
63
- await query_set.execute(db=db)
81
+ await self.reserve(db=db, number=number, identifier=identifier, at=at)
64
82
  return number
65
83
 
66
- async def get_next(self, db: InfrahubDatabase, branch: Branch, attribute: BaseAttribute) -> int:
67
- query = await NumberPoolGetUsed.init(db=db, branch=branch, pool=self, branch_agnostic=True)
68
- await query.execute(db=db)
69
- taken = [result.get_as_optional_type("av.value", return_type=int) for result in query.results]
70
- parameters = attribute.schema.parameters
84
+ async def get_next(self, db: InfrahubDatabase, branch: Branch, attribute: AttributeSchema) -> int:
85
+ taken = await self.get_used(db=db, branch=branch)
86
+
71
87
  next_number = find_next_free(
72
88
  start=self.start_range.value, # type: ignore[attr-defined]
73
89
  end=self.end_range.value, # type: ignore[attr-defined]
74
90
  taken=taken,
75
- parameters=parameters if isinstance(parameters, NumberAttributeParameters) else None,
91
+ parameters=attribute.parameters if isinstance(attribute.parameters, NumberAttributeParameters) else None,
76
92
  )
77
93
  if next_number is None:
78
94
  raise PoolExhaustedError("There are no more values available in this pool.")
79
95
 
80
96
  return next_number
81
97
 
98
+ async def get_next_many(
99
+ self, db: InfrahubDatabase, quantity: int, branch: Branch, attribute: AttributeSchema
100
+ ) -> list[int]:
101
+ taken = await self.get_used(db=db, branch=branch)
102
+
103
+ allocated: list[int] = []
104
+
105
+ for _ in range(quantity):
106
+ next_number = find_next_free(
107
+ start=self.start_range.value, # type: ignore[attr-defined]
108
+ end=self.end_range.value, # type: ignore[attr-defined]
109
+ taken=list(set(taken) | set(allocated)),
110
+ parameters=attribute.parameters
111
+ if isinstance(attribute.parameters, NumberAttributeParameters)
112
+ else None,
113
+ )
114
+ if next_number is None:
115
+ raise PoolExhaustedError(
116
+ f"There are no more values available in this pool, couldn't allocate {quantity} values, only {len(allocated)} available."
117
+ )
118
+
119
+ allocated.append(next_number)
120
+
121
+ return allocated
122
+
82
123
 
83
- def find_next_free(
84
- start: int, end: int, taken: list[int | None], parameters: NumberAttributeParameters | None
85
- ) -> int | None:
86
- used_numbers = [number for number in taken if number is not None]
87
- used_set = set(used_numbers)
124
+ def find_next_free(start: int, end: int, taken: list[int], parameters: NumberAttributeParameters | None) -> int | None:
125
+ used_set = set(taken)
88
126
 
89
127
  for num in range(start, end + 1):
90
128
  if num not in used_set: