infrahub-server 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. infrahub/api/artifact.py +4 -1
  2. infrahub/api/oidc.py +33 -4
  3. infrahub/api/transformation.py +2 -0
  4. infrahub/computed_attribute/tasks.py +9 -5
  5. infrahub/core/diff/calculator.py +127 -35
  6. infrahub/core/diff/coordinator.py +1 -2
  7. infrahub/core/diff/ipam_diff_parser.py +1 -1
  8. infrahub/core/diff/merger/serializer.py +15 -8
  9. infrahub/core/diff/model/path.py +0 -9
  10. infrahub/core/diff/query/all_conflicts.py +64 -0
  11. infrahub/core/diff/query/diff_get.py +17 -22
  12. infrahub/core/diff/query_parser.py +51 -28
  13. infrahub/core/diff/repository/deserializer.py +51 -47
  14. infrahub/core/diff/repository/repository.py +75 -33
  15. infrahub/core/merge.py +7 -5
  16. infrahub/core/query/diff.py +509 -758
  17. infrahub/core/relationship/model.py +11 -3
  18. infrahub/core/schema/schema_branch.py +16 -7
  19. infrahub/core/validators/aggregated_checker.py +34 -4
  20. infrahub/core/validators/node/hierarchy.py +1 -1
  21. infrahub/generators/tasks.py +12 -10
  22. infrahub/git/base.py +50 -8
  23. infrahub/git/integrator.py +30 -19
  24. infrahub/git/models.py +3 -2
  25. infrahub/git/repository.py +2 -2
  26. infrahub/git/tasks.py +19 -11
  27. infrahub/graphql/mutations/artifact_definition.py +10 -2
  28. infrahub/message_bus/messages/check_repository_usercheck.py +1 -0
  29. infrahub/message_bus/operations/check/repository.py +5 -2
  30. infrahub/message_bus/operations/requests/artifact_definition.py +1 -1
  31. infrahub/message_bus/operations/requests/proposed_change.py +4 -0
  32. infrahub/message_bus/types.py +1 -0
  33. infrahub/proposed_change/tasks.py +6 -1
  34. infrahub/storage.py +6 -5
  35. infrahub/transformations/constants.py +1 -0
  36. infrahub/transformations/models.py +3 -1
  37. infrahub/transformations/tasks.py +2 -2
  38. infrahub/webhook/models.py +9 -1
  39. infrahub/workflows/catalogue.py +2 -2
  40. infrahub_sdk/batch.py +2 -2
  41. infrahub_sdk/config.py +1 -1
  42. infrahub_sdk/ctl/check.py +1 -1
  43. infrahub_sdk/ctl/cli_commands.py +4 -4
  44. infrahub_sdk/ctl/exporter.py +2 -2
  45. infrahub_sdk/ctl/importer.py +6 -4
  46. infrahub_sdk/ctl/repository.py +56 -1
  47. infrahub_sdk/ctl/utils.py +2 -2
  48. infrahub_sdk/data.py +1 -1
  49. infrahub_sdk/node.py +1 -4
  50. infrahub_sdk/protocols.py +0 -1
  51. infrahub_sdk/pytest_plugin/items/graphql_query.py +1 -1
  52. infrahub_sdk/pytest_plugin/items/jinja2_transform.py +1 -1
  53. infrahub_sdk/pytest_plugin/items/python_transform.py +1 -1
  54. infrahub_sdk/schema/__init__.py +0 -3
  55. infrahub_sdk/testing/docker.py +30 -0
  56. infrahub_sdk/transfer/exporter/json.py +1 -1
  57. {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.6.dist-info}/METADATA +1 -1
  58. {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.6.dist-info}/RECORD +62 -60
  59. infrahub_testcontainers/docker-compose.test.yml +1 -1
  60. {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.6.dist-info}/LICENSE.txt +0 -0
  61. {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.6.dist-info}/WHEEL +0 -0
  62. {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.6.dist-info}/entry_points.txt +0 -0
