infrahub-server 1.1.4__py3-none-any.whl → 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- infrahub/api/artifact.py +4 -1
- infrahub/api/oidc.py +32 -4
- infrahub/api/transformation.py +2 -0
- infrahub/computed_attribute/tasks.py +9 -5
- infrahub/core/diff/calculator.py +118 -8
- infrahub/core/diff/coordinator.py +1 -2
- infrahub/core/diff/model/path.py +0 -9
- infrahub/core/diff/query/all_conflicts.py +64 -0
- infrahub/core/diff/query_parser.py +51 -28
- infrahub/core/diff/repository/repository.py +23 -14
- infrahub/core/merge.py +7 -5
- infrahub/core/query/diff.py +417 -621
- infrahub/core/relationship/model.py +10 -2
- infrahub/core/validators/aggregated_checker.py +34 -4
- infrahub/core/validators/node/hierarchy.py +1 -1
- infrahub/generators/tasks.py +12 -10
- infrahub/git/base.py +50 -8
- infrahub/git/integrator.py +24 -17
- infrahub/git/models.py +3 -2
- infrahub/git/repository.py +2 -2
- infrahub/git/tasks.py +19 -11
- infrahub/graphql/mutations/artifact_definition.py +10 -2
- infrahub/message_bus/messages/check_repository_usercheck.py +1 -0
- infrahub/message_bus/operations/check/repository.py +5 -2
- infrahub/message_bus/operations/requests/artifact_definition.py +1 -1
- infrahub/message_bus/operations/requests/proposed_change.py +4 -0
- infrahub/message_bus/types.py +1 -0
- infrahub/transformations/constants.py +1 -0
- infrahub/transformations/models.py +3 -1
- infrahub/transformations/tasks.py +2 -2
- infrahub/webhook/models.py +9 -1
- infrahub/workflows/catalogue.py +2 -2
- infrahub_sdk/ctl/cli_commands.py +4 -4
- infrahub_sdk/ctl/exporter.py +2 -2
- infrahub_sdk/ctl/importer.py +6 -4
- infrahub_sdk/ctl/repository.py +56 -1
- infrahub_sdk/pytest_plugin/items/graphql_query.py +1 -1
- infrahub_sdk/pytest_plugin/items/jinja2_transform.py +1 -1
- infrahub_sdk/pytest_plugin/items/python_transform.py +1 -1
- {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.5.dist-info}/METADATA +1 -1
- {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.5.dist-info}/RECORD +45 -43
- infrahub_testcontainers/docker-compose.test.yml +1 -1
- {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.5.dist-info}/LICENSE.txt +0 -0
- {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.5.dist-info}/WHEEL +0 -0
- {infrahub_server-1.1.4.dist-info → infrahub_server-1.1.5.dist-info}/entry_points.txt +0 -0
infrahub/api/artifact.py
CHANGED
|
@@ -85,7 +85,10 @@ async def generate_artifact(
|
|
|
85
85
|
|
|
86
86
|
service = request.app.state.service
|
|
87
87
|
model = RequestArtifactDefinitionGenerate(
|
|
88
|
-
|
|
88
|
+
artifact_definition_id=artifact_definition.id,
|
|
89
|
+
artifact_definition_name=artifact_definition.name.value,
|
|
90
|
+
branch=branch_params.branch.name,
|
|
91
|
+
limit=payload.nodes,
|
|
89
92
|
)
|
|
90
93
|
|
|
91
94
|
await service.workflow.submit_workflow(workflow=REQUEST_ARTIFACT_DEFINITION_GENERATE, parameters={"model": model})
|
infrahub/api/oidc.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
4
|
from urllib.parse import urljoin
|
|
5
5
|
|
|
6
|
+
import jwt
|
|
6
7
|
import ujson
|
|
7
8
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
8
9
|
from fastapi import APIRouter, Depends, Request, Response
|
|
@@ -138,7 +139,7 @@ async def token(
|
|
|
138
139
|
|
|
139
140
|
with trace.get_tracer(__name__).start_as_current_span("sso_token_request") as span:
|
|
140
141
|
span.set_attribute("token_request_data", ujson.dumps(token_response.json()))
|
|
141
|
-
payload = token_response.json()
|
|
142
|
+
payload: dict[str, Any] = token_response.json()
|
|
142
143
|
|
|
143
144
|
headers = {"Authorization": f"{payload.get('token_type')} {payload.get('access_token')}"}
|
|
144
145
|
|
|
@@ -148,8 +149,10 @@ async def token(
|
|
|
148
149
|
userinfo_response = await service.http.post(str(oidc_config.userinfo_endpoint), headers=headers)
|
|
149
150
|
|
|
150
151
|
_validate_response(response=userinfo_response)
|
|
151
|
-
user_info = userinfo_response.json()
|
|
152
|
-
sso_groups = user_info.get("groups"
|
|
152
|
+
user_info: dict[str, Any] = userinfo_response.json()
|
|
153
|
+
sso_groups = user_info.get("groups") or await _get_id_token_groups(
|
|
154
|
+
oidc_config=oidc_config, service=service, payload=payload, client_id=provider.client_id
|
|
155
|
+
)
|
|
153
156
|
|
|
154
157
|
if not sso_groups and config.SETTINGS.security.sso_user_default_group:
|
|
155
158
|
sso_groups = [config.SETTINGS.security.sso_user_default_group]
|
|
@@ -185,3 +188,28 @@ def _validate_response(response: httpx.Response) -> None:
|
|
|
185
188
|
body=response.json(),
|
|
186
189
|
)
|
|
187
190
|
raise GatewayError(message="Invalid response from Authentication provider")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
async def _get_id_token_groups(
|
|
194
|
+
oidc_config: OIDCDiscoveryConfig, service: InfrahubServices, payload: dict[str, Any], client_id: str
|
|
195
|
+
) -> list[str]:
|
|
196
|
+
id_token = payload.get("id_token")
|
|
197
|
+
if not id_token:
|
|
198
|
+
return []
|
|
199
|
+
jwks = await service.http.get(url=str(oidc_config.jwks_uri))
|
|
200
|
+
|
|
201
|
+
jwk_client = jwt.PyJWKClient(uri=str(oidc_config.jwks_uri), cache_jwk_set=True)
|
|
202
|
+
if jwk_client.jwk_set_cache:
|
|
203
|
+
jwk_client.jwk_set_cache.put(jwks.json())
|
|
204
|
+
|
|
205
|
+
signing_key = jwk_client.get_signing_key_from_jwt(id_token)
|
|
206
|
+
|
|
207
|
+
decoded_token: dict[str, Any] = jwt.decode(
|
|
208
|
+
jwt=id_token,
|
|
209
|
+
key=signing_key.key,
|
|
210
|
+
algorithms=oidc_config.id_token_signing_alg_values_supported,
|
|
211
|
+
audience=client_id,
|
|
212
|
+
issuer=str(oidc_config.issuer),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return decoded_token.get("groups", [])
|
infrahub/api/transformation.py
CHANGED
|
@@ -83,6 +83,7 @@ async def transform_python(
|
|
|
83
83
|
commit=repository.commit.value, # type: ignore[attr-defined]
|
|
84
84
|
branch=branch_params.branch.name,
|
|
85
85
|
transform_location=f"{transform.file_path.value}::{transform.class_name.value}",
|
|
86
|
+
timeout=transform.timeout.value,
|
|
86
87
|
data=data,
|
|
87
88
|
)
|
|
88
89
|
|
|
@@ -140,6 +141,7 @@ async def transform_jinja2(
|
|
|
140
141
|
commit=repository.commit.value, # type: ignore[attr-defined]
|
|
141
142
|
branch=branch_params.branch.name,
|
|
142
143
|
template_location=transform.template_path.value,
|
|
144
|
+
timeout=transform.timeout.value,
|
|
143
145
|
data=data,
|
|
144
146
|
)
|
|
145
147
|
|
|
@@ -4,7 +4,10 @@ from datetime import timedelta
|
|
|
4
4
|
from typing import TYPE_CHECKING, Any
|
|
5
5
|
|
|
6
6
|
import ujson
|
|
7
|
-
from infrahub_sdk.protocols import
|
|
7
|
+
from infrahub_sdk.protocols import (
|
|
8
|
+
CoreNode, # noqa: TC002
|
|
9
|
+
CoreTransformPython,
|
|
10
|
+
)
|
|
8
11
|
from prefect import flow
|
|
9
12
|
from prefect.automations import AutomationCore
|
|
10
13
|
from prefect.client.orchestration import get_client
|
|
@@ -88,17 +91,18 @@ async def process_transform(
|
|
|
88
91
|
|
|
89
92
|
for attribute_name, transform_attribute in transform_attributes.items():
|
|
90
93
|
transform = await service.client.get(
|
|
91
|
-
kind=
|
|
94
|
+
kind=CoreTransformPython,
|
|
92
95
|
branch=branch_name,
|
|
93
96
|
id=transform_attribute.transform,
|
|
94
97
|
prefetch_relationships=True,
|
|
95
98
|
populate_store=True,
|
|
96
99
|
)
|
|
100
|
+
|
|
97
101
|
if not transform:
|
|
98
102
|
continue
|
|
99
103
|
|
|
100
104
|
repo_node = await service.client.get(
|
|
101
|
-
kind=transform.repository.peer.typename,
|
|
105
|
+
kind=str(transform.repository.peer.typename),
|
|
102
106
|
branch=branch_name,
|
|
103
107
|
id=transform.repository.peer.id,
|
|
104
108
|
raise_when_missing=True,
|
|
@@ -108,7 +112,7 @@ async def process_transform(
|
|
|
108
112
|
repository_id=transform.repository.peer.id,
|
|
109
113
|
name=transform.repository.peer.name.value,
|
|
110
114
|
service=service,
|
|
111
|
-
repository_kind=transform.repository.peer.typename,
|
|
115
|
+
repository_kind=str(transform.repository.peer.typename),
|
|
112
116
|
commit=repo_node.commit.value,
|
|
113
117
|
)
|
|
114
118
|
|
|
@@ -120,7 +124,7 @@ async def process_transform(
|
|
|
120
124
|
subscribers=[object_id],
|
|
121
125
|
)
|
|
122
126
|
|
|
123
|
-
transformed_data = await repo.execute_python_transform(
|
|
127
|
+
transformed_data = await repo.execute_python_transform.with_options(timeout_seconds=transform.timeout.value)(
|
|
124
128
|
branch_name=branch_name,
|
|
125
129
|
commit=repo_node.commit.value,
|
|
126
130
|
location=f"{transform.file_path.value}::{transform.class_name.value}",
|
infrahub/core/diff/calculator.py
CHANGED
|
@@ -1,20 +1,68 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
from infrahub import config
|
|
1
4
|
from infrahub.core import registry
|
|
2
5
|
from infrahub.core.branch import Branch
|
|
3
6
|
from infrahub.core.diff.query_parser import DiffQueryParser
|
|
4
|
-
from infrahub.core.query.diff import DiffAllPathsQuery
|
|
7
|
+
from infrahub.core.query.diff import DiffAllPathsQuery, DiffCalculationQuery, DiffFieldPathsQuery, DiffNodePathsQuery
|
|
5
8
|
from infrahub.core.timestamp import Timestamp
|
|
6
9
|
from infrahub.database import InfrahubDatabase
|
|
7
10
|
from infrahub.log import get_logger
|
|
8
11
|
|
|
9
|
-
from .model.path import CalculatedDiffs
|
|
12
|
+
from .model.path import CalculatedDiffs
|
|
10
13
|
|
|
11
14
|
log = get_logger()
|
|
12
15
|
|
|
13
16
|
|
|
17
|
+
@dataclass
|
|
18
|
+
class DiffCalculationRequest:
|
|
19
|
+
base_branch: Branch
|
|
20
|
+
diff_branch: Branch
|
|
21
|
+
branch_from_time: Timestamp
|
|
22
|
+
from_time: Timestamp
|
|
23
|
+
to_time: Timestamp
|
|
24
|
+
current_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
|
|
25
|
+
new_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
|
|
26
|
+
|
|
27
|
+
|
|
14
28
|
class DiffCalculator:
|
|
15
29
|
def __init__(self, db: InfrahubDatabase) -> None:
|
|
16
30
|
self.db = db
|
|
17
31
|
|
|
32
|
+
async def _run_diff_calculation_query(
|
|
33
|
+
self,
|
|
34
|
+
diff_parser: DiffQueryParser,
|
|
35
|
+
query_class: type[DiffCalculationQuery],
|
|
36
|
+
calculation_request: DiffCalculationRequest,
|
|
37
|
+
limit: int,
|
|
38
|
+
) -> None:
|
|
39
|
+
has_more_data = True
|
|
40
|
+
offset = 0
|
|
41
|
+
while has_more_data:
|
|
42
|
+
diff_query = await query_class.init(
|
|
43
|
+
db=self.db,
|
|
44
|
+
branch=calculation_request.diff_branch,
|
|
45
|
+
base_branch=calculation_request.base_branch,
|
|
46
|
+
diff_branch_from_time=calculation_request.branch_from_time,
|
|
47
|
+
diff_from=calculation_request.from_time,
|
|
48
|
+
diff_to=calculation_request.to_time,
|
|
49
|
+
current_node_field_specifiers=calculation_request.current_node_field_specifiers,
|
|
50
|
+
new_node_field_specifiers=calculation_request.new_node_field_specifiers,
|
|
51
|
+
limit=limit,
|
|
52
|
+
offset=offset,
|
|
53
|
+
)
|
|
54
|
+
log.info(f"Beginning one diff calculation query {limit=}, {offset=}")
|
|
55
|
+
await diff_query.execute(db=self.db)
|
|
56
|
+
log.info("Diff calculation query complete")
|
|
57
|
+
last_result = None
|
|
58
|
+
for query_result in diff_query.get_results():
|
|
59
|
+
diff_parser.read_result(query_result=query_result)
|
|
60
|
+
last_result = query_result
|
|
61
|
+
has_more_data = False
|
|
62
|
+
if last_result:
|
|
63
|
+
has_more_data = last_result.get_as_type("has_more_data", bool)
|
|
64
|
+
offset += limit
|
|
65
|
+
|
|
18
66
|
async def calculate_diff(
|
|
19
67
|
self,
|
|
20
68
|
base_branch: Branch,
|
|
@@ -22,7 +70,7 @@ class DiffCalculator:
|
|
|
22
70
|
from_time: Timestamp,
|
|
23
71
|
to_time: Timestamp,
|
|
24
72
|
include_unchanged: bool = True,
|
|
25
|
-
previous_node_specifiers: set[
|
|
73
|
+
previous_node_specifiers: dict[str, set[str]] | None = None,
|
|
26
74
|
) -> CalculatedDiffs:
|
|
27
75
|
if diff_branch.name == registry.default_branch:
|
|
28
76
|
diff_branch_from_time = from_time
|
|
@@ -36,6 +84,35 @@ class DiffCalculator:
|
|
|
36
84
|
to_time=to_time,
|
|
37
85
|
previous_node_field_specifiers=previous_node_specifiers,
|
|
38
86
|
)
|
|
87
|
+
node_limit = int(config.SETTINGS.database.query_size_limit / 10)
|
|
88
|
+
fields_limit = int(config.SETTINGS.database.query_size_limit / 3)
|
|
89
|
+
|
|
90
|
+
calculation_request = DiffCalculationRequest(
|
|
91
|
+
base_branch=base_branch,
|
|
92
|
+
diff_branch=diff_branch,
|
|
93
|
+
branch_from_time=diff_branch_from_time,
|
|
94
|
+
from_time=from_time,
|
|
95
|
+
to_time=to_time,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
log.info("Beginning diff node-level calculation queries for branch")
|
|
99
|
+
await self._run_diff_calculation_query(
|
|
100
|
+
diff_parser=diff_parser,
|
|
101
|
+
query_class=DiffNodePathsQuery,
|
|
102
|
+
calculation_request=calculation_request,
|
|
103
|
+
limit=node_limit,
|
|
104
|
+
)
|
|
105
|
+
log.info("Diff node-level calculation queries for branch complete")
|
|
106
|
+
|
|
107
|
+
log.info("Beginning diff field-level calculation queries for branch")
|
|
108
|
+
await self._run_diff_calculation_query(
|
|
109
|
+
diff_parser=diff_parser,
|
|
110
|
+
query_class=DiffFieldPathsQuery,
|
|
111
|
+
calculation_request=calculation_request,
|
|
112
|
+
limit=fields_limit,
|
|
113
|
+
)
|
|
114
|
+
log.info("Diff field-level calculation queries for branch complete")
|
|
115
|
+
|
|
39
116
|
branch_diff_query = await DiffAllPathsQuery.init(
|
|
40
117
|
db=self.db,
|
|
41
118
|
branch=diff_branch,
|
|
@@ -53,8 +130,43 @@ class DiffCalculator:
|
|
|
53
130
|
log.info("Results of query for branch read")
|
|
54
131
|
|
|
55
132
|
if base_branch.name != diff_branch.name:
|
|
56
|
-
new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
|
|
57
133
|
current_node_field_specifiers = diff_parser.get_current_node_field_specifiers()
|
|
134
|
+
new_node_field_specifiers = diff_parser.get_new_node_field_specifiers()
|
|
135
|
+
calculation_request = DiffCalculationRequest(
|
|
136
|
+
base_branch=base_branch,
|
|
137
|
+
diff_branch=base_branch,
|
|
138
|
+
branch_from_time=diff_branch_from_time,
|
|
139
|
+
from_time=from_time,
|
|
140
|
+
to_time=to_time,
|
|
141
|
+
current_node_field_specifiers=current_node_field_specifiers,
|
|
142
|
+
new_node_field_specifiers=new_node_field_specifiers,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
log.info("Beginning diff node-level calculation queries for base")
|
|
146
|
+
await self._run_diff_calculation_query(
|
|
147
|
+
diff_parser=diff_parser,
|
|
148
|
+
query_class=DiffNodePathsQuery,
|
|
149
|
+
calculation_request=calculation_request,
|
|
150
|
+
limit=node_limit,
|
|
151
|
+
)
|
|
152
|
+
log.info("Diff node-level calculation queries for base complete")
|
|
153
|
+
|
|
154
|
+
log.info("Beginning diff field-level calculation queries for base")
|
|
155
|
+
await self._run_diff_calculation_query(
|
|
156
|
+
diff_parser=diff_parser,
|
|
157
|
+
query_class=DiffFieldPathsQuery,
|
|
158
|
+
calculation_request=calculation_request,
|
|
159
|
+
limit=fields_limit,
|
|
160
|
+
)
|
|
161
|
+
log.info("Diff field-level calculation queries for base complete")
|
|
162
|
+
|
|
163
|
+
# Temporary until next change
|
|
164
|
+
current_node_field_specifier_tuples: list[tuple[str, str]] = []
|
|
165
|
+
new_node_field_specifier_tuples: list[tuple[str, str]] = []
|
|
166
|
+
for node_uuid, field_names in current_node_field_specifiers.items():
|
|
167
|
+
current_node_field_specifier_tuples.extend((node_uuid, field_name) for field_name in field_names)
|
|
168
|
+
for node_uuid, field_names in new_node_field_specifiers.items():
|
|
169
|
+
new_node_field_specifier_tuples.extend((node_uuid, field_name) for field_name in field_names)
|
|
58
170
|
|
|
59
171
|
base_diff_query = await DiffAllPathsQuery.init(
|
|
60
172
|
db=self.db,
|
|
@@ -63,10 +175,8 @@ class DiffCalculator:
|
|
|
63
175
|
diff_branch_from_time=diff_branch_from_time,
|
|
64
176
|
diff_from=from_time,
|
|
65
177
|
diff_to=to_time,
|
|
66
|
-
current_node_field_specifiers=
|
|
67
|
-
|
|
68
|
-
],
|
|
69
|
-
new_node_field_specifiers=[(nfs.node_uuid, nfs.field_name) for nfs in new_node_field_specifiers],
|
|
178
|
+
current_node_field_specifiers=current_node_field_specifier_tuples,
|
|
179
|
+
new_node_field_specifiers=new_node_field_specifier_tuples,
|
|
70
180
|
)
|
|
71
181
|
|
|
72
182
|
log.info("Beginning diff calculation query for base")
|
|
@@ -15,7 +15,6 @@ from .model.path import (
|
|
|
15
15
|
EnrichedDiffs,
|
|
16
16
|
EnrichedDiffsMetadata,
|
|
17
17
|
NameTrackingId,
|
|
18
|
-
NodeFieldSpecifier,
|
|
19
18
|
TrackingId,
|
|
20
19
|
)
|
|
21
20
|
|
|
@@ -44,7 +43,7 @@ class EnrichedDiffRequest:
|
|
|
44
43
|
from_time: Timestamp
|
|
45
44
|
to_time: Timestamp
|
|
46
45
|
tracking_id: TrackingId | None = field(default=None)
|
|
47
|
-
node_field_specifiers: set[
|
|
46
|
+
node_field_specifiers: dict[str, set[str]] = field(default_factory=dict)
|
|
48
47
|
|
|
49
48
|
def __repr__(self) -> str:
|
|
50
49
|
return (
|
infrahub/core/diff/model/path.py
CHANGED
|
@@ -82,15 +82,6 @@ def deserialize_tracking_id(tracking_id_str: str) -> TrackingId:
|
|
|
82
82
|
raise ValueError(f"{tracking_id_str} is not a valid TrackingId")
|
|
83
83
|
|
|
84
84
|
|
|
85
|
-
@dataclass
|
|
86
|
-
class NodeFieldSpecifier:
|
|
87
|
-
node_uuid: str
|
|
88
|
-
field_name: str
|
|
89
|
-
|
|
90
|
-
def __hash__(self) -> int:
|
|
91
|
-
return hash(f"{self.node_uuid}:{self.field_name}")
|
|
92
|
-
|
|
93
|
-
|
|
94
85
|
@dataclass
|
|
95
86
|
class NodeDiffFieldSummary:
|
|
96
87
|
kind: str
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import Any, Generator
|
|
2
|
+
|
|
3
|
+
from neo4j.graph import Node as Neo4jNode
|
|
4
|
+
|
|
5
|
+
from infrahub.core.query import Query, QueryType
|
|
6
|
+
from infrahub.database import InfrahubDatabase
|
|
7
|
+
|
|
8
|
+
from ..model.path import TrackingId
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EnrichedDiffAllConflictsQuery(Query):
|
|
12
|
+
name = "enriched_diff_all_conflicts"
|
|
13
|
+
type = QueryType.READ
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, diff_branch_name: str, tracking_id: TrackingId | None = None, diff_id: str | None = None, **kwargs: Any
|
|
17
|
+
) -> None:
|
|
18
|
+
super().__init__(**kwargs)
|
|
19
|
+
if (diff_id is None and tracking_id is None) or (diff_id and tracking_id):
|
|
20
|
+
raise ValueError("EnrichedDiffAllConflictsQuery requires one and only one of `tracking_id` or `diff_id`")
|
|
21
|
+
self.diff_branch_name = diff_branch_name
|
|
22
|
+
self.tracking_id = tracking_id
|
|
23
|
+
self.diff_id = diff_id
|
|
24
|
+
if self.tracking_id is None and self.diff_id is None:
|
|
25
|
+
raise RuntimeError("tracking_id or diff_id is required")
|
|
26
|
+
|
|
27
|
+
async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
|
|
28
|
+
self.params = {
|
|
29
|
+
"diff_branch_name": self.diff_branch_name,
|
|
30
|
+
"diff_id": self.diff_id,
|
|
31
|
+
"tracking_id": self.tracking_id.serialize() if self.tracking_id else None,
|
|
32
|
+
}
|
|
33
|
+
query = """
|
|
34
|
+
MATCH (root:DiffRoot)
|
|
35
|
+
WHERE ($diff_id IS NOT NULL AND root.uuid = $diff_id)
|
|
36
|
+
OR ($tracking_id IS NOT NULL AND root.tracking_id = $tracking_id AND root.diff_branch = $diff_branch_name)
|
|
37
|
+
CALL {
|
|
38
|
+
WITH root
|
|
39
|
+
MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_CONFLICT]->(node_conflict:DiffConflict)
|
|
40
|
+
RETURN node.path_identifier AS path_identifier, node_conflict AS conflict
|
|
41
|
+
UNION
|
|
42
|
+
WITH root
|
|
43
|
+
MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_ATTRIBUTE]->(:DiffAttribute)
|
|
44
|
+
-[:DIFF_HAS_PROPERTY]->(property:DiffProperty)-[:DIFF_HAS_CONFLICT]->(attr_property_conflict:DiffConflict)
|
|
45
|
+
RETURN property.path_identifier AS path_identifier, attr_property_conflict AS conflict
|
|
46
|
+
UNION
|
|
47
|
+
WITH root
|
|
48
|
+
MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(:DiffRelationship)
|
|
49
|
+
-[:DIFF_HAS_ELEMENT]->(element:DiffRelationshipElement)-[:DIFF_HAS_CONFLICT]->(rel_element_conflict:DiffConflict)
|
|
50
|
+
RETURN element.path_identifier AS path_identifier, rel_element_conflict AS conflict
|
|
51
|
+
UNION
|
|
52
|
+
WITH root
|
|
53
|
+
MATCH (root)-[:DIFF_HAS_NODE]->(node:DiffNode)-[:DIFF_HAS_RELATIONSHIP]->(:DiffRelationship)
|
|
54
|
+
-[:DIFF_HAS_ELEMENT]->(:DiffRelationshipElement)-[:DIFF_HAS_PROPERTY]->(property:DiffProperty)
|
|
55
|
+
-[:DIFF_HAS_CONFLICT]->(rel_property_conflict:DiffConflict)
|
|
56
|
+
RETURN property.path_identifier AS path_identifier, rel_property_conflict AS conflict
|
|
57
|
+
}
|
|
58
|
+
"""
|
|
59
|
+
self.return_labels = ["path_identifier", "conflict"]
|
|
60
|
+
self.add_to_query(query=query)
|
|
61
|
+
|
|
62
|
+
def get_conflict_paths_and_nodes(self) -> Generator[tuple[str, Neo4jNode], None, None]:
|
|
63
|
+
for result in self.get_results():
|
|
64
|
+
yield (result.get_as_type("path_identifier", str), result.get_node("conflict"))
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections import defaultdict
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Optional
|
|
5
6
|
from uuid import uuid4
|
|
@@ -22,7 +23,6 @@ from .model.path import (
|
|
|
22
23
|
DiffRelationship,
|
|
23
24
|
DiffRoot,
|
|
24
25
|
DiffSingleRelationship,
|
|
25
|
-
NodeFieldSpecifier,
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
@@ -459,7 +459,7 @@ class DiffQueryParser:
|
|
|
459
459
|
schema_manager: SchemaManager,
|
|
460
460
|
from_time: Timestamp,
|
|
461
461
|
to_time: Optional[Timestamp] = None,
|
|
462
|
-
previous_node_field_specifiers: set[
|
|
462
|
+
previous_node_field_specifiers: dict[str, set[str]] | None = None,
|
|
463
463
|
) -> None:
|
|
464
464
|
self.base_branch_name = base_branch.name
|
|
465
465
|
self.diff_branch_name = diff_branch.name
|
|
@@ -473,7 +473,9 @@ class DiffQueryParser:
|
|
|
473
473
|
self.diff_branched_from_time = Timestamp(diff_branch.get_branched_from())
|
|
474
474
|
self._diff_root_by_branch: dict[str, DiffRootIntermediate] = {}
|
|
475
475
|
self._final_diff_root_by_branch: dict[str, DiffRoot] = {}
|
|
476
|
-
self._previous_node_field_specifiers = previous_node_field_specifiers or
|
|
476
|
+
self._previous_node_field_specifiers = previous_node_field_specifiers or {}
|
|
477
|
+
self._new_node_field_specifiers: dict[str, set[str]] | None = None
|
|
478
|
+
self._current_node_field_specifiers: dict[str, set[str]] | None = None
|
|
477
479
|
|
|
478
480
|
def get_branches(self) -> set[str]:
|
|
479
481
|
return set(self._final_diff_root_by_branch.keys())
|
|
@@ -485,38 +487,59 @@ class DiffQueryParser:
|
|
|
485
487
|
return self._final_diff_root_by_branch[branch]
|
|
486
488
|
return DiffRoot(from_time=self.from_time, to_time=self.to_time, uuid=str(uuid4()), branch=branch, nodes=[])
|
|
487
489
|
|
|
488
|
-
def get_diff_node_field_specifiers(self) -> set[
|
|
490
|
+
def get_diff_node_field_specifiers(self) -> dict[str, set[str]]:
|
|
489
491
|
if self.diff_branch_name not in self._diff_root_by_branch:
|
|
490
|
-
return
|
|
491
|
-
|
|
492
|
+
return {}
|
|
493
|
+
node_field_specifiers_map: dict[str, set[str]] = defaultdict(set)
|
|
492
494
|
diff_root = self._diff_root_by_branch[self.diff_branch_name]
|
|
493
495
|
for node in diff_root.nodes_by_id.values():
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
496
|
+
for attribute_name in node.attributes_by_name:
|
|
497
|
+
node_field_specifiers_map[node.uuid].add(attribute_name)
|
|
498
|
+
for relationship_diff in node.relationships_by_name.values():
|
|
499
|
+
node_field_specifiers_map[node.uuid].add(relationship_diff.identifier)
|
|
500
|
+
return node_field_specifiers_map
|
|
501
|
+
|
|
502
|
+
def _remove_node_specifiers(
|
|
503
|
+
self, node_specifiers: dict[str, set[str]], node_specifiers_to_remove: dict[str, set[str]]
|
|
504
|
+
) -> dict[str, set[str]]:
|
|
505
|
+
final_node_specifiers: dict[str, set[str]] = defaultdict(set)
|
|
506
|
+
for node_uuid, field_names_set in node_specifiers.items():
|
|
507
|
+
specifiers_to_remove = node_specifiers_to_remove.get(node_uuid, set())
|
|
508
|
+
final_specifiers = field_names_set - specifiers_to_remove
|
|
509
|
+
if final_specifiers:
|
|
510
|
+
final_node_specifiers[node_uuid] = final_specifiers
|
|
511
|
+
return final_node_specifiers
|
|
512
|
+
|
|
513
|
+
def get_new_node_field_specifiers(self) -> dict[str, set[str]]:
|
|
514
|
+
if self._new_node_field_specifiers is not None:
|
|
515
|
+
return self._new_node_field_specifiers
|
|
505
516
|
branch_node_specifiers = self.get_diff_node_field_specifiers()
|
|
506
|
-
new_node_field_specifiers =
|
|
517
|
+
new_node_field_specifiers = self._remove_node_specifiers(
|
|
518
|
+
branch_node_specifiers, self._previous_node_field_specifiers
|
|
519
|
+
)
|
|
520
|
+
self._new_node_field_specifiers = new_node_field_specifiers
|
|
507
521
|
return new_node_field_specifiers
|
|
508
522
|
|
|
509
|
-
def get_current_node_field_specifiers(self) -> set[
|
|
523
|
+
def get_current_node_field_specifiers(self) -> dict[str, set[str]]:
|
|
524
|
+
if self._current_node_field_specifiers is not None:
|
|
525
|
+
return self._current_node_field_specifiers
|
|
510
526
|
new_node_field_specifiers = self.get_new_node_field_specifiers()
|
|
511
|
-
current_node_field_specifiers = self.
|
|
527
|
+
current_node_field_specifiers = self._remove_node_specifiers(
|
|
528
|
+
self._previous_node_field_specifiers, new_node_field_specifiers
|
|
529
|
+
)
|
|
530
|
+
self._current_node_field_specifiers = current_node_field_specifiers
|
|
512
531
|
return current_node_field_specifiers
|
|
513
532
|
|
|
514
533
|
def read_result(self, query_result: QueryResult) -> None:
|
|
515
534
|
path = query_result.get_path(label="diff_path")
|
|
516
535
|
database_path = DatabasePath.from_cypher_path(cypher_path=path)
|
|
517
536
|
self._parse_path(database_path=database_path)
|
|
537
|
+
self._current_node_field_specifiers = None
|
|
538
|
+
self._new_node_field_specifiers = None
|
|
518
539
|
|
|
519
540
|
def parse(self, include_unchanged: bool = False) -> None:
|
|
541
|
+
self._new_node_field_specifiers = None
|
|
542
|
+
self._current_node_field_specifiers = None
|
|
520
543
|
if len(self._diff_root_by_branch) > 1:
|
|
521
544
|
self._apply_base_branch_previous_values()
|
|
522
545
|
self._remove_empty_base_diff_root()
|
|
@@ -586,11 +609,12 @@ class DiffQueryParser:
|
|
|
586
609
|
self, database_path: DatabasePath, diff_node: DiffNodeIntermediate
|
|
587
610
|
) -> DiffAttributeIntermediate:
|
|
588
611
|
attribute_name = database_path.attribute_name
|
|
589
|
-
node_field_specifier = NodeFieldSpecifier(node_uuid=diff_node.uuid, field_name=attribute_name)
|
|
590
612
|
branch_name = database_path.deepest_branch
|
|
591
613
|
from_time = self.from_time
|
|
592
|
-
if branch_name == self.base_branch_name
|
|
593
|
-
|
|
614
|
+
if branch_name == self.base_branch_name:
|
|
615
|
+
new_node_field_specifiers = self.get_new_node_field_specifiers()
|
|
616
|
+
if attribute_name in new_node_field_specifiers.get(diff_node.uuid, set()):
|
|
617
|
+
from_time = self.diff_branched_from_time
|
|
594
618
|
if attribute_name not in diff_node.attributes_by_name:
|
|
595
619
|
diff_node.attributes_by_name[attribute_name] = DiffAttributeIntermediate(
|
|
596
620
|
uuid=database_path.attribute_id,
|
|
@@ -627,13 +651,12 @@ class DiffQueryParser:
|
|
|
627
651
|
) -> DiffRelationshipIntermediate:
|
|
628
652
|
diff_relationship = diff_node.relationships_by_name.get(relationship_schema.name)
|
|
629
653
|
if not diff_relationship:
|
|
630
|
-
node_field_specifier = NodeFieldSpecifier(
|
|
631
|
-
node_uuid=diff_node.uuid, field_name=relationship_schema.get_identifier()
|
|
632
|
-
)
|
|
633
654
|
branch_name = database_path.deepest_branch
|
|
634
655
|
from_time = self.from_time
|
|
635
|
-
if branch_name == self.base_branch_name
|
|
636
|
-
|
|
656
|
+
if branch_name == self.base_branch_name:
|
|
657
|
+
new_node_field_specifiers = self.get_new_node_field_specifiers()
|
|
658
|
+
if relationship_schema.get_identifier() in new_node_field_specifiers.get(diff_node.uuid, set()):
|
|
659
|
+
from_time = self.diff_branched_from_time
|
|
637
660
|
diff_relationship = DiffRelationshipIntermediate(
|
|
638
661
|
name=relationship_schema.name,
|
|
639
662
|
cardinality=relationship_schema.cardinality,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import AsyncGenerator, Generator
|
|
2
3
|
|
|
3
4
|
from infrahub import config
|
|
4
5
|
from infrahub.core import registry
|
|
@@ -18,10 +19,10 @@ from ..model.path import (
|
|
|
18
19
|
EnrichedDiffsMetadata,
|
|
19
20
|
EnrichedNodeCreateRequest,
|
|
20
21
|
NodeDiffFieldSummary,
|
|
21
|
-
NodeFieldSpecifier,
|
|
22
22
|
TimeRange,
|
|
23
23
|
TrackingId,
|
|
24
24
|
)
|
|
25
|
+
from ..query.all_conflicts import EnrichedDiffAllConflictsQuery
|
|
25
26
|
from ..query.delete_query import EnrichedDiffDeleteQuery
|
|
26
27
|
from ..query.diff_get import EnrichedDiffGetQuery
|
|
27
28
|
from ..query.diff_summary import DiffSummaryCounters, DiffSummaryQuery
|
|
@@ -318,6 +319,19 @@ class DiffRepository:
|
|
|
318
319
|
raise ResourceNotFoundError(f"No conflict with id {conflict_id}")
|
|
319
320
|
return self.deserializer.deserialize_conflict(diff_conflict_node=conflict_node)
|
|
320
321
|
|
|
322
|
+
async def get_all_conflicts_for_diff(
|
|
323
|
+
self,
|
|
324
|
+
diff_branch_name: str,
|
|
325
|
+
tracking_id: TrackingId | None = None,
|
|
326
|
+
diff_id: str | None = None,
|
|
327
|
+
) -> AsyncGenerator[tuple[str, EnrichedDiffConflict], None]:
|
|
328
|
+
query = await EnrichedDiffAllConflictsQuery.init(
|
|
329
|
+
db=self.db, diff_branch_name=diff_branch_name, tracking_id=tracking_id, diff_id=diff_id
|
|
330
|
+
)
|
|
331
|
+
await query.execute(db=self.db)
|
|
332
|
+
for conflict_path, conflict_node in query.get_conflict_paths_and_nodes():
|
|
333
|
+
yield (conflict_path, self.deserializer.deserialize_conflict(diff_conflict_node=conflict_node))
|
|
334
|
+
|
|
321
335
|
async def get_node_field_summaries(
|
|
322
336
|
self, diff_branch_name: str, tracking_id: TrackingId | None = None, diff_id: str | None = None
|
|
323
337
|
) -> list[NodeDiffFieldSummary]:
|
|
@@ -338,23 +352,18 @@ class DiffRepository:
|
|
|
338
352
|
await query.execute(db=self.db)
|
|
339
353
|
return query.get_num_changes_by_branch()
|
|
340
354
|
|
|
341
|
-
async def get_node_field_specifiers(self, diff_id: str) -> set[
|
|
355
|
+
async def get_node_field_specifiers(self, diff_id: str) -> dict[str, set[str]]:
|
|
342
356
|
limit = config.SETTINGS.database.query_size_limit
|
|
343
357
|
offset = 0
|
|
344
|
-
specifiers: set[
|
|
358
|
+
specifiers: dict[str, set[str]] = defaultdict(set)
|
|
345
359
|
while True:
|
|
346
360
|
query = await EnrichedDiffFieldSpecifiersQuery.init(db=self.db, diff_id=diff_id, offset=offset, limit=limit)
|
|
347
361
|
await query.execute(db=self.db)
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
)
|
|
354
|
-
for field_specifier_tuple in query.get_node_field_specifier_tuples()
|
|
355
|
-
}
|
|
356
|
-
if not new_specifiers:
|
|
362
|
+
has_data = False
|
|
363
|
+
for field_specifier_tuple in query.get_node_field_specifier_tuples():
|
|
364
|
+
specifiers[field_specifier_tuple[0]].add(field_specifier_tuple[1])
|
|
365
|
+
has_data = True
|
|
366
|
+
if not has_data:
|
|
357
367
|
break
|
|
358
|
-
specifiers |= new_specifiers
|
|
359
368
|
offset += limit
|
|
360
369
|
return specifiers
|
infrahub/core/merge.py
CHANGED
|
@@ -177,14 +177,16 @@ class BranchMerger:
|
|
|
177
177
|
if self.source_branch.name == registry.default_branch:
|
|
178
178
|
raise ValidationError(f"Unable to merge the branch '{self.source_branch.name}' into itself")
|
|
179
179
|
|
|
180
|
-
log.
|
|
181
|
-
|
|
180
|
+
log.info("Updating diff for merge")
|
|
181
|
+
await self.diff_coordinator.update_branch_diff(
|
|
182
182
|
base_branch=self.destination_branch, diff_branch=self.source_branch
|
|
183
183
|
)
|
|
184
|
-
log.
|
|
185
|
-
|
|
184
|
+
log.info("Diff updated for merge")
|
|
185
|
+
|
|
186
186
|
errors: list[str] = []
|
|
187
|
-
for conflict_path, conflict in
|
|
187
|
+
async for conflict_path, conflict in self.diff_repository.get_all_conflicts_for_diff(
|
|
188
|
+
diff_branch_name=self.source_branch.name, tracking_id=BranchTrackingId(name=self.source_branch.name)
|
|
189
|
+
):
|
|
188
190
|
if conflict.selected_branch is None or conflict.resolvable is False:
|
|
189
191
|
errors.append(conflict_path)
|
|
190
192
|
|