infrahub-server 1.1.3__py3-none-any.whl → 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. infrahub/api/artifact.py +4 -1
  2. infrahub/api/oidc.py +32 -4
  3. infrahub/api/schema.py +16 -3
  4. infrahub/api/transformation.py +2 -0
  5. infrahub/computed_attribute/tasks.py +9 -5
  6. infrahub/config.py +0 -1
  7. infrahub/core/constants/__init__.py +5 -0
  8. infrahub/core/diff/calculator.py +118 -8
  9. infrahub/core/diff/coordinator.py +1 -2
  10. infrahub/core/diff/model/path.py +0 -9
  11. infrahub/core/diff/query/all_conflicts.py +64 -0
  12. infrahub/core/diff/query_parser.py +51 -28
  13. infrahub/core/diff/repository/repository.py +23 -14
  14. infrahub/core/merge.py +7 -5
  15. infrahub/core/query/diff.py +417 -621
  16. infrahub/core/relationship/model.py +10 -2
  17. infrahub/core/schema/__init__.py +5 -0
  18. infrahub/core/validators/aggregated_checker.py +34 -4
  19. infrahub/core/validators/node/hierarchy.py +1 -1
  20. infrahub/generators/tasks.py +12 -10
  21. infrahub/git/base.py +50 -8
  22. infrahub/git/integrator.py +24 -17
  23. infrahub/git/models.py +3 -2
  24. infrahub/git/repository.py +2 -2
  25. infrahub/git/tasks.py +19 -11
  26. infrahub/graphql/mutations/artifact_definition.py +10 -2
  27. infrahub/message_bus/messages/check_repository_usercheck.py +1 -0
  28. infrahub/message_bus/operations/check/repository.py +5 -2
  29. infrahub/message_bus/operations/requests/artifact_definition.py +1 -1
  30. infrahub/message_bus/operations/requests/proposed_change.py +4 -0
  31. infrahub/message_bus/types.py +1 -0
  32. infrahub/transformations/constants.py +1 -0
  33. infrahub/transformations/models.py +3 -1
  34. infrahub/transformations/tasks.py +2 -2
  35. infrahub/webhook/models.py +9 -1
  36. infrahub/workflows/catalogue.py +2 -2
  37. infrahub_sdk/analyzer.py +1 -1
  38. infrahub_sdk/checks.py +4 -4
  39. infrahub_sdk/client.py +26 -16
  40. infrahub_sdk/ctl/cli_commands.py +4 -4
  41. infrahub_sdk/ctl/exporter.py +2 -2
  42. infrahub_sdk/ctl/importer.py +6 -4
  43. infrahub_sdk/ctl/repository.py +56 -1
  44. infrahub_sdk/generator.py +3 -3
  45. infrahub_sdk/node.py +2 -2
  46. infrahub_sdk/pytest_plugin/items/base.py +0 -5
  47. infrahub_sdk/pytest_plugin/items/graphql_query.py +1 -1
  48. infrahub_sdk/pytest_plugin/items/jinja2_transform.py +1 -1
  49. infrahub_sdk/pytest_plugin/items/python_transform.py +1 -1
  50. infrahub_sdk/repository.py +33 -0
  51. infrahub_sdk/testing/repository.py +14 -8
  52. infrahub_sdk/transforms.py +3 -3
  53. infrahub_sdk/utils.py +8 -3
  54. {infrahub_server-1.1.3.dist-info → infrahub_server-1.1.5.dist-info}/METADATA +2 -1
  55. {infrahub_server-1.1.3.dist-info → infrahub_server-1.1.5.dist-info}/RECORD +59 -57
  56. infrahub_testcontainers/docker-compose.test.yml +1 -1
  57. infrahub_sdk/task_report.py +0 -208
  58. {infrahub_server-1.1.3.dist-info → infrahub_server-1.1.5.dist-info}/LICENSE.txt +0 -0
  59. {infrahub_server-1.1.3.dist-info → infrahub_server-1.1.5.dist-info}/WHEEL +0 -0
  60. {infrahub_server-1.1.3.dist-info → infrahub_server-1.1.5.dist-info}/entry_points.txt +0 -0
infrahub/api/artifact.py CHANGED
@@ -85,7 +85,10 @@ async def generate_artifact(
85
85
 
86
86
  service = request.app.state.service
87
87
  model = RequestArtifactDefinitionGenerate(
88
- artifact_definition=artifact_definition.id, branch=branch_params.branch.name, limit=payload.nodes
88
+ artifact_definition_id=artifact_definition.id,
89
+ artifact_definition_name=artifact_definition.name.value,
90
+ branch=branch_params.branch.name,
91
+ limit=payload.nodes,
89
92
  )
90
93
 
91
94
  await service.workflow.submit_workflow(workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model})
infrahub/api/oidc.py CHANGED
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, Any
4
4
  from urllib.parse import urljoin
5
5
 
6
+ import jwt
6
7
  import ujson
7
8
  from authlib.integrations.httpx_client import AsyncOAuth2Client
8
9
  from fastapi import APIRouter, Depends, Request, Response
@@ -138,7 +139,7 @@ async def token(
138
139
 
139
140
  with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
140
141
  span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
141
- payload = token_response.json()
142
+ payload: dict[str, Any] = token_response.json()
142
143
 
143
144
  headers = {"Authorization": f"{payload.get('token_type')} {payload.get('access_token')}"}
144
145
 
@@ -148,8 +149,10 @@ async def token(
148
149
  userinfo_response = await service.http.post(str(oidc_config.userinfo_endpoint), headers=headers)
149
150
 
150
151
  _validate_response(response=userinfo_response)
151
- user_info = userinfo_response.json()
152
- sso_groups = user_info.get("groups", [])
152
+ user_info: dict[str, Any] = userinfo_response.json()
153
+ sso_groups = user_info.get("groups") or await _get_id_token_groups(
154
+ oidc_config=oidc_config, service=service, payload=payload, client_id=provider.client_id
155
+ )
153
156
 
154
157
  if not sso_groups and config.SETTINGS.security.sso_user_default_group:
155
158
  sso_groups = [config.SETTINGS.security.sso_user_default_group]
@@ -185,3 +188,28 @@ def _validate_response(response: httpx.Response) -> None:
185
188
  body=response.json(),
186
189
  )
187
190
  raise GatewayError(message="Invalid response from Authentication provider")
191
+
192
+
193
+ async def _get_id_token_groups(
194
+ oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str
195
+ ) -> list[str]:
196
+ id_token = payload.get("id_token")
197
+ if not id_token:
198
+ return []
199
+ jwks = await service.http.get(url=str(oidc_config.jwks_uri))
200
+
201
+ jwk_client = jwt.PyJWKClient(uri=str(oidc_config.jwks_uri), cache_jwk_set=True)
202
+ if jwk_client.jwk_set_cache:
203
+ jwk_client.jwk_set_cache.put(jwks.json())
204
+
205
+ signing_key = jwk_client.get_signing_key_from_jwt(id_token)
206
+
207
+ decoded_token: dict[str, Any] = jwt.decode(
208
+ jwt=id_token,
209
+ key=signing_key.key,
210
+ algorithms=oidc_config.id_token_signing_alg_values_supported,
211
+ audience=client_id,
212
+ issuer=str(oidc_config.issuer),
213
+ )
214
+
215
+ return decoded_token.get("groups", [])
infrahub/api/schema.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Any
3
+ from typing import TYPE_CHECKING, Any, Sequence
4
4
 
5
5
  from fastapi import APIRouter, Depends, Query, Request
6
6
  from pydantic import (
@@ -128,13 +128,26 @@ class SchemaUpdate(BaseModel):
128
128
  return self.hash != self.previous_hash
129
129
 
130
130
 
131
+ def _merge_candidate_schemas(schemas: Sequence[SchemaRoot]) -> SchemaRoot:
132
+ """Merge multiple schemas into one suitable to be loaded."""
133
+ if not schemas:
134
+ raise ValueError("Cannot merge an empty list of schemas")
135
+
136
+ merged = schemas[0]
137
+ for schema in schemas[1:]:
138
+ merged = merged.merge(schema=schema)
139
+
140
+ return merged
141
+
142
+
131
143
  def evaluate_candidate_schemas(
132
144
  branch_schema: SchemaBranch, schemas_to_evaluate: SchemasLoadAPI
133
145
  ) -> tuple[SchemaBranch, SchemaUpdateValidationResult]:
134
146
  candidate_schema = branch_schema.duplicate()
147
+ schema = _merge_candidate_schemas(schemas=schemas_to_evaluate.schemas)
148
+
135
149
  try:
136
- for schema in schemas_to_evaluate.schemas:
137
- candidate_schema.load_schema(schema=schema)
150
+ candidate_schema.load_schema(schema=schema)
138
151
  candidate_schema.process()
139
152
 
140
153
  schema_diff = branch_schema.diff(other=candidate_schema)
@@ -83,6 +83,7 @@ async def transform_python(
83
83
  commit=repository.commit.value, # type: ignore[attr-defined]
84
84
  branch=branch_params.branch.name,
85
85
  transform_location=f"{transform.file_path.value}::{transform.class_name.value}",
86
+ timeout=transform.timeout.value,
86
87
  data=data,
87
88
  )
88
89
 
@@ -140,6 +141,7 @@ async def transform_jinja2(
140
141
  commit=repository.commit.value, # type: ignore[attr-defined]
141
142
  branch=branch_params.branch.name,
142
143
  template_location=transform.template_path.value,
144
+ timeout=transform.timeout.value,
143
145
  data=data,
144
146
  )
145
147
 
@@ -4,7 +4,10 @@ from datetime import timedelta
4
4
  from typing import TYPE_CHECKING, Any
5
5
 
6
6
  import ujson
7
- from infrahub_sdk.protocols import CoreNode # noqa: TC002
7
+ from infrahub_sdk.protocols import (
8
+ CoreNode, # noqa: TC002
9
+ CoreTransformPython,
10
+ )
8
11
  from prefect import flow
9
12
  from prefect.automations import AutomationCore
10
13
  from prefect.client.orchestration import get_client
@@ -88,17 +91,18 @@ async def process_transform(
88
91
 
89
92
  for attribute_name, transform_attribute in transform_attributes.items():
90
93
  transform = await service.client.get(
91
- kind="CoreTransformPython",
94
+ kind=CoreTransformPython,
92
95
  branch=branch_name,
93
96
  id=transform_attribute.transform,
94
97
  prefetch_relationships=True,
95
98
  populate_store=True,
96
99
  )
100
+
97
101
  if not transform:
98
102
  continue
99
103
 
100
104
  repo_node = await service.client.get(
101
- kind=transform.repository.peer.typename,
105
+ kind=str(transform.repository.peer.typename),
102
106
  branch=branch_name,
103
107
  id=transform.repository.peer.id,
104
108
  raise_when_missing=True,
@@ -108,7 +112,7 @@ async def process_transform(
108
112
  repository_id=transform.repository.peer.id,
109
113
  name=transform.repository.peer.name.value,
110
114
  service=service,
111
- repository_kind=transform.repository.peer.typename,
115
+ repository_kind=str(transform.repository.peer.typename),
112
116
  commit=repo_node.commit.value,
113
117
  )
114
118
 
@@ -120,7 +124,7 @@ async def process_transform(
120
124
  subscribers=[object_id],
121
125
  )
122
126
 
123
- transformed_data = await repo.execute_python_transform(
127
+ transformed_data = await repo.execute_python_transform.with_options(timeout_seconds=transform.timeout.value)(
124
128
  branch_name=branch_name,
125
129
  commit=repo_node.commit.value,
126
130
  location=f"{transform.file_path.value}::{transform.class_name.value}",
infrahub/config.py CHANGED
@@ -585,7 +585,6 @@ class AnalyticsSettings(BaseSettings):
585
585
 
586
586
  class ExperimentalFeaturesSettings(BaseSettings):
587
587
  model_config = SettingsConfigDict(env_prefix="INFRAHUB_EXPERIMENTAL_")
588
- pull_request: bool = False
589
588
  graphql_enums: bool = False
590
589
 
591
590
 
@@ -120,7 +120,12 @@ class AllowOverrideType(InfrahubStringEnum):
120
120
 
121
121
  class ContentType(InfrahubStringEnum):
122
122
  APPLICATION_JSON = "application/json"
123
+ APPLICATION_YAML = "application/yaml"
124
+ APPLICATION_XML = "application/xml"
123
125
  TEXT_PLAIN = "text/plain"
126
+ TEXT_MARKDOWN = "text/markdown"
127
+ TEXT_CSV = "text/csv"
128
+ IMAGE_SVG = "image/svg+xml"
124
129
 
125
130
 
126
131
  class CheckType(InfrahubStringEnum):
@@ -1,20 +1,68 @@
1
+ from dataclasses import dataclass, field
2
+
3
+ from infrahub import config
1
4
  from infrahub.core import registry
2
5
  from infrahub.core.branch import Branch
3
6
  from infrahub.core.diff.query_parser import DiffQueryParser
4
- from infrahub.core.query.diff import DiffAllPathsQuery
7
+ from infrahub.core.query.diff import DiffAllPathsQuery, DiffCalculationQuery, DiffFieldPathsQuery, DiffNodePathsQuery
5
8
  from infrahub.core.timestamp import Timestamp
6
9
  from infrahub.database import InfrahubDatabase
7
10
  from infrahub.log import get_logger
8
11
 
9
- from .model.path import CalculatedDiffs, NodeFieldSpecifier
12
+ from .model.path import CalculatedDiffs
10
13
 
11
14
  log = get_logger()
12
15
 
13
16
 
17
+ @dataclass
18
+ class DiffCalculationRequest:
19
+ base_branch: Branch
20
+ diff_branch: Branch
21
+ branch_from_time: Timestamp
22
+ from_time: Timestamp
23
+ to_time: Timestamp
24
+ current_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
25
+ new_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
26
+
27
+
14
28
  class DiffCalculator:
15
29
  def __init__(self, db: InfrahubDatabase) -> None:
16
30
  self.db = db
17
31
 
32
+ async def _run_diff_calculation_query(
33
+ self,
34
+ diff_parser: DiffQueryParser,
35
+ query_class: type[DiffCalculationQuery],
36
+ calculation_request: DiffCalculationRequest,
37
+ limit: int,
38
+ ) -> None:
39
+ has_more_data = True
40
+ offset = 0
41
+ while has_more_data:
42
+ diff_query = await query_class.init(
43
+ db=self.db,
44
+ branch=calculation_request.diff_branch,
45
+ base_branch=calculation_request.base_branch,
46
+ diff_branch_from_time=calculation_request.branch_from_time,
47
+ diff_from=calculation_request.from_time,
48
+ diff_to=calculation_request.to_time,
49
+ current_node_field_specifiers=calculation_request.current_node_field_specifiers,
50
+ new_node_field_specifiers=calculation_request.new_node_field_specifiers,
51
+ limit=limit,
52
+ offset=offset,
53
+ )
54
+ log.info(f"Beginning one diff calculation query {limit=}, {offset=}")
55
+ await diff_query.execute(db=self.db)
56
+ log.info("Diff calculation query complete")
57
+ last_result = None
58
+ for query_result in diff_query.get_results():
59
+ diff_parser.read_result(query_result=query_result)
60
+ last_result = query_result
61
+ has_more_data = False
62
+ if last_result:
63
+ has_more_data = last_result.get_as_type("has_more_data", bool)
64
+ offset += limit
65
+
18
66
  async def calculate_diff(
19
67
  self,
20
68
  base_branch: Branch,
@@ -22,7 +70,7 @@ class DiffCalculator:
22
70
  from_time: Timestamp,
23
71
  to_time: Timestamp,
24
72
  include_unchanged: bool = True,
25
- previous_node_specifiers: set[NodeFieldSpecifier] | None = None,
73
+ previous_node_specifiers: dict[str, set[str]] | None = None,
26
74
  ) -> CalculatedDiffs:
27
75
  if diff_branch.name == registry.default_branch:
28
76
  diff_branch_from_time = from_time
@@ -36,6 +84,35 @@ class DiffCalculator:
36
84
  to_time=to_time,
37
85
  previous_node_field_specifiers=previous_node_specifiers,
38
86
  )
87
+ node_limit = int(config.SETTINGS.database.query_size_limit / 10)
88
+ fields_limit = int(config.SETTINGS.database.query_size_limit / 3)
89
+
90
+ calculation_request = DiffCalculationRequest(
91
+ base_branch=base_branch,
92
+ diff_branch=diff_branch,
93
+ branch_from_time=diff_branch_from_time,
94
+ from_time=from_time,
95
+ to_time=to_time,
96
+ )
97
+
98
+ log.info("Beginning diff node-level calculation queries for branch")
99
+ await self._run_diff_calculation_query(
100
+ diff_parser=diff_parser,
101
+ query_class=DiffNodePathsQuery,
102
+ calculation_request=calculation_request,
103
+ limit=node_limit,
104
+ )
105
+ log.info("Diff node-level calculation queries for branch complete")
106
+
107
+ log.info("Beginning diff field-level calculation queries for branch")
108
+ await self._run_diff_calculation_query(
109
+ diff_parser=diff_parser,
110
+ query_class=DiffFieldPathsQuery,
111
+ calculation_request=calculation_request,
112
+ limit=fields_limit,
113
+ )
114
+ log.info("Diff field-level calculation queries for branch complete")
115
+
39
116
  branch_diff_query = await DiffAllPathsQuery.init(
40
117
  db=self.db,
41
118
  branch=diff_branch,
@@ -53,8 +130,43 @@ class DiffCalculator:
53
130
  log.info("Results of query for branch read")
54
131
 
55
132
  if base_branch.name != diff_branch.name:
56
- new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
57
133
  current_node_field_specifiers = diff_parser.get_current_node_field_specifiers()
134
+ new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
135
+ calculation_request = DiffCalculationRequest(
136
+ base_branch=base_branch,
137
+ diff_branch=base_branch,
138
+ branch_from_time=diff_branch_from_time,
139
+ from_time=from_time,
140
+ to_time=to_time,
141
+ current_node_field_specifiers=current_node_field_specifiers,
142
+ new_node_field_specifiers=new_node_field_specifiers,
143
+ )
144
+
145
+ log.info("Beginning diff node-level calculation queries for base")
146
+ await self._run_diff_calculation_query(
147
+ diff_parser=diff_parser,
148
+ query_class=DiffNodePathsQuery,
149
+ calculation_request=calculation_request,
150
+ limit=node_limit,
151
+ )
152
+ log.info("Diff node-level calculation queries for base complete")
153
+
154
+ log.info("Beginning diff field-level calculation queries for base")
155
+ await self._run_diff_calculation_query(
156
+ diff_parser=diff_parser,
157
+ query_class=DiffFieldPathsQuery,
158
+ calculation_request=calculation_request,
159
+ limit=fields_limit,
160
+ )
161
+ log.info("Diff field-level calculation queries for base complete")
162
+
163
+ # Temporary until next change
164
+ current_node_field_specifier_tuples: list[tuple[str, str]] = []
165
+ new_node_field_specifier_tuples: list[tuple[str, str]] = []
166
+ for node_uuid, field_names in current_node_field_specifiers.items():
167
+ current_node_field_specifier_tuples.extend((node_uuid, field_name) for field_name in field_names)
168
+ for node_uuid, field_names in new_node_field_specifiers.items():
169
+ new_node_field_specifier_tuples.extend((node_uuid, field_name) for field_name in field_names)
58
170
 
59
171
  base_diff_query = await DiffAllPathsQuery.init(
60
172
  db=self.db,
@@ -63,10 +175,8 @@ class DiffCalculator:
63
175
  diff_branch_from_time=diff_branch_from_time,
64
176
  diff_from=from_time,
65
177
  diff_to=to_time,
66
- current_node_field_specifiers=[
67
- (nfs.node_uuid, nfs.field_name) for nfs in current_node_field_specifiers
68
- ],
69
- new_node_field_specifiers=[(nfs.node_uuid, nfs.field_name) for nfs in new_node_field_specifiers],
178
+ current_node_field_specifiers=current_node_field_specifier_tuples,
179
+ new_node_field_specifiers=new_node_field_specifier_tuples,
70
180
  )
71
181
 
72
182
  log.info("Beginning diff calculation query for base")
@@ -15,7 +15,6 @@ from .model.path import (
15
15
  EnrichedDiffs,
16
16
  EnrichedDiffsMetadata,
17
17
  NameTrackingId,
18
- NodeFieldSpecifier,
19
18
  TrackingId,
20
19
  )
21
20
 
@@ -44,7 +43,7 @@ class EnrichedDiffRequest:
44
43
  from_time: Timestamp
45
44
  to_time: Timestamp
46
45
  tracking_id: TrackingId | None = field(default=None)
47
- node_field_specifiers: set[NodeFieldSpecifier] = field(default_factory=set)
46
+ node_field_specifiers: dict[str, set[str]] = field(default_factory=dict)
48
47
 
49
48
  def __repr__(self) -> str:
50
49
  return (
@@ -82,15 +82,6 @@ def deserialize_tracking_id(tracking_id_str: str) -> TrackingId:
82
82
  raise ValueError(f"{tracking_id_str} is not a valid TrackingId")
83
83
 
84
84
 
85
- @dataclass
86
- class NodeFieldSpecifier:
87
- node_uuid: str
88
- field_name: str
89
-
90
- def __hash__(self) -> int:
91
- return hash(f"{self.node_uuid}:{self.field_name}")
92
-
93
-
94
85
  @dataclass
95
86
  class NodeDiffFieldSummary:
96
87
  kind: str
@@ -0,0 +1,64 @@
1
+ from typing import Any, Generator
2
+
3
+ from neo4j.graph import Node as Neo4jNode
4
+
5
+ from infrahub.core.query import Query, QueryType
6
+ from infrahub.database import InfrahubDatabase
7
+
8
+ from ..model.path import TrackingId
9
+
10
+
11
+ class EnrichedDiffAllConflictsQuery(Query):
12
+ name = "enriched_diff_all_conflicts"
13
+ type = QueryType.READ
14
+
15
+ def __init__(
16
+ self, diff_branch_name: str, tracking_id: TrackingId | None = None, diff_id: str | None = None, **kwargs: Any
17
+ ) -> None:
18
+ super().__init__(**kwargs)
19
+ if (diff_id is None and tracking_id is None) or (diff_id and tracking_id):
20
+ raise ValueError("EnrichedDiffAllConflictsQuery requires one and only one of `tracking_id` or `diff_id`")
21
+ self.diff_branch_name = diff_branch_name
22
+ self.tracking_id = tracking_id
23
+ self.diff_id = diff_id
24
+ if self.tracking_id is None and self.diff_id is None:
25
+ raise RuntimeError("tracking_id or diff_id is required")
26
+
27
+ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
28
+ self.params = {
29
+ "diff_branch_name": self.diff_branch_name,
30
+ "diff_id": self.diff_id,
31
+ "tracking_id": self.tracking_id.serialize() if self.tracking_id else None,
32
+ }
33
+ query = """
34
+ MATCH (root:DiffRoot)
35
+ WHERE ($diff_id IS NOT NULL AND root.uuid = $diff_id)
36
+ OR ($tracking_id IS NOT NULL AND root.tracking_id = $tracking_id AND root.diff_branch = $diff_branch_name)
37
+ CALL {
38
+ WITH root
39
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_CONFLICT]->(node_conflict:DiffConflict)
40
+ RETURN node.path_identifier AS path_identifier, node_conflict AS conflict
41
+ UNION
42
+ WITH root
43
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_ATTRIBUTE]->(:DiffAttribute)
44
+ -[:DIFF_HAS_PROPERTY]->(property:DiffProperty)-[:DIFF_HAS_CONFLICT]->(attr_property_conflict:DiffConflict)
45
+ RETURN property.path_identifier AS path_identifier, attr_property_conflict AS conflict
46
+ UNION
47
+ WITH root
48
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(:DiffRelationship)
49
+ -[:DIFF_HAS_ELEMENT]->(element:DiffRelationshipElement)-[:DIFF_HAS_CONFLICT]->(rel_element_conflict:DiffConflict)
50
+ RETURN element.path_identifier AS path_identifier, rel_element_conflict AS conflict
51
+ UNION
52
+ WITH root
53
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(:DiffRelationship)
54
+ -[:DIFF_HAS_ELEMENT]->(:DiffRelationshipElement)-[:DIFF_HAS_PROPERTY]->(property:DiffProperty)
55
+ -[:DIFF_HAS_CONFLICT]->(rel_property_conflict:DiffConflict)
56
+ RETURN property.path_identifier AS path_identifier, rel_property_conflict AS conflict
57
+ }
58
+ """
59
+ self.return_labels = ["path_identifier", "conflict"]
60
+ self.add_to_query(query=query)
61
+
62
+ def get_conflict_paths_and_nodes(self) -> Generator[tuple[str, Neo4jNode], None, None]:
63
+ for result in self.get_results():
64
+ yield (result.get_as_type("path_identifier", str), result.get_node("conflict"))
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections import defaultdict
3
4
  from dataclasses import dataclass, field
4
5
  from typing import TYPE_CHECKING, Any, Optional
5
6
  from uuid import uuid4
@@ -22,7 +23,6 @@ from .model.path import (
22
23
  DiffRelationship,
23
24
  DiffRoot,
24
25
  DiffSingleRelationship,
25
- NodeFieldSpecifier,
26
26
  )
27
27
 
28
28
  if TYPE_CHECKING:
@@ -459,7 +459,7 @@ class DiffQueryParser:
459
459
  schema_manager: SchemaManager,
460
460
  from_time: Timestamp,
461
461
  to_time: Optional[Timestamp] = None,
462
- previous_node_field_specifiers: set[NodeFieldSpecifier] | None = None,
462
+ previous_node_field_specifiers: dict[str, set[str]] | None = None,
463
463
  ) -> None:
464
464
  self.base_branch_name = base_branch.name
465
465
  self.diff_branch_name = diff_branch.name
@@ -473,7 +473,9 @@ class DiffQueryParser:
473
473
  self.diff_branched_from_time = Timestamp(diff_branch.get_branched_from())
474
474
  self._diff_root_by_branch: dict[str, DiffRootIntermediate] = {}
475
475
  self._final_diff_root_by_branch: dict[str, DiffRoot] = {}
476
- self._previous_node_field_specifiers = previous_node_field_specifiers or set()
476
+ self._previous_node_field_specifiers = previous_node_field_specifiers or {}
477
+ self._new_node_field_specifiers: dict[str, set[str]] | None = None
478
+ self._current_node_field_specifiers: dict[str, set[str]] | None = None
477
479
 
478
480
  def get_branches(self) -> set[str]:
479
481
  return set(self._final_diff_root_by_branch.keys())
@@ -485,38 +487,59 @@ class DiffQueryParser:
485
487
  return self._final_diff_root_by_branch[branch]
486
488
  return DiffRoot(from_time=self.from_time, to_time=self.to_time, uuid=str(uuid4()), branch=branch, nodes=[])
487
489
 
488
- def get_diff_node_field_specifiers(self) -> set[NodeFieldSpecifier]:
490
+ def get_diff_node_field_specifiers(self) -> dict[str, set[str]]:
489
491
  if self.diff_branch_name not in self._diff_root_by_branch:
490
- return set()
491
- node_field_specifiers: set[NodeFieldSpecifier] = set()
492
+ return {}
493
+ node_field_specifiers_map: dict[str, set[str]] = defaultdict(set)
492
494
  diff_root = self._diff_root_by_branch[self.diff_branch_name]
493
495
  for node in diff_root.nodes_by_id.values():
494
- node_field_specifiers.update(
495
- NodeFieldSpecifier(node_uuid=node.uuid, field_name=attribute_name)
496
- for attribute_name in node.attributes_by_name
497
- )
498
- node_field_specifiers.update(
499
- NodeFieldSpecifier(node_uuid=node.uuid, field_name=relationship_diff.identifier)
500
- for relationship_diff in node.relationships_by_name.values()
501
- )
502
- return node_field_specifiers
503
-
504
- def get_new_node_field_specifiers(self) -> set[NodeFieldSpecifier]:
496
+ for attribute_name in node.attributes_by_name:
497
+ node_field_specifiers_map[node.uuid].add(attribute_name)
498
+ for relationship_diff in node.relationships_by_name.values():
499
+ node_field_specifiers_map[node.uuid].add(relationship_diff.identifier)
500
+ return node_field_specifiers_map
501
+
502
+ def _remove_node_specifiers(
503
+ self, node_specifiers: dict[str, set[str]], node_specifiers_to_remove: dict[str, set[str]]
504
+ ) -> dict[str, set[str]]:
505
+ final_node_specifiers: dict[str, set[str]] = defaultdict(set)
506
+ for node_uuid, field_names_set in node_specifiers.items():
507
+ specifiers_to_remove = node_specifiers_to_remove.get(node_uuid, set())
508
+ final_specifiers = field_names_set - specifiers_to_remove
509
+ if final_specifiers:
510
+ final_node_specifiers[node_uuid] = final_specifiers
511
+ return final_node_specifiers
512
+
513
+ def get_new_node_field_specifiers(self) -> dict[str, set[str]]:
514
+ if self._new_node_field_specifiers is not None:
515
+ return self._new_node_field_specifiers
505
516
  branch_node_specifiers = self.get_diff_node_field_specifiers()
506
- new_node_field_specifiers = branch_node_specifiers - self._previous_node_field_specifiers
517
+ new_node_field_specifiers = self._remove_node_specifiers(
518
+ branch_node_specifiers, self._previous_node_field_specifiers
519
+ )
520
+ self._new_node_field_specifiers = new_node_field_specifiers
507
521
  return new_node_field_specifiers
508
522
 
509
- def get_current_node_field_specifiers(self) -> set[NodeFieldSpecifier]:
523
+ def get_current_node_field_specifiers(self) -> dict[str, set[str]]:
524
+ if self._current_node_field_specifiers is not None:
525
+ return self._current_node_field_specifiers
510
526
  new_node_field_specifiers = self.get_new_node_field_specifiers()
511
- current_node_field_specifiers = self._previous_node_field_specifiers - new_node_field_specifiers
527
+ current_node_field_specifiers = self._remove_node_specifiers(
528
+ self._previous_node_field_specifiers, new_node_field_specifiers
529
+ )
530
+ self._current_node_field_specifiers = current_node_field_specifiers
512
531
  return current_node_field_specifiers
513
532
 
514
533
  def read_result(self, query_result: QueryResult) -> None:
515
534
  path = query_result.get_path(label="diff_path")
516
535
  database_path = DatabasePath.from_cypher_path(cypher_path=path)
517
536
  self._parse_path(database_path=database_path)
537
+ self._current_node_field_specifiers = None
538
+ self._new_node_field_specifiers = None
518
539
 
519
540
  def parse(self, include_unchanged: bool = False) -> None:
541
+ self._new_node_field_specifiers = None
542
+ self._current_node_field_specifiers = None
520
543
  if len(self._diff_root_by_branch) > 1:
521
544
  self._apply_base_branch_previous_values()
522
545
  self._remove_empty_base_diff_root()
@@ -586,11 +609,12 @@ class DiffQueryParser:
586
609
  self, database_path: DatabasePath, diff_node: DiffNodeIntermediate
587
610
  ) -> DiffAttributeIntermediate:
588
611
  attribute_name = database_path.attribute_name
589
- node_field_specifier = NodeFieldSpecifier(node_uuid=diff_node.uuid, field_name=attribute_name)
590
612
  branch_name = database_path.deepest_branch
591
613
  from_time = self.from_time
592
- if branch_name == self.base_branch_name and node_field_specifier in self.get_new_node_field_specifiers():
593
- from_time = self.diff_branched_from_time
614
+ if branch_name == self.base_branch_name:
615
+ new_node_field_specifiers = self.get_new_node_field_specifiers()
616
+ if attribute_name in new_node_field_specifiers.get(diff_node.uuid, set()):
617
+ from_time = self.diff_branched_from_time
594
618
  if attribute_name not in diff_node.attributes_by_name:
595
619
  diff_node.attributes_by_name[attribute_name] = DiffAttributeIntermediate(
596
620
  uuid=database_path.attribute_id,
@@ -627,13 +651,12 @@ class DiffQueryParser:
627
651
  ) -> DiffRelationshipIntermediate:
628
652
  diff_relationship = diff_node.relationships_by_name.get(relationship_schema.name)
629
653
  if not diff_relationship:
630
- node_field_specifier = NodeFieldSpecifier(
631
- node_uuid=diff_node.uuid, field_name=relationship_schema.get_identifier()
632
- )
633
654
  branch_name = database_path.deepest_branch
634
655
  from_time = self.from_time
635
- if branch_name == self.base_branch_name and node_field_specifier in self.get_new_node_field_specifiers():
636
- from_time = self.diff_branched_from_time
656
+ if branch_name == self.base_branch_name:
657
+ new_node_field_specifiers = self.get_new_node_field_specifiers()
658
+ if relationship_schema.get_identifier() in new_node_field_specifiers.get(diff_node.uuid, set()):
659
+ from_time = self.diff_branched_from_time
637
660
  diff_relationship = DiffRelationshipIntermediate(
638
661
  name=relationship_schema.name,
639
662
  cardinality=relationship_schema.cardinality,