infrahub/api/artifact.py CHANGED
@@ -85,7 +85,10 @@ async def generate_artifact(
85
85
 
86
86
  service = request.app.state.service
87
87
  model = RequestArtifactDefinitionGenerate(
88
- artifact_definition=artifact_definition.id, branch=branch_params.branch.name, limit=payload.nodes
88
+ artifact_definition_id=artifact_definition.id,
89
+ artifact_definition_name=artifact_definition.name.value,
90
+ branch=branch_params.branch.name,
91
+ limit=payload.nodes,
89
92
  )
90
93
 
91
94
  await service.workflow.submit_workflow(workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model})
infrahub/api/oidc.py CHANGED
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, Any
4
4
  from urllib.parse import urljoin
5
5
 
6
+ import jwt
6
7
  import ujson
7
8
  from authlib.integrations.httpx_client import AsyncOAuth2Client
8
9
  from fastapi import APIRouter, Depends, Request, Response
@@ -138,7 +139,7 @@ async def token(
138
139
 
139
140
  with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
140
141
  span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
141
- payload = token_response.json()
142
+ payload: dict[str, Any] = token_response.json()
142
143
 
143
144
  headers = {"Authorization": f"{payload.get('token_type')} {payload.get('access_token')}"}
144
145
 
@@ -148,8 +149,10 @@ async def token(
148
149
  userinfo_response = await service.http.post(str(oidc_config.userinfo_endpoint), headers=headers)
149
150
 
150
151
  _validate_response(response=userinfo_response)
151
- user_info = userinfo_response.json()
152
- sso_groups = user_info.get("groups", [])
152
+ user_info: dict[str, Any] = userinfo_response.json()
153
+ sso_groups = user_info.get("groups") or await _get_id_token_groups(
154
+ oidc_config=oidc_config, service=service, payload=payload, client_id=provider.client_id
155
+ )
153
156
 
154
157
  if not sso_groups and config.SETTINGS.security.sso_user_default_group:
155
158
  sso_groups = [config.SETTINGS.security.sso_user_default_group]
@@ -185,3 +188,29 @@ def _validate_response(response: httpx.Response) -> None:
185
188
  body=response.json(),
186
189
  )
187
190
  raise GatewayError(message="Invalid response from Authentication provider")
191
+
192
+
193
+ async def _get_id_token_groups(
194
+ oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str
195
+ ) -> list[str]:
196
+ id_token = payload.get("id_token")
197
+ if not id_token:
198
+ return []
199
+ jwks = await service.http.get(url=str(oidc_config.jwks_uri))
200
+
201
+ jwk_client = jwt.PyJWKClient(uri=str(oidc_config.jwks_uri), cache_jwk_set=True)
202
+ if jwk_client.jwk_set_cache:
203
+ jwk_client.jwk_set_cache.put(jwks.json())
204
+
205
+ signing_key = jwk_client.get_signing_key_from_jwt(id_token)
206
+
207
+ decoded_token: dict[str, Any] = jwt.decode(
208
+ jwt=id_token,
209
+ key=signing_key.key,
210
+ algorithms=oidc_config.id_token_signing_alg_values_supported,
211
+ audience=client_id,
212
+ issuer=str(oidc_config.issuer),
213
+ options={"verify_signature": False, "verify_aud": False, "verify_iss": False},
214
+ )
215
+
216
+ return decoded_token.get("groups", [])
@@ -83,6 +83,7 @@ async def transform_python(
83
83
  commit=repository.commit.value, # type: ignore[attr-defined]
84
84
  branch=branch_params.branch.name,
85
85
  transform_location=f"{transform.file_path.value}::{transform.class_name.value}",
86
+ timeout=transform.timeout.value,
86
87
  data=data,
87
88
  )
88
89
 
@@ -140,6 +141,7 @@ async def transform_jinja2(
140
141
  commit=repository.commit.value, # type: ignore[attr-defined]
141
142
  branch=branch_params.branch.name,
142
143
  template_location=transform.template_path.value,
144
+ timeout=transform.timeout.value,
143
145
  data=data,
144
146
  )
145
147
 
@@ -4,7 +4,10 @@ from datetime import timedelta
4
4
  from typing import TYPE_CHECKING, Any
5
5
 
6
6
  import ujson
7
- from infrahub_sdk.protocols import CoreNode # noqa: TC002
7
+ from infrahub_sdk.protocols import (
8
+ CoreNode, # noqa: TC002
9
+ CoreTransformPython,
10
+ )
8
11
  from prefect import flow
9
12
  from prefect.automations import AutomationCore
10
13
  from prefect.client.orchestration import get_client
@@ -88,17 +91,18 @@ async def process_transform(
88
91
 
89
92
  for attribute_name, transform_attribute in transform_attributes.items():
90
93
  transform = await service.client.get(
91
- kind="CoreTransformPython",
94
+ kind=CoreTransformPython,
92
95
  branch=branch_name,
93
96
  id=transform_attribute.transform,
94
97
  prefetch_relationships=True,
95
98
  populate_store=True,
96
99
  )
100
+
97
101
  if not transform:
98
102
  continue
99
103
 
100
104
  repo_node = await service.client.get(
101
- kind=transform.repository.peer.typename,
105
+ kind=str(transform.repository.peer.typename),
102
106
  branch=branch_name,
103
107
  id=transform.repository.peer.id,
104
108
  raise_when_missing=True,
@@ -108,7 +112,7 @@ async def process_transform(
108
112
  repository_id=transform.repository.peer.id,
109
113
  name=transform.repository.peer.name.value,
110
114
  service=service,
111
- repository_kind=transform.repository.peer.typename,
115
+ repository_kind=str(transform.repository.peer.typename),
112
116
  commit=repo_node.commit.value,
113
117
  )
114
118
 
@@ -120,7 +124,7 @@ async def process_transform(
120
124
  subscribers=[object_id],
121
125
  )
122
126
 
123
- transformed_data = await repo.execute_python_transform(
127
+ transformed_data = await repo.execute_python_transform.with_options(timeout_seconds=transform.timeout.value)(
124
128
  branch_name=branch_name,
125
129
  commit=repo_node.commit.value,
126
130
  location=f"{transform.file_path.value}::{transform.class_name.value}",
@@ -1,20 +1,73 @@
1
+ from dataclasses import dataclass, field
2
+
3
+ from infrahub import config
1
4
  from infrahub.core import registry
2
5
  from infrahub.core.branch import Branch
3
6
  from infrahub.core.diff.query_parser import DiffQueryParser
4
- from infrahub.core.query.diff import DiffAllPathsQuery
7
+ from infrahub.core.query.diff import (
8
+ DiffCalculationQuery,
9
+ DiffFieldPathsQuery,
10
+ DiffNodePathsQuery,
11
+ DiffPropertyPathsQuery,
12
+ )
5
13
  from infrahub.core.timestamp import Timestamp
6
14
  from infrahub.database import InfrahubDatabase
7
15
  from infrahub.log import get_logger
8
16
 
9
- from .model.path import CalculatedDiffs, NodeFieldSpecifier
17
+ from .model.path import CalculatedDiffs
10
18
 
11
19
  log = get_logger()
12
20
 
13
21
 
22
+ @dataclass
23
+ class DiffCalculationRequest:
24
+ base_branch: Branch
25
+ diff_branch: Branch
26
+ branch_from_time: Timestamp
27
+ from_time: Timestamp
28
+ to_time: Timestamp
29
+ current_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
30
+ new_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
31
+
32
+
14
33
  class DiffCalculator:
15
34
  def __init__(self, db: InfrahubDatabase) -> None:
16
35
  self.db = db
17
36
 
37
+ async def _run_diff_calculation_query(
38
+ self,
39
+ diff_parser: DiffQueryParser,
40
+ query_class: type[DiffCalculationQuery],
41
+ calculation_request: DiffCalculationRequest,
42
+ limit: int,
43
+ ) -> None:
44
+ has_more_data = True
45
+ offset = 0
46
+ while has_more_data:
47
+ diff_query = await query_class.init(
48
+ db=self.db,
49
+ branch=calculation_request.diff_branch,
50
+ base_branch=calculation_request.base_branch,
51
+ diff_branch_from_time=calculation_request.branch_from_time,
52
+ diff_from=calculation_request.from_time,
53
+ diff_to=calculation_request.to_time,
54
+ current_node_field_specifiers=calculation_request.current_node_field_specifiers,
55
+ new_node_field_specifiers=calculation_request.new_node_field_specifiers,
56
+ limit=limit,
57
+ offset=offset,
58
+ )
59
+ log.info(f"Beginning one diff calculation query {limit=}, {offset=}")
60
+ await diff_query.execute(db=self.db)
61
+ log.info("Diff calculation query complete")
62
+ last_result = None
63
+ for query_result in diff_query.get_results():
64
+ diff_parser.read_result(query_result=query_result)
65
+ last_result = query_result
66
+ has_more_data = False
67
+ if last_result:
68
+ has_more_data = last_result.get_as_type("has_more_data", bool)
69
+ offset += limit
70
+
18
71
  async def calculate_diff(
19
72
  self,
20
73
  base_branch: Branch,
@@ -22,7 +75,7 @@ class DiffCalculator:
22
75
  from_time: Timestamp,
23
76
  to_time: Timestamp,
24
77
  include_unchanged: bool = True,
25
- previous_node_specifiers: set[NodeFieldSpecifier] | None = None,
78
+ previous_node_specifiers: dict[str, set[str]] | None = None,
26
79
  ) -> CalculatedDiffs:
27
80
  if diff_branch.name == registry.default_branch:
28
81
  diff_branch_from_time = from_time
@@ -36,46 +89,85 @@ class DiffCalculator:
36
89
  to_time=to_time,
37
90
  previous_node_field_specifiers=previous_node_specifiers,
38
91
  )
39
- branch_diff_query = await DiffAllPathsQuery.init(
40
- db=self.db,
41
- branch=diff_branch,
92
+ node_limit = int(config.SETTINGS.database.query_size_limit / 10)
93
+ fields_limit = int(config.SETTINGS.database.query_size_limit / 3)
94
+ properties_limit = int(config.SETTINGS.database.query_size_limit)
95
+
96
+ calculation_request = DiffCalculationRequest(
42
97
  base_branch=base_branch,
43
- diff_branch_from_time=diff_branch_from_time,
44
- diff_from=from_time,
45
- diff_to=to_time,
98
+ diff_branch=diff_branch,
99
+ branch_from_time=diff_branch_from_time,
100
+ from_time=from_time,
101
+ to_time=to_time,
46
102
  )
47
- log.info("Beginning diff calculation query for branch")
48
- await branch_diff_query.execute(db=self.db)
49
- log.info("Diff calculation query for branch complete")
50
- log.info("Reading results of query for branch")
51
- for query_result in branch_diff_query.get_results():
52
- diff_parser.read_result(query_result=query_result)
53
- log.info("Results of query for branch read")
103
+
104
+ log.info("Beginning diff node-level calculation queries for branch")
105
+ await self._run_diff_calculation_query(
106
+ diff_parser=diff_parser,
107
+ query_class=DiffNodePathsQuery,
108
+ calculation_request=calculation_request,
109
+ limit=node_limit,
110
+ )
111
+ log.info("Diff node-level calculation queries for branch complete")
112
+
113
+ log.info("Beginning diff field-level calculation queries for branch")
114
+ await self._run_diff_calculation_query(
115
+ diff_parser=diff_parser,
116
+ query_class=DiffFieldPathsQuery,
117
+ calculation_request=calculation_request,
118
+ limit=fields_limit,
119
+ )
120
+ log.info("Diff field-level calculation queries for branch complete")
121
+
122
+ log.info("Beginning diff property-level calculation queries for branch")
123
+ await self._run_diff_calculation_query(
124
+ diff_parser=diff_parser,
125
+ query_class=DiffPropertyPathsQuery,
126
+ calculation_request=calculation_request,
127
+ limit=properties_limit,
128
+ )
129
+ log.info("Diff property-level calculation queries for branch complete")
54
130
 
55
131
  if base_branch.name != diff_branch.name:
56
- new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
57
132
  current_node_field_specifiers = diff_parser.get_current_node_field_specifiers()
58
-
59
- base_diff_query = await DiffAllPathsQuery.init(
60
- db=self.db,
61
- branch=base_branch,
133
+ new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
134
+ calculation_request = DiffCalculationRequest(
62
135
  base_branch=base_branch,
63
- diff_branch_from_time=diff_branch_from_time,
64
- diff_from=from_time,
65
- diff_to=to_time,
66
- current_node_field_specifiers=[
67
- (nfs.node_uuid, nfs.field_name) for nfs in current_node_field_specifiers
68
- ],
69
- new_node_field_specifiers=[(nfs.node_uuid, nfs.field_name) for nfs in new_node_field_specifiers],
136
+ diff_branch=base_branch,
137
+ branch_from_time=diff_branch_from_time,
138
+ from_time=from_time,
139
+ to_time=to_time,
140
+ current_node_field_specifiers=current_node_field_specifiers,
141
+ new_node_field_specifiers=new_node_field_specifiers,
70
142
  )
71
143
 
72
- log.info("Beginning diff calculation query for base")
73
- await base_diff_query.execute(db=self.db)
74
- log.info("Diff calculation query for base complete")
75
- log.info("Reading results of query for base")
76
- for query_result in base_diff_query.get_results():
77
- diff_parser.read_result(query_result=query_result)
78
- log.info("Results of query for branch read")
144
+ log.info("Beginning diff node-level calculation queries for base")
145
+ await self._run_diff_calculation_query(
146
+ diff_parser=diff_parser,
147
+ query_class=DiffNodePathsQuery,
148
+ calculation_request=calculation_request,
149
+ limit=node_limit,
150
+ )
151
+ log.info("Diff node-level calculation queries for base complete")
152
+
153
+ log.info("Beginning diff field-level calculation queries for base")
154
+ await self._run_diff_calculation_query(
155
+ diff_parser=diff_parser,
156
+ query_class=DiffFieldPathsQuery,
157
+ calculation_request=calculation_request,
158
+ limit=fields_limit,
159
+ )
160
+ log.info("Diff field-level calculation queries for base complete")
161
+
162
+ log.info("Beginning diff property-level calculation queries for base")
163
+ await self._run_diff_calculation_query(
164
+ diff_parser=diff_parser,
165
+ query_class=DiffPropertyPathsQuery,
166
+ calculation_request=calculation_request,
167
+ limit=properties_limit,
168
+ )
169
+ log.info("Diff property-level calculation queries for base complete")
170
+
79
171
  log.info("Parsing calculated diff")
80
172
  diff_parser.parse(include_unchanged=include_unchanged)
81
173
  log.info("Calculated diff parsed")
@@ -15,7 +15,6 @@ from .model.path import (
15
15
  EnrichedDiffs,
16
16
  EnrichedDiffsMetadata,
17
17
  NameTrackingId,
18
- NodeFieldSpecifier,
19
18
  TrackingId,
20
19
  )
21
20
 
@@ -44,7 +43,7 @@ class EnrichedDiffRequest:
44
43
  from_time: Timestamp
45
44
  to_time: Timestamp
46
45
  tracking_id: TrackingId | None = field(default=None)
47
- node_field_specifiers: set[NodeFieldSpecifier] = field(default_factory=set)
46
+ node_field_specifiers: dict[str, set[str]] = field(default_factory=dict)
48
47
 
49
48
  def __repr__(self) -> str:
50
49
  return (
@@ -116,7 +116,7 @@ class IpamDiffParser:
116
116
  rels = await node_from_db.ip_namespace.get_relationships(db=self.db) # type: ignore[attr-defined]
117
117
  if rels:
118
118
  cnd.namespace_id = rels[0].get_peer_id()
119
- if cnd.ip_value and cnd.namespace_id:
119
+ if cnd.ip_value and cnd.namespace_id and cnd.node_uuid in uuids_missing_data:
120
120
  uuids_missing_data.remove(cnd.node_uuid)
121
121
 
122
122
  async def _add_missing_values(
@@ -294,10 +294,6 @@ class DiffMergeSerializer:
294
294
  actions_and_peers = self._get_actions_and_peers(relationship_diff=relationship_diff)
295
295
  added_peer_ids = [peer_id for action, peer_id in actions_and_peers if action is DiffAction.ADDED]
296
296
  removed_peer_ids = [peer_id for action, peer_id in actions_and_peers if action is DiffAction.REMOVED]
297
- if not added_peer_ids:
298
- added_peer_ids = [relationship_diff.peer_id]
299
- if not removed_peer_ids:
300
- removed_peer_ids = [relationship_diff.peer_id]
301
297
  for action, peer_id in actions_and_peers:
302
298
  if (
303
299
  peer_id
@@ -314,6 +310,7 @@ class DiffMergeSerializer:
314
310
  relationship_diff_properties=relationship_diff.properties,
315
311
  added_peer_ids=added_peer_ids,
316
312
  removed_peer_ids=removed_peer_ids,
313
+ unchanged_peer_id=relationship_diff.peer_id,
317
314
  )
318
315
  return relationship_dicts, relationship_property_dicts
319
316
 
@@ -324,9 +321,14 @@ class DiffMergeSerializer:
324
321
  relationship_diff_properties: set[EnrichedDiffProperty],
325
322
  added_peer_ids: list[str],
326
323
  removed_peer_ids: list[str],
324
+ unchanged_peer_id: str,
327
325
  ) -> list[RelationshipPropertyMergeDict]:
328
- added_property_dicts = self._get_default_property_merge_dicts(action=DiffAction.ADDED)
329
- removed_property_dicts = self._get_default_property_merge_dicts(action=DiffAction.REMOVED)
326
+ added_property_dicts = {}
327
+ removed_property_dicts = {}
328
+ if added_peer_ids:
329
+ added_property_dicts = self._get_default_property_merge_dicts(action=DiffAction.ADDED)
330
+ if removed_peer_ids:
331
+ removed_property_dicts = self._get_default_property_merge_dicts(action=DiffAction.REMOVED)
330
332
  for property_diff in relationship_diff_properties:
331
333
  if property_diff.property_type is DatabaseEdgeType.IS_RELATED:
332
334
  # handled above
@@ -348,10 +350,15 @@ class DiffMergeSerializer:
348
350
  elif action is DiffAction.REMOVED:
349
351
  removed_property_dicts[property_diff.property_type] = property_dict
350
352
  relationship_property_dicts = []
353
+ peers_and_property_dicts: list[tuple[str, dict[DatabaseEdgeType, PropertyMergeDict]]] = []
351
354
  if added_property_dicts:
352
- peers_and_property_dicts = [(peer_id, added_property_dicts) for peer_id in added_peer_ids]
355
+ peers_and_property_dicts += [
356
+ (peer_id, added_property_dicts) for peer_id in (added_peer_ids or [unchanged_peer_id])
357
+ ]
353
358
  if removed_property_dicts:
354
- peers_and_property_dicts += [(peer_id, removed_property_dicts) for peer_id in removed_peer_ids]
359
+ peers_and_property_dicts += [
360
+ (peer_id, removed_property_dicts) for peer_id in (removed_peer_ids or [unchanged_peer_id])
361
+ ]
355
362
  for peer_id, property_dicts in peers_and_property_dicts:
356
363
  if (
357
364
  peer_id
@@ -82,15 +82,6 @@ def deserialize_tracking_id(tracking_id_str: str) -> TrackingId:
82
82
  raise ValueError(f"{tracking_id_str} is not a valid TrackingId")
83
83
 
84
84
 
85
- @dataclass
86
- class NodeFieldSpecifier:
87
- node_uuid: str
88
- field_name: str
89
-
90
- def __hash__(self) -> int:
91
- return hash(f"{self.node_uuid}:{self.field_name}")
92
-
93
-
94
85
  @dataclass
95
86
  class NodeDiffFieldSummary:
96
87
  kind: str
@@ -0,0 +1,64 @@
1
+ from typing import Any, Generator
2
+
3
+ from neo4j.graph import Node as Neo4jNode
4
+
5
+ from infrahub.core.query import Query, QueryType
6
+ from infrahub.database import InfrahubDatabase
7
+
8
+ from ..model.path import TrackingId
9
+
10
+
11
+ class EnrichedDiffAllConflictsQuery(Query):
12
+ name = "enriched_diff_all_conflicts"
13
+ type = QueryType.READ
14
+
15
+ def __init__(
16
+ self, diff_branch_name: str, tracking_id: TrackingId | None = None, diff_id: str | None = None, **kwargs: Any
17
+ ) -> None:
18
+ super().__init__(**kwargs)
19
+ if (diff_id is None and tracking_id is None) or (diff_id and tracking_id):
20
+ raise ValueError("EnrichedDiffAllConflictsQuery requires one and only one of `tracking_id` or `diff_id`")
21
+ self.diff_branch_name = diff_branch_name
22
+ self.tracking_id = tracking_id
23
+ self.diff_id = diff_id
24
+ if self.tracking_id is None and self.diff_id is None:
25
+ raise RuntimeError("tracking_id or diff_id is required")
26
+
27
+ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
28
+ self.params = {
29
+ "diff_branch_name": self.diff_branch_name,
30
+ "diff_id": self.diff_id,
31
+ "tracking_id": self.tracking_id.serialize() if self.tracking_id else None,
32
+ }
33
+ query = """
34
+ MATCH (root:DiffRoot)
35
+ WHERE ($diff_id IS NOT NULL AND root.uuid = $diff_id)
36
+ OR ($tracking_id IS NOT NULL AND root.tracking_id = $tracking_id AND root.diff_branch = $diff_branch_name)
37
+ CALL {
38
+ WITH root
39
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_CONFLICT]->(node_conflict:DiffConflict)
40
+ RETURN node.path_identifier AS path_identifier, node_conflict AS conflict
41
+ UNION
42
+ WITH root
43
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_ATTRIBUTE]->(:DiffAttribute)
44
+ -[:DIFF_HAS_PROPERTY]->(property:DiffProperty)-[:DIFF_HAS_CONFLICT]->(attr_property_conflict:DiffConflict)
45
+ RETURN property.path_identifier AS path_identifier, attr_property_conflict AS conflict
46
+ UNION
47
+ WITH root
48
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(:DiffRelationship)
49
+ -[:DIFF_HAS_ELEMENT]->(element:DiffRelationshipElement)-[:DIFF_HAS_CONFLICT]->(rel_element_conflict:DiffConflict)
50
+ RETURN element.path_identifier AS path_identifier, rel_element_conflict AS conflict
51
+ UNION
52
+ WITH root
53
+ MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(:DiffRelationship)
54
+ -[:DIFF_HAS_ELEMENT]->(:DiffRelationshipElement)-[:DIFF_HAS_PROPERTY]->(property:DiffProperty)
55
+ -[:DIFF_HAS_CONFLICT]->(rel_property_conflict:DiffConflict)
56
+ RETURN property.path_identifier AS path_identifier, rel_property_conflict AS conflict
57
+ }
58
+ """
59
+ self.return_labels = ["path_identifier", "conflict"]
60
+ self.add_to_query(query=query)
61
+
62
+ def get_conflict_paths_and_nodes(self) -> Generator[tuple[str, Neo4jNode], None, None]:
63
+ for result in self.get_results():
64
+ yield (result.get_as_type("path_identifier", str), result.get_node("conflict"))
@@ -17,8 +17,6 @@ QUERY_MATCH_NODES = """
17
17
  AND ($to_time IS NULL OR diff_root.to_time <= $to_time)
18
18
  AND ($tracking_id IS NULL OR diff_root.tracking_id = $tracking_id)
19
19
  AND ($diff_ids IS NULL OR diff_root.uuid IN $diff_ids)
20
- WITH diff_root
21
- ORDER BY diff_root.base_branch, diff_root.diff_branch, diff_root.from_time, diff_root.to_time
22
20
  // get all the nodes attached to the diffs
23
21
  OPTIONAL MATCH (diff_root)-[:DIFF_HAS_NODE]->(diff_node:DiffNode)
24
22
  """
@@ -79,26 +77,21 @@ class EnrichedDiffGetQuery(Query):
79
77
  self.add_to_query(query=query_filters)
80
78
 
81
79
  query_2 = """
82
- // group by diff node uuid for pagination
83
- WITH diff_node.uuid AS diff_node_uuid, diff_node.kind AS diff_node_kind, collect([diff_root, diff_node]) AS node_root_tuples
84
- // order by kind and latest label for each diff_node uuid
85
- CALL {
86
- WITH node_root_tuples
87
- UNWIND node_root_tuples AS nrt
88
- WITH nrt[0] AS diff_root, nrt[1] AS diff_node
89
- ORDER BY diff_root.from_time DESC
90
- RETURN diff_node.label AS latest_node_label
91
- LIMIT 1
92
- }
93
- WITH diff_node_kind, node_root_tuples, latest_node_label
94
- ORDER BY diff_node_kind, latest_node_label
80
+ WITH diff_root, diff_node
81
+ ORDER BY diff_root.base_branch, diff_root.diff_branch, diff_root.from_time, diff_root.to_time, diff_node.uuid
82
+ // -------------------------------------
83
+ // Limit number of results
84
+ // -------------------------------------
95
85
  SKIP COALESCE($offset, 0)
96
86
  LIMIT $limit
97
- UNWIND node_root_tuples AS nrt
98
- WITH nrt[0] AS diff_root, nrt[1] AS diff_node
99
- WITH diff_root, diff_node
87
+ // -------------------------------------
88
+ // Check if more data after this limited group
89
+ // -------------------------------------
90
+ WITH collect([diff_root, diff_node]) AS limited_results
91
+ WITH limited_results, size(limited_results) = $limit AS has_more_data
92
+ UNWIND limited_results AS one_result
93
+ WITH one_result[0] AS diff_root, one_result[1] AS diff_node, has_more_data
100
94
  // if depth limit, make sure not to exceed it when traversing linked nodes
101
- WITH diff_root, diff_node
102
95
  // -------------------------------------
103
96
  // Retrieve Parents
104
97
  // -------------------------------------
@@ -109,12 +102,12 @@ class EnrichedDiffGetQuery(Query):
109
102
  ORDER BY size(nodes(parents_path)) DESC
110
103
  LIMIT 1
111
104
  }
112
- WITH diff_root, diff_node, parents_path
105
+ WITH diff_root, diff_node, has_more_data, parents_path
113
106
  // -------------------------------------
114
107
  // Retrieve conflicts
115
108
  // -------------------------------------
116
109
  OPTIONAL MATCH (diff_node)-[:DIFF_HAS_CONFLICT]->(diff_node_conflict:DiffConflict)
117
- WITH diff_root, diff_node, parents_path, diff_node_conflict
110
+ WITH diff_root, diff_node, has_more_data, parents_path, diff_node_conflict
118
111
  // -------------------------------------
119
112
  // Retrieve Attributes
120
113
  // -------------------------------------
@@ -128,7 +121,7 @@ class EnrichedDiffGetQuery(Query):
128
121
  RETURN diff_attribute, diff_attr_property, diff_attr_property_conflict
129
122
  ORDER BY diff_attribute.name, diff_attr_property.property_type
130
123
  }
131
- WITH diff_root, diff_node, parents_path, diff_node_conflict, collect([diff_attribute, diff_attr_property, diff_attr_property_conflict]) as diff_attributes
124
+ WITH diff_root, diff_node, has_more_data, parents_path, diff_node_conflict, collect([diff_attribute, diff_attr_property, diff_attr_property_conflict]) as diff_attributes
132
125
  // -------------------------------------
133
126
  // Retrieve Relationships
134
127
  // -------------------------------------
@@ -150,6 +143,7 @@ class EnrichedDiffGetQuery(Query):
150
143
  WITH
151
144
  diff_root,
152
145
  diff_node,
146
+ has_more_data,
153
147
  parents_path,
154
148
  diff_node_conflict,
155
149
  diff_attributes,
@@ -161,6 +155,7 @@ class EnrichedDiffGetQuery(Query):
161
155
  self.return_labels = [
162
156
  "diff_root",
163
157
  "diff_node",
158
+ "has_more_data",
164
159
  "parents_path",
165
160
  "diff_node_conflict",
166
161
  "diff_attributes",