infrahub-server 1.4.9__py3-none-any.whl → 1.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. infrahub/actions/tasks.py +200 -16
  2. infrahub/api/artifact.py +3 -0
  3. infrahub/api/query.py +2 -0
  4. infrahub/api/schema.py +3 -0
  5. infrahub/auth.py +5 -5
  6. infrahub/cli/db.py +2 -2
  7. infrahub/config.py +7 -2
  8. infrahub/core/attribute.py +22 -19
  9. infrahub/core/branch/models.py +2 -2
  10. infrahub/core/branch/needs_rebase_status.py +11 -0
  11. infrahub/core/branch/tasks.py +2 -2
  12. infrahub/core/constants/__init__.py +1 -0
  13. infrahub/core/convert_object_type/object_conversion.py +201 -0
  14. infrahub/core/convert_object_type/repository_conversion.py +89 -0
  15. infrahub/core/convert_object_type/schema_mapping.py +27 -3
  16. infrahub/core/diff/query/artifact.py +12 -9
  17. infrahub/core/graph/__init__.py +1 -1
  18. infrahub/core/initialization.py +2 -2
  19. infrahub/core/manager.py +3 -81
  20. infrahub/core/migrations/graph/__init__.py +2 -0
  21. infrahub/core/migrations/graph/m040_profile_attrs_in_db.py +166 -0
  22. infrahub/core/node/__init__.py +26 -3
  23. infrahub/core/node/create.py +79 -38
  24. infrahub/core/node/lock_utils.py +98 -0
  25. infrahub/core/property.py +11 -0
  26. infrahub/core/protocols.py +1 -0
  27. infrahub/core/query/attribute.py +27 -15
  28. infrahub/core/query/node.py +47 -184
  29. infrahub/core/query/relationship.py +43 -26
  30. infrahub/core/query/subquery.py +0 -8
  31. infrahub/core/relationship/model.py +59 -19
  32. infrahub/core/schema/attribute_schema.py +0 -2
  33. infrahub/core/schema/definitions/core/repository.py +7 -0
  34. infrahub/core/schema/relationship_schema.py +0 -1
  35. infrahub/core/schema/schema_branch.py +3 -2
  36. infrahub/generators/models.py +31 -12
  37. infrahub/generators/tasks.py +3 -1
  38. infrahub/git/base.py +38 -1
  39. infrahub/graphql/api/dependencies.py +2 -4
  40. infrahub/graphql/api/endpoints.py +2 -2
  41. infrahub/graphql/app.py +2 -4
  42. infrahub/graphql/initialization.py +2 -3
  43. infrahub/graphql/manager.py +212 -137
  44. infrahub/graphql/middleware.py +12 -0
  45. infrahub/graphql/mutations/branch.py +11 -0
  46. infrahub/graphql/mutations/computed_attribute.py +110 -3
  47. infrahub/graphql/mutations/convert_object_type.py +34 -13
  48. infrahub/graphql/mutations/ipam.py +21 -8
  49. infrahub/graphql/mutations/main.py +37 -153
  50. infrahub/graphql/mutations/profile.py +195 -0
  51. infrahub/graphql/mutations/proposed_change.py +2 -1
  52. infrahub/graphql/mutations/repository.py +22 -83
  53. infrahub/graphql/mutations/webhook.py +1 -1
  54. infrahub/graphql/registry.py +173 -0
  55. infrahub/graphql/schema.py +4 -1
  56. infrahub/lock.py +52 -26
  57. infrahub/locks/__init__.py +0 -0
  58. infrahub/locks/tasks.py +37 -0
  59. infrahub/patch/plan_writer.py +2 -2
  60. infrahub/profiles/__init__.py +0 -0
  61. infrahub/profiles/node_applier.py +101 -0
  62. infrahub/profiles/queries/__init__.py +0 -0
  63. infrahub/profiles/queries/get_profile_data.py +99 -0
  64. infrahub/profiles/tasks.py +63 -0
  65. infrahub/repositories/__init__.py +0 -0
  66. infrahub/repositories/create_repository.py +113 -0
  67. infrahub/tasks/registry.py +6 -4
  68. infrahub/webhook/models.py +1 -1
  69. infrahub/workflows/catalogue.py +38 -3
  70. infrahub/workflows/models.py +17 -2
  71. infrahub_sdk/branch.py +5 -8
  72. infrahub_sdk/client.py +364 -84
  73. infrahub_sdk/convert_object_type.py +61 -0
  74. infrahub_sdk/ctl/check.py +2 -3
  75. infrahub_sdk/ctl/cli_commands.py +16 -12
  76. infrahub_sdk/ctl/config.py +8 -2
  77. infrahub_sdk/ctl/generator.py +2 -3
  78. infrahub_sdk/ctl/repository.py +39 -1
  79. infrahub_sdk/ctl/schema.py +12 -1
  80. infrahub_sdk/ctl/utils.py +4 -0
  81. infrahub_sdk/ctl/validate.py +5 -3
  82. infrahub_sdk/diff.py +4 -5
  83. infrahub_sdk/exceptions.py +2 -0
  84. infrahub_sdk/graphql.py +7 -2
  85. infrahub_sdk/node/attribute.py +2 -0
  86. infrahub_sdk/node/node.py +28 -20
  87. infrahub_sdk/playback.py +1 -2
  88. infrahub_sdk/protocols.py +40 -6
  89. infrahub_sdk/pytest_plugin/plugin.py +7 -4
  90. infrahub_sdk/pytest_plugin/utils.py +40 -0
  91. infrahub_sdk/repository.py +1 -2
  92. infrahub_sdk/schema/main.py +1 -0
  93. infrahub_sdk/spec/object.py +43 -4
  94. infrahub_sdk/spec/range_expansion.py +118 -0
  95. infrahub_sdk/timestamp.py +18 -6
  96. {infrahub_server-1.4.9.dist-info → infrahub_server-1.5.0b0.dist-info}/METADATA +20 -24
  97. {infrahub_server-1.4.9.dist-info → infrahub_server-1.5.0b0.dist-info}/RECORD +102 -84
  98. infrahub_testcontainers/models.py +2 -2
  99. infrahub_testcontainers/performance_test.py +4 -4
  100. infrahub/core/convert_object_type/conversion.py +0 -134
  101. {infrahub_server-1.4.9.dist-info → infrahub_server-1.5.0b0.dist-info}/LICENSE.txt +0 -0
  102. {infrahub_server-1.4.9.dist-info → infrahub_server-1.5.0b0.dist-info}/WHEEL +0 -0
  103. {infrahub_server-1.4.9.dist-info → infrahub_server-1.5.0b0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from rich.console import Console
7
+ from rich.progress import Progress
8
+
9
+ from infrahub.core.branch.models import Branch
10
+ from infrahub.core.initialization import initialization
11
+ from infrahub.core.manager import NodeManager
12
+ from infrahub.core.migrations.shared import MigrationResult
13
+ from infrahub.core.query import Query, QueryType
14
+ from infrahub.core.timestamp import Timestamp
15
+ from infrahub.lock import initialize_lock
16
+ from infrahub.log import get_logger
17
+ from infrahub.profiles.node_applier import NodeProfilesApplier
18
+
19
+ from ..shared import ArbitraryMigration
20
+
21
+ if TYPE_CHECKING:
22
+ from infrahub.core.node import Node
23
+ from infrahub.database import InfrahubDatabase
24
+
25
+ log = get_logger()
26
+
27
+
28
+ class GetProfilesByBranchQuery(Query):
29
+ """
30
+ Get CoreProfile UUIDs by which branches they have attribute updates on
31
+ """
32
+
33
+ name = "get_profiles_by_branch"
34
+ type = QueryType.READ
35
+ insert_return = False
36
+
37
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
38
+ query = """
39
+ MATCH (profile:CoreProfile)-[:HAS_ATTRIBUTE]->(attr:Attribute)-[e:HAS_VALUE]->(:AttributeValue)
40
+ WITH DISTINCT profile.uuid AS profile_uuid, e.branch AS branch
41
+ RETURN profile_uuid, collect(branch) AS branches
42
+ """
43
+ self.add_to_query(query)
44
+ self.return_labels = ["profile_uuid", "branches"]
45
+
46
+ def get_profile_ids_by_branch(self) -> dict[str, set[str]]:
47
+ """Get dictionary of branch names to set of updated profile UUIDs"""
48
+ profiles_by_branch = defaultdict(set)
49
+ for result in self.get_results():
50
+ profile_uuid = result.get_as_type("profile_uuid", str)
51
+ branches = result.get_as_type("branches", list[str])
52
+ for branch in branches:
53
+ profiles_by_branch[branch].add(profile_uuid)
54
+ return profiles_by_branch
55
+
56
+
57
+ class GetNodesWithProfileUpdatesByBranchQuery(Query):
58
+ """
59
+ Get Node UUIDs by which branches they have updated profiles on
60
+ """
61
+
62
+ name = "get_nodes_with_profile_updates_by_branch"
63
+ type = QueryType.READ
64
+ insert_return = False
65
+
66
+ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
67
+ query = """
68
+ MATCH (node:Node)-[e1:IS_RELATED]->(:Relationship {name: "node__profile"})
69
+ WHERE NOT node:CoreProfile
70
+ WITH DISTINCT node.uuid AS node_uuid, e1.branch AS branch
71
+ RETURN node_uuid, collect(branch) AS branches
72
+ """
73
+ self.add_to_query(query)
74
+ self.return_labels = ["node_uuid", "branches"]
75
+
76
+ def get_node_ids_by_branch(self) -> dict[str, set[str]]:
77
+ """Get dictionary of branch names to set of updated node UUIDs"""
78
+ nodes_by_branch = defaultdict(set)
79
+ for result in self.get_results():
80
+ node_uuid = result.get_as_type("node_uuid", str)
81
+ branches = result.get_as_type("branches", list[str])
82
+ for branch in branches:
83
+ nodes_by_branch[branch].add(node_uuid)
84
+ return nodes_by_branch
85
+
86
+
87
+ class Migration040(ArbitraryMigration):
88
+ """
89
+ Save profile attribute values on each node using the profile in the database
90
+ For any profile that has updates on a given branch (including default branch)
91
+ - run NodeProfilesApplier.apply_profiles on each node related to the profile on that branch
92
+ For any node that has an updated relationship to a profile on a given branch
93
+ - run NodeProfilesApplier.apply_profiles on the node on that branch
94
+ """
95
+
96
+ name: str = "040_profile_attrs_in_db"
97
+ minimum_version: int = 39
98
+
99
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
100
+ super().__init__(*args, **kwargs)
101
+ self._appliers_by_branch: dict[str, NodeProfilesApplier] = {}
102
+
103
+ async def _get_profile_applier(self, db: InfrahubDatabase, branch_name: str) -> NodeProfilesApplier:
104
+ if branch_name not in self._appliers_by_branch:
105
+ branch = await Branch.get_by_name(db=db, name=branch_name)
106
+ self._appliers_by_branch[branch_name] = NodeProfilesApplier(db=db, branch=branch)
107
+ return self._appliers_by_branch[branch_name]
108
+
109
+ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
110
+ return MigrationResult()
111
+
112
+ async def execute(self, db: InfrahubDatabase) -> MigrationResult:
113
+ console = Console()
114
+ result = MigrationResult()
115
+ # load schemas from database into registry
116
+ initialize_lock()
117
+ await initialization(db=db)
118
+
119
+ console.print("Gathering profiles for each branch...", end="")
120
+ get_profiles_by_branch_query = await GetProfilesByBranchQuery.init(db=db)
121
+ await get_profiles_by_branch_query.execute(db=db)
122
+ profiles_ids_by_branch = get_profiles_by_branch_query.get_profile_ids_by_branch()
123
+
124
+ profiles_by_branch: dict[str, list[Node]] = {}
125
+ for branch_name, profile_ids in profiles_ids_by_branch.items():
126
+ profiles_map = await NodeManager.get_many(db=db, branch=branch_name, ids=list(profile_ids))
127
+ profiles_by_branch[branch_name] = list(profiles_map.values())
128
+ console.print("done")
129
+
130
+ node_ids_to_update_by_branch: dict[str, set[str]] = defaultdict(set)
131
+ total_size = sum(len(profiles) for profiles in profiles_by_branch.values())
132
+ with Progress() as progress:
133
+ gather_nodes_task = progress.add_task(
134
+ "Gathering affected objects for each profile on each branch...", total=total_size
135
+ )
136
+
137
+ for branch_name, profiles in profiles_by_branch.items():
138
+ for profile in profiles:
139
+ node_relationship_manager = profile.get_relationship("related_nodes")
140
+ node_peers = await node_relationship_manager.get_db_peers(db=db)
141
+ node_ids_to_update_by_branch[branch_name].update({str(peer.peer_id) for peer in node_peers})
142
+ progress.update(gather_nodes_task, advance=1)
143
+
144
+ console.print("Identifying nodes with profile updates by branch...", end="")
145
+ get_nodes_with_profile_updates_by_branch_query = await GetNodesWithProfileUpdatesByBranchQuery.init(db=db)
146
+ await get_nodes_with_profile_updates_by_branch_query.execute(db=db)
147
+ nodes_ids_by_branch = get_nodes_with_profile_updates_by_branch_query.get_node_ids_by_branch()
148
+ for branch_name, node_ids in nodes_ids_by_branch.items():
149
+ node_ids_to_update_by_branch[branch_name].update(node_ids)
150
+ console.print("done")
151
+
152
+ right_now = Timestamp()
153
+ total_size = sum(len(node_ids) for node_ids in node_ids_to_update_by_branch.values())
154
+ with Progress() as progress:
155
+ apply_task = progress.add_task("Applying profiles to nodes...", total=total_size)
156
+ for branch_name, node_ids in node_ids_to_update_by_branch.items():
157
+ applier = await self._get_profile_applier(db=db, branch_name=branch_name)
158
+ for node_id in node_ids:
159
+ node = await NodeManager.get_one(db=db, branch=branch_name, id=node_id, at=right_now)
160
+ if node:
161
+ updated_field_names = await applier.apply_profiles(node=node)
162
+ if updated_field_names:
163
+ await node.save(db=db, fields=updated_field_names, at=right_now)
164
+ progress.update(apply_task, advance=1)
165
+
166
+ return result
@@ -42,6 +42,7 @@ from infrahub.types import ATTRIBUTE_TYPES
42
42
  from ...graphql.constants import KIND_GRAPHQL_FIELD_NAME
43
43
  from ...graphql.models import OrderModel
44
44
  from ...log import get_logger
45
+ from ..attribute import BaseAttribute
45
46
  from ..query.relationship import RelationshipDeleteAllQuery
46
47
  from ..relationship import RelationshipManager
47
48
  from ..utils import update_relationships_to
@@ -53,8 +54,6 @@ if TYPE_CHECKING:
53
54
  from infrahub.core.branch import Branch
54
55
  from infrahub.database import InfrahubDatabase
55
56
 
56
- from ..attribute import BaseAttribute
57
-
58
57
  SchemaProtocol = TypeVar("SchemaProtocol")
59
58
 
60
59
  # ---------------------------------------------------------------------------------------
@@ -100,6 +99,28 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
100
99
  def get_updated_at(self) -> Timestamp | None:
101
100
  return self._updated_at
102
101
 
102
+ def get_attribute(self, name: str) -> BaseAttribute:
103
+ attribute = getattr(self, name)
104
+ if not isinstance(attribute, BaseAttribute):
105
+ raise ValueError(f"{name} is not an attribute of {self.get_kind()}")
106
+ return attribute
107
+
108
+ def get_relationship(self, name: str) -> RelationshipManager:
109
+ relationship = getattr(self, name)
110
+ if not isinstance(relationship, RelationshipManager):
111
+ raise ValueError(f"{name} is not a relationship of {self.get_kind()}")
112
+ return relationship
113
+
114
+ def uses_profiles(self) -> bool:
115
+ for attr_name in self.get_schema().attribute_names:
116
+ try:
117
+ node_attr = self.get_attribute(attr_name)
118
+ except ValueError:
119
+ continue
120
+ if node_attr and node_attr.is_from_profile:
121
+ return True
122
+ return False
123
+
103
124
  async def get_hfid(self, db: InfrahubDatabase, include_kind: bool = False) -> list[str] | None:
104
125
  """Return the Human friendly id of the node."""
105
126
  if not self._schema.human_friendly_id:
@@ -408,7 +429,9 @@ class Node(BaseNode, metaclass=BaseNodeMeta):
408
429
  for attribute_name in template._attributes:
409
430
  if attribute_name in list(fields) + [OBJECT_TEMPLATE_NAME_ATTR]:
410
431
  continue
411
- fields[attribute_name] = {"value": getattr(template, attribute_name).value, "source": template.id}
432
+ attr_value = getattr(template, attribute_name).value
433
+ if attr_value is not None:
434
+ fields[attribute_name] = {"value": attr_value, "source": template.id}
412
435
 
413
436
  for relationship_name in template._relationships:
414
437
  relationship_schema = template._schema.get_relationship(name=relationship_name)
@@ -2,18 +2,23 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, Any, Mapping
4
4
 
5
+ from infrahub import lock
5
6
  from infrahub.core import registry
6
7
  from infrahub.core.constants import RelationshipCardinality, RelationshipKind
7
8
  from infrahub.core.constraint.node.runner import NodeConstraintRunner
8
- from infrahub.core.manager import NodeManager
9
9
  from infrahub.core.node import Node
10
+ from infrahub.core.node.lock_utils import get_kind_lock_names_on_object_mutation
10
11
  from infrahub.core.protocols import CoreObjectTemplate
12
+ from infrahub.core.schema import GenericSchema
11
13
  from infrahub.dependencies.registry import get_component_registry
14
+ from infrahub.lock import InfrahubMultiLock
15
+ from infrahub.profiles.node_applier import NodeProfilesApplier
12
16
 
13
17
  if TYPE_CHECKING:
14
18
  from infrahub.core.branch import Branch
15
19
  from infrahub.core.relationship.model import RelationshipManager
16
20
  from infrahub.core.schema import MainSchemaTypes, NonGenericSchemaTypes, RelationshipSchema
21
+ from infrahub.core.timestamp import Timestamp
17
22
  from infrahub.database import InfrahubDatabase
18
23
 
19
24
 
@@ -41,10 +46,19 @@ async def extract_peer_data(
41
46
  ) -> Mapping[str, Any]:
42
47
  obj_peer_data: dict[str, Any] = {}
43
48
 
44
- for attr in template_peer.get_schema().attribute_names:
45
- if attr not in obj_peer_schema.attribute_names:
49
+ for attr_name in template_peer.get_schema().attribute_names:
50
+ template_attr = getattr(template_peer, attr_name)
51
+ if template_attr.value is None:
46
52
  continue
47
- obj_peer_data[attr] = {"value": getattr(template_peer, attr).value, "source": template_peer.id}
53
+ if template_attr.is_default:
54
+ # if template attr is_default and the value matches the object schema, then do not set the source
55
+ try:
56
+ if obj_peer_schema.get_attribute(name=attr_name).default_value == template_attr.value:
57
+ continue
58
+ except ValueError:
59
+ pass
60
+
61
+ obj_peer_data[attr_name] = {"value": template_attr.value, "source": template_peer.id}
48
62
 
49
63
  for rel in template_peer.get_schema().relationship_names:
50
64
  rel_manager: RelationshipManager = getattr(template_peer, rel)
@@ -67,6 +81,7 @@ async def handle_template_relationships(
67
81
  template: CoreObjectTemplate,
68
82
  fields: list,
69
83
  constraint_runner: NodeConstraintRunner | None = None,
84
+ at: Timestamp | None = None,
70
85
  ) -> None:
71
86
  if constraint_runner is None:
72
87
  component_registry = get_component_registry()
@@ -94,7 +109,7 @@ async def handle_template_relationships(
94
109
  current_template=template,
95
110
  )
96
111
 
97
- obj_peer = await Node.init(schema=obj_peer_schema, db=db, branch=branch)
112
+ obj_peer = await Node.init(schema=obj_peer_schema, db=db, branch=branch, at=at)
98
113
  await obj_peer.new(db=db, **obj_peer_data)
99
114
  await constraint_runner.check(node=obj_peer, field_filters=list(obj_peer_data))
100
115
  await obj_peer.save(db=db)
@@ -106,6 +121,7 @@ async def handle_template_relationships(
106
121
  obj=obj_peer,
107
122
  template=template_relationship_peer,
108
123
  fields=fields,
124
+ at=at,
109
125
  )
110
126
 
111
127
 
@@ -116,43 +132,20 @@ async def get_profile_ids(db: InfrahubDatabase, obj: Node) -> set[str]:
116
132
  return {pr.peer_id for pr in profile_rels}
117
133
 
118
134
 
119
- async def refresh_for_profile_update(
120
- db: InfrahubDatabase,
121
- branch: Branch,
122
- obj: Node,
123
- schema: NonGenericSchemaTypes,
124
- previous_profile_ids: set[str] | None = None,
125
- ) -> Node:
126
- if not hasattr(obj, "profiles"):
127
- return obj
128
- current_profile_ids = await get_profile_ids(db=db, obj=obj)
129
- if previous_profile_ids is None or previous_profile_ids != current_profile_ids:
130
- refreshed_node = await NodeManager.get_one_by_id_or_default_filter(
131
- db=db,
132
- kind=schema.kind,
133
- id=obj.get_id(),
134
- branch=branch,
135
- include_owner=True,
136
- include_source=True,
137
- )
138
- refreshed_node._node_changelog = obj.node_changelog
139
- return refreshed_node
140
- return obj
141
-
142
-
143
135
  async def _do_create_node(
144
136
  node_class: type[Node],
137
+ node_constraint_runner: NodeConstraintRunner,
145
138
  db: InfrahubDatabase,
146
- data: dict,
147
139
  schema: NonGenericSchemaTypes,
148
- fields_to_validate: list,
149
140
  branch: Branch,
150
- node_constraint_runner: NodeConstraintRunner,
141
+ fields_to_validate: list[str],
142
+ data: dict[str, Any],
143
+ at: Timestamp | None = None,
151
144
  ) -> Node:
152
145
  obj = await node_class.init(db=db, schema=schema, branch=branch)
153
146
  await obj.new(db=db, **data)
154
147
  await node_constraint_runner.check(node=obj, field_filters=fields_to_validate)
155
- await obj.save(db=db)
148
+ await obj.save(db=db, at=at)
156
149
 
157
150
  object_template = await obj.get_object_template(db=db)
158
151
  if object_template:
@@ -162,18 +155,62 @@ async def _do_create_node(
162
155
  template=object_template,
163
156
  obj=obj,
164
157
  fields=fields_to_validate,
158
+ at=at,
165
159
  )
166
160
  return obj
167
161
 
168
162
 
163
+ async def _do_create_node_with_lock(
164
+ node_class: type[Node],
165
+ node_constraint_runner: NodeConstraintRunner,
166
+ db: InfrahubDatabase,
167
+ schema: NonGenericSchemaTypes,
168
+ branch: Branch,
169
+ fields_to_validate: list[str],
170
+ data: dict[str, Any],
171
+ at: Timestamp | None = None,
172
+ ) -> Node:
173
+ schema_branch = registry.schema.get_schema_branch(name=branch.name)
174
+ lock_names = get_kind_lock_names_on_object_mutation(
175
+ kind=schema.kind, branch=branch, schema_branch=schema_branch, data=dict(data)
176
+ )
177
+
178
+ if lock_names:
179
+ async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names):
180
+ return await _do_create_node(
181
+ node_class=node_class,
182
+ node_constraint_runner=node_constraint_runner,
183
+ db=db,
184
+ schema=schema,
185
+ branch=branch,
186
+ fields_to_validate=fields_to_validate,
187
+ data=data,
188
+ at=at,
189
+ )
190
+ return await _do_create_node(
191
+ node_class=node_class,
192
+ node_constraint_runner=node_constraint_runner,
193
+ db=db,
194
+ schema=schema,
195
+ branch=branch,
196
+ fields_to_validate=fields_to_validate,
197
+ data=data,
198
+ at=at,
199
+ )
200
+
201
+
169
202
  async def create_node(
170
- data: dict,
203
+ data: dict[str, Any],
171
204
  db: InfrahubDatabase,
172
205
  branch: Branch,
173
- schema: NonGenericSchemaTypes,
206
+ schema: MainSchemaTypes,
207
+ at: Timestamp | None = None,
174
208
  ) -> Node:
175
209
  """Create a node in the database if constraint checks succeed."""
176
210
 
211
+ if isinstance(schema, GenericSchema):
212
+ raise ValueError(f"Node of generic schema `{schema.name=}` can not be instantiated.")
213
+
177
214
  component_registry = get_component_registry()
178
215
  node_constraint_runner = await component_registry.get_component(
179
216
  NodeConstraintRunner, db=db.start_session() if not db.is_transaction else db, branch=branch
@@ -184,7 +221,7 @@ async def create_node(
184
221
 
185
222
  fields_to_validate = list(data)
186
223
  if db.is_transaction:
187
- obj = await _do_create_node(
224
+ obj = await _do_create_node_with_lock(
188
225
  node_class=node_class,
189
226
  node_constraint_runner=node_constraint_runner,
190
227
  db=db,
@@ -192,10 +229,11 @@ async def create_node(
192
229
  branch=branch,
193
230
  fields_to_validate=fields_to_validate,
194
231
  data=data,
232
+ at=at,
195
233
  )
196
234
  else:
197
235
  async with db.start_transaction() as dbt:
198
- obj = await _do_create_node(
236
+ obj = await _do_create_node_with_lock(
199
237
  node_class=node_class,
200
238
  node_constraint_runner=node_constraint_runner,
201
239
  db=dbt,
@@ -203,9 +241,12 @@ async def create_node(
203
241
  branch=branch,
204
242
  fields_to_validate=fields_to_validate,
205
243
  data=data,
244
+ at=at,
206
245
  )
207
246
 
208
247
  if await get_profile_ids(db=db, obj=obj):
209
- obj = await refresh_for_profile_update(db=db, branch=branch, schema=schema, obj=obj)
248
+ node_profiles_applier = NodeProfilesApplier(db=db, branch=branch)
249
+ await node_profiles_applier.apply_profiles(node=obj)
250
+ await obj.save(db=db)
210
251
 
211
252
  return obj
@@ -0,0 +1,98 @@
1
+ import hashlib
2
+ from typing import Any
3
+
4
+ from infrahub.core.branch import Branch
5
+ from infrahub.core.constants.infrahubkind import GENERICGROUP, GRAPHQLQUERYGROUP
6
+ from infrahub.core.schema import GenericSchema
7
+ from infrahub.core.schema.schema_branch import SchemaBranch
8
+
9
+ KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED = [GENERICGROUP]
10
+
11
+
12
+ def _get_kinds_to_lock_on_object_mutation(kind: str, schema_branch: SchemaBranch) -> list[str]:
13
+ """
14
+ Return kinds for which we want to lock during creating / updating an object of a given schema node.
15
+ Lock should be performed on schema kind and its generics having a uniqueness_constraint defined.
16
+ If a generic uniqueness constraint is the same as the node schema one,
17
+ it means node schema overrided this constraint, in which case we only need to lock on the generic.
18
+ """
19
+
20
+ node_schema = schema_branch.get(name=kind, duplicate=False)
21
+
22
+ schema_uc = None
23
+ kinds = []
24
+ if node_schema.uniqueness_constraints:
25
+ kinds.append(node_schema.kind)
26
+ schema_uc = node_schema.uniqueness_constraints
27
+
28
+ if isinstance(node_schema, GenericSchema):
29
+ return kinds
30
+
31
+ generics_kinds = node_schema.inherit_from
32
+
33
+ node_schema_kind_removed = False
34
+ for generic_kind in generics_kinds:
35
+ generic_uc = schema_branch.get(name=generic_kind, duplicate=False).uniqueness_constraints
36
+ if generic_uc:
37
+ kinds.append(generic_kind)
38
+ if not node_schema_kind_removed and generic_uc == schema_uc:
39
+ # Check whether we should remove original schema kind as it simply overrides uniqueness_constraint
40
+ # of a generic
41
+ kinds.pop(0)
42
+ node_schema_kind_removed = True
43
+ return kinds
44
+
45
+
46
+ def _should_kind_be_locked_on_any_branch(kind: str, schema_branch: SchemaBranch) -> bool:
47
+ """
48
+ Check whether kind or any kind generic is in KINDS_TO_LOCK_ON_ANY_BRANCH.
49
+ """
50
+
51
+ if kind in KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED:
52
+ return True
53
+
54
+ node_schema = schema_branch.get(name=kind, duplicate=False)
55
+ if isinstance(node_schema, GenericSchema):
56
+ return False
57
+
58
+ for generic_kind in node_schema.inherit_from:
59
+ if generic_kind in KINDS_CONCURRENT_MUTATIONS_NOT_ALLOWED:
60
+ return True
61
+ return False
62
+
63
+
64
+ def _hash(value: str) -> str:
65
+ # Do not use builtin `hash` for lock names as due to randomization results would differ between
66
+ # different processes.
67
+ return hashlib.sha256(value.encode()).hexdigest()
68
+
69
+
70
+ def get_kind_lock_names_on_object_mutation(
71
+ kind: str, branch: Branch, schema_branch: SchemaBranch, data: dict[str, Any]
72
+ ) -> list[str]:
73
+ """
74
+ Return objects kind for which we want to avoid concurrent mutation (create/update). Except for some specific kinds,
75
+ concurrent mutations are only allowed on non-main branch as objects validations will be performed at least when merging in main branch.
76
+ """
77
+
78
+ if not branch.is_default and not _should_kind_be_locked_on_any_branch(kind=kind, schema_branch=schema_branch):
79
+ return []
80
+
81
+ if kind == GRAPHQLQUERYGROUP:
82
+ # Lock on name as well to improve performances
83
+ try:
84
+ name = data["name"].value
85
+ return [build_object_lock_name(kind + "." + _hash(name))]
86
+ except KeyError:
87
+ # We might reach here if we are updating a CoreGraphQLQueryGroup without updating the name,
88
+ # in which case we would not need to lock. This is not supposed to happen as current `update`
89
+ # logic first fetches the node with its name.
90
+ return []
91
+
92
+ lock_kinds = _get_kinds_to_lock_on_object_mutation(kind, schema_branch)
93
+ lock_names = [build_object_lock_name(kind) for kind in lock_kinds]
94
+ return lock_names
95
+
96
+
97
+ def build_object_lock_name(name: str) -> str:
98
+ return f"global.object.{name}"
infrahub/core/property.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from enum import Enum
3
4
  from typing import TYPE_CHECKING
4
5
  from uuid import UUID
5
6
 
@@ -26,6 +27,10 @@ class NodePropertyData(BaseModel):
26
27
  peer_id: str
27
28
 
28
29
 
30
+ class ClearValue(Enum):
31
+ CLEAR = "clear"
32
+
33
+
29
34
  class FlagPropertyMixin:
30
35
  _flag_properties: list[str] = [v.value for v in FlagProperty]
31
36
 
@@ -51,6 +56,7 @@ class NodePropertyMixin:
51
56
  for node in self._node_properties:
52
57
  setattr(self, f"_{node}", None)
53
58
  setattr(self, f"{node}_id", None)
59
+ setattr(self, f"_clear_{node}", False)
54
60
 
55
61
  if not kwargs:
56
62
  return
@@ -79,12 +85,14 @@ class NodePropertyMixin:
79
85
 
80
86
  def clear_owner(self) -> None:
81
87
  self._set_node_property(name="owner", value=None)
88
+ self._clear_owner = True
82
89
 
83
90
  async def get_source(self, db: InfrahubDatabase) -> Node | None:
84
91
  return await self._get_node_property(name="source", db=db)
85
92
 
86
93
  def clear_source(self) -> None:
87
94
  self._set_node_property(name="source", value=None)
95
+ self._clear_source = True
88
96
 
89
97
  def set_source(self, value: str | Node | UUID) -> None:
90
98
  self._set_node_property(name="source", value=value)
@@ -95,6 +103,9 @@ class NodePropertyMixin:
95
103
  def set_owner(self, value: str | Node | UUID) -> None:
96
104
  self._set_node_property(name="owner", value=value)
97
105
 
106
+ def is_clear(self, name: str) -> bool:
107
+ return getattr(self, f"_clear_{name}", False)
108
+
98
109
  def _get_node_property_from_cache(self, name: str) -> Node:
99
110
  """Return the node attribute if it's already present locally,
100
111
  Otherwise raise an exception
@@ -125,6 +125,7 @@ class CoreGenericRepository(CoreNode):
125
125
  queries: RelationshipManager
126
126
  checks: RelationshipManager
127
127
  generators: RelationshipManager
128
+ groups_objects: RelationshipManager
128
129
 
129
130
 
130
131
  class CoreGroup(CoreNode):
@@ -133,7 +133,7 @@ class AttributeUpdateNodePropertyQuery(AttributeQuery):
133
133
  def __init__(
134
134
  self,
135
135
  prop_name: str,
136
- prop_id: str,
136
+ prop_id: str | None = None,
137
137
  **kwargs: Any,
138
138
  ):
139
139
  self.prop_name = prop_name
@@ -144,6 +144,8 @@ class AttributeUpdateNodePropertyQuery(AttributeQuery):
144
144
  async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
145
145
  at = self.at or self.attr.at
146
146
 
147
+ branch_filter, branch_params = self.branch.get_query_filter_path(at=at)
148
+ self.params.update(branch_params)
147
149
  self.params["attr_uuid"] = self.attr.id
148
150
  self.params["branch"] = self.branch.name
149
151
  self.params["branch_level"] = self.branch.hierarchy_level
@@ -151,18 +153,34 @@ class AttributeUpdateNodePropertyQuery(AttributeQuery):
151
153
  self.params["prop_name"] = self.prop_name
152
154
  self.params["prop_id"] = self.prop_id
153
155
 
154
- rel_name = f"HAS_{self.prop_name.upper()}"
156
+ rel_label = f"HAS_{self.prop_name.upper()}"
155
157
 
156
- query = (
158
+ if self.branch.is_default or self.branch.is_global:
159
+ node_query = """
160
+ MATCH (np:Node { uuid: $prop_id })-[r:IS_PART_OF]->(:Root)
161
+ WHERE r.branch IN $branch0
162
+ AND r.status = "active"
163
+ AND r.from <= $at AND (r.to IS NULL OR r.to > $at)
164
+ WITH np
165
+ LIMIT 1
157
166
  """
167
+ else:
168
+ node_query = """
169
+ MATCH (np:Node { uuid: $prop_id })-[r:IS_PART_OF]->(:Root)
170
+ WHERE %(branch_filter)s
171
+ ORDER BY r.branch_level DESC, r.from DESC, r.status ASC
172
+ LIMIT 1
173
+ WITH np
174
+ WHERE r.status = "active"
175
+ """ % {"branch_filter": branch_filter}
176
+ self.add_to_query(node_query)
177
+
178
+ attr_query = """
158
179
  MATCH (a:Attribute { uuid: $attr_uuid })
159
- MATCH (np:Node { uuid: $prop_id })
160
- CREATE (a)-[r:%s { branch: $branch, branch_level: $branch_level, status: "active", from: $at }]->(np)
161
- """
162
- % rel_name
163
- )
180
+ CREATE (a)-[r:%(rel_label)s { branch: $branch, branch_level: $branch_level, status: "active", from: $at }]->(np)
181
+ """ % {"rel_label": rel_label}
182
+ self.add_to_query(attr_query)
164
183
 
165
- self.add_to_query(query)
166
184
  self.return_labels = ["a", "np", "r"]
167
185
 
168
186
 
@@ -204,7 +222,6 @@ async def default_attribute_query_filter(
204
222
  param_prefix: str | None = None,
205
223
  db: InfrahubDatabase | None = None, # noqa: ARG001
206
224
  partial_match: bool = False,
207
- support_profiles: bool = False,
208
225
  ) -> tuple[list[QueryElement], dict[str, Any], list[str]]:
209
226
  """Generate Query String Snippet to filter the right node."""
210
227
  attribute_value_label = GraphAttributeValueNode.get_default_label()
@@ -251,9 +268,6 @@ async def default_attribute_query_filter(
251
268
  query_where.append(f"toString(av.{filter_name}) =~ ${param_prefix}_{filter_name}")
252
269
  elif filter_name == "isnull":
253
270
  query_filter.append(QueryNode(name="av", labels=[attribute_value_label]))
254
- elif support_profiles:
255
- query_filter.append(QueryNode(name="av", labels=[attribute_value_label]))
256
- query_where.append(f"(av.{filter_name} = ${param_prefix}_{filter_name} OR av.is_default)")
257
271
  else:
258
272
  query_filter.append(
259
273
  QueryNode(
@@ -271,8 +285,6 @@ async def default_attribute_query_filter(
271
285
  if attribute_kind and attribute_kind == "List":
272
286
  query_params[f"{param_prefix}_{filter_name}"] = build_regex_attrs(values=filter_value)
273
287
  query_where.append(f"toString(av.value) =~ ${param_prefix}_{filter_name}")
274
- elif support_profiles:
275
- query_where.append(f"(av.value IN ${param_prefix}_value OR av.is_default)")
276
288
  else:
277
289
  query_where.append(f"av.value IN ${param_prefix}_value")
278
290
  query_params[f"{param_prefix}_value"] = filter_value