infrahub-server 1.2.9__py3-none-any.whl → 1.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. infrahub/computed_attribute/models.py +13 -0
  2. infrahub/computed_attribute/tasks.py +29 -28
  3. infrahub/core/attribute.py +43 -2
  4. infrahub/core/branch/models.py +8 -9
  5. infrahub/core/diff/calculator.py +61 -8
  6. infrahub/core/diff/combiner.py +37 -29
  7. infrahub/core/diff/enricher/hierarchy.py +4 -6
  8. infrahub/core/diff/merger/merger.py +29 -1
  9. infrahub/core/diff/merger/serializer.py +1 -0
  10. infrahub/core/diff/model/path.py +6 -3
  11. infrahub/core/diff/query/merge.py +264 -28
  12. infrahub/core/diff/query/save.py +6 -5
  13. infrahub/core/diff/query_parser.py +4 -15
  14. infrahub/core/diff/repository/deserializer.py +7 -6
  15. infrahub/core/graph/__init__.py +1 -1
  16. infrahub/core/migrations/graph/m028_delete_diffs.py +38 -0
  17. infrahub/core/query/diff.py +97 -13
  18. infrahub/core/query/node.py +26 -3
  19. infrahub/core/query/relationship.py +96 -35
  20. infrahub/core/relationship/model.py +1 -1
  21. infrahub/core/validators/uniqueness/query.py +7 -0
  22. infrahub/trigger/setup.py +13 -2
  23. infrahub/types.py +1 -1
  24. infrahub/webhook/models.py +2 -1
  25. infrahub/workflows/catalogue.py +9 -0
  26. infrahub_sdk/timestamp.py +2 -2
  27. {infrahub_server-1.2.9.dist-info → infrahub_server-1.2.10.dist-info}/METADATA +3 -3
  28. {infrahub_server-1.2.9.dist-info → infrahub_server-1.2.10.dist-info}/RECORD +33 -32
  29. infrahub_testcontainers/docker-compose.test.yml +2 -2
  30. infrahub_testcontainers/performance_test.py +6 -3
  31. {infrahub_server-1.2.9.dist-info → infrahub_server-1.2.10.dist-info}/LICENSE.txt +0 -0
  32. {infrahub_server-1.2.9.dist-info → infrahub_server-1.2.10.dist-info}/WHEEL +0 -0
  33. {infrahub_server-1.2.9.dist-info → infrahub_server-1.2.10.dist-info}/entry_points.txt +0 -0
@@ -118,6 +118,10 @@ class PythonTransformTarget:
118
118
  class ComputedAttrJinja2TriggerDefinition(TriggerBranchDefinition):
119
119
  type: TriggerType = TriggerType.COMPUTED_ATTR_JINJA2
120
120
  computed_attribute: ComputedAttributeTarget
121
+ template_hash: str
122
+
123
+ def get_description(self) -> str:
124
+ return f"{super().get_description()} | hash:{self.template_hash}"
121
125
 
122
126
  @classmethod
123
127
  def from_computed_attribute(
@@ -139,6 +143,14 @@ class ComputedAttrJinja2TriggerDefinition(TriggerBranchDefinition):
139
143
  # node creation events if the attribute is optional.
140
144
  event_trigger.events.add(NodeCreatedEvent.event_name)
141
145
 
146
+ if (
147
+ computed_attribute.attribute.computed_attribute
148
+ and computed_attribute.attribute.computed_attribute.jinja2_template is None
149
+ ) or not computed_attribute.attribute.computed_attribute:
150
+ raise ValueError("Jinja2 template is required for computed attribute")
151
+
152
+ template_hash = computed_attribute.attribute.computed_attribute.get_hash()
153
+
142
154
  event_trigger.match = {"infrahub.node.kind": trigger_node.kind}
143
155
  if branches_out_of_scope:
144
156
  event_trigger.match["infrahub.branch.name"] = [f"!{branch}" for branch in branches_out_of_scope]
@@ -177,6 +189,7 @@ class ComputedAttrJinja2TriggerDefinition(TriggerBranchDefinition):
177
189
 
178
190
  definition = cls(
179
191
  name=f"{computed_attribute.key_name}{NAME_SEPARATOR}kind{NAME_SEPARATOR}{trigger_node.kind}",
192
+ template_hash=template_hash,
180
193
  branch=branch,
181
194
  computed_attribute=computed_attribute,
182
195
  trigger=event_trigger,
@@ -14,7 +14,7 @@ from infrahub.core.registry import registry
14
14
  from infrahub.events import BranchDeletedEvent
15
15
  from infrahub.git.repository import get_initialized_repo
16
16
  from infrahub.services import InfrahubServices # noqa: TC001 needed for prefect flow
17
- from infrahub.trigger.models import TriggerType
17
+ from infrahub.trigger.models import TriggerSetupReport, TriggerType
18
18
  from infrahub.trigger.setup import setup_triggers
19
19
  from infrahub.workflows.catalogue import (
20
20
  COMPUTED_ATTRIBUTE_PROCESS_JINJA2,
@@ -25,7 +25,11 @@ from infrahub.workflows.catalogue import (
25
25
  from infrahub.workflows.utils import add_tags, wait_for_schema_to_converge
26
26
 
27
27
  from .gather import gather_trigger_computed_attribute_jinja2, gather_trigger_computed_attribute_python
28
- from .models import ComputedAttrJinja2GraphQL, ComputedAttrJinja2GraphQLResponse, PythonTransformTarget
28
+ from .models import (
29
+ ComputedAttrJinja2GraphQL,
30
+ ComputedAttrJinja2GraphQLResponse,
31
+ PythonTransformTarget,
32
+ )
29
33
 
30
34
  if TYPE_CHECKING:
31
35
  from infrahub.core.schema.computed_attribute import ComputedAttribute
@@ -159,10 +163,10 @@ async def trigger_update_python_computed_attributes(
159
163
 
160
164
 
161
165
  @flow(
162
- name="process_computed_attribute_value_jinja2",
163
- flow_run_name="Update value for computed attribute {attribute_name}",
166
+ name="computed-attribute-jinja2-update-value",
167
+ flow_run_name="Update value for computed attribute {node_kind}:{attribute_name}",
164
168
  )
165
- async def update_computed_attribute_value_jinja2(
169
+ async def computed_attribute_jinja2_update_value(
166
170
  branch_name: str,
167
171
  obj: ComputedAttrJinja2GraphQLResponse,
168
172
  node_kind: str,
@@ -246,7 +250,7 @@ async def process_jinja2(
246
250
  batch = await service.client.create_batch()
247
251
  for node in found:
248
252
  batch.add(
249
- task=update_computed_attribute_value_jinja2,
253
+ task=computed_attribute_jinja2_update_value,
250
254
  branch_name=branch_name,
251
255
  obj=node,
252
256
  node_kind=node_schema.kind,
@@ -302,36 +306,33 @@ async def computed_attribute_setup_jinja2(
302
306
 
303
307
  triggers = await gather_trigger_computed_attribute_jinja2()
304
308
 
305
- # Since we can have multiple trigger per NodeKind
306
- # we need to extract the list of unique node that should be processed
307
- # also
308
- # Because the automation in Prefect doesn't capture all information about the computed attribute
309
- # we can't tell right now if a given computed attribute has changed and need to be updated
310
- unique_nodes: set[tuple[str, str, str]] = {
311
- (trigger.branch, trigger.computed_attribute.kind, trigger.computed_attribute.attribute.name)
312
- for trigger in triggers
313
- }
314
- for branch, kind, attribute_name in unique_nodes:
315
- if event_name != BranchDeletedEvent.event_name and branch == branch_name:
316
- await service.workflow.submit_workflow(
317
- workflow=TRIGGER_UPDATE_JINJA_COMPUTED_ATTRIBUTES,
318
- context=context,
319
- parameters={
320
- "branch_name": branch,
321
- "computed_attribute_name": attribute_name,
322
- "computed_attribute_kind": kind,
323
- },
324
- )
325
-
326
309
  # Configure all ComputedAttrJinja2Trigger in Prefect
327
310
  async with get_client(sync_client=False) as prefect_client:
328
- await setup_triggers(
311
+ report: TriggerSetupReport = await setup_triggers(
329
312
  client=prefect_client,
330
313
  triggers=triggers,
331
314
  trigger_type=TriggerType.COMPUTED_ATTR_JINJA2,
332
315
  force_update=False,
333
316
  ) # type: ignore[misc]
334
317
 
318
+ # Since we can have multiple trigger per NodeKind
319
+ # we need to extract the list of unique node that should be processed
320
+ unique_nodes: set[tuple[str, str, str]] = {
321
+ (trigger.branch, trigger.computed_attribute.kind, trigger.computed_attribute.attribute.name) # type: ignore[attr-defined]
322
+ for trigger in report.updated + report.created
323
+ }
324
+ for branch, kind, attribute_name in unique_nodes:
325
+ if event_name != BranchDeletedEvent.event_name and branch == branch_name:
326
+ await service.workflow.submit_workflow(
327
+ workflow=TRIGGER_UPDATE_JINJA_COMPUTED_ATTRIBUTES,
328
+ context=context,
329
+ parameters={
330
+ "branch_name": branch,
331
+ "computed_attribute_name": attribute_name,
332
+ "computed_attribute_kind": kind,
333
+ },
334
+ )
335
+
335
336
  log.info(f"{len(triggers)} Computed Attribute for Jinja2 automation configuration completed")
336
337
 
337
338
 
@@ -782,6 +782,10 @@ class Dropdown(BaseAttribute):
782
782
 
783
783
  return ""
784
784
 
785
+ @staticmethod
786
+ def get_allowed_property_in_path() -> list[str]:
787
+ return ["color", "description", "label", "value"]
788
+
785
789
  @classmethod
786
790
  def validate_content(cls, value: Any, name: str, schema: AttributeSchema) -> None:
787
791
  """Validate the content of the dropdown."""
@@ -817,7 +821,18 @@ class IPNetwork(BaseAttribute):
817
821
 
818
822
  @staticmethod
819
823
  def get_allowed_property_in_path() -> list[str]:
820
- return ["value", "version", "binary_address", "prefixlen"]
824
+ return [
825
+ "binary_address",
826
+ "broadcast_address",
827
+ "hostmask",
828
+ "netmask",
829
+ "num_addresses",
830
+ "prefixlen",
831
+ "value",
832
+ "version",
833
+ "with_hostmask",
834
+ "with_netmask",
835
+ ]
821
836
 
822
837
  @property
823
838
  def obj(self) -> ipaddress.IPv4Network | ipaddress.IPv6Network:
@@ -950,7 +965,17 @@ class IPHost(BaseAttribute):
950
965
 
951
966
  @staticmethod
952
967
  def get_allowed_property_in_path() -> list[str]:
953
- return ["value", "version", "binary_address"]
968
+ return [
969
+ "binary_address",
970
+ "hostmask",
971
+ "ip",
972
+ "netmask",
973
+ "prefixlen",
974
+ "value",
975
+ "version",
976
+ "with_hostmask",
977
+ "with_netmask",
978
+ ]
954
979
 
955
980
  @property
956
981
  def obj(self) -> ipaddress.IPv4Interface | ipaddress.IPv6Interface:
@@ -1170,6 +1195,22 @@ class MacAddress(BaseAttribute):
1170
1195
  """Serialize the value as standard EUI-48 or EUI-64 before storing it in the database."""
1171
1196
  return str(netaddr.EUI(addr=self.value))
1172
1197
 
1198
+ @staticmethod
1199
+ def get_allowed_property_in_path() -> list[str]:
1200
+ return [
1201
+ "bare",
1202
+ "binary",
1203
+ "dot_notation",
1204
+ "ei",
1205
+ "eui48",
1206
+ "eui64",
1207
+ "oui",
1208
+ "semicolon_notation",
1209
+ "split_notation",
1210
+ "value",
1211
+ "version",
1212
+ ]
1213
+
1173
1214
 
1174
1215
  class MacAddressOptional(MacAddress):
1175
1216
  value: str | None
@@ -295,6 +295,7 @@ class Branch(StandardNode):
295
295
  is_isolated: bool = True,
296
296
  branch_agnostic: bool = False,
297
297
  variable_name: str = "r",
298
+ params_prefix: str = "",
298
299
  ) -> tuple[str, dict]:
299
300
  """
300
301
  Generate a CYPHER Query filter based on a path to query a part of the graph at a specific time and on a specific branch.
@@ -306,30 +307,28 @@ class Branch(StandardNode):
306
307
 
307
308
  There is a currently an assumption that the relationship in the path will be named 'r'
308
309
  """
309
-
310
+ pp = params_prefix
310
311
  params: dict[str, Any] = {}
311
312
  at = Timestamp(at)
312
313
  at_str = at.to_string()
313
314
  if branch_agnostic:
314
- filter_str = (
315
- f"{variable_name}.from <= $time1 AND ({variable_name}.to IS NULL or {variable_name}.to >= $time1)"
316
- )
317
- params["time1"] = at_str
315
+ filter_str = f"{variable_name}.from <= ${pp}time1 AND ({variable_name}.to IS NULL or {variable_name}.to >= ${pp}time1)"
316
+ params[f"{pp}time1"] = at_str
318
317
  return filter_str, params
319
318
 
320
319
  branches_times = self.get_branches_and_times_to_query_global(at=at_str, is_isolated=is_isolated)
321
320
 
322
321
  for idx, (branch_name, time_to_query) in enumerate(branches_times.items()):
323
- params[f"branch{idx}"] = list(branch_name)
324
- params[f"time{idx}"] = time_to_query
322
+ params[f"{pp}branch{idx}"] = list(branch_name)
323
+ params[f"{pp}time{idx}"] = time_to_query
325
324
 
326
325
  filters = []
327
326
  for idx in range(len(branches_times)):
328
327
  filters.append(
329
- f"({variable_name}.branch IN $branch{idx} AND {variable_name}.from <= $time{idx} AND {variable_name}.to IS NULL)"
328
+ f"({variable_name}.branch IN ${pp}branch{idx} AND {variable_name}.from <= ${pp}time{idx} AND {variable_name}.to IS NULL)"
330
329
  )
331
330
  filters.append(
332
- f"({variable_name}.branch IN $branch{idx} AND {variable_name}.from <= $time{idx} AND {variable_name}.to >= $time{idx})"
331
+ f"({variable_name}.branch IN ${pp}branch{idx} AND {variable_name}.from <= ${pp}time{idx} AND {variable_name}.to >= ${pp}time{idx})"
333
332
  )
334
333
 
335
334
  filter_str = "(" + "\n OR ".join(filters) + ")"
@@ -7,6 +7,7 @@ from infrahub.core.diff.query_parser import DiffQueryParser
7
7
  from infrahub.core.query.diff import (
8
8
  DiffCalculationQuery,
9
9
  DiffFieldPathsQuery,
10
+ DiffMigratedKindNodesQuery,
10
11
  DiffNodePathsQuery,
11
12
  DiffPropertyPathsQuery,
12
13
  )
@@ -15,7 +16,7 @@ from infrahub.database import InfrahubDatabase
15
16
  from infrahub.log import get_logger
16
17
 
17
18
  from .model.field_specifiers_map import NodeFieldSpecifierMap
18
- from .model.path import CalculatedDiffs
19
+ from .model.path import CalculatedDiffs, DiffNode, DiffRoot, NodeIdentifier
19
20
 
20
21
  log = get_logger()
21
22
 
@@ -59,7 +60,7 @@ class DiffCalculator:
59
60
  )
60
61
  log.info(f"Beginning one diff calculation query {limit=}, {offset=}")
61
62
  await diff_query.execute(db=self.db)
62
- log.info("Diff calculation query complete")
63
+ log.info(f"Diff calculation query complete {limit=}, {offset=}")
63
64
  last_result = None
64
65
  for query_result in diff_query.get_results():
65
66
  diff_parser.read_result(query_result=query_result)
@@ -69,6 +70,56 @@ class DiffCalculator:
69
70
  has_more_data = last_result.get_as_type("has_more_data", bool)
70
71
  offset += limit
71
72
 
73
+ async def _apply_kind_migrated_nodes(
74
+ self, branch_diff: DiffRoot, calculation_request: DiffCalculationRequest
75
+ ) -> None:
76
+ has_more_data = True
77
+ offset = 0
78
+ limit = config.SETTINGS.database.query_size_limit
79
+ diff_nodes_by_identifier = {n.identifier: n for n in branch_diff.nodes}
80
+ diff_nodes_to_add: list[DiffNode] = []
81
+ while has_more_data:
82
+ diff_query = await DiffMigratedKindNodesQuery.init(
83
+ db=self.db,
84
+ branch=calculation_request.diff_branch,
85
+ base_branch=calculation_request.base_branch,
86
+ diff_branch_from_time=calculation_request.branch_from_time,
87
+ diff_from=calculation_request.from_time,
88
+ diff_to=calculation_request.to_time,
89
+ limit=limit,
90
+ offset=offset,
91
+ )
92
+ log.info(f"Getting one batch of migrated kind nodes {limit=}, {offset=}")
93
+ await diff_query.execute(db=self.db)
94
+ log.info(f"Migrated kind nodes query complete {limit=}, {offset=}")
95
+ last_result = None
96
+ for migrated_kind_node in diff_query.get_migrated_kind_nodes():
97
+ migrated_kind_identifier = NodeIdentifier(
98
+ uuid=migrated_kind_node.uuid,
99
+ kind=migrated_kind_node.kind,
100
+ db_id=migrated_kind_node.db_id,
101
+ )
102
+ if migrated_kind_identifier in diff_nodes_by_identifier:
103
+ diff_node = diff_nodes_by_identifier[migrated_kind_identifier]
104
+ diff_node.is_node_kind_migration = True
105
+ continue
106
+ new_diff_node = DiffNode(
107
+ identifier=migrated_kind_identifier,
108
+ changed_at=migrated_kind_node.from_time,
109
+ action=migrated_kind_node.action,
110
+ is_node_kind_migration=True,
111
+ attributes=[],
112
+ relationships=[],
113
+ )
114
+ diff_nodes_by_identifier[migrated_kind_identifier] = new_diff_node
115
+ diff_nodes_to_add.append(new_diff_node)
116
+ last_result = migrated_kind_node
117
+ has_more_data = False
118
+ if last_result:
119
+ has_more_data = last_result.has_more_data
120
+ offset += limit
121
+ branch_diff.nodes.extend(diff_nodes_to_add)
122
+
72
123
  async def calculate_diff(
73
124
  self,
74
125
  base_branch: Branch,
@@ -92,7 +143,7 @@ class DiffCalculator:
92
143
  )
93
144
  node_limit = int(config.SETTINGS.database.query_size_limit / 10)
94
145
  fields_limit = int(config.SETTINGS.database.query_size_limit / 3)
95
- properties_limit = int(config.SETTINGS.database.query_size_limit)
146
+ properties_limit = config.SETTINGS.database.query_size_limit
96
147
 
97
148
  calculation_request = DiffCalculationRequest(
98
149
  base_branch=base_branch,
@@ -132,7 +183,7 @@ class DiffCalculator:
132
183
  if base_branch.name != diff_branch.name:
133
184
  current_node_field_specifiers = diff_parser.get_current_node_field_specifiers()
134
185
  new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
135
- calculation_request = DiffCalculationRequest(
186
+ base_calculation_request = DiffCalculationRequest(
136
187
  base_branch=base_branch,
137
188
  diff_branch=base_branch,
138
189
  branch_from_time=diff_branch_from_time,
@@ -146,7 +197,7 @@ class DiffCalculator:
146
197
  await self._run_diff_calculation_query(
147
198
  diff_parser=diff_parser,
148
199
  query_class=DiffNodePathsQuery,
149
- calculation_request=calculation_request,
200
+ calculation_request=base_calculation_request,
150
201
  limit=node_limit,
151
202
  )
152
203
  log.info("Diff node-level calculation queries for base complete")
@@ -155,7 +206,7 @@ class DiffCalculator:
155
206
  await self._run_diff_calculation_query(
156
207
  diff_parser=diff_parser,
157
208
  query_class=DiffFieldPathsQuery,
158
- calculation_request=calculation_request,
209
+ calculation_request=base_calculation_request,
159
210
  limit=fields_limit,
160
211
  )
161
212
  log.info("Diff field-level calculation queries for base complete")
@@ -164,7 +215,7 @@ class DiffCalculator:
164
215
  await self._run_diff_calculation_query(
165
216
  diff_parser=diff_parser,
166
217
  query_class=DiffPropertyPathsQuery,
167
- calculation_request=calculation_request,
218
+ calculation_request=base_calculation_request,
168
219
  limit=properties_limit,
169
220
  )
170
221
  log.info("Diff property-level calculation queries for base complete")
@@ -172,9 +223,11 @@ class DiffCalculator:
172
223
  log.info("Parsing calculated diff")
173
224
  diff_parser.parse(include_unchanged=include_unchanged)
174
225
  log.info("Calculated diff parsed")
226
+ branch_diff = diff_parser.get_diff_root_for_branch(branch=diff_branch.name)
227
+ await self._apply_kind_migrated_nodes(branch_diff=branch_diff, calculation_request=calculation_request)
175
228
  return CalculatedDiffs(
176
229
  base_branch_name=base_branch.name,
177
230
  diff_branch_name=diff_branch.name,
178
231
  base_branch_diff=diff_parser.get_diff_root_for_branch(branch=base_branch.name),
179
- diff_branch_diff=diff_parser.get_diff_root_for_branch(branch=diff_branch.name),
232
+ diff_branch_diff=branch_diff,
180
233
  )
@@ -14,6 +14,7 @@ from .model.path import (
14
14
  EnrichedDiffRoot,
15
15
  EnrichedDiffs,
16
16
  EnrichedDiffSingleRelationship,
17
+ NodeIdentifier,
17
18
  )
18
19
 
19
20
 
@@ -26,30 +27,35 @@ class NodePair:
26
27
  class DiffCombiner:
27
28
  def __init__(self) -> None:
28
29
  # {child_uuid: (parent_uuid, parent_rel_name)}
29
- self._child_parent_uuid_map: dict[str, tuple[str, str]] = {}
30
- self._parent_node_uuids: set[str] = set()
31
- self._earlier_nodes_by_uuid: dict[str, EnrichedDiffNode] = {}
32
- self._later_nodes_by_uuid: dict[str, EnrichedDiffNode] = {}
33
- self._common_node_uuids: set[str] = set()
30
+ self._child_parent_identifier_map: dict[NodeIdentifier, tuple[NodeIdentifier, str]] = {}
31
+ self._parent_node_identifiers: set[NodeIdentifier] = set()
32
+ self._earlier_nodes_by_identifier: dict[NodeIdentifier, EnrichedDiffNode] = {}
33
+ self._later_nodes_by_identifier: dict[NodeIdentifier, EnrichedDiffNode] = {}
34
+ self._common_node_identifiers: set[NodeIdentifier] = set()
34
35
  self._diff_branch_name: str | None = None
35
36
 
36
37
  def _initialize(self, earlier_diff: EnrichedDiffRoot, later_diff: EnrichedDiffRoot) -> None:
37
38
  self._diff_branch_name = earlier_diff.diff_branch_name
38
- self._child_parent_uuid_map = {}
39
- self._earlier_nodes_by_uuid = {}
40
- self._later_nodes_by_uuid = {}
41
- self._common_node_uuids = set()
39
+ self._child_parent_identifier_map = {}
40
+ self._earlier_nodes_by_identifier = {}
41
+ self._later_nodes_by_identifier = {}
42
+ self._common_node_identifiers = set()
42
43
  # map the parent of each node (if it exists), preference to the later diff
43
44
  for diff_root in (earlier_diff, later_diff):
44
45
  for child_node in diff_root.nodes:
45
46
  for parent_rel in child_node.relationships:
46
47
  for parent_node in parent_rel.nodes:
47
- self._child_parent_uuid_map[child_node.uuid] = (parent_node.uuid, parent_rel.name)
48
+ self._child_parent_identifier_map[child_node.identifier] = (
49
+ parent_node.identifier,
50
+ parent_rel.name,
51
+ )
48
52
  # UUIDs of all the parents, removing the stale parents from the earlier diff
49
- self._parent_node_uuids = {parent_tuple[0] for parent_tuple in self._child_parent_uuid_map.values()}
50
- self._earlier_nodes_by_uuid = {n.uuid: n for n in earlier_diff.nodes}
51
- self._later_nodes_by_uuid = {n.uuid: n for n in later_diff.nodes}
52
- self._common_node_uuids = set(self._earlier_nodes_by_uuid.keys()) & set(self._later_nodes_by_uuid.keys())
53
+ self._parent_node_identifiers = {parent_tuple[0] for parent_tuple in self._child_parent_identifier_map.values()}
54
+ self._earlier_nodes_by_identifier = {n.identifier: n for n in earlier_diff.nodes}
55
+ self._later_nodes_by_identifier = {n.identifier: n for n in later_diff.nodes}
56
+ self._common_node_identifiers = set(self._earlier_nodes_by_identifier.keys()) & set(
57
+ self._later_nodes_by_identifier.keys()
58
+ )
53
59
 
54
60
  @property
55
61
  def diff_branch_name(self) -> str:
@@ -61,13 +67,13 @@ class DiffCombiner:
61
67
  filtered_node_pairs: list[NodePair] = []
62
68
  for earlier_node in earlier_diff.nodes:
63
69
  later_node: EnrichedDiffNode | None = None
64
- if earlier_node.uuid in self._common_node_uuids:
65
- later_node = self._later_nodes_by_uuid[earlier_node.uuid]
70
+ if earlier_node.identifier in self._common_node_identifiers:
71
+ later_node = self._later_nodes_by_identifier[earlier_node.identifier]
66
72
  # this is an out-of-date parent
67
73
  if (
68
74
  earlier_node.action is DiffAction.UNCHANGED
69
75
  and (later_node is None or later_node.action is DiffAction.UNCHANGED)
70
- and earlier_node.uuid not in self._parent_node_uuids
76
+ and earlier_node.identifier not in self._parent_node_identifiers
71
77
  ):
72
78
  continue
73
79
  if later_node is None:
@@ -79,15 +85,15 @@ class DiffCombiner:
79
85
  filtered_node_pairs.append(NodePair(earlier=earlier_node, later=later_node))
80
86
  for later_node in later_diff.nodes:
81
87
  # these have already been handled
82
- if later_node.uuid in self._common_node_uuids:
88
+ if later_node.identifier in self._common_node_identifiers:
83
89
  continue
84
90
  filtered_node_pairs.append(NodePair(later=later_node))
85
91
  return filtered_node_pairs
86
92
 
87
- def _get_parent_relationship_name(self, node_id: str) -> str | None:
88
- if node_id not in self._child_parent_uuid_map:
93
+ def _get_parent_relationship_name(self, node_id: NodeIdentifier) -> str | None:
94
+ if node_id not in self._child_parent_identifier_map:
89
95
  return None
90
- return self._child_parent_uuid_map[node_id][1]
96
+ return self._child_parent_identifier_map[node_id][1]
91
97
 
92
98
  def _should_include(self, earlier: DiffAction, later: DiffAction) -> bool:
93
99
  actions = {earlier, later}
@@ -284,7 +290,7 @@ class DiffCombiner:
284
290
  self,
285
291
  earlier_relationships: set[EnrichedDiffRelationship],
286
292
  later_relationships: set[EnrichedDiffRelationship],
287
- node_id: str,
293
+ node_id: NodeIdentifier,
288
294
  ) -> set[EnrichedDiffRelationship]:
289
295
  earlier_rels_by_name = {rel.name: rel for rel in earlier_relationships}
290
296
  later_rels_by_name = {rel.name: rel for rel in later_relationships}
@@ -365,7 +371,7 @@ class DiffCombiner:
365
371
  combined_relationships = self._combine_relationships(
366
372
  earlier_relationships=node_pair.earlier.relationships,
367
373
  later_relationships=node_pair.later.relationships,
368
- node_id=node_pair.later.uuid,
374
+ node_id=node_pair.later.identifier,
369
375
  )
370
376
  if all(ca.action is DiffAction.UNCHANGED for ca in combined_attributes) and all(
371
377
  cr.action is DiffAction.UNCHANGED for cr in combined_relationships
@@ -380,7 +386,7 @@ class DiffCombiner:
380
386
  combined_attributes
381
387
  or combined_relationships
382
388
  or combined_conflict
383
- or node_pair.later.uuid in self._parent_node_uuids
389
+ or node_pair.later.identifier in self._parent_node_identifiers
384
390
  ):
385
391
  combined_nodes.add(
386
392
  EnrichedDiffNode(
@@ -388,6 +394,8 @@ class DiffCombiner:
388
394
  label=node_pair.later.label,
389
395
  changed_at=node_pair.later.changed_at or node_pair.earlier.changed_at,
390
396
  action=combined_action,
397
+ is_node_kind_migration=node_pair.earlier.is_node_kind_migration
398
+ or node_pair.later.is_node_kind_migration,
391
399
  path_identifier=node_pair.later.path_identifier,
392
400
  attributes=combined_attributes,
393
401
  relationships=combined_relationships,
@@ -397,12 +405,12 @@ class DiffCombiner:
397
405
  return combined_nodes
398
406
 
399
407
  def _link_child_nodes(self, nodes: Iterable[EnrichedDiffNode]) -> None:
400
- nodes_by_uuid: dict[str, EnrichedDiffNode] = {n.uuid: n for n in nodes}
401
- for child_node in nodes_by_uuid.values():
402
- if child_node.uuid not in self._child_parent_uuid_map:
408
+ nodes_by_identifier: dict[NodeIdentifier, EnrichedDiffNode] = {n.identifier: n for n in nodes}
409
+ for child_node in nodes_by_identifier.values():
410
+ if child_node.identifier not in self._child_parent_identifier_map:
403
411
  continue
404
- parent_uuid, parent_rel_name = self._child_parent_uuid_map[child_node.uuid]
405
- parent_node = nodes_by_uuid[parent_uuid]
412
+ parent_identifier, parent_rel_name = self._child_parent_identifier_map[child_node.identifier]
413
+ parent_node = nodes_by_identifier[parent_identifier]
406
414
  parent_rel = child_node.get_relationship(name=parent_rel_name)
407
415
  parent_rel.nodes.add(parent_node)
408
416
 
@@ -99,7 +99,7 @@ class DiffHierarchyEnricher(DiffEnricherInterface):
99
99
 
100
100
  current_node = node
101
101
  for ancestor in ancestors:
102
- ancestor_identifier = NodeIdentifier(uuid=ancestor.uuid, kind=ancestor.kind, labels=ancestor.labels)
102
+ ancestor_identifier = NodeIdentifier(uuid=ancestor.uuid, kind=ancestor.kind, db_id=ancestor.db_id)
103
103
  parent_request = ParentNodeAddRequest(
104
104
  node_identifier=current_node.identifier,
105
105
  parent_identifier=ancestor_identifier,
@@ -146,13 +146,11 @@ class DiffHierarchyEnricher(DiffEnricherInterface):
146
146
 
147
147
  for peer in query.get_peers():
148
148
  source_identifier = NodeIdentifier(
149
- uuid=str(peer.source_id), kind=peer.source_kind, labels=peer.source_labels
149
+ uuid=str(peer.source_id), kind=peer.source_kind, db_id=peer.source_db_id
150
150
  )
151
151
  parent_peers[source_identifier] = peer
152
152
  if parent_schema.has_parent_relationship:
153
- peer_identifier = NodeIdentifier(
154
- uuid=str(peer.peer_id), kind=peer.peer_kind, labels=peer.peer_labels
155
- )
153
+ peer_identifier = NodeIdentifier(uuid=str(peer.peer_id), kind=peer.peer_kind, db_id=peer.peer_db_id)
156
154
  node_parent_with_parent_map[parent_schema.kind].append(peer_identifier)
157
155
 
158
156
  # Check if the parent are already present
@@ -170,7 +168,7 @@ class DiffHierarchyEnricher(DiffEnricherInterface):
170
168
  parent_rel = [rel for rel in schema_node.relationships if rel.kind == RelationshipKind.PARENT][0]
171
169
 
172
170
  peer_identifier = NodeIdentifier(
173
- uuid=str(peer_parent.peer_id), kind=peer_parent.peer_kind, labels=peer_parent.peer_labels
171
+ uuid=str(peer_parent.peer_id), kind=peer_parent.peer_kind, db_id=peer_parent.peer_db_id
174
172
  )
175
173
  parent_request = ParentNodeAddRequest(
176
174
  node_identifier=node.identifier,
@@ -3,8 +3,14 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  from infrahub.core import registry
6
+ from infrahub.core.constants import DiffAction
6
7
  from infrahub.core.diff.model.path import BranchTrackingId
7
- from infrahub.core.diff.query.merge import DiffMergePropertiesQuery, DiffMergeQuery, DiffMergeRollbackQuery
8
+ from infrahub.core.diff.query.merge import (
9
+ DiffMergeMigratedKindsQuery,
10
+ DiffMergePropertiesQuery,
11
+ DiffMergeQuery,
12
+ DiffMergeRollbackQuery,
13
+ )
8
14
  from infrahub.log import get_logger
9
15
 
10
16
  if TYPE_CHECKING:
@@ -53,6 +59,16 @@ class DiffMerger:
53
59
  )
54
60
  log.info(f"Diff {latest_diff.uuid} retrieved")
55
61
  batch_num = 0
62
+ migrated_kinds_id_map = {}
63
+ for n in enriched_diff.nodes:
64
+ if not n.is_node_kind_migration:
65
+ continue
66
+ if n.uuid not in migrated_kinds_id_map or (
67
+ n.uuid in migrated_kinds_id_map and n.action is DiffAction.ADDED
68
+ ):
69
+ # make sure that we use the ADDED db_id if it exists
70
+ # it will not if a node was migrated and then deleted
71
+ migrated_kinds_id_map[n.uuid] = n.identifier.db_id
56
72
  async for node_diff_dicts, property_diff_dicts in self.serializer.serialize_diff(diff=enriched_diff):
57
73
  if node_diff_dicts:
58
74
  log.info(f"Merging batch of nodes #{batch_num}")
@@ -62,6 +78,7 @@ class DiffMerger:
62
78
  at=at,
63
79
  target_branch=self.destination_branch,
64
80
  node_diff_dicts=node_diff_dicts,
81
+ migrated_kinds_id_map=migrated_kinds_id_map,
65
82
  )
66
83
  await merge_query.execute(db=self.db)
67
84
  if property_diff_dicts:
@@ -72,10 +89,21 @@ class DiffMerger:
72
89
  at=at,
73
90
  target_branch=self.destination_branch,
74
91
  property_diff_dicts=property_diff_dicts,
92
+ migrated_kinds_id_map=migrated_kinds_id_map,
75
93
  )
76
94
  await merge_properties_query.execute(db=self.db)
77
95
  log.info(f"Batch #{batch_num} merged")
78
96
  batch_num += 1
97
+ migrated_kind_uuids = {n.identifier.uuid for n in enriched_diff.nodes if n.is_node_kind_migration}
98
+ if migrated_kind_uuids:
99
+ migrated_merge_query = await DiffMergeMigratedKindsQuery.init(
100
+ db=self.db,
101
+ branch=self.source_branch,
102
+ at=at,
103
+ target_branch=self.destination_branch,
104
+ migrated_uuids=list(migrated_kind_uuids),
105
+ )
106
+ await migrated_merge_query.execute(db=self.db)
79
107
 
80
108
  self.source_branch.branched_from = at.to_string()
81
109
  await self.source_branch.save(db=self.db)
@@ -39,6 +39,7 @@ class DiffMergeSerializer:
39
39
 
40
40
  def _reset_caches(self) -> None:
41
41
  self._attribute_type_cache = {}
42
+ self._conflicted_cardinality_one_relationships = set()
42
43
 
43
44
  @property
44
45
  def source_branch_name(self) -> str: