infrahub-server 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. infrahub/api/oidc.py +1 -0
  2. infrahub/core/attribute.py +4 -1
  3. infrahub/core/branch/tasks.py +7 -4
  4. infrahub/core/diff/calculator.py +21 -39
  5. infrahub/core/diff/combiner.py +11 -7
  6. infrahub/core/diff/coordinator.py +49 -70
  7. infrahub/core/diff/data_check_synchronizer.py +86 -7
  8. infrahub/core/diff/enricher/aggregated.py +3 -3
  9. infrahub/core/diff/enricher/cardinality_one.py +1 -6
  10. infrahub/core/diff/enricher/labels.py +13 -3
  11. infrahub/core/diff/enricher/path_identifier.py +2 -8
  12. infrahub/core/diff/ipam_diff_parser.py +1 -1
  13. infrahub/core/diff/merger/merger.py +5 -3
  14. infrahub/core/diff/merger/serializer.py +15 -8
  15. infrahub/core/diff/model/path.py +42 -24
  16. infrahub/core/diff/query/all_conflicts.py +5 -2
  17. infrahub/core/diff/query/diff_get.py +19 -23
  18. infrahub/core/diff/query/field_specifiers.py +2 -0
  19. infrahub/core/diff/query/field_summary.py +2 -1
  20. infrahub/core/diff/query/filters.py +12 -1
  21. infrahub/core/diff/query/has_conflicts_query.py +5 -2
  22. infrahub/core/diff/query/{drop_tracking_id.py → merge_tracking_id.py} +3 -3
  23. infrahub/core/diff/query/roots_metadata.py +8 -1
  24. infrahub/core/diff/query/save.py +148 -63
  25. infrahub/core/diff/query/summary_counts_enricher.py +220 -0
  26. infrahub/core/diff/query/time_range_query.py +2 -1
  27. infrahub/core/diff/query_parser.py +49 -24
  28. infrahub/core/diff/repository/deserializer.py +74 -71
  29. infrahub/core/diff/repository/repository.py +119 -30
  30. infrahub/core/node/__init__.py +6 -1
  31. infrahub/core/node/constraints/grouped_uniqueness.py +9 -2
  32. infrahub/core/node/ipam.py +6 -1
  33. infrahub/core/node/permissions.py +4 -0
  34. infrahub/core/query/diff.py +223 -230
  35. infrahub/core/query/node.py +8 -2
  36. infrahub/core/query/relationship.py +2 -1
  37. infrahub/core/query/resource_manager.py +3 -1
  38. infrahub/core/relationship/model.py +1 -1
  39. infrahub/core/schema/schema_branch.py +16 -7
  40. infrahub/core/utils.py +1 -0
  41. infrahub/core/validators/uniqueness/query.py +20 -17
  42. infrahub/database/__init__.py +13 -0
  43. infrahub/dependencies/builder/constraint/grouped/node_runner.py +0 -2
  44. infrahub/dependencies/builder/diff/coordinator.py +0 -2
  45. infrahub/git/integrator.py +10 -6
  46. infrahub/graphql/mutations/computed_attribute.py +3 -1
  47. infrahub/graphql/mutations/diff.py +28 -4
  48. infrahub/graphql/mutations/main.py +11 -6
  49. infrahub/graphql/mutations/relationship.py +29 -1
  50. infrahub/graphql/mutations/tasks.py +6 -3
  51. infrahub/graphql/queries/resource_manager.py +7 -3
  52. infrahub/permissions/__init__.py +2 -1
  53. infrahub/permissions/types.py +26 -0
  54. infrahub/proposed_change/tasks.py +6 -1
  55. infrahub/storage.py +6 -5
  56. {infrahub_server-1.1.5.dist-info → infrahub_server-1.1.7.dist-info}/METADATA +41 -7
  57. {infrahub_server-1.1.5.dist-info → infrahub_server-1.1.7.dist-info}/RECORD +64 -64
  58. infrahub_testcontainers/container.py +12 -3
  59. infrahub_testcontainers/docker-compose.test.yml +22 -3
  60. infrahub_testcontainers/haproxy.cfg +43 -0
  61. infrahub_testcontainers/helpers.py +85 -1
  62. infrahub/core/diff/enricher/summary_counts.py +0 -105
  63. infrahub/dependencies/builder/diff/enricher/summary_counts.py +0 -8
  64. {infrahub_server-1.1.5.dist-info → infrahub_server-1.1.7.dist-info}/LICENSE.txt +0 -0
  65. {infrahub_server-1.1.5.dist-info → infrahub_server-1.1.7.dist-info}/WHEEL +0 -0
  66. {infrahub_server-1.1.5.dist-info → infrahub_server-1.1.7.dist-info}/entry_points.txt +0 -0
@@ -28,7 +28,6 @@ from .model.path import (
28
28
  if TYPE_CHECKING:
29
29
  from infrahub.core.branch import Branch
30
30
  from infrahub.core.query import QueryResult
31
- from infrahub.core.schema import MainSchemaTypes
32
31
  from infrahub.core.schema.manager import SchemaManager
33
32
  from infrahub.core.schema.relationship_schema import RelationshipSchema
34
33
 
@@ -397,8 +396,12 @@ class DiffNodeIntermediate(TrackedStatusUpdates):
397
396
  force_action: DiffAction | None
398
397
  uuid: str
399
398
  kind: str
399
+ db_id: str
400
+ from_time: Timestamp
401
+ status: RelationshipStatus
400
402
  attributes_by_name: dict[str, DiffAttributeIntermediate] = field(default_factory=dict)
401
- relationships_by_name: dict[str, DiffRelationshipIntermediate] = field(default_factory=dict)
403
+ # {(name, identifier): DiffRelationshipIntermediate}
404
+ relationships_by_identifier: dict[tuple[str, str], DiffRelationshipIntermediate] = field(default_factory=dict)
402
405
 
403
406
  def to_diff_node(self, from_time: Timestamp, include_unchanged: bool) -> DiffNode:
404
407
  attributes = []
@@ -408,7 +411,7 @@ class DiffNodeIntermediate(TrackedStatusUpdates):
408
411
  attributes.append(diff_attr)
409
412
  action, changed_at = self.get_action_and_timestamp(from_time=from_time)
410
413
  relationships = []
411
- for rel in self.relationships_by_name.values():
414
+ for rel in self.relationships_by_identifier.values():
412
415
  diff_rel = rel.to_diff_relationship(include_unchanged=include_unchanged)
413
416
  if include_unchanged or diff_rel.action is not DiffAction.UNCHANGED:
414
417
  relationships.append(diff_rel)
@@ -431,7 +434,7 @@ class DiffNodeIntermediate(TrackedStatusUpdates):
431
434
 
432
435
  @property
433
436
  def is_empty(self) -> bool:
434
- return len(self.attributes_by_name) == 0 and len(self.relationships_by_name) == 0
437
+ return len(self.attributes_by_name) == 0 and len(self.relationships_by_identifier) == 0
435
438
 
436
439
 
437
440
  @dataclass
@@ -495,7 +498,7 @@ class DiffQueryParser:
495
498
  for node in diff_root.nodes_by_id.values():
496
499
  for attribute_name in node.attributes_by_name:
497
500
  node_field_specifiers_map[node.uuid].add(attribute_name)
498
- for relationship_diff in node.relationships_by_name.values():
501
+ for relationship_diff in node.relationships_by_identifier.values():
499
502
  node_field_specifiers_map[node.uuid].add(relationship_diff.identifier)
500
503
  return node_field_specifiers_map
501
504
 
@@ -567,35 +570,53 @@ class DiffQueryParser:
567
570
  diff_root.nodes_by_id[node_id] = DiffNodeIntermediate(
568
571
  uuid=node_id,
569
572
  kind=database_path.node_kind,
573
+ db_id=database_path.node_db_id,
574
+ from_time=database_path.node_changed_at,
575
+ status=database_path.node_status,
570
576
  force_action=DiffAction.UPDATED
571
577
  if database_path.node_branch_support is BranchSupportType.AGNOSTIC
572
578
  else None,
573
579
  )
574
580
  diff_node = diff_root.nodes_by_id[node_id]
581
+ # special handling for nodes that have their kind updated, which results in 2 nodes with the same uuid
582
+ if diff_node.db_id != database_path.node_db_id and (
583
+ database_path.node_changed_at > diff_node.from_time
584
+ or (
585
+ database_path.node_changed_at >= diff_node.from_time
586
+ and (diff_node.status, database_path.node_status)
587
+ == (RelationshipStatus.DELETED, RelationshipStatus.ACTIVE)
588
+ )
589
+ ):
590
+ diff_node.kind = database_path.node_kind
591
+ diff_node.db_id = database_path.node_db_id
592
+ diff_node.from_time = database_path.node_changed_at
593
+ diff_node.status = database_path.node_status
575
594
  diff_node.track_database_path(database_path=database_path)
576
595
  return diff_node
577
596
 
578
- def _get_relationship_schema(
579
- self, database_path: DatabasePath, node_schema: MainSchemaTypes
580
- ) -> RelationshipSchema | None:
581
- relationship_schemas = node_schema.get_relationships_by_identifier(id=database_path.attribute_name)
582
- if len(relationship_schemas) == 1:
583
- return relationship_schemas[0]
584
- possible_path_directions = database_path.possible_relationship_directions
585
- for rel_schema in relationship_schemas:
586
- if rel_schema.direction in possible_path_directions:
587
- return rel_schema
597
+ def _get_relationship_schema(self, database_path: DatabasePath) -> RelationshipSchema | None:
598
+ branches_to_check = [database_path.deepest_branch]
599
+ if database_path.deepest_branch == self.diff_branch_name:
600
+ branches_to_check.append(self.base_branch_name)
601
+ for schema_branch_name in branches_to_check:
602
+ node_schema = self.schema_manager.get(
603
+ name=database_path.node_kind, branch=schema_branch_name, duplicate=False
604
+ )
605
+ relationship_schemas = node_schema.get_relationships_by_identifier(id=database_path.attribute_name)
606
+ if len(relationship_schemas) == 1:
607
+ return relationship_schemas[0]
608
+ possible_path_directions = database_path.possible_relationship_directions
609
+ for rel_schema in relationship_schemas:
610
+ if rel_schema.direction in possible_path_directions:
611
+ return rel_schema
588
612
  return None
589
613
 
590
614
  def _update_attribute_level(self, database_path: DatabasePath, diff_node: DiffNodeIntermediate) -> None:
591
- node_schema = self.schema_manager.get(
592
- name=database_path.node_kind, branch=database_path.deepest_branch, duplicate=False
593
- )
594
615
  if "Attribute" in database_path.attribute_node.labels:
595
616
  diff_attribute = self._get_diff_attribute(database_path=database_path, diff_node=diff_node)
596
617
  self._update_attribute_property(database_path=database_path, diff_attribute=diff_attribute)
597
618
  return
598
- relationship_schema = self._get_relationship_schema(database_path=database_path, node_schema=node_schema)
619
+ relationship_schema = self._get_relationship_schema(database_path=database_path)
599
620
  if not relationship_schema:
600
621
  return
601
622
  diff_relationship = self._get_diff_relationship(
@@ -649,7 +670,9 @@ class DiffQueryParser:
649
670
  relationship_schema: RelationshipSchema,
650
671
  database_path: DatabasePath,
651
672
  ) -> DiffRelationshipIntermediate:
652
- diff_relationship = diff_node.relationships_by_name.get(relationship_schema.name)
673
+ diff_relationship = diff_node.relationships_by_identifier.get(
674
+ (relationship_schema.name, relationship_schema.get_identifier())
675
+ )
653
676
  if not diff_relationship:
654
677
  branch_name = database_path.deepest_branch
655
678
  from_time = self.from_time
@@ -663,7 +686,9 @@ class DiffQueryParser:
663
686
  identifier=relationship_schema.get_identifier(),
664
687
  from_time=from_time,
665
688
  )
666
- diff_node.relationships_by_name[relationship_schema.name] = diff_relationship
689
+ diff_node.relationships_by_identifier[relationship_schema.name, relationship_schema.get_identifier()] = (
690
+ diff_relationship
691
+ )
667
692
  return diff_relationship
668
693
 
669
694
  def _apply_base_branch_previous_values(self) -> None:
@@ -700,8 +725,8 @@ class DiffQueryParser:
700
725
  def _apply_relationship_previous_values(
701
726
  self, diff_node: DiffNodeIntermediate, base_diff_node: DiffNodeIntermediate
702
727
  ) -> None:
703
- for relationship_name, diff_relationship in diff_node.relationships_by_name.items():
704
- base_diff_relationship = base_diff_node.relationships_by_name.get(relationship_name)
728
+ for relationship_key, diff_relationship in diff_node.relationships_by_identifier.items():
729
+ base_diff_relationship = base_diff_node.relationships_by_identifier.get(relationship_key)
705
730
  if not base_diff_relationship:
706
731
  continue
707
732
  for db_id, property_set in diff_relationship.properties_by_db_id.items():
@@ -754,7 +779,7 @@ class DiffQueryParser:
754
779
  continue
755
780
  if ordered_diff_values[-1].changed_at >= self.diff_branched_from_time:
756
781
  return
757
- for relationship_diff in node_diff.relationships_by_name.values():
782
+ for relationship_diff in node_diff.relationships_by_identifier.values():
758
783
  for diff_relationship_property_list in relationship_diff.properties_by_db_id.values():
759
784
  for diff_relationship_property in diff_relationship_property_list:
760
785
  if diff_relationship_property.changed_at >= self.diff_branched_from_time:
@@ -1,5 +1,3 @@
1
- from typing import Iterable
2
-
3
1
  from neo4j.graph import Node as Neo4jNode
4
2
  from neo4j.graph import Path as Neo4jPath
5
3
 
@@ -30,37 +28,47 @@ class EnrichedDiffDeserializer:
30
28
  self._diff_node_rel_group_map: dict[tuple[str, str, str], EnrichedDiffRelationship] = {}
31
29
  self._diff_node_rel_element_map: dict[tuple[str, str, str, str], EnrichedDiffSingleRelationship] = {}
32
30
  self._diff_prop_map: dict[tuple[str, str, str, str] | tuple[str, str, str, str, str], EnrichedDiffProperty] = {}
31
+ # {EnrichedDiffRoot: [(node_uuid, parents_path: Neo4jPath), ...]}
32
+ self._parents_path_map: dict[EnrichedDiffRoot, list[tuple[str, Neo4jPath]]] = {}
33
33
 
34
- def _initialize(self) -> None:
34
+ def initialize(self) -> None:
35
35
  self._diff_root_map = {}
36
36
  self._diff_node_map = {}
37
37
  self._diff_node_attr_map = {}
38
38
  self._diff_node_rel_group_map = {}
39
39
  self._diff_node_rel_element_map = {}
40
40
  self._diff_prop_map = {}
41
+ self._parents_path_map = {}
41
42
 
42
- async def deserialize(
43
- self, database_results: Iterable[QueryResult], include_parents: bool
44
- ) -> list[EnrichedDiffRoot]:
45
- self._initialize()
46
- results = list(database_results)
47
- for result in results:
48
- enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root"))
49
- node_node = result.get(label="diff_node")
50
- if not isinstance(node_node, Neo4jNode):
51
- continue
52
- enriched_node = self._deserialize_diff_node(node_node=node_node, enriched_root=enriched_root)
53
- node_conflict_node = result.get(label="diff_node_conflict")
54
- if isinstance(node_conflict_node, Neo4jNode) and not enriched_node.conflict:
55
- conflict = self.deserialize_conflict(diff_conflict_node=node_conflict_node)
56
- enriched_node.conflict = conflict
57
- self._deserialize_attributes(result=result, enriched_root=enriched_root, enriched_node=enriched_node)
58
- self._deserialize_relationships(result=result, enriched_root=enriched_root, enriched_node=enriched_node)
43
+ def _track_parents_path(self, enriched_root: EnrichedDiffRoot, node_uuid: str, parents_path: Neo4jPath) -> None:
44
+ if enriched_root not in self._parents_path_map:
45
+ self._parents_path_map[enriched_root] = []
46
+ self._parents_path_map[enriched_root].append((node_uuid, parents_path))
47
+
48
+ async def read_result(self, result: QueryResult, include_parents: bool) -> None:
49
+ enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root"))
50
+ node_node = result.get(label="diff_node")
51
+ if not isinstance(node_node, Neo4jNode):
52
+ return
53
+ enriched_node = self._deserialize_diff_node(node_node=node_node, enriched_root=enriched_root)
59
54
 
60
55
  if include_parents:
61
- for result in results:
62
- enriched_root = self._deserialize_diff_root(root_node=result.get_node("diff_root"))
63
- self._deserialize_parents(result=result, enriched_root=enriched_root)
56
+ parents_path = result.get("parents_path")
57
+ if parents_path and isinstance(parents_path, Neo4jPath):
58
+ self._track_parents_path(
59
+ enriched_root=enriched_root, node_uuid=enriched_node.uuid, parents_path=parents_path
60
+ )
61
+
62
+ node_conflict_node = result.get(label="diff_node_conflict")
63
+ if isinstance(node_conflict_node, Neo4jNode) and not enriched_node.conflict:
64
+ conflict = self.deserialize_conflict(diff_conflict_node=node_conflict_node)
65
+ enriched_node.conflict = conflict
66
+ self._deserialize_attributes(result=result, enriched_root=enriched_root, enriched_node=enriched_node)
67
+ self._deserialize_relationships(result=result, enriched_root=enriched_root, enriched_node=enriched_node)
68
+
69
+ async def deserialize(self, include_parents: bool = True) -> list[EnrichedDiffRoot]:
70
+ if include_parents:
71
+ self._deserialize_parents()
64
72
 
65
73
  return list(self._diff_root_map.values())
66
74
 
@@ -117,30 +125,26 @@ class EnrichedDiffDeserializer:
117
125
  conflict = self.deserialize_conflict(diff_conflict_node=property_conflict)
118
126
  element_property.conflict = conflict
119
127
 
120
- def _deserialize_parents(self, result: QueryResult, enriched_root: EnrichedDiffRoot) -> None:
121
- parents_path = result.get("parents_path")
122
- if not parents_path or not isinstance(parents_path, Neo4jPath):
123
- return
124
-
125
- node_uuid = result.get(label="diff_node").get("uuid")
126
-
127
- # Remove the node itself from the path
128
- parents_path = parents_path.nodes[1:] # type: ignore[union-attr]
129
-
130
- # TODO Ensure the list is even
131
- current_node_uuid = node_uuid
132
- for rel, parent in zip(parents_path[::2], parents_path[1::2]):
133
- enriched_root.add_parent(
134
- node_id=current_node_uuid,
135
- parent_id=parent.get("uuid"),
136
- parent_kind=parent.get("kind"),
137
- parent_label=parent.get("label"),
138
- parent_rel_name=rel.get("name"),
139
- parent_rel_identifier=rel.get("identifier"),
140
- parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")),
141
- parent_rel_label=rel.get("label"),
142
- )
143
- current_node_uuid = parent.get("uuid")
128
+ def _deserialize_parents(self) -> None:
129
+ for enriched_root, node_path_tuples in self._parents_path_map.items():
130
+ for node_uuid, parents_path in node_path_tuples:
131
+ # Remove the node itself from the path
132
+ parents_path_slice = parents_path.nodes[1:]
133
+
134
+ # TODO Ensure the list is even
135
+ current_node_uuid = node_uuid
136
+ for rel, parent in zip(parents_path_slice[::2], parents_path_slice[1::2]):
137
+ enriched_root.add_parent(
138
+ node_id=current_node_uuid,
139
+ parent_id=parent.get("uuid"),
140
+ parent_kind=parent.get("kind"),
141
+ parent_label=parent.get("label"),
142
+ parent_rel_name=rel.get("name"),
143
+ parent_rel_identifier=rel.get("identifier"),
144
+ parent_rel_cardinality=RelationshipCardinality(rel.get("cardinality")),
145
+ parent_rel_label=rel.get("label"),
146
+ )
147
+ current_node_uuid = parent.get("uuid")
144
148
 
145
149
  @classmethod
146
150
  def _get_str_or_none_property_value(cls, node: Neo4jNode, property_name: str) -> str | None:
@@ -160,10 +164,8 @@ class EnrichedDiffDeserializer:
160
164
  def build_diff_root_metadata(cls, root_node: Neo4jNode) -> EnrichedDiffRootMetadata:
161
165
  from_time = Timestamp(str(root_node.get("from_time")))
162
166
  to_time = Timestamp(str(root_node.get("to_time")))
163
- tracking_id_str = cls._get_str_or_none_property_value(node=root_node, property_name="tracking_id")
164
- tracking_id = None
165
- if tracking_id_str:
166
- tracking_id = deserialize_tracking_id(tracking_id_str=tracking_id_str)
167
+ tracking_id_str = str(root_node.get("tracking_id"))
168
+ tracking_id = deserialize_tracking_id(tracking_id_str=tracking_id_str)
167
169
  return EnrichedDiffRootMetadata(
168
170
  base_branch_name=str(root_node.get("base_branch")),
169
171
  diff_branch_name=str(root_node.get("diff_branch")),
@@ -172,11 +174,12 @@ class EnrichedDiffDeserializer:
172
174
  uuid=str(root_node.get("uuid")),
173
175
  partner_uuid=str(root_node.get("partner_uuid")),
174
176
  tracking_id=tracking_id,
175
- num_added=int(root_node.get("num_added")),
176
- num_updated=int(root_node.get("num_updated")),
177
- num_removed=int(root_node.get("num_removed")),
178
- num_conflicts=int(root_node.get("num_conflicts")),
177
+ num_added=int(root_node.get("num_added", 0)),
178
+ num_updated=int(root_node.get("num_updated", 0)),
179
+ num_removed=int(root_node.get("num_removed", 0)),
180
+ num_conflicts=int(root_node.get("num_conflicts", 0)),
179
181
  contains_conflict=str(root_node.get("contains_conflict")).lower() == "true",
182
+ exists_on_database=True,
180
183
  )
181
184
 
182
185
  def _deserialize_diff_node(self, node_node: Neo4jNode, enriched_root: EnrichedDiffRoot) -> EnrichedDiffNode:
@@ -193,10 +196,10 @@ class EnrichedDiffDeserializer:
193
196
  changed_at=Timestamp(timestamp_str) if timestamp_str else None,
194
197
  action=DiffAction(str(node_node.get("action"))),
195
198
  path_identifier=str(node_node.get("path_identifier")),
196
- num_added=int(node_node.get("num_added")),
197
- num_updated=int(node_node.get("num_updated")),
198
- num_removed=int(node_node.get("num_removed")),
199
- num_conflicts=int(node_node.get("num_conflicts")),
199
+ num_added=int(node_node.get("num_added", 0)),
200
+ num_updated=int(node_node.get("num_updated", 0)),
201
+ num_removed=int(node_node.get("num_removed", 0)),
202
+ num_conflicts=int(node_node.get("num_conflicts", 0)),
200
203
  contains_conflict=str(node_node.get("contains_conflict")).lower() == "true",
201
204
  )
202
205
  self._diff_node_map[node_key] = enriched_node
@@ -216,10 +219,10 @@ class EnrichedDiffDeserializer:
216
219
  changed_at=Timestamp(str(diff_attr_node.get("changed_at"))),
217
220
  path_identifier=str(diff_attr_node.get("path_identifier")),
218
221
  action=DiffAction(str(diff_attr_node.get("action"))),
219
- num_added=int(diff_attr_node.get("num_added")),
220
- num_updated=int(diff_attr_node.get("num_updated")),
221
- num_removed=int(diff_attr_node.get("num_removed")),
222
- num_conflicts=int(diff_attr_node.get("num_conflicts")),
222
+ num_added=int(diff_attr_node.get("num_added", 0)),
223
+ num_updated=int(diff_attr_node.get("num_updated", 0)),
224
+ num_removed=int(diff_attr_node.get("num_removed", 0)),
225
+ num_conflicts=int(diff_attr_node.get("num_conflicts", 0)),
223
226
  contains_conflict=str(diff_attr_node.get("contains_conflict")).lower() == "true",
224
227
  )
225
228
  self._diff_node_attr_map[attr_key] = enriched_attr
@@ -243,10 +246,10 @@ class EnrichedDiffDeserializer:
243
246
  changed_at=Timestamp(timestamp_str) if timestamp_str else None,
244
247
  action=DiffAction(str(relationship_group_node.get("action"))),
245
248
  path_identifier=str(relationship_group_node.get("path_identifier")),
246
- num_added=int(relationship_group_node.get("num_added")),
247
- num_conflicts=int(relationship_group_node.get("num_conflicts")),
248
- num_removed=int(relationship_group_node.get("num_removed")),
249
- num_updated=int(relationship_group_node.get("num_updated")),
249
+ num_added=int(relationship_group_node.get("num_added", 0)),
250
+ num_conflicts=int(relationship_group_node.get("num_conflicts", 0)),
251
+ num_removed=int(relationship_group_node.get("num_removed", 0)),
252
+ num_updated=int(relationship_group_node.get("num_updated", 0)),
250
253
  contains_conflict=str(relationship_group_node.get("contains_conflict")).lower() == "true",
251
254
  )
252
255
 
@@ -278,10 +281,10 @@ class EnrichedDiffDeserializer:
278
281
  peer_id=diff_element_peer_id,
279
282
  peer_label=peer_label,
280
283
  path_identifier=str(relationship_element_node.get("path_identifier")),
281
- num_added=int(relationship_element_node.get("num_added")),
282
- num_updated=int(relationship_element_node.get("num_updated")),
283
- num_removed=int(relationship_element_node.get("num_removed")),
284
- num_conflicts=int(relationship_element_node.get("num_conflicts")),
284
+ num_added=int(relationship_element_node.get("num_added", 0)),
285
+ num_updated=int(relationship_element_node.get("num_updated", 0)),
286
+ num_removed=int(relationship_element_node.get("num_removed", 0)),
287
+ num_conflicts=int(relationship_element_node.get("num_conflicts", 0)),
285
288
  contains_conflict=str(relationship_element_node.get("contains_conflict")).lower() == "true",
286
289
  )
287
290
  enriched_relationship_group.relationships.add(enriched_rel_element)
@@ -1,9 +1,10 @@
1
1
  from collections import defaultdict
2
- from typing import AsyncGenerator, Generator
2
+ from typing import AsyncGenerator, Generator, Iterable
3
3
 
4
4
  from infrahub import config
5
5
  from infrahub.core import registry
6
6
  from infrahub.core.diff.query.field_summary import EnrichedDiffNodeFieldSummaryQuery
7
+ from infrahub.core.diff.query.summary_counts_enricher import DiffSummaryCountsEnricherQuery
7
8
  from infrahub.core.query.diff import DiffCountChanges
8
9
  from infrahub.core.timestamp import Timestamp
9
10
  from infrahub.database import InfrahubDatabase, retry_db_transaction
@@ -26,13 +27,13 @@ from ..query.all_conflicts import EnrichedDiffAllConflictsQuery
26
27
  from ..query.delete_query import EnrichedDiffDeleteQuery
27
28
  from ..query.diff_get import EnrichedDiffGetQuery
28
29
  from ..query.diff_summary import DiffSummaryCounters, DiffSummaryQuery
29
- from ..query.drop_tracking_id import EnrichedDiffDropTrackingIdQuery
30
30
  from ..query.field_specifiers import EnrichedDiffFieldSpecifiersQuery
31
31
  from ..query.filters import EnrichedDiffQueryFilters
32
32
  from ..query.get_conflict_query import EnrichedDiffConflictQuery
33
33
  from ..query.has_conflicts_query import EnrichedDiffHasConflictQuery
34
+ from ..query.merge_tracking_id import EnrichedDiffMergedTrackingIdQuery
34
35
  from ..query.roots_metadata import EnrichedDiffRootsMetadataQuery
35
- from ..query.save import EnrichedDiffRootsCreateQuery, EnrichedNodeBatchCreateQuery, EnrichedNodesLinkQuery
36
+ from ..query.save import EnrichedDiffRootsUpsertQuery, EnrichedNodeBatchCreateQuery, EnrichedNodesLinkQuery
36
37
  from ..query.time_range_query import EnrichedDiffTimeRangeQuery
37
38
  from ..query.update_conflict_query import EnrichedDiffConflictUpdateQuery
38
39
  from .deserializer import EnrichedDiffDeserializer
@@ -47,6 +48,55 @@ class DiffRepository:
47
48
  self.db = db
48
49
  self.deserializer = deserializer
49
50
 
51
+ async def _run_get_diff_query(
52
+ self,
53
+ base_branch_name: str,
54
+ diff_branch_names: list[str],
55
+ batch_size_limit: int,
56
+ limit: int | None = None,
57
+ from_time: Timestamp | None = None,
58
+ to_time: Timestamp | None = None,
59
+ filters: EnrichedDiffQueryFilters | None = None,
60
+ offset: int = 0,
61
+ include_parents: bool = True,
62
+ max_depth: int | None = None,
63
+ tracking_id: TrackingId | None = None,
64
+ diff_ids: list[str] | None = None,
65
+ ) -> list[EnrichedDiffRoot]:
66
+ self.deserializer.initialize()
67
+ final_row_number = None
68
+ if limit:
69
+ final_row_number = offset + limit
70
+ has_more_data = True
71
+ while has_more_data and (final_row_number is None or offset < final_row_number):
72
+ if final_row_number is not None and offset + batch_size_limit > final_row_number:
73
+ batch_size_limit = final_row_number - offset
74
+ get_query = await EnrichedDiffGetQuery.init(
75
+ db=self.db,
76
+ base_branch_name=base_branch_name,
77
+ diff_branch_names=diff_branch_names,
78
+ from_time=from_time,
79
+ to_time=to_time,
80
+ filters=filters,
81
+ max_depth=max_depth,
82
+ limit=batch_size_limit,
83
+ offset=offset,
84
+ tracking_id=tracking_id,
85
+ diff_ids=diff_ids,
86
+ )
87
+ log.info(f"Beginning enriched diff get query {batch_size_limit=}, {offset=}")
88
+ await get_query.execute(db=self.db)
89
+ log.info("Enriched diff get query complete")
90
+ last_result = None
91
+ for query_result in get_query.get_results():
92
+ await self.deserializer.read_result(result=query_result, include_parents=include_parents)
93
+ last_result = query_result
94
+ has_more_data = False
95
+ if last_result:
96
+ has_more_data = last_result.get_as_type("has_more_data", bool)
97
+ offset += batch_size_limit
98
+ return await self.deserializer.deserialize()
99
+
50
100
  async def get(
51
101
  self,
52
102
  base_branch_name: str,
@@ -62,23 +112,21 @@ class DiffRepository:
62
112
  include_empty: bool = False,
63
113
  ) -> list[EnrichedDiffRoot]:
64
114
  final_max_depth = config.SETTINGS.database.max_depth_search_hierarchy
65
- query = await EnrichedDiffGetQuery.init(
66
- db=self.db,
115
+ batch_size_limit = int(config.SETTINGS.database.query_size_limit / 10)
116
+ diff_roots = await self._run_get_diff_query(
67
117
  base_branch_name=base_branch_name,
68
118
  diff_branch_names=diff_branch_names,
119
+ batch_size_limit=batch_size_limit,
120
+ limit=limit,
69
121
  from_time=from_time,
70
122
  to_time=to_time,
71
123
  filters=EnrichedDiffQueryFilters(**dict(filters or {})),
124
+ include_parents=include_parents,
72
125
  max_depth=final_max_depth,
73
- limit=limit,
74
- offset=offset,
126
+ offset=offset or 0,
75
127
  tracking_id=tracking_id,
76
128
  diff_ids=diff_ids,
77
129
  )
78
- await query.execute(db=self.db)
79
- diff_roots = await self.deserializer.deserialize(
80
- database_results=query.get_results(), include_parents=include_parents
81
- )
82
130
  if not include_empty:
83
131
  diff_roots = [dr for dr in diff_roots if len(dr.nodes) > 0]
84
132
  return diff_roots
@@ -91,30 +139,23 @@ class DiffRepository:
91
139
  to_time: Timestamp,
92
140
  ) -> list[EnrichedDiffs]:
93
141
  max_depth = config.SETTINGS.database.max_depth_search_hierarchy
94
- query = await EnrichedDiffGetQuery.init(
95
- db=self.db,
142
+ batch_size_limit = int(config.SETTINGS.database.query_size_limit / 10)
143
+ diff_branch_roots = await self._run_get_diff_query(
96
144
  base_branch_name=base_branch_name,
97
145
  diff_branch_names=[diff_branch_name],
98
146
  from_time=from_time,
99
147
  to_time=to_time,
100
148
  max_depth=max_depth,
101
- )
102
- await query.execute(db=self.db)
103
- diff_branch_roots = await self.deserializer.deserialize(
104
- database_results=query.get_results(), include_parents=True
149
+ batch_size_limit=batch_size_limit,
105
150
  )
106
151
  diffs_by_uuid = {dbr.uuid: dbr for dbr in diff_branch_roots}
107
- base_partner_query = await EnrichedDiffGetQuery.init(
108
- db=self.db,
152
+ base_branch_roots = await self._run_get_diff_query(
109
153
  base_branch_name=base_branch_name,
110
154
  diff_branch_names=[base_branch_name],
111
155
  max_depth=max_depth,
156
+ batch_size_limit=batch_size_limit,
112
157
  diff_ids=[d.partner_uuid for d in diffs_by_uuid.values()],
113
158
  )
114
- await base_partner_query.execute(db=self.db)
115
- base_branch_roots = await self.deserializer.deserialize(
116
- database_results=base_partner_query.get_results(), include_parents=True
117
- )
118
159
  diffs_by_uuid.update({bbr.uuid: bbr for bbr in base_branch_roots})
119
160
  return [
120
161
  EnrichedDiffs(
@@ -126,14 +167,23 @@ class DiffRepository:
126
167
  for dbr in diff_branch_roots
127
168
  ]
128
169
 
129
- async def hydrate_diff_pair(self, enriched_diffs_metadata: EnrichedDiffsMetadata) -> EnrichedDiffs:
170
+ async def hydrate_diff_pair(
171
+ self,
172
+ enriched_diffs_metadata: EnrichedDiffsMetadata,
173
+ node_uuids: Iterable[str] | None = None,
174
+ ) -> EnrichedDiffs:
175
+ filters = None
176
+ if node_uuids:
177
+ filters = {"ids": list(node_uuids) if node_uuids is not None else None}
130
178
  hydrated_base_diff = await self.get_one(
131
179
  diff_branch_name=enriched_diffs_metadata.base_branch_name,
132
180
  diff_id=enriched_diffs_metadata.base_branch_diff.uuid,
181
+ filters=filters,
133
182
  )
134
183
  hydrated_branch_diff = await self.get_one(
135
184
  diff_branch_name=enriched_diffs_metadata.diff_branch_name,
136
185
  diff_id=enriched_diffs_metadata.diff_branch_diff.uuid,
186
+ filters=filters,
137
187
  )
138
188
  return EnrichedDiffs(
139
189
  base_branch_name=enriched_diffs_metadata.base_branch_name,
@@ -184,17 +234,34 @@ class DiffRepository:
184
234
  yield node_requests
185
235
 
186
236
  @retry_db_transaction(name="enriched_diff_save")
187
- async def save(self, enriched_diffs: EnrichedDiffs) -> None:
237
+ async def save(self, enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata, do_summary_counts: bool = True) -> None:
238
+ log.info("Updating diff metadata...")
239
+ root_query = await EnrichedDiffRootsUpsertQuery.init(db=self.db, enriched_diffs=enriched_diffs)
240
+ await root_query.execute(db=self.db)
241
+ log.info("Diff metadata updated.")
242
+ if not isinstance(enriched_diffs, EnrichedDiffs):
243
+ return
188
244
  num_nodes = len(enriched_diffs.base_branch_diff.nodes) + len(enriched_diffs.diff_branch_diff.nodes)
189
245
  log.info(f"Saving diff (num_nodes={num_nodes})...")
190
- root_query = await EnrichedDiffRootsCreateQuery.init(db=self.db, enriched_diffs=enriched_diffs)
191
- await root_query.execute(db=self.db)
192
- for node_create_batch in self._get_node_create_request_batch(enriched_diffs=enriched_diffs):
246
+ for batch_num, node_create_batch in enumerate(
247
+ self._get_node_create_request_batch(enriched_diffs=enriched_diffs)
248
+ ):
249
+ log.info(f"Saving node batch #{batch_num}...")
193
250
  node_query = await EnrichedNodeBatchCreateQuery.init(db=self.db, node_create_batch=node_create_batch)
194
251
  await node_query.execute(db=self.db)
252
+ log.info(f"Batch #{batch_num} saved")
195
253
  link_query = await EnrichedNodesLinkQuery.init(db=self.db, enriched_diffs=enriched_diffs)
196
254
  await link_query.execute(db=self.db)
197
255
  log.info("Diff saved.")
256
+ if do_summary_counts:
257
+ node_uuids: list[str] | None = None
258
+ if enriched_diffs.diff_branch_diff.exists_on_database:
259
+ node_uuids = list(enriched_diffs.branch_node_uuids)
260
+ await self.add_summary_counts(
261
+ diff_branch_name=enriched_diffs.diff_branch_name,
262
+ diff_id=enriched_diffs.diff_branch_diff.uuid,
263
+ node_uuids=node_uuids,
264
+ )
198
265
 
199
266
  async def summary(
200
267
  self,
@@ -244,6 +311,7 @@ class DiffRepository:
244
311
  base_branch_names: list[str] | None = None,
245
312
  from_time: Timestamp | None = None,
246
313
  to_time: Timestamp | None = None,
314
+ tracking_id: TrackingId | None = None,
247
315
  ) -> list[EnrichedDiffsMetadata]:
248
316
  if diff_branch_names and base_branch_names:
249
317
  diff_branch_names += base_branch_names
@@ -252,6 +320,7 @@ class DiffRepository:
252
320
  base_branch_names=base_branch_names,
253
321
  from_time=from_time,
254
322
  to_time=to_time,
323
+ tracking_id=tracking_id,
255
324
  )
256
325
  roots_by_id = {root.uuid: root for root in empty_roots}
257
326
  pairs: list[EnrichedDiffsMetadata] = []
@@ -275,6 +344,7 @@ class DiffRepository:
275
344
  base_branch_names: list[str] | None = None,
276
345
  from_time: Timestamp | None = None,
277
346
  to_time: Timestamp | None = None,
347
+ tracking_id: TrackingId | None = None,
278
348
  ) -> list[EnrichedDiffRootMetadata]:
279
349
  query = await EnrichedDiffRootsMetadataQuery.init(
280
350
  db=self.db,
@@ -282,6 +352,7 @@ class DiffRepository:
282
352
  base_branch_names=base_branch_names,
283
353
  from_time=from_time,
284
354
  to_time=to_time,
355
+ tracking_id=tracking_id,
285
356
  )
286
357
  await query.execute(db=self.db)
287
358
  diff_roots = []
@@ -341,8 +412,8 @@ class DiffRepository:
341
412
  await query.execute(db=self.db)
342
413
  return await query.get_field_summaries()
343
414
 
344
- async def drop_tracking_ids(self, tracking_ids: list[TrackingId]) -> None:
345
- query = await EnrichedDiffDropTrackingIdQuery.init(db=self.db, tracking_ids=tracking_ids)
415
+ async def mark_tracking_ids_merged(self, tracking_ids: list[TrackingId]) -> None:
416
+ query = await EnrichedDiffMergedTrackingIdQuery.init(db=self.db, tracking_ids=tracking_ids)
346
417
  await query.execute(db=self.db)
347
418
 
348
419
  async def get_num_changes_in_time_range_by_branch(
@@ -367,3 +438,21 @@ class DiffRepository:
367
438
  break
368
439
  offset += limit
369
440
  return specifiers
441
+
442
+ async def add_summary_counts(
443
+ self,
444
+ diff_branch_name: str,
445
+ tracking_id: TrackingId | None = None,
446
+ diff_id: str | None = None,
447
+ node_uuids: list[str] | None = None,
448
+ ) -> None:
449
+ log.info("Updating summary counts...")
450
+ query = await DiffSummaryCountsEnricherQuery.init(
451
+ db=self.db,
452
+ diff_branch_name=diff_branch_name,
453
+ tracking_id=tracking_id,
454
+ diff_id=diff_id,
455
+ node_uuids=node_uuids,
456
+ )
457
+ await query.execute(db=self.db)
458
+ log.info("Summary counts updated...